Radix/runtimes/neuron/neuron.zig
Tarry Singh 1427286716 runtimes/neuron: fix neuron runtime
This PR fixes the neuron runtime with the following:

Proxy the PJRT Api method to enforce the client struct sizes since the
neuron PJRT plugin doesn't use `>=` but `==` to assert them, breaking
PJRT compatibility guarantees.
Fixes https://github.com/aws-neuron/aws-neuron-sdk/issues/1095

Reimplement `libneuronxla` in Zig to control neuronx-cc sandboxing and
invocation.

Implement a python bootstrapper in Zig to create a full blown
`neuronx-cc` executable, avoiding the infamous chicken and egg problem
of python executables boostrapping when sandboxed (due to fixed path
shebangs).

---------

Co-authored-by: Corentin Kerisit <corentin.kerisit@gmail.com>
2025-07-15 15:26:03 +00:00

70 lines
2.0 KiB
Zig

const std = @import("std");
const builtin = @import("builtin");
const asynk = @import("async");
const bazel_builtin = @import("bazel_builtin");
const c = @import("c");
const pjrt = @import("pjrt");
const runfiles = @import("runfiles");
const stdx = @import("stdx");
const log = std.log.scoped(.@"zml/runtime/neuron");
pub fn isEnabled() bool {
return @hasDecl(c, "ZML_RUNTIME_NEURON");
}
fn hasNeuronDevice() bool {
asynk.File.access("/dev/neuron0", .{ .mode = .read_only }) catch return false;
return true;
}
fn isRunningOnEC2() !bool {
const AmazonEC2 = "Amazon EC2";
var f = try asynk.File.open("/sys/devices/virtual/dmi/id/sys_vendor", .{ .mode = .read_only });
defer f.close() catch {};
var buf: [AmazonEC2.len]u8 = undefined;
_ = try f.reader().readAll(&buf);
return std.mem.eql(u8, &buf, AmazonEC2);
}
pub fn load() !*const pjrt.Api {
if (comptime !isEnabled()) {
return error.Unavailable;
}
if (comptime builtin.os.tag != .linux) {
return error.Unavailable;
}
if (!(isRunningOnEC2() catch false)) {
return error.Unavailable;
}
if (!hasNeuronDevice()) {
return error.Unavailable;
}
var arena = std.heap.ArenaAllocator.init(std.heap.c_allocator);
defer arena.deinit();
var r_ = try runfiles.Runfiles.create(.{ .allocator = arena.allocator() }) orelse {
stdx.debug.panic("Unable to find runfiles", .{});
};
const source_repo = bazel_builtin.current_repository;
const r = r_.withSourceRepo(source_repo);
var sandbox_path_buf: [std.fs.max_path_bytes]u8 = undefined;
const sandbox_path = try r.rlocation("libpjrt_neuron/sandbox", &sandbox_path_buf) orelse {
log.err("Failed to find sandbox path for NEURON runtime", .{});
return error.FileNotFound;
};
return blk: {
var lib_path_buf: [std.fs.max_path_bytes]u8 = undefined;
const path = try stdx.fs.path.bufJoinZ(&lib_path_buf, &.{ sandbox_path, "lib", "libpjrt_neuron.so" });
break :blk asynk.callBlocking(pjrt.Api.loadFrom, .{path});
};
}