From eba0e725324f0389602acf8d19f060b9b74a0844 Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Thu, 10 Apr 2025 14:47:16 +0000 Subject: [PATCH] runtimes/tpu: sandbox TPU PJRT plugin; no external dependencies. --- runtimes/tpu/BUILD.bazel | 2 ++ runtimes/tpu/libpjrt_tpu.BUILD.bazel | 21 +++++++++++++++++---- runtimes/tpu/tpu.zig | 27 ++++++++++++++++++++++++++- 3 files changed, 45 insertions(+), 5 deletions(-) diff --git a/runtimes/tpu/BUILD.bazel b/runtimes/tpu/BUILD.bazel index dd4686c..94ae63c 100644 --- a/runtimes/tpu/BUILD.bazel +++ b/runtimes/tpu/BUILD.bazel @@ -21,6 +21,8 @@ zig_library( "//runtimes:tpu.enabled": [ ":libpjrt_tpu", "//async", + "//stdx", + "@rules_zig//zig/runfiles", ], "//conditions:default": [":empty"], }), diff --git a/runtimes/tpu/libpjrt_tpu.BUILD.bazel b/runtimes/tpu/libpjrt_tpu.BUILD.bazel index 2fad364..338867b 100644 --- a/runtimes/tpu/libpjrt_tpu.BUILD.bazel +++ b/runtimes/tpu/libpjrt_tpu.BUILD.bazel @@ -1,8 +1,21 @@ -load("@zml//bazel:cc_import.bzl", "cc_import") +load("@aspect_bazel_lib//lib:copy_to_directory.bzl", "copy_to_directory") +load("@bazel_skylib//rules:copy_file.bzl", "copy_file") -cc_import( +copy_file( + name = "libpjrt_tpu.so", + src = "libtpu/libtpu.so", + out = "lib/libpjrt_tpu.so", +) + +copy_to_directory( + name = "sandbox", + srcs = [ + "lib/libpjrt_tpu.so", + ], +) + +cc_library( name = "libpjrt_tpu", - shared_library = "libtpu/libtpu.so", - soname = "libpjrt_tpu.so", + data = [":sandbox"], visibility = ["@zml//runtimes/tpu:__subpackages__"], ) diff --git a/runtimes/tpu/tpu.zig b/runtimes/tpu/tpu.zig index 980f421..dc12b43 100644 --- a/runtimes/tpu/tpu.zig +++ b/runtimes/tpu/tpu.zig @@ -4,6 +4,11 @@ const std = @import("std"); const asynk = @import("async"); const pjrt = @import("pjrt"); const c = @import("c"); +const stdx = @import("stdx"); +const bazel_builtin = @import("bazel_builtin"); +const runfiles = @import("runfiles"); + +const log = std.log.scoped(.@"zml/runtime/tpu"); pub fn isEnabled() bool { return @hasDecl(c, "ZML_RUNTIME_TPU"); @@ -37,5 +42,25 @@ pub fn load() !*const pjrt.Api { return error.Unavailable; } - return try asynk.callBlocking(pjrt.Api.loadFrom, .{"libpjrt_tpu.so"}); + var arena = std.heap.ArenaAllocator.init(std.heap.c_allocator); + defer arena.deinit(); + + var r_ = try runfiles.Runfiles.create(.{ .allocator = arena.allocator() }) orelse { + stdx.debug.panic("Unable to find runfiles", .{}); + }; + + const source_repo = bazel_builtin.current_repository; + const r = r_.withSourceRepo(source_repo); + + var path_buf: [std.fs.max_path_bytes]u8 = undefined; + const sandbox_path = try r.rlocation("libpjrt_tpu/sandbox", &path_buf) orelse { + log.err("Failed to find sandbox path for TPU runtime", .{}); + return error.FileNotFound; + }; + + return blk: { + var lib_path_buf: [std.fs.max_path_bytes]u8 = undefined; + const path = try stdx.fs.path.bufJoinZ(&lib_path_buf, &.{ sandbox_path, "lib", "libpjrt_tpu.so" }); + break :blk asynk.callBlocking(pjrt.Api.loadFrom, .{path}); + }; }