Update Buffer.from to be blocking by default and add options for async loading and memory placement, adjusting aio, hostbuffer, pjrtx, and tensor implementations.
This commit is contained in:
parent
da1fd2d9dc
commit
e6286b6097
@ -631,8 +631,11 @@ fn findSimilarBufferKeys(original_key: []const u8, store: BufferStore, temp_allo
|
|||||||
|
|
||||||
/// deinit all buffers in the given struct
|
/// deinit all buffers in the given struct
|
||||||
pub fn awaitAll(buffers: anytype) !void {
|
pub fn awaitAll(buffers: anytype) !void {
|
||||||
// TODO: implement once we have async buffers.
|
zml.meta.visit((struct {
|
||||||
_ = buffers;
|
fn cb(_: void, buffer: *zml.Buffer) void {
|
||||||
|
buffer.* = buffer.awaitt() catch unreachable;
|
||||||
|
}
|
||||||
|
}).cb, {}, buffers);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *PrefixBuilder, buffer_store: BufferStore, obj: anytype, platform: zml.Platform) !void {
|
fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *PrefixBuilder, buffer_store: BufferStore, obj: anytype, platform: zml.Platform) !void {
|
||||||
@ -653,7 +656,7 @@ fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *Prefi
|
|||||||
log.debug("Loading buffer {s} ({})", .{ prefix, obj._shape });
|
log.debug("Loading buffer {s} ({})", .{ prefix, obj._shape });
|
||||||
stdx.debug.assert(host_buffer.shape().eql(obj._shape), "loadModelBuffers expects to find the same shapes in the model and in the buffer store, got {} and {} for tensor {s}", .{ obj._shape, host_buffer, prefix });
|
stdx.debug.assert(host_buffer.shape().eql(obj._shape), "loadModelBuffers expects to find the same shapes in the model and in the buffer store, got {} and {} for tensor {s}", .{ obj._shape, host_buffer, prefix });
|
||||||
buf_with_metadata._shape = obj._shape;
|
buf_with_metadata._shape = obj._shape;
|
||||||
obj.* = try zml.Buffer.from(platform, buf_with_metadata);
|
obj.* = try zml.Buffer.from(platform, buf_with_metadata, .{});
|
||||||
} else {
|
} else {
|
||||||
log.err("Buffer not found: {s}", .{prefix});
|
log.err("Buffer not found: {s}", .{prefix});
|
||||||
|
|
||||||
|
|||||||
103
zml/buffer.zig
103
zml/buffer.zig
@ -51,8 +51,13 @@ pub const Buffer = struct {
|
|||||||
pub const MAX_NUM_SHARDS: u8 = Platform.MAX_NUM_DEVICES;
|
pub const MAX_NUM_SHARDS: u8 = Platform.MAX_NUM_DEVICES;
|
||||||
pub const Shards = std.BoundedArray(*pjrt.Buffer, MAX_NUM_SHARDS);
|
pub const Shards = std.BoundedArray(*pjrt.Buffer, MAX_NUM_SHARDS);
|
||||||
|
|
||||||
|
pub const FromOptions = struct {
|
||||||
|
wait: bool = true,
|
||||||
|
memory: ?pjrt.Memory.Kind = null,
|
||||||
|
};
|
||||||
|
|
||||||
/// Copies the content of the given buffer from host memory to the accelerator memory.
|
/// Copies the content of the given buffer from host memory to the accelerator memory.
|
||||||
pub fn from(platform: Platform, host_buffer: HostBuffer) !Buffer {
|
pub fn from(platform: Platform, host_buffer: HostBuffer, opts: FromOptions) !Buffer {
|
||||||
var res: Buffer = .{
|
var res: Buffer = .{
|
||||||
._api = platform.pjrt_api,
|
._api = platform.pjrt_api,
|
||||||
._shape = host_buffer.shape(),
|
._shape = host_buffer.shape(),
|
||||||
@ -73,7 +78,6 @@ pub const Buffer = struct {
|
|||||||
const buffer_type = bufferTypeFromDtype(host_buffer.shape().dtype());
|
const buffer_type = bufferTypeFromDtype(host_buffer.shape().dtype());
|
||||||
const byte_strides = host_buffer.strides();
|
const byte_strides = host_buffer.strides();
|
||||||
|
|
||||||
var frames: std.BoundedArray(asynk.Frame(pjrt.Client.bufferFromHostBuffer), MAX_NUM_SHARDS) = .{};
|
|
||||||
const devices = platform.getDevices();
|
const devices = platform.getDevices();
|
||||||
for (0..n_partitions) |i| {
|
for (0..n_partitions) |i| {
|
||||||
// If no sharding if found, the given buffer is replicated on all devices.
|
// If no sharding if found, the given buffer is replicated on all devices.
|
||||||
@ -82,29 +86,50 @@ pub const Buffer = struct {
|
|||||||
break :buf host_buffer.slice1d(ax, .{ .start = start, .end = start + chunk_size });
|
break :buf host_buffer.slice1d(ax, .{ .start = start, .end = start + chunk_size });
|
||||||
} else host_buffer;
|
} else host_buffer;
|
||||||
|
|
||||||
const frame = try asynk.asyncc(pjrt.Client.bufferFromHostBuffer, .{
|
var args = pjrt.Client.BufferFromHostBufferArgs{
|
||||||
platform.pjrt_client,
|
.data = buf._data,
|
||||||
platform.pjrt_api,
|
.buffer_type = buffer_type,
|
||||||
pjrt.Client.BufferFromHostBufferArgs{
|
.dims = buf.shape().dims(),
|
||||||
.data = buf._data,
|
.byte_strides = byte_strides,
|
||||||
.buffer_type = buffer_type,
|
.host_buffer_semantics = .ImmutableUntilTransferCompletes,
|
||||||
.dims = buf.shape().dims(),
|
};
|
||||||
.byte_strides = byte_strides,
|
if (opts.memory) |memory_kind| {
|
||||||
.device = devices[i],
|
const memories = try devices[i].addressableMemories(platform.pjrt_api);
|
||||||
.host_buffer_semantics = .ImmutableOnlyDuringCall,
|
const memory = for (memories) |m| {
|
||||||
},
|
const kind = m.kind(platform.pjrt_api);
|
||||||
});
|
if (kind == memory_kind) break m;
|
||||||
|
} else return error.NotFound;
|
||||||
|
args.memory = memory;
|
||||||
|
} else {
|
||||||
|
args.device = devices[i];
|
||||||
|
}
|
||||||
|
|
||||||
frames.appendAssumeCapacity(frame);
|
const pjrt_buffer, const event = try platform.pjrt_client.bufferFromHostBuffer(platform.pjrt_api, args);
|
||||||
}
|
|
||||||
|
if (event) |ev| {
|
||||||
|
ev.deinit(platform.pjrt_api);
|
||||||
|
}
|
||||||
|
|
||||||
for (frames.slice()) |*frame| {
|
|
||||||
const pjrt_buffer = try frame.awaitt();
|
|
||||||
res._shards.appendAssumeCapacity(pjrt_buffer);
|
res._shards.appendAssumeCapacity(pjrt_buffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (opts.wait) {
|
||||||
|
res = try res.awaitt();
|
||||||
|
}
|
||||||
|
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn awaitt(self: Buffer) !Buffer {
|
||||||
|
for (self._shards.constSlice()) |buffer| {
|
||||||
|
if (buffer.getReadyEvent(self._api)) |ev| {
|
||||||
|
try ev.await_(self._api);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return self;
|
||||||
|
}
|
||||||
|
|
||||||
/// Wraps pre-exisiting `pjrt.Buffer` shards into one `zml.Buffer`.
|
/// Wraps pre-exisiting `pjrt.Buffer` shards into one `zml.Buffer`.
|
||||||
pub fn fromPjrtBuffers(platform: Platform, shape_: Shape, pjrt_buffers: []const *pjrt.Buffer) Buffer {
|
pub fn fromPjrtBuffers(platform: Platform, shape_: Shape, pjrt_buffers: []const *pjrt.Buffer) Buffer {
|
||||||
stdx.debug.assert(pjrt_buffers.len <= MAX_NUM_SHARDS, "ZML doesn't support having more than {} shards. Received {} shards for one buffer.", .{ MAX_NUM_SHARDS, pjrt_buffers.len });
|
stdx.debug.assert(pjrt_buffers.len <= MAX_NUM_SHARDS, "ZML doesn't support having more than {} shards. Received {} shards for one buffer.", .{ MAX_NUM_SHARDS, pjrt_buffers.len });
|
||||||
@ -122,20 +147,40 @@ pub const Buffer = struct {
|
|||||||
/// return a Buffer with the given dimensions.
|
/// return a Buffer with the given dimensions.
|
||||||
pub fn fromSlice(platform: Platform, dimz: anytype, s: anytype) !Buffer {
|
pub fn fromSlice(platform: Platform, dimz: anytype, s: anytype) !Buffer {
|
||||||
const sh = Shape.init(dimz, DataType.fromSliceElementType(s));
|
const sh = Shape.init(dimz, DataType.fromSliceElementType(s));
|
||||||
return from(platform, HostBuffer.fromBytes(sh, std.mem.sliceAsBytes(s)));
|
return from(platform, HostBuffer.fromBytes(sh, std.mem.sliceAsBytes(s)), .{});
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Copies the given Zig slice to the accelerator memory and
|
/// Copies the given Zig slice to the accelerator memory and
|
||||||
/// return a Buffer with the given dimensions.
|
/// return a Buffer with the given dimensions.
|
||||||
pub fn fromBytes(platform: Platform, sh: Shape, data: []const u8) !Buffer {
|
pub fn fromBytes(platform: Platform, sh: Shape, data: []const u8) !Buffer {
|
||||||
return from(platform, HostBuffer.fromBytes(sh, data));
|
return from(platform, HostBuffer.fromBytes(sh, data), .{});
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Copies the given Zig array to the accelerator memory and
|
/// Copies the given Zig array to the accelerator memory and
|
||||||
/// return a Buffer using the array shape.
|
/// return a Buffer using the array shape.
|
||||||
pub fn fromArray(platform: Platform, arr: anytype) !Buffer {
|
pub fn fromArray(platform: Platform, arr: anytype) !Buffer {
|
||||||
const host_buffer = HostBuffer.fromArray(&arr);
|
const host_buffer = HostBuffer.fromArray(&arr);
|
||||||
return try from(platform, host_buffer);
|
return try from(platform, host_buffer, .{ .wait = true });
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Copies the given Zig slice to the accelerator memory and
|
||||||
|
/// return a Buffer with the given dimensions.
|
||||||
|
pub fn fromSliceOpts(platform: Platform, dimz: anytype, s: anytype, opts: FromOptions) !Buffer {
|
||||||
|
const sh = Shape.init(dimz, DataType.fromSliceElementType(s));
|
||||||
|
return from(platform, HostBuffer.fromBytes(sh, std.mem.sliceAsBytes(s)), opts);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Copies the given Zig slice to the accelerator memory and
|
||||||
|
/// return a Buffer with the given dimensions.
|
||||||
|
pub fn fromBytesOpts(platform: Platform, sh: Shape, data: []const u8, opts: FromOptions) !Buffer {
|
||||||
|
return from(platform, HostBuffer.fromBytes(sh, data), opts);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Copies the given Zig array to the accelerator memory and
|
||||||
|
/// return a Buffer using the array shape.
|
||||||
|
pub fn fromArrayOpts(platform: Platform, arr: anytype, opts: FromOptions) !Buffer {
|
||||||
|
const host_buffer = HostBuffer.fromArray(&arr);
|
||||||
|
return try from(platform, host_buffer, opts);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn asPinnedHostBuffer(self: Buffer) HostBuffer {
|
pub fn asPinnedHostBuffer(self: Buffer) HostBuffer {
|
||||||
@ -150,7 +195,7 @@ pub const Buffer = struct {
|
|||||||
pub fn scalar(platform: Platform, val: anytype, dtype_: DataType) !Buffer {
|
pub fn scalar(platform: Platform, val: anytype, dtype_: DataType) !Buffer {
|
||||||
const x = dtype_.constant(val);
|
const x = dtype_.constant(val);
|
||||||
const host_buffer = HostBuffer.fromBytes(Shape.init(.{}, dtype_), x.constSlice());
|
const host_buffer = HostBuffer.fromBytes(Shape.init(.{}, dtype_), x.constSlice());
|
||||||
return try from(platform, host_buffer);
|
return try from(platform, host_buffer, .{ .wait = true });
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates a Buffer with a single element repeated manytime.
|
/// Creates a Buffer with a single element repeated manytime.
|
||||||
@ -176,7 +221,7 @@ pub const Buffer = struct {
|
|||||||
._strides = @splat(0),
|
._strides = @splat(0),
|
||||||
._data = x.constSlice().ptr,
|
._data = x.constSlice().ptr,
|
||||||
};
|
};
|
||||||
return try from(platform, host_buffer);
|
return try from(platform, host_buffer, .{ .wait = true });
|
||||||
}
|
}
|
||||||
|
|
||||||
// To speed up copies, duplicate the scalar value into a vector,
|
// To speed up copies, duplicate the scalar value into a vector,
|
||||||
@ -199,7 +244,7 @@ pub const Buffer = struct {
|
|||||||
else => unreachable,
|
else => unreachable,
|
||||||
}
|
}
|
||||||
const host_buffer: HostBuffer = .{ ._shape = shape_, ._strides = strides, ._data = &bytes };
|
const host_buffer: HostBuffer = .{ ._shape = shape_, ._strides = strides, ._data = &bytes };
|
||||||
return try from(platform, host_buffer);
|
return try from(platform, host_buffer, .{ .wait = true });
|
||||||
}
|
}
|
||||||
|
|
||||||
test constant {
|
test constant {
|
||||||
@ -352,6 +397,16 @@ pub const Buffer = struct {
|
|||||||
if (self._shards.len == 1) return false;
|
if (self._shards.len == 1) return false;
|
||||||
return @reduce(.Or, self._shape._sharding_info);
|
return @reduce(.Or, self._shape._sharding_info);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn copyToMemory(self: Buffer, memory: *const pjrt.Memory) !Buffer {
|
||||||
|
var new_shards: Buffer.Shards = .{};
|
||||||
|
for (self._shards.slice()) |shard| {
|
||||||
|
const new_shard = try shard.copyToMemory(self._api, memory);
|
||||||
|
new_shards.appendAssumeCapacity(new_shard);
|
||||||
|
}
|
||||||
|
|
||||||
|
return Buffer{ ._shape = self._shape, ._shards = new_shards, ._api = self._api };
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
pub fn bufferTypeFromDtype(dt: DataType) pjrt.BufferType {
|
pub fn bufferTypeFromDtype(dt: DataType) pjrt.BufferType {
|
||||||
|
|||||||
@ -158,7 +158,12 @@ pub const HostBuffer = struct {
|
|||||||
|
|
||||||
/// Copies this HostBuffer to the given accelerator.
|
/// Copies this HostBuffer to the given accelerator.
|
||||||
pub fn toDevice(self: HostBuffer, platform_: Platform) !Buffer {
|
pub fn toDevice(self: HostBuffer, platform_: Platform) !Buffer {
|
||||||
return try Buffer.from(platform_, self);
|
return try self.toDeviceOpts(platform_, .{});
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Copies this HostBuffer to the given accelerator (with options).
|
||||||
|
pub fn toDeviceOpts(self: HostBuffer, platform_: Platform, opts: Buffer.FromOptions) !Buffer {
|
||||||
|
return try Buffer.from(platform_, self, opts);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Interpret the underlying data as a contiguous slice.
|
/// Interpret the underlying data as a contiguous slice.
|
||||||
|
|||||||
@ -59,13 +59,9 @@ pub const Client = opaque {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub const BufferFromHostBufferArgs = pjrt.Client.BufferFromHostBufferArgs;
|
pub const BufferFromHostBufferArgs = pjrt.Client.BufferFromHostBufferArgs;
|
||||||
pub fn bufferFromHostBuffer(self: *const Client, api: *const Api, args: BufferFromHostBufferArgs) ApiError!*Buffer {
|
pub fn bufferFromHostBuffer(self: *const Client, api: *const Api, args: BufferFromHostBufferArgs) ApiError!struct { *Buffer, ?*Event } {
|
||||||
const buffer, const event_ = try asynk.callBlocking(pjrt.Client.bufferFromHostBuffer, .{ self.inner(), api, args });
|
const buffer, const event_ = try self.inner().bufferFromHostBuffer(api, args);
|
||||||
if (event_) |event__| {
|
return .{ @ptrCast(buffer), @ptrCast(event_) };
|
||||||
const event: *Event = @ptrCast(event__);
|
|
||||||
try event.await_(api);
|
|
||||||
}
|
|
||||||
return @ptrCast(buffer);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn deserializeAndLoad(self: *const Client, api: *const Api, bytes: []const u8) ApiError!*LoadedExecutable {
|
pub fn deserializeAndLoad(self: *const Client, api: *const Api, bytes: []const u8) ApiError!*LoadedExecutable {
|
||||||
|
|||||||
@ -2674,7 +2674,7 @@ pub const Tensor = struct {
|
|||||||
// Test with actual values, no batching.
|
// Test with actual values, no batching.
|
||||||
{
|
{
|
||||||
const a_host = try zml.HostBuffer.arange(std.testing.allocator, .{ .end = 9 }, .i32);
|
const a_host = try zml.HostBuffer.arange(std.testing.allocator, .{ .end = 9 }, .i32);
|
||||||
const a = (try zml.Buffer.from(platform, a_host.reshape(.{ 3, 3 }))).withTags(.{ .a, .b });
|
const a = (try zml.Buffer.from(platform, a_host.reshape(.{ 3, 3 }), .{})).withTags(.{ .a, .b });
|
||||||
defer a.deinit();
|
defer a.deinit();
|
||||||
a_host.deinit(std.testing.allocator);
|
a_host.deinit(std.testing.allocator);
|
||||||
|
|
||||||
@ -2693,7 +2693,7 @@ pub const Tensor = struct {
|
|||||||
// Test with setting individual values (no batching)
|
// Test with setting individual values (no batching)
|
||||||
{
|
{
|
||||||
const a_host = try zml.HostBuffer.arange(std.testing.allocator, .{ .end = 9 }, .i32);
|
const a_host = try zml.HostBuffer.arange(std.testing.allocator, .{ .end = 9 }, .i32);
|
||||||
const a = try zml.Buffer.from(platform, a_host);
|
const a = try zml.Buffer.from(platform, a_host, .{});
|
||||||
defer a.deinit();
|
defer a.deinit();
|
||||||
a_host.deinit(std.testing.allocator);
|
a_host.deinit(std.testing.allocator);
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user