Radix/runtimes/neuron/BUILD.bazel
Tarry Singh 1427286716 runtimes/neuron: fix neuron runtime
This PR fixes the neuron runtime with the following:

Proxy the PJRT Api method to enforce the client struct sizes since the
neuron PJRT plugin doesn't use `>=` but `==` to assert them, breaking
PJRT compatibility guarantees.
Fixes https://github.com/aws-neuron/aws-neuron-sdk/issues/1095

Reimplement `libneuronxla` in Zig to control neuronx-cc sandboxing and
invocation.

Implement a python bootstrapper in Zig to create a full blown
`neuronx-cc` executable, avoiding the infamous chicken and egg problem
of python executables boostrapping when sandboxed (due to fixed path
shebangs).

---------

Co-authored-by: Corentin Kerisit <corentin.kerisit@gmail.com>
2025-07-15 15:26:03 +00:00

132 lines
3.0 KiB
Python

load("@com_google_protobuf//bazel:upb_proto_library.bzl", "upb_c_proto_library")
load("@rules_cc//cc:cc_library.bzl", "cc_library")
load("@rules_python//python/uv:lock.bzl", uv_lock = "lock")
load("@rules_zig//zig:defs.bzl", "zig_binary", "zig_library", "zig_shared_library")
load("@zml//bazel:runfiles.bzl", "runfiles_to_default")
# A proxy PJRT Plugin that loads the Neuron PJRT Plugin
# and returns the instance from nested GetPjrtApi.
#
# Additionally, it provides a way to load implicit transitive dependencies
# of neuronx-cc (see add_needed of the patchelf target below).
zig_shared_library(
name = "libpjrt_neuron",
copts = ["-fno-stack-check"],
main = "libpjrt_neuron.zig",
visibility = ["@libpjrt_neuron//:__subpackages__"],
deps = [
":libpython",
"//stdx",
"@rules_zig//zig/runfiles",
"@xla//xla/pjrt/c:pjrt_c_api_hdrs",
],
)
zig_binary(
name = "neuronx-cc",
data = ["@neuron_py_deps//neuronx_cc"],
linkopts = ["-Wl,-rpath,$ORIGIN/../lib"],
main = "neuronx-cc.zig",
tags = ["manual"],
deps = [":libpython"],
)
runfiles_to_default(
name = "neuronx-cc_files",
visibility = ["@libpjrt_neuron//:__subpackages__"],
deps = [":neuronx-cc"],
)
cc_library(
name = "libpython",
hdrs = ["libpython.h"],
deps = [
"@rules_python//python/cc:current_py_cc_headers",
"@rules_python//python/cc:current_py_cc_libs",
],
)
cc_library(
name = "empty",
)
cc_library(
name = "zmlxneuron",
defines = ["ZML_RUNTIME_NEURON"],
)
cc_library(
name = "libnrt_headers",
hdrs = ["nrt.h"],
deps = ["@libpjrt_neuron//:libnrt_headers"],
)
filegroup(
name = "layers",
srcs = [],
visibility = ["//visibility:public"],
)
upb_c_proto_library(
name = "xla_data_upb",
deps = ["@xla//xla:xla_data_proto"],
)
upb_c_proto_library(
name = "hlo_proto_upb",
deps = ["@xla//xla/service:hlo_proto"],
)
zig_shared_library(
name = "libneuronxla",
copts = [
"-fno-stack-check",
"-fPIC",
],
main = "libneuronxla.zig",
shared_lib_name = "libneuronxla.so",
visibility = ["@libpjrt_neuron//:__subpackages__"],
deps = [
":hlo_proto_upb",
":libpython",
":xla_data_upb",
"//stdx",
"//upb",
],
)
uv_lock(
name = "requirements",
srcs = ["requirements.in"],
out = "requirements.lock.txt",
args = [
"--emit-index-url",
"--emit-find-links",
"--index-strategy=unsafe-best-match",
"--upgrade",
],
python_version = "3.11",
tags = ["manual"],
)
zig_library(
name = "neuron",
import_name = "runtimes/neuron",
main = "neuron.zig",
visibility = ["//visibility:public"],
deps = [
"//pjrt",
] + select({
"//runtimes:neuron.enabled": [
":libnrt_headers",
":libpython",
":zmlxneuron",
"//async",
"//stdx",
"@libpjrt_neuron",
"@rules_zig//zig/runfiles",
],
"//conditions:default": [":empty"],
}),
)