Add struct‑based client creation flags to the Zig PJRT API and update context.autoPlatform to accept a flag struct.
This commit is contained in:
parent
cb6fcbbb1a
commit
57bf667c90
@ -823,7 +823,7 @@ pub const NamedValue = extern struct {
|
||||
[]i64, []const i64 => fromInt64List(name_, value),
|
||||
f32 => fromFloat(name_, value),
|
||||
bool => fromBool(name_, value),
|
||||
else => unreachable,
|
||||
else => fromString(name_, @tagName(value)),
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@ -81,35 +81,16 @@ pub const Context = struct {
|
||||
Context.apis_once.call();
|
||||
Context.mlir_once.call();
|
||||
|
||||
var platforms = PlatformsMap.initFill(null);
|
||||
var num_platforms: u8 = 0;
|
||||
var it = Context.apis.iterator();
|
||||
while (it.next()) |entry| {
|
||||
if (entry.value.*) |api| {
|
||||
const target = entry.key;
|
||||
const p = Platform.init(target, api) catch |err| {
|
||||
log.err("Failed to load platform .{s}: {}", .{ @tagName(target), err });
|
||||
continue;
|
||||
};
|
||||
if (p.getDevices().len == 0) {
|
||||
log.err("No device found for platform {} !", .{target});
|
||||
continue;
|
||||
}
|
||||
if (target == .cuda) {
|
||||
try cuda.registerZmlCustomCalls(p);
|
||||
}
|
||||
platforms.set(target, p);
|
||||
num_platforms += 1;
|
||||
}
|
||||
for (Context.apis.values) |api| {
|
||||
if (api != null) num_platforms += 1;
|
||||
}
|
||||
if (num_platforms == 0) {
|
||||
log.err("No platform available", .{});
|
||||
return error.NoPlatformAvailable;
|
||||
}
|
||||
|
||||
return .{
|
||||
.platforms = platforms,
|
||||
};
|
||||
return .{ .platforms = PlatformsMap.initFill(null) };
|
||||
}
|
||||
|
||||
fn platformToLibrary(comptime target: Target) []const u8 {
|
||||
@ -137,20 +118,69 @@ pub const Context = struct {
|
||||
self.* = undefined;
|
||||
}
|
||||
|
||||
const prefered_targets = [_]Target{ .tpu, .neuron, .cuda, .rocm, .cpu };
|
||||
|
||||
/// Automatically selects the best Platform loaded in the current Context.
|
||||
///
|
||||
/// For example, if supported, this will select a platform corresponding to an accelerator (GPU, TPU, ...).
|
||||
pub fn autoPlatform(self: *Context) Platform {
|
||||
// the last platform is the one that with the high enum number, so considered
|
||||
// to be the "best" one
|
||||
var platform_: ?Platform = null;
|
||||
var iterator = self.platforms.iterator();
|
||||
while (iterator.next()) |entry| {
|
||||
if (entry.value.*) |p| {
|
||||
platform_ = p;
|
||||
pub fn autoPlatform(self: *Context, opts: Platform.CreateOptions) Platform {
|
||||
stdx.debug.assert(prefered_targets.len == apis.values.len, "New target need to be inserted inside `zml.Context.preferred_targets`", .{});
|
||||
|
||||
return self.platformByPreferences(opts, &prefered_targets);
|
||||
}
|
||||
|
||||
/// Given a list of preferred targets to select the best Platform
|
||||
///
|
||||
/// For example, if supported, this will select a platform corresponding to an accelerator (GPU, TPU, ...).
|
||||
pub fn platformByPreferences(self: *Context, opts: Platform.CreateOptions, prefered: []const Target) Platform {
|
||||
// Try prefered targets.
|
||||
for (prefered) |target| {
|
||||
if (apis.get(target) == null) continue;
|
||||
return self.platform(target, opts) catch |err| {
|
||||
log.err("Failed to load platform .{s}: {}", .{ @tagName(target), err });
|
||||
continue;
|
||||
};
|
||||
}
|
||||
return platform_ orelse @panic("No platform found !");
|
||||
|
||||
// Try unlisted targets
|
||||
var it = Context.apis.iterator();
|
||||
while (it.next()) |entry| {
|
||||
const target = entry.key;
|
||||
// CPU should only be use as fallback.
|
||||
if (target == .cpu) continue;
|
||||
if (entry.value.* == null) continue;
|
||||
if (std.mem.indexOfScalar(Target, prefered, target) != null) continue;
|
||||
return self.platform(target, opts) catch |err| {
|
||||
log.err("Failed to load platform .{s}: {}", .{ @tagName(target), err });
|
||||
continue;
|
||||
};
|
||||
}
|
||||
|
||||
// Finally fallback to cpu.
|
||||
return self.platform(.cpu, opts) catch {
|
||||
log.err("No platform available", .{});
|
||||
@panic("No platform available !");
|
||||
};
|
||||
}
|
||||
|
||||
pub fn platform(self: *Context, target: Target, opts: Platform.CreateOptions) !Platform {
|
||||
if (self.platforms.get(target)) |p| {
|
||||
return p;
|
||||
}
|
||||
const api = Context.apis.get(target);
|
||||
if (api == null) return error.PlatformNotCompiled;
|
||||
const p = try Platform.init(target, api.?, opts);
|
||||
if (p.getDevices().len == 0) {
|
||||
log.err("No device found for platform {} !", .{target});
|
||||
return error.NoDevicesFound;
|
||||
}
|
||||
// TODO: should this be moved to platform.zig ?
|
||||
if (target == .cuda) {
|
||||
try cuda.registerZmlCustomCalls(p);
|
||||
}
|
||||
|
||||
self.platforms.set(target, p);
|
||||
return p;
|
||||
}
|
||||
|
||||
pub fn printAvailablePlatforms(self: Context, selected: Platform) void {
|
||||
|
||||
@ -42,8 +42,8 @@ fn InnerMixin(comptime innerT: type) type {
|
||||
pub const Client = opaque {
|
||||
const inner = InnerMixin(pjrt.Client).inner;
|
||||
|
||||
pub fn init(api: *const Api, create_options: []const NamedValue) ClientInitError!*Client {
|
||||
return @ptrCast(try pjrt.Client.init(api, create_options));
|
||||
pub fn init(api: *const Api, options: []const NamedValue) ClientInitError!*Client {
|
||||
return @ptrCast(try pjrt.Client.init(api, options));
|
||||
}
|
||||
|
||||
pub fn deinit(self: *Client, api: *const Api) void {
|
||||
|
||||
@ -29,9 +29,11 @@ pub const Platform = struct {
|
||||
compilation_options: CompilationOptions = .{},
|
||||
|
||||
pub const MAX_NUM_DEVICES: u8 = 32;
|
||||
pub const CreateOptions = _CreateOptions;
|
||||
|
||||
pub fn init(target: Target, api: *const pjrt.Api) !Platform {
|
||||
const pjrt_client = try pjrt.Client.init(api, &.{});
|
||||
pub fn init(target: Target, api: *const pjrt.Api, options: CreateOptions) !Platform {
|
||||
var named_values_buf: [16]pjrt.NamedValue = undefined;
|
||||
const pjrt_client = try pjrt.Client.init(api, options.toNamedValues(target, &named_values_buf));
|
||||
const true_num_devices = pjrt_client.getAddressableDevices(api).len;
|
||||
if (true_num_devices > MAX_NUM_DEVICES) {
|
||||
log.warn("platform {} got {} devices, but ZML only support up to {} devices. Some devices won't be used.", .{ target, true_num_devices, MAX_NUM_DEVICES });
|
||||
@ -84,3 +86,75 @@ pub const Platform = struct {
|
||||
return self.pjrt_client.getProfiler(self.pjrt_api, options);
|
||||
}
|
||||
};
|
||||
|
||||
const _CreateOptions = struct {
|
||||
// XLA CPU client doesn't read options
|
||||
// https://github.com/openxla/xla/blob/42496a28c374bd35f493cc5dbde74805407245dc/xla/pjrt/c/pjrt_c_api_cpu_internal.cc#L33-L46
|
||||
cpu: struct {} = .{},
|
||||
|
||||
// bump memory fraction from XLA defaults of 75% to 90%.
|
||||
// Even on a 8GB GPU it should leave enough space for the Cuda driver
|
||||
// https://github.com/openxla/xla/blob/3e87afa11a865cf91137522492918ad18bfe5b7c/xla/pjrt/plugin/xla_gpu/xla_gpu_allocator_config.h#L25-L60
|
||||
cuda: Cuda = .{ .allocator = .{ .bfc = .{ .preallocate = true, .memory_fraction = 0.90 } } },
|
||||
rocm: struct {} = .{},
|
||||
tpu: struct {} = .{},
|
||||
neuron: struct {} = .{},
|
||||
|
||||
pub const Cuda = struct {
|
||||
allocator: Allocator = .{ .bfc = .{} },
|
||||
// TODO support all of https://github.com/openxla/xla/blob/3d31c48c719d331d432132b3e0c2c5ce52650675/xla/pjrt/c/pjrt_c_api_gpu_internal.cc#L76-L86
|
||||
// visible_devices: []const i64 = &.{},
|
||||
// node_id
|
||||
// num_nodes
|
||||
// enable_mock_nccl
|
||||
// mock_gpu_topology
|
||||
|
||||
pub const Allocator = union(enum) {
|
||||
/// "Best-Fit with Coalescing" algorithm
|
||||
bfc: Options,
|
||||
/// use cudaMallocAsync
|
||||
@"async": Options,
|
||||
/// use raw cuMalloc
|
||||
platform,
|
||||
|
||||
pub const Options = struct {
|
||||
preallocate: bool = true,
|
||||
memory_fraction: f32 = 0.90,
|
||||
collective_memory_size_mb: u32 = 0,
|
||||
};
|
||||
};
|
||||
|
||||
pub fn writeNamedValues(self: Cuda, values: *std.ArrayListUnmanaged(pjrt.NamedValue)) void {
|
||||
switch (self.allocator) {
|
||||
.platform => {
|
||||
values.appendAssumeCapacity(pjrt.NamedValue.fromString("allocator", "platform"));
|
||||
},
|
||||
.bfc, .@"async" => |opt| {
|
||||
values.appendAssumeCapacity(pjrt.NamedValue.from("allocator", self.allocator));
|
||||
values.appendAssumeCapacity(pjrt.NamedValue.from("preallocate", opt.preallocate));
|
||||
if (opt.memory_fraction > 0) {
|
||||
values.appendAssumeCapacity(pjrt.NamedValue.from("memory_fraction", opt.memory_fraction));
|
||||
}
|
||||
if (opt.collective_memory_size_mb > 0) {
|
||||
const collective = @as(i64, opt.collective_memory_size_mb) * 1024 * 1024;
|
||||
values.appendAssumeCapacity(pjrt.NamedValue.from("collective_memory_size", collective));
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
pub fn toNamedValues(self: _CreateOptions, target: Target, out: []pjrt.NamedValue) []pjrt.NamedValue {
|
||||
var values = std.ArrayListUnmanaged(pjrt.NamedValue).fromOwnedSlice(out);
|
||||
values.shrinkRetainingCapacity(0);
|
||||
switch (target) {
|
||||
.cuda => self.cuda.writeNamedValues(&values),
|
||||
inline else => |t| {
|
||||
stdx.debug.assertComptime(@hasField(_CreateOptions, @tagName(t)), "zml.platform.CreateOptions doesn't list target {s}", .{@tagName(t)});
|
||||
const options = @field(self, @tagName(t));
|
||||
stdx.debug.assertComptime(@sizeOf(@TypeOf(options)) == 0, "zml.platform.CreateOptions.{s} is discarded", .{@tagName(t)});
|
||||
},
|
||||
}
|
||||
return values.items;
|
||||
}
|
||||
};
|
||||
|
||||
@ -8,11 +8,11 @@ const shapesOf = @import("tensor.zig").shapesOf;
|
||||
|
||||
const log = std.log.scoped(.@"zml/testing");
|
||||
|
||||
var _ctx: ?zml.Context = null;
|
||||
var _platform: ?zml.Platform = null;
|
||||
|
||||
pub fn env() zml.Platform {
|
||||
if (!builtin.is_test) @compileError("Cannot use zml.testing.env outside of a test block");
|
||||
if (_ctx == null) {
|
||||
if (_platform == null) {
|
||||
_test_compile_opts = if (initCacheDir())
|
||||
.{
|
||||
.cache_location = "/tmp/zml/tests/cache",
|
||||
@ -22,10 +22,11 @@ pub fn env() zml.Platform {
|
||||
else
|
||||
.{};
|
||||
|
||||
_ctx = zml.Context.init() catch unreachable;
|
||||
var ctx = zml.Context.init() catch unreachable;
|
||||
_platform = ctx.autoPlatform(.{}).withCompilationOptions(_test_compile_opts);
|
||||
}
|
||||
|
||||
return _ctx.?.autoPlatform().withCompilationOptions(_test_compile_opts);
|
||||
return _platform.?;
|
||||
}
|
||||
|
||||
var _test_compile_opts: zml.CompilationOptions = .{};
|
||||
|
||||
Loading…
Reference in New Issue
Block a user