2023-05-15 09:36:41 +00:00
|
|
|
const builtin = @import("builtin");
|
2024-01-15 09:41:42 +00:00
|
|
|
const std = @import("std");
|
|
|
|
|
|
2023-05-15 09:36:41 +00:00
|
|
|
const asynk = @import("async");
|
2024-01-15 09:41:42 +00:00
|
|
|
const bazel_builtin = @import("bazel_builtin");
|
2023-05-15 09:36:41 +00:00
|
|
|
const c = @import("c");
|
2024-01-15 09:41:42 +00:00
|
|
|
const pjrt = @import("pjrt");
|
|
|
|
|
const runfiles = @import("runfiles");
|
|
|
|
|
const stdx = @import("stdx");
|
2023-05-15 09:36:41 +00:00
|
|
|
|
2024-01-26 13:02:23 +00:00
|
|
|
const ROCmEnvEntry = struct {
|
|
|
|
|
name: [:0]const u8,
|
|
|
|
|
rpath: []const u8,
|
|
|
|
|
dirname: bool,
|
|
|
|
|
mandatory: bool,
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
const rocm_env_entries: []const ROCmEnvEntry = &.{
|
|
|
|
|
.{ .name = "HIPBLASLT_EXT_OP_LIBRARY_PATH", .rpath = "hipblaslt/lib/hipblaslt/library/hipblasltExtOpLibrary.dat", .dirname = false, .mandatory = false },
|
|
|
|
|
.{ .name = "HIPBLASLT_TENSILE_LIBPATH", .rpath = "hipblaslt/lib/hipblaslt/library/TensileManifest.txt", .dirname = true, .mandatory = false },
|
|
|
|
|
.{ .name = "ROCBLAS_TENSILE_LIBPATH", .rpath = "rocblas/lib/rocblas/library/TensileManifest.txt", .dirname = true, .mandatory = true },
|
|
|
|
|
.{ .name = "ROCM_PATH", .rpath = "libpjrt_rocm/sandbox", .dirname = false, .mandatory = true },
|
|
|
|
|
};
|
|
|
|
|
|
2023-05-15 09:36:41 +00:00
|
|
|
pub fn isEnabled() bool {
|
|
|
|
|
return @hasDecl(c, "ZML_RUNTIME_ROCM");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn hasRocmDevices() bool {
|
|
|
|
|
inline for (&.{ "/dev/kfd", "/dev/dri" }) |path| {
|
|
|
|
|
asynk.File.access(path, .{ .mode = .read_only }) catch return false;
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
2024-01-15 09:41:42 +00:00
|
|
|
fn setupRocmEnv() !void {
|
|
|
|
|
var arena = std.heap.ArenaAllocator.init(std.heap.c_allocator);
|
|
|
|
|
defer arena.deinit();
|
|
|
|
|
|
|
|
|
|
const r = blk: {
|
|
|
|
|
var r_ = try runfiles.Runfiles.create(.{ .allocator = arena.allocator() }) orelse {
|
|
|
|
|
stdx.debug.panic("Unable to find Runfiles directory", .{});
|
|
|
|
|
};
|
|
|
|
|
const source_repo = bazel_builtin.current_repository;
|
|
|
|
|
break :blk r_.withSourceRepo(source_repo);
|
|
|
|
|
};
|
|
|
|
|
|
2024-01-26 13:02:23 +00:00
|
|
|
for (rocm_env_entries) |entry| {
|
|
|
|
|
var real_path = r.rlocationAlloc(arena.allocator(), entry.rpath) catch null orelse {
|
|
|
|
|
if (entry.mandatory) {
|
2024-05-03 15:57:56 +00:00
|
|
|
stdx.debug.panic("Unable to find {s} in {s}\n", .{ entry.name, bazel_builtin.current_repository });
|
2024-01-26 13:02:23 +00:00
|
|
|
}
|
|
|
|
|
continue;
|
2024-01-15 09:41:42 +00:00
|
|
|
};
|
|
|
|
|
|
2024-01-26 13:02:23 +00:00
|
|
|
if (entry.dirname) {
|
2024-01-15 09:41:42 +00:00
|
|
|
real_path = std.fs.path.dirname(real_path) orelse {
|
|
|
|
|
stdx.debug.panic("Unable to dirname on {s}", .{real_path});
|
|
|
|
|
};
|
|
|
|
|
}
|
|
|
|
|
|
2024-01-26 13:02:23 +00:00
|
|
|
_ = c.setenv(entry.name, try arena.allocator().dupeZ(u8, real_path), 1);
|
2024-01-15 09:41:42 +00:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2023-05-15 09:36:41 +00:00
|
|
|
pub fn load() !*const pjrt.Api {
|
|
|
|
|
if (comptime !isEnabled()) {
|
|
|
|
|
return error.Unavailable;
|
|
|
|
|
}
|
|
|
|
|
if (comptime builtin.os.tag != .linux) {
|
|
|
|
|
return error.Unavailable;
|
|
|
|
|
}
|
|
|
|
|
if (!hasRocmDevices()) {
|
|
|
|
|
return error.Unavailable;
|
|
|
|
|
}
|
|
|
|
|
|
2024-01-15 09:41:42 +00:00
|
|
|
try setupRocmEnv();
|
|
|
|
|
|
2024-01-16 14:13:45 +00:00
|
|
|
return try asynk.callBlocking(pjrt.Api.loadFrom, .{"libpjrt_rocm.so"});
|
2023-05-15 09:36:41 +00:00
|
|
|
}
|