From edc2ac26f8863f3dfee050d561430350366dbd3b Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Fri, 26 Jan 2024 13:02:23 +0000 Subject: [PATCH] Adjust ROCm runtime sandboxing to hook only the PJRT plugin and make hipblastlt bytecodes optional. --- runtimes/rocm/libpjrt_rocm.BUILD.bazel | 4 +++ runtimes/rocm/rocm.bzl | 8 ------ runtimes/rocm/rocm.zig | 38 +++++++++++++++----------- 3 files changed, 26 insertions(+), 24 deletions(-) diff --git a/runtimes/rocm/libpjrt_rocm.BUILD.bazel b/runtimes/rocm/libpjrt_rocm.BUILD.bazel index c337717..f19ab7f 100644 --- a/runtimes/rocm/libpjrt_rocm.BUILD.bazel +++ b/runtimes/rocm/libpjrt_rocm.BUILD.bazel @@ -60,6 +60,10 @@ cc_import( ":_hipblaslt": ["@hipblaslt//:runfiles"], "//conditions:default": [], }), + add_needed = ["libzmlxrocm.so.0"], + rename_dynamic_symbols = { + "dlopen": "zmlxrocm_dlopen", + }, shared_library = "libpjrt_rocm.so", visibility = ["//visibility:public"], deps = [ diff --git a/runtimes/rocm/rocm.bzl b/runtimes/rocm/rocm.bzl index 491f678..07242c7 100644 --- a/runtimes/rocm/rocm.bzl +++ b/runtimes/rocm/rocm.bzl @@ -119,10 +119,6 @@ load("@zml//runtimes/rocm:gfx.bzl", "bytecode_select") cc_import( name = "rocblas", shared_library = "lib/librocblas.so.4", - add_needed = ["libzmlxrocm.so.0"], - rename_dynamic_symbols = { - "dlopen": "zmlxrocm_dlopen", - }, visibility = ["@libpjrt_rocm//:__subpackages__"], ) @@ -160,10 +156,6 @@ load("@zml//runtimes/rocm:gfx.bzl", "bytecode_select") cc_import( name = "hipblaslt", shared_library = "lib/libhipblaslt.so.0", - add_needed = ["libzmlxrocm.so.0"], - rename_dynamic_symbols = { - "dlopen": "zmlxrocm_dlopen", - }, visibility = ["@libpjrt_rocm//:__subpackages__"], ) diff --git a/runtimes/rocm/rocm.zig b/runtimes/rocm/rocm.zig index d0bb66c..72acb7e 100644 --- a/runtimes/rocm/rocm.zig +++ b/runtimes/rocm/rocm.zig @@ -8,6 +8,20 @@ const pjrt = @import("pjrt"); const runfiles = @import("runfiles"); const stdx = @import("stdx"); +const ROCmEnvEntry = struct { + name: [:0]const u8, + rpath: []const u8, + dirname: bool, + mandatory: bool, +}; + +const rocm_env_entries: []const ROCmEnvEntry = &.{ + .{ .name = "HIPBLASLT_EXT_OP_LIBRARY_PATH", .rpath = "hipblaslt/lib/hipblaslt/library/hipblasltExtOpLibrary.dat", .dirname = false, .mandatory = false }, + .{ .name = "HIPBLASLT_TENSILE_LIBPATH", .rpath = "hipblaslt/lib/hipblaslt/library/TensileManifest.txt", .dirname = true, .mandatory = false }, + .{ .name = "ROCBLAS_TENSILE_LIBPATH", .rpath = "rocblas/lib/rocblas/library/TensileManifest.txt", .dirname = true, .mandatory = true }, + .{ .name = "ROCM_PATH", .rpath = "libpjrt_rocm/sandbox", .dirname = false, .mandatory = true }, +}; + pub fn isEnabled() bool { return @hasDecl(c, "ZML_RUNTIME_ROCM"); } @@ -23,13 +37,6 @@ fn setupRocmEnv() !void { var arena = std.heap.ArenaAllocator.init(std.heap.c_allocator); defer arena.deinit(); - const paths = .{ - .{ "HIPBLASLT_EXT_OP_LIBRARY_PATH", "hipblaslt/lib/hipblaslt/library/hipblasltExtOpLibrary.dat", false }, - .{ "HIPBLASLT_TENSILE_LIBPATH", "hipblaslt/lib/hipblaslt/library/TensileManifest.txt", true }, - .{ "ROCBLAS_TENSILE_LIBPATH", "rocblas/lib/rocblas/library/TensileManifest.txt", true }, - .{ "ROCM_PATH", "libpjrt_rocm/sandbox", false }, - }; - const r = blk: { var r_ = try runfiles.Runfiles.create(.{ .allocator = arena.allocator() }) orelse { stdx.debug.panic("Unable to find Runfiles directory", .{}); @@ -38,22 +45,21 @@ fn setupRocmEnv() !void { break :blk r_.withSourceRepo(source_repo); }; - inline for (paths) |path| { - const name = path[0]; - const rpath = path[1]; - const dirname = path[2]; - - var real_path = r.rlocationAlloc(arena.allocator(), rpath) catch null orelse { - stdx.debug.panic("Unable to find " ++ name ++ " in " ++ bazel_builtin.current_repository, .{}); + for (rocm_env_entries) |entry| { + var real_path = r.rlocationAlloc(arena.allocator(), entry.rpath) catch null orelse { + if (entry.mandatory) { + stdx.debug.panic("Unable to find {s} in {s}", .{ entry.name, bazel_builtin.current_repository }); + } + continue; }; - if (dirname) { + if (entry.dirname) { real_path = std.fs.path.dirname(real_path) orelse { stdx.debug.panic("Unable to dirname on {s}", .{real_path}); }; } - _ = c.setenv(name, try arena.allocator().dupeZ(u8, real_path), 1); + _ = c.setenv(entry.name, try arena.allocator().dupeZ(u8, real_path), 1); } }