Radix/zml/platform.zig

73 lines
1.9 KiB
Zig

const builtin = @import("builtin");
const std = @import("std");
const aio = @import("aio.zig");
const meta = @import("meta.zig");
const module = @import("module.zig");
const pjrt = @import("pjrtx.zig");
const pjrt_core = @import("pjrt");
pub const Target = enum {
cpu,
cuda,
rocm,
tpu,
};
pub const available_targets = switch (builtin.os.tag) {
.macos => [_]Target{
.cpu,
},
.linux => [_]Target{
.cpu,
.cuda,
.rocm,
.tpu,
},
else => [_]Target{},
};
pub const CompilationOptions = struct {
xla_dump_to: ?[]const u8 = null,
xla_dump_fusion_visualization: bool = false,
cache_location: ?[]const u8 = null,
};
pub const Platform = struct {
target: Target,
pjrt_api: *const pjrt.Api,
pjrt_client: *pjrt.Client,
compilation_options: CompilationOptions = .{},
pub fn init(target: Target, api: *const pjrt.Api) !Platform {
const pjrt_client = try pjrt.Client.init(api, &.{});
return .{
.target = target,
.pjrt_api = api,
.pjrt_client = pjrt_client,
.compilation_options = .{},
};
}
pub fn getDevices(self: Platform) []const *const pjrt_core.Device {
return self.pjrt_client.getAddressableDevices(self.pjrt_api);
}
pub fn withCompilationOptions(self: Platform, opts: CompilationOptions) Platform {
var res = self;
res.compilation_options = opts;
return res;
}
pub fn deinit(self: *Platform) void {
self.pjrt_client.deinit(self.pjrt_api);
}
/// Returns the Profiler for this API.
/// Not all platform have a profiling api, for those the profiler object will do nothing.
/// Platforms with known profiler extensions: cuda, xpu
pub fn getProfiler(self: Platform, options: pjrt_core.Profiler.Options) pjrt_core.Profiler {
return self.pjrt_client.getProfiler(self.pjrt_api, options);
}
};