209 lines
8.2 KiB
Zig
209 lines
8.2 KiB
Zig
const std = @import("std");
|
|
|
|
const runtimes = @import("runtimes");
|
|
pub const Target = runtimes.Platform;
|
|
const stdx = @import("stdx");
|
|
|
|
const pjrt = @import("pjrtx.zig");
|
|
|
|
const log = std.log.scoped(.zml);
|
|
|
|
pub const available_targets = std.enums.values(Target);
|
|
|
|
test {
|
|
std.testing.refAllDecls(@This());
|
|
}
|
|
|
|
pub const CompilationOptions = struct {
|
|
xla_dump_to: ?[]const u8 = null,
|
|
xla_dump_fusion_visualization: bool = false,
|
|
xla_dump_hlo_pass_re: ?[]const u8 = null,
|
|
sharding_enabled: bool = false,
|
|
sharding_axes: stdx.BoundedArray([*:0]const u8, 8) = .{},
|
|
device_memory_size: u64 = 0,
|
|
};
|
|
|
|
pub const Platform = struct {
|
|
target: Target,
|
|
pjrt_api: *const pjrt.Api,
|
|
pjrt_client: *pjrt.Client,
|
|
|
|
// This make the pjrt struct quite fat, but is only used during compilation.
|
|
// TODO: Reconsider having it here, and maybe pass explicitly to compile,
|
|
// or create an intermediary struct:
|
|
// `const comp = platform.compiler(compile_opts); const exe = comp.compile(...);`
|
|
compilation_options: CompilationOptions = .{},
|
|
|
|
pub const MAX_NUM_DEVICES: u8 = if (runtimes.isEnabled(.tpu)) 32 else 8;
|
|
pub const CreateOptions = _CreateOptions;
|
|
|
|
pub fn init(target: Target, api: *const pjrt.Api, options: CreateOptions) !Platform {
|
|
var named_values_buf: [16]pjrt.NamedValue = undefined;
|
|
const pjrt_client = try pjrt.Client.init(api, options.toNamedValues(target, &named_values_buf));
|
|
const true_num_devices = pjrt_client.getAddressableDevices(api).len;
|
|
if (true_num_devices > MAX_NUM_DEVICES) {
|
|
log.warn("platform {} got {} devices, but ZML only support up to {} devices. Some devices won't be used.", .{ target, true_num_devices, MAX_NUM_DEVICES });
|
|
}
|
|
return .{
|
|
.target = target,
|
|
.pjrt_api = api,
|
|
.pjrt_client = pjrt_client,
|
|
.compilation_options = .{},
|
|
};
|
|
}
|
|
|
|
pub fn getDevices(self: Platform) []const *const pjrt.Device {
|
|
const all_devices = self.pjrt_client.getAddressableDevices(self.pjrt_api);
|
|
if (all_devices.len > MAX_NUM_DEVICES) {
|
|
return all_devices[0..MAX_NUM_DEVICES];
|
|
}
|
|
return all_devices;
|
|
}
|
|
|
|
pub const Sharding = struct { num_replicas: u8, num_partitions: u8 };
|
|
|
|
pub fn sharding(self: Platform) Sharding {
|
|
// replicas run the same function but with different inputs,
|
|
// while partitions contribute to one evaluation over a shared input.
|
|
// Inside an inference process, we generally don't want replicas,
|
|
// as it's best to fully isolate replicas on different processes.
|
|
// For now we hardcode num_replicas = 1.
|
|
const num_devices: u8 = @intCast(self.getDevices().len);
|
|
return if (self.compilation_options.sharding_enabled)
|
|
.{ .num_replicas = 1, .num_partitions = num_devices }
|
|
else
|
|
.{ .num_replicas = 1, .num_partitions = 1 };
|
|
}
|
|
|
|
pub fn withCompilationOptions(self: Platform, opts: CompilationOptions) Platform {
|
|
var res = self;
|
|
res.compilation_options = opts;
|
|
return res;
|
|
}
|
|
|
|
pub fn deinit(self: *Platform) void {
|
|
self.pjrt_client.deinit(self.pjrt_api);
|
|
}
|
|
|
|
pub fn memoryForDevice(platform: Platform, memory: pjrt.Memory.Kind, device: *const pjrt.Device) *const pjrt.Memory {
|
|
const memory_target: pjrt.Memory.Kind = switch (memory) {
|
|
.host_unpinned => switch (platform.target) {
|
|
// Cuda doesn't have host_unpinned.
|
|
.cuda => .host_pinned,
|
|
else => .host_unpinned,
|
|
},
|
|
inline else => |t| t,
|
|
};
|
|
// TODO measure the cost of this and consider caching.
|
|
const device_memories = device.addressableMemories(platform.pjrt_api);
|
|
for (device_memories) |m| {
|
|
if (memory_target == m.kind(platform.pjrt_api)) {
|
|
return m;
|
|
}
|
|
}
|
|
log.err("Platform {t} doesn't have memory {t}", .{ platform.target, memory });
|
|
@panic("Memory kind not found");
|
|
}
|
|
|
|
test memoryForDevice {
|
|
const zml = @import("zml.zig");
|
|
const platform = zml.testing.env();
|
|
const memory_fields = @typeInfo(pjrt.Memory.Kind).@"enum".fields;
|
|
inline for (memory_fields) |field| {
|
|
for (platform.getDevices()) |dev| {
|
|
_ = platform.memoryForDevice(@field(pjrt.Memory.Kind, field.name), dev);
|
|
}
|
|
}
|
|
}
|
|
|
|
pub fn memoryStats(platform: Platform, device_id: usize) pjrt.MemoryStats {
|
|
if (platform.target == .cpu) return .zeroes;
|
|
|
|
const device = platform.getDevices()[device_id];
|
|
return device.memoryStats(platform.pjrt_api) catch .zeroes;
|
|
}
|
|
};
|
|
|
|
const _CreateOptions = struct {
|
|
cpu: Cpu = .{ .device_count = 4 },
|
|
|
|
// bump memory fraction from XLA defaults of 75% to 90%.
|
|
// Even on a 8GB GPU it should leave enough space for the Cuda driver
|
|
// https://github.com/openxla/xla/blob/3e87afa11a865cf91137522492918ad18bfe5b7c/xla/pjrt/plugin/xla_gpu/xla_gpu_allocator_config.h#L25-L60
|
|
cuda: Cuda = .{ .allocator = .{ .bfc = .{ .preallocate = true, .memory_fraction = 0.90 } } },
|
|
rocm: struct {} = .{},
|
|
tpu: struct {} = .{},
|
|
neuron: struct {} = .{},
|
|
|
|
pub const Cpu = struct {
|
|
device_count: u32,
|
|
|
|
fn writeNamedValues(self: Cpu, values: *std.ArrayList(pjrt.NamedValue)) void {
|
|
values.appendAssumeCapacity(pjrt.NamedValue.from("cpu_device_count", @as(i64, self.device_count)));
|
|
}
|
|
};
|
|
|
|
pub const Cuda = struct {
|
|
allocator: Allocator = .{ .bfc = .{} },
|
|
// TODO support all of https://github.com/openxla/xla/blob/3d31c48c719d331d432132b3e0c2c5ce52650675/xla/pjrt/c/pjrt_c_api_gpu_internal.cc#L76-L86
|
|
// visible_devices: []const i64 = &.{},
|
|
// node_id
|
|
// num_nodes
|
|
// enable_mock_nccl
|
|
// mock_gpu_topology
|
|
|
|
pub const Allocator = union(enum) {
|
|
/// "Best-Fit with Coalescing" algorithm
|
|
bfc: Options,
|
|
/// use cudaMallocAsync
|
|
async: Options,
|
|
/// use raw cuMalloc
|
|
platform,
|
|
|
|
pub const Options = struct {
|
|
preallocate: bool = true,
|
|
memory_fraction: f32 = 0.90,
|
|
collective_memory_size_mb: u32 = 0,
|
|
};
|
|
};
|
|
|
|
fn writeNamedValues(self: Cuda, values: *std.ArrayList(pjrt.NamedValue)) void {
|
|
switch (self.allocator) {
|
|
.platform => {
|
|
values.appendAssumeCapacity(.fromString("allocator", "platform"));
|
|
},
|
|
.bfc, .async => |opt| {
|
|
values.appendAssumeCapacity(.fromString("allocator", switch (self.allocator) {
|
|
.bfc => "bfc",
|
|
.async => "cuda_async",
|
|
.platform => unreachable,
|
|
}));
|
|
values.appendAssumeCapacity(.from("preallocate", opt.preallocate));
|
|
if (opt.memory_fraction > 0) {
|
|
values.appendAssumeCapacity(.from("memory_fraction", opt.memory_fraction));
|
|
}
|
|
if (opt.collective_memory_size_mb > 0) {
|
|
const collective = @as(i64, opt.collective_memory_size_mb) * 1024 * 1024;
|
|
values.appendAssumeCapacity(.from("collective_memory_size", collective));
|
|
}
|
|
},
|
|
}
|
|
}
|
|
};
|
|
|
|
pub fn toNamedValues(self: _CreateOptions, target: Target, out: []pjrt.NamedValue) []pjrt.NamedValue {
|
|
var values = std.ArrayList(pjrt.NamedValue).fromOwnedSlice(out);
|
|
values.shrinkRetainingCapacity(0);
|
|
switch (target) {
|
|
.cpu => self.cpu.writeNamedValues(&values),
|
|
.cuda => self.cuda.writeNamedValues(&values),
|
|
inline else => |t| {
|
|
stdx.debug.assertComptime(@hasField(_CreateOptions, @tagName(t)), "zml.platform.CreateOptions doesn't list target {s}", .{@tagName(t)});
|
|
const options = @field(self, @tagName(t));
|
|
stdx.debug.assertComptime(@sizeOf(@TypeOf(options)) == 0, "zml.platform.CreateOptions.{s} is discarded", .{@tagName(t)});
|
|
},
|
|
}
|
|
return values.items;
|
|
}
|
|
};
|