Adjust ROCm runtime sandboxing to hook only the PJRT plugin and make hipblastlt bytecodes optional.
This commit is contained in:
parent
0ce36599da
commit
edc2ac26f8
@ -60,6 +60,10 @@ cc_import(
|
|||||||
":_hipblaslt": ["@hipblaslt//:runfiles"],
|
":_hipblaslt": ["@hipblaslt//:runfiles"],
|
||||||
"//conditions:default": [],
|
"//conditions:default": [],
|
||||||
}),
|
}),
|
||||||
|
add_needed = ["libzmlxrocm.so.0"],
|
||||||
|
rename_dynamic_symbols = {
|
||||||
|
"dlopen": "zmlxrocm_dlopen",
|
||||||
|
},
|
||||||
shared_library = "libpjrt_rocm.so",
|
shared_library = "libpjrt_rocm.so",
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
|
|||||||
@ -119,10 +119,6 @@ load("@zml//runtimes/rocm:gfx.bzl", "bytecode_select")
|
|||||||
cc_import(
|
cc_import(
|
||||||
name = "rocblas",
|
name = "rocblas",
|
||||||
shared_library = "lib/librocblas.so.4",
|
shared_library = "lib/librocblas.so.4",
|
||||||
add_needed = ["libzmlxrocm.so.0"],
|
|
||||||
rename_dynamic_symbols = {
|
|
||||||
"dlopen": "zmlxrocm_dlopen",
|
|
||||||
},
|
|
||||||
visibility = ["@libpjrt_rocm//:__subpackages__"],
|
visibility = ["@libpjrt_rocm//:__subpackages__"],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -160,10 +156,6 @@ load("@zml//runtimes/rocm:gfx.bzl", "bytecode_select")
|
|||||||
cc_import(
|
cc_import(
|
||||||
name = "hipblaslt",
|
name = "hipblaslt",
|
||||||
shared_library = "lib/libhipblaslt.so.0",
|
shared_library = "lib/libhipblaslt.so.0",
|
||||||
add_needed = ["libzmlxrocm.so.0"],
|
|
||||||
rename_dynamic_symbols = {
|
|
||||||
"dlopen": "zmlxrocm_dlopen",
|
|
||||||
},
|
|
||||||
visibility = ["@libpjrt_rocm//:__subpackages__"],
|
visibility = ["@libpjrt_rocm//:__subpackages__"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -8,6 +8,20 @@ const pjrt = @import("pjrt");
|
|||||||
const runfiles = @import("runfiles");
|
const runfiles = @import("runfiles");
|
||||||
const stdx = @import("stdx");
|
const stdx = @import("stdx");
|
||||||
|
|
||||||
|
const ROCmEnvEntry = struct {
|
||||||
|
name: [:0]const u8,
|
||||||
|
rpath: []const u8,
|
||||||
|
dirname: bool,
|
||||||
|
mandatory: bool,
|
||||||
|
};
|
||||||
|
|
||||||
|
const rocm_env_entries: []const ROCmEnvEntry = &.{
|
||||||
|
.{ .name = "HIPBLASLT_EXT_OP_LIBRARY_PATH", .rpath = "hipblaslt/lib/hipblaslt/library/hipblasltExtOpLibrary.dat", .dirname = false, .mandatory = false },
|
||||||
|
.{ .name = "HIPBLASLT_TENSILE_LIBPATH", .rpath = "hipblaslt/lib/hipblaslt/library/TensileManifest.txt", .dirname = true, .mandatory = false },
|
||||||
|
.{ .name = "ROCBLAS_TENSILE_LIBPATH", .rpath = "rocblas/lib/rocblas/library/TensileManifest.txt", .dirname = true, .mandatory = true },
|
||||||
|
.{ .name = "ROCM_PATH", .rpath = "libpjrt_rocm/sandbox", .dirname = false, .mandatory = true },
|
||||||
|
};
|
||||||
|
|
||||||
pub fn isEnabled() bool {
|
pub fn isEnabled() bool {
|
||||||
return @hasDecl(c, "ZML_RUNTIME_ROCM");
|
return @hasDecl(c, "ZML_RUNTIME_ROCM");
|
||||||
}
|
}
|
||||||
@ -23,13 +37,6 @@ fn setupRocmEnv() !void {
|
|||||||
var arena = std.heap.ArenaAllocator.init(std.heap.c_allocator);
|
var arena = std.heap.ArenaAllocator.init(std.heap.c_allocator);
|
||||||
defer arena.deinit();
|
defer arena.deinit();
|
||||||
|
|
||||||
const paths = .{
|
|
||||||
.{ "HIPBLASLT_EXT_OP_LIBRARY_PATH", "hipblaslt/lib/hipblaslt/library/hipblasltExtOpLibrary.dat", false },
|
|
||||||
.{ "HIPBLASLT_TENSILE_LIBPATH", "hipblaslt/lib/hipblaslt/library/TensileManifest.txt", true },
|
|
||||||
.{ "ROCBLAS_TENSILE_LIBPATH", "rocblas/lib/rocblas/library/TensileManifest.txt", true },
|
|
||||||
.{ "ROCM_PATH", "libpjrt_rocm/sandbox", false },
|
|
||||||
};
|
|
||||||
|
|
||||||
const r = blk: {
|
const r = blk: {
|
||||||
var r_ = try runfiles.Runfiles.create(.{ .allocator = arena.allocator() }) orelse {
|
var r_ = try runfiles.Runfiles.create(.{ .allocator = arena.allocator() }) orelse {
|
||||||
stdx.debug.panic("Unable to find Runfiles directory", .{});
|
stdx.debug.panic("Unable to find Runfiles directory", .{});
|
||||||
@ -38,22 +45,21 @@ fn setupRocmEnv() !void {
|
|||||||
break :blk r_.withSourceRepo(source_repo);
|
break :blk r_.withSourceRepo(source_repo);
|
||||||
};
|
};
|
||||||
|
|
||||||
inline for (paths) |path| {
|
for (rocm_env_entries) |entry| {
|
||||||
const name = path[0];
|
var real_path = r.rlocationAlloc(arena.allocator(), entry.rpath) catch null orelse {
|
||||||
const rpath = path[1];
|
if (entry.mandatory) {
|
||||||
const dirname = path[2];
|
stdx.debug.panic("Unable to find {s} in {s}", .{ entry.name, bazel_builtin.current_repository });
|
||||||
|
}
|
||||||
var real_path = r.rlocationAlloc(arena.allocator(), rpath) catch null orelse {
|
continue;
|
||||||
stdx.debug.panic("Unable to find " ++ name ++ " in " ++ bazel_builtin.current_repository, .{});
|
|
||||||
};
|
};
|
||||||
|
|
||||||
if (dirname) {
|
if (entry.dirname) {
|
||||||
real_path = std.fs.path.dirname(real_path) orelse {
|
real_path = std.fs.path.dirname(real_path) orelse {
|
||||||
stdx.debug.panic("Unable to dirname on {s}", .{real_path});
|
stdx.debug.panic("Unable to dirname on {s}", .{real_path});
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
_ = c.setenv(name, try arena.allocator().dupeZ(u8, real_path), 1);
|
_ = c.setenv(entry.name, try arena.allocator().dupeZ(u8, real_path), 1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user