This PR fixes the neuron runtime with the following: Proxy the PJRT Api method to enforce the client struct sizes since the neuron PJRT plugin doesn't use `>=` but `==` to assert them, breaking PJRT compatibility guarantees. Fixes https://github.com/aws-neuron/aws-neuron-sdk/issues/1095 Reimplement `libneuronxla` in Zig to control neuronx-cc sandboxing and invocation. Implement a python bootstrapper in Zig to create a full blown `neuronx-cc` executable, avoiding the infamous chicken and egg problem of python executables boostrapping when sandboxed (due to fixed path shebangs). --------- Co-authored-by: Corentin Kerisit <corentin.kerisit@gmail.com>
203 lines
6.6 KiB
Zig
203 lines
6.6 KiB
Zig
const std = @import("std");
|
|
|
|
const bazel_builtin = @import("bazel_builtin");
|
|
const c = @import("c");
|
|
const runfiles = @import("runfiles");
|
|
const stdx = @import("stdx");
|
|
|
|
const log = std.log.scoped(.@"zml/runtimes/neuron");
|
|
|
|
fn findFreeTcpPort() !u16 {
|
|
var address = std.net.Address.initIp4(.{ 127, 0, 0, 1 }, 0);
|
|
const sockfd = try std.posix.socket(
|
|
std.posix.AF.INET,
|
|
std.posix.SOCK.STREAM,
|
|
std.posix.IPPROTO.TCP,
|
|
);
|
|
defer std.posix.close(sockfd);
|
|
var socklen = address.getOsSockLen();
|
|
try std.posix.bind(sockfd, &address.any, socklen);
|
|
try std.posix.getsockname(sockfd, &address.any, &socklen);
|
|
return address.getPort();
|
|
}
|
|
|
|
pub export fn zmlxneuron_dlopen(filename: [*c]const u8, flags: c_int) ?*anyopaque {
|
|
const replacements: std.StaticStringMap([:0]const u8) = .initComptime(.{
|
|
.{ "libnccom.so", "libnccom.so.2" },
|
|
.{ "libnrt.so", "libnrt.so.1" },
|
|
.{ "libncfw.so", "libncfw.so.2" },
|
|
});
|
|
|
|
var buf: [std.fs.max_path_bytes]u8 = undefined;
|
|
const new_filename: [*c]const u8 = if (filename) |f| blk: {
|
|
const replacement = replacements.get(std.fs.path.basename(std.mem.span(f))) orelse break :blk f;
|
|
break :blk stdx.fs.path.bufJoinZ(&buf, &.{
|
|
stdx.fs.selfSharedObjectDirPath(),
|
|
replacement,
|
|
}) catch unreachable;
|
|
} else null;
|
|
|
|
return std.c.dlopen(new_filename, @bitCast(flags));
|
|
}
|
|
|
|
extern fn setenv(name: [*:0]const u8, value: [*:0]const u8, overwrite: c_int) c_int;
|
|
fn setupNeuronEnv() !void {
|
|
var buf: [256]u8 = undefined;
|
|
_ = setenv(
|
|
"NEURON_RT_ROOT_COMM_ID",
|
|
try std.fmt.bufPrintZ(&buf, "127.0.0.1:{d}", .{try findFreeTcpPort()}),
|
|
1,
|
|
);
|
|
_ = setenv(
|
|
"NEURON_INTERNAL_PJRT_C_API_VERSION",
|
|
std.fmt.comptimePrint("{d}.{d}", .{
|
|
c.PJRT_API_MAJOR,
|
|
c.PJRT_API_MINOR,
|
|
}),
|
|
1,
|
|
);
|
|
_ = setenv(
|
|
"NEURON_RT_STOCHASTIC_ROUNDING_EN",
|
|
"1",
|
|
1,
|
|
);
|
|
}
|
|
|
|
fn pyStatusCheck(status: c.PyStatus) void {
|
|
if (c.PyStatus_Exception(status) != 0) {
|
|
if (c.PyStatus_IsExit(status) != 0) {
|
|
std.process.exit(@intCast(status.exitcode));
|
|
}
|
|
c.Py_ExitStatusException(status);
|
|
}
|
|
}
|
|
|
|
fn toPosixPathW(file_path: []const u8) error{NameTooLong}![std.posix.PATH_MAX - 1:0]c.wchar_t {
|
|
if (file_path.len >= std.posix.PATH_MAX) return error.NameTooLong;
|
|
|
|
var path_with_null: [std.posix.PATH_MAX - 1:0]c.wchar_t = undefined;
|
|
const len = c.mbstowcs(&path_with_null, file_path.ptr, file_path.len);
|
|
path_with_null[len] = 0;
|
|
return path_with_null;
|
|
}
|
|
|
|
fn setupPythonEnv(sandbox_path: []const u8) !void {
|
|
const Static = struct {
|
|
var py_config: c.PyConfig = undefined;
|
|
};
|
|
|
|
{
|
|
var preconfig: c.PyPreConfig = undefined;
|
|
c.PyPreConfig_InitIsolatedConfig(&preconfig);
|
|
preconfig.utf8_mode = 1;
|
|
pyStatusCheck(c.Py_PreInitialize(&preconfig));
|
|
}
|
|
|
|
c.PyConfig_InitIsolatedConfig(&Static.py_config);
|
|
|
|
Static.py_config.module_search_paths_set = 1;
|
|
Static.py_config.optimization_level = 2;
|
|
Static.py_config.write_bytecode = 0;
|
|
|
|
{
|
|
var buf: [std.fs.max_path_bytes]u8 = undefined;
|
|
const home = try std.fmt.bufPrintZ(&buf, "{f}{d}.{d}", .{
|
|
std.fs.path.fmtJoin(&.{
|
|
sandbox_path,
|
|
"lib",
|
|
"python",
|
|
}),
|
|
c.PY_MAJOR_VERSION,
|
|
c.PY_MINOR_VERSION,
|
|
});
|
|
pyStatusCheck(c.PyConfig_SetBytesString(&Static.py_config, &Static.py_config.home, home));
|
|
pyStatusCheck(c.PyWideStringList_Append(&Static.py_config.module_search_paths, &try toPosixPathW(home)));
|
|
}
|
|
|
|
{
|
|
var buf: [std.fs.max_path_bytes]u8 = undefined;
|
|
const site_packages = try stdx.fs.path.bufJoin(&buf, &.{
|
|
sandbox_path,
|
|
"site-packages",
|
|
});
|
|
pyStatusCheck(c.PyWideStringList_Append(&Static.py_config.module_search_paths, &try toPosixPathW(site_packages)));
|
|
}
|
|
|
|
pyStatusCheck(c.Py_InitializeFromConfig(&Static.py_config));
|
|
|
|
// release the GIL
|
|
_ = c.PyEval_SaveThread();
|
|
}
|
|
|
|
// Duplicates a PJRT Api object while being careful about struct size differences
|
|
fn dupePjrtApi(api: *c.PJRT_Api) c.PJRT_Api {
|
|
var ret: c.PJRT_Api = undefined;
|
|
const struct_size = @min(@sizeOf(c.PJRT_Api), api.struct_size);
|
|
@memcpy(
|
|
std.mem.asBytes(&ret)[0..struct_size],
|
|
std.mem.asBytes(api)[0..struct_size],
|
|
);
|
|
return ret;
|
|
}
|
|
|
|
fn getPjrtApi() !*c.PJRT_Api {
|
|
const Static = struct {
|
|
var inner: *c.PJRT_Api = undefined;
|
|
var proxy: c.PJRT_Api = undefined;
|
|
};
|
|
|
|
var sandbox_path_buf: [std.fs.max_path_bytes]u8 = undefined;
|
|
const sandbox_path = try stdx.fs.path.bufJoin(&sandbox_path_buf, &.{
|
|
stdx.fs.selfSharedObjectDirPath(),
|
|
"..",
|
|
});
|
|
|
|
try setupNeuronEnv();
|
|
try setupPythonEnv(sandbox_path);
|
|
|
|
Static.inner = blk: {
|
|
const GetPjrtApi_inner = GetPjrtApi_blk: {
|
|
var lib: std.DynLib = .{
|
|
.inner = .{
|
|
.handle = handle_blk: {
|
|
var lib_path_buf: [std.fs.max_path_bytes]u8 = undefined;
|
|
const library = try stdx.fs.path.bufJoinZ(&lib_path_buf, &.{ sandbox_path, "lib", "libneuronpjrt.so" });
|
|
break :handle_blk std.c.dlopen(library, .{ .LAZY = true, .NODELETE = true }) orelse {
|
|
log.err("Unable to dlopen plugin: {?s}", .{std.mem.span(std.c.dlerror())});
|
|
return error.DlOpenFailed;
|
|
};
|
|
},
|
|
},
|
|
};
|
|
|
|
break :GetPjrtApi_blk lib.lookup(*const fn () callconv(.c) *c.PJRT_Api, "GetPjrtApi") orelse {
|
|
log.err("Unable to find symbol GetPjrtApi in plugin: {?s}", .{std.mem.span(std.c.dlerror())});
|
|
return error.SymbolNotFound;
|
|
};
|
|
};
|
|
|
|
break :blk GetPjrtApi_inner();
|
|
};
|
|
|
|
Static.proxy = dupePjrtApi(Static.inner);
|
|
// Setup the API proxy functions
|
|
Static.proxy.PJRT_Plugin_Attributes = &struct {
|
|
const STRUCT_SIZE = 24; // according to the failing assertion
|
|
|
|
fn call(args: [*c]c.PJRT_Plugin_Attributes_Args) callconv(.c) ?*c.PJRT_Error {
|
|
var new_args = args.*;
|
|
new_args.struct_size = @min(new_args.struct_size, STRUCT_SIZE);
|
|
return Static.inner.PJRT_Plugin_Attributes.?(&new_args);
|
|
}
|
|
}.call;
|
|
|
|
return &Static.proxy;
|
|
}
|
|
|
|
pub export fn GetPjrtApi() ?*c.PJRT_Api {
|
|
return getPjrtApi() catch |err| {
|
|
log.err("Failed to get PJRT API: {}", .{err});
|
|
return null;
|
|
};
|
|
}
|