From 2f129f76c97007a80a57b1123817d4074eb8cd9a Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Fri, 24 Feb 2023 17:33:14 +0000 Subject: [PATCH] Add in-process sharding support across core ZML components (platform, shape, tensor, MLIR generation, buffers, and PJRT integration) --- mlir/mlir.zig | 30 +++-- pjrt/pjrt.zig | 10 +- zml/aio.zig | 61 +++------- zml/buffer.zig | 208 ++++++++++++++++++++++++--------- zml/hostbuffer.zig | 63 +++++++++- zml/meta.zig | 33 ++++++ zml/module.zig | 286 +++++++++++++++++++++++++++++++++------------ zml/pjrtx.zig | 93 +++------------ zml/platform.zig | 30 ++++- zml/shape.zig | 22 +++- zml/tensor.zig | 11 +- 11 files changed, 567 insertions(+), 280 deletions(-) diff --git a/mlir/mlir.zig b/mlir/mlir.zig index 6d48e79..0ec47ea 100644 --- a/mlir/mlir.zig +++ b/mlir/mlir.zig @@ -47,7 +47,7 @@ pub fn MlirTypeMethods(comptime InnerT: type) type { /// Alternative to MlirWrapperType pub const MlirStrCallback = fn (c.MlirStringRef, ?*anyopaque) callconv(.C) void; -fn MlirHelpersMethods(comptime OuterT: type) type { +fn MlirHelpersMethods(OuterT: type) type { switch (@typeInfo(OuterT)) { .Struct => |info| { if (info.fields.len != 1) @compileError("Mlir wrapper type can only wrap one Mlir value. Received: " ++ @typeName(OuterT)); @@ -382,6 +382,10 @@ pub const StringAttribute = struct { pub fn value(self: Self) []const u8 { return fromStringRef(c.mlirStringAttrGetValue(self.inner())); } + + pub fn asAttr(self: StringAttribute) Attribute { + return .{ ._inner = self._inner }; + } }; pub const UnitAttribute = struct { @@ -493,6 +497,10 @@ pub fn IntegerAttribute(comptime it: IntegerTypes) type { pub fn get(value: IntAttr) ZigType { return @intCast(getter(value.inner())); } + + pub fn asAttr(self: IntAttr) Attribute { + return .{ ._inner = self._inner }; + } }; } @@ -731,23 +739,29 @@ pub const DictionaryAttribute = struct { .equal_fn = c.mlirAttributeEqual, }); - const Self = DictionaryAttribute; - - pub fn init(ctx: Context, attributes: []const NamedAttribute) Self { - return Self.wrap(c.mlirDictionaryAttrGet(ctx.inner(), @intCast(attributes.len), @ptrCast(attributes.ptr))); + pub fn init(ctx: Context, attributes: []const NamedAttribute) DictionaryAttribute { + return DictionaryAttribute.wrap(c.mlirDictionaryAttrGet( + ctx.inner(), + @intCast(attributes.len), + @ptrCast(attributes.ptr), + )); } - pub fn size(self: Self) usize { + pub fn size(self: DictionaryAttribute) usize { return @intCast(c.mlirDictionaryAttrGetNumElements(self.inner())); } - pub fn get(self: Self, pos: usize) NamedAttribute { + pub fn get(self: DictionaryAttribute, pos: usize) NamedAttribute { return NamedAttribute.wrap(c.mlirDictionaryAttrGetElement(self.inner(), @intCast(pos))); } - pub fn getByName(self: Self, name: [:0]const u8) ?NamedAttribute { + pub fn getByName(self: DictionaryAttribute, name: [:0]const u8) ?NamedAttribute { return NamedAttribute.wrapOr(c.mlirDictionaryAttrGetElementByName(self.inner(), name)); } + + pub fn asAttr(self: DictionaryAttribute) Attribute { + return .{ ._inner = self._inner }; + } }; pub const Operation = struct { diff --git a/pjrt/pjrt.zig b/pjrt/pjrt.zig index e80fdf8..5a4aaec 100644 --- a/pjrt/pjrt.zig +++ b/pjrt/pjrt.zig @@ -727,11 +727,11 @@ pub const Event = opaque { return ret.is_ready; } - pub fn getEventError(self: *const Event, api: *const Api) ApiError!?*Error { - const ret = try api.call(.PJRT_Event_Error, .{ - .event = self.inner(), - }); - return @ptrCast(ret); + pub fn getEventError(self: *const Event, api: *const Api) ?*Error { + var args: Api.CallFnArgType(.PJRT_Event_Error) = .{ .event = self.inner() }; + args = pjrtStruct(args); + const result: ?*c.PJRT_Error = api.inner.PJRT_Event_Error.?(&args); + return @ptrCast(result); } pub fn await_(self: *const Event, api: *const Api) ApiError!void { diff --git a/zml/aio.zig b/zml/aio.zig index 6997db2..c7318f5 100644 --- a/zml/aio.zig +++ b/zml/aio.zig @@ -398,7 +398,12 @@ pub fn loadModelBuffers( ) !zml.Bufferized(Model) { return try loadModelBuffersWithPrefix(Model, model, buffer_store, allocator, platform, ""); } - +/// Creates a bufferized version of a Model from the given BufferStore and the given prefix. +/// For details about bufferization, see the documentation of Bufferized(T). +/// +/// This will represent the weights of the model, loaded on a specific platform. +/// It can be used with a `module.Exe` (a compiled version of the same Model), to make a +/// `module.ExeWithWeights` ready to be called. pub fn loadModelBuffersWithPrefix( comptime Model: type, model: Model, @@ -408,12 +413,12 @@ pub fn loadModelBuffersWithPrefix( prefix: []const u8, ) !zml.Bufferized(Model) { // Allocate the bufferized version. - // We set fields to undefined, cause visitStructAndLoadBuffer is responsible + // We copy the shape, and let visitStructAndLoadBuffer write the other fields. // to write them just afterward. var res: zml.Bufferized(Model) = undefined; try zml.meta.mapAlloc(struct { - pub fn initBuffer(_: void, _: zml.Tensor) zml.Buffer { - return undefined; + pub fn initBuffer(_: void, tensor: zml.Tensor) zml.Buffer { + return .{ ._shape = tensor.shape(), ._api = undefined, ._shards = undefined }; } }.initBuffer, allocator, {}, model, &res); @@ -425,32 +430,6 @@ pub fn loadModelBuffersWithPrefix( return res; } -/// Creates a bufferized version of a Model from the given BufferStore and the given prefix. -/// For details about bufferization, see the documentation of Bufferized(T). -/// -/// This will represent the weights of the model, loaded on a specific platform. -/// It can be used with a `module.Exe` (a compiled version of the same Model), to make a -/// `module.ExeWithWeights` ready to be called. -pub fn loadBuffersFromModelWithPrefix(comptime Model: type, model: Model, buffer_store: BufferStore, allocator: std.mem.Allocator, prefix: []const u8, platform: zml.Platform) !zml.Bufferized(Model) { - - // Allocate the bufferized version. - // We set fields to undefined, cause visitStructAndLoadBuffer is responsible - // to write them just afterward. - var res: zml.Bufferized(Model) = undefined; - try zml.meta.mapAlloc(struct { - pub fn initBuffer(_: void, _: zml.Tensor) zml.Buffer { - return undefined; - } - }.initBuffer, allocator, {}, model, &res); - - var prefix_builder: PrefixBuilder = .{}; - defer prefix_builder.deinit(allocator); - try prefix_builder.push(allocator, prefix); - - try visitStructAndLoadBuffer(allocator, &prefix_builder, buffer_store, &res, platform); - return res; -} - /// Takes a bufferized version of a `model`, ie a mirror struct of the `model`, and deinit all the /// Buffer found. pub fn unloadBuffers(model: anytype) void { @@ -474,7 +453,12 @@ fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *Prefi const prefix = prefix_builder.data.items; if (T == zml.Buffer) { return if (buffer_store.get(prefix)) |host_buffer| { - obj.* = try zml.Buffer.from(platform, host_buffer); + // obj._shape has been set inside `loadModelBuffersWithPrefix`, before calling us. + var buf_with_metadata = host_buffer; + log.warn("loading {s} ({})", .{ prefix, obj._shape }); + zml.meta.assert(host_buffer.shape().eql(obj._shape), "loadModelBuffers expects to find the same shapes in the model and in the buffer store, got {} and {} for tensor {s}", .{ obj._shape, host_buffer, prefix }); + buf_with_metadata._shape = obj._shape; + obj.* = try zml.Buffer.from(platform, buf_with_metadata); } else { return error.BufferNotFound; }; @@ -484,10 +468,7 @@ fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *Prefi .Pointer => |ptr_info| { if (ptr_info.size == .Slice) { for (obj.*, 0..) |*value, i| { - var buffer: [100]u8 = undefined; - const new_prefix = std.fmt.bufPrint(&buffer, "{d}", .{i}) catch unreachable; - - try prefix_builder.push(allocator, new_prefix); + try prefix_builder.pushDigit(allocator, i); defer prefix_builder.pop(); try visitStructAndLoadBuffer(allocator, prefix_builder, buffer_store, value, platform); @@ -502,13 +483,9 @@ fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *Prefi try visitStructAndLoadBuffer(allocator, prefix_builder, buffer_store, &@field(obj, field.name), platform); } }, - .Optional => |opt_info| { - var child = @as(opt_info.child, undefined); - if (visitStructAndLoadBuffer(allocator, prefix_builder, buffer_store, &child, platform)) { - obj.* = child; - } else |err| switch (err) { - error.BufferNotFound => {}, - else => return err, + .Optional => { + if (obj.*) |*obj_val| { + try visitStructAndLoadBuffer(allocator, prefix_builder, buffer_store, obj_val, platform); } }, else => {}, diff --git a/zml/buffer.zig b/zml/buffer.zig index eddf976..ad1798a 100644 --- a/zml/buffer.zig +++ b/zml/buffer.zig @@ -3,7 +3,6 @@ const testing = std.testing; const meta = @import("meta.zig"); const pjrt = @import("pjrt"); -const pjrtx = @import("pjrtx.zig"); const Context = @import("context.zig").Context; const Data = @import("dtype.zig").Data; @@ -13,9 +12,12 @@ const Platform = @import("platform.zig").Platform; const Shape = @import("shape.zig").Shape; test { + std.testing.refAllDecls(@This()); std.testing.refAllDecls(Buffer); } +const log = std.log.scoped(.zml); + /// Buffer is a multi-dimension array, whose memory is allocated on an accelerator. /// /// * contains a handle that the ZML runtime can use to convert into a physical address, but there is no guarantee this address is visible from the CPU. @@ -23,33 +25,70 @@ test { /// * can be created by calling `HostBuffer.toDevice(platform)`. pub const Buffer = struct { _shape: Shape, - _shards: Shape = undefined, - _platform: Platform, - _data: *pjrtx.Buffer, + _api: *const pjrt.Api, + _shards: Shards, + + pub const MAX_NUM_SHARDS: u8 = Platform.MAX_NUM_DEVICES; + pub const Shards = std.BoundedArray(*pjrt.Buffer, MAX_NUM_SHARDS); /// Copies the content of the given buffer from host memory to the accelerator memory. - pub fn from(platform: Platform, buf: HostBuffer) !Buffer { - const pjrt_buffer = try platform.pjrt_client.bufferFromHostBuffer(platform.pjrt_api, .{ - .data = buf.data, - .buffer_type = pjrtx.Buffer.BufferTypeFromDType(buf.shape().dtype()), - .dims = buf.shape().dims(), - .byte_strides = buf.strides(), - .device = platform.getDevices()[0], - .host_buffer_semantics = .ImmutableUntilTransferCompletes, - }); - return .{ - ._platform = platform, - ._shape = buf.shape(), - ._data = pjrt_buffer, + pub fn from(platform: Platform, host_buffer: HostBuffer) !Buffer { + var res: Buffer = .{ + ._api = platform.pjrt_api, + ._shape = host_buffer.shape(), + ._shards = .{}, }; + + // We shard only on the first axis so that the chunks are still contiguous. + // TODO: support more advanced sharding specs + meta.assert(platform.sharding().num_replicas == 1, "ZML doesn't support num_replicas > 1 for now, got: {}", .{platform.sharding()}); + const sharding_ax: ?u3 = std.simd.firstTrue(host_buffer.shape()._sharding_info); + const n_partitions = platform.sharding().num_partitions; + const chunk_size = if (sharding_ax) |ax| cs: { + // This kind of sharding error should be detected earlier on. + meta.assert(@rem(host_buffer.dim(ax), n_partitions) == 0, "Buffer.from({}) expects the sharding axis {} to have a dimension divisble by the number of devices ({}).", .{ host_buffer, ax, n_partitions }); + break :cs @divExact(host_buffer.dim(ax), n_partitions); + } else 0; + + const buffer_type = bufferTypeFromDtype(host_buffer.shape().dtype()); + const byte_strides = host_buffer.strides() orelse host_buffer.shape().computeStrides().constSlice(); + + const devices = platform.getDevices(); + for (0..n_partitions) |i| { + // If no sharding if found, the given buffer is replicated on all devices. + const buf = if (sharding_ax) |ax| buf: { + const start: i64 = @as(i64, @intCast(i)) * chunk_size; + break :buf host_buffer.slice1d(ax, .{ .start = start, .end = start + chunk_size }); + } else host_buffer; + + const pjrt_buffer = try platform.pjrt_client.bufferFromHostBuffer(platform.pjrt_api, .{ + .data = buf.data, + .buffer_type = buffer_type, + .dims = buf.shape().dims(), + .byte_strides = byte_strides, + .device = devices[i], + .host_buffer_semantics = .ImmutableUntilTransferCompletes, + }); + + res._shards.appendAssumeCapacity(pjrt_buffer); + } + return res; } - /// Wraps a pre-exisiting `pjrt.Buffer` into a `zml.Buffer`. - pub fn fromPjrtBuffer(platform: Platform, pjrt_buffer: *pjrtx.Buffer) Buffer { + /// Wraps pre-exisiting `pjrt.Buffer` shards into one `zml.Buffer`. + pub fn fromPjrtBuffers(platform: Platform, pjrt_buffers: []const *pjrt.Buffer) Buffer { + meta.assert(pjrt_buffers.len <= MAX_NUM_SHARDS, "ZML doesn't support having more than {} shards. Received {} shards for one buffer.", .{ MAX_NUM_SHARDS, pjrt_buffers.len }); + meta.assert(pjrt_buffers.len > 0, "fromPjrtBuffers expects at least one buffer, got 0.", .{}); + var shards: Shards = .{}; + shards.appendSliceAssumeCapacity(pjrt_buffers); return .{ - ._platform = platform, - ._shape = _shapeFromPjrtBuffer(platform, pjrt_buffer), - ._data = pjrt_buffer, + ._api = platform.pjrt_api, + ._shape = Shape.init( + // This isn't with sharded axes. + pjrt_buffers[0].getDimensions(platform.pjrt_api), + dtypeFromBufferType(pjrt_buffers[0].getElementType(platform.pjrt_api)), + ), + ._shards = shards, }; } @@ -112,8 +151,9 @@ pub const Buffer = struct { const pjrt_buffer = try platform.pjrt_client.createViewOfDeviceBuffer(platform.pjrt_api, .{ .data = buf.data, - .element_type = pjrtx.Buffer.BufferTypeFromDType(buf.shape().dtype()), + .element_type = bufferTypeFromDtype(buf.shape().dtype()), .dims = buf.shape().dims(), + // TODO: split in shards .device = platform.getDevices()[0], .layout = .{ .Tiled = .{ @@ -124,10 +164,12 @@ pub const Buffer = struct { }, }); + var shards: Shards = .{}; + shards.appendAssumeCapacity(pjrt_buffer); return .{ - ._platform = platform, + ._api = platform.pjrt_api, ._shape = buf.shape(), - ._data = pjrt_buffer, + ._shards = shards, }; } @@ -135,7 +177,9 @@ pub const Buffer = struct { pub fn getValue(self: Buffer, T: type) !T { meta.assert(self._shape.byteSize() == @sizeOf(T), "Buffer {} has {d} bytes of data, can't load it to a {s} with {d} bytes", .{ self, self._shape.byteSize(), @typeName(T), @sizeOf(T) }); var res: T = undefined; - try self._data.toHostBuffer(self._platform.pjrt_api, std.mem.asBytes(&res)); + meta.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer -> Host transfer", .{}); + const event = try self._shards.get(0).toHostBuffer(self._api, std.mem.asBytes(&res)); + try event.await_(self._api); return res; } @@ -143,7 +187,9 @@ pub const Buffer = struct { /// and return a new `HostBuffer` object with the same shape. /// The returned `HostBuffer` doesn't own the memory. pub fn toHost(self: Buffer, output: []u8) !HostBuffer { - try self._data.toHostBuffer(self._platform.pjrt_api, output); + meta.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer -> Host transfer", .{}); + const event = try self._shards.get(0).toHostBuffer(self._api, output); + try event.await_(self._api); return HostBuffer.fromBytes(self.shape(), output); } @@ -151,7 +197,9 @@ pub const Buffer = struct { /// The returned `HostBuffer` does own the memory. pub fn toHostAlloc(self: Buffer, allocator: std.mem.Allocator) !HostBuffer { const output = try HostBuffer.empty(allocator, self.shape()); - try self._data.toHostBuffer(self._platform.pjrt_api, @constCast(output.data)); + meta.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer -> Host transfer", .{}); + const event = try self._shards.get(0).toHostBuffer(self._api, @constCast(output.data)); + try event.await_(self._api); return output; } @@ -159,7 +207,9 @@ pub const Buffer = struct { /// Depending on the platform, the memory is typically not released to the OS /// but just marked as available in the memory pool. pub fn deinit(self: *const Buffer) void { - self._data.deinit(self._platform.pjrt_api); + for (self._shards.constSlice()) |buffer| { + buffer.deinit(self._api); + } } /// This Buffer shape. @@ -202,34 +252,76 @@ pub const Buffer = struct { try writer.print("Buffer({_})", .{self._shape}); } - fn _shapeFromPjrtBuffer(platform: Platform, buf: *pjrtx.Buffer) Shape { - const dt: DataType = switch (buf.getElementType(platform.pjrt_api)) { - // Please keep the list exhaustive and in the same order than in DataType. - .PRED => .bool, - .F8E4M3B11FNUZ => .f8e4m3b11fnuz, - .F8E4M3FN => .f8e4m3fn, - .F8E4M3FNUZ => .f8e4m3fnuz, - .F8E5M2 => .f8e5m2, - .F8E5M2FNUZ => .f8e5m2fnuz, - .BF16 => .bf16, - .F16 => .f16, - .F32 => .f32, - .F64 => .f64, - .S4 => .i4, - .S8 => .i8, - .S16 => .i16, - .S32 => .i32, - .S64 => .i64, - .U4 => .u4, - .U8 => .u8, - .U16 => .u16, - .U32 => .u32, - .U64 => .u64, - .C64 => .c64, - .C128 => .c128, - .INVALID => @panic("Can't handle INVALID Pjrt buffers."), - }; - - return Shape.init(buf.getDimensions(platform.pjrt_api), dt); + fn hasShardedAxis(self: Buffer) bool { + if (self._shards.len == 1) return false; + return @reduce(.Or, self._shape._sharding_info); } }; + +pub fn bufferTypeFromDtype(dt: DataType) pjrt.BufferType { + return switch (dt) { + .bool => .PRED, + .f8e4m3b11fnuz => .F8E4M3B11FNUZ, + .f8e4m3fn => .F8E4M3FN, + .f8e4m3fnuz => .F8E4M3FNUZ, + .f8e5m2 => .F8E5M2, + .f8e5m2fnuz => .F8E5M2FNUZ, + .bf16 => .BF16, + .f16 => .F16, + .f32 => .F32, + .f64 => .F64, + .i8 => .S8, + .i4 => .S4, + .i16 => .S16, + .i32 => .S32, + .i64 => .S64, + .u4 => .U4, + .u8 => .U8, + .u16 => .U16, + .u32 => .U32, + .u64 => .U64, + .c64 => .C64, + .c128 => .C128, + }; +} + +pub fn dtypeFromBufferType(pjrt_type: pjrt.BufferType) DataType { + return switch (pjrt_type) { + .PRED => .bool, + .F8E4M3B11FNUZ => .f8e4m3b11fnuz, + .F8E4M3FN => .f8e4m3fn, + .F8E4M3FNUZ => .f8e4m3fnuz, + .F8E5M2 => .f8e5m2, + .F8E5M2FNUZ => .f8e5m2fnuz, + .BF16 => .bf16, + .F16 => .f16, + .F32 => .f32, + .F64 => .f64, + .S8 => .i8, + .S4 => .i4, + .S16 => .i16, + .S32 => .i32, + .S64 => .i64, + .U4 => .u4, + .U8 => .u8, + .U16 => .u16, + .U32 => .u32, + .U64 => .u64, + .C64 => .c64, + .C128 => .c128, + .INVALID => @panic("Found an invalid pjrt buffer"), + }; +} + +test bufferTypeFromDtype { + inline for (@typeInfo(DataType).Enum.fields) |field| { + const dt: DataType = @enumFromInt(field.value); + try std.testing.expectEqual(dt, dtypeFromBufferType(bufferTypeFromDtype(dt))); + } + + inline for (@typeInfo(pjrt.BufferType).Enum.fields) |field| { + const dt: pjrt.BufferType = @enumFromInt(field.value); + if (dt == .INVALID) continue; + try std.testing.expectEqual(dt, bufferTypeFromDtype(dtypeFromBufferType(dt))); + } +} diff --git a/zml/hostbuffer.zig b/zml/hostbuffer.zig index f4b92ec..863a7c6 100644 --- a/zml/hostbuffer.zig +++ b/zml/hostbuffer.zig @@ -147,9 +147,15 @@ pub const HostBuffer = struct { return try Buffer.from(platform_, self); } + /// Interpret the underlying data as a contiguous slice. + /// WARNING: It's only valid if the buffer is contiguous. + /// Strided buffers can't use this method. pub fn items(self: HostBuffer, comptime T: type) []const T { if (DataType.fromZigType(T) != self.dtype()) { - std.debug.panic("Can't reinterpret HostBuffer({_}) as {s}", .{ self.shape(), @typeName(T) }); + std.debug.panic("Can't reinterpret {} as {s}", .{ self, @typeName(T) }); + } + if (!self.isContiguous()) { + std.debug.panic("{} isn't contiguous", .{self}); } const ptr: [*]const T = @alignCast(@ptrCast(self.data.ptr)); return ptr[0..self._shape.count()]; @@ -180,16 +186,65 @@ pub const HostBuffer = struct { return self._shape.count(); } - pub fn dim(self: HostBuffer, axis: anytype) i64 { - return self._shape.dim(axis); + pub fn dim(self: HostBuffer, axis_: anytype) i64 { + return self._shape.dim(axis_); + } + + pub fn axis(self: HostBuffer, axis_: anytype) u3 { + return self._shape.axis(axis_); + } + + pub fn isContiguous(self: HostBuffer) bool { + const strd = self._strides orelse return true; + const cont_strides = self._shape.computeStrides(); + return std.mem.eql(i64, strd[0..self.rank()], cont_strides.constSlice()); } pub fn reshape(self: HostBuffer, shape_: anytype) HostBuffer { - meta.assert(self._strides == null, "reshape expects a contiguous tensor, got: {}", .{self}); + meta.assert(self.isContiguous(), "reshape expects a contiguous tensor, got: {}", .{self}); var res = self; res._shape = self._shape.reshape(shape_); return res; } + + pub const Slice = struct { + start: i64 = 0, + end: ?i64 = null, + }; + + /// Slices the input Tensor over the given axis_ using the given parameters. + pub fn slice1d(self: HostBuffer, axis_: anytype, s: Slice) HostBuffer { + const ax = self._shape.axis(axis_); + const d = self.dim(ax); + const start: i64 = if (s.start < 0) s.start + d else s.start; + var end = s.end orelse d; + if (end < 0) end += d; + meta.assert(start >= 0 and start < d, "slice1d({}, {}) expects the slice start to be between 0 and {} got: {}", .{ self, ax, d, start }); + meta.assert(end >= 1 and end <= d, "slice1d({}, {}) expects the slice end to be between 1 and {} got: {}", .{ self, ax, d, end }); + meta.assert(start < end, "slice1d({}, {}) expects the slice start ({}) to be smaller than the end ({})", .{ self, ax, start, end }); + + // If strides weren't set it means original buffer is contiguous. + // But it won't be anymore after slicing. The strides don't change though. + const _strides = self._strides orelse self._shape.computeStrides().buffer; + const offset: usize = @intCast(start * _strides[ax]); + return .{ + ._shape = self.shape().set(ax, end - start), + .data = self.data[offset..], + ._strides = _strides, + ._memory = .unmanaged, + }; + } + + pub fn format( + self: HostBuffer, + comptime fmt: []const u8, + options: std.fmt.FormatOptions, + writer: anytype, + ) !void { + _ = fmt; + _ = options; + try writer.print("HostBuffer(.{_})", .{self._shape}); + } }; fn parseArrayInfo(T: type) Shape { diff --git a/zml/meta.zig b/zml/meta.zig index 4f35c1a..532afc7 100644 --- a/zml/meta.zig +++ b/zml/meta.zig @@ -663,3 +663,36 @@ test zip { const a_sum: A = try zip(Sum.call, testing.allocator, &[_]A{ a0, a1 }, .{}); try testing.expectEqual(A{ .a = 5, .b = .{ 7, 9 } }, a_sum); } + +/// Given a func(X) -> Y or a func(Ctx, X) -> Y, +/// finds all X in the given object, and write the result of func(X) into an arraylist. +pub fn collect(func: anytype, func_ctx: _CollectCtx(func), out: *std.ArrayList(FnResult(func)), obj: anytype) error{OutOfMemory}!void { + assertComptime(@typeInfo(@TypeOf(func)).Fn.params.len <= 2, "zml.meta.collect expects a func with two arguments, got: {}", .{@TypeOf(func)}); + const LocalContext = struct { + func_ctx: _CollectCtx(func), + out: *std.ArrayList(FnResult(func)), + oom: bool = false, + }; + var context = LocalContext{ .func_ctx = func_ctx, .out = out }; + visit((struct { + fn cb(ctx: *LocalContext, val: *const _CollectArg(func)) void { + if (ctx.oom) return; + const res = if (_CollectCtx(func) == void) func(val.*) else func(ctx.func_ctx, val.*); + ctx.out.append(res) catch { + ctx.oom = true; + }; + } + }).cb, &context, obj); + if (context.oom) return error.OutOfMemory; +} + +fn _CollectCtx(func: anytype) type { + const params = @typeInfo(@TypeOf(func)).Fn.params; + if (params.len == 1) return void; + return params[0].type orelse @compileError("anytype not supported in collect"); +} + +fn _CollectArg(func: anytype) type { + const params = @typeInfo(@TypeOf(func)).Fn.params; + return params[params.len - 1].type orelse @compileError("anytype not supported in collect"); +} diff --git a/zml/module.zig b/zml/module.zig index 00504fc..e8927e3 100644 --- a/zml/module.zig +++ b/zml/module.zig @@ -49,9 +49,11 @@ pub const CompilationContext = struct { _unique_id: u64 = 10000, _tracer: Tracer, - const TensorToBlockArg = std.AutoHashMapUnmanaged(Tensor._Id, struct { mlir.Value, Tensor._Donation }); threadlocal var _current: ?*CompilationContext = null; + const TensorToBlockArg = std.AutoHashMapUnmanaged(Tensor._Id, struct { mlir.Value, Tensor._Donation }); + const AttributeList = std.BoundedArray(mlir.NamedAttribute, 3); + pub fn init(allocator: std.mem.Allocator, name: []const u8, platform: Platform) !CompilationContext { const mlir_registry = mlir.Registry.init() catch unreachable; inline for (.{ "func", "stablehlo" }) |d| { @@ -181,7 +183,7 @@ pub const CompilationContext = struct { comptime func: anytype, model: *const ModuleSignature(func).ModelT, args: *const ModuleSignature(func).ArgsT, - opts: struct { add_donations_attributes: bool = false }, + opts: struct { add_donations_attributes: bool = false, sharding: bool = true }, ) error{OutOfMemory}!MlirFn { const frame = self._tracer.frameStart("generateBytecode.emit"); errdefer self._tracer.frameEnd(frame, "generateBytecode.emit"); @@ -197,17 +199,25 @@ pub const CompilationContext = struct { const tensor_count = model_tensor_count + args_tensor_count; - const loc = self.mlirCtx().location(@src()); + const mlir_ctx = self.mlirCtx(); + const loc = mlir_ctx.location(@src()); const locations = try arena.alloc(mlir.Location, tensor_count); - for (locations) |*l| l.* = mlir.Location.unknown(self.mlirCtx()); - var input_types = try arena.alloc(mlir.Type, tensor_count); - fillMlirTypes(model, self.mlirCtx(), input_types[0..model_tensor_count]); - fillMlirTypes(args, self.mlirCtx(), input_types[model_tensor_count..]); + @memset(locations, mlir.Location.unknown(mlir_ctx)); + + var input_shapes = try std.ArrayList(Shape).initCapacity(arena, tensor_count); + meta.collect(Tensor.shape, {}, &input_shapes, model) catch unreachable; + meta.internalAssert(input_shapes.items.len == model_tensor_count, "model has changed ?", .{}); + meta.collect(Tensor.shape, {}, &input_shapes, args) catch unreachable; + meta.internalAssert(input_shapes.items.len == tensor_count, "args have changed ?", .{}); + + const input_types = try arena.alloc(mlir.Type, tensor_count); + for (input_types, input_shapes.items) |*t, sh| t.* = mlir.ext.mlirType(mlir_ctx, sh); // Note: this isn't stricly necessary. We call `countTensor` on `fn_res`. // But it forces user to have simpler function. const out_tensor_count = comptime ops.staticCountTensors(ModuleSignature(func).ReturnT) orelse @compileError("Can't use " ++ @typeName(ModuleSignature(func).ReturnT) ++ " in an MLIR function, because it has a variable number of tensors"); + // Those are returned to caller so we don't put them in the arena. const fn_res_types = try allocator.alloc(mlir.Type, out_tensor_count); const fn_res_shapes = try allocator.alloc(Shape, out_tensor_count); const fn_res_donations = try allocator.alloc(Tensor._Donation, out_tensor_count); @@ -234,20 +244,25 @@ pub const CompilationContext = struct { var fn_res_values: [out_tensor_count]mlir.Value = undefined; self.extractValuesAndTypes(&fn_res, &fn_res_values, fn_res_types, fn_res_shapes, fn_res_donations); - const fn_ret = dialect.func.return_(self.mlirCtx(), &fn_res_values, loc); + const fn_ret = dialect.func.return_(mlir_ctx, &fn_res_values, loc); fn_body.addOperationsRecursive(fn_ret); } + const arg_attrs = try arena.alloc(AttributeList, tensor_count); + @memset(arg_attrs, .{}); + // Donations attributes only make sense on the main function. - const attrs: []const mlir.Attribute = if (opts.add_donations_attributes) - try self.addDonationsAttribute(arena, fn_res_donations, tensor_count) - else - &.{}; + if (opts.add_donations_attributes) { + self.addDonationsAttributes(arg_attrs, fn_res_donations); + } + if (opts.sharding) { + self.addShardingAttributes(arg_attrs, input_shapes.items); + } const mlir_fn = dialect.func.func(self.mlirCtx(), .{ .sym_name = fn_name, - .args = input_types[0..], - .arg_attrs = attrs, + .args = input_types, + .arg_attrs = try finalizeAttributeList(arena, mlir_ctx, arg_attrs), .results = fn_res_types, .block = fn_body, .location = loc, @@ -277,12 +292,7 @@ pub const CompilationContext = struct { /// Given a list of donations mapping output buffers to input buffers, /// generate donation attribute for each `n_args` input argument. - fn addDonationsAttribute(self: *const CompilationContext, allocator: std.mem.Allocator, donations: []const Tensor._Donation, n_args: usize) ![]mlir.Attribute { - const empty = mlir.DictionaryAttribute.init(self.mlirCtx(), &.{}).as(mlir.Attribute).?; - - const arg_attrs = try allocator.alloc(mlir.Attribute, n_args); - @memset(arg_attrs, empty); - + fn addDonationsAttributes(self: CompilationContext, attributes: []AttributeList, donations: []const Tensor._Donation) void { var n_donations: usize = 0; for (donations, 0..) |donation, index| { switch (donation) { @@ -293,23 +303,23 @@ pub const CompilationContext = struct { .input_buffer => {}, .arg => |a| { n_donations += 1; - meta.assert(arg_attrs[a].eql(empty), "Donation error ! Argument {} has been donated twice ! To {} and to {}", .{ a, index, arg_attrs[a] }); - arg_attrs[a] = mlir.DictionaryAttribute.init(self.mlirCtx(), &.{ + // This will break the day we writer another attribute before donation. + // When the time come, do a more fancy lookup here to check if an argument + // is donated twice. + meta.assert(attributes[a].len == 0, "Donation error ! Argument {} has been donated twice ! To {} and to {}", .{ a, index, attributes[a].buffer[0] }); + attributes[a].appendAssumeCapacity( mlir.NamedAttribute.init( mlir.Identifier.get(self.mlirCtx(), "tf.aliasing_output"), mlir.IntegerAttribute(.i32).init(self.mlirCtx(), @intCast(index)).as(mlir.Attribute).?, ), - }).as(mlir.Attribute).?; - // log.debug("attribute: {}", .{arg_attrs[a]}); + ); + // log.debug("attribute: {}", .{attributes[a].constSlice()}); }, } } - - if (n_donations == 0) return &.{}; - return arg_attrs; } - test addDonationsAttribute { + test addDonationsAttributes { const zml = @import("zml.zig"); const platform = zml.testing.env(); var arena = std.heap.ArenaAllocator.init(std.testing.allocator); @@ -343,12 +353,81 @@ pub const CompilationContext = struct { // `%arg0` is the bias of the model, `%arg1` is `x`. try std.testing.expectEqual(1, f.n_model); try std.testing.expectEqual(1, f.n_args); - std.testing.expect(std.mem.indexOf(u8, mlir_bytecode.items, "%arg1: tensor<8xf16> {tf.aliasing_output = 0 : i32}") != null) catch |err| { + std.testing.expect(std.mem.indexOf(u8, mlir_bytecode.items, "tf.aliasing_output = 0 : i32") != null) catch |err| { log.warn("Didn't produced the expected IR:\n{s}", .{mlir_bytecode.items}); return err; }; } + fn addShardingAttributes(self: CompilationContext, attributes: []AttributeList, shapes: []const Shape) void { + const mlir_ctx = self.mlirCtx(); + if (!self._platform.compilation_options.sharding_enabled) return; + + const num_partitions = self._platform.sharding().num_partitions; + var sharding_str: std.BoundedArray(u8, 128) = .{}; + + const mhlo_default_layout = mlir.NamedAttribute.init( + mlir.Identifier.get(mlir_ctx, "mhlo.layout_mode"), + mlir.StringAttribute.init(mlir_ctx, "default").asAttr(), + ); + for (attributes, shapes) |*attr, shape| { + attr.appendAssumeCapacity(mhlo_default_layout); + + writeShardingRepresentation(shape, num_partitions, sharding_str.writer()) catch unreachable; + defer sharding_str.len = 0; + attr.appendAssumeCapacity(mlir.NamedAttribute.init( + mlir.Identifier.get(mlir_ctx, "mhlo.sharding"), + mlir.StringAttribute.init(mlir_ctx, sharding_str.constSlice()).asAttr(), + )); + } + } + + fn writeShardingRepresentation(shape: Shape, num_partitions: u8, writer: anytype) @TypeOf(writer).Error!void { + const n_sharded: u8 = @popCount(@as(u8, @bitCast(shape._sharding_info))); + if (n_sharded == 0 or num_partitions == 1) { + try writer.writeAll("{replicated}"); + return; + } + try writer.writeAll("{devices=["); + for (0..shape.rank()) |i| { + try writer.print("{d}", .{if (shape._sharding_info[i]) num_partitions else 1}); + if (i < shape.rank() - 1) try writer.writeByte(','); + } + try writer.print("]<=[{d}]}}", .{num_partitions}); + } + + test writeShardingRepresentation { + var rule: [64]u8 = undefined; + const x = Shape.init(.{ 16, 8 }, .f32); + + // By default tensors are replicated. + { + var fbs = std.io.fixedBufferStream(&rule); + try writeShardingRepresentation(x, 4, fbs.writer()); + try std.testing.expectEqualStrings("{replicated}", fbs.getWritten()); + } + // Shard along first axis. + { + var fbs = std.io.fixedBufferStream(&rule); + try writeShardingRepresentation(x.withSharding(.{0}), 4, fbs.writer()); + try std.testing.expectEqualStrings("{devices=[4,1]<=[4]}", fbs.getWritten()); + } + // Also shard along second axis. + { + var fbs = std.io.fixedBufferStream(&rule); + try writeShardingRepresentation(x.withSharding(.{ 0, 1 }), 2, fbs.writer()); + try std.testing.expectEqualStrings("{devices=[2,2]<=[2]}", fbs.getWritten()); + } + } + + fn finalizeAttributeList(allocator: std.mem.Allocator, mlir_ctx: mlir.Context, attributes: []AttributeList) ![]mlir.Attribute { + const res = try allocator.alloc(mlir.Attribute, attributes.len); + for (res, attributes) |*r, attr| { + r.* = mlir.DictionaryAttribute.init(mlir_ctx, attr.constSlice()).asAttr(); + } + return res; + } + /// Generates an MLIR `func.call` of the given function. /// If the function has not been seen yet, we generate MLIR for it, /// in a independent function. @@ -565,47 +644,57 @@ fn assignBlockArguments(v: anytype, block: mlir.Block, start: usize) usize { } /// Visit the given struct and fill the `buffers` slice with the buffer associated with encountered Tensor. -fn fillBuffers(v: anytype, buffers: []*pjrt.Buffer) void { +fn fillBuffers(v: anytype, buffers: []const [*]*pjrt.Buffer, start: u32, len: u32) void { const LocalContext = struct { - index: usize, - buffers: []*pjrt.Buffer, + index: u32, + buffers: []const [*]*pjrt.Buffer, }; - var ctx: LocalContext = .{ - .index = 0, + var context: LocalContext = .{ + .index = start, .buffers = buffers, }; meta.visit((struct { - fn cb(inner_context: *LocalContext, buffer: *const Buffer) void { - // meta.assert(!buffer._data.isDeleted(), "Can't use {} (argument buffer {}) because its pjrt buffer has been donated", .{ buffer, inner_context.index }); - inner_context.buffers[inner_context.index] = buffer._data; - inner_context.index += 1; + fn cb(ctx: *LocalContext, buffer: *const Buffer) void { + // meta.assert(!buffer._data.isDeleted(), "Can't use {} (argument buffer {}) because its pjrt buffer has been donated", .{ buffer, ctx.index }); + const model_sharding = ctx.buffers.len; + meta.assert(buffer._shards.len == model_sharding, "Can't feed a {}-sharded tensor into a {}-sharded model", .{ buffer._shards.len, ctx.buffers.len }); + for (buffer._shards.constSlice(), 0..) |shard, d| { + ctx.buffers[d][ctx.index] = shard; + } + ctx.index += 1; } - }).cb, &ctx, v); - assert(ctx.index == buffers.len); + }).cb, &context, v); + assert(context.index == start + len); } /// Visit the given struct and override tensors by creating a new one using the provided PJRT buffers. -pub fn assignRawBuffers(v: anytype, platform: Platform, buffers: []*pjrt.Buffer) void { +pub fn assignRawBuffers(v: anytype, platform: Platform, buffers: []const [*]*pjrt.Buffer, expected_count: u32) void { const LocalContext = struct { - index: usize, + index: u32, platform: Platform, - buffers: []*pjrt.Buffer, + buffers: []const [*]*pjrt.Buffer, + expected_count: u32, }; - var ctx: LocalContext = .{ + var local_ctx: LocalContext = .{ .index = 0, .platform = platform, .buffers = buffers, + .expected_count = expected_count, }; meta.visit((struct { - fn cb(inner_context: *LocalContext, buffer: *Buffer) void { - const i = inner_context.index; - if (i < inner_context.buffers.len) { - buffer.* = Buffer.fromPjrtBuffer(inner_context.platform, inner_context.buffers[i]); + fn cb(ctx: *LocalContext, buffer: *Buffer) void { + const i = ctx.index; + ctx.index += 1; + if (i >= ctx.expected_count) return; + + var shards: Buffer.Shards = .{}; + for (ctx.buffers) |buff| { + shards.appendAssumeCapacity(buff[i]); } - inner_context.index += 1; + buffer.* = Buffer.fromPjrtBuffers(ctx.platform, shards.constSlice()); } - }).cb, &ctx, v); - meta.assert(ctx.index == buffers.len, "Pjrt call returned {} tensors, but the return type {s}, contains {} Buffers. Note that modules need to have a comptime know number of returned tensors.", .{ buffers.len, @typeName(@TypeOf(v)), ctx.index }); + }).cb, &local_ctx, v); + meta.internalAssert(local_ctx.index == expected_count, "Pjrt call returned {} tensors, but the return type {s}, contains {} Buffers. Note that modules need to have a comptime know number of returned tensors.", .{ buffers.len, @typeName(@TypeOf(v)), local_ctx.index }); } /// Visit the given struct and assign op results to each tensor found. @@ -637,11 +726,13 @@ const BaseExe = struct { /// The PJRT executable representing the compiled module. exe: *pjrt.LoadedExecutable, /// Number of buffers in the model. - model_buffer_count: usize, + model_buffer_count: u32, /// Number of buffers in the arguments. - args_buffer_count: usize, + args_buffer_count: u32, /// Number of buffers in result. - result_buffer_count: usize, + result_buffer_count: u32, + /// Num devices used (>1 for sharded executable) + num_devices: u8, }; /// Represents a ZML model, compiled into a PJRT executable. @@ -674,34 +765,56 @@ pub fn ExeWithWeights(comptime func: anytype) type { /// The raw untyped compiled module. inner: BaseExe, - /// The allocator used for bookkeeping. - allocator: std.mem.Allocator, - /// Pre-allocated slice of buffers to use as inputs when the module is called. - input_buffers: []*pjrt.Buffer, + input_per_device: []const [*]*pjrt.Buffer, /// Pre-allocated slice of buffers to use as outputs when the module is called. - output_buffers: []*pjrt.Buffer, + output_per_device: []const [*]*pjrt.Buffer, + + /// Internal memory slice used. + _all_buffers: []*pjrt.Buffer, + _all_per_device: [][*]*pjrt.Buffer, + + /// And the allocator backing _data_buffer. + _allocator: std.mem.Allocator, pub fn initFromModel(allocator: std.mem.Allocator, inner: BaseExe, model: Bufferized(Signature.ModelT)) !Self { - const input_buffers = try allocator.alloc(*pjrt.Buffer, inner.model_buffer_count + inner.args_buffer_count); - errdefer allocator.free(input_buffers); - fillBuffers(&model, input_buffers[0..inner.model_buffer_count]); + const n_input_buffers = inner.model_buffer_count + inner.args_buffer_count; + const n_output_buffers = inner.result_buffer_count; + const n_devices = inner.num_devices; - const output_buffers = try allocator.alloc(*pjrt.Buffer, inner.result_buffer_count); - errdefer allocator.free(output_buffers); + // Allocate once for all the *pjrt.Buffer we need to store ... + const all_buffers = try allocator.alloc(*pjrt.Buffer, (n_input_buffers + n_output_buffers) * n_devices); + errdefer allocator.free(all_buffers); + const all_input_buffers, const all_output_buffers = splitBuffer(*pjrt.Buffer, all_buffers, .{ n_input_buffers * n_devices, n_output_buffers * n_devices }); + + // ... and once for all the [*]*pjrt.Buffer. + const all_per_device = try allocator.alloc([*]*pjrt.Buffer, 2 * n_devices); + errdefer allocator.free(all_per_device); + const input_per_device, const output_per_device = splitBuffer([*]*pjrt.Buffer, all_per_device, .{ n_devices, n_devices }); + + for (0..n_devices) |i| { + input_per_device[i] = all_input_buffers[i * n_input_buffers ..].ptr; + output_per_device[i] = all_output_buffers[i * n_output_buffers ..].ptr; + } + + fillBuffers(&model, input_per_device, 0, inner.model_buffer_count); + // Note: all_output_buffers is left undefined, it will be written to in `call`. return .{ .inner = inner, - .allocator = allocator, - .input_buffers = input_buffers, - .output_buffers = output_buffers, + .input_per_device = input_per_device, + .output_per_device = output_per_device, + ._all_buffers = all_buffers, + ._all_per_device = all_per_device, + ._allocator = allocator, }; } pub fn deinit(self: Self) void { - self.allocator.free(self.input_buffers); - self.allocator.free(self.output_buffers); + // Free in reverse order of allocation. + self._allocator.free(self._all_per_device); + self._allocator.free(self._all_buffers); } pub fn platform(self: Self) Platform { @@ -709,13 +822,13 @@ pub fn ExeWithWeights(comptime func: anytype) type { } pub fn call(self: Self, args: Bufferized(Signature.ArgsT)) Bufferized(Signature.ReturnT) { - fillBuffers(&args, self.input_buffers[self.inner.model_buffer_count..][0..self.inner.args_buffer_count]); + fillBuffers(&args, self.input_per_device, self.inner.model_buffer_count, self.inner.args_buffer_count); var event: [1]*pjrt.Event = undefined; self.inner.exe.execute(self.inner.platform.pjrt_api, .{ - .arguments = &.{self.input_buffers.ptr}, - .num_args = self.input_buffers.len, - .results = &.{self.output_buffers.ptr}, + .arguments = self.input_per_device, + .num_args = self.inner.args_buffer_count + self.inner.model_buffer_count, + .results = self.output_per_device, .events = &event, // TODO: this allows to tell a specific buffer shouldn't be donated, // even if it has been marked as "can be donated" during compilation. @@ -723,7 +836,7 @@ pub fn ExeWithWeights(comptime func: anytype) type { }) catch unreachable; var result: Bufferized(Signature.ReturnT) = undefined; - assignRawBuffers(&result, self.inner.platform, self.output_buffers); + assignRawBuffers(&result, self.inner.platform, self.output_per_device, self.inner.result_buffer_count); return result; } }; @@ -750,6 +863,11 @@ fn compileInternal( const f = try asynk.callGeneric(CompilationContext.generateBytecode, .{ context, arena, "main", func, &model, &tensor_args, .{ .add_donations_attributes = true } }); context._module.getBody().appendOperation(f.mlir_fn); + const sharding = context._platform.sharding(); + const mlir_ctx = context._mlir_ctx; + context._module.op().setAttributeByName("mhlo.num_replicas", mlir.IntegerAttribute(.i32).init(mlir_ctx, sharding.num_replicas).asAttr()); + context._module.op().setAttributeByName("mhlo.num_partitions", mlir.IntegerAttribute(.i32).init(mlir_ctx, sharding.num_partitions).asAttr()); + const loaded_executable = loadOrCompilePjrtExecutable(arena, context._platform, context._module) catch |err| { log.err( "pjrt-{s} failed to compile following valid MLIR:\n{}\n{}", @@ -771,7 +889,8 @@ fn compileInternal( .exe = loaded_executable, .model_buffer_count = f.n_model, .args_buffer_count = f.n_args, - .result_buffer_count = f.res_types.len, + .result_buffer_count = @intCast(f.res_types.len), + .num_devices = sharding.num_replicas * sharding.num_partitions, }; } @@ -932,10 +1051,12 @@ fn loadOrCompilePjrtExecutable( } fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, module: mlir.Module, module_hash: u64) !*pjrt.LoadedExecutable { + const sharding = platform.sharding(); var options: xla_pb.CompileOptionsProto = .{ .executable_build_options = .{ - .num_replicas = 1, - .num_partitions = 1, + .num_replicas = sharding.num_replicas, + .num_partitions = sharding.num_partitions, + .use_spmd_partitioning = sharding.num_partitions > 1 or sharding.num_replicas > 1, }, }; // Let the arena deinit, zig-protobuf deinit is very slow. @@ -1372,3 +1493,14 @@ fn hashArray(hasher: anytype, key: anytype, comptime strat: HashStrategy) void { hash(hasher, element, strat); } } + +fn splitBuffer(T: type, buffer: []T, lengths: anytype) [lengths.len][]T { + var res: [lengths.len][]T = undefined; + var i: usize = 0; + inline for (&res, lengths) |*r, len| { + r.* = buffer[i .. i + len]; + i += len; + } + std.debug.assert(i == buffer.len); + return res; +} diff --git a/zml/pjrtx.zig b/zml/pjrtx.zig index 96457e7..2f04c22 100644 --- a/zml/pjrtx.zig +++ b/zml/pjrtx.zig @@ -15,6 +15,7 @@ const Target = @import("platform.zig").Target; const log = std.log.scoped(.zml); +pub const Buffer = pjrt.Buffer; pub const Device = pjrt.Device; pub const DeviceDescription = pjrt.DeviceDescription; pub const Api = pjrt.Api; @@ -27,6 +28,12 @@ pub const SerializeResult = pjrt.SerializeResult; pub const Executable = pjrt.Executable; pub const ExecuteError = ApiError; +test { + std.testing.refAllDecls(Client); + std.testing.refAllDecls(Event); + std.testing.refAllDecls(LoadedExecutable); +} + fn InnerMixin(comptime innerT: type) type { return struct { inline fn inner(self: anytype) if (@typeInfo(@TypeOf(self)).Pointer.is_const) *const innerT else *innerT { @@ -63,7 +70,7 @@ pub const Client = opaque { const buffer, const event_ = try self.inner().bufferFromHostBuffer(api, args); const event: *Event = @ptrCast(event_); try event.await_(api); - return @ptrCast(buffer); + return buffer; } pub fn deserializeAndLoad(self: *const Client, api: *const Api, bytes: []const u8) ApiError!*LoadedExecutable { @@ -169,75 +176,6 @@ pub const Client = opaque { } }; -pub const Buffer = opaque { - const inner = InnerMixin(pjrt.Buffer).inner; - - pub const BufferType = pjrt.BufferType; - - pub fn BufferTypeFromDType(dt: dtype.DataType) BufferType { - return switch (dt) { - .bool => .PRED, - .f8e4m3b11fnuz => .F8E4M3B11FNUZ, - .f8e4m3fn => .F8E4M3FN, - .f8e4m3fnuz => .F8E4M3FNUZ, - .f8e5m2 => .F8E5M2, - .f8e5m2fnuz => .F8E5M2FNUZ, - .bf16 => .BF16, - .f16 => .F16, - .f32 => .F32, - .f64 => .F64, - .i8 => .S8, - .i4 => .S4, - .i16 => .S16, - .i32 => .S32, - .i64 => .S64, - .u4 => .U4, - .u8 => .U8, - .u16 => .U16, - .u32 => .U32, - .u64 => .U64, - .c64 => .C64, - .c128 => .C128, - }; - } - - pub const HostBufferSemantics = pjrt.HostBufferSemantics; - pub const MemoryLayoutType = pjrt.MemoryLayoutType; - - pub fn deinit(self: *Buffer, api: *const Api) void { - self.inner().deinit(api); - } - - pub fn delete(self: *Buffer, api: *const Api) void { - self.inner().delete(api); - } - - pub fn toHostBuffer(self: *const Buffer, api: *const Api, dst: []u8) ApiError!void { - var event = try self.inner().toHostBuffer(api, dst); - try event.await_(api); - } - - pub fn getDimensions(self: *const Buffer, api: *const Api) []const i64 { - return self.inner().getDimensions(api); - } - - pub fn getElementType(self: *const Buffer, api: *const Api) BufferType { - return self.inner().getElementType(api); - } - - pub fn isDeleted(self: *const Buffer, api: *const Api) bool { - return self.inner().isDeleted(api); - } - - pub fn getDevice(self: *const Buffer, api: *const Api) ApiError!*Device { - return @ptrCast(try self.inner().getDevice(api)); - } - - pub fn getOpaqueDeviceMemoryDataPointer(self: *const Buffer, api: *const Api) ApiError!*anyopaque { - return self.inner().getOpaqueDeviceMemoryDataPointer(api); - } -}; - pub const Event = opaque { pub const inner = InnerMixin(pjrt.Event).inner; @@ -249,7 +187,7 @@ pub const Event = opaque { return self.inner().isReady(api); } - pub fn getEventError(self: *const Event, api: *const Api) ApiError!?*Error { + pub fn getEventError(self: *const Event, api: *const Api) ?*Error { return self.inner().getEventError(api); } @@ -288,9 +226,9 @@ pub const Event = opaque { pub const LoadedExecutable = opaque { const inner = InnerMixin(pjrt.LoadedExecutable).inner; - pub fn deinit(self: *LoadedExecutable, api: *const Api) void { - self.inner().deinit(api); - } + // pub fn deinit(self: *LoadedExecutable, api: *const Api) void { + // self.inner().deinit(api); + // } pub fn delete(self: *LoadedExecutable, api: *const Api) void { self.inner().delete(api); @@ -300,9 +238,10 @@ pub const LoadedExecutable = opaque { return self.inner().isDeleted(api); } - pub fn getAddressableDevices(self: *const LoadedExecutable, api: *const Api) []*const Device { - return self.inner().getAddressableDevices(api); - } + // TODO fix me + // pub fn getAddressableDevices(self: *const LoadedExecutable, api: *const Api) []*const Device { + // return self.inner().getAddressableDevices(api); + // } pub fn execute(self: *const LoadedExecutable, api: *const Api, args: struct { arguments: []const [*]const *const Buffer, diff --git a/zml/platform.zig b/zml/platform.zig index 60f1ef5..586a2fa 100644 --- a/zml/platform.zig +++ b/zml/platform.zig @@ -6,6 +6,7 @@ const meta = @import("meta.zig"); const module = @import("module.zig"); const pjrt = @import("pjrtx.zig"); const pjrt_core = @import("pjrt"); +const log = std.log.scoped(.zml); pub const Target = enum { cpu, @@ -31,6 +32,8 @@ pub const CompilationOptions = struct { xla_dump_to: ?[]const u8 = null, xla_dump_fusion_visualization: bool = false, cache_location: ?[]const u8 = null, + sharding_enabled: bool = false, + sharding_axes: std.BoundedArray([*:0]const u8, 8) = .{}, }; pub const Platform = struct { @@ -39,8 +42,14 @@ pub const Platform = struct { pjrt_client: *pjrt.Client, compilation_options: CompilationOptions = .{}, + pub const MAX_NUM_DEVICES: u8 = 8; + pub fn init(target: Target, api: *const pjrt.Api) !Platform { const pjrt_client = try pjrt.Client.init(api, &.{}); + const true_num_devices = pjrt_client.getAddressableDevices(api).len; + if (true_num_devices > MAX_NUM_DEVICES) { + log.warn("platform {} got {} devices, but ZML only support up to {} devices. Some devices won't be used.", .{ target, true_num_devices, MAX_NUM_DEVICES }); + } return .{ .target = target, .pjrt_api = api, @@ -50,7 +59,26 @@ pub const Platform = struct { } pub fn getDevices(self: Platform) []const *const pjrt_core.Device { - return self.pjrt_client.getAddressableDevices(self.pjrt_api); + const all_devices = self.pjrt_client.getAddressableDevices(self.pjrt_api); + if (all_devices.len > MAX_NUM_DEVICES) { + return all_devices[0..MAX_NUM_DEVICES]; + } + return all_devices; + } + + pub const Sharding = struct { num_replicas: u8, num_partitions: u8 }; + + pub fn sharding(self: Platform) Sharding { + // replicas run the same function but with different inputs, + // while partitions contribute to one evaluation over a shared input. + // Inside an inference process, we generally don't want replicas, + // as it's best to fully isolate replicas on different processes. + // For now we hardcode num_replicas = 1. + const num_devices: u8 = @intCast(self.getDevices().len); + return if (self.compilation_options.sharding_enabled) + .{ .num_replicas = 1, .num_partitions = num_devices } + else + .{ .num_replicas = 1, .num_partitions = 1 }; } pub fn withCompilationOptions(self: Platform, opts: CompilationOptions) Platform { diff --git a/zml/shape.zig b/zml/shape.zig index 2cfc292..2ed4016 100644 --- a/zml/shape.zig +++ b/zml/shape.zig @@ -23,12 +23,14 @@ pub const Shape = struct { pub const DimsArray = std.BoundedArray(i64, MAX_RANK); pub const TagsArray = std.BoundedArray(Tag, MAX_RANK); pub const AxesArray = std.BoundedArray(u3, MAX_RANK); + pub const ShardingInfo = @Vector(MAX_RANK, bool); const UnknownTags: TagsArray = .{ .len = 0, .buffer = [_]Tag{TagUnknown} ** MAX_RANK }; _dtype: DataType, _dims: DimsArray = .{}, _tags: TagsArray = UnknownTags, + _sharding_info: ShardingInfo = @splat(false), pub fn parseDimensions(v: anytype) struct { DimsArray, TagsArray } { const T = @TypeOf(v); @@ -69,7 +71,7 @@ pub const Shape = struct { return .{ dims_, tags_ }; } - meta.compileError("Wrong type, got {}", .{T}); + meta.compileError("expected a dimension tuple eg '.{{ .a = 10, .b = 20}}' or '.{{ 10, 20 }}', got {}", .{T}); } test parseDimensions { @@ -286,7 +288,7 @@ pub const Shape = struct { return res; } - meta.compileError("Wrong type, got {}", .{T}); + meta.compileError("axes expects an int-tuple or a tuple of enum literal, got {}", .{T}); } fn axisFromInt(self: Shape, d: isize) u3 { @@ -384,6 +386,9 @@ pub const Shape = struct { } else { try writer.print("{s}{d}", .{ prefix, d }); } + if (self._sharding_info[i]) { + try writer.writeByte('!'); + } } _ = try writer.print("}}, dtype=.{s}", .{@tagName(self.dtype())}); if (!bare_fmt) _ = try writer.write(")"); @@ -664,6 +669,16 @@ pub const Shape = struct { return res; } + pub fn withSharding(self: Shape, axes_: anytype) Shape { + var res = self; + // Reset sharding. + res._sharding_info = @splat(false); + for (self.axes(axes_).constSlice()) |ax| { + res._sharding_info[ax] = true; + } + return res; + } + /// Renames some of the tags in this shape. /// Shape.init(.{ .a = 10, .b = 20 }).rename(.{ .b = .batch }); // .{ .a = 10, .batch = 20 }; pub fn rename(self: Shape, renames: anytype) Shape { @@ -690,7 +705,8 @@ pub const Shape = struct { } } - pub fn computeStrides(self: Shape, base_stride: u32) std.BoundedArray(i64, MAX_RANK) { + pub fn computeStrides(self: Shape) std.BoundedArray(i64, MAX_RANK) { + const base_stride = self.dtype().sizeOf(); const rk = self.rank(); var strides: std.BoundedArray(i64, MAX_RANK) = .{ .len = @intCast(self.rank()) }; if (rk == 0) return strides; diff --git a/zml/tensor.zig b/zml/tensor.zig index 4cde57e..dc0012e 100644 --- a/zml/tensor.zig +++ b/zml/tensor.zig @@ -159,6 +159,12 @@ pub const Tensor = struct { return res; } + pub fn withSharding(self: Tensor, axes_: anytype) Tensor { + var res = self; + res._shape = self._shape.withSharding(axes_); + return res; + } + /// Returns a Tensor with new tag names. pub fn rename(self: Tensor, renames: anytype) Tensor { var res = self; @@ -196,11 +202,6 @@ pub const Tensor = struct { return res; } - /// Returns a slice containing the strides for a Tensor. - pub inline fn computeStrides(self: Tensor) []const i64 { - return self._shape.computeStrides(self.dtype().sizeOf()).constSlice(); - } - var _global_tensor_counter: u64 = 0; /// Internal use