Update Buffer.from to be blocking by default and add options for async loading and memory placement, adjusting aio, hostbuffer, pjrtx, and tensor implementations.
This commit is contained in:
parent
da1fd2d9dc
commit
e6286b6097
@ -631,8 +631,11 @@ fn findSimilarBufferKeys(original_key: []const u8, store: BufferStore, temp_allo
|
||||
|
||||
/// deinit all buffers in the given struct
|
||||
pub fn awaitAll(buffers: anytype) !void {
|
||||
// TODO: implement once we have async buffers.
|
||||
_ = buffers;
|
||||
zml.meta.visit((struct {
|
||||
fn cb(_: void, buffer: *zml.Buffer) void {
|
||||
buffer.* = buffer.awaitt() catch unreachable;
|
||||
}
|
||||
}).cb, {}, buffers);
|
||||
}
|
||||
|
||||
fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *PrefixBuilder, buffer_store: BufferStore, obj: anytype, platform: zml.Platform) !void {
|
||||
@ -653,7 +656,7 @@ fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *Prefi
|
||||
log.debug("Loading buffer {s} ({})", .{ prefix, obj._shape });
|
||||
stdx.debug.assert(host_buffer.shape().eql(obj._shape), "loadModelBuffers expects to find the same shapes in the model and in the buffer store, got {} and {} for tensor {s}", .{ obj._shape, host_buffer, prefix });
|
||||
buf_with_metadata._shape = obj._shape;
|
||||
obj.* = try zml.Buffer.from(platform, buf_with_metadata);
|
||||
obj.* = try zml.Buffer.from(platform, buf_with_metadata, .{});
|
||||
} else {
|
||||
log.err("Buffer not found: {s}", .{prefix});
|
||||
|
||||
|
||||
103
zml/buffer.zig
103
zml/buffer.zig
@ -51,8 +51,13 @@ pub const Buffer = struct {
|
||||
pub const MAX_NUM_SHARDS: u8 = Platform.MAX_NUM_DEVICES;
|
||||
pub const Shards = std.BoundedArray(*pjrt.Buffer, MAX_NUM_SHARDS);
|
||||
|
||||
pub const FromOptions = struct {
|
||||
wait: bool = true,
|
||||
memory: ?pjrt.Memory.Kind = null,
|
||||
};
|
||||
|
||||
/// Copies the content of the given buffer from host memory to the accelerator memory.
|
||||
pub fn from(platform: Platform, host_buffer: HostBuffer) !Buffer {
|
||||
pub fn from(platform: Platform, host_buffer: HostBuffer, opts: FromOptions) !Buffer {
|
||||
var res: Buffer = .{
|
||||
._api = platform.pjrt_api,
|
||||
._shape = host_buffer.shape(),
|
||||
@ -73,7 +78,6 @@ pub const Buffer = struct {
|
||||
const buffer_type = bufferTypeFromDtype(host_buffer.shape().dtype());
|
||||
const byte_strides = host_buffer.strides();
|
||||
|
||||
var frames: std.BoundedArray(asynk.Frame(pjrt.Client.bufferFromHostBuffer), MAX_NUM_SHARDS) = .{};
|
||||
const devices = platform.getDevices();
|
||||
for (0..n_partitions) |i| {
|
||||
// If no sharding if found, the given buffer is replicated on all devices.
|
||||
@ -82,29 +86,50 @@ pub const Buffer = struct {
|
||||
break :buf host_buffer.slice1d(ax, .{ .start = start, .end = start + chunk_size });
|
||||
} else host_buffer;
|
||||
|
||||
const frame = try asynk.asyncc(pjrt.Client.bufferFromHostBuffer, .{
|
||||
platform.pjrt_client,
|
||||
platform.pjrt_api,
|
||||
pjrt.Client.BufferFromHostBufferArgs{
|
||||
.data = buf._data,
|
||||
.buffer_type = buffer_type,
|
||||
.dims = buf.shape().dims(),
|
||||
.byte_strides = byte_strides,
|
||||
.device = devices[i],
|
||||
.host_buffer_semantics = .ImmutableOnlyDuringCall,
|
||||
},
|
||||
});
|
||||
var args = pjrt.Client.BufferFromHostBufferArgs{
|
||||
.data = buf._data,
|
||||
.buffer_type = buffer_type,
|
||||
.dims = buf.shape().dims(),
|
||||
.byte_strides = byte_strides,
|
||||
.host_buffer_semantics = .ImmutableUntilTransferCompletes,
|
||||
};
|
||||
if (opts.memory) |memory_kind| {
|
||||
const memories = try devices[i].addressableMemories(platform.pjrt_api);
|
||||
const memory = for (memories) |m| {
|
||||
const kind = m.kind(platform.pjrt_api);
|
||||
if (kind == memory_kind) break m;
|
||||
} else return error.NotFound;
|
||||
args.memory = memory;
|
||||
} else {
|
||||
args.device = devices[i];
|
||||
}
|
||||
|
||||
frames.appendAssumeCapacity(frame);
|
||||
}
|
||||
const pjrt_buffer, const event = try platform.pjrt_client.bufferFromHostBuffer(platform.pjrt_api, args);
|
||||
|
||||
if (event) |ev| {
|
||||
ev.deinit(platform.pjrt_api);
|
||||
}
|
||||
|
||||
for (frames.slice()) |*frame| {
|
||||
const pjrt_buffer = try frame.awaitt();
|
||||
res._shards.appendAssumeCapacity(pjrt_buffer);
|
||||
}
|
||||
|
||||
if (opts.wait) {
|
||||
res = try res.awaitt();
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
pub fn awaitt(self: Buffer) !Buffer {
|
||||
for (self._shards.constSlice()) |buffer| {
|
||||
if (buffer.getReadyEvent(self._api)) |ev| {
|
||||
try ev.await_(self._api);
|
||||
}
|
||||
}
|
||||
|
||||
return self;
|
||||
}
|
||||
|
||||
/// Wraps pre-exisiting `pjrt.Buffer` shards into one `zml.Buffer`.
|
||||
pub fn fromPjrtBuffers(platform: Platform, shape_: Shape, pjrt_buffers: []const *pjrt.Buffer) Buffer {
|
||||
stdx.debug.assert(pjrt_buffers.len <= MAX_NUM_SHARDS, "ZML doesn't support having more than {} shards. Received {} shards for one buffer.", .{ MAX_NUM_SHARDS, pjrt_buffers.len });
|
||||
@ -122,20 +147,40 @@ pub const Buffer = struct {
|
||||
/// return a Buffer with the given dimensions.
|
||||
pub fn fromSlice(platform: Platform, dimz: anytype, s: anytype) !Buffer {
|
||||
const sh = Shape.init(dimz, DataType.fromSliceElementType(s));
|
||||
return from(platform, HostBuffer.fromBytes(sh, std.mem.sliceAsBytes(s)));
|
||||
return from(platform, HostBuffer.fromBytes(sh, std.mem.sliceAsBytes(s)), .{});
|
||||
}
|
||||
|
||||
/// Copies the given Zig slice to the accelerator memory and
|
||||
/// return a Buffer with the given dimensions.
|
||||
pub fn fromBytes(platform: Platform, sh: Shape, data: []const u8) !Buffer {
|
||||
return from(platform, HostBuffer.fromBytes(sh, data));
|
||||
return from(platform, HostBuffer.fromBytes(sh, data), .{});
|
||||
}
|
||||
|
||||
/// Copies the given Zig array to the accelerator memory and
|
||||
/// return a Buffer using the array shape.
|
||||
pub fn fromArray(platform: Platform, arr: anytype) !Buffer {
|
||||
const host_buffer = HostBuffer.fromArray(&arr);
|
||||
return try from(platform, host_buffer);
|
||||
return try from(platform, host_buffer, .{ .wait = true });
|
||||
}
|
||||
|
||||
/// Copies the given Zig slice to the accelerator memory and
|
||||
/// return a Buffer with the given dimensions.
|
||||
pub fn fromSliceOpts(platform: Platform, dimz: anytype, s: anytype, opts: FromOptions) !Buffer {
|
||||
const sh = Shape.init(dimz, DataType.fromSliceElementType(s));
|
||||
return from(platform, HostBuffer.fromBytes(sh, std.mem.sliceAsBytes(s)), opts);
|
||||
}
|
||||
|
||||
/// Copies the given Zig slice to the accelerator memory and
|
||||
/// return a Buffer with the given dimensions.
|
||||
pub fn fromBytesOpts(platform: Platform, sh: Shape, data: []const u8, opts: FromOptions) !Buffer {
|
||||
return from(platform, HostBuffer.fromBytes(sh, data), opts);
|
||||
}
|
||||
|
||||
/// Copies the given Zig array to the accelerator memory and
|
||||
/// return a Buffer using the array shape.
|
||||
pub fn fromArrayOpts(platform: Platform, arr: anytype, opts: FromOptions) !Buffer {
|
||||
const host_buffer = HostBuffer.fromArray(&arr);
|
||||
return try from(platform, host_buffer, opts);
|
||||
}
|
||||
|
||||
pub fn asPinnedHostBuffer(self: Buffer) HostBuffer {
|
||||
@ -150,7 +195,7 @@ pub const Buffer = struct {
|
||||
pub fn scalar(platform: Platform, val: anytype, dtype_: DataType) !Buffer {
|
||||
const x = dtype_.constant(val);
|
||||
const host_buffer = HostBuffer.fromBytes(Shape.init(.{}, dtype_), x.constSlice());
|
||||
return try from(platform, host_buffer);
|
||||
return try from(platform, host_buffer, .{ .wait = true });
|
||||
}
|
||||
|
||||
/// Creates a Buffer with a single element repeated manytime.
|
||||
@ -176,7 +221,7 @@ pub const Buffer = struct {
|
||||
._strides = @splat(0),
|
||||
._data = x.constSlice().ptr,
|
||||
};
|
||||
return try from(platform, host_buffer);
|
||||
return try from(platform, host_buffer, .{ .wait = true });
|
||||
}
|
||||
|
||||
// To speed up copies, duplicate the scalar value into a vector,
|
||||
@ -199,7 +244,7 @@ pub const Buffer = struct {
|
||||
else => unreachable,
|
||||
}
|
||||
const host_buffer: HostBuffer = .{ ._shape = shape_, ._strides = strides, ._data = &bytes };
|
||||
return try from(platform, host_buffer);
|
||||
return try from(platform, host_buffer, .{ .wait = true });
|
||||
}
|
||||
|
||||
test constant {
|
||||
@ -352,6 +397,16 @@ pub const Buffer = struct {
|
||||
if (self._shards.len == 1) return false;
|
||||
return @reduce(.Or, self._shape._sharding_info);
|
||||
}
|
||||
|
||||
pub fn copyToMemory(self: Buffer, memory: *const pjrt.Memory) !Buffer {
|
||||
var new_shards: Buffer.Shards = .{};
|
||||
for (self._shards.slice()) |shard| {
|
||||
const new_shard = try shard.copyToMemory(self._api, memory);
|
||||
new_shards.appendAssumeCapacity(new_shard);
|
||||
}
|
||||
|
||||
return Buffer{ ._shape = self._shape, ._shards = new_shards, ._api = self._api };
|
||||
}
|
||||
};
|
||||
|
||||
pub fn bufferTypeFromDtype(dt: DataType) pjrt.BufferType {
|
||||
|
||||
@ -158,7 +158,12 @@ pub const HostBuffer = struct {
|
||||
|
||||
/// Copies this HostBuffer to the given accelerator.
|
||||
pub fn toDevice(self: HostBuffer, platform_: Platform) !Buffer {
|
||||
return try Buffer.from(platform_, self);
|
||||
return try self.toDeviceOpts(platform_, .{});
|
||||
}
|
||||
|
||||
/// Copies this HostBuffer to the given accelerator (with options).
|
||||
pub fn toDeviceOpts(self: HostBuffer, platform_: Platform, opts: Buffer.FromOptions) !Buffer {
|
||||
return try Buffer.from(platform_, self, opts);
|
||||
}
|
||||
|
||||
/// Interpret the underlying data as a contiguous slice.
|
||||
|
||||
@ -59,13 +59,9 @@ pub const Client = opaque {
|
||||
}
|
||||
|
||||
pub const BufferFromHostBufferArgs = pjrt.Client.BufferFromHostBufferArgs;
|
||||
pub fn bufferFromHostBuffer(self: *const Client, api: *const Api, args: BufferFromHostBufferArgs) ApiError!*Buffer {
|
||||
const buffer, const event_ = try asynk.callBlocking(pjrt.Client.bufferFromHostBuffer, .{ self.inner(), api, args });
|
||||
if (event_) |event__| {
|
||||
const event: *Event = @ptrCast(event__);
|
||||
try event.await_(api);
|
||||
}
|
||||
return @ptrCast(buffer);
|
||||
pub fn bufferFromHostBuffer(self: *const Client, api: *const Api, args: BufferFromHostBufferArgs) ApiError!struct { *Buffer, ?*Event } {
|
||||
const buffer, const event_ = try self.inner().bufferFromHostBuffer(api, args);
|
||||
return .{ @ptrCast(buffer), @ptrCast(event_) };
|
||||
}
|
||||
|
||||
pub fn deserializeAndLoad(self: *const Client, api: *const Api, bytes: []const u8) ApiError!*LoadedExecutable {
|
||||
|
||||
@ -2674,7 +2674,7 @@ pub const Tensor = struct {
|
||||
// Test with actual values, no batching.
|
||||
{
|
||||
const a_host = try zml.HostBuffer.arange(std.testing.allocator, .{ .end = 9 }, .i32);
|
||||
const a = (try zml.Buffer.from(platform, a_host.reshape(.{ 3, 3 }))).withTags(.{ .a, .b });
|
||||
const a = (try zml.Buffer.from(platform, a_host.reshape(.{ 3, 3 }), .{})).withTags(.{ .a, .b });
|
||||
defer a.deinit();
|
||||
a_host.deinit(std.testing.allocator);
|
||||
|
||||
@ -2693,7 +2693,7 @@ pub const Tensor = struct {
|
||||
// Test with setting individual values (no batching)
|
||||
{
|
||||
const a_host = try zml.HostBuffer.arange(std.testing.allocator, .{ .end = 9 }, .i32);
|
||||
const a = try zml.Buffer.from(platform, a_host);
|
||||
const a = try zml.Buffer.from(platform, a_host, .{});
|
||||
defer a.deinit();
|
||||
a_host.deinit(std.testing.allocator);
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user