Introduce a thin abstraction layer between ZML and PJRT to manage plugin loading decisions, enable compile‑time detection of linked runtimes, and handle cases such as libtpu blocking metadata access.

This commit is contained in:
Tarry Singh 2023-05-15 09:36:41 +00:00
parent 74e90855ca
commit 54e7eb30b4
14 changed files with 270 additions and 73 deletions

View File

@ -13,7 +13,6 @@ zig_library(
visibility = ["//visibility:public"],
deps = [
":profiler_options_proto",
"//runtimes",
"@xla//xla/pjrt/c:pjrt_c_api_gpu_extension_hdrs",
"@xla//xla/pjrt/c:pjrt_c_api_hdrs",
"@xla//xla/pjrt/c:pjrt_c_api_profiler_extension_hdrs",

View File

@ -1,4 +1,5 @@
load("@bazel_skylib//rules:common_settings.bzl", "bool_flag")
load("@rules_zig//zig:defs.bzl", "zig_library")
RUNTIMES = {
"cpu": True,
@ -17,26 +18,21 @@ RUNTIMES = {
[
config_setting(
name = "_{}".format(runtime),
name = "{}.enabled".format(runtime),
flag_values = {":{}".format(runtime): "True"},
visibility = ["//runtimes:__subpackages__"],
)
for runtime in RUNTIMES.keys()
]
cc_library(
zig_library(
name = "runtimes",
main = "runtimes.zig",
visibility = ["//visibility:public"],
deps = select({
":_cpu": ["//runtimes/cpu"],
"//conditions:default": [],
}) + select({
":_cuda": ["//runtimes/cuda"],
"//conditions:default": [],
}) + select({
":_rocm": ["//runtimes/rocm"],
"//conditions:default": [],
}) + select({
":_tpu": ["//runtimes/tpu"],
"//conditions:default": [],
}),
deps = [
"//pjrt",
] + [
"//runtimes/{}".format(runtime)
for runtime in RUNTIMES.keys()
],
)

View File

@ -1,8 +1,27 @@
alias(
name = "cpu",
actual = select({
"@platforms//os:macos": "@libpjrt_cpu_darwin_arm64//:libpjrt_cpu",
"@platforms//os:linux": "@libpjrt_cpu_linux_amd64//:libpjrt_cpu",
}),
visibility = ["//visibility:public"],
load("@rules_zig//zig:defs.bzl", "zig_library")
cc_library(
name = "empty",
)
cc_library(
name = "libpjrt_cpu",
defines = ["ZML_RUNTIME_CPU"],
deps = select({
"@platforms//os:macos": ["@libpjrt_cpu_darwin_arm64//:libpjrt_cpu"],
"@platforms//os:linux": ["@libpjrt_cpu_linux_amd64//:libpjrt_cpu"],
}),
)
zig_library(
name = "cpu",
import_name = "runtimes/cpu",
main = "cpu.zig",
visibility = ["//visibility:public"],
deps = [
"//pjrt",
] + select({
"//runtimes:cpu.enabled": [":libpjrt_cpu"],
"//conditions:default": [":empty"],
}),
)

20
runtimes/cpu/cpu.zig Normal file
View File

@ -0,0 +1,20 @@
const builtin = @import("builtin");
const pjrt = @import("pjrt");
const c = @import("c");
pub fn isEnabled() bool {
return @hasDecl(c, "ZML_RUNTIME_CPU");
}
pub fn load() !*const pjrt.Api {
if (comptime !isEnabled()) {
return error.Unavailable;
}
const ext = switch (builtin.os.tag) {
.windows => ".dll",
.macos, .ios, .watchos => ".dylib",
else => ".so",
};
return try pjrt.Api.loadFrom("libpjrt_cpu" ++ ext);
}

View File

@ -1,5 +1,27 @@
alias(
name = "cuda",
actual = "@libpjrt_cuda",
visibility = ["//visibility:public"],
load("@rules_zig//zig:defs.bzl", "zig_library")
cc_library(
name = "empty",
)
cc_library(
name = "libpjrt_cuda",
defines = ["ZML_RUNTIME_CUDA"],
deps = ["@libpjrt_cuda"],
)
zig_library(
name = "cuda",
import_name = "runtimes/cuda",
main = "cuda.zig",
visibility = ["//visibility:public"],
deps = [
"//pjrt",
] + select({
"//runtimes:cuda.enabled": [
":libpjrt_cuda",
"//async",
],
"//conditions:default": [":empty"],
}),
)

27
runtimes/cuda/cuda.zig Normal file
View File

@ -0,0 +1,27 @@
const builtin = @import("builtin");
const asynk = @import("async");
const pjrt = @import("pjrt");
const c = @import("c");
pub fn isEnabled() bool {
return @hasDecl(c, "ZML_RUNTIME_CUDA");
}
fn hasNvidiaDevice() bool {
asynk.File.access("/dev/nvidia0", .{ .mode = .read_only }) catch return false;
return true;
}
pub fn load() !*const pjrt.Api {
if (comptime !isEnabled()) {
return error.Unavailable;
}
if (comptime builtin.os.tag != .linux) {
return error.Unavailable;
}
if (!hasNvidiaDevice()) {
return error.Unavailable;
}
return try pjrt.Api.loadFrom("libpjrt_cuda.so");
}

View File

@ -1,3 +1,5 @@
load("@rules_zig//zig:defs.bzl", "zig_library")
filegroup(
name = "zmlxrocm_srcs",
srcs = ["zmlxrocm.cc"],
@ -14,8 +16,28 @@ alias(
actual = "@libpjrt_rocm//:gfx",
)
alias(
name = "rocm",
actual = "@libpjrt_rocm",
visibility = ["//visibility:public"],
cc_library(
name = "empty",
)
cc_library(
name = "libpjrt_rocm",
defines = ["ZML_RUNTIME_ROCM"],
deps = ["@libpjrt_rocm"],
)
zig_library(
name = "rocm",
import_name = "runtimes/rocm",
main = "rocm.zig",
visibility = ["//visibility:public"],
deps = [
"//pjrt",
] + select({
"//runtimes:rocm.enabled": [
":libpjrt_rocm",
"//async",
],
"//conditions:default": [":empty"],
}),
)

29
runtimes/rocm/rocm.zig Normal file
View File

@ -0,0 +1,29 @@
const builtin = @import("builtin");
const asynk = @import("async");
const pjrt = @import("pjrt");
const c = @import("c");
pub fn isEnabled() bool {
return @hasDecl(c, "ZML_RUNTIME_ROCM");
}
fn hasRocmDevices() bool {
inline for (&.{ "/dev/kfd", "/dev/dri" }) |path| {
asynk.File.access(path, .{ .mode = .read_only }) catch return false;
}
return true;
}
pub fn load() !*const pjrt.Api {
if (comptime !isEnabled()) {
return error.Unavailable;
}
if (comptime builtin.os.tag != .linux) {
return error.Unavailable;
}
if (!hasRocmDevices()) {
return error.Unavailable;
}
return try pjrt.Api.loadFrom("libpjrt_rocm.so");
}

View File

@ -8,7 +8,7 @@
#include "tools/cpp/runfiles/runfiles.h"
__attribute__((constructor)) static void setup_runfiles(int argc, char **argv)
static void setup_runfiles(int argc, char **argv) __attribute__((constructor))
{
using bazel::tools::cpp::runfiles::Runfiles;
auto runfiles = std::unique_ptr<Runfiles>(Runfiles::Create(argv[0], BAZEL_CURRENT_REPOSITORY));
@ -33,7 +33,7 @@ __attribute__((constructor)) static void setup_runfiles(int argc, char **argv)
setenv("ROCM_PATH", ROCM_PATH.c_str(), 1);
}
extern "C" void *zmlxrocm_dlopen(const char *filename, int flags)
extern "C" void *zmlxrocm_dlopen(const char *filename, int flags) __attribute__((visibility("default")))
{
if (filename != NULL)
{

30
runtimes/runtimes.zig Normal file
View File

@ -0,0 +1,30 @@
const pjrt = @import("pjrt");
const cpu = @import("runtimes/cpu");
const cuda = @import("runtimes/cuda");
const rocm = @import("runtimes/rocm");
const tpu = @import("runtimes/tpu");
pub const Platform = enum {
cpu,
cuda,
rocm,
tpu,
};
pub fn load(tag: Platform) !*const pjrt.Api {
return switch (tag) {
.cpu => try cpu.load(),
.cuda => try cuda.load(),
.rocm => try rocm.load(),
.tpu => try tpu.load(),
};
}
pub fn isEnabled(tag: Platform) bool {
return switch (tag) {
.cpu => cpu.isEnabled(),
.cuda => cuda.isEnabled(),
.rocm => rocm.isEnabled(),
.tpu => tpu.isEnabled(),
};
}

View File

@ -1,5 +1,27 @@
alias(
name = "tpu",
actual = "@libpjrt_tpu",
visibility = ["//visibility:public"],
load("@rules_zig//zig:defs.bzl", "zig_library")
cc_library(
name = "empty",
)
cc_library(
name = "libpjrt_tpu",
defines = ["ZML_RUNTIME_TPU"],
deps = ["@libpjrt_tpu"],
)
zig_library(
name = "tpu",
import_name = "runtimes/tpu",
main = "tpu.zig",
visibility = ["//visibility:public"],
deps = [
"//pjrt",
] + select({
"//runtimes:tpu.enabled": [
":libpjrt_tpu",
"//async",
],
"//conditions:default": [":empty"],
}),
)

40
runtimes/tpu/tpu.zig Normal file
View File

@ -0,0 +1,40 @@
const builtin = @import("builtin");
const asynk = @import("async");
const pjrt = @import("pjrt");
const c = @import("c");
const std = @import("std");
pub fn isEnabled() bool {
return @hasDecl(c, "ZML_RUNTIME_TPU");
}
/// Check if running on Google Compute Engine, because TPUs will poll the
/// metadata server, hanging the process. So only do it on GCP.
/// Do it using the official method at:
/// https://cloud.google.com/compute/docs/instances/detect-compute-engine?hl=en#use_operating_system_tools_to_detect_if_a_vm_is_running_in
fn isOnGCP() !bool {
// TODO: abstract that in the client and fail init
const GoogleComputeEngine = "Google Compute Engine";
var f = try asynk.File.open("/sys/devices/virtual/dmi/id/product_name", .{ .mode = .read_only });
defer f.close() catch {};
var buf = [_]u8{0} ** GoogleComputeEngine.len;
_ = try f.reader().readAll(&buf);
return std.mem.eql(u8, &buf, GoogleComputeEngine);
}
pub fn load() !*const pjrt.Api {
if (comptime !isEnabled()) {
return error.Unavailable;
}
if (comptime builtin.os.tag != .linux) {
return error.Unavailable;
}
if (!(isOnGCP() catch false)) {
return error.Unavailable;
}
return try pjrt.Api.loadFrom("libpjrt_tpu.so");
}

View File

@ -6,6 +6,7 @@ const mlir = @import("mlir");
const pjrt = @import("pjrt");
const c = @import("c");
const runfiles = @import("runfiles");
const runtimes = @import("runtimes");
const platform = @import("platform.zig");
const Target = @import("platform.zig").Target;
@ -26,14 +27,12 @@ pub const Context = struct {
var apis = PjrtApiMap.initFill(null);
var apis_once = std.once(struct {
fn call() void {
inline for (platform.available_targets) |t| {
if (canLoad(t)) {
if (pjrt.Api.loadFrom(platformToLibrary(t))) |api| {
inline for (comptime std.enums.values(runtimes.Platform)) |t| {
if (runtimes.load(t)) |api| {
Context.apis.set(t, api);
} else |_| {}
}
}
}
}.call);
var mlir_once = std.once(struct {
@ -112,30 +111,6 @@ pub const Context = struct {
};
}
fn canLoad(t: Target) bool {
return switch (t) {
.tpu => isRunningOnGCP() catch false,
else => true,
};
}
/// Check if running on Google Compute Engine, because TPUs will poll the
/// metadata server, hanging the process. So only do it on GCP.
/// Do it using the official method at:
/// https://cloud.google.com/compute/docs/instances/detect-compute-engine?hl=en#use_operating_system_tools_to_detect_if_a_vm_is_running_in
fn isRunningOnGCP() !bool {
// TODO: abstract that in the client and fail init
const GoogleComputeEngine = "Google Compute Engine";
var f = try asynk.File.open("/sys/devices/virtual/dmi/id/product_name", .{ .mode = .read_only });
defer f.close() catch {};
var buf = [_]u8{0} ** GoogleComputeEngine.len;
_ = try f.reader().readAll(&buf);
return std.mem.eql(u8, &buf, GoogleComputeEngine);
}
pub fn pjrtApi(target: Target) *const pjrt.Api {
return Context.apis.get(target).?;
}

View File

@ -3,17 +3,13 @@ const std = @import("std");
const pjrt = @import("pjrt");
const asynk = @import("async");
const runtimes = @import("runtimes");
const meta = @import("meta.zig");
const module = @import("module.zig");
const log = std.log.scoped(.zml);
pub const Target = enum {
cpu,
cuda,
rocm,
tpu,
};
pub const Target = runtimes.Platform;
pub const available_targets = switch (builtin.os.tag) {
.macos => [_]Target{