Introduce a thin abstraction layer between ZML and PJRT to manage plugin loading decisions, enable compile‑time detection of linked runtimes, and handle cases such as libtpu blocking metadata access.
This commit is contained in:
parent
74e90855ca
commit
54e7eb30b4
@ -13,7 +13,6 @@ zig_library(
|
|||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
":profiler_options_proto",
|
":profiler_options_proto",
|
||||||
"//runtimes",
|
|
||||||
"@xla//xla/pjrt/c:pjrt_c_api_gpu_extension_hdrs",
|
"@xla//xla/pjrt/c:pjrt_c_api_gpu_extension_hdrs",
|
||||||
"@xla//xla/pjrt/c:pjrt_c_api_hdrs",
|
"@xla//xla/pjrt/c:pjrt_c_api_hdrs",
|
||||||
"@xla//xla/pjrt/c:pjrt_c_api_profiler_extension_hdrs",
|
"@xla//xla/pjrt/c:pjrt_c_api_profiler_extension_hdrs",
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
load("@bazel_skylib//rules:common_settings.bzl", "bool_flag")
|
load("@bazel_skylib//rules:common_settings.bzl", "bool_flag")
|
||||||
|
load("@rules_zig//zig:defs.bzl", "zig_library")
|
||||||
|
|
||||||
RUNTIMES = {
|
RUNTIMES = {
|
||||||
"cpu": True,
|
"cpu": True,
|
||||||
@ -17,26 +18,21 @@ RUNTIMES = {
|
|||||||
|
|
||||||
[
|
[
|
||||||
config_setting(
|
config_setting(
|
||||||
name = "_{}".format(runtime),
|
name = "{}.enabled".format(runtime),
|
||||||
flag_values = {":{}".format(runtime): "True"},
|
flag_values = {":{}".format(runtime): "True"},
|
||||||
|
visibility = ["//runtimes:__subpackages__"],
|
||||||
)
|
)
|
||||||
for runtime in RUNTIMES.keys()
|
for runtime in RUNTIMES.keys()
|
||||||
]
|
]
|
||||||
|
|
||||||
cc_library(
|
zig_library(
|
||||||
name = "runtimes",
|
name = "runtimes",
|
||||||
|
main = "runtimes.zig",
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = select({
|
deps = [
|
||||||
":_cpu": ["//runtimes/cpu"],
|
"//pjrt",
|
||||||
"//conditions:default": [],
|
] + [
|
||||||
}) + select({
|
"//runtimes/{}".format(runtime)
|
||||||
":_cuda": ["//runtimes/cuda"],
|
for runtime in RUNTIMES.keys()
|
||||||
"//conditions:default": [],
|
],
|
||||||
}) + select({
|
|
||||||
":_rocm": ["//runtimes/rocm"],
|
|
||||||
"//conditions:default": [],
|
|
||||||
}) + select({
|
|
||||||
":_tpu": ["//runtimes/tpu"],
|
|
||||||
"//conditions:default": [],
|
|
||||||
}),
|
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,8 +1,27 @@
|
|||||||
alias(
|
load("@rules_zig//zig:defs.bzl", "zig_library")
|
||||||
name = "cpu",
|
|
||||||
actual = select({
|
cc_library(
|
||||||
"@platforms//os:macos": "@libpjrt_cpu_darwin_arm64//:libpjrt_cpu",
|
name = "empty",
|
||||||
"@platforms//os:linux": "@libpjrt_cpu_linux_amd64//:libpjrt_cpu",
|
)
|
||||||
}),
|
|
||||||
visibility = ["//visibility:public"],
|
cc_library(
|
||||||
|
name = "libpjrt_cpu",
|
||||||
|
defines = ["ZML_RUNTIME_CPU"],
|
||||||
|
deps = select({
|
||||||
|
"@platforms//os:macos": ["@libpjrt_cpu_darwin_arm64//:libpjrt_cpu"],
|
||||||
|
"@platforms//os:linux": ["@libpjrt_cpu_linux_amd64//:libpjrt_cpu"],
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
|
||||||
|
zig_library(
|
||||||
|
name = "cpu",
|
||||||
|
import_name = "runtimes/cpu",
|
||||||
|
main = "cpu.zig",
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
"//pjrt",
|
||||||
|
] + select({
|
||||||
|
"//runtimes:cpu.enabled": [":libpjrt_cpu"],
|
||||||
|
"//conditions:default": [":empty"],
|
||||||
|
}),
|
||||||
)
|
)
|
||||||
|
|||||||
20
runtimes/cpu/cpu.zig
Normal file
20
runtimes/cpu/cpu.zig
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
const builtin = @import("builtin");
|
||||||
|
const pjrt = @import("pjrt");
|
||||||
|
const c = @import("c");
|
||||||
|
|
||||||
|
pub fn isEnabled() bool {
|
||||||
|
return @hasDecl(c, "ZML_RUNTIME_CPU");
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load() !*const pjrt.Api {
|
||||||
|
if (comptime !isEnabled()) {
|
||||||
|
return error.Unavailable;
|
||||||
|
}
|
||||||
|
|
||||||
|
const ext = switch (builtin.os.tag) {
|
||||||
|
.windows => ".dll",
|
||||||
|
.macos, .ios, .watchos => ".dylib",
|
||||||
|
else => ".so",
|
||||||
|
};
|
||||||
|
return try pjrt.Api.loadFrom("libpjrt_cpu" ++ ext);
|
||||||
|
}
|
||||||
@ -1,5 +1,27 @@
|
|||||||
alias(
|
load("@rules_zig//zig:defs.bzl", "zig_library")
|
||||||
name = "cuda",
|
|
||||||
actual = "@libpjrt_cuda",
|
cc_library(
|
||||||
visibility = ["//visibility:public"],
|
name = "empty",
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "libpjrt_cuda",
|
||||||
|
defines = ["ZML_RUNTIME_CUDA"],
|
||||||
|
deps = ["@libpjrt_cuda"],
|
||||||
|
)
|
||||||
|
|
||||||
|
zig_library(
|
||||||
|
name = "cuda",
|
||||||
|
import_name = "runtimes/cuda",
|
||||||
|
main = "cuda.zig",
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
"//pjrt",
|
||||||
|
] + select({
|
||||||
|
"//runtimes:cuda.enabled": [
|
||||||
|
":libpjrt_cuda",
|
||||||
|
"//async",
|
||||||
|
],
|
||||||
|
"//conditions:default": [":empty"],
|
||||||
|
}),
|
||||||
)
|
)
|
||||||
|
|||||||
27
runtimes/cuda/cuda.zig
Normal file
27
runtimes/cuda/cuda.zig
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
const builtin = @import("builtin");
|
||||||
|
const asynk = @import("async");
|
||||||
|
const pjrt = @import("pjrt");
|
||||||
|
const c = @import("c");
|
||||||
|
|
||||||
|
pub fn isEnabled() bool {
|
||||||
|
return @hasDecl(c, "ZML_RUNTIME_CUDA");
|
||||||
|
}
|
||||||
|
|
||||||
|
fn hasNvidiaDevice() bool {
|
||||||
|
asynk.File.access("/dev/nvidia0", .{ .mode = .read_only }) catch return false;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load() !*const pjrt.Api {
|
||||||
|
if (comptime !isEnabled()) {
|
||||||
|
return error.Unavailable;
|
||||||
|
}
|
||||||
|
if (comptime builtin.os.tag != .linux) {
|
||||||
|
return error.Unavailable;
|
||||||
|
}
|
||||||
|
if (!hasNvidiaDevice()) {
|
||||||
|
return error.Unavailable;
|
||||||
|
}
|
||||||
|
|
||||||
|
return try pjrt.Api.loadFrom("libpjrt_cuda.so");
|
||||||
|
}
|
||||||
@ -1,3 +1,5 @@
|
|||||||
|
load("@rules_zig//zig:defs.bzl", "zig_library")
|
||||||
|
|
||||||
filegroup(
|
filegroup(
|
||||||
name = "zmlxrocm_srcs",
|
name = "zmlxrocm_srcs",
|
||||||
srcs = ["zmlxrocm.cc"],
|
srcs = ["zmlxrocm.cc"],
|
||||||
@ -14,8 +16,28 @@ alias(
|
|||||||
actual = "@libpjrt_rocm//:gfx",
|
actual = "@libpjrt_rocm//:gfx",
|
||||||
)
|
)
|
||||||
|
|
||||||
alias(
|
cc_library(
|
||||||
name = "rocm",
|
name = "empty",
|
||||||
actual = "@libpjrt_rocm",
|
)
|
||||||
visibility = ["//visibility:public"],
|
|
||||||
|
cc_library(
|
||||||
|
name = "libpjrt_rocm",
|
||||||
|
defines = ["ZML_RUNTIME_ROCM"],
|
||||||
|
deps = ["@libpjrt_rocm"],
|
||||||
|
)
|
||||||
|
|
||||||
|
zig_library(
|
||||||
|
name = "rocm",
|
||||||
|
import_name = "runtimes/rocm",
|
||||||
|
main = "rocm.zig",
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
"//pjrt",
|
||||||
|
] + select({
|
||||||
|
"//runtimes:rocm.enabled": [
|
||||||
|
":libpjrt_rocm",
|
||||||
|
"//async",
|
||||||
|
],
|
||||||
|
"//conditions:default": [":empty"],
|
||||||
|
}),
|
||||||
)
|
)
|
||||||
|
|||||||
29
runtimes/rocm/rocm.zig
Normal file
29
runtimes/rocm/rocm.zig
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
const builtin = @import("builtin");
|
||||||
|
const asynk = @import("async");
|
||||||
|
const pjrt = @import("pjrt");
|
||||||
|
const c = @import("c");
|
||||||
|
|
||||||
|
pub fn isEnabled() bool {
|
||||||
|
return @hasDecl(c, "ZML_RUNTIME_ROCM");
|
||||||
|
}
|
||||||
|
|
||||||
|
fn hasRocmDevices() bool {
|
||||||
|
inline for (&.{ "/dev/kfd", "/dev/dri" }) |path| {
|
||||||
|
asynk.File.access(path, .{ .mode = .read_only }) catch return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load() !*const pjrt.Api {
|
||||||
|
if (comptime !isEnabled()) {
|
||||||
|
return error.Unavailable;
|
||||||
|
}
|
||||||
|
if (comptime builtin.os.tag != .linux) {
|
||||||
|
return error.Unavailable;
|
||||||
|
}
|
||||||
|
if (!hasRocmDevices()) {
|
||||||
|
return error.Unavailable;
|
||||||
|
}
|
||||||
|
|
||||||
|
return try pjrt.Api.loadFrom("libpjrt_rocm.so");
|
||||||
|
}
|
||||||
@ -8,7 +8,7 @@
|
|||||||
|
|
||||||
#include "tools/cpp/runfiles/runfiles.h"
|
#include "tools/cpp/runfiles/runfiles.h"
|
||||||
|
|
||||||
__attribute__((constructor)) static void setup_runfiles(int argc, char **argv)
|
static void setup_runfiles(int argc, char **argv) __attribute__((constructor))
|
||||||
{
|
{
|
||||||
using bazel::tools::cpp::runfiles::Runfiles;
|
using bazel::tools::cpp::runfiles::Runfiles;
|
||||||
auto runfiles = std::unique_ptr<Runfiles>(Runfiles::Create(argv[0], BAZEL_CURRENT_REPOSITORY));
|
auto runfiles = std::unique_ptr<Runfiles>(Runfiles::Create(argv[0], BAZEL_CURRENT_REPOSITORY));
|
||||||
@ -33,7 +33,7 @@ __attribute__((constructor)) static void setup_runfiles(int argc, char **argv)
|
|||||||
setenv("ROCM_PATH", ROCM_PATH.c_str(), 1);
|
setenv("ROCM_PATH", ROCM_PATH.c_str(), 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "C" void *zmlxrocm_dlopen(const char *filename, int flags)
|
extern "C" void *zmlxrocm_dlopen(const char *filename, int flags) __attribute__((visibility("default")))
|
||||||
{
|
{
|
||||||
if (filename != NULL)
|
if (filename != NULL)
|
||||||
{
|
{
|
||||||
|
|||||||
30
runtimes/runtimes.zig
Normal file
30
runtimes/runtimes.zig
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
const pjrt = @import("pjrt");
|
||||||
|
const cpu = @import("runtimes/cpu");
|
||||||
|
const cuda = @import("runtimes/cuda");
|
||||||
|
const rocm = @import("runtimes/rocm");
|
||||||
|
const tpu = @import("runtimes/tpu");
|
||||||
|
|
||||||
|
pub const Platform = enum {
|
||||||
|
cpu,
|
||||||
|
cuda,
|
||||||
|
rocm,
|
||||||
|
tpu,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub fn load(tag: Platform) !*const pjrt.Api {
|
||||||
|
return switch (tag) {
|
||||||
|
.cpu => try cpu.load(),
|
||||||
|
.cuda => try cuda.load(),
|
||||||
|
.rocm => try rocm.load(),
|
||||||
|
.tpu => try tpu.load(),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn isEnabled(tag: Platform) bool {
|
||||||
|
return switch (tag) {
|
||||||
|
.cpu => cpu.isEnabled(),
|
||||||
|
.cuda => cuda.isEnabled(),
|
||||||
|
.rocm => rocm.isEnabled(),
|
||||||
|
.tpu => tpu.isEnabled(),
|
||||||
|
};
|
||||||
|
}
|
||||||
@ -1,5 +1,27 @@
|
|||||||
alias(
|
load("@rules_zig//zig:defs.bzl", "zig_library")
|
||||||
name = "tpu",
|
|
||||||
actual = "@libpjrt_tpu",
|
cc_library(
|
||||||
visibility = ["//visibility:public"],
|
name = "empty",
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "libpjrt_tpu",
|
||||||
|
defines = ["ZML_RUNTIME_TPU"],
|
||||||
|
deps = ["@libpjrt_tpu"],
|
||||||
|
)
|
||||||
|
|
||||||
|
zig_library(
|
||||||
|
name = "tpu",
|
||||||
|
import_name = "runtimes/tpu",
|
||||||
|
main = "tpu.zig",
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
"//pjrt",
|
||||||
|
] + select({
|
||||||
|
"//runtimes:tpu.enabled": [
|
||||||
|
":libpjrt_tpu",
|
||||||
|
"//async",
|
||||||
|
],
|
||||||
|
"//conditions:default": [":empty"],
|
||||||
|
}),
|
||||||
)
|
)
|
||||||
|
|||||||
40
runtimes/tpu/tpu.zig
Normal file
40
runtimes/tpu/tpu.zig
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
const builtin = @import("builtin");
|
||||||
|
const asynk = @import("async");
|
||||||
|
const pjrt = @import("pjrt");
|
||||||
|
const c = @import("c");
|
||||||
|
const std = @import("std");
|
||||||
|
|
||||||
|
pub fn isEnabled() bool {
|
||||||
|
return @hasDecl(c, "ZML_RUNTIME_TPU");
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if running on Google Compute Engine, because TPUs will poll the
|
||||||
|
/// metadata server, hanging the process. So only do it on GCP.
|
||||||
|
/// Do it using the official method at:
|
||||||
|
/// https://cloud.google.com/compute/docs/instances/detect-compute-engine?hl=en#use_operating_system_tools_to_detect_if_a_vm_is_running_in
|
||||||
|
fn isOnGCP() !bool {
|
||||||
|
// TODO: abstract that in the client and fail init
|
||||||
|
const GoogleComputeEngine = "Google Compute Engine";
|
||||||
|
|
||||||
|
var f = try asynk.File.open("/sys/devices/virtual/dmi/id/product_name", .{ .mode = .read_only });
|
||||||
|
defer f.close() catch {};
|
||||||
|
|
||||||
|
var buf = [_]u8{0} ** GoogleComputeEngine.len;
|
||||||
|
_ = try f.reader().readAll(&buf);
|
||||||
|
|
||||||
|
return std.mem.eql(u8, &buf, GoogleComputeEngine);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load() !*const pjrt.Api {
|
||||||
|
if (comptime !isEnabled()) {
|
||||||
|
return error.Unavailable;
|
||||||
|
}
|
||||||
|
if (comptime builtin.os.tag != .linux) {
|
||||||
|
return error.Unavailable;
|
||||||
|
}
|
||||||
|
if (!(isOnGCP() catch false)) {
|
||||||
|
return error.Unavailable;
|
||||||
|
}
|
||||||
|
|
||||||
|
return try pjrt.Api.loadFrom("libpjrt_tpu.so");
|
||||||
|
}
|
||||||
@ -6,6 +6,7 @@ const mlir = @import("mlir");
|
|||||||
const pjrt = @import("pjrt");
|
const pjrt = @import("pjrt");
|
||||||
const c = @import("c");
|
const c = @import("c");
|
||||||
const runfiles = @import("runfiles");
|
const runfiles = @import("runfiles");
|
||||||
|
const runtimes = @import("runtimes");
|
||||||
|
|
||||||
const platform = @import("platform.zig");
|
const platform = @import("platform.zig");
|
||||||
const Target = @import("platform.zig").Target;
|
const Target = @import("platform.zig").Target;
|
||||||
@ -26,12 +27,10 @@ pub const Context = struct {
|
|||||||
var apis = PjrtApiMap.initFill(null);
|
var apis = PjrtApiMap.initFill(null);
|
||||||
var apis_once = std.once(struct {
|
var apis_once = std.once(struct {
|
||||||
fn call() void {
|
fn call() void {
|
||||||
inline for (platform.available_targets) |t| {
|
inline for (comptime std.enums.values(runtimes.Platform)) |t| {
|
||||||
if (canLoad(t)) {
|
if (runtimes.load(t)) |api| {
|
||||||
if (pjrt.Api.loadFrom(platformToLibrary(t))) |api| {
|
Context.apis.set(t, api);
|
||||||
Context.apis.set(t, api);
|
} else |_| {}
|
||||||
} else |_| {}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}.call);
|
}.call);
|
||||||
@ -112,30 +111,6 @@ pub const Context = struct {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
fn canLoad(t: Target) bool {
|
|
||||||
return switch (t) {
|
|
||||||
.tpu => isRunningOnGCP() catch false,
|
|
||||||
else => true,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Check if running on Google Compute Engine, because TPUs will poll the
|
|
||||||
/// metadata server, hanging the process. So only do it on GCP.
|
|
||||||
/// Do it using the official method at:
|
|
||||||
/// https://cloud.google.com/compute/docs/instances/detect-compute-engine?hl=en#use_operating_system_tools_to_detect_if_a_vm_is_running_in
|
|
||||||
fn isRunningOnGCP() !bool {
|
|
||||||
// TODO: abstract that in the client and fail init
|
|
||||||
const GoogleComputeEngine = "Google Compute Engine";
|
|
||||||
|
|
||||||
var f = try asynk.File.open("/sys/devices/virtual/dmi/id/product_name", .{ .mode = .read_only });
|
|
||||||
defer f.close() catch {};
|
|
||||||
|
|
||||||
var buf = [_]u8{0} ** GoogleComputeEngine.len;
|
|
||||||
_ = try f.reader().readAll(&buf);
|
|
||||||
|
|
||||||
return std.mem.eql(u8, &buf, GoogleComputeEngine);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn pjrtApi(target: Target) *const pjrt.Api {
|
pub fn pjrtApi(target: Target) *const pjrt.Api {
|
||||||
return Context.apis.get(target).?;
|
return Context.apis.get(target).?;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -3,17 +3,13 @@ const std = @import("std");
|
|||||||
|
|
||||||
const pjrt = @import("pjrt");
|
const pjrt = @import("pjrt");
|
||||||
const asynk = @import("async");
|
const asynk = @import("async");
|
||||||
|
const runtimes = @import("runtimes");
|
||||||
|
|
||||||
const meta = @import("meta.zig");
|
const meta = @import("meta.zig");
|
||||||
const module = @import("module.zig");
|
const module = @import("module.zig");
|
||||||
const log = std.log.scoped(.zml);
|
const log = std.log.scoped(.zml);
|
||||||
|
|
||||||
pub const Target = enum {
|
pub const Target = runtimes.Platform;
|
||||||
cpu,
|
|
||||||
cuda,
|
|
||||||
rocm,
|
|
||||||
tpu,
|
|
||||||
};
|
|
||||||
|
|
||||||
pub const available_targets = switch (builtin.os.tag) {
|
pub const available_targets = switch (builtin.os.tag) {
|
||||||
.macos => [_]Target{
|
.macos => [_]Target{
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user