diff --git a/pjrt/pjrt.zig b/pjrt/pjrt.zig index cee62af..69bd22a 100644 --- a/pjrt/pjrt.zig +++ b/pjrt/pjrt.zig @@ -823,7 +823,7 @@ pub const NamedValue = extern struct { []i64, []const i64 => fromInt64List(name_, value), f32 => fromFloat(name_, value), bool => fromBool(name_, value), - else => unreachable, + else => fromString(name_, @tagName(value)), }; } diff --git a/zml/context.zig b/zml/context.zig index 94308ce..7e36b6b 100644 --- a/zml/context.zig +++ b/zml/context.zig @@ -81,35 +81,16 @@ pub const Context = struct { Context.apis_once.call(); Context.mlir_once.call(); - var platforms = PlatformsMap.initFill(null); var num_platforms: u8 = 0; - var it = Context.apis.iterator(); - while (it.next()) |entry| { - if (entry.value.*) |api| { - const target = entry.key; - const p = Platform.init(target, api) catch |err| { - log.err("Failed to load platform .{s}: {}", .{ @tagName(target), err }); - continue; - }; - if (p.getDevices().len == 0) { - log.err("No device found for platform {} !", .{target}); - continue; - } - if (target == .cuda) { - try cuda.registerZmlCustomCalls(p); - } - platforms.set(target, p); - num_platforms += 1; - } + for (Context.apis.values) |api| { + if (api != null) num_platforms += 1; } if (num_platforms == 0) { log.err("No platform available", .{}); return error.NoPlatformAvailable; } - return .{ - .platforms = platforms, - }; + return .{ .platforms = PlatformsMap.initFill(null) }; } fn platformToLibrary(comptime target: Target) []const u8 { @@ -137,20 +118,69 @@ pub const Context = struct { self.* = undefined; } + const prefered_targets = [_]Target{ .tpu, .neuron, .cuda, .rocm, .cpu }; + /// Automatically selects the best Platform loaded in the current Context. /// /// For example, if supported, this will select a platform corresponding to an accelerator (GPU, TPU, ...). - pub fn autoPlatform(self: *Context) Platform { - // the last platform is the one that with the high enum number, so considered - // to be the "best" one - var platform_: ?Platform = null; - var iterator = self.platforms.iterator(); - while (iterator.next()) |entry| { - if (entry.value.*) |p| { - platform_ = p; - } + pub fn autoPlatform(self: *Context, opts: Platform.CreateOptions) Platform { + stdx.debug.assert(prefered_targets.len == apis.values.len, "New target need to be inserted inside `zml.Context.preferred_targets`", .{}); + + return self.platformByPreferences(opts, &prefered_targets); + } + + /// Given a list of preferred targets to select the best Platform + /// + /// For example, if supported, this will select a platform corresponding to an accelerator (GPU, TPU, ...). + pub fn platformByPreferences(self: *Context, opts: Platform.CreateOptions, prefered: []const Target) Platform { + // Try prefered targets. + for (prefered) |target| { + if (apis.get(target) == null) continue; + return self.platform(target, opts) catch |err| { + log.err("Failed to load platform .{s}: {}", .{ @tagName(target), err }); + continue; + }; } - return platform_ orelse @panic("No platform found !"); + + // Try unlisted targets + var it = Context.apis.iterator(); + while (it.next()) |entry| { + const target = entry.key; + // CPU should only be use as fallback. + if (target == .cpu) continue; + if (entry.value.* == null) continue; + if (std.mem.indexOfScalar(Target, prefered, target) != null) continue; + return self.platform(target, opts) catch |err| { + log.err("Failed to load platform .{s}: {}", .{ @tagName(target), err }); + continue; + }; + } + + // Finally fallback to cpu. + return self.platform(.cpu, opts) catch { + log.err("No platform available", .{}); + @panic("No platform available !"); + }; + } + + pub fn platform(self: *Context, target: Target, opts: Platform.CreateOptions) !Platform { + if (self.platforms.get(target)) |p| { + return p; + } + const api = Context.apis.get(target); + if (api == null) return error.PlatformNotCompiled; + const p = try Platform.init(target, api.?, opts); + if (p.getDevices().len == 0) { + log.err("No device found for platform {} !", .{target}); + return error.NoDevicesFound; + } + // TODO: should this be moved to platform.zig ? + if (target == .cuda) { + try cuda.registerZmlCustomCalls(p); + } + + self.platforms.set(target, p); + return p; } pub fn printAvailablePlatforms(self: Context, selected: Platform) void { diff --git a/zml/pjrtx.zig b/zml/pjrtx.zig index 9d43b9b..7408e07 100644 --- a/zml/pjrtx.zig +++ b/zml/pjrtx.zig @@ -42,8 +42,8 @@ fn InnerMixin(comptime innerT: type) type { pub const Client = opaque { const inner = InnerMixin(pjrt.Client).inner; - pub fn init(api: *const Api, create_options: []const NamedValue) ClientInitError!*Client { - return @ptrCast(try pjrt.Client.init(api, create_options)); + pub fn init(api: *const Api, options: []const NamedValue) ClientInitError!*Client { + return @ptrCast(try pjrt.Client.init(api, options)); } pub fn deinit(self: *Client, api: *const Api) void { diff --git a/zml/platform.zig b/zml/platform.zig index c594ba6..ae515d6 100644 --- a/zml/platform.zig +++ b/zml/platform.zig @@ -29,9 +29,11 @@ pub const Platform = struct { compilation_options: CompilationOptions = .{}, pub const MAX_NUM_DEVICES: u8 = 32; + pub const CreateOptions = _CreateOptions; - pub fn init(target: Target, api: *const pjrt.Api) !Platform { - const pjrt_client = try pjrt.Client.init(api, &.{}); + pub fn init(target: Target, api: *const pjrt.Api, options: CreateOptions) !Platform { + var named_values_buf: [16]pjrt.NamedValue = undefined; + const pjrt_client = try pjrt.Client.init(api, options.toNamedValues(target, &named_values_buf)); const true_num_devices = pjrt_client.getAddressableDevices(api).len; if (true_num_devices > MAX_NUM_DEVICES) { log.warn("platform {} got {} devices, but ZML only support up to {} devices. Some devices won't be used.", .{ target, true_num_devices, MAX_NUM_DEVICES }); @@ -84,3 +86,75 @@ pub const Platform = struct { return self.pjrt_client.getProfiler(self.pjrt_api, options); } }; + +const _CreateOptions = struct { + // XLA CPU client doesn't read options + // https://github.com/openxla/xla/blob/42496a28c374bd35f493cc5dbde74805407245dc/xla/pjrt/c/pjrt_c_api_cpu_internal.cc#L33-L46 + cpu: struct {} = .{}, + + // bump memory fraction from XLA defaults of 75% to 90%. + // Even on a 8GB GPU it should leave enough space for the Cuda driver + // https://github.com/openxla/xla/blob/3e87afa11a865cf91137522492918ad18bfe5b7c/xla/pjrt/plugin/xla_gpu/xla_gpu_allocator_config.h#L25-L60 + cuda: Cuda = .{ .allocator = .{ .bfc = .{ .preallocate = true, .memory_fraction = 0.90 } } }, + rocm: struct {} = .{}, + tpu: struct {} = .{}, + neuron: struct {} = .{}, + + pub const Cuda = struct { + allocator: Allocator = .{ .bfc = .{} }, + // TODO support all of https://github.com/openxla/xla/blob/3d31c48c719d331d432132b3e0c2c5ce52650675/xla/pjrt/c/pjrt_c_api_gpu_internal.cc#L76-L86 + // visible_devices: []const i64 = &.{}, + // node_id + // num_nodes + // enable_mock_nccl + // mock_gpu_topology + + pub const Allocator = union(enum) { + /// "Best-Fit with Coalescing" algorithm + bfc: Options, + /// use cudaMallocAsync + @"async": Options, + /// use raw cuMalloc + platform, + + pub const Options = struct { + preallocate: bool = true, + memory_fraction: f32 = 0.90, + collective_memory_size_mb: u32 = 0, + }; + }; + + pub fn writeNamedValues(self: Cuda, values: *std.ArrayListUnmanaged(pjrt.NamedValue)) void { + switch (self.allocator) { + .platform => { + values.appendAssumeCapacity(pjrt.NamedValue.fromString("allocator", "platform")); + }, + .bfc, .@"async" => |opt| { + values.appendAssumeCapacity(pjrt.NamedValue.from("allocator", self.allocator)); + values.appendAssumeCapacity(pjrt.NamedValue.from("preallocate", opt.preallocate)); + if (opt.memory_fraction > 0) { + values.appendAssumeCapacity(pjrt.NamedValue.from("memory_fraction", opt.memory_fraction)); + } + if (opt.collective_memory_size_mb > 0) { + const collective = @as(i64, opt.collective_memory_size_mb) * 1024 * 1024; + values.appendAssumeCapacity(pjrt.NamedValue.from("collective_memory_size", collective)); + } + }, + } + } + }; + + pub fn toNamedValues(self: _CreateOptions, target: Target, out: []pjrt.NamedValue) []pjrt.NamedValue { + var values = std.ArrayListUnmanaged(pjrt.NamedValue).fromOwnedSlice(out); + values.shrinkRetainingCapacity(0); + switch (target) { + .cuda => self.cuda.writeNamedValues(&values), + inline else => |t| { + stdx.debug.assertComptime(@hasField(_CreateOptions, @tagName(t)), "zml.platform.CreateOptions doesn't list target {s}", .{@tagName(t)}); + const options = @field(self, @tagName(t)); + stdx.debug.assertComptime(@sizeOf(@TypeOf(options)) == 0, "zml.platform.CreateOptions.{s} is discarded", .{@tagName(t)}); + }, + } + return values.items; + } +}; diff --git a/zml/testing.zig b/zml/testing.zig index e5626d2..ae09026 100644 --- a/zml/testing.zig +++ b/zml/testing.zig @@ -8,11 +8,11 @@ const shapesOf = @import("tensor.zig").shapesOf; const log = std.log.scoped(.@"zml/testing"); -var _ctx: ?zml.Context = null; +var _platform: ?zml.Platform = null; pub fn env() zml.Platform { if (!builtin.is_test) @compileError("Cannot use zml.testing.env outside of a test block"); - if (_ctx == null) { + if (_platform == null) { _test_compile_opts = if (initCacheDir()) .{ .cache_location = "/tmp/zml/tests/cache", @@ -22,10 +22,11 @@ pub fn env() zml.Platform { else .{}; - _ctx = zml.Context.init() catch unreachable; + var ctx = zml.Context.init() catch unreachable; + _platform = ctx.autoPlatform(.{}).withCompilationOptions(_test_compile_opts); } - return _ctx.?.autoPlatform().withCompilationOptions(_test_compile_opts); + return _platform.?; } var _test_compile_opts: zml.CompilationOptions = .{};