130 lines
4.4 KiB
Zig
130 lines
4.4 KiB
Zig
const builtin = @import("builtin");
|
|
const std = @import("std");
|
|
|
|
const pjrt = @import("pjrt");
|
|
const asynk = @import("async");
|
|
const runtimes = @import("runtimes");
|
|
|
|
const meta = @import("meta.zig");
|
|
const module = @import("module.zig");
|
|
const log = std.log.scoped(.zml);
|
|
|
|
pub const Target = runtimes.Platform;
|
|
|
|
pub const available_targets = switch (builtin.os.tag) {
|
|
.macos => [_]Target{
|
|
.cpu,
|
|
},
|
|
.linux => [_]Target{
|
|
.cpu,
|
|
.cuda,
|
|
.rocm,
|
|
.tpu,
|
|
},
|
|
else => [_]Target{},
|
|
};
|
|
|
|
pub const CompilationOptions = struct {
|
|
xla_dump_to: ?[]const u8 = null,
|
|
xla_dump_fusion_visualization: bool = false,
|
|
cache_location: ?[]const u8 = null,
|
|
sharding_enabled: bool = false,
|
|
sharding_axes: std.BoundedArray([*:0]const u8, 8) = .{},
|
|
};
|
|
|
|
pub const Platform = struct {
|
|
target: Target,
|
|
pjrt_api: *const pjrt.Api,
|
|
pjrt_client: *pjrt.Client,
|
|
compilation_options: CompilationOptions = .{},
|
|
|
|
pub const MAX_NUM_DEVICES: u8 = 8;
|
|
|
|
pub fn init(target: Target, api: *const pjrt.Api) !Platform {
|
|
const pjrt_client = try pjrt.Client.init(api, &.{});
|
|
const true_num_devices = pjrt_client.getAddressableDevices(api).len;
|
|
if (true_num_devices > MAX_NUM_DEVICES) {
|
|
log.warn("platform {} got {} devices, but ZML only support up to {} devices. Some devices won't be used.", .{ target, true_num_devices, MAX_NUM_DEVICES });
|
|
}
|
|
return .{
|
|
.target = target,
|
|
.pjrt_api = api,
|
|
.pjrt_client = pjrt_client,
|
|
.compilation_options = .{},
|
|
};
|
|
}
|
|
|
|
pub fn getDevices(self: Platform) []const *const pjrt.Device {
|
|
const all_devices = self.pjrt_client.getAddressableDevices(self.pjrt_api);
|
|
if (all_devices.len > MAX_NUM_DEVICES) {
|
|
return all_devices[0..MAX_NUM_DEVICES];
|
|
}
|
|
return all_devices;
|
|
}
|
|
|
|
pub const Sharding = struct { num_replicas: u8, num_partitions: u8 };
|
|
|
|
pub fn sharding(self: Platform) Sharding {
|
|
// replicas run the same function but with different inputs,
|
|
// while partitions contribute to one evaluation over a shared input.
|
|
// Inside an inference process, we generally don't want replicas,
|
|
// as it's best to fully isolate replicas on different processes.
|
|
// For now we hardcode num_replicas = 1.
|
|
const num_devices: u8 = @intCast(self.getDevices().len);
|
|
return if (self.compilation_options.sharding_enabled)
|
|
.{ .num_replicas = 1, .num_partitions = num_devices }
|
|
else
|
|
.{ .num_replicas = 1, .num_partitions = 1 };
|
|
}
|
|
|
|
pub fn withCompilationOptions(self: Platform, opts: CompilationOptions) Platform {
|
|
var res = self;
|
|
res.compilation_options = opts;
|
|
return res;
|
|
}
|
|
|
|
pub fn deinit(self: *Platform) void {
|
|
self.pjrt_client.deinit(self.pjrt_api);
|
|
}
|
|
|
|
/// Returns the Profiler for this API.
|
|
/// Not all platform have a profiling api, for those the profiler object will do nothing.
|
|
/// Platforms with known profiler extensions: cuda, xpu
|
|
pub fn getProfiler(self: Platform, options: pjrt.Profiler.Options) pjrt.Profiler {
|
|
return self.pjrt_client.getProfiler(self.pjrt_api, options);
|
|
}
|
|
|
|
/// Suspend the current co-routine while awaiting for a pjrt event to be over.
|
|
pub fn awaitEvent(self: Platform, event: *pjrt.Event) !void {
|
|
defer event.deinit(self.pjrt_api);
|
|
// If we aren't in a coroutine just use the normal blocking api.
|
|
if (!asynk.inCoro()) {
|
|
return try event.await_(self.pjrt_api);
|
|
}
|
|
|
|
var ctx = struct {
|
|
err: ?*pjrt.Error = null,
|
|
notif: asynk.Notification,
|
|
ready: bool = false,
|
|
}{
|
|
.notif = try asynk.Notification.init(),
|
|
};
|
|
defer ctx.notif.deinit();
|
|
|
|
try event.onReady(self.pjrt_api, &(struct {
|
|
fn call(err: ?*pjrt.Error, user_arg: ?*anyopaque) callconv(.C) void {
|
|
const ctx_: *@TypeOf(ctx) = @ptrCast(@alignCast(user_arg.?));
|
|
ctx_.err = err;
|
|
@atomicStore(bool, &ctx_.ready, true, .seq_cst);
|
|
ctx_.notif.notify() catch @panic("Unable to notify");
|
|
}
|
|
}.call), &ctx);
|
|
// Suspend
|
|
try ctx.notif.wait();
|
|
if (ctx.err) |e| {
|
|
defer e.deinit(self.pjrt_api);
|
|
return e.getCode(self.pjrt_api).toApiError();
|
|
}
|
|
}
|
|
};
|