From 78d7b672e77c8390f3edc451188f0029dfe6d496 Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Thu, 3 Apr 2025 11:57:46 +0000 Subject: [PATCH] runtimes/cpu: sandbox CPU PJRT plugin, simplifying as there are no additional NEEDED dependencies. --- runtimes/cpu/BUILD.bazel | 18 ++++++++++++++---- runtimes/cpu/cpu.bzl | 10 ++++------ runtimes/cpu/cpu.zig | 37 ++++++++++++++++++++++++++++++++----- 3 files changed, 50 insertions(+), 15 deletions(-) diff --git a/runtimes/cpu/BUILD.bazel b/runtimes/cpu/BUILD.bazel index e16f1d9..fe803c2 100644 --- a/runtimes/cpu/BUILD.bazel +++ b/runtimes/cpu/BUILD.bazel @@ -1,3 +1,4 @@ +load("@aspect_bazel_lib//lib:copy_to_directory.bzl", "copy_to_directory") load("@rules_zig//zig:defs.bzl", "zig_library") config_setting( @@ -20,14 +21,21 @@ cc_library( name = "empty", ) -cc_library( - name = "libpjrt_cpu", - defines = ["ZML_RUNTIME_CPU"], - deps = select({ +copy_to_directory( + name = "sandbox", + out = "sandbox/lib", + srcs = select({ ":darwin_arm64": ["@libpjrt_cpu_darwin_arm64//:libpjrt_cpu"], ":darwin_amd64": ["@libpjrt_cpu_darwin_amd64//:libpjrt_cpu"], "@platforms//os:linux": ["@libpjrt_cpu_linux_amd64//:libpjrt_cpu"], }), + include_external_repositories = ["**"], +) + +cc_library( + name = "libpjrt_cpu", + defines = ["ZML_RUNTIME_CPU"], + data = [":sandbox"], ) zig_library( @@ -41,6 +49,8 @@ zig_library( "//runtimes:cpu.enabled": [ ":libpjrt_cpu", "//async", + "//stdx", + "@rules_zig//zig/runfiles", ], "//conditions:default": [":empty"], }), diff --git a/runtimes/cpu/cpu.bzl b/runtimes/cpu/cpu.bzl index 4dc62b7..6d9808d 100644 --- a/runtimes/cpu/cpu.bzl +++ b/runtimes/cpu/cpu.bzl @@ -6,18 +6,16 @@ package(default_visibility = ["//visibility:public"]) """ _BUILD_LINUX = "\n".join([ - packages.load_("@zml//bazel:cc_import.bzl", "cc_import"), - packages.cc_import( + packages.filegroup( name = "libpjrt_cpu", - shared_library = "libpjrt_cpu.so", - soname = "libpjrt_cpu.so", + srcs = ["libpjrt_cpu.so"], visibility = ["@zml//runtimes/cpu:__subpackages__"], ), ]) -_BUILD_DARWIN = packages.cc_import( +_BUILD_DARWIN = packages.filegroup( name = "libpjrt_cpu", - shared_library = "libpjrt_cpu.dylib", + srcs = ["libpjrt_cpu.dylib"], visibility = ["@zml//runtimes/cpu:__subpackages__"], ) diff --git a/runtimes/cpu/cpu.zig b/runtimes/cpu/cpu.zig index 2270398..ff22cf5 100644 --- a/runtimes/cpu/cpu.zig +++ b/runtimes/cpu/cpu.zig @@ -3,6 +3,12 @@ const builtin = @import("builtin"); const asynk = @import("async"); const c = @import("c"); const pjrt = @import("pjrt"); +const bazel_builtin = @import("bazel_builtin"); +const std = @import("std"); +const stdx = @import("stdx"); +const runfiles = @import("runfiles"); + +const log = std.log.scoped(.@"zml/runtime/cpu"); pub fn isEnabled() bool { return @hasDecl(c, "ZML_RUNTIME_CPU"); @@ -13,10 +19,31 @@ pub fn load() !*const pjrt.Api { return error.Unavailable; } - const ext = switch (builtin.os.tag) { - .windows => ".dll", - .macos, .ios, .watchos => ".dylib", - else => ".so", + var arena = std.heap.ArenaAllocator.init(std.heap.c_allocator); + defer arena.deinit(); + + var r_ = try runfiles.Runfiles.create(.{ .allocator = arena.allocator() }) orelse { + stdx.debug.panic("Unable to find runfiles", .{}); + }; + + const source_repo = bazel_builtin.current_repository; + const r = r_.withSourceRepo(source_repo); + + var path_buf: [std.fs.max_path_bytes]u8 = undefined; + const sandbox_path = try r.rlocation("zml/runtimes/cpu/sandbox/lib", &path_buf) orelse { + log.err("Failed to find sandbox path for CPU runtime", .{}); + return error.FileNotFound; + }; + + return blk: { + const ext = switch (builtin.os.tag) { + .windows => ".dll", + .macos, .ios, .watchos => ".dylib", + else => ".so", + }; + + var lib_path_buf: [std.fs.max_path_bytes]u8 = undefined; + const path = try stdx.fs.path.bufJoinZ(&lib_path_buf, &.{ sandbox_path, "libpjrt_cpu" ++ ext }); + break :blk asynk.callBlocking(pjrt.Api.loadFrom, .{path}); }; - return try asynk.callBlocking(pjrt.Api.loadFrom, .{"libpjrt_cpu" ++ ext}); }