Add struct‑based client creation flags to the Zig PJRT API and update context.autoPlatform to accept a flag struct.

This commit is contained in:
Tarry Singh 2023-11-13 12:45:17 +00:00
parent cb6fcbbb1a
commit 57bf667c90
5 changed files with 146 additions and 41 deletions

View File

@ -823,7 +823,7 @@ pub const NamedValue = extern struct {
[]i64, []const i64 => fromInt64List(name_, value),
f32 => fromFloat(name_, value),
bool => fromBool(name_, value),
else => unreachable,
else => fromString(name_, @tagName(value)),
};
}

View File

@ -81,35 +81,16 @@ pub const Context = struct {
Context.apis_once.call();
Context.mlir_once.call();
var platforms = PlatformsMap.initFill(null);
var num_platforms: u8 = 0;
var it = Context.apis.iterator();
while (it.next()) |entry| {
if (entry.value.*) |api| {
const target = entry.key;
const p = Platform.init(target, api) catch |err| {
log.err("Failed to load platform .{s}: {}", .{ @tagName(target), err });
continue;
};
if (p.getDevices().len == 0) {
log.err("No device found for platform {} !", .{target});
continue;
}
if (target == .cuda) {
try cuda.registerZmlCustomCalls(p);
}
platforms.set(target, p);
num_platforms += 1;
}
for (Context.apis.values) |api| {
if (api != null) num_platforms += 1;
}
if (num_platforms == 0) {
log.err("No platform available", .{});
return error.NoPlatformAvailable;
}
return .{
.platforms = platforms,
};
return .{ .platforms = PlatformsMap.initFill(null) };
}
fn platformToLibrary(comptime target: Target) []const u8 {
@ -137,20 +118,69 @@ pub const Context = struct {
self.* = undefined;
}
const prefered_targets = [_]Target{ .tpu, .neuron, .cuda, .rocm, .cpu };
/// Automatically selects the best Platform loaded in the current Context.
///
/// For example, if supported, this will select a platform corresponding to an accelerator (GPU, TPU, ...).
pub fn autoPlatform(self: *Context) Platform {
// the last platform is the one that with the high enum number, so considered
// to be the "best" one
var platform_: ?Platform = null;
var iterator = self.platforms.iterator();
while (iterator.next()) |entry| {
if (entry.value.*) |p| {
platform_ = p;
}
pub fn autoPlatform(self: *Context, opts: Platform.CreateOptions) Platform {
stdx.debug.assert(prefered_targets.len == apis.values.len, "New target need to be inserted inside `zml.Context.preferred_targets`", .{});
return self.platformByPreferences(opts, &prefered_targets);
}
/// Given a list of preferred targets to select the best Platform
///
/// For example, if supported, this will select a platform corresponding to an accelerator (GPU, TPU, ...).
pub fn platformByPreferences(self: *Context, opts: Platform.CreateOptions, prefered: []const Target) Platform {
// Try prefered targets.
for (prefered) |target| {
if (apis.get(target) == null) continue;
return self.platform(target, opts) catch |err| {
log.err("Failed to load platform .{s}: {}", .{ @tagName(target), err });
continue;
};
}
return platform_ orelse @panic("No platform found !");
// Try unlisted targets
var it = Context.apis.iterator();
while (it.next()) |entry| {
const target = entry.key;
// CPU should only be use as fallback.
if (target == .cpu) continue;
if (entry.value.* == null) continue;
if (std.mem.indexOfScalar(Target, prefered, target) != null) continue;
return self.platform(target, opts) catch |err| {
log.err("Failed to load platform .{s}: {}", .{ @tagName(target), err });
continue;
};
}
// Finally fallback to cpu.
return self.platform(.cpu, opts) catch {
log.err("No platform available", .{});
@panic("No platform available !");
};
}
pub fn platform(self: *Context, target: Target, opts: Platform.CreateOptions) !Platform {
if (self.platforms.get(target)) |p| {
return p;
}
const api = Context.apis.get(target);
if (api == null) return error.PlatformNotCompiled;
const p = try Platform.init(target, api.?, opts);
if (p.getDevices().len == 0) {
log.err("No device found for platform {} !", .{target});
return error.NoDevicesFound;
}
// TODO: should this be moved to platform.zig ?
if (target == .cuda) {
try cuda.registerZmlCustomCalls(p);
}
self.platforms.set(target, p);
return p;
}
pub fn printAvailablePlatforms(self: Context, selected: Platform) void {

View File

@ -42,8 +42,8 @@ fn InnerMixin(comptime innerT: type) type {
pub const Client = opaque {
const inner = InnerMixin(pjrt.Client).inner;
pub fn init(api: *const Api, create_options: []const NamedValue) ClientInitError!*Client {
return @ptrCast(try pjrt.Client.init(api, create_options));
pub fn init(api: *const Api, options: []const NamedValue) ClientInitError!*Client {
return @ptrCast(try pjrt.Client.init(api, options));
}
pub fn deinit(self: *Client, api: *const Api) void {

View File

@ -29,9 +29,11 @@ pub const Platform = struct {
compilation_options: CompilationOptions = .{},
pub const MAX_NUM_DEVICES: u8 = 32;
pub const CreateOptions = _CreateOptions;
pub fn init(target: Target, api: *const pjrt.Api) !Platform {
const pjrt_client = try pjrt.Client.init(api, &.{});
pub fn init(target: Target, api: *const pjrt.Api, options: CreateOptions) !Platform {
var named_values_buf: [16]pjrt.NamedValue = undefined;
const pjrt_client = try pjrt.Client.init(api, options.toNamedValues(target, &named_values_buf));
const true_num_devices = pjrt_client.getAddressableDevices(api).len;
if (true_num_devices > MAX_NUM_DEVICES) {
log.warn("platform {} got {} devices, but ZML only support up to {} devices. Some devices won't be used.", .{ target, true_num_devices, MAX_NUM_DEVICES });
@ -84,3 +86,75 @@ pub const Platform = struct {
return self.pjrt_client.getProfiler(self.pjrt_api, options);
}
};
const _CreateOptions = struct {
// XLA CPU client doesn't read options
// https://github.com/openxla/xla/blob/42496a28c374bd35f493cc5dbde74805407245dc/xla/pjrt/c/pjrt_c_api_cpu_internal.cc#L33-L46
cpu: struct {} = .{},
// bump memory fraction from XLA defaults of 75% to 90%.
// Even on a 8GB GPU it should leave enough space for the Cuda driver
// https://github.com/openxla/xla/blob/3e87afa11a865cf91137522492918ad18bfe5b7c/xla/pjrt/plugin/xla_gpu/xla_gpu_allocator_config.h#L25-L60
cuda: Cuda = .{ .allocator = .{ .bfc = .{ .preallocate = true, .memory_fraction = 0.90 } } },
rocm: struct {} = .{},
tpu: struct {} = .{},
neuron: struct {} = .{},
pub const Cuda = struct {
allocator: Allocator = .{ .bfc = .{} },
// TODO support all of https://github.com/openxla/xla/blob/3d31c48c719d331d432132b3e0c2c5ce52650675/xla/pjrt/c/pjrt_c_api_gpu_internal.cc#L76-L86
// visible_devices: []const i64 = &.{},
// node_id
// num_nodes
// enable_mock_nccl
// mock_gpu_topology
pub const Allocator = union(enum) {
/// "Best-Fit with Coalescing" algorithm
bfc: Options,
/// use cudaMallocAsync
@"async": Options,
/// use raw cuMalloc
platform,
pub const Options = struct {
preallocate: bool = true,
memory_fraction: f32 = 0.90,
collective_memory_size_mb: u32 = 0,
};
};
pub fn writeNamedValues(self: Cuda, values: *std.ArrayListUnmanaged(pjrt.NamedValue)) void {
switch (self.allocator) {
.platform => {
values.appendAssumeCapacity(pjrt.NamedValue.fromString("allocator", "platform"));
},
.bfc, .@"async" => |opt| {
values.appendAssumeCapacity(pjrt.NamedValue.from("allocator", self.allocator));
values.appendAssumeCapacity(pjrt.NamedValue.from("preallocate", opt.preallocate));
if (opt.memory_fraction > 0) {
values.appendAssumeCapacity(pjrt.NamedValue.from("memory_fraction", opt.memory_fraction));
}
if (opt.collective_memory_size_mb > 0) {
const collective = @as(i64, opt.collective_memory_size_mb) * 1024 * 1024;
values.appendAssumeCapacity(pjrt.NamedValue.from("collective_memory_size", collective));
}
},
}
}
};
pub fn toNamedValues(self: _CreateOptions, target: Target, out: []pjrt.NamedValue) []pjrt.NamedValue {
var values = std.ArrayListUnmanaged(pjrt.NamedValue).fromOwnedSlice(out);
values.shrinkRetainingCapacity(0);
switch (target) {
.cuda => self.cuda.writeNamedValues(&values),
inline else => |t| {
stdx.debug.assertComptime(@hasField(_CreateOptions, @tagName(t)), "zml.platform.CreateOptions doesn't list target {s}", .{@tagName(t)});
const options = @field(self, @tagName(t));
stdx.debug.assertComptime(@sizeOf(@TypeOf(options)) == 0, "zml.platform.CreateOptions.{s} is discarded", .{@tagName(t)});
},
}
return values.items;
}
};

View File

@ -8,11 +8,11 @@ const shapesOf = @import("tensor.zig").shapesOf;
const log = std.log.scoped(.@"zml/testing");
var _ctx: ?zml.Context = null;
var _platform: ?zml.Platform = null;
pub fn env() zml.Platform {
if (!builtin.is_test) @compileError("Cannot use zml.testing.env outside of a test block");
if (_ctx == null) {
if (_platform == null) {
_test_compile_opts = if (initCacheDir())
.{
.cache_location = "/tmp/zml/tests/cache",
@ -22,10 +22,11 @@ pub fn env() zml.Platform {
else
.{};
_ctx = zml.Context.init() catch unreachable;
var ctx = zml.Context.init() catch unreachable;
_platform = ctx.autoPlatform(.{}).withCompilationOptions(_test_compile_opts);
}
return _ctx.?.autoPlatform().withCompilationOptions(_test_compile_opts);
return _platform.?;
}
var _test_compile_opts: zml.CompilationOptions = .{};