Add in-process sharding support across core ZML components (platform, shape, tensor, MLIR generation, buffers, and PJRT integration)

This commit is contained in:
Tarry Singh 2023-02-24 17:33:14 +00:00
parent cad1a688da
commit 2f129f76c9
11 changed files with 567 additions and 280 deletions

View File

@ -47,7 +47,7 @@ pub fn MlirTypeMethods(comptime InnerT: type) type {
/// Alternative to MlirWrapperType
pub const MlirStrCallback = fn (c.MlirStringRef, ?*anyopaque) callconv(.C) void;
fn MlirHelpersMethods(comptime OuterT: type) type {
fn MlirHelpersMethods(OuterT: type) type {
switch (@typeInfo(OuterT)) {
.Struct => |info| {
if (info.fields.len != 1) @compileError("Mlir wrapper type can only wrap one Mlir value. Received: " ++ @typeName(OuterT));
@ -382,6 +382,10 @@ pub const StringAttribute = struct {
pub fn value(self: Self) []const u8 {
return fromStringRef(c.mlirStringAttrGetValue(self.inner()));
}
pub fn asAttr(self: StringAttribute) Attribute {
return .{ ._inner = self._inner };
}
};
pub const UnitAttribute = struct {
@ -493,6 +497,10 @@ pub fn IntegerAttribute(comptime it: IntegerTypes) type {
pub fn get(value: IntAttr) ZigType {
return @intCast(getter(value.inner()));
}
pub fn asAttr(self: IntAttr) Attribute {
return .{ ._inner = self._inner };
}
};
}
@ -731,23 +739,29 @@ pub const DictionaryAttribute = struct {
.equal_fn = c.mlirAttributeEqual,
});
const Self = DictionaryAttribute;
pub fn init(ctx: Context, attributes: []const NamedAttribute) Self {
return Self.wrap(c.mlirDictionaryAttrGet(ctx.inner(), @intCast(attributes.len), @ptrCast(attributes.ptr)));
pub fn init(ctx: Context, attributes: []const NamedAttribute) DictionaryAttribute {
return DictionaryAttribute.wrap(c.mlirDictionaryAttrGet(
ctx.inner(),
@intCast(attributes.len),
@ptrCast(attributes.ptr),
));
}
pub fn size(self: Self) usize {
pub fn size(self: DictionaryAttribute) usize {
return @intCast(c.mlirDictionaryAttrGetNumElements(self.inner()));
}
pub fn get(self: Self, pos: usize) NamedAttribute {
pub fn get(self: DictionaryAttribute, pos: usize) NamedAttribute {
return NamedAttribute.wrap(c.mlirDictionaryAttrGetElement(self.inner(), @intCast(pos)));
}
pub fn getByName(self: Self, name: [:0]const u8) ?NamedAttribute {
pub fn getByName(self: DictionaryAttribute, name: [:0]const u8) ?NamedAttribute {
return NamedAttribute.wrapOr(c.mlirDictionaryAttrGetElementByName(self.inner(), name));
}
pub fn asAttr(self: DictionaryAttribute) Attribute {
return .{ ._inner = self._inner };
}
};
pub const Operation = struct {

View File

@ -727,11 +727,11 @@ pub const Event = opaque {
return ret.is_ready;
}
pub fn getEventError(self: *const Event, api: *const Api) ApiError!?*Error {
const ret = try api.call(.PJRT_Event_Error, .{
.event = self.inner(),
});
return @ptrCast(ret);
pub fn getEventError(self: *const Event, api: *const Api) ?*Error {
var args: Api.CallFnArgType(.PJRT_Event_Error) = .{ .event = self.inner() };
args = pjrtStruct(args);
const result: ?*c.PJRT_Error = api.inner.PJRT_Event_Error.?(&args);
return @ptrCast(result);
}
pub fn await_(self: *const Event, api: *const Api) ApiError!void {

View File

@ -398,7 +398,12 @@ pub fn loadModelBuffers(
) !zml.Bufferized(Model) {
return try loadModelBuffersWithPrefix(Model, model, buffer_store, allocator, platform, "");
}
/// Creates a bufferized version of a Model from the given BufferStore and the given prefix.
/// For details about bufferization, see the documentation of Bufferized(T).
///
/// This will represent the weights of the model, loaded on a specific platform.
/// It can be used with a `module.Exe` (a compiled version of the same Model), to make a
/// `module.ExeWithWeights` ready to be called.
pub fn loadModelBuffersWithPrefix(
comptime Model: type,
model: Model,
@ -408,12 +413,12 @@ pub fn loadModelBuffersWithPrefix(
prefix: []const u8,
) !zml.Bufferized(Model) {
// Allocate the bufferized version.
// We set fields to undefined, cause visitStructAndLoadBuffer is responsible
// We copy the shape, and let visitStructAndLoadBuffer write the other fields.
// to write them just afterward.
var res: zml.Bufferized(Model) = undefined;
try zml.meta.mapAlloc(struct {
pub fn initBuffer(_: void, _: zml.Tensor) zml.Buffer {
return undefined;
pub fn initBuffer(_: void, tensor: zml.Tensor) zml.Buffer {
return .{ ._shape = tensor.shape(), ._api = undefined, ._shards = undefined };
}
}.initBuffer, allocator, {}, model, &res);
@ -425,32 +430,6 @@ pub fn loadModelBuffersWithPrefix(
return res;
}
/// Creates a bufferized version of a Model from the given BufferStore and the given prefix.
/// For details about bufferization, see the documentation of Bufferized(T).
///
/// This will represent the weights of the model, loaded on a specific platform.
/// It can be used with a `module.Exe` (a compiled version of the same Model), to make a
/// `module.ExeWithWeights` ready to be called.
pub fn loadBuffersFromModelWithPrefix(comptime Model: type, model: Model, buffer_store: BufferStore, allocator: std.mem.Allocator, prefix: []const u8, platform: zml.Platform) !zml.Bufferized(Model) {
// Allocate the bufferized version.
// We set fields to undefined, cause visitStructAndLoadBuffer is responsible
// to write them just afterward.
var res: zml.Bufferized(Model) = undefined;
try zml.meta.mapAlloc(struct {
pub fn initBuffer(_: void, _: zml.Tensor) zml.Buffer {
return undefined;
}
}.initBuffer, allocator, {}, model, &res);
var prefix_builder: PrefixBuilder = .{};
defer prefix_builder.deinit(allocator);
try prefix_builder.push(allocator, prefix);
try visitStructAndLoadBuffer(allocator, &prefix_builder, buffer_store, &res, platform);
return res;
}
/// Takes a bufferized version of a `model`, ie a mirror struct of the `model`, and deinit all the
/// Buffer found.
pub fn unloadBuffers(model: anytype) void {
@ -474,7 +453,12 @@ fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *Prefi
const prefix = prefix_builder.data.items;
if (T == zml.Buffer) {
return if (buffer_store.get(prefix)) |host_buffer| {
obj.* = try zml.Buffer.from(platform, host_buffer);
// obj._shape has been set inside `loadModelBuffersWithPrefix`, before calling us.
var buf_with_metadata = host_buffer;
log.warn("loading {s} ({})", .{ prefix, obj._shape });
zml.meta.assert(host_buffer.shape().eql(obj._shape), "loadModelBuffers expects to find the same shapes in the model and in the buffer store, got {} and {} for tensor {s}", .{ obj._shape, host_buffer, prefix });
buf_with_metadata._shape = obj._shape;
obj.* = try zml.Buffer.from(platform, buf_with_metadata);
} else {
return error.BufferNotFound;
};
@ -484,10 +468,7 @@ fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *Prefi
.Pointer => |ptr_info| {
if (ptr_info.size == .Slice) {
for (obj.*, 0..) |*value, i| {
var buffer: [100]u8 = undefined;
const new_prefix = std.fmt.bufPrint(&buffer, "{d}", .{i}) catch unreachable;
try prefix_builder.push(allocator, new_prefix);
try prefix_builder.pushDigit(allocator, i);
defer prefix_builder.pop();
try visitStructAndLoadBuffer(allocator, prefix_builder, buffer_store, value, platform);
@ -502,13 +483,9 @@ fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *Prefi
try visitStructAndLoadBuffer(allocator, prefix_builder, buffer_store, &@field(obj, field.name), platform);
}
},
.Optional => |opt_info| {
var child = @as(opt_info.child, undefined);
if (visitStructAndLoadBuffer(allocator, prefix_builder, buffer_store, &child, platform)) {
obj.* = child;
} else |err| switch (err) {
error.BufferNotFound => {},
else => return err,
.Optional => {
if (obj.*) |*obj_val| {
try visitStructAndLoadBuffer(allocator, prefix_builder, buffer_store, obj_val, platform);
}
},
else => {},

View File

@ -3,7 +3,6 @@ const testing = std.testing;
const meta = @import("meta.zig");
const pjrt = @import("pjrt");
const pjrtx = @import("pjrtx.zig");
const Context = @import("context.zig").Context;
const Data = @import("dtype.zig").Data;
@ -13,9 +12,12 @@ const Platform = @import("platform.zig").Platform;
const Shape = @import("shape.zig").Shape;
test {
std.testing.refAllDecls(@This());
std.testing.refAllDecls(Buffer);
}
const log = std.log.scoped(.zml);
/// Buffer is a multi-dimension array, whose memory is allocated on an accelerator.
///
/// * contains a handle that the ZML runtime can use to convert into a physical address, but there is no guarantee this address is visible from the CPU.
@ -23,33 +25,70 @@ test {
/// * can be created by calling `HostBuffer.toDevice(platform)`.
pub const Buffer = struct {
_shape: Shape,
_shards: Shape = undefined,
_platform: Platform,
_data: *pjrtx.Buffer,
_api: *const pjrt.Api,
_shards: Shards,
pub const MAX_NUM_SHARDS: u8 = Platform.MAX_NUM_DEVICES;
pub const Shards = std.BoundedArray(*pjrt.Buffer, MAX_NUM_SHARDS);
/// Copies the content of the given buffer from host memory to the accelerator memory.
pub fn from(platform: Platform, buf: HostBuffer) !Buffer {
pub fn from(platform: Platform, host_buffer: HostBuffer) !Buffer {
var res: Buffer = .{
._api = platform.pjrt_api,
._shape = host_buffer.shape(),
._shards = .{},
};
// We shard only on the first axis so that the chunks are still contiguous.
// TODO: support more advanced sharding specs
meta.assert(platform.sharding().num_replicas == 1, "ZML doesn't support num_replicas > 1 for now, got: {}", .{platform.sharding()});
const sharding_ax: ?u3 = std.simd.firstTrue(host_buffer.shape()._sharding_info);
const n_partitions = platform.sharding().num_partitions;
const chunk_size = if (sharding_ax) |ax| cs: {
// This kind of sharding error should be detected earlier on.
meta.assert(@rem(host_buffer.dim(ax), n_partitions) == 0, "Buffer.from({}) expects the sharding axis {} to have a dimension divisble by the number of devices ({}).", .{ host_buffer, ax, n_partitions });
break :cs @divExact(host_buffer.dim(ax), n_partitions);
} else 0;
const buffer_type = bufferTypeFromDtype(host_buffer.shape().dtype());
const byte_strides = host_buffer.strides() orelse host_buffer.shape().computeStrides().constSlice();
const devices = platform.getDevices();
for (0..n_partitions) |i| {
// If no sharding if found, the given buffer is replicated on all devices.
const buf = if (sharding_ax) |ax| buf: {
const start: i64 = @as(i64, @intCast(i)) * chunk_size;
break :buf host_buffer.slice1d(ax, .{ .start = start, .end = start + chunk_size });
} else host_buffer;
const pjrt_buffer = try platform.pjrt_client.bufferFromHostBuffer(platform.pjrt_api, .{
.data = buf.data,
.buffer_type = pjrtx.Buffer.BufferTypeFromDType(buf.shape().dtype()),
.buffer_type = buffer_type,
.dims = buf.shape().dims(),
.byte_strides = buf.strides(),
.device = platform.getDevices()[0],
.byte_strides = byte_strides,
.device = devices[i],
.host_buffer_semantics = .ImmutableUntilTransferCompletes,
});
return .{
._platform = platform,
._shape = buf.shape(),
._data = pjrt_buffer,
};
res._shards.appendAssumeCapacity(pjrt_buffer);
}
return res;
}
/// Wraps a pre-exisiting `pjrt.Buffer` into a `zml.Buffer`.
pub fn fromPjrtBuffer(platform: Platform, pjrt_buffer: *pjrtx.Buffer) Buffer {
/// Wraps pre-exisiting `pjrt.Buffer` shards into one `zml.Buffer`.
pub fn fromPjrtBuffers(platform: Platform, pjrt_buffers: []const *pjrt.Buffer) Buffer {
meta.assert(pjrt_buffers.len <= MAX_NUM_SHARDS, "ZML doesn't support having more than {} shards. Received {} shards for one buffer.", .{ MAX_NUM_SHARDS, pjrt_buffers.len });
meta.assert(pjrt_buffers.len > 0, "fromPjrtBuffers expects at least one buffer, got 0.", .{});
var shards: Shards = .{};
shards.appendSliceAssumeCapacity(pjrt_buffers);
return .{
._platform = platform,
._shape = _shapeFromPjrtBuffer(platform, pjrt_buffer),
._data = pjrt_buffer,
._api = platform.pjrt_api,
._shape = Shape.init(
// This isn't with sharded axes.
pjrt_buffers[0].getDimensions(platform.pjrt_api),
dtypeFromBufferType(pjrt_buffers[0].getElementType(platform.pjrt_api)),
),
._shards = shards,
};
}
@ -112,8 +151,9 @@ pub const Buffer = struct {
const pjrt_buffer = try platform.pjrt_client.createViewOfDeviceBuffer(platform.pjrt_api, .{
.data = buf.data,
.element_type = pjrtx.Buffer.BufferTypeFromDType(buf.shape().dtype()),
.element_type = bufferTypeFromDtype(buf.shape().dtype()),
.dims = buf.shape().dims(),
// TODO: split in shards
.device = platform.getDevices()[0],
.layout = .{
.Tiled = .{
@ -124,10 +164,12 @@ pub const Buffer = struct {
},
});
var shards: Shards = .{};
shards.appendAssumeCapacity(pjrt_buffer);
return .{
._platform = platform,
._api = platform.pjrt_api,
._shape = buf.shape(),
._data = pjrt_buffer,
._shards = shards,
};
}
@ -135,7 +177,9 @@ pub const Buffer = struct {
pub fn getValue(self: Buffer, T: type) !T {
meta.assert(self._shape.byteSize() == @sizeOf(T), "Buffer {} has {d} bytes of data, can't load it to a {s} with {d} bytes", .{ self, self._shape.byteSize(), @typeName(T), @sizeOf(T) });
var res: T = undefined;
try self._data.toHostBuffer(self._platform.pjrt_api, std.mem.asBytes(&res));
meta.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer -> Host transfer", .{});
const event = try self._shards.get(0).toHostBuffer(self._api, std.mem.asBytes(&res));
try event.await_(self._api);
return res;
}
@ -143,7 +187,9 @@ pub const Buffer = struct {
/// and return a new `HostBuffer` object with the same shape.
/// The returned `HostBuffer` doesn't own the memory.
pub fn toHost(self: Buffer, output: []u8) !HostBuffer {
try self._data.toHostBuffer(self._platform.pjrt_api, output);
meta.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer -> Host transfer", .{});
const event = try self._shards.get(0).toHostBuffer(self._api, output);
try event.await_(self._api);
return HostBuffer.fromBytes(self.shape(), output);
}
@ -151,7 +197,9 @@ pub const Buffer = struct {
/// The returned `HostBuffer` does own the memory.
pub fn toHostAlloc(self: Buffer, allocator: std.mem.Allocator) !HostBuffer {
const output = try HostBuffer.empty(allocator, self.shape());
try self._data.toHostBuffer(self._platform.pjrt_api, @constCast(output.data));
meta.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer -> Host transfer", .{});
const event = try self._shards.get(0).toHostBuffer(self._api, @constCast(output.data));
try event.await_(self._api);
return output;
}
@ -159,7 +207,9 @@ pub const Buffer = struct {
/// Depending on the platform, the memory is typically not released to the OS
/// but just marked as available in the memory pool.
pub fn deinit(self: *const Buffer) void {
self._data.deinit(self._platform.pjrt_api);
for (self._shards.constSlice()) |buffer| {
buffer.deinit(self._api);
}
}
/// This Buffer shape.
@ -202,9 +252,41 @@ pub const Buffer = struct {
try writer.print("Buffer({_})", .{self._shape});
}
fn _shapeFromPjrtBuffer(platform: Platform, buf: *pjrtx.Buffer) Shape {
const dt: DataType = switch (buf.getElementType(platform.pjrt_api)) {
// Please keep the list exhaustive and in the same order than in DataType.
fn hasShardedAxis(self: Buffer) bool {
if (self._shards.len == 1) return false;
return @reduce(.Or, self._shape._sharding_info);
}
};
pub fn bufferTypeFromDtype(dt: DataType) pjrt.BufferType {
return switch (dt) {
.bool => .PRED,
.f8e4m3b11fnuz => .F8E4M3B11FNUZ,
.f8e4m3fn => .F8E4M3FN,
.f8e4m3fnuz => .F8E4M3FNUZ,
.f8e5m2 => .F8E5M2,
.f8e5m2fnuz => .F8E5M2FNUZ,
.bf16 => .BF16,
.f16 => .F16,
.f32 => .F32,
.f64 => .F64,
.i8 => .S8,
.i4 => .S4,
.i16 => .S16,
.i32 => .S32,
.i64 => .S64,
.u4 => .U4,
.u8 => .U8,
.u16 => .U16,
.u32 => .U32,
.u64 => .U64,
.c64 => .C64,
.c128 => .C128,
};
}
pub fn dtypeFromBufferType(pjrt_type: pjrt.BufferType) DataType {
return switch (pjrt_type) {
.PRED => .bool,
.F8E4M3B11FNUZ => .f8e4m3b11fnuz,
.F8E4M3FN => .f8e4m3fn,
@ -215,8 +297,8 @@ pub const Buffer = struct {
.F16 => .f16,
.F32 => .f32,
.F64 => .f64,
.S4 => .i4,
.S8 => .i8,
.S4 => .i4,
.S16 => .i16,
.S32 => .i32,
.S64 => .i64,
@ -227,9 +309,19 @@ pub const Buffer = struct {
.U64 => .u64,
.C64 => .c64,
.C128 => .c128,
.INVALID => @panic("Can't handle INVALID Pjrt buffers."),
.INVALID => @panic("Found an invalid pjrt buffer"),
};
}
return Shape.init(buf.getDimensions(platform.pjrt_api), dt);
test bufferTypeFromDtype {
inline for (@typeInfo(DataType).Enum.fields) |field| {
const dt: DataType = @enumFromInt(field.value);
try std.testing.expectEqual(dt, dtypeFromBufferType(bufferTypeFromDtype(dt)));
}
};
inline for (@typeInfo(pjrt.BufferType).Enum.fields) |field| {
const dt: pjrt.BufferType = @enumFromInt(field.value);
if (dt == .INVALID) continue;
try std.testing.expectEqual(dt, bufferTypeFromDtype(dtypeFromBufferType(dt)));
}
}

View File

@ -147,9 +147,15 @@ pub const HostBuffer = struct {
return try Buffer.from(platform_, self);
}
/// Interpret the underlying data as a contiguous slice.
/// WARNING: It's only valid if the buffer is contiguous.
/// Strided buffers can't use this method.
pub fn items(self: HostBuffer, comptime T: type) []const T {
if (DataType.fromZigType(T) != self.dtype()) {
std.debug.panic("Can't reinterpret HostBuffer({_}) as {s}", .{ self.shape(), @typeName(T) });
std.debug.panic("Can't reinterpret {} as {s}", .{ self, @typeName(T) });
}
if (!self.isContiguous()) {
std.debug.panic("{} isn't contiguous", .{self});
}
const ptr: [*]const T = @alignCast(@ptrCast(self.data.ptr));
return ptr[0..self._shape.count()];
@ -180,16 +186,65 @@ pub const HostBuffer = struct {
return self._shape.count();
}
pub fn dim(self: HostBuffer, axis: anytype) i64 {
return self._shape.dim(axis);
pub fn dim(self: HostBuffer, axis_: anytype) i64 {
return self._shape.dim(axis_);
}
pub fn axis(self: HostBuffer, axis_: anytype) u3 {
return self._shape.axis(axis_);
}
pub fn isContiguous(self: HostBuffer) bool {
const strd = self._strides orelse return true;
const cont_strides = self._shape.computeStrides();
return std.mem.eql(i64, strd[0..self.rank()], cont_strides.constSlice());
}
pub fn reshape(self: HostBuffer, shape_: anytype) HostBuffer {
meta.assert(self._strides == null, "reshape expects a contiguous tensor, got: {}", .{self});
meta.assert(self.isContiguous(), "reshape expects a contiguous tensor, got: {}", .{self});
var res = self;
res._shape = self._shape.reshape(shape_);
return res;
}
pub const Slice = struct {
start: i64 = 0,
end: ?i64 = null,
};
/// Slices the input Tensor over the given axis_ using the given parameters.
pub fn slice1d(self: HostBuffer, axis_: anytype, s: Slice) HostBuffer {
const ax = self._shape.axis(axis_);
const d = self.dim(ax);
const start: i64 = if (s.start < 0) s.start + d else s.start;
var end = s.end orelse d;
if (end < 0) end += d;
meta.assert(start >= 0 and start < d, "slice1d({}, {}) expects the slice start to be between 0 and {} got: {}", .{ self, ax, d, start });
meta.assert(end >= 1 and end <= d, "slice1d({}, {}) expects the slice end to be between 1 and {} got: {}", .{ self, ax, d, end });
meta.assert(start < end, "slice1d({}, {}) expects the slice start ({}) to be smaller than the end ({})", .{ self, ax, start, end });
// If strides weren't set it means original buffer is contiguous.
// But it won't be anymore after slicing. The strides don't change though.
const _strides = self._strides orelse self._shape.computeStrides().buffer;
const offset: usize = @intCast(start * _strides[ax]);
return .{
._shape = self.shape().set(ax, end - start),
.data = self.data[offset..],
._strides = _strides,
._memory = .unmanaged,
};
}
pub fn format(
self: HostBuffer,
comptime fmt: []const u8,
options: std.fmt.FormatOptions,
writer: anytype,
) !void {
_ = fmt;
_ = options;
try writer.print("HostBuffer(.{_})", .{self._shape});
}
};
fn parseArrayInfo(T: type) Shape {

View File

@ -663,3 +663,36 @@ test zip {
const a_sum: A = try zip(Sum.call, testing.allocator, &[_]A{ a0, a1 }, .{});
try testing.expectEqual(A{ .a = 5, .b = .{ 7, 9 } }, a_sum);
}
/// Given a func(X) -> Y or a func(Ctx, X) -> Y,
/// finds all X in the given object, and write the result of func(X) into an arraylist.
pub fn collect(func: anytype, func_ctx: _CollectCtx(func), out: *std.ArrayList(FnResult(func)), obj: anytype) error{OutOfMemory}!void {
assertComptime(@typeInfo(@TypeOf(func)).Fn.params.len <= 2, "zml.meta.collect expects a func with two arguments, got: {}", .{@TypeOf(func)});
const LocalContext = struct {
func_ctx: _CollectCtx(func),
out: *std.ArrayList(FnResult(func)),
oom: bool = false,
};
var context = LocalContext{ .func_ctx = func_ctx, .out = out };
visit((struct {
fn cb(ctx: *LocalContext, val: *const _CollectArg(func)) void {
if (ctx.oom) return;
const res = if (_CollectCtx(func) == void) func(val.*) else func(ctx.func_ctx, val.*);
ctx.out.append(res) catch {
ctx.oom = true;
};
}
}).cb, &context, obj);
if (context.oom) return error.OutOfMemory;
}
fn _CollectCtx(func: anytype) type {
const params = @typeInfo(@TypeOf(func)).Fn.params;
if (params.len == 1) return void;
return params[0].type orelse @compileError("anytype not supported in collect");
}
fn _CollectArg(func: anytype) type {
const params = @typeInfo(@TypeOf(func)).Fn.params;
return params[params.len - 1].type orelse @compileError("anytype not supported in collect");
}

View File

@ -49,9 +49,11 @@ pub const CompilationContext = struct {
_unique_id: u64 = 10000,
_tracer: Tracer,
const TensorToBlockArg = std.AutoHashMapUnmanaged(Tensor._Id, struct { mlir.Value, Tensor._Donation });
threadlocal var _current: ?*CompilationContext = null;
const TensorToBlockArg = std.AutoHashMapUnmanaged(Tensor._Id, struct { mlir.Value, Tensor._Donation });
const AttributeList = std.BoundedArray(mlir.NamedAttribute, 3);
pub fn init(allocator: std.mem.Allocator, name: []const u8, platform: Platform) !CompilationContext {
const mlir_registry = mlir.Registry.init() catch unreachable;
inline for (.{ "func", "stablehlo" }) |d| {
@ -181,7 +183,7 @@ pub const CompilationContext = struct {
comptime func: anytype,
model: *const ModuleSignature(func).ModelT,
args: *const ModuleSignature(func).ArgsT,
opts: struct { add_donations_attributes: bool = false },
opts: struct { add_donations_attributes: bool = false, sharding: bool = true },
) error{OutOfMemory}!MlirFn {
const frame = self._tracer.frameStart("generateBytecode.emit");
errdefer self._tracer.frameEnd(frame, "generateBytecode.emit");
@ -197,17 +199,25 @@ pub const CompilationContext = struct {
const tensor_count = model_tensor_count + args_tensor_count;
const loc = self.mlirCtx().location(@src());
const mlir_ctx = self.mlirCtx();
const loc = mlir_ctx.location(@src());
const locations = try arena.alloc(mlir.Location, tensor_count);
for (locations) |*l| l.* = mlir.Location.unknown(self.mlirCtx());
var input_types = try arena.alloc(mlir.Type, tensor_count);
fillMlirTypes(model, self.mlirCtx(), input_types[0..model_tensor_count]);
fillMlirTypes(args, self.mlirCtx(), input_types[model_tensor_count..]);
@memset(locations, mlir.Location.unknown(mlir_ctx));
var input_shapes = try std.ArrayList(Shape).initCapacity(arena, tensor_count);
meta.collect(Tensor.shape, {}, &input_shapes, model) catch unreachable;
meta.internalAssert(input_shapes.items.len == model_tensor_count, "model has changed ?", .{});
meta.collect(Tensor.shape, {}, &input_shapes, args) catch unreachable;
meta.internalAssert(input_shapes.items.len == tensor_count, "args have changed ?", .{});
const input_types = try arena.alloc(mlir.Type, tensor_count);
for (input_types, input_shapes.items) |*t, sh| t.* = mlir.ext.mlirType(mlir_ctx, sh);
// Note: this isn't stricly necessary. We call `countTensor` on `fn_res`.
// But it forces user to have simpler function.
const out_tensor_count = comptime ops.staticCountTensors(ModuleSignature(func).ReturnT) orelse @compileError("Can't use " ++ @typeName(ModuleSignature(func).ReturnT) ++ " in an MLIR function, because it has a variable number of tensors");
// Those are returned to caller so we don't put them in the arena.
const fn_res_types = try allocator.alloc(mlir.Type, out_tensor_count);
const fn_res_shapes = try allocator.alloc(Shape, out_tensor_count);
const fn_res_donations = try allocator.alloc(Tensor._Donation, out_tensor_count);
@ -234,20 +244,25 @@ pub const CompilationContext = struct {
var fn_res_values: [out_tensor_count]mlir.Value = undefined;
self.extractValuesAndTypes(&fn_res, &fn_res_values, fn_res_types, fn_res_shapes, fn_res_donations);
const fn_ret = dialect.func.return_(self.mlirCtx(), &fn_res_values, loc);
const fn_ret = dialect.func.return_(mlir_ctx, &fn_res_values, loc);
fn_body.addOperationsRecursive(fn_ret);
}
const arg_attrs = try arena.alloc(AttributeList, tensor_count);
@memset(arg_attrs, .{});
// Donations attributes only make sense on the main function.
const attrs: []const mlir.Attribute = if (opts.add_donations_attributes)
try self.addDonationsAttribute(arena, fn_res_donations, tensor_count)
else
&.{};
if (opts.add_donations_attributes) {
self.addDonationsAttributes(arg_attrs, fn_res_donations);
}
if (opts.sharding) {
self.addShardingAttributes(arg_attrs, input_shapes.items);
}
const mlir_fn = dialect.func.func(self.mlirCtx(), .{
.sym_name = fn_name,
.args = input_types[0..],
.arg_attrs = attrs,
.args = input_types,
.arg_attrs = try finalizeAttributeList(arena, mlir_ctx, arg_attrs),
.results = fn_res_types,
.block = fn_body,
.location = loc,
@ -277,12 +292,7 @@ pub const CompilationContext = struct {
/// Given a list of donations mapping output buffers to input buffers,
/// generate donation attribute for each `n_args` input argument.
fn addDonationsAttribute(self: *const CompilationContext, allocator: std.mem.Allocator, donations: []const Tensor._Donation, n_args: usize) ![]mlir.Attribute {
const empty = mlir.DictionaryAttribute.init(self.mlirCtx(), &.{}).as(mlir.Attribute).?;
const arg_attrs = try allocator.alloc(mlir.Attribute, n_args);
@memset(arg_attrs, empty);
fn addDonationsAttributes(self: CompilationContext, attributes: []AttributeList, donations: []const Tensor._Donation) void {
var n_donations: usize = 0;
for (donations, 0..) |donation, index| {
switch (donation) {
@ -293,23 +303,23 @@ pub const CompilationContext = struct {
.input_buffer => {},
.arg => |a| {
n_donations += 1;
meta.assert(arg_attrs[a].eql(empty), "Donation error ! Argument {} has been donated twice ! To {} and to {}", .{ a, index, arg_attrs[a] });
arg_attrs[a] = mlir.DictionaryAttribute.init(self.mlirCtx(), &.{
// This will break the day we writer another attribute before donation.
// When the time come, do a more fancy lookup here to check if an argument
// is donated twice.
meta.assert(attributes[a].len == 0, "Donation error ! Argument {} has been donated twice ! To {} and to {}", .{ a, index, attributes[a].buffer[0] });
attributes[a].appendAssumeCapacity(
mlir.NamedAttribute.init(
mlir.Identifier.get(self.mlirCtx(), "tf.aliasing_output"),
mlir.IntegerAttribute(.i32).init(self.mlirCtx(), @intCast(index)).as(mlir.Attribute).?,
),
}).as(mlir.Attribute).?;
// log.debug("attribute: {}", .{arg_attrs[a]});
);
// log.debug("attribute: {}", .{attributes[a].constSlice()});
},
}
}
if (n_donations == 0) return &.{};
return arg_attrs;
}
test addDonationsAttribute {
test addDonationsAttributes {
const zml = @import("zml.zig");
const platform = zml.testing.env();
var arena = std.heap.ArenaAllocator.init(std.testing.allocator);
@ -343,12 +353,81 @@ pub const CompilationContext = struct {
// `%arg0` is the bias of the model, `%arg1` is `x`.
try std.testing.expectEqual(1, f.n_model);
try std.testing.expectEqual(1, f.n_args);
std.testing.expect(std.mem.indexOf(u8, mlir_bytecode.items, "%arg1: tensor<8xf16> {tf.aliasing_output = 0 : i32}") != null) catch |err| {
std.testing.expect(std.mem.indexOf(u8, mlir_bytecode.items, "tf.aliasing_output = 0 : i32") != null) catch |err| {
log.warn("Didn't produced the expected IR:\n{s}", .{mlir_bytecode.items});
return err;
};
}
fn addShardingAttributes(self: CompilationContext, attributes: []AttributeList, shapes: []const Shape) void {
const mlir_ctx = self.mlirCtx();
if (!self._platform.compilation_options.sharding_enabled) return;
const num_partitions = self._platform.sharding().num_partitions;
var sharding_str: std.BoundedArray(u8, 128) = .{};
const mhlo_default_layout = mlir.NamedAttribute.init(
mlir.Identifier.get(mlir_ctx, "mhlo.layout_mode"),
mlir.StringAttribute.init(mlir_ctx, "default").asAttr(),
);
for (attributes, shapes) |*attr, shape| {
attr.appendAssumeCapacity(mhlo_default_layout);
writeShardingRepresentation(shape, num_partitions, sharding_str.writer()) catch unreachable;
defer sharding_str.len = 0;
attr.appendAssumeCapacity(mlir.NamedAttribute.init(
mlir.Identifier.get(mlir_ctx, "mhlo.sharding"),
mlir.StringAttribute.init(mlir_ctx, sharding_str.constSlice()).asAttr(),
));
}
}
fn writeShardingRepresentation(shape: Shape, num_partitions: u8, writer: anytype) @TypeOf(writer).Error!void {
const n_sharded: u8 = @popCount(@as(u8, @bitCast(shape._sharding_info)));
if (n_sharded == 0 or num_partitions == 1) {
try writer.writeAll("{replicated}");
return;
}
try writer.writeAll("{devices=[");
for (0..shape.rank()) |i| {
try writer.print("{d}", .{if (shape._sharding_info[i]) num_partitions else 1});
if (i < shape.rank() - 1) try writer.writeByte(',');
}
try writer.print("]<=[{d}]}}", .{num_partitions});
}
test writeShardingRepresentation {
var rule: [64]u8 = undefined;
const x = Shape.init(.{ 16, 8 }, .f32);
// By default tensors are replicated.
{
var fbs = std.io.fixedBufferStream(&rule);
try writeShardingRepresentation(x, 4, fbs.writer());
try std.testing.expectEqualStrings("{replicated}", fbs.getWritten());
}
// Shard along first axis.
{
var fbs = std.io.fixedBufferStream(&rule);
try writeShardingRepresentation(x.withSharding(.{0}), 4, fbs.writer());
try std.testing.expectEqualStrings("{devices=[4,1]<=[4]}", fbs.getWritten());
}
// Also shard along second axis.
{
var fbs = std.io.fixedBufferStream(&rule);
try writeShardingRepresentation(x.withSharding(.{ 0, 1 }), 2, fbs.writer());
try std.testing.expectEqualStrings("{devices=[2,2]<=[2]}", fbs.getWritten());
}
}
fn finalizeAttributeList(allocator: std.mem.Allocator, mlir_ctx: mlir.Context, attributes: []AttributeList) ![]mlir.Attribute {
const res = try allocator.alloc(mlir.Attribute, attributes.len);
for (res, attributes) |*r, attr| {
r.* = mlir.DictionaryAttribute.init(mlir_ctx, attr.constSlice()).asAttr();
}
return res;
}
/// Generates an MLIR `func.call` of the given function.
/// If the function has not been seen yet, we generate MLIR for it,
/// in a independent function.
@ -565,47 +644,57 @@ fn assignBlockArguments(v: anytype, block: mlir.Block, start: usize) usize {
}
/// Visit the given struct and fill the `buffers` slice with the buffer associated with encountered Tensor.
fn fillBuffers(v: anytype, buffers: []*pjrt.Buffer) void {
fn fillBuffers(v: anytype, buffers: []const [*]*pjrt.Buffer, start: u32, len: u32) void {
const LocalContext = struct {
index: usize,
buffers: []*pjrt.Buffer,
index: u32,
buffers: []const [*]*pjrt.Buffer,
};
var ctx: LocalContext = .{
.index = 0,
var context: LocalContext = .{
.index = start,
.buffers = buffers,
};
meta.visit((struct {
fn cb(inner_context: *LocalContext, buffer: *const Buffer) void {
// meta.assert(!buffer._data.isDeleted(), "Can't use {} (argument buffer {}) because its pjrt buffer has been donated", .{ buffer, inner_context.index });
inner_context.buffers[inner_context.index] = buffer._data;
inner_context.index += 1;
fn cb(ctx: *LocalContext, buffer: *const Buffer) void {
// meta.assert(!buffer._data.isDeleted(), "Can't use {} (argument buffer {}) because its pjrt buffer has been donated", .{ buffer, ctx.index });
const model_sharding = ctx.buffers.len;
meta.assert(buffer._shards.len == model_sharding, "Can't feed a {}-sharded tensor into a {}-sharded model", .{ buffer._shards.len, ctx.buffers.len });
for (buffer._shards.constSlice(), 0..) |shard, d| {
ctx.buffers[d][ctx.index] = shard;
}
}).cb, &ctx, v);
assert(ctx.index == buffers.len);
ctx.index += 1;
}
}).cb, &context, v);
assert(context.index == start + len);
}
/// Visit the given struct and override tensors by creating a new one using the provided PJRT buffers.
pub fn assignRawBuffers(v: anytype, platform: Platform, buffers: []*pjrt.Buffer) void {
pub fn assignRawBuffers(v: anytype, platform: Platform, buffers: []const [*]*pjrt.Buffer, expected_count: u32) void {
const LocalContext = struct {
index: usize,
index: u32,
platform: Platform,
buffers: []*pjrt.Buffer,
buffers: []const [*]*pjrt.Buffer,
expected_count: u32,
};
var ctx: LocalContext = .{
var local_ctx: LocalContext = .{
.index = 0,
.platform = platform,
.buffers = buffers,
.expected_count = expected_count,
};
meta.visit((struct {
fn cb(inner_context: *LocalContext, buffer: *Buffer) void {
const i = inner_context.index;
if (i < inner_context.buffers.len) {
buffer.* = Buffer.fromPjrtBuffer(inner_context.platform, inner_context.buffers[i]);
fn cb(ctx: *LocalContext, buffer: *Buffer) void {
const i = ctx.index;
ctx.index += 1;
if (i >= ctx.expected_count) return;
var shards: Buffer.Shards = .{};
for (ctx.buffers) |buff| {
shards.appendAssumeCapacity(buff[i]);
}
inner_context.index += 1;
buffer.* = Buffer.fromPjrtBuffers(ctx.platform, shards.constSlice());
}
}).cb, &ctx, v);
meta.assert(ctx.index == buffers.len, "Pjrt call returned {} tensors, but the return type {s}, contains {} Buffers. Note that modules need to have a comptime know number of returned tensors.", .{ buffers.len, @typeName(@TypeOf(v)), ctx.index });
}).cb, &local_ctx, v);
meta.internalAssert(local_ctx.index == expected_count, "Pjrt call returned {} tensors, but the return type {s}, contains {} Buffers. Note that modules need to have a comptime know number of returned tensors.", .{ buffers.len, @typeName(@TypeOf(v)), local_ctx.index });
}
/// Visit the given struct and assign op results to each tensor found.
@ -637,11 +726,13 @@ const BaseExe = struct {
/// The PJRT executable representing the compiled module.
exe: *pjrt.LoadedExecutable,
/// Number of buffers in the model.
model_buffer_count: usize,
model_buffer_count: u32,
/// Number of buffers in the arguments.
args_buffer_count: usize,
args_buffer_count: u32,
/// Number of buffers in result.
result_buffer_count: usize,
result_buffer_count: u32,
/// Num devices used (>1 for sharded executable)
num_devices: u8,
};
/// Represents a ZML model, compiled into a PJRT executable.
@ -674,34 +765,56 @@ pub fn ExeWithWeights(comptime func: anytype) type {
/// The raw untyped compiled module.
inner: BaseExe,
/// The allocator used for bookkeeping.
allocator: std.mem.Allocator,
/// Pre-allocated slice of buffers to use as inputs when the module is called.
input_buffers: []*pjrt.Buffer,
input_per_device: []const [*]*pjrt.Buffer,
/// Pre-allocated slice of buffers to use as outputs when the module is called.
output_buffers: []*pjrt.Buffer,
output_per_device: []const [*]*pjrt.Buffer,
/// Internal memory slice used.
_all_buffers: []*pjrt.Buffer,
_all_per_device: [][*]*pjrt.Buffer,
/// And the allocator backing _data_buffer.
_allocator: std.mem.Allocator,
pub fn initFromModel(allocator: std.mem.Allocator, inner: BaseExe, model: Bufferized(Signature.ModelT)) !Self {
const input_buffers = try allocator.alloc(*pjrt.Buffer, inner.model_buffer_count + inner.args_buffer_count);
errdefer allocator.free(input_buffers);
fillBuffers(&model, input_buffers[0..inner.model_buffer_count]);
const n_input_buffers = inner.model_buffer_count + inner.args_buffer_count;
const n_output_buffers = inner.result_buffer_count;
const n_devices = inner.num_devices;
const output_buffers = try allocator.alloc(*pjrt.Buffer, inner.result_buffer_count);
errdefer allocator.free(output_buffers);
// Allocate once for all the *pjrt.Buffer we need to store ...
const all_buffers = try allocator.alloc(*pjrt.Buffer, (n_input_buffers + n_output_buffers) * n_devices);
errdefer allocator.free(all_buffers);
const all_input_buffers, const all_output_buffers = splitBuffer(*pjrt.Buffer, all_buffers, .{ n_input_buffers * n_devices, n_output_buffers * n_devices });
// ... and once for all the [*]*pjrt.Buffer.
const all_per_device = try allocator.alloc([*]*pjrt.Buffer, 2 * n_devices);
errdefer allocator.free(all_per_device);
const input_per_device, const output_per_device = splitBuffer([*]*pjrt.Buffer, all_per_device, .{ n_devices, n_devices });
for (0..n_devices) |i| {
input_per_device[i] = all_input_buffers[i * n_input_buffers ..].ptr;
output_per_device[i] = all_output_buffers[i * n_output_buffers ..].ptr;
}
fillBuffers(&model, input_per_device, 0, inner.model_buffer_count);
// Note: all_output_buffers is left undefined, it will be written to in `call`.
return .{
.inner = inner,
.allocator = allocator,
.input_buffers = input_buffers,
.output_buffers = output_buffers,
.input_per_device = input_per_device,
.output_per_device = output_per_device,
._all_buffers = all_buffers,
._all_per_device = all_per_device,
._allocator = allocator,
};
}
pub fn deinit(self: Self) void {
self.allocator.free(self.input_buffers);
self.allocator.free(self.output_buffers);
// Free in reverse order of allocation.
self._allocator.free(self._all_per_device);
self._allocator.free(self._all_buffers);
}
pub fn platform(self: Self) Platform {
@ -709,13 +822,13 @@ pub fn ExeWithWeights(comptime func: anytype) type {
}
pub fn call(self: Self, args: Bufferized(Signature.ArgsT)) Bufferized(Signature.ReturnT) {
fillBuffers(&args, self.input_buffers[self.inner.model_buffer_count..][0..self.inner.args_buffer_count]);
fillBuffers(&args, self.input_per_device, self.inner.model_buffer_count, self.inner.args_buffer_count);
var event: [1]*pjrt.Event = undefined;
self.inner.exe.execute(self.inner.platform.pjrt_api, .{
.arguments = &.{self.input_buffers.ptr},
.num_args = self.input_buffers.len,
.results = &.{self.output_buffers.ptr},
.arguments = self.input_per_device,
.num_args = self.inner.args_buffer_count + self.inner.model_buffer_count,
.results = self.output_per_device,
.events = &event,
// TODO: this allows to tell a specific buffer shouldn't be donated,
// even if it has been marked as "can be donated" during compilation.
@ -723,7 +836,7 @@ pub fn ExeWithWeights(comptime func: anytype) type {
}) catch unreachable;
var result: Bufferized(Signature.ReturnT) = undefined;
assignRawBuffers(&result, self.inner.platform, self.output_buffers);
assignRawBuffers(&result, self.inner.platform, self.output_per_device, self.inner.result_buffer_count);
return result;
}
};
@ -750,6 +863,11 @@ fn compileInternal(
const f = try asynk.callGeneric(CompilationContext.generateBytecode, .{ context, arena, "main", func, &model, &tensor_args, .{ .add_donations_attributes = true } });
context._module.getBody().appendOperation(f.mlir_fn);
const sharding = context._platform.sharding();
const mlir_ctx = context._mlir_ctx;
context._module.op().setAttributeByName("mhlo.num_replicas", mlir.IntegerAttribute(.i32).init(mlir_ctx, sharding.num_replicas).asAttr());
context._module.op().setAttributeByName("mhlo.num_partitions", mlir.IntegerAttribute(.i32).init(mlir_ctx, sharding.num_partitions).asAttr());
const loaded_executable = loadOrCompilePjrtExecutable(arena, context._platform, context._module) catch |err| {
log.err(
"pjrt-{s} failed to compile following valid MLIR:\n{}\n{}",
@ -771,7 +889,8 @@ fn compileInternal(
.exe = loaded_executable,
.model_buffer_count = f.n_model,
.args_buffer_count = f.n_args,
.result_buffer_count = f.res_types.len,
.result_buffer_count = @intCast(f.res_types.len),
.num_devices = sharding.num_replicas * sharding.num_partitions,
};
}
@ -932,10 +1051,12 @@ fn loadOrCompilePjrtExecutable(
}
fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, module: mlir.Module, module_hash: u64) !*pjrt.LoadedExecutable {
const sharding = platform.sharding();
var options: xla_pb.CompileOptionsProto = .{
.executable_build_options = .{
.num_replicas = 1,
.num_partitions = 1,
.num_replicas = sharding.num_replicas,
.num_partitions = sharding.num_partitions,
.use_spmd_partitioning = sharding.num_partitions > 1 or sharding.num_replicas > 1,
},
};
// Let the arena deinit, zig-protobuf deinit is very slow.
@ -1372,3 +1493,14 @@ fn hashArray(hasher: anytype, key: anytype, comptime strat: HashStrategy) void {
hash(hasher, element, strat);
}
}
fn splitBuffer(T: type, buffer: []T, lengths: anytype) [lengths.len][]T {
var res: [lengths.len][]T = undefined;
var i: usize = 0;
inline for (&res, lengths) |*r, len| {
r.* = buffer[i .. i + len];
i += len;
}
std.debug.assert(i == buffer.len);
return res;
}

View File

@ -15,6 +15,7 @@ const Target = @import("platform.zig").Target;
const log = std.log.scoped(.zml);
pub const Buffer = pjrt.Buffer;
pub const Device = pjrt.Device;
pub const DeviceDescription = pjrt.DeviceDescription;
pub const Api = pjrt.Api;
@ -27,6 +28,12 @@ pub const SerializeResult = pjrt.SerializeResult;
pub const Executable = pjrt.Executable;
pub const ExecuteError = ApiError;
test {
std.testing.refAllDecls(Client);
std.testing.refAllDecls(Event);
std.testing.refAllDecls(LoadedExecutable);
}
fn InnerMixin(comptime innerT: type) type {
return struct {
inline fn inner(self: anytype) if (@typeInfo(@TypeOf(self)).Pointer.is_const) *const innerT else *innerT {
@ -63,7 +70,7 @@ pub const Client = opaque {
const buffer, const event_ = try self.inner().bufferFromHostBuffer(api, args);
const event: *Event = @ptrCast(event_);
try event.await_(api);
return @ptrCast(buffer);
return buffer;
}
pub fn deserializeAndLoad(self: *const Client, api: *const Api, bytes: []const u8) ApiError!*LoadedExecutable {
@ -169,75 +176,6 @@ pub const Client = opaque {
}
};
pub const Buffer = opaque {
const inner = InnerMixin(pjrt.Buffer).inner;
pub const BufferType = pjrt.BufferType;
pub fn BufferTypeFromDType(dt: dtype.DataType) BufferType {
return switch (dt) {
.bool => .PRED,
.f8e4m3b11fnuz => .F8E4M3B11FNUZ,
.f8e4m3fn => .F8E4M3FN,
.f8e4m3fnuz => .F8E4M3FNUZ,
.f8e5m2 => .F8E5M2,
.f8e5m2fnuz => .F8E5M2FNUZ,
.bf16 => .BF16,
.f16 => .F16,
.f32 => .F32,
.f64 => .F64,
.i8 => .S8,
.i4 => .S4,
.i16 => .S16,
.i32 => .S32,
.i64 => .S64,
.u4 => .U4,
.u8 => .U8,
.u16 => .U16,
.u32 => .U32,
.u64 => .U64,
.c64 => .C64,
.c128 => .C128,
};
}
pub const HostBufferSemantics = pjrt.HostBufferSemantics;
pub const MemoryLayoutType = pjrt.MemoryLayoutType;
pub fn deinit(self: *Buffer, api: *const Api) void {
self.inner().deinit(api);
}
pub fn delete(self: *Buffer, api: *const Api) void {
self.inner().delete(api);
}
pub fn toHostBuffer(self: *const Buffer, api: *const Api, dst: []u8) ApiError!void {
var event = try self.inner().toHostBuffer(api, dst);
try event.await_(api);
}
pub fn getDimensions(self: *const Buffer, api: *const Api) []const i64 {
return self.inner().getDimensions(api);
}
pub fn getElementType(self: *const Buffer, api: *const Api) BufferType {
return self.inner().getElementType(api);
}
pub fn isDeleted(self: *const Buffer, api: *const Api) bool {
return self.inner().isDeleted(api);
}
pub fn getDevice(self: *const Buffer, api: *const Api) ApiError!*Device {
return @ptrCast(try self.inner().getDevice(api));
}
pub fn getOpaqueDeviceMemoryDataPointer(self: *const Buffer, api: *const Api) ApiError!*anyopaque {
return self.inner().getOpaqueDeviceMemoryDataPointer(api);
}
};
pub const Event = opaque {
pub const inner = InnerMixin(pjrt.Event).inner;
@ -249,7 +187,7 @@ pub const Event = opaque {
return self.inner().isReady(api);
}
pub fn getEventError(self: *const Event, api: *const Api) ApiError!?*Error {
pub fn getEventError(self: *const Event, api: *const Api) ?*Error {
return self.inner().getEventError(api);
}
@ -288,9 +226,9 @@ pub const Event = opaque {
pub const LoadedExecutable = opaque {
const inner = InnerMixin(pjrt.LoadedExecutable).inner;
pub fn deinit(self: *LoadedExecutable, api: *const Api) void {
self.inner().deinit(api);
}
// pub fn deinit(self: *LoadedExecutable, api: *const Api) void {
// self.inner().deinit(api);
// }
pub fn delete(self: *LoadedExecutable, api: *const Api) void {
self.inner().delete(api);
@ -300,9 +238,10 @@ pub const LoadedExecutable = opaque {
return self.inner().isDeleted(api);
}
pub fn getAddressableDevices(self: *const LoadedExecutable, api: *const Api) []*const Device {
return self.inner().getAddressableDevices(api);
}
// TODO fix me
// pub fn getAddressableDevices(self: *const LoadedExecutable, api: *const Api) []*const Device {
// return self.inner().getAddressableDevices(api);
// }
pub fn execute(self: *const LoadedExecutable, api: *const Api, args: struct {
arguments: []const [*]const *const Buffer,

View File

@ -6,6 +6,7 @@ const meta = @import("meta.zig");
const module = @import("module.zig");
const pjrt = @import("pjrtx.zig");
const pjrt_core = @import("pjrt");
const log = std.log.scoped(.zml);
pub const Target = enum {
cpu,
@ -31,6 +32,8 @@ pub const CompilationOptions = struct {
xla_dump_to: ?[]const u8 = null,
xla_dump_fusion_visualization: bool = false,
cache_location: ?[]const u8 = null,
sharding_enabled: bool = false,
sharding_axes: std.BoundedArray([*:0]const u8, 8) = .{},
};
pub const Platform = struct {
@ -39,8 +42,14 @@ pub const Platform = struct {
pjrt_client: *pjrt.Client,
compilation_options: CompilationOptions = .{},
pub const MAX_NUM_DEVICES: u8 = 8;
pub fn init(target: Target, api: *const pjrt.Api) !Platform {
const pjrt_client = try pjrt.Client.init(api, &.{});
const true_num_devices = pjrt_client.getAddressableDevices(api).len;
if (true_num_devices > MAX_NUM_DEVICES) {
log.warn("platform {} got {} devices, but ZML only support up to {} devices. Some devices won't be used.", .{ target, true_num_devices, MAX_NUM_DEVICES });
}
return .{
.target = target,
.pjrt_api = api,
@ -50,7 +59,26 @@ pub const Platform = struct {
}
pub fn getDevices(self: Platform) []const *const pjrt_core.Device {
return self.pjrt_client.getAddressableDevices(self.pjrt_api);
const all_devices = self.pjrt_client.getAddressableDevices(self.pjrt_api);
if (all_devices.len > MAX_NUM_DEVICES) {
return all_devices[0..MAX_NUM_DEVICES];
}
return all_devices;
}
pub const Sharding = struct { num_replicas: u8, num_partitions: u8 };
pub fn sharding(self: Platform) Sharding {
// replicas run the same function but with different inputs,
// while partitions contribute to one evaluation over a shared input.
// Inside an inference process, we generally don't want replicas,
// as it's best to fully isolate replicas on different processes.
// For now we hardcode num_replicas = 1.
const num_devices: u8 = @intCast(self.getDevices().len);
return if (self.compilation_options.sharding_enabled)
.{ .num_replicas = 1, .num_partitions = num_devices }
else
.{ .num_replicas = 1, .num_partitions = 1 };
}
pub fn withCompilationOptions(self: Platform, opts: CompilationOptions) Platform {

View File

@ -23,12 +23,14 @@ pub const Shape = struct {
pub const DimsArray = std.BoundedArray(i64, MAX_RANK);
pub const TagsArray = std.BoundedArray(Tag, MAX_RANK);
pub const AxesArray = std.BoundedArray(u3, MAX_RANK);
pub const ShardingInfo = @Vector(MAX_RANK, bool);
const UnknownTags: TagsArray = .{ .len = 0, .buffer = [_]Tag{TagUnknown} ** MAX_RANK };
_dtype: DataType,
_dims: DimsArray = .{},
_tags: TagsArray = UnknownTags,
_sharding_info: ShardingInfo = @splat(false),
pub fn parseDimensions(v: anytype) struct { DimsArray, TagsArray } {
const T = @TypeOf(v);
@ -69,7 +71,7 @@ pub const Shape = struct {
return .{ dims_, tags_ };
}
meta.compileError("Wrong type, got {}", .{T});
meta.compileError("expected a dimension tuple eg '.{{ .a = 10, .b = 20}}' or '.{{ 10, 20 }}', got {}", .{T});
}
test parseDimensions {
@ -286,7 +288,7 @@ pub const Shape = struct {
return res;
}
meta.compileError("Wrong type, got {}", .{T});
meta.compileError("axes expects an int-tuple or a tuple of enum literal, got {}", .{T});
}
fn axisFromInt(self: Shape, d: isize) u3 {
@ -384,6 +386,9 @@ pub const Shape = struct {
} else {
try writer.print("{s}{d}", .{ prefix, d });
}
if (self._sharding_info[i]) {
try writer.writeByte('!');
}
}
_ = try writer.print("}}, dtype=.{s}", .{@tagName(self.dtype())});
if (!bare_fmt) _ = try writer.write(")");
@ -664,6 +669,16 @@ pub const Shape = struct {
return res;
}
pub fn withSharding(self: Shape, axes_: anytype) Shape {
var res = self;
// Reset sharding.
res._sharding_info = @splat(false);
for (self.axes(axes_).constSlice()) |ax| {
res._sharding_info[ax] = true;
}
return res;
}
/// Renames some of the tags in this shape.
/// Shape.init(.{ .a = 10, .b = 20 }).rename(.{ .b = .batch }); // .{ .a = 10, .batch = 20 };
pub fn rename(self: Shape, renames: anytype) Shape {
@ -690,7 +705,8 @@ pub const Shape = struct {
}
}
pub fn computeStrides(self: Shape, base_stride: u32) std.BoundedArray(i64, MAX_RANK) {
pub fn computeStrides(self: Shape) std.BoundedArray(i64, MAX_RANK) {
const base_stride = self.dtype().sizeOf();
const rk = self.rank();
var strides: std.BoundedArray(i64, MAX_RANK) = .{ .len = @intCast(self.rank()) };
if (rk == 0) return strides;

View File

@ -159,6 +159,12 @@ pub const Tensor = struct {
return res;
}
pub fn withSharding(self: Tensor, axes_: anytype) Tensor {
var res = self;
res._shape = self._shape.withSharding(axes_);
return res;
}
/// Returns a Tensor with new tag names.
pub fn rename(self: Tensor, renames: anytype) Tensor {
var res = self;
@ -196,11 +202,6 @@ pub const Tensor = struct {
return res;
}
/// Returns a slice containing the strides for a Tensor.
pub inline fn computeStrides(self: Tensor) []const i64 {
return self._shape.computeStrides(self.dtype().sizeOf()).constSlice();
}
var _global_tensor_counter: u64 = 0;
/// Internal use