const std = @import("std"); const builtin = @import("builtin"); const async = @import("async"); const bazel_builtin = @import("bazel_builtin"); const c = @import("c"); const pjrt = @import("pjrt"); const runfiles = @import("runfiles"); const stdx = @import("stdx"); const log = std.log.scoped(.@"zml/runtime/rocm"); pub fn isEnabled() bool { return @hasDecl(c, "ZML_RUNTIME_ROCM"); } fn hasRocmDevices() bool { inline for (&.{ "/dev/kfd", "/dev/dri" }) |path| { async.File.access(path, .{ .mode = .read_only }) catch return false; } return true; } fn setupRocmEnv(rocm_data_dir: []const u8) !void { var buf: [std.fs.max_path_bytes]u8 = undefined; _ = c.setenv("ROCM_PATH", try stdx.fs.path.bufJoinZ(&buf, &.{rocm_data_dir}), 1); // must be zero terminated } pub fn load() !*const pjrt.Api { if (comptime !isEnabled()) { return error.Unavailable; } if (comptime builtin.os.tag != .linux) { return error.Unavailable; } if (!hasRocmDevices()) { return error.Unavailable; } var arena = std.heap.ArenaAllocator.init(std.heap.smp_allocator); defer arena.deinit(); var r_ = try runfiles.Runfiles.create(.{ .allocator = arena.allocator() }) orelse { stdx.debug.panic("Unable to find runfiles", .{}); }; const source_repo = bazel_builtin.current_repository; const r = r_.withSourceRepo(source_repo); var path_buf: [std.fs.max_path_bytes]u8 = undefined; const sandbox_path = try r.rlocation("libpjrt_rocm/sandbox", &path_buf) orelse { log.err("Failed to find sandbox path for ROCm runtime", .{}); return error.FileNotFound; }; try setupRocmEnv(sandbox_path); var lib_path_buf: [std.fs.max_path_bytes]u8 = undefined; const lib_path = try stdx.fs.path.bufJoinZ(&lib_path_buf, &.{ sandbox_path, "lib", "libpjrt_rocm.so" }); // We must load the PJRT plugin from the main thread. // // This is because libamdhip64.so use thread local storage as part of the static destructors... // // This destructor accesses a thread-local variable. If the destructor is // executed in a different thread than the one that originally called dlopen() // on the library, the thread-local storage (TLS) offset may be resolved // relative to the TLS base of the main thread, rather than the thread actually // executing the destructor. Accessing this variable results in a segmentation fault... return try pjrt.Api.loadFrom(lib_path); }