72 lines
2.4 KiB
Zig
72 lines
2.4 KiB
Zig
const std = @import("std");
|
|
const builtin = @import("builtin");
|
|
|
|
const async = @import("async");
|
|
const bazel_builtin = @import("bazel_builtin");
|
|
const c = @import("c");
|
|
const pjrt = @import("pjrt");
|
|
const runfiles = @import("runfiles");
|
|
const stdx = @import("stdx");
|
|
|
|
const log = std.log.scoped(.@"zml/runtime/rocm");
|
|
|
|
pub fn isEnabled() bool {
|
|
return @hasDecl(c, "ZML_RUNTIME_ROCM");
|
|
}
|
|
|
|
fn hasRocmDevices() bool {
|
|
inline for (&.{ "/dev/kfd", "/dev/dri" }) |path| {
|
|
async.File.access(path, .{ .mode = .read_only }) catch return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
fn setupRocmEnv(rocm_data_dir: []const u8) !void {
|
|
var buf: [std.fs.max_path_bytes]u8 = undefined;
|
|
_ = c.setenv("ROCM_PATH", try stdx.fs.path.bufJoinZ(&buf, &.{rocm_data_dir}), 1); // must be zero terminated
|
|
}
|
|
|
|
pub fn load() !*const pjrt.Api {
|
|
if (comptime !isEnabled()) {
|
|
return error.Unavailable;
|
|
}
|
|
if (comptime builtin.os.tag != .linux) {
|
|
return error.Unavailable;
|
|
}
|
|
if (!hasRocmDevices()) {
|
|
return error.Unavailable;
|
|
}
|
|
|
|
var arena = std.heap.ArenaAllocator.init(std.heap.smp_allocator);
|
|
defer arena.deinit();
|
|
|
|
var r_ = try runfiles.Runfiles.create(.{ .allocator = arena.allocator() }) orelse {
|
|
stdx.debug.panic("Unable to find runfiles", .{});
|
|
};
|
|
|
|
const source_repo = bazel_builtin.current_repository;
|
|
const r = r_.withSourceRepo(source_repo);
|
|
|
|
var path_buf: [std.fs.max_path_bytes]u8 = undefined;
|
|
const sandbox_path = try r.rlocation("libpjrt_rocm/sandbox", &path_buf) orelse {
|
|
log.err("Failed to find sandbox path for ROCm runtime", .{});
|
|
return error.FileNotFound;
|
|
};
|
|
|
|
try setupRocmEnv(sandbox_path);
|
|
|
|
var lib_path_buf: [std.fs.max_path_bytes]u8 = undefined;
|
|
const lib_path = try stdx.fs.path.bufJoinZ(&lib_path_buf, &.{ sandbox_path, "lib", "libpjrt_rocm.so" });
|
|
|
|
// We must load the PJRT plugin from the main thread.
|
|
//
|
|
// This is because libamdhip64.so use thread local storage as part of the static destructors...
|
|
//
|
|
// This destructor accesses a thread-local variable. If the destructor is
|
|
// executed in a different thread than the one that originally called dlopen()
|
|
// on the library, the thread-local storage (TLS) offset may be resolved
|
|
// relative to the TLS base of the main thread, rather than the thread actually
|
|
// executing the destructor. Accessing this variable results in a segmentation fault...
|
|
return try pjrt.Api.loadFrom(lib_path);
|
|
}
|