Add in-process sharding support across core ZML components (platform, shape, tensor, MLIR generation, buffers, and PJRT integration)
This commit is contained in:
parent
cad1a688da
commit
2f129f76c9
@ -47,7 +47,7 @@ pub fn MlirTypeMethods(comptime InnerT: type) type {
|
|||||||
/// Alternative to MlirWrapperType
|
/// Alternative to MlirWrapperType
|
||||||
pub const MlirStrCallback = fn (c.MlirStringRef, ?*anyopaque) callconv(.C) void;
|
pub const MlirStrCallback = fn (c.MlirStringRef, ?*anyopaque) callconv(.C) void;
|
||||||
|
|
||||||
fn MlirHelpersMethods(comptime OuterT: type) type {
|
fn MlirHelpersMethods(OuterT: type) type {
|
||||||
switch (@typeInfo(OuterT)) {
|
switch (@typeInfo(OuterT)) {
|
||||||
.Struct => |info| {
|
.Struct => |info| {
|
||||||
if (info.fields.len != 1) @compileError("Mlir wrapper type can only wrap one Mlir value. Received: " ++ @typeName(OuterT));
|
if (info.fields.len != 1) @compileError("Mlir wrapper type can only wrap one Mlir value. Received: " ++ @typeName(OuterT));
|
||||||
@ -382,6 +382,10 @@ pub const StringAttribute = struct {
|
|||||||
pub fn value(self: Self) []const u8 {
|
pub fn value(self: Self) []const u8 {
|
||||||
return fromStringRef(c.mlirStringAttrGetValue(self.inner()));
|
return fromStringRef(c.mlirStringAttrGetValue(self.inner()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn asAttr(self: StringAttribute) Attribute {
|
||||||
|
return .{ ._inner = self._inner };
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
pub const UnitAttribute = struct {
|
pub const UnitAttribute = struct {
|
||||||
@ -493,6 +497,10 @@ pub fn IntegerAttribute(comptime it: IntegerTypes) type {
|
|||||||
pub fn get(value: IntAttr) ZigType {
|
pub fn get(value: IntAttr) ZigType {
|
||||||
return @intCast(getter(value.inner()));
|
return @intCast(getter(value.inner()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn asAttr(self: IntAttr) Attribute {
|
||||||
|
return .{ ._inner = self._inner };
|
||||||
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -731,23 +739,29 @@ pub const DictionaryAttribute = struct {
|
|||||||
.equal_fn = c.mlirAttributeEqual,
|
.equal_fn = c.mlirAttributeEqual,
|
||||||
});
|
});
|
||||||
|
|
||||||
const Self = DictionaryAttribute;
|
pub fn init(ctx: Context, attributes: []const NamedAttribute) DictionaryAttribute {
|
||||||
|
return DictionaryAttribute.wrap(c.mlirDictionaryAttrGet(
|
||||||
pub fn init(ctx: Context, attributes: []const NamedAttribute) Self {
|
ctx.inner(),
|
||||||
return Self.wrap(c.mlirDictionaryAttrGet(ctx.inner(), @intCast(attributes.len), @ptrCast(attributes.ptr)));
|
@intCast(attributes.len),
|
||||||
|
@ptrCast(attributes.ptr),
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn size(self: Self) usize {
|
pub fn size(self: DictionaryAttribute) usize {
|
||||||
return @intCast(c.mlirDictionaryAttrGetNumElements(self.inner()));
|
return @intCast(c.mlirDictionaryAttrGetNumElements(self.inner()));
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get(self: Self, pos: usize) NamedAttribute {
|
pub fn get(self: DictionaryAttribute, pos: usize) NamedAttribute {
|
||||||
return NamedAttribute.wrap(c.mlirDictionaryAttrGetElement(self.inner(), @intCast(pos)));
|
return NamedAttribute.wrap(c.mlirDictionaryAttrGetElement(self.inner(), @intCast(pos)));
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn getByName(self: Self, name: [:0]const u8) ?NamedAttribute {
|
pub fn getByName(self: DictionaryAttribute, name: [:0]const u8) ?NamedAttribute {
|
||||||
return NamedAttribute.wrapOr(c.mlirDictionaryAttrGetElementByName(self.inner(), name));
|
return NamedAttribute.wrapOr(c.mlirDictionaryAttrGetElementByName(self.inner(), name));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn asAttr(self: DictionaryAttribute) Attribute {
|
||||||
|
return .{ ._inner = self._inner };
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
pub const Operation = struct {
|
pub const Operation = struct {
|
||||||
|
|||||||
@ -727,11 +727,11 @@ pub const Event = opaque {
|
|||||||
return ret.is_ready;
|
return ret.is_ready;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn getEventError(self: *const Event, api: *const Api) ApiError!?*Error {
|
pub fn getEventError(self: *const Event, api: *const Api) ?*Error {
|
||||||
const ret = try api.call(.PJRT_Event_Error, .{
|
var args: Api.CallFnArgType(.PJRT_Event_Error) = .{ .event = self.inner() };
|
||||||
.event = self.inner(),
|
args = pjrtStruct(args);
|
||||||
});
|
const result: ?*c.PJRT_Error = api.inner.PJRT_Event_Error.?(&args);
|
||||||
return @ptrCast(ret);
|
return @ptrCast(result);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn await_(self: *const Event, api: *const Api) ApiError!void {
|
pub fn await_(self: *const Event, api: *const Api) ApiError!void {
|
||||||
|
|||||||
61
zml/aio.zig
61
zml/aio.zig
@ -398,7 +398,12 @@ pub fn loadModelBuffers(
|
|||||||
) !zml.Bufferized(Model) {
|
) !zml.Bufferized(Model) {
|
||||||
return try loadModelBuffersWithPrefix(Model, model, buffer_store, allocator, platform, "");
|
return try loadModelBuffersWithPrefix(Model, model, buffer_store, allocator, platform, "");
|
||||||
}
|
}
|
||||||
|
/// Creates a bufferized version of a Model from the given BufferStore and the given prefix.
|
||||||
|
/// For details about bufferization, see the documentation of Bufferized(T).
|
||||||
|
///
|
||||||
|
/// This will represent the weights of the model, loaded on a specific platform.
|
||||||
|
/// It can be used with a `module.Exe` (a compiled version of the same Model), to make a
|
||||||
|
/// `module.ExeWithWeights` ready to be called.
|
||||||
pub fn loadModelBuffersWithPrefix(
|
pub fn loadModelBuffersWithPrefix(
|
||||||
comptime Model: type,
|
comptime Model: type,
|
||||||
model: Model,
|
model: Model,
|
||||||
@ -408,12 +413,12 @@ pub fn loadModelBuffersWithPrefix(
|
|||||||
prefix: []const u8,
|
prefix: []const u8,
|
||||||
) !zml.Bufferized(Model) {
|
) !zml.Bufferized(Model) {
|
||||||
// Allocate the bufferized version.
|
// Allocate the bufferized version.
|
||||||
// We set fields to undefined, cause visitStructAndLoadBuffer is responsible
|
// We copy the shape, and let visitStructAndLoadBuffer write the other fields.
|
||||||
// to write them just afterward.
|
// to write them just afterward.
|
||||||
var res: zml.Bufferized(Model) = undefined;
|
var res: zml.Bufferized(Model) = undefined;
|
||||||
try zml.meta.mapAlloc(struct {
|
try zml.meta.mapAlloc(struct {
|
||||||
pub fn initBuffer(_: void, _: zml.Tensor) zml.Buffer {
|
pub fn initBuffer(_: void, tensor: zml.Tensor) zml.Buffer {
|
||||||
return undefined;
|
return .{ ._shape = tensor.shape(), ._api = undefined, ._shards = undefined };
|
||||||
}
|
}
|
||||||
}.initBuffer, allocator, {}, model, &res);
|
}.initBuffer, allocator, {}, model, &res);
|
||||||
|
|
||||||
@ -425,32 +430,6 @@ pub fn loadModelBuffersWithPrefix(
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates a bufferized version of a Model from the given BufferStore and the given prefix.
|
|
||||||
/// For details about bufferization, see the documentation of Bufferized(T).
|
|
||||||
///
|
|
||||||
/// This will represent the weights of the model, loaded on a specific platform.
|
|
||||||
/// It can be used with a `module.Exe` (a compiled version of the same Model), to make a
|
|
||||||
/// `module.ExeWithWeights` ready to be called.
|
|
||||||
pub fn loadBuffersFromModelWithPrefix(comptime Model: type, model: Model, buffer_store: BufferStore, allocator: std.mem.Allocator, prefix: []const u8, platform: zml.Platform) !zml.Bufferized(Model) {
|
|
||||||
|
|
||||||
// Allocate the bufferized version.
|
|
||||||
// We set fields to undefined, cause visitStructAndLoadBuffer is responsible
|
|
||||||
// to write them just afterward.
|
|
||||||
var res: zml.Bufferized(Model) = undefined;
|
|
||||||
try zml.meta.mapAlloc(struct {
|
|
||||||
pub fn initBuffer(_: void, _: zml.Tensor) zml.Buffer {
|
|
||||||
return undefined;
|
|
||||||
}
|
|
||||||
}.initBuffer, allocator, {}, model, &res);
|
|
||||||
|
|
||||||
var prefix_builder: PrefixBuilder = .{};
|
|
||||||
defer prefix_builder.deinit(allocator);
|
|
||||||
try prefix_builder.push(allocator, prefix);
|
|
||||||
|
|
||||||
try visitStructAndLoadBuffer(allocator, &prefix_builder, buffer_store, &res, platform);
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Takes a bufferized version of a `model`, ie a mirror struct of the `model`, and deinit all the
|
/// Takes a bufferized version of a `model`, ie a mirror struct of the `model`, and deinit all the
|
||||||
/// Buffer found.
|
/// Buffer found.
|
||||||
pub fn unloadBuffers(model: anytype) void {
|
pub fn unloadBuffers(model: anytype) void {
|
||||||
@ -474,7 +453,12 @@ fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *Prefi
|
|||||||
const prefix = prefix_builder.data.items;
|
const prefix = prefix_builder.data.items;
|
||||||
if (T == zml.Buffer) {
|
if (T == zml.Buffer) {
|
||||||
return if (buffer_store.get(prefix)) |host_buffer| {
|
return if (buffer_store.get(prefix)) |host_buffer| {
|
||||||
obj.* = try zml.Buffer.from(platform, host_buffer);
|
// obj._shape has been set inside `loadModelBuffersWithPrefix`, before calling us.
|
||||||
|
var buf_with_metadata = host_buffer;
|
||||||
|
log.warn("loading {s} ({})", .{ prefix, obj._shape });
|
||||||
|
zml.meta.assert(host_buffer.shape().eql(obj._shape), "loadModelBuffers expects to find the same shapes in the model and in the buffer store, got {} and {} for tensor {s}", .{ obj._shape, host_buffer, prefix });
|
||||||
|
buf_with_metadata._shape = obj._shape;
|
||||||
|
obj.* = try zml.Buffer.from(platform, buf_with_metadata);
|
||||||
} else {
|
} else {
|
||||||
return error.BufferNotFound;
|
return error.BufferNotFound;
|
||||||
};
|
};
|
||||||
@ -484,10 +468,7 @@ fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *Prefi
|
|||||||
.Pointer => |ptr_info| {
|
.Pointer => |ptr_info| {
|
||||||
if (ptr_info.size == .Slice) {
|
if (ptr_info.size == .Slice) {
|
||||||
for (obj.*, 0..) |*value, i| {
|
for (obj.*, 0..) |*value, i| {
|
||||||
var buffer: [100]u8 = undefined;
|
try prefix_builder.pushDigit(allocator, i);
|
||||||
const new_prefix = std.fmt.bufPrint(&buffer, "{d}", .{i}) catch unreachable;
|
|
||||||
|
|
||||||
try prefix_builder.push(allocator, new_prefix);
|
|
||||||
defer prefix_builder.pop();
|
defer prefix_builder.pop();
|
||||||
|
|
||||||
try visitStructAndLoadBuffer(allocator, prefix_builder, buffer_store, value, platform);
|
try visitStructAndLoadBuffer(allocator, prefix_builder, buffer_store, value, platform);
|
||||||
@ -502,13 +483,9 @@ fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *Prefi
|
|||||||
try visitStructAndLoadBuffer(allocator, prefix_builder, buffer_store, &@field(obj, field.name), platform);
|
try visitStructAndLoadBuffer(allocator, prefix_builder, buffer_store, &@field(obj, field.name), platform);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
.Optional => |opt_info| {
|
.Optional => {
|
||||||
var child = @as(opt_info.child, undefined);
|
if (obj.*) |*obj_val| {
|
||||||
if (visitStructAndLoadBuffer(allocator, prefix_builder, buffer_store, &child, platform)) {
|
try visitStructAndLoadBuffer(allocator, prefix_builder, buffer_store, obj_val, platform);
|
||||||
obj.* = child;
|
|
||||||
} else |err| switch (err) {
|
|
||||||
error.BufferNotFound => {},
|
|
||||||
else => return err,
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
else => {},
|
else => {},
|
||||||
|
|||||||
158
zml/buffer.zig
158
zml/buffer.zig
@ -3,7 +3,6 @@ const testing = std.testing;
|
|||||||
|
|
||||||
const meta = @import("meta.zig");
|
const meta = @import("meta.zig");
|
||||||
const pjrt = @import("pjrt");
|
const pjrt = @import("pjrt");
|
||||||
const pjrtx = @import("pjrtx.zig");
|
|
||||||
|
|
||||||
const Context = @import("context.zig").Context;
|
const Context = @import("context.zig").Context;
|
||||||
const Data = @import("dtype.zig").Data;
|
const Data = @import("dtype.zig").Data;
|
||||||
@ -13,9 +12,12 @@ const Platform = @import("platform.zig").Platform;
|
|||||||
const Shape = @import("shape.zig").Shape;
|
const Shape = @import("shape.zig").Shape;
|
||||||
|
|
||||||
test {
|
test {
|
||||||
|
std.testing.refAllDecls(@This());
|
||||||
std.testing.refAllDecls(Buffer);
|
std.testing.refAllDecls(Buffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const log = std.log.scoped(.zml);
|
||||||
|
|
||||||
/// Buffer is a multi-dimension array, whose memory is allocated on an accelerator.
|
/// Buffer is a multi-dimension array, whose memory is allocated on an accelerator.
|
||||||
///
|
///
|
||||||
/// * contains a handle that the ZML runtime can use to convert into a physical address, but there is no guarantee this address is visible from the CPU.
|
/// * contains a handle that the ZML runtime can use to convert into a physical address, but there is no guarantee this address is visible from the CPU.
|
||||||
@ -23,33 +25,70 @@ test {
|
|||||||
/// * can be created by calling `HostBuffer.toDevice(platform)`.
|
/// * can be created by calling `HostBuffer.toDevice(platform)`.
|
||||||
pub const Buffer = struct {
|
pub const Buffer = struct {
|
||||||
_shape: Shape,
|
_shape: Shape,
|
||||||
_shards: Shape = undefined,
|
_api: *const pjrt.Api,
|
||||||
_platform: Platform,
|
_shards: Shards,
|
||||||
_data: *pjrtx.Buffer,
|
|
||||||
|
pub const MAX_NUM_SHARDS: u8 = Platform.MAX_NUM_DEVICES;
|
||||||
|
pub const Shards = std.BoundedArray(*pjrt.Buffer, MAX_NUM_SHARDS);
|
||||||
|
|
||||||
/// Copies the content of the given buffer from host memory to the accelerator memory.
|
/// Copies the content of the given buffer from host memory to the accelerator memory.
|
||||||
pub fn from(platform: Platform, buf: HostBuffer) !Buffer {
|
pub fn from(platform: Platform, host_buffer: HostBuffer) !Buffer {
|
||||||
|
var res: Buffer = .{
|
||||||
|
._api = platform.pjrt_api,
|
||||||
|
._shape = host_buffer.shape(),
|
||||||
|
._shards = .{},
|
||||||
|
};
|
||||||
|
|
||||||
|
// We shard only on the first axis so that the chunks are still contiguous.
|
||||||
|
// TODO: support more advanced sharding specs
|
||||||
|
meta.assert(platform.sharding().num_replicas == 1, "ZML doesn't support num_replicas > 1 for now, got: {}", .{platform.sharding()});
|
||||||
|
const sharding_ax: ?u3 = std.simd.firstTrue(host_buffer.shape()._sharding_info);
|
||||||
|
const n_partitions = platform.sharding().num_partitions;
|
||||||
|
const chunk_size = if (sharding_ax) |ax| cs: {
|
||||||
|
// This kind of sharding error should be detected earlier on.
|
||||||
|
meta.assert(@rem(host_buffer.dim(ax), n_partitions) == 0, "Buffer.from({}) expects the sharding axis {} to have a dimension divisble by the number of devices ({}).", .{ host_buffer, ax, n_partitions });
|
||||||
|
break :cs @divExact(host_buffer.dim(ax), n_partitions);
|
||||||
|
} else 0;
|
||||||
|
|
||||||
|
const buffer_type = bufferTypeFromDtype(host_buffer.shape().dtype());
|
||||||
|
const byte_strides = host_buffer.strides() orelse host_buffer.shape().computeStrides().constSlice();
|
||||||
|
|
||||||
|
const devices = platform.getDevices();
|
||||||
|
for (0..n_partitions) |i| {
|
||||||
|
// If no sharding if found, the given buffer is replicated on all devices.
|
||||||
|
const buf = if (sharding_ax) |ax| buf: {
|
||||||
|
const start: i64 = @as(i64, @intCast(i)) * chunk_size;
|
||||||
|
break :buf host_buffer.slice1d(ax, .{ .start = start, .end = start + chunk_size });
|
||||||
|
} else host_buffer;
|
||||||
|
|
||||||
const pjrt_buffer = try platform.pjrt_client.bufferFromHostBuffer(platform.pjrt_api, .{
|
const pjrt_buffer = try platform.pjrt_client.bufferFromHostBuffer(platform.pjrt_api, .{
|
||||||
.data = buf.data,
|
.data = buf.data,
|
||||||
.buffer_type = pjrtx.Buffer.BufferTypeFromDType(buf.shape().dtype()),
|
.buffer_type = buffer_type,
|
||||||
.dims = buf.shape().dims(),
|
.dims = buf.shape().dims(),
|
||||||
.byte_strides = buf.strides(),
|
.byte_strides = byte_strides,
|
||||||
.device = platform.getDevices()[0],
|
.device = devices[i],
|
||||||
.host_buffer_semantics = .ImmutableUntilTransferCompletes,
|
.host_buffer_semantics = .ImmutableUntilTransferCompletes,
|
||||||
});
|
});
|
||||||
return .{
|
|
||||||
._platform = platform,
|
res._shards.appendAssumeCapacity(pjrt_buffer);
|
||||||
._shape = buf.shape(),
|
}
|
||||||
._data = pjrt_buffer,
|
return res;
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Wraps a pre-exisiting `pjrt.Buffer` into a `zml.Buffer`.
|
/// Wraps pre-exisiting `pjrt.Buffer` shards into one `zml.Buffer`.
|
||||||
pub fn fromPjrtBuffer(platform: Platform, pjrt_buffer: *pjrtx.Buffer) Buffer {
|
pub fn fromPjrtBuffers(platform: Platform, pjrt_buffers: []const *pjrt.Buffer) Buffer {
|
||||||
|
meta.assert(pjrt_buffers.len <= MAX_NUM_SHARDS, "ZML doesn't support having more than {} shards. Received {} shards for one buffer.", .{ MAX_NUM_SHARDS, pjrt_buffers.len });
|
||||||
|
meta.assert(pjrt_buffers.len > 0, "fromPjrtBuffers expects at least one buffer, got 0.", .{});
|
||||||
|
var shards: Shards = .{};
|
||||||
|
shards.appendSliceAssumeCapacity(pjrt_buffers);
|
||||||
return .{
|
return .{
|
||||||
._platform = platform,
|
._api = platform.pjrt_api,
|
||||||
._shape = _shapeFromPjrtBuffer(platform, pjrt_buffer),
|
._shape = Shape.init(
|
||||||
._data = pjrt_buffer,
|
// This isn't with sharded axes.
|
||||||
|
pjrt_buffers[0].getDimensions(platform.pjrt_api),
|
||||||
|
dtypeFromBufferType(pjrt_buffers[0].getElementType(platform.pjrt_api)),
|
||||||
|
),
|
||||||
|
._shards = shards,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -112,8 +151,9 @@ pub const Buffer = struct {
|
|||||||
|
|
||||||
const pjrt_buffer = try platform.pjrt_client.createViewOfDeviceBuffer(platform.pjrt_api, .{
|
const pjrt_buffer = try platform.pjrt_client.createViewOfDeviceBuffer(platform.pjrt_api, .{
|
||||||
.data = buf.data,
|
.data = buf.data,
|
||||||
.element_type = pjrtx.Buffer.BufferTypeFromDType(buf.shape().dtype()),
|
.element_type = bufferTypeFromDtype(buf.shape().dtype()),
|
||||||
.dims = buf.shape().dims(),
|
.dims = buf.shape().dims(),
|
||||||
|
// TODO: split in shards
|
||||||
.device = platform.getDevices()[0],
|
.device = platform.getDevices()[0],
|
||||||
.layout = .{
|
.layout = .{
|
||||||
.Tiled = .{
|
.Tiled = .{
|
||||||
@ -124,10 +164,12 @@ pub const Buffer = struct {
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
var shards: Shards = .{};
|
||||||
|
shards.appendAssumeCapacity(pjrt_buffer);
|
||||||
return .{
|
return .{
|
||||||
._platform = platform,
|
._api = platform.pjrt_api,
|
||||||
._shape = buf.shape(),
|
._shape = buf.shape(),
|
||||||
._data = pjrt_buffer,
|
._shards = shards,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -135,7 +177,9 @@ pub const Buffer = struct {
|
|||||||
pub fn getValue(self: Buffer, T: type) !T {
|
pub fn getValue(self: Buffer, T: type) !T {
|
||||||
meta.assert(self._shape.byteSize() == @sizeOf(T), "Buffer {} has {d} bytes of data, can't load it to a {s} with {d} bytes", .{ self, self._shape.byteSize(), @typeName(T), @sizeOf(T) });
|
meta.assert(self._shape.byteSize() == @sizeOf(T), "Buffer {} has {d} bytes of data, can't load it to a {s} with {d} bytes", .{ self, self._shape.byteSize(), @typeName(T), @sizeOf(T) });
|
||||||
var res: T = undefined;
|
var res: T = undefined;
|
||||||
try self._data.toHostBuffer(self._platform.pjrt_api, std.mem.asBytes(&res));
|
meta.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer -> Host transfer", .{});
|
||||||
|
const event = try self._shards.get(0).toHostBuffer(self._api, std.mem.asBytes(&res));
|
||||||
|
try event.await_(self._api);
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -143,7 +187,9 @@ pub const Buffer = struct {
|
|||||||
/// and return a new `HostBuffer` object with the same shape.
|
/// and return a new `HostBuffer` object with the same shape.
|
||||||
/// The returned `HostBuffer` doesn't own the memory.
|
/// The returned `HostBuffer` doesn't own the memory.
|
||||||
pub fn toHost(self: Buffer, output: []u8) !HostBuffer {
|
pub fn toHost(self: Buffer, output: []u8) !HostBuffer {
|
||||||
try self._data.toHostBuffer(self._platform.pjrt_api, output);
|
meta.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer -> Host transfer", .{});
|
||||||
|
const event = try self._shards.get(0).toHostBuffer(self._api, output);
|
||||||
|
try event.await_(self._api);
|
||||||
return HostBuffer.fromBytes(self.shape(), output);
|
return HostBuffer.fromBytes(self.shape(), output);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -151,7 +197,9 @@ pub const Buffer = struct {
|
|||||||
/// The returned `HostBuffer` does own the memory.
|
/// The returned `HostBuffer` does own the memory.
|
||||||
pub fn toHostAlloc(self: Buffer, allocator: std.mem.Allocator) !HostBuffer {
|
pub fn toHostAlloc(self: Buffer, allocator: std.mem.Allocator) !HostBuffer {
|
||||||
const output = try HostBuffer.empty(allocator, self.shape());
|
const output = try HostBuffer.empty(allocator, self.shape());
|
||||||
try self._data.toHostBuffer(self._platform.pjrt_api, @constCast(output.data));
|
meta.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer -> Host transfer", .{});
|
||||||
|
const event = try self._shards.get(0).toHostBuffer(self._api, @constCast(output.data));
|
||||||
|
try event.await_(self._api);
|
||||||
return output;
|
return output;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -159,7 +207,9 @@ pub const Buffer = struct {
|
|||||||
/// Depending on the platform, the memory is typically not released to the OS
|
/// Depending on the platform, the memory is typically not released to the OS
|
||||||
/// but just marked as available in the memory pool.
|
/// but just marked as available in the memory pool.
|
||||||
pub fn deinit(self: *const Buffer) void {
|
pub fn deinit(self: *const Buffer) void {
|
||||||
self._data.deinit(self._platform.pjrt_api);
|
for (self._shards.constSlice()) |buffer| {
|
||||||
|
buffer.deinit(self._api);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// This Buffer shape.
|
/// This Buffer shape.
|
||||||
@ -202,9 +252,41 @@ pub const Buffer = struct {
|
|||||||
try writer.print("Buffer({_})", .{self._shape});
|
try writer.print("Buffer({_})", .{self._shape});
|
||||||
}
|
}
|
||||||
|
|
||||||
fn _shapeFromPjrtBuffer(platform: Platform, buf: *pjrtx.Buffer) Shape {
|
fn hasShardedAxis(self: Buffer) bool {
|
||||||
const dt: DataType = switch (buf.getElementType(platform.pjrt_api)) {
|
if (self._shards.len == 1) return false;
|
||||||
// Please keep the list exhaustive and in the same order than in DataType.
|
return @reduce(.Or, self._shape._sharding_info);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
pub fn bufferTypeFromDtype(dt: DataType) pjrt.BufferType {
|
||||||
|
return switch (dt) {
|
||||||
|
.bool => .PRED,
|
||||||
|
.f8e4m3b11fnuz => .F8E4M3B11FNUZ,
|
||||||
|
.f8e4m3fn => .F8E4M3FN,
|
||||||
|
.f8e4m3fnuz => .F8E4M3FNUZ,
|
||||||
|
.f8e5m2 => .F8E5M2,
|
||||||
|
.f8e5m2fnuz => .F8E5M2FNUZ,
|
||||||
|
.bf16 => .BF16,
|
||||||
|
.f16 => .F16,
|
||||||
|
.f32 => .F32,
|
||||||
|
.f64 => .F64,
|
||||||
|
.i8 => .S8,
|
||||||
|
.i4 => .S4,
|
||||||
|
.i16 => .S16,
|
||||||
|
.i32 => .S32,
|
||||||
|
.i64 => .S64,
|
||||||
|
.u4 => .U4,
|
||||||
|
.u8 => .U8,
|
||||||
|
.u16 => .U16,
|
||||||
|
.u32 => .U32,
|
||||||
|
.u64 => .U64,
|
||||||
|
.c64 => .C64,
|
||||||
|
.c128 => .C128,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn dtypeFromBufferType(pjrt_type: pjrt.BufferType) DataType {
|
||||||
|
return switch (pjrt_type) {
|
||||||
.PRED => .bool,
|
.PRED => .bool,
|
||||||
.F8E4M3B11FNUZ => .f8e4m3b11fnuz,
|
.F8E4M3B11FNUZ => .f8e4m3b11fnuz,
|
||||||
.F8E4M3FN => .f8e4m3fn,
|
.F8E4M3FN => .f8e4m3fn,
|
||||||
@ -215,8 +297,8 @@ pub const Buffer = struct {
|
|||||||
.F16 => .f16,
|
.F16 => .f16,
|
||||||
.F32 => .f32,
|
.F32 => .f32,
|
||||||
.F64 => .f64,
|
.F64 => .f64,
|
||||||
.S4 => .i4,
|
|
||||||
.S8 => .i8,
|
.S8 => .i8,
|
||||||
|
.S4 => .i4,
|
||||||
.S16 => .i16,
|
.S16 => .i16,
|
||||||
.S32 => .i32,
|
.S32 => .i32,
|
||||||
.S64 => .i64,
|
.S64 => .i64,
|
||||||
@ -227,9 +309,19 @@ pub const Buffer = struct {
|
|||||||
.U64 => .u64,
|
.U64 => .u64,
|
||||||
.C64 => .c64,
|
.C64 => .c64,
|
||||||
.C128 => .c128,
|
.C128 => .c128,
|
||||||
.INVALID => @panic("Can't handle INVALID Pjrt buffers."),
|
.INVALID => @panic("Found an invalid pjrt buffer"),
|
||||||
};
|
};
|
||||||
|
|
||||||
return Shape.init(buf.getDimensions(platform.pjrt_api), dt);
|
|
||||||
}
|
}
|
||||||
};
|
|
||||||
|
test bufferTypeFromDtype {
|
||||||
|
inline for (@typeInfo(DataType).Enum.fields) |field| {
|
||||||
|
const dt: DataType = @enumFromInt(field.value);
|
||||||
|
try std.testing.expectEqual(dt, dtypeFromBufferType(bufferTypeFromDtype(dt)));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline for (@typeInfo(pjrt.BufferType).Enum.fields) |field| {
|
||||||
|
const dt: pjrt.BufferType = @enumFromInt(field.value);
|
||||||
|
if (dt == .INVALID) continue;
|
||||||
|
try std.testing.expectEqual(dt, bufferTypeFromDtype(dtypeFromBufferType(dt)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@ -147,9 +147,15 @@ pub const HostBuffer = struct {
|
|||||||
return try Buffer.from(platform_, self);
|
return try Buffer.from(platform_, self);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Interpret the underlying data as a contiguous slice.
|
||||||
|
/// WARNING: It's only valid if the buffer is contiguous.
|
||||||
|
/// Strided buffers can't use this method.
|
||||||
pub fn items(self: HostBuffer, comptime T: type) []const T {
|
pub fn items(self: HostBuffer, comptime T: type) []const T {
|
||||||
if (DataType.fromZigType(T) != self.dtype()) {
|
if (DataType.fromZigType(T) != self.dtype()) {
|
||||||
std.debug.panic("Can't reinterpret HostBuffer({_}) as {s}", .{ self.shape(), @typeName(T) });
|
std.debug.panic("Can't reinterpret {} as {s}", .{ self, @typeName(T) });
|
||||||
|
}
|
||||||
|
if (!self.isContiguous()) {
|
||||||
|
std.debug.panic("{} isn't contiguous", .{self});
|
||||||
}
|
}
|
||||||
const ptr: [*]const T = @alignCast(@ptrCast(self.data.ptr));
|
const ptr: [*]const T = @alignCast(@ptrCast(self.data.ptr));
|
||||||
return ptr[0..self._shape.count()];
|
return ptr[0..self._shape.count()];
|
||||||
@ -180,16 +186,65 @@ pub const HostBuffer = struct {
|
|||||||
return self._shape.count();
|
return self._shape.count();
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn dim(self: HostBuffer, axis: anytype) i64 {
|
pub fn dim(self: HostBuffer, axis_: anytype) i64 {
|
||||||
return self._shape.dim(axis);
|
return self._shape.dim(axis_);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn axis(self: HostBuffer, axis_: anytype) u3 {
|
||||||
|
return self._shape.axis(axis_);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn isContiguous(self: HostBuffer) bool {
|
||||||
|
const strd = self._strides orelse return true;
|
||||||
|
const cont_strides = self._shape.computeStrides();
|
||||||
|
return std.mem.eql(i64, strd[0..self.rank()], cont_strides.constSlice());
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn reshape(self: HostBuffer, shape_: anytype) HostBuffer {
|
pub fn reshape(self: HostBuffer, shape_: anytype) HostBuffer {
|
||||||
meta.assert(self._strides == null, "reshape expects a contiguous tensor, got: {}", .{self});
|
meta.assert(self.isContiguous(), "reshape expects a contiguous tensor, got: {}", .{self});
|
||||||
var res = self;
|
var res = self;
|
||||||
res._shape = self._shape.reshape(shape_);
|
res._shape = self._shape.reshape(shape_);
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub const Slice = struct {
|
||||||
|
start: i64 = 0,
|
||||||
|
end: ?i64 = null,
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Slices the input Tensor over the given axis_ using the given parameters.
|
||||||
|
pub fn slice1d(self: HostBuffer, axis_: anytype, s: Slice) HostBuffer {
|
||||||
|
const ax = self._shape.axis(axis_);
|
||||||
|
const d = self.dim(ax);
|
||||||
|
const start: i64 = if (s.start < 0) s.start + d else s.start;
|
||||||
|
var end = s.end orelse d;
|
||||||
|
if (end < 0) end += d;
|
||||||
|
meta.assert(start >= 0 and start < d, "slice1d({}, {}) expects the slice start to be between 0 and {} got: {}", .{ self, ax, d, start });
|
||||||
|
meta.assert(end >= 1 and end <= d, "slice1d({}, {}) expects the slice end to be between 1 and {} got: {}", .{ self, ax, d, end });
|
||||||
|
meta.assert(start < end, "slice1d({}, {}) expects the slice start ({}) to be smaller than the end ({})", .{ self, ax, start, end });
|
||||||
|
|
||||||
|
// If strides weren't set it means original buffer is contiguous.
|
||||||
|
// But it won't be anymore after slicing. The strides don't change though.
|
||||||
|
const _strides = self._strides orelse self._shape.computeStrides().buffer;
|
||||||
|
const offset: usize = @intCast(start * _strides[ax]);
|
||||||
|
return .{
|
||||||
|
._shape = self.shape().set(ax, end - start),
|
||||||
|
.data = self.data[offset..],
|
||||||
|
._strides = _strides,
|
||||||
|
._memory = .unmanaged,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn format(
|
||||||
|
self: HostBuffer,
|
||||||
|
comptime fmt: []const u8,
|
||||||
|
options: std.fmt.FormatOptions,
|
||||||
|
writer: anytype,
|
||||||
|
) !void {
|
||||||
|
_ = fmt;
|
||||||
|
_ = options;
|
||||||
|
try writer.print("HostBuffer(.{_})", .{self._shape});
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
fn parseArrayInfo(T: type) Shape {
|
fn parseArrayInfo(T: type) Shape {
|
||||||
|
|||||||
33
zml/meta.zig
33
zml/meta.zig
@ -663,3 +663,36 @@ test zip {
|
|||||||
const a_sum: A = try zip(Sum.call, testing.allocator, &[_]A{ a0, a1 }, .{});
|
const a_sum: A = try zip(Sum.call, testing.allocator, &[_]A{ a0, a1 }, .{});
|
||||||
try testing.expectEqual(A{ .a = 5, .b = .{ 7, 9 } }, a_sum);
|
try testing.expectEqual(A{ .a = 5, .b = .{ 7, 9 } }, a_sum);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Given a func(X) -> Y or a func(Ctx, X) -> Y,
|
||||||
|
/// finds all X in the given object, and write the result of func(X) into an arraylist.
|
||||||
|
pub fn collect(func: anytype, func_ctx: _CollectCtx(func), out: *std.ArrayList(FnResult(func)), obj: anytype) error{OutOfMemory}!void {
|
||||||
|
assertComptime(@typeInfo(@TypeOf(func)).Fn.params.len <= 2, "zml.meta.collect expects a func with two arguments, got: {}", .{@TypeOf(func)});
|
||||||
|
const LocalContext = struct {
|
||||||
|
func_ctx: _CollectCtx(func),
|
||||||
|
out: *std.ArrayList(FnResult(func)),
|
||||||
|
oom: bool = false,
|
||||||
|
};
|
||||||
|
var context = LocalContext{ .func_ctx = func_ctx, .out = out };
|
||||||
|
visit((struct {
|
||||||
|
fn cb(ctx: *LocalContext, val: *const _CollectArg(func)) void {
|
||||||
|
if (ctx.oom) return;
|
||||||
|
const res = if (_CollectCtx(func) == void) func(val.*) else func(ctx.func_ctx, val.*);
|
||||||
|
ctx.out.append(res) catch {
|
||||||
|
ctx.oom = true;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}).cb, &context, obj);
|
||||||
|
if (context.oom) return error.OutOfMemory;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn _CollectCtx(func: anytype) type {
|
||||||
|
const params = @typeInfo(@TypeOf(func)).Fn.params;
|
||||||
|
if (params.len == 1) return void;
|
||||||
|
return params[0].type orelse @compileError("anytype not supported in collect");
|
||||||
|
}
|
||||||
|
|
||||||
|
fn _CollectArg(func: anytype) type {
|
||||||
|
const params = @typeInfo(@TypeOf(func)).Fn.params;
|
||||||
|
return params[params.len - 1].type orelse @compileError("anytype not supported in collect");
|
||||||
|
}
|
||||||
|
|||||||
286
zml/module.zig
286
zml/module.zig
@ -49,9 +49,11 @@ pub const CompilationContext = struct {
|
|||||||
_unique_id: u64 = 10000,
|
_unique_id: u64 = 10000,
|
||||||
_tracer: Tracer,
|
_tracer: Tracer,
|
||||||
|
|
||||||
const TensorToBlockArg = std.AutoHashMapUnmanaged(Tensor._Id, struct { mlir.Value, Tensor._Donation });
|
|
||||||
threadlocal var _current: ?*CompilationContext = null;
|
threadlocal var _current: ?*CompilationContext = null;
|
||||||
|
|
||||||
|
const TensorToBlockArg = std.AutoHashMapUnmanaged(Tensor._Id, struct { mlir.Value, Tensor._Donation });
|
||||||
|
const AttributeList = std.BoundedArray(mlir.NamedAttribute, 3);
|
||||||
|
|
||||||
pub fn init(allocator: std.mem.Allocator, name: []const u8, platform: Platform) !CompilationContext {
|
pub fn init(allocator: std.mem.Allocator, name: []const u8, platform: Platform) !CompilationContext {
|
||||||
const mlir_registry = mlir.Registry.init() catch unreachable;
|
const mlir_registry = mlir.Registry.init() catch unreachable;
|
||||||
inline for (.{ "func", "stablehlo" }) |d| {
|
inline for (.{ "func", "stablehlo" }) |d| {
|
||||||
@ -181,7 +183,7 @@ pub const CompilationContext = struct {
|
|||||||
comptime func: anytype,
|
comptime func: anytype,
|
||||||
model: *const ModuleSignature(func).ModelT,
|
model: *const ModuleSignature(func).ModelT,
|
||||||
args: *const ModuleSignature(func).ArgsT,
|
args: *const ModuleSignature(func).ArgsT,
|
||||||
opts: struct { add_donations_attributes: bool = false },
|
opts: struct { add_donations_attributes: bool = false, sharding: bool = true },
|
||||||
) error{OutOfMemory}!MlirFn {
|
) error{OutOfMemory}!MlirFn {
|
||||||
const frame = self._tracer.frameStart("generateBytecode.emit");
|
const frame = self._tracer.frameStart("generateBytecode.emit");
|
||||||
errdefer self._tracer.frameEnd(frame, "generateBytecode.emit");
|
errdefer self._tracer.frameEnd(frame, "generateBytecode.emit");
|
||||||
@ -197,17 +199,25 @@ pub const CompilationContext = struct {
|
|||||||
|
|
||||||
const tensor_count = model_tensor_count + args_tensor_count;
|
const tensor_count = model_tensor_count + args_tensor_count;
|
||||||
|
|
||||||
const loc = self.mlirCtx().location(@src());
|
const mlir_ctx = self.mlirCtx();
|
||||||
|
const loc = mlir_ctx.location(@src());
|
||||||
|
|
||||||
const locations = try arena.alloc(mlir.Location, tensor_count);
|
const locations = try arena.alloc(mlir.Location, tensor_count);
|
||||||
for (locations) |*l| l.* = mlir.Location.unknown(self.mlirCtx());
|
@memset(locations, mlir.Location.unknown(mlir_ctx));
|
||||||
var input_types = try arena.alloc(mlir.Type, tensor_count);
|
|
||||||
fillMlirTypes(model, self.mlirCtx(), input_types[0..model_tensor_count]);
|
var input_shapes = try std.ArrayList(Shape).initCapacity(arena, tensor_count);
|
||||||
fillMlirTypes(args, self.mlirCtx(), input_types[model_tensor_count..]);
|
meta.collect(Tensor.shape, {}, &input_shapes, model) catch unreachable;
|
||||||
|
meta.internalAssert(input_shapes.items.len == model_tensor_count, "model has changed ?", .{});
|
||||||
|
meta.collect(Tensor.shape, {}, &input_shapes, args) catch unreachable;
|
||||||
|
meta.internalAssert(input_shapes.items.len == tensor_count, "args have changed ?", .{});
|
||||||
|
|
||||||
|
const input_types = try arena.alloc(mlir.Type, tensor_count);
|
||||||
|
for (input_types, input_shapes.items) |*t, sh| t.* = mlir.ext.mlirType(mlir_ctx, sh);
|
||||||
|
|
||||||
// Note: this isn't stricly necessary. We call `countTensor` on `fn_res`.
|
// Note: this isn't stricly necessary. We call `countTensor` on `fn_res`.
|
||||||
// But it forces user to have simpler function.
|
// But it forces user to have simpler function.
|
||||||
const out_tensor_count = comptime ops.staticCountTensors(ModuleSignature(func).ReturnT) orelse @compileError("Can't use " ++ @typeName(ModuleSignature(func).ReturnT) ++ " in an MLIR function, because it has a variable number of tensors");
|
const out_tensor_count = comptime ops.staticCountTensors(ModuleSignature(func).ReturnT) orelse @compileError("Can't use " ++ @typeName(ModuleSignature(func).ReturnT) ++ " in an MLIR function, because it has a variable number of tensors");
|
||||||
|
// Those are returned to caller so we don't put them in the arena.
|
||||||
const fn_res_types = try allocator.alloc(mlir.Type, out_tensor_count);
|
const fn_res_types = try allocator.alloc(mlir.Type, out_tensor_count);
|
||||||
const fn_res_shapes = try allocator.alloc(Shape, out_tensor_count);
|
const fn_res_shapes = try allocator.alloc(Shape, out_tensor_count);
|
||||||
const fn_res_donations = try allocator.alloc(Tensor._Donation, out_tensor_count);
|
const fn_res_donations = try allocator.alloc(Tensor._Donation, out_tensor_count);
|
||||||
@ -234,20 +244,25 @@ pub const CompilationContext = struct {
|
|||||||
|
|
||||||
var fn_res_values: [out_tensor_count]mlir.Value = undefined;
|
var fn_res_values: [out_tensor_count]mlir.Value = undefined;
|
||||||
self.extractValuesAndTypes(&fn_res, &fn_res_values, fn_res_types, fn_res_shapes, fn_res_donations);
|
self.extractValuesAndTypes(&fn_res, &fn_res_values, fn_res_types, fn_res_shapes, fn_res_donations);
|
||||||
const fn_ret = dialect.func.return_(self.mlirCtx(), &fn_res_values, loc);
|
const fn_ret = dialect.func.return_(mlir_ctx, &fn_res_values, loc);
|
||||||
fn_body.addOperationsRecursive(fn_ret);
|
fn_body.addOperationsRecursive(fn_ret);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const arg_attrs = try arena.alloc(AttributeList, tensor_count);
|
||||||
|
@memset(arg_attrs, .{});
|
||||||
|
|
||||||
// Donations attributes only make sense on the main function.
|
// Donations attributes only make sense on the main function.
|
||||||
const attrs: []const mlir.Attribute = if (opts.add_donations_attributes)
|
if (opts.add_donations_attributes) {
|
||||||
try self.addDonationsAttribute(arena, fn_res_donations, tensor_count)
|
self.addDonationsAttributes(arg_attrs, fn_res_donations);
|
||||||
else
|
}
|
||||||
&.{};
|
if (opts.sharding) {
|
||||||
|
self.addShardingAttributes(arg_attrs, input_shapes.items);
|
||||||
|
}
|
||||||
|
|
||||||
const mlir_fn = dialect.func.func(self.mlirCtx(), .{
|
const mlir_fn = dialect.func.func(self.mlirCtx(), .{
|
||||||
.sym_name = fn_name,
|
.sym_name = fn_name,
|
||||||
.args = input_types[0..],
|
.args = input_types,
|
||||||
.arg_attrs = attrs,
|
.arg_attrs = try finalizeAttributeList(arena, mlir_ctx, arg_attrs),
|
||||||
.results = fn_res_types,
|
.results = fn_res_types,
|
||||||
.block = fn_body,
|
.block = fn_body,
|
||||||
.location = loc,
|
.location = loc,
|
||||||
@ -277,12 +292,7 @@ pub const CompilationContext = struct {
|
|||||||
|
|
||||||
/// Given a list of donations mapping output buffers to input buffers,
|
/// Given a list of donations mapping output buffers to input buffers,
|
||||||
/// generate donation attribute for each `n_args` input argument.
|
/// generate donation attribute for each `n_args` input argument.
|
||||||
fn addDonationsAttribute(self: *const CompilationContext, allocator: std.mem.Allocator, donations: []const Tensor._Donation, n_args: usize) ![]mlir.Attribute {
|
fn addDonationsAttributes(self: CompilationContext, attributes: []AttributeList, donations: []const Tensor._Donation) void {
|
||||||
const empty = mlir.DictionaryAttribute.init(self.mlirCtx(), &.{}).as(mlir.Attribute).?;
|
|
||||||
|
|
||||||
const arg_attrs = try allocator.alloc(mlir.Attribute, n_args);
|
|
||||||
@memset(arg_attrs, empty);
|
|
||||||
|
|
||||||
var n_donations: usize = 0;
|
var n_donations: usize = 0;
|
||||||
for (donations, 0..) |donation, index| {
|
for (donations, 0..) |donation, index| {
|
||||||
switch (donation) {
|
switch (donation) {
|
||||||
@ -293,23 +303,23 @@ pub const CompilationContext = struct {
|
|||||||
.input_buffer => {},
|
.input_buffer => {},
|
||||||
.arg => |a| {
|
.arg => |a| {
|
||||||
n_donations += 1;
|
n_donations += 1;
|
||||||
meta.assert(arg_attrs[a].eql(empty), "Donation error ! Argument {} has been donated twice ! To {} and to {}", .{ a, index, arg_attrs[a] });
|
// This will break the day we writer another attribute before donation.
|
||||||
arg_attrs[a] = mlir.DictionaryAttribute.init(self.mlirCtx(), &.{
|
// When the time come, do a more fancy lookup here to check if an argument
|
||||||
|
// is donated twice.
|
||||||
|
meta.assert(attributes[a].len == 0, "Donation error ! Argument {} has been donated twice ! To {} and to {}", .{ a, index, attributes[a].buffer[0] });
|
||||||
|
attributes[a].appendAssumeCapacity(
|
||||||
mlir.NamedAttribute.init(
|
mlir.NamedAttribute.init(
|
||||||
mlir.Identifier.get(self.mlirCtx(), "tf.aliasing_output"),
|
mlir.Identifier.get(self.mlirCtx(), "tf.aliasing_output"),
|
||||||
mlir.IntegerAttribute(.i32).init(self.mlirCtx(), @intCast(index)).as(mlir.Attribute).?,
|
mlir.IntegerAttribute(.i32).init(self.mlirCtx(), @intCast(index)).as(mlir.Attribute).?,
|
||||||
),
|
),
|
||||||
}).as(mlir.Attribute).?;
|
);
|
||||||
// log.debug("attribute: {}", .{arg_attrs[a]});
|
// log.debug("attribute: {}", .{attributes[a].constSlice()});
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (n_donations == 0) return &.{};
|
|
||||||
return arg_attrs;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
test addDonationsAttribute {
|
test addDonationsAttributes {
|
||||||
const zml = @import("zml.zig");
|
const zml = @import("zml.zig");
|
||||||
const platform = zml.testing.env();
|
const platform = zml.testing.env();
|
||||||
var arena = std.heap.ArenaAllocator.init(std.testing.allocator);
|
var arena = std.heap.ArenaAllocator.init(std.testing.allocator);
|
||||||
@ -343,12 +353,81 @@ pub const CompilationContext = struct {
|
|||||||
// `%arg0` is the bias of the model, `%arg1` is `x`.
|
// `%arg0` is the bias of the model, `%arg1` is `x`.
|
||||||
try std.testing.expectEqual(1, f.n_model);
|
try std.testing.expectEqual(1, f.n_model);
|
||||||
try std.testing.expectEqual(1, f.n_args);
|
try std.testing.expectEqual(1, f.n_args);
|
||||||
std.testing.expect(std.mem.indexOf(u8, mlir_bytecode.items, "%arg1: tensor<8xf16> {tf.aliasing_output = 0 : i32}") != null) catch |err| {
|
std.testing.expect(std.mem.indexOf(u8, mlir_bytecode.items, "tf.aliasing_output = 0 : i32") != null) catch |err| {
|
||||||
log.warn("Didn't produced the expected IR:\n{s}", .{mlir_bytecode.items});
|
log.warn("Didn't produced the expected IR:\n{s}", .{mlir_bytecode.items});
|
||||||
return err;
|
return err;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn addShardingAttributes(self: CompilationContext, attributes: []AttributeList, shapes: []const Shape) void {
|
||||||
|
const mlir_ctx = self.mlirCtx();
|
||||||
|
if (!self._platform.compilation_options.sharding_enabled) return;
|
||||||
|
|
||||||
|
const num_partitions = self._platform.sharding().num_partitions;
|
||||||
|
var sharding_str: std.BoundedArray(u8, 128) = .{};
|
||||||
|
|
||||||
|
const mhlo_default_layout = mlir.NamedAttribute.init(
|
||||||
|
mlir.Identifier.get(mlir_ctx, "mhlo.layout_mode"),
|
||||||
|
mlir.StringAttribute.init(mlir_ctx, "default").asAttr(),
|
||||||
|
);
|
||||||
|
for (attributes, shapes) |*attr, shape| {
|
||||||
|
attr.appendAssumeCapacity(mhlo_default_layout);
|
||||||
|
|
||||||
|
writeShardingRepresentation(shape, num_partitions, sharding_str.writer()) catch unreachable;
|
||||||
|
defer sharding_str.len = 0;
|
||||||
|
attr.appendAssumeCapacity(mlir.NamedAttribute.init(
|
||||||
|
mlir.Identifier.get(mlir_ctx, "mhlo.sharding"),
|
||||||
|
mlir.StringAttribute.init(mlir_ctx, sharding_str.constSlice()).asAttr(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn writeShardingRepresentation(shape: Shape, num_partitions: u8, writer: anytype) @TypeOf(writer).Error!void {
|
||||||
|
const n_sharded: u8 = @popCount(@as(u8, @bitCast(shape._sharding_info)));
|
||||||
|
if (n_sharded == 0 or num_partitions == 1) {
|
||||||
|
try writer.writeAll("{replicated}");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
try writer.writeAll("{devices=[");
|
||||||
|
for (0..shape.rank()) |i| {
|
||||||
|
try writer.print("{d}", .{if (shape._sharding_info[i]) num_partitions else 1});
|
||||||
|
if (i < shape.rank() - 1) try writer.writeByte(',');
|
||||||
|
}
|
||||||
|
try writer.print("]<=[{d}]}}", .{num_partitions});
|
||||||
|
}
|
||||||
|
|
||||||
|
test writeShardingRepresentation {
|
||||||
|
var rule: [64]u8 = undefined;
|
||||||
|
const x = Shape.init(.{ 16, 8 }, .f32);
|
||||||
|
|
||||||
|
// By default tensors are replicated.
|
||||||
|
{
|
||||||
|
var fbs = std.io.fixedBufferStream(&rule);
|
||||||
|
try writeShardingRepresentation(x, 4, fbs.writer());
|
||||||
|
try std.testing.expectEqualStrings("{replicated}", fbs.getWritten());
|
||||||
|
}
|
||||||
|
// Shard along first axis.
|
||||||
|
{
|
||||||
|
var fbs = std.io.fixedBufferStream(&rule);
|
||||||
|
try writeShardingRepresentation(x.withSharding(.{0}), 4, fbs.writer());
|
||||||
|
try std.testing.expectEqualStrings("{devices=[4,1]<=[4]}", fbs.getWritten());
|
||||||
|
}
|
||||||
|
// Also shard along second axis.
|
||||||
|
{
|
||||||
|
var fbs = std.io.fixedBufferStream(&rule);
|
||||||
|
try writeShardingRepresentation(x.withSharding(.{ 0, 1 }), 2, fbs.writer());
|
||||||
|
try std.testing.expectEqualStrings("{devices=[2,2]<=[2]}", fbs.getWritten());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn finalizeAttributeList(allocator: std.mem.Allocator, mlir_ctx: mlir.Context, attributes: []AttributeList) ![]mlir.Attribute {
|
||||||
|
const res = try allocator.alloc(mlir.Attribute, attributes.len);
|
||||||
|
for (res, attributes) |*r, attr| {
|
||||||
|
r.* = mlir.DictionaryAttribute.init(mlir_ctx, attr.constSlice()).asAttr();
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
/// Generates an MLIR `func.call` of the given function.
|
/// Generates an MLIR `func.call` of the given function.
|
||||||
/// If the function has not been seen yet, we generate MLIR for it,
|
/// If the function has not been seen yet, we generate MLIR for it,
|
||||||
/// in a independent function.
|
/// in a independent function.
|
||||||
@ -565,47 +644,57 @@ fn assignBlockArguments(v: anytype, block: mlir.Block, start: usize) usize {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Visit the given struct and fill the `buffers` slice with the buffer associated with encountered Tensor.
|
/// Visit the given struct and fill the `buffers` slice with the buffer associated with encountered Tensor.
|
||||||
fn fillBuffers(v: anytype, buffers: []*pjrt.Buffer) void {
|
fn fillBuffers(v: anytype, buffers: []const [*]*pjrt.Buffer, start: u32, len: u32) void {
|
||||||
const LocalContext = struct {
|
const LocalContext = struct {
|
||||||
index: usize,
|
index: u32,
|
||||||
buffers: []*pjrt.Buffer,
|
buffers: []const [*]*pjrt.Buffer,
|
||||||
};
|
};
|
||||||
var ctx: LocalContext = .{
|
var context: LocalContext = .{
|
||||||
.index = 0,
|
.index = start,
|
||||||
.buffers = buffers,
|
.buffers = buffers,
|
||||||
};
|
};
|
||||||
meta.visit((struct {
|
meta.visit((struct {
|
||||||
fn cb(inner_context: *LocalContext, buffer: *const Buffer) void {
|
fn cb(ctx: *LocalContext, buffer: *const Buffer) void {
|
||||||
// meta.assert(!buffer._data.isDeleted(), "Can't use {} (argument buffer {}) because its pjrt buffer has been donated", .{ buffer, inner_context.index });
|
// meta.assert(!buffer._data.isDeleted(), "Can't use {} (argument buffer {}) because its pjrt buffer has been donated", .{ buffer, ctx.index });
|
||||||
inner_context.buffers[inner_context.index] = buffer._data;
|
const model_sharding = ctx.buffers.len;
|
||||||
inner_context.index += 1;
|
meta.assert(buffer._shards.len == model_sharding, "Can't feed a {}-sharded tensor into a {}-sharded model", .{ buffer._shards.len, ctx.buffers.len });
|
||||||
|
for (buffer._shards.constSlice(), 0..) |shard, d| {
|
||||||
|
ctx.buffers[d][ctx.index] = shard;
|
||||||
}
|
}
|
||||||
}).cb, &ctx, v);
|
ctx.index += 1;
|
||||||
assert(ctx.index == buffers.len);
|
}
|
||||||
|
}).cb, &context, v);
|
||||||
|
assert(context.index == start + len);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Visit the given struct and override tensors by creating a new one using the provided PJRT buffers.
|
/// Visit the given struct and override tensors by creating a new one using the provided PJRT buffers.
|
||||||
pub fn assignRawBuffers(v: anytype, platform: Platform, buffers: []*pjrt.Buffer) void {
|
pub fn assignRawBuffers(v: anytype, platform: Platform, buffers: []const [*]*pjrt.Buffer, expected_count: u32) void {
|
||||||
const LocalContext = struct {
|
const LocalContext = struct {
|
||||||
index: usize,
|
index: u32,
|
||||||
platform: Platform,
|
platform: Platform,
|
||||||
buffers: []*pjrt.Buffer,
|
buffers: []const [*]*pjrt.Buffer,
|
||||||
|
expected_count: u32,
|
||||||
};
|
};
|
||||||
var ctx: LocalContext = .{
|
var local_ctx: LocalContext = .{
|
||||||
.index = 0,
|
.index = 0,
|
||||||
.platform = platform,
|
.platform = platform,
|
||||||
.buffers = buffers,
|
.buffers = buffers,
|
||||||
|
.expected_count = expected_count,
|
||||||
};
|
};
|
||||||
meta.visit((struct {
|
meta.visit((struct {
|
||||||
fn cb(inner_context: *LocalContext, buffer: *Buffer) void {
|
fn cb(ctx: *LocalContext, buffer: *Buffer) void {
|
||||||
const i = inner_context.index;
|
const i = ctx.index;
|
||||||
if (i < inner_context.buffers.len) {
|
ctx.index += 1;
|
||||||
buffer.* = Buffer.fromPjrtBuffer(inner_context.platform, inner_context.buffers[i]);
|
if (i >= ctx.expected_count) return;
|
||||||
|
|
||||||
|
var shards: Buffer.Shards = .{};
|
||||||
|
for (ctx.buffers) |buff| {
|
||||||
|
shards.appendAssumeCapacity(buff[i]);
|
||||||
}
|
}
|
||||||
inner_context.index += 1;
|
buffer.* = Buffer.fromPjrtBuffers(ctx.platform, shards.constSlice());
|
||||||
}
|
}
|
||||||
}).cb, &ctx, v);
|
}).cb, &local_ctx, v);
|
||||||
meta.assert(ctx.index == buffers.len, "Pjrt call returned {} tensors, but the return type {s}, contains {} Buffers. Note that modules need to have a comptime know number of returned tensors.", .{ buffers.len, @typeName(@TypeOf(v)), ctx.index });
|
meta.internalAssert(local_ctx.index == expected_count, "Pjrt call returned {} tensors, but the return type {s}, contains {} Buffers. Note that modules need to have a comptime know number of returned tensors.", .{ buffers.len, @typeName(@TypeOf(v)), local_ctx.index });
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Visit the given struct and assign op results to each tensor found.
|
/// Visit the given struct and assign op results to each tensor found.
|
||||||
@ -637,11 +726,13 @@ const BaseExe = struct {
|
|||||||
/// The PJRT executable representing the compiled module.
|
/// The PJRT executable representing the compiled module.
|
||||||
exe: *pjrt.LoadedExecutable,
|
exe: *pjrt.LoadedExecutable,
|
||||||
/// Number of buffers in the model.
|
/// Number of buffers in the model.
|
||||||
model_buffer_count: usize,
|
model_buffer_count: u32,
|
||||||
/// Number of buffers in the arguments.
|
/// Number of buffers in the arguments.
|
||||||
args_buffer_count: usize,
|
args_buffer_count: u32,
|
||||||
/// Number of buffers in result.
|
/// Number of buffers in result.
|
||||||
result_buffer_count: usize,
|
result_buffer_count: u32,
|
||||||
|
/// Num devices used (>1 for sharded executable)
|
||||||
|
num_devices: u8,
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Represents a ZML model, compiled into a PJRT executable.
|
/// Represents a ZML model, compiled into a PJRT executable.
|
||||||
@ -674,34 +765,56 @@ pub fn ExeWithWeights(comptime func: anytype) type {
|
|||||||
/// The raw untyped compiled module.
|
/// The raw untyped compiled module.
|
||||||
inner: BaseExe,
|
inner: BaseExe,
|
||||||
|
|
||||||
/// The allocator used for bookkeeping.
|
|
||||||
allocator: std.mem.Allocator,
|
|
||||||
|
|
||||||
/// Pre-allocated slice of buffers to use as inputs when the module is called.
|
/// Pre-allocated slice of buffers to use as inputs when the module is called.
|
||||||
input_buffers: []*pjrt.Buffer,
|
input_per_device: []const [*]*pjrt.Buffer,
|
||||||
|
|
||||||
/// Pre-allocated slice of buffers to use as outputs when the module is called.
|
/// Pre-allocated slice of buffers to use as outputs when the module is called.
|
||||||
output_buffers: []*pjrt.Buffer,
|
output_per_device: []const [*]*pjrt.Buffer,
|
||||||
|
|
||||||
|
/// Internal memory slice used.
|
||||||
|
_all_buffers: []*pjrt.Buffer,
|
||||||
|
_all_per_device: [][*]*pjrt.Buffer,
|
||||||
|
|
||||||
|
/// And the allocator backing _data_buffer.
|
||||||
|
_allocator: std.mem.Allocator,
|
||||||
|
|
||||||
pub fn initFromModel(allocator: std.mem.Allocator, inner: BaseExe, model: Bufferized(Signature.ModelT)) !Self {
|
pub fn initFromModel(allocator: std.mem.Allocator, inner: BaseExe, model: Bufferized(Signature.ModelT)) !Self {
|
||||||
const input_buffers = try allocator.alloc(*pjrt.Buffer, inner.model_buffer_count + inner.args_buffer_count);
|
const n_input_buffers = inner.model_buffer_count + inner.args_buffer_count;
|
||||||
errdefer allocator.free(input_buffers);
|
const n_output_buffers = inner.result_buffer_count;
|
||||||
fillBuffers(&model, input_buffers[0..inner.model_buffer_count]);
|
const n_devices = inner.num_devices;
|
||||||
|
|
||||||
const output_buffers = try allocator.alloc(*pjrt.Buffer, inner.result_buffer_count);
|
// Allocate once for all the *pjrt.Buffer we need to store ...
|
||||||
errdefer allocator.free(output_buffers);
|
const all_buffers = try allocator.alloc(*pjrt.Buffer, (n_input_buffers + n_output_buffers) * n_devices);
|
||||||
|
errdefer allocator.free(all_buffers);
|
||||||
|
const all_input_buffers, const all_output_buffers = splitBuffer(*pjrt.Buffer, all_buffers, .{ n_input_buffers * n_devices, n_output_buffers * n_devices });
|
||||||
|
|
||||||
|
// ... and once for all the [*]*pjrt.Buffer.
|
||||||
|
const all_per_device = try allocator.alloc([*]*pjrt.Buffer, 2 * n_devices);
|
||||||
|
errdefer allocator.free(all_per_device);
|
||||||
|
const input_per_device, const output_per_device = splitBuffer([*]*pjrt.Buffer, all_per_device, .{ n_devices, n_devices });
|
||||||
|
|
||||||
|
for (0..n_devices) |i| {
|
||||||
|
input_per_device[i] = all_input_buffers[i * n_input_buffers ..].ptr;
|
||||||
|
output_per_device[i] = all_output_buffers[i * n_output_buffers ..].ptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
fillBuffers(&model, input_per_device, 0, inner.model_buffer_count);
|
||||||
|
// Note: all_output_buffers is left undefined, it will be written to in `call`.
|
||||||
|
|
||||||
return .{
|
return .{
|
||||||
.inner = inner,
|
.inner = inner,
|
||||||
.allocator = allocator,
|
.input_per_device = input_per_device,
|
||||||
.input_buffers = input_buffers,
|
.output_per_device = output_per_device,
|
||||||
.output_buffers = output_buffers,
|
._all_buffers = all_buffers,
|
||||||
|
._all_per_device = all_per_device,
|
||||||
|
._allocator = allocator,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn deinit(self: Self) void {
|
pub fn deinit(self: Self) void {
|
||||||
self.allocator.free(self.input_buffers);
|
// Free in reverse order of allocation.
|
||||||
self.allocator.free(self.output_buffers);
|
self._allocator.free(self._all_per_device);
|
||||||
|
self._allocator.free(self._all_buffers);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn platform(self: Self) Platform {
|
pub fn platform(self: Self) Platform {
|
||||||
@ -709,13 +822,13 @@ pub fn ExeWithWeights(comptime func: anytype) type {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn call(self: Self, args: Bufferized(Signature.ArgsT)) Bufferized(Signature.ReturnT) {
|
pub fn call(self: Self, args: Bufferized(Signature.ArgsT)) Bufferized(Signature.ReturnT) {
|
||||||
fillBuffers(&args, self.input_buffers[self.inner.model_buffer_count..][0..self.inner.args_buffer_count]);
|
fillBuffers(&args, self.input_per_device, self.inner.model_buffer_count, self.inner.args_buffer_count);
|
||||||
var event: [1]*pjrt.Event = undefined;
|
var event: [1]*pjrt.Event = undefined;
|
||||||
|
|
||||||
self.inner.exe.execute(self.inner.platform.pjrt_api, .{
|
self.inner.exe.execute(self.inner.platform.pjrt_api, .{
|
||||||
.arguments = &.{self.input_buffers.ptr},
|
.arguments = self.input_per_device,
|
||||||
.num_args = self.input_buffers.len,
|
.num_args = self.inner.args_buffer_count + self.inner.model_buffer_count,
|
||||||
.results = &.{self.output_buffers.ptr},
|
.results = self.output_per_device,
|
||||||
.events = &event,
|
.events = &event,
|
||||||
// TODO: this allows to tell a specific buffer shouldn't be donated,
|
// TODO: this allows to tell a specific buffer shouldn't be donated,
|
||||||
// even if it has been marked as "can be donated" during compilation.
|
// even if it has been marked as "can be donated" during compilation.
|
||||||
@ -723,7 +836,7 @@ pub fn ExeWithWeights(comptime func: anytype) type {
|
|||||||
}) catch unreachable;
|
}) catch unreachable;
|
||||||
|
|
||||||
var result: Bufferized(Signature.ReturnT) = undefined;
|
var result: Bufferized(Signature.ReturnT) = undefined;
|
||||||
assignRawBuffers(&result, self.inner.platform, self.output_buffers);
|
assignRawBuffers(&result, self.inner.platform, self.output_per_device, self.inner.result_buffer_count);
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -750,6 +863,11 @@ fn compileInternal(
|
|||||||
const f = try asynk.callGeneric(CompilationContext.generateBytecode, .{ context, arena, "main", func, &model, &tensor_args, .{ .add_donations_attributes = true } });
|
const f = try asynk.callGeneric(CompilationContext.generateBytecode, .{ context, arena, "main", func, &model, &tensor_args, .{ .add_donations_attributes = true } });
|
||||||
context._module.getBody().appendOperation(f.mlir_fn);
|
context._module.getBody().appendOperation(f.mlir_fn);
|
||||||
|
|
||||||
|
const sharding = context._platform.sharding();
|
||||||
|
const mlir_ctx = context._mlir_ctx;
|
||||||
|
context._module.op().setAttributeByName("mhlo.num_replicas", mlir.IntegerAttribute(.i32).init(mlir_ctx, sharding.num_replicas).asAttr());
|
||||||
|
context._module.op().setAttributeByName("mhlo.num_partitions", mlir.IntegerAttribute(.i32).init(mlir_ctx, sharding.num_partitions).asAttr());
|
||||||
|
|
||||||
const loaded_executable = loadOrCompilePjrtExecutable(arena, context._platform, context._module) catch |err| {
|
const loaded_executable = loadOrCompilePjrtExecutable(arena, context._platform, context._module) catch |err| {
|
||||||
log.err(
|
log.err(
|
||||||
"pjrt-{s} failed to compile following valid MLIR:\n{}\n{}",
|
"pjrt-{s} failed to compile following valid MLIR:\n{}\n{}",
|
||||||
@ -771,7 +889,8 @@ fn compileInternal(
|
|||||||
.exe = loaded_executable,
|
.exe = loaded_executable,
|
||||||
.model_buffer_count = f.n_model,
|
.model_buffer_count = f.n_model,
|
||||||
.args_buffer_count = f.n_args,
|
.args_buffer_count = f.n_args,
|
||||||
.result_buffer_count = f.res_types.len,
|
.result_buffer_count = @intCast(f.res_types.len),
|
||||||
|
.num_devices = sharding.num_replicas * sharding.num_partitions,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -932,10 +1051,12 @@ fn loadOrCompilePjrtExecutable(
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, module: mlir.Module, module_hash: u64) !*pjrt.LoadedExecutable {
|
fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, module: mlir.Module, module_hash: u64) !*pjrt.LoadedExecutable {
|
||||||
|
const sharding = platform.sharding();
|
||||||
var options: xla_pb.CompileOptionsProto = .{
|
var options: xla_pb.CompileOptionsProto = .{
|
||||||
.executable_build_options = .{
|
.executable_build_options = .{
|
||||||
.num_replicas = 1,
|
.num_replicas = sharding.num_replicas,
|
||||||
.num_partitions = 1,
|
.num_partitions = sharding.num_partitions,
|
||||||
|
.use_spmd_partitioning = sharding.num_partitions > 1 or sharding.num_replicas > 1,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
// Let the arena deinit, zig-protobuf deinit is very slow.
|
// Let the arena deinit, zig-protobuf deinit is very slow.
|
||||||
@ -1372,3 +1493,14 @@ fn hashArray(hasher: anytype, key: anytype, comptime strat: HashStrategy) void {
|
|||||||
hash(hasher, element, strat);
|
hash(hasher, element, strat);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn splitBuffer(T: type, buffer: []T, lengths: anytype) [lengths.len][]T {
|
||||||
|
var res: [lengths.len][]T = undefined;
|
||||||
|
var i: usize = 0;
|
||||||
|
inline for (&res, lengths) |*r, len| {
|
||||||
|
r.* = buffer[i .. i + len];
|
||||||
|
i += len;
|
||||||
|
}
|
||||||
|
std.debug.assert(i == buffer.len);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|||||||
@ -15,6 +15,7 @@ const Target = @import("platform.zig").Target;
|
|||||||
|
|
||||||
const log = std.log.scoped(.zml);
|
const log = std.log.scoped(.zml);
|
||||||
|
|
||||||
|
pub const Buffer = pjrt.Buffer;
|
||||||
pub const Device = pjrt.Device;
|
pub const Device = pjrt.Device;
|
||||||
pub const DeviceDescription = pjrt.DeviceDescription;
|
pub const DeviceDescription = pjrt.DeviceDescription;
|
||||||
pub const Api = pjrt.Api;
|
pub const Api = pjrt.Api;
|
||||||
@ -27,6 +28,12 @@ pub const SerializeResult = pjrt.SerializeResult;
|
|||||||
pub const Executable = pjrt.Executable;
|
pub const Executable = pjrt.Executable;
|
||||||
pub const ExecuteError = ApiError;
|
pub const ExecuteError = ApiError;
|
||||||
|
|
||||||
|
test {
|
||||||
|
std.testing.refAllDecls(Client);
|
||||||
|
std.testing.refAllDecls(Event);
|
||||||
|
std.testing.refAllDecls(LoadedExecutable);
|
||||||
|
}
|
||||||
|
|
||||||
fn InnerMixin(comptime innerT: type) type {
|
fn InnerMixin(comptime innerT: type) type {
|
||||||
return struct {
|
return struct {
|
||||||
inline fn inner(self: anytype) if (@typeInfo(@TypeOf(self)).Pointer.is_const) *const innerT else *innerT {
|
inline fn inner(self: anytype) if (@typeInfo(@TypeOf(self)).Pointer.is_const) *const innerT else *innerT {
|
||||||
@ -63,7 +70,7 @@ pub const Client = opaque {
|
|||||||
const buffer, const event_ = try self.inner().bufferFromHostBuffer(api, args);
|
const buffer, const event_ = try self.inner().bufferFromHostBuffer(api, args);
|
||||||
const event: *Event = @ptrCast(event_);
|
const event: *Event = @ptrCast(event_);
|
||||||
try event.await_(api);
|
try event.await_(api);
|
||||||
return @ptrCast(buffer);
|
return buffer;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn deserializeAndLoad(self: *const Client, api: *const Api, bytes: []const u8) ApiError!*LoadedExecutable {
|
pub fn deserializeAndLoad(self: *const Client, api: *const Api, bytes: []const u8) ApiError!*LoadedExecutable {
|
||||||
@ -169,75 +176,6 @@ pub const Client = opaque {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
pub const Buffer = opaque {
|
|
||||||
const inner = InnerMixin(pjrt.Buffer).inner;
|
|
||||||
|
|
||||||
pub const BufferType = pjrt.BufferType;
|
|
||||||
|
|
||||||
pub fn BufferTypeFromDType(dt: dtype.DataType) BufferType {
|
|
||||||
return switch (dt) {
|
|
||||||
.bool => .PRED,
|
|
||||||
.f8e4m3b11fnuz => .F8E4M3B11FNUZ,
|
|
||||||
.f8e4m3fn => .F8E4M3FN,
|
|
||||||
.f8e4m3fnuz => .F8E4M3FNUZ,
|
|
||||||
.f8e5m2 => .F8E5M2,
|
|
||||||
.f8e5m2fnuz => .F8E5M2FNUZ,
|
|
||||||
.bf16 => .BF16,
|
|
||||||
.f16 => .F16,
|
|
||||||
.f32 => .F32,
|
|
||||||
.f64 => .F64,
|
|
||||||
.i8 => .S8,
|
|
||||||
.i4 => .S4,
|
|
||||||
.i16 => .S16,
|
|
||||||
.i32 => .S32,
|
|
||||||
.i64 => .S64,
|
|
||||||
.u4 => .U4,
|
|
||||||
.u8 => .U8,
|
|
||||||
.u16 => .U16,
|
|
||||||
.u32 => .U32,
|
|
||||||
.u64 => .U64,
|
|
||||||
.c64 => .C64,
|
|
||||||
.c128 => .C128,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
pub const HostBufferSemantics = pjrt.HostBufferSemantics;
|
|
||||||
pub const MemoryLayoutType = pjrt.MemoryLayoutType;
|
|
||||||
|
|
||||||
pub fn deinit(self: *Buffer, api: *const Api) void {
|
|
||||||
self.inner().deinit(api);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn delete(self: *Buffer, api: *const Api) void {
|
|
||||||
self.inner().delete(api);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn toHostBuffer(self: *const Buffer, api: *const Api, dst: []u8) ApiError!void {
|
|
||||||
var event = try self.inner().toHostBuffer(api, dst);
|
|
||||||
try event.await_(api);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn getDimensions(self: *const Buffer, api: *const Api) []const i64 {
|
|
||||||
return self.inner().getDimensions(api);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn getElementType(self: *const Buffer, api: *const Api) BufferType {
|
|
||||||
return self.inner().getElementType(api);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn isDeleted(self: *const Buffer, api: *const Api) bool {
|
|
||||||
return self.inner().isDeleted(api);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn getDevice(self: *const Buffer, api: *const Api) ApiError!*Device {
|
|
||||||
return @ptrCast(try self.inner().getDevice(api));
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn getOpaqueDeviceMemoryDataPointer(self: *const Buffer, api: *const Api) ApiError!*anyopaque {
|
|
||||||
return self.inner().getOpaqueDeviceMemoryDataPointer(api);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
pub const Event = opaque {
|
pub const Event = opaque {
|
||||||
pub const inner = InnerMixin(pjrt.Event).inner;
|
pub const inner = InnerMixin(pjrt.Event).inner;
|
||||||
|
|
||||||
@ -249,7 +187,7 @@ pub const Event = opaque {
|
|||||||
return self.inner().isReady(api);
|
return self.inner().isReady(api);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn getEventError(self: *const Event, api: *const Api) ApiError!?*Error {
|
pub fn getEventError(self: *const Event, api: *const Api) ?*Error {
|
||||||
return self.inner().getEventError(api);
|
return self.inner().getEventError(api);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -288,9 +226,9 @@ pub const Event = opaque {
|
|||||||
pub const LoadedExecutable = opaque {
|
pub const LoadedExecutable = opaque {
|
||||||
const inner = InnerMixin(pjrt.LoadedExecutable).inner;
|
const inner = InnerMixin(pjrt.LoadedExecutable).inner;
|
||||||
|
|
||||||
pub fn deinit(self: *LoadedExecutable, api: *const Api) void {
|
// pub fn deinit(self: *LoadedExecutable, api: *const Api) void {
|
||||||
self.inner().deinit(api);
|
// self.inner().deinit(api);
|
||||||
}
|
// }
|
||||||
|
|
||||||
pub fn delete(self: *LoadedExecutable, api: *const Api) void {
|
pub fn delete(self: *LoadedExecutable, api: *const Api) void {
|
||||||
self.inner().delete(api);
|
self.inner().delete(api);
|
||||||
@ -300,9 +238,10 @@ pub const LoadedExecutable = opaque {
|
|||||||
return self.inner().isDeleted(api);
|
return self.inner().isDeleted(api);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn getAddressableDevices(self: *const LoadedExecutable, api: *const Api) []*const Device {
|
// TODO fix me
|
||||||
return self.inner().getAddressableDevices(api);
|
// pub fn getAddressableDevices(self: *const LoadedExecutable, api: *const Api) []*const Device {
|
||||||
}
|
// return self.inner().getAddressableDevices(api);
|
||||||
|
// }
|
||||||
|
|
||||||
pub fn execute(self: *const LoadedExecutable, api: *const Api, args: struct {
|
pub fn execute(self: *const LoadedExecutable, api: *const Api, args: struct {
|
||||||
arguments: []const [*]const *const Buffer,
|
arguments: []const [*]const *const Buffer,
|
||||||
|
|||||||
@ -6,6 +6,7 @@ const meta = @import("meta.zig");
|
|||||||
const module = @import("module.zig");
|
const module = @import("module.zig");
|
||||||
const pjrt = @import("pjrtx.zig");
|
const pjrt = @import("pjrtx.zig");
|
||||||
const pjrt_core = @import("pjrt");
|
const pjrt_core = @import("pjrt");
|
||||||
|
const log = std.log.scoped(.zml);
|
||||||
|
|
||||||
pub const Target = enum {
|
pub const Target = enum {
|
||||||
cpu,
|
cpu,
|
||||||
@ -31,6 +32,8 @@ pub const CompilationOptions = struct {
|
|||||||
xla_dump_to: ?[]const u8 = null,
|
xla_dump_to: ?[]const u8 = null,
|
||||||
xla_dump_fusion_visualization: bool = false,
|
xla_dump_fusion_visualization: bool = false,
|
||||||
cache_location: ?[]const u8 = null,
|
cache_location: ?[]const u8 = null,
|
||||||
|
sharding_enabled: bool = false,
|
||||||
|
sharding_axes: std.BoundedArray([*:0]const u8, 8) = .{},
|
||||||
};
|
};
|
||||||
|
|
||||||
pub const Platform = struct {
|
pub const Platform = struct {
|
||||||
@ -39,8 +42,14 @@ pub const Platform = struct {
|
|||||||
pjrt_client: *pjrt.Client,
|
pjrt_client: *pjrt.Client,
|
||||||
compilation_options: CompilationOptions = .{},
|
compilation_options: CompilationOptions = .{},
|
||||||
|
|
||||||
|
pub const MAX_NUM_DEVICES: u8 = 8;
|
||||||
|
|
||||||
pub fn init(target: Target, api: *const pjrt.Api) !Platform {
|
pub fn init(target: Target, api: *const pjrt.Api) !Platform {
|
||||||
const pjrt_client = try pjrt.Client.init(api, &.{});
|
const pjrt_client = try pjrt.Client.init(api, &.{});
|
||||||
|
const true_num_devices = pjrt_client.getAddressableDevices(api).len;
|
||||||
|
if (true_num_devices > MAX_NUM_DEVICES) {
|
||||||
|
log.warn("platform {} got {} devices, but ZML only support up to {} devices. Some devices won't be used.", .{ target, true_num_devices, MAX_NUM_DEVICES });
|
||||||
|
}
|
||||||
return .{
|
return .{
|
||||||
.target = target,
|
.target = target,
|
||||||
.pjrt_api = api,
|
.pjrt_api = api,
|
||||||
@ -50,7 +59,26 @@ pub const Platform = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn getDevices(self: Platform) []const *const pjrt_core.Device {
|
pub fn getDevices(self: Platform) []const *const pjrt_core.Device {
|
||||||
return self.pjrt_client.getAddressableDevices(self.pjrt_api);
|
const all_devices = self.pjrt_client.getAddressableDevices(self.pjrt_api);
|
||||||
|
if (all_devices.len > MAX_NUM_DEVICES) {
|
||||||
|
return all_devices[0..MAX_NUM_DEVICES];
|
||||||
|
}
|
||||||
|
return all_devices;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub const Sharding = struct { num_replicas: u8, num_partitions: u8 };
|
||||||
|
|
||||||
|
pub fn sharding(self: Platform) Sharding {
|
||||||
|
// replicas run the same function but with different inputs,
|
||||||
|
// while partitions contribute to one evaluation over a shared input.
|
||||||
|
// Inside an inference process, we generally don't want replicas,
|
||||||
|
// as it's best to fully isolate replicas on different processes.
|
||||||
|
// For now we hardcode num_replicas = 1.
|
||||||
|
const num_devices: u8 = @intCast(self.getDevices().len);
|
||||||
|
return if (self.compilation_options.sharding_enabled)
|
||||||
|
.{ .num_replicas = 1, .num_partitions = num_devices }
|
||||||
|
else
|
||||||
|
.{ .num_replicas = 1, .num_partitions = 1 };
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn withCompilationOptions(self: Platform, opts: CompilationOptions) Platform {
|
pub fn withCompilationOptions(self: Platform, opts: CompilationOptions) Platform {
|
||||||
|
|||||||
@ -23,12 +23,14 @@ pub const Shape = struct {
|
|||||||
pub const DimsArray = std.BoundedArray(i64, MAX_RANK);
|
pub const DimsArray = std.BoundedArray(i64, MAX_RANK);
|
||||||
pub const TagsArray = std.BoundedArray(Tag, MAX_RANK);
|
pub const TagsArray = std.BoundedArray(Tag, MAX_RANK);
|
||||||
pub const AxesArray = std.BoundedArray(u3, MAX_RANK);
|
pub const AxesArray = std.BoundedArray(u3, MAX_RANK);
|
||||||
|
pub const ShardingInfo = @Vector(MAX_RANK, bool);
|
||||||
|
|
||||||
const UnknownTags: TagsArray = .{ .len = 0, .buffer = [_]Tag{TagUnknown} ** MAX_RANK };
|
const UnknownTags: TagsArray = .{ .len = 0, .buffer = [_]Tag{TagUnknown} ** MAX_RANK };
|
||||||
|
|
||||||
_dtype: DataType,
|
_dtype: DataType,
|
||||||
_dims: DimsArray = .{},
|
_dims: DimsArray = .{},
|
||||||
_tags: TagsArray = UnknownTags,
|
_tags: TagsArray = UnknownTags,
|
||||||
|
_sharding_info: ShardingInfo = @splat(false),
|
||||||
|
|
||||||
pub fn parseDimensions(v: anytype) struct { DimsArray, TagsArray } {
|
pub fn parseDimensions(v: anytype) struct { DimsArray, TagsArray } {
|
||||||
const T = @TypeOf(v);
|
const T = @TypeOf(v);
|
||||||
@ -69,7 +71,7 @@ pub const Shape = struct {
|
|||||||
return .{ dims_, tags_ };
|
return .{ dims_, tags_ };
|
||||||
}
|
}
|
||||||
|
|
||||||
meta.compileError("Wrong type, got {}", .{T});
|
meta.compileError("expected a dimension tuple eg '.{{ .a = 10, .b = 20}}' or '.{{ 10, 20 }}', got {}", .{T});
|
||||||
}
|
}
|
||||||
|
|
||||||
test parseDimensions {
|
test parseDimensions {
|
||||||
@ -286,7 +288,7 @@ pub const Shape = struct {
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
meta.compileError("Wrong type, got {}", .{T});
|
meta.compileError("axes expects an int-tuple or a tuple of enum literal, got {}", .{T});
|
||||||
}
|
}
|
||||||
|
|
||||||
fn axisFromInt(self: Shape, d: isize) u3 {
|
fn axisFromInt(self: Shape, d: isize) u3 {
|
||||||
@ -384,6 +386,9 @@ pub const Shape = struct {
|
|||||||
} else {
|
} else {
|
||||||
try writer.print("{s}{d}", .{ prefix, d });
|
try writer.print("{s}{d}", .{ prefix, d });
|
||||||
}
|
}
|
||||||
|
if (self._sharding_info[i]) {
|
||||||
|
try writer.writeByte('!');
|
||||||
|
}
|
||||||
}
|
}
|
||||||
_ = try writer.print("}}, dtype=.{s}", .{@tagName(self.dtype())});
|
_ = try writer.print("}}, dtype=.{s}", .{@tagName(self.dtype())});
|
||||||
if (!bare_fmt) _ = try writer.write(")");
|
if (!bare_fmt) _ = try writer.write(")");
|
||||||
@ -664,6 +669,16 @@ pub const Shape = struct {
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn withSharding(self: Shape, axes_: anytype) Shape {
|
||||||
|
var res = self;
|
||||||
|
// Reset sharding.
|
||||||
|
res._sharding_info = @splat(false);
|
||||||
|
for (self.axes(axes_).constSlice()) |ax| {
|
||||||
|
res._sharding_info[ax] = true;
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
/// Renames some of the tags in this shape.
|
/// Renames some of the tags in this shape.
|
||||||
/// Shape.init(.{ .a = 10, .b = 20 }).rename(.{ .b = .batch }); // .{ .a = 10, .batch = 20 };
|
/// Shape.init(.{ .a = 10, .b = 20 }).rename(.{ .b = .batch }); // .{ .a = 10, .batch = 20 };
|
||||||
pub fn rename(self: Shape, renames: anytype) Shape {
|
pub fn rename(self: Shape, renames: anytype) Shape {
|
||||||
@ -690,7 +705,8 @@ pub const Shape = struct {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn computeStrides(self: Shape, base_stride: u32) std.BoundedArray(i64, MAX_RANK) {
|
pub fn computeStrides(self: Shape) std.BoundedArray(i64, MAX_RANK) {
|
||||||
|
const base_stride = self.dtype().sizeOf();
|
||||||
const rk = self.rank();
|
const rk = self.rank();
|
||||||
var strides: std.BoundedArray(i64, MAX_RANK) = .{ .len = @intCast(self.rank()) };
|
var strides: std.BoundedArray(i64, MAX_RANK) = .{ .len = @intCast(self.rank()) };
|
||||||
if (rk == 0) return strides;
|
if (rk == 0) return strides;
|
||||||
|
|||||||
@ -159,6 +159,12 @@ pub const Tensor = struct {
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn withSharding(self: Tensor, axes_: anytype) Tensor {
|
||||||
|
var res = self;
|
||||||
|
res._shape = self._shape.withSharding(axes_);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
/// Returns a Tensor with new tag names.
|
/// Returns a Tensor with new tag names.
|
||||||
pub fn rename(self: Tensor, renames: anytype) Tensor {
|
pub fn rename(self: Tensor, renames: anytype) Tensor {
|
||||||
var res = self;
|
var res = self;
|
||||||
@ -196,11 +202,6 @@ pub const Tensor = struct {
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a slice containing the strides for a Tensor.
|
|
||||||
pub inline fn computeStrides(self: Tensor) []const i64 {
|
|
||||||
return self._shape.computeStrides(self.dtype().sizeOf()).constSlice();
|
|
||||||
}
|
|
||||||
|
|
||||||
var _global_tensor_counter: u64 = 0;
|
var _global_tensor_counter: u64 = 0;
|
||||||
|
|
||||||
/// Internal use
|
/// Internal use
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user