diff --git a/pjrt/BUILD.bazel b/pjrt/BUILD.bazel index 4882d87..cd47616 100644 --- a/pjrt/BUILD.bazel +++ b/pjrt/BUILD.bazel @@ -13,7 +13,6 @@ zig_library( visibility = ["//visibility:public"], deps = [ ":profiler_options_proto", - "//runtimes", "@xla//xla/pjrt/c:pjrt_c_api_gpu_extension_hdrs", "@xla//xla/pjrt/c:pjrt_c_api_hdrs", "@xla//xla/pjrt/c:pjrt_c_api_profiler_extension_hdrs", diff --git a/runtimes/BUILD.bazel b/runtimes/BUILD.bazel index 291651c..7aa038d 100644 --- a/runtimes/BUILD.bazel +++ b/runtimes/BUILD.bazel @@ -1,4 +1,5 @@ load("@bazel_skylib//rules:common_settings.bzl", "bool_flag") +load("@rules_zig//zig:defs.bzl", "zig_library") RUNTIMES = { "cpu": True, @@ -17,26 +18,21 @@ RUNTIMES = { [ config_setting( - name = "_{}".format(runtime), + name = "{}.enabled".format(runtime), flag_values = {":{}".format(runtime): "True"}, + visibility = ["//runtimes:__subpackages__"], ) for runtime in RUNTIMES.keys() ] -cc_library( +zig_library( name = "runtimes", + main = "runtimes.zig", visibility = ["//visibility:public"], - deps = select({ - ":_cpu": ["//runtimes/cpu"], - "//conditions:default": [], - }) + select({ - ":_cuda": ["//runtimes/cuda"], - "//conditions:default": [], - }) + select({ - ":_rocm": ["//runtimes/rocm"], - "//conditions:default": [], - }) + select({ - ":_tpu": ["//runtimes/tpu"], - "//conditions:default": [], - }), + deps = [ + "//pjrt", + ] + [ + "//runtimes/{}".format(runtime) + for runtime in RUNTIMES.keys() + ], ) diff --git a/runtimes/cpu/BUILD.bazel b/runtimes/cpu/BUILD.bazel index ac3ac6d..5877b0a 100644 --- a/runtimes/cpu/BUILD.bazel +++ b/runtimes/cpu/BUILD.bazel @@ -1,8 +1,27 @@ -alias( - name = "cpu", - actual = select({ - "@platforms//os:macos": "@libpjrt_cpu_darwin_arm64//:libpjrt_cpu", - "@platforms//os:linux": "@libpjrt_cpu_linux_amd64//:libpjrt_cpu", - }), - visibility = ["//visibility:public"], +load("@rules_zig//zig:defs.bzl", "zig_library") + +cc_library( + name = "empty", +) + +cc_library( + name = "libpjrt_cpu", + defines = ["ZML_RUNTIME_CPU"], + deps = select({ + "@platforms//os:macos": ["@libpjrt_cpu_darwin_arm64//:libpjrt_cpu"], + "@platforms//os:linux": ["@libpjrt_cpu_linux_amd64//:libpjrt_cpu"], + }), +) + +zig_library( + name = "cpu", + import_name = "runtimes/cpu", + main = "cpu.zig", + visibility = ["//visibility:public"], + deps = [ + "//pjrt", + ] + select({ + "//runtimes:cpu.enabled": [":libpjrt_cpu"], + "//conditions:default": [":empty"], + }), ) diff --git a/runtimes/cpu/cpu.zig b/runtimes/cpu/cpu.zig new file mode 100644 index 0000000..e8531ae --- /dev/null +++ b/runtimes/cpu/cpu.zig @@ -0,0 +1,20 @@ +const builtin = @import("builtin"); +const pjrt = @import("pjrt"); +const c = @import("c"); + +pub fn isEnabled() bool { + return @hasDecl(c, "ZML_RUNTIME_CPU"); +} + +pub fn load() !*const pjrt.Api { + if (comptime !isEnabled()) { + return error.Unavailable; + } + + const ext = switch (builtin.os.tag) { + .windows => ".dll", + .macos, .ios, .watchos => ".dylib", + else => ".so", + }; + return try pjrt.Api.loadFrom("libpjrt_cpu" ++ ext); +} diff --git a/runtimes/cuda/BUILD.bazel b/runtimes/cuda/BUILD.bazel index 530c10d..9b5df5a 100644 --- a/runtimes/cuda/BUILD.bazel +++ b/runtimes/cuda/BUILD.bazel @@ -1,5 +1,27 @@ -alias( - name = "cuda", - actual = "@libpjrt_cuda", - visibility = ["//visibility:public"], +load("@rules_zig//zig:defs.bzl", "zig_library") + +cc_library( + name = "empty", +) + +cc_library( + name = "libpjrt_cuda", + defines = ["ZML_RUNTIME_CUDA"], + deps = ["@libpjrt_cuda"], +) + +zig_library( + name = "cuda", + import_name = "runtimes/cuda", + main = "cuda.zig", + visibility = ["//visibility:public"], + deps = [ + "//pjrt", + ] + select({ + "//runtimes:cuda.enabled": [ + ":libpjrt_cuda", + "//async", + ], + "//conditions:default": [":empty"], + }), ) diff --git a/runtimes/cuda/cuda.zig b/runtimes/cuda/cuda.zig new file mode 100644 index 0000000..2863120 --- /dev/null +++ b/runtimes/cuda/cuda.zig @@ -0,0 +1,27 @@ +const builtin = @import("builtin"); +const asynk = @import("async"); +const pjrt = @import("pjrt"); +const c = @import("c"); + +pub fn isEnabled() bool { + return @hasDecl(c, "ZML_RUNTIME_CUDA"); +} + +fn hasNvidiaDevice() bool { + asynk.File.access("/dev/nvidia0", .{ .mode = .read_only }) catch return false; + return true; +} + +pub fn load() !*const pjrt.Api { + if (comptime !isEnabled()) { + return error.Unavailable; + } + if (comptime builtin.os.tag != .linux) { + return error.Unavailable; + } + if (!hasNvidiaDevice()) { + return error.Unavailable; + } + + return try pjrt.Api.loadFrom("libpjrt_cuda.so"); +} diff --git a/runtimes/rocm/BUILD.bazel b/runtimes/rocm/BUILD.bazel index 0294432..d7e7327 100644 --- a/runtimes/rocm/BUILD.bazel +++ b/runtimes/rocm/BUILD.bazel @@ -1,3 +1,5 @@ +load("@rules_zig//zig:defs.bzl", "zig_library") + filegroup( name = "zmlxrocm_srcs", srcs = ["zmlxrocm.cc"], @@ -14,8 +16,28 @@ alias( actual = "@libpjrt_rocm//:gfx", ) -alias( - name = "rocm", - actual = "@libpjrt_rocm", - visibility = ["//visibility:public"], +cc_library( + name = "empty", +) + +cc_library( + name = "libpjrt_rocm", + defines = ["ZML_RUNTIME_ROCM"], + deps = ["@libpjrt_rocm"], +) + +zig_library( + name = "rocm", + import_name = "runtimes/rocm", + main = "rocm.zig", + visibility = ["//visibility:public"], + deps = [ + "//pjrt", + ] + select({ + "//runtimes:rocm.enabled": [ + ":libpjrt_rocm", + "//async", + ], + "//conditions:default": [":empty"], + }), ) diff --git a/runtimes/rocm/rocm.zig b/runtimes/rocm/rocm.zig new file mode 100644 index 0000000..399ed4b --- /dev/null +++ b/runtimes/rocm/rocm.zig @@ -0,0 +1,29 @@ +const builtin = @import("builtin"); +const asynk = @import("async"); +const pjrt = @import("pjrt"); +const c = @import("c"); + +pub fn isEnabled() bool { + return @hasDecl(c, "ZML_RUNTIME_ROCM"); +} + +fn hasRocmDevices() bool { + inline for (&.{ "/dev/kfd", "/dev/dri" }) |path| { + asynk.File.access(path, .{ .mode = .read_only }) catch return false; + } + return true; +} + +pub fn load() !*const pjrt.Api { + if (comptime !isEnabled()) { + return error.Unavailable; + } + if (comptime builtin.os.tag != .linux) { + return error.Unavailable; + } + if (!hasRocmDevices()) { + return error.Unavailable; + } + + return try pjrt.Api.loadFrom("libpjrt_rocm.so"); +} diff --git a/runtimes/rocm/zmlxrocm.cc b/runtimes/rocm/zmlxrocm.cc index 7a42fc4..e69f75b 100644 --- a/runtimes/rocm/zmlxrocm.cc +++ b/runtimes/rocm/zmlxrocm.cc @@ -8,7 +8,7 @@ #include "tools/cpp/runfiles/runfiles.h" -__attribute__((constructor)) static void setup_runfiles(int argc, char **argv) +static void setup_runfiles(int argc, char **argv) __attribute__((constructor)) { using bazel::tools::cpp::runfiles::Runfiles; auto runfiles = std::unique_ptr(Runfiles::Create(argv[0], BAZEL_CURRENT_REPOSITORY)); @@ -33,7 +33,7 @@ __attribute__((constructor)) static void setup_runfiles(int argc, char **argv) setenv("ROCM_PATH", ROCM_PATH.c_str(), 1); } -extern "C" void *zmlxrocm_dlopen(const char *filename, int flags) +extern "C" void *zmlxrocm_dlopen(const char *filename, int flags) __attribute__((visibility("default"))) { if (filename != NULL) { diff --git a/runtimes/runtimes.zig b/runtimes/runtimes.zig new file mode 100644 index 0000000..2737439 --- /dev/null +++ b/runtimes/runtimes.zig @@ -0,0 +1,30 @@ +const pjrt = @import("pjrt"); +const cpu = @import("runtimes/cpu"); +const cuda = @import("runtimes/cuda"); +const rocm = @import("runtimes/rocm"); +const tpu = @import("runtimes/tpu"); + +pub const Platform = enum { + cpu, + cuda, + rocm, + tpu, +}; + +pub fn load(tag: Platform) !*const pjrt.Api { + return switch (tag) { + .cpu => try cpu.load(), + .cuda => try cuda.load(), + .rocm => try rocm.load(), + .tpu => try tpu.load(), + }; +} + +pub fn isEnabled(tag: Platform) bool { + return switch (tag) { + .cpu => cpu.isEnabled(), + .cuda => cuda.isEnabled(), + .rocm => rocm.isEnabled(), + .tpu => tpu.isEnabled(), + }; +} diff --git a/runtimes/tpu/BUILD.bazel b/runtimes/tpu/BUILD.bazel index 6bda66e..2da9ac3 100644 --- a/runtimes/tpu/BUILD.bazel +++ b/runtimes/tpu/BUILD.bazel @@ -1,5 +1,27 @@ -alias( - name = "tpu", - actual = "@libpjrt_tpu", - visibility = ["//visibility:public"], +load("@rules_zig//zig:defs.bzl", "zig_library") + +cc_library( + name = "empty", +) + +cc_library( + name = "libpjrt_tpu", + defines = ["ZML_RUNTIME_TPU"], + deps = ["@libpjrt_tpu"], +) + +zig_library( + name = "tpu", + import_name = "runtimes/tpu", + main = "tpu.zig", + visibility = ["//visibility:public"], + deps = [ + "//pjrt", + ] + select({ + "//runtimes:tpu.enabled": [ + ":libpjrt_tpu", + "//async", + ], + "//conditions:default": [":empty"], + }), ) diff --git a/runtimes/tpu/tpu.zig b/runtimes/tpu/tpu.zig new file mode 100644 index 0000000..1bca0b2 --- /dev/null +++ b/runtimes/tpu/tpu.zig @@ -0,0 +1,40 @@ +const builtin = @import("builtin"); +const asynk = @import("async"); +const pjrt = @import("pjrt"); +const c = @import("c"); +const std = @import("std"); + +pub fn isEnabled() bool { + return @hasDecl(c, "ZML_RUNTIME_TPU"); +} + +/// Check if running on Google Compute Engine, because TPUs will poll the +/// metadata server, hanging the process. So only do it on GCP. +/// Do it using the official method at: +/// https://cloud.google.com/compute/docs/instances/detect-compute-engine?hl=en#use_operating_system_tools_to_detect_if_a_vm_is_running_in +fn isOnGCP() !bool { + // TODO: abstract that in the client and fail init + const GoogleComputeEngine = "Google Compute Engine"; + + var f = try asynk.File.open("/sys/devices/virtual/dmi/id/product_name", .{ .mode = .read_only }); + defer f.close() catch {}; + + var buf = [_]u8{0} ** GoogleComputeEngine.len; + _ = try f.reader().readAll(&buf); + + return std.mem.eql(u8, &buf, GoogleComputeEngine); +} + +pub fn load() !*const pjrt.Api { + if (comptime !isEnabled()) { + return error.Unavailable; + } + if (comptime builtin.os.tag != .linux) { + return error.Unavailable; + } + if (!(isOnGCP() catch false)) { + return error.Unavailable; + } + + return try pjrt.Api.loadFrom("libpjrt_tpu.so"); +} diff --git a/zml/context.zig b/zml/context.zig index 13781ab..38cb76e 100644 --- a/zml/context.zig +++ b/zml/context.zig @@ -6,6 +6,7 @@ const mlir = @import("mlir"); const pjrt = @import("pjrt"); const c = @import("c"); const runfiles = @import("runfiles"); +const runtimes = @import("runtimes"); const platform = @import("platform.zig"); const Target = @import("platform.zig").Target; @@ -26,12 +27,10 @@ pub const Context = struct { var apis = PjrtApiMap.initFill(null); var apis_once = std.once(struct { fn call() void { - inline for (platform.available_targets) |t| { - if (canLoad(t)) { - if (pjrt.Api.loadFrom(platformToLibrary(t))) |api| { - Context.apis.set(t, api); - } else |_| {} - } + inline for (comptime std.enums.values(runtimes.Platform)) |t| { + if (runtimes.load(t)) |api| { + Context.apis.set(t, api); + } else |_| {} } } }.call); @@ -112,30 +111,6 @@ pub const Context = struct { }; } - fn canLoad(t: Target) bool { - return switch (t) { - .tpu => isRunningOnGCP() catch false, - else => true, - }; - } - - /// Check if running on Google Compute Engine, because TPUs will poll the - /// metadata server, hanging the process. So only do it on GCP. - /// Do it using the official method at: - /// https://cloud.google.com/compute/docs/instances/detect-compute-engine?hl=en#use_operating_system_tools_to_detect_if_a_vm_is_running_in - fn isRunningOnGCP() !bool { - // TODO: abstract that in the client and fail init - const GoogleComputeEngine = "Google Compute Engine"; - - var f = try asynk.File.open("/sys/devices/virtual/dmi/id/product_name", .{ .mode = .read_only }); - defer f.close() catch {}; - - var buf = [_]u8{0} ** GoogleComputeEngine.len; - _ = try f.reader().readAll(&buf); - - return std.mem.eql(u8, &buf, GoogleComputeEngine); - } - pub fn pjrtApi(target: Target) *const pjrt.Api { return Context.apis.get(target).?; } diff --git a/zml/platform.zig b/zml/platform.zig index a8ec82f..8ba22ea 100644 --- a/zml/platform.zig +++ b/zml/platform.zig @@ -3,17 +3,13 @@ const std = @import("std"); const pjrt = @import("pjrt"); const asynk = @import("async"); +const runtimes = @import("runtimes"); const meta = @import("meta.zig"); const module = @import("module.zig"); const log = std.log.scoped(.zml); -pub const Target = enum { - cpu, - cuda, - rocm, - tpu, -}; +pub const Target = runtimes.Platform; pub const available_targets = switch (builtin.os.tag) { .macos => [_]Target{