Update Buffer.from to be blocking by default and add options for async loading and memory placement, adjusting aio, hostbuffer, pjrtx, and tensor implementations.

This commit is contained in:
Tarry Singh 2024-12-25 17:14:44 +00:00
parent da1fd2d9dc
commit e6286b6097
5 changed files with 96 additions and 37 deletions

View File

@ -631,8 +631,11 @@ fn findSimilarBufferKeys(original_key: []const u8, store: BufferStore, temp_allo
/// deinit all buffers in the given struct
pub fn awaitAll(buffers: anytype) !void {
// TODO: implement once we have async buffers.
_ = buffers;
zml.meta.visit((struct {
fn cb(_: void, buffer: *zml.Buffer) void {
buffer.* = buffer.awaitt() catch unreachable;
}
}).cb, {}, buffers);
}
fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *PrefixBuilder, buffer_store: BufferStore, obj: anytype, platform: zml.Platform) !void {
@ -653,7 +656,7 @@ fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *Prefi
log.debug("Loading buffer {s} ({})", .{ prefix, obj._shape });
stdx.debug.assert(host_buffer.shape().eql(obj._shape), "loadModelBuffers expects to find the same shapes in the model and in the buffer store, got {} and {} for tensor {s}", .{ obj._shape, host_buffer, prefix });
buf_with_metadata._shape = obj._shape;
obj.* = try zml.Buffer.from(platform, buf_with_metadata);
obj.* = try zml.Buffer.from(platform, buf_with_metadata, .{});
} else {
log.err("Buffer not found: {s}", .{prefix});

View File

@ -51,8 +51,13 @@ pub const Buffer = struct {
pub const MAX_NUM_SHARDS: u8 = Platform.MAX_NUM_DEVICES;
pub const Shards = std.BoundedArray(*pjrt.Buffer, MAX_NUM_SHARDS);
pub const FromOptions = struct {
wait: bool = true,
memory: ?pjrt.Memory.Kind = null,
};
/// Copies the content of the given buffer from host memory to the accelerator memory.
pub fn from(platform: Platform, host_buffer: HostBuffer) !Buffer {
pub fn from(platform: Platform, host_buffer: HostBuffer, opts: FromOptions) !Buffer {
var res: Buffer = .{
._api = platform.pjrt_api,
._shape = host_buffer.shape(),
@ -73,7 +78,6 @@ pub const Buffer = struct {
const buffer_type = bufferTypeFromDtype(host_buffer.shape().dtype());
const byte_strides = host_buffer.strides();
var frames: std.BoundedArray(asynk.Frame(pjrt.Client.bufferFromHostBuffer), MAX_NUM_SHARDS) = .{};
const devices = platform.getDevices();
for (0..n_partitions) |i| {
// If no sharding if found, the given buffer is replicated on all devices.
@ -82,29 +86,50 @@ pub const Buffer = struct {
break :buf host_buffer.slice1d(ax, .{ .start = start, .end = start + chunk_size });
} else host_buffer;
const frame = try asynk.asyncc(pjrt.Client.bufferFromHostBuffer, .{
platform.pjrt_client,
platform.pjrt_api,
pjrt.Client.BufferFromHostBufferArgs{
.data = buf._data,
.buffer_type = buffer_type,
.dims = buf.shape().dims(),
.byte_strides = byte_strides,
.device = devices[i],
.host_buffer_semantics = .ImmutableOnlyDuringCall,
},
});
var args = pjrt.Client.BufferFromHostBufferArgs{
.data = buf._data,
.buffer_type = buffer_type,
.dims = buf.shape().dims(),
.byte_strides = byte_strides,
.host_buffer_semantics = .ImmutableUntilTransferCompletes,
};
if (opts.memory) |memory_kind| {
const memories = try devices[i].addressableMemories(platform.pjrt_api);
const memory = for (memories) |m| {
const kind = m.kind(platform.pjrt_api);
if (kind == memory_kind) break m;
} else return error.NotFound;
args.memory = memory;
} else {
args.device = devices[i];
}
frames.appendAssumeCapacity(frame);
}
const pjrt_buffer, const event = try platform.pjrt_client.bufferFromHostBuffer(platform.pjrt_api, args);
if (event) |ev| {
ev.deinit(platform.pjrt_api);
}
for (frames.slice()) |*frame| {
const pjrt_buffer = try frame.awaitt();
res._shards.appendAssumeCapacity(pjrt_buffer);
}
if (opts.wait) {
res = try res.awaitt();
}
return res;
}
pub fn awaitt(self: Buffer) !Buffer {
for (self._shards.constSlice()) |buffer| {
if (buffer.getReadyEvent(self._api)) |ev| {
try ev.await_(self._api);
}
}
return self;
}
/// Wraps pre-exisiting `pjrt.Buffer` shards into one `zml.Buffer`.
pub fn fromPjrtBuffers(platform: Platform, shape_: Shape, pjrt_buffers: []const *pjrt.Buffer) Buffer {
stdx.debug.assert(pjrt_buffers.len <= MAX_NUM_SHARDS, "ZML doesn't support having more than {} shards. Received {} shards for one buffer.", .{ MAX_NUM_SHARDS, pjrt_buffers.len });
@ -122,20 +147,40 @@ pub const Buffer = struct {
/// return a Buffer with the given dimensions.
pub fn fromSlice(platform: Platform, dimz: anytype, s: anytype) !Buffer {
const sh = Shape.init(dimz, DataType.fromSliceElementType(s));
return from(platform, HostBuffer.fromBytes(sh, std.mem.sliceAsBytes(s)));
return from(platform, HostBuffer.fromBytes(sh, std.mem.sliceAsBytes(s)), .{});
}
/// Copies the given Zig slice to the accelerator memory and
/// return a Buffer with the given dimensions.
pub fn fromBytes(platform: Platform, sh: Shape, data: []const u8) !Buffer {
return from(platform, HostBuffer.fromBytes(sh, data));
return from(platform, HostBuffer.fromBytes(sh, data), .{});
}
/// Copies the given Zig array to the accelerator memory and
/// return a Buffer using the array shape.
pub fn fromArray(platform: Platform, arr: anytype) !Buffer {
const host_buffer = HostBuffer.fromArray(&arr);
return try from(platform, host_buffer);
return try from(platform, host_buffer, .{ .wait = true });
}
/// Copies the given Zig slice to the accelerator memory and
/// return a Buffer with the given dimensions.
pub fn fromSliceOpts(platform: Platform, dimz: anytype, s: anytype, opts: FromOptions) !Buffer {
const sh = Shape.init(dimz, DataType.fromSliceElementType(s));
return from(platform, HostBuffer.fromBytes(sh, std.mem.sliceAsBytes(s)), opts);
}
/// Copies the given Zig slice to the accelerator memory and
/// return a Buffer with the given dimensions.
pub fn fromBytesOpts(platform: Platform, sh: Shape, data: []const u8, opts: FromOptions) !Buffer {
return from(platform, HostBuffer.fromBytes(sh, data), opts);
}
/// Copies the given Zig array to the accelerator memory and
/// return a Buffer using the array shape.
pub fn fromArrayOpts(platform: Platform, arr: anytype, opts: FromOptions) !Buffer {
const host_buffer = HostBuffer.fromArray(&arr);
return try from(platform, host_buffer, opts);
}
pub fn asPinnedHostBuffer(self: Buffer) HostBuffer {
@ -150,7 +195,7 @@ pub const Buffer = struct {
pub fn scalar(platform: Platform, val: anytype, dtype_: DataType) !Buffer {
const x = dtype_.constant(val);
const host_buffer = HostBuffer.fromBytes(Shape.init(.{}, dtype_), x.constSlice());
return try from(platform, host_buffer);
return try from(platform, host_buffer, .{ .wait = true });
}
/// Creates a Buffer with a single element repeated manytime.
@ -176,7 +221,7 @@ pub const Buffer = struct {
._strides = @splat(0),
._data = x.constSlice().ptr,
};
return try from(platform, host_buffer);
return try from(platform, host_buffer, .{ .wait = true });
}
// To speed up copies, duplicate the scalar value into a vector,
@ -199,7 +244,7 @@ pub const Buffer = struct {
else => unreachable,
}
const host_buffer: HostBuffer = .{ ._shape = shape_, ._strides = strides, ._data = &bytes };
return try from(platform, host_buffer);
return try from(platform, host_buffer, .{ .wait = true });
}
test constant {
@ -352,6 +397,16 @@ pub const Buffer = struct {
if (self._shards.len == 1) return false;
return @reduce(.Or, self._shape._sharding_info);
}
pub fn copyToMemory(self: Buffer, memory: *const pjrt.Memory) !Buffer {
var new_shards: Buffer.Shards = .{};
for (self._shards.slice()) |shard| {
const new_shard = try shard.copyToMemory(self._api, memory);
new_shards.appendAssumeCapacity(new_shard);
}
return Buffer{ ._shape = self._shape, ._shards = new_shards, ._api = self._api };
}
};
pub fn bufferTypeFromDtype(dt: DataType) pjrt.BufferType {

View File

@ -158,7 +158,12 @@ pub const HostBuffer = struct {
/// Copies this HostBuffer to the given accelerator.
pub fn toDevice(self: HostBuffer, platform_: Platform) !Buffer {
return try Buffer.from(platform_, self);
return try self.toDeviceOpts(platform_, .{});
}
/// Copies this HostBuffer to the given accelerator (with options).
pub fn toDeviceOpts(self: HostBuffer, platform_: Platform, opts: Buffer.FromOptions) !Buffer {
return try Buffer.from(platform_, self, opts);
}
/// Interpret the underlying data as a contiguous slice.

View File

@ -59,13 +59,9 @@ pub const Client = opaque {
}
pub const BufferFromHostBufferArgs = pjrt.Client.BufferFromHostBufferArgs;
pub fn bufferFromHostBuffer(self: *const Client, api: *const Api, args: BufferFromHostBufferArgs) ApiError!*Buffer {
const buffer, const event_ = try asynk.callBlocking(pjrt.Client.bufferFromHostBuffer, .{ self.inner(), api, args });
if (event_) |event__| {
const event: *Event = @ptrCast(event__);
try event.await_(api);
}
return @ptrCast(buffer);
pub fn bufferFromHostBuffer(self: *const Client, api: *const Api, args: BufferFromHostBufferArgs) ApiError!struct { *Buffer, ?*Event } {
const buffer, const event_ = try self.inner().bufferFromHostBuffer(api, args);
return .{ @ptrCast(buffer), @ptrCast(event_) };
}
pub fn deserializeAndLoad(self: *const Client, api: *const Api, bytes: []const u8) ApiError!*LoadedExecutable {

View File

@ -2674,7 +2674,7 @@ pub const Tensor = struct {
// Test with actual values, no batching.
{
const a_host = try zml.HostBuffer.arange(std.testing.allocator, .{ .end = 9 }, .i32);
const a = (try zml.Buffer.from(platform, a_host.reshape(.{ 3, 3 }))).withTags(.{ .a, .b });
const a = (try zml.Buffer.from(platform, a_host.reshape(.{ 3, 3 }), .{})).withTags(.{ .a, .b });
defer a.deinit();
a_host.deinit(std.testing.allocator);
@ -2693,7 +2693,7 @@ pub const Tensor = struct {
// Test with setting individual values (no batching)
{
const a_host = try zml.HostBuffer.arange(std.testing.allocator, .{ .end = 9 }, .i32);
const a = try zml.Buffer.from(platform, a_host);
const a = try zml.Buffer.from(platform, a_host, .{});
defer a.deinit();
a_host.deinit(std.testing.allocator);