runtimes/tpu: sandbox TPU PJRT plugin; no external dependencies.

This commit is contained in:
Tarry Singh 2025-04-10 14:47:16 +00:00
parent 8073e45894
commit eba0e72532
3 changed files with 45 additions and 5 deletions

View File

@ -21,6 +21,8 @@ zig_library(
"//runtimes:tpu.enabled": [
":libpjrt_tpu",
"//async",
"//stdx",
"@rules_zig//zig/runfiles",
],
"//conditions:default": [":empty"],
}),

View File

@ -1,8 +1,21 @@
load("@zml//bazel:cc_import.bzl", "cc_import")
load("@aspect_bazel_lib//lib:copy_to_directory.bzl", "copy_to_directory")
load("@bazel_skylib//rules:copy_file.bzl", "copy_file")
cc_import(
copy_file(
name = "libpjrt_tpu.so",
src = "libtpu/libtpu.so",
out = "lib/libpjrt_tpu.so",
)
copy_to_directory(
name = "sandbox",
srcs = [
"lib/libpjrt_tpu.so",
],
)
cc_library(
name = "libpjrt_tpu",
shared_library = "libtpu/libtpu.so",
soname = "libpjrt_tpu.so",
data = [":sandbox"],
visibility = ["@zml//runtimes/tpu:__subpackages__"],
)

View File

@ -4,6 +4,11 @@ const std = @import("std");
const asynk = @import("async");
const pjrt = @import("pjrt");
const c = @import("c");
const stdx = @import("stdx");
const bazel_builtin = @import("bazel_builtin");
const runfiles = @import("runfiles");
const log = std.log.scoped(.@"zml/runtime/tpu");
pub fn isEnabled() bool {
return @hasDecl(c, "ZML_RUNTIME_TPU");
@ -37,5 +42,25 @@ pub fn load() !*const pjrt.Api {
return error.Unavailable;
}
return try asynk.callBlocking(pjrt.Api.loadFrom, .{"libpjrt_tpu.so"});
var arena = std.heap.ArenaAllocator.init(std.heap.c_allocator);
defer arena.deinit();
var r_ = try runfiles.Runfiles.create(.{ .allocator = arena.allocator() }) orelse {
stdx.debug.panic("Unable to find runfiles", .{});
};
const source_repo = bazel_builtin.current_repository;
const r = r_.withSourceRepo(source_repo);
var path_buf: [std.fs.max_path_bytes]u8 = undefined;
const sandbox_path = try r.rlocation("libpjrt_tpu/sandbox", &path_buf) orelse {
log.err("Failed to find sandbox path for TPU runtime", .{});
return error.FileNotFound;
};
return blk: {
var lib_path_buf: [std.fs.max_path_bytes]u8 = undefined;
const path = try stdx.fs.path.bufJoinZ(&lib_path_buf, &.{ sandbox_path, "lib", "libpjrt_tpu.so" });
break :blk asynk.callBlocking(pjrt.Api.loadFrom, .{path});
};
}