From e6286b6097919497f881a012e636da1fc3b855c7 Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Wed, 25 Dec 2024 17:14:44 +0000 Subject: [PATCH] Update Buffer.from to be blocking by default and add options for async loading and memory placement, adjusting aio, hostbuffer, pjrtx, and tensor implementations. --- zml/aio.zig | 9 ++-- zml/buffer.zig | 103 ++++++++++++++++++++++++++++++++++----------- zml/hostbuffer.zig | 7 ++- zml/pjrtx.zig | 10 ++--- zml/tensor.zig | 4 +- 5 files changed, 96 insertions(+), 37 deletions(-) diff --git a/zml/aio.zig b/zml/aio.zig index 8ffa372..2c370a1 100644 --- a/zml/aio.zig +++ b/zml/aio.zig @@ -631,8 +631,11 @@ fn findSimilarBufferKeys(original_key: []const u8, store: BufferStore, temp_allo /// deinit all buffers in the given struct pub fn awaitAll(buffers: anytype) !void { - // TODO: implement once we have async buffers. - _ = buffers; + zml.meta.visit((struct { + fn cb(_: void, buffer: *zml.Buffer) void { + buffer.* = buffer.awaitt() catch unreachable; + } + }).cb, {}, buffers); } fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *PrefixBuilder, buffer_store: BufferStore, obj: anytype, platform: zml.Platform) !void { @@ -653,7 +656,7 @@ fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *Prefi log.debug("Loading buffer {s} ({})", .{ prefix, obj._shape }); stdx.debug.assert(host_buffer.shape().eql(obj._shape), "loadModelBuffers expects to find the same shapes in the model and in the buffer store, got {} and {} for tensor {s}", .{ obj._shape, host_buffer, prefix }); buf_with_metadata._shape = obj._shape; - obj.* = try zml.Buffer.from(platform, buf_with_metadata); + obj.* = try zml.Buffer.from(platform, buf_with_metadata, .{}); } else { log.err("Buffer not found: {s}", .{prefix}); diff --git a/zml/buffer.zig b/zml/buffer.zig index 4828844..9d79c0b 100644 --- a/zml/buffer.zig +++ b/zml/buffer.zig @@ -51,8 +51,13 @@ pub const Buffer = struct { pub const MAX_NUM_SHARDS: u8 = Platform.MAX_NUM_DEVICES; pub const Shards = std.BoundedArray(*pjrt.Buffer, MAX_NUM_SHARDS); + pub const FromOptions = struct { + wait: bool = true, + memory: ?pjrt.Memory.Kind = null, + }; + /// Copies the content of the given buffer from host memory to the accelerator memory. - pub fn from(platform: Platform, host_buffer: HostBuffer) !Buffer { + pub fn from(platform: Platform, host_buffer: HostBuffer, opts: FromOptions) !Buffer { var res: Buffer = .{ ._api = platform.pjrt_api, ._shape = host_buffer.shape(), @@ -73,7 +78,6 @@ pub const Buffer = struct { const buffer_type = bufferTypeFromDtype(host_buffer.shape().dtype()); const byte_strides = host_buffer.strides(); - var frames: std.BoundedArray(asynk.Frame(pjrt.Client.bufferFromHostBuffer), MAX_NUM_SHARDS) = .{}; const devices = platform.getDevices(); for (0..n_partitions) |i| { // If no sharding if found, the given buffer is replicated on all devices. @@ -82,29 +86,50 @@ pub const Buffer = struct { break :buf host_buffer.slice1d(ax, .{ .start = start, .end = start + chunk_size }); } else host_buffer; - const frame = try asynk.asyncc(pjrt.Client.bufferFromHostBuffer, .{ - platform.pjrt_client, - platform.pjrt_api, - pjrt.Client.BufferFromHostBufferArgs{ - .data = buf._data, - .buffer_type = buffer_type, - .dims = buf.shape().dims(), - .byte_strides = byte_strides, - .device = devices[i], - .host_buffer_semantics = .ImmutableOnlyDuringCall, - }, - }); + var args = pjrt.Client.BufferFromHostBufferArgs{ + .data = buf._data, + .buffer_type = buffer_type, + .dims = buf.shape().dims(), + .byte_strides = byte_strides, + .host_buffer_semantics = .ImmutableUntilTransferCompletes, + }; + if (opts.memory) |memory_kind| { + const memories = try devices[i].addressableMemories(platform.pjrt_api); + const memory = for (memories) |m| { + const kind = m.kind(platform.pjrt_api); + if (kind == memory_kind) break m; + } else return error.NotFound; + args.memory = memory; + } else { + args.device = devices[i]; + } - frames.appendAssumeCapacity(frame); - } + const pjrt_buffer, const event = try platform.pjrt_client.bufferFromHostBuffer(platform.pjrt_api, args); + + if (event) |ev| { + ev.deinit(platform.pjrt_api); + } - for (frames.slice()) |*frame| { - const pjrt_buffer = try frame.awaitt(); res._shards.appendAssumeCapacity(pjrt_buffer); } + + if (opts.wait) { + res = try res.awaitt(); + } + return res; } + pub fn awaitt(self: Buffer) !Buffer { + for (self._shards.constSlice()) |buffer| { + if (buffer.getReadyEvent(self._api)) |ev| { + try ev.await_(self._api); + } + } + + return self; + } + /// Wraps pre-exisiting `pjrt.Buffer` shards into one `zml.Buffer`. pub fn fromPjrtBuffers(platform: Platform, shape_: Shape, pjrt_buffers: []const *pjrt.Buffer) Buffer { stdx.debug.assert(pjrt_buffers.len <= MAX_NUM_SHARDS, "ZML doesn't support having more than {} shards. Received {} shards for one buffer.", .{ MAX_NUM_SHARDS, pjrt_buffers.len }); @@ -122,20 +147,40 @@ pub const Buffer = struct { /// return a Buffer with the given dimensions. pub fn fromSlice(platform: Platform, dimz: anytype, s: anytype) !Buffer { const sh = Shape.init(dimz, DataType.fromSliceElementType(s)); - return from(platform, HostBuffer.fromBytes(sh, std.mem.sliceAsBytes(s))); + return from(platform, HostBuffer.fromBytes(sh, std.mem.sliceAsBytes(s)), .{}); } /// Copies the given Zig slice to the accelerator memory and /// return a Buffer with the given dimensions. pub fn fromBytes(platform: Platform, sh: Shape, data: []const u8) !Buffer { - return from(platform, HostBuffer.fromBytes(sh, data)); + return from(platform, HostBuffer.fromBytes(sh, data), .{}); } /// Copies the given Zig array to the accelerator memory and /// return a Buffer using the array shape. pub fn fromArray(platform: Platform, arr: anytype) !Buffer { const host_buffer = HostBuffer.fromArray(&arr); - return try from(platform, host_buffer); + return try from(platform, host_buffer, .{ .wait = true }); + } + + /// Copies the given Zig slice to the accelerator memory and + /// return a Buffer with the given dimensions. + pub fn fromSliceOpts(platform: Platform, dimz: anytype, s: anytype, opts: FromOptions) !Buffer { + const sh = Shape.init(dimz, DataType.fromSliceElementType(s)); + return from(platform, HostBuffer.fromBytes(sh, std.mem.sliceAsBytes(s)), opts); + } + + /// Copies the given Zig slice to the accelerator memory and + /// return a Buffer with the given dimensions. + pub fn fromBytesOpts(platform: Platform, sh: Shape, data: []const u8, opts: FromOptions) !Buffer { + return from(platform, HostBuffer.fromBytes(sh, data), opts); + } + + /// Copies the given Zig array to the accelerator memory and + /// return a Buffer using the array shape. + pub fn fromArrayOpts(platform: Platform, arr: anytype, opts: FromOptions) !Buffer { + const host_buffer = HostBuffer.fromArray(&arr); + return try from(platform, host_buffer, opts); } pub fn asPinnedHostBuffer(self: Buffer) HostBuffer { @@ -150,7 +195,7 @@ pub const Buffer = struct { pub fn scalar(platform: Platform, val: anytype, dtype_: DataType) !Buffer { const x = dtype_.constant(val); const host_buffer = HostBuffer.fromBytes(Shape.init(.{}, dtype_), x.constSlice()); - return try from(platform, host_buffer); + return try from(platform, host_buffer, .{ .wait = true }); } /// Creates a Buffer with a single element repeated manytime. @@ -176,7 +221,7 @@ pub const Buffer = struct { ._strides = @splat(0), ._data = x.constSlice().ptr, }; - return try from(platform, host_buffer); + return try from(platform, host_buffer, .{ .wait = true }); } // To speed up copies, duplicate the scalar value into a vector, @@ -199,7 +244,7 @@ pub const Buffer = struct { else => unreachable, } const host_buffer: HostBuffer = .{ ._shape = shape_, ._strides = strides, ._data = &bytes }; - return try from(platform, host_buffer); + return try from(platform, host_buffer, .{ .wait = true }); } test constant { @@ -352,6 +397,16 @@ pub const Buffer = struct { if (self._shards.len == 1) return false; return @reduce(.Or, self._shape._sharding_info); } + + pub fn copyToMemory(self: Buffer, memory: *const pjrt.Memory) !Buffer { + var new_shards: Buffer.Shards = .{}; + for (self._shards.slice()) |shard| { + const new_shard = try shard.copyToMemory(self._api, memory); + new_shards.appendAssumeCapacity(new_shard); + } + + return Buffer{ ._shape = self._shape, ._shards = new_shards, ._api = self._api }; + } }; pub fn bufferTypeFromDtype(dt: DataType) pjrt.BufferType { diff --git a/zml/hostbuffer.zig b/zml/hostbuffer.zig index 5e8697c..6e038a3 100644 --- a/zml/hostbuffer.zig +++ b/zml/hostbuffer.zig @@ -158,7 +158,12 @@ pub const HostBuffer = struct { /// Copies this HostBuffer to the given accelerator. pub fn toDevice(self: HostBuffer, platform_: Platform) !Buffer { - return try Buffer.from(platform_, self); + return try self.toDeviceOpts(platform_, .{}); + } + + /// Copies this HostBuffer to the given accelerator (with options). + pub fn toDeviceOpts(self: HostBuffer, platform_: Platform, opts: Buffer.FromOptions) !Buffer { + return try Buffer.from(platform_, self, opts); } /// Interpret the underlying data as a contiguous slice. diff --git a/zml/pjrtx.zig b/zml/pjrtx.zig index dd1a563..13ef6b6 100644 --- a/zml/pjrtx.zig +++ b/zml/pjrtx.zig @@ -59,13 +59,9 @@ pub const Client = opaque { } pub const BufferFromHostBufferArgs = pjrt.Client.BufferFromHostBufferArgs; - pub fn bufferFromHostBuffer(self: *const Client, api: *const Api, args: BufferFromHostBufferArgs) ApiError!*Buffer { - const buffer, const event_ = try asynk.callBlocking(pjrt.Client.bufferFromHostBuffer, .{ self.inner(), api, args }); - if (event_) |event__| { - const event: *Event = @ptrCast(event__); - try event.await_(api); - } - return @ptrCast(buffer); + pub fn bufferFromHostBuffer(self: *const Client, api: *const Api, args: BufferFromHostBufferArgs) ApiError!struct { *Buffer, ?*Event } { + const buffer, const event_ = try self.inner().bufferFromHostBuffer(api, args); + return .{ @ptrCast(buffer), @ptrCast(event_) }; } pub fn deserializeAndLoad(self: *const Client, api: *const Api, bytes: []const u8) ApiError!*LoadedExecutable { diff --git a/zml/tensor.zig b/zml/tensor.zig index 51517ef..02d7386 100644 --- a/zml/tensor.zig +++ b/zml/tensor.zig @@ -2674,7 +2674,7 @@ pub const Tensor = struct { // Test with actual values, no batching. { const a_host = try zml.HostBuffer.arange(std.testing.allocator, .{ .end = 9 }, .i32); - const a = (try zml.Buffer.from(platform, a_host.reshape(.{ 3, 3 }))).withTags(.{ .a, .b }); + const a = (try zml.Buffer.from(platform, a_host.reshape(.{ 3, 3 }), .{})).withTags(.{ .a, .b }); defer a.deinit(); a_host.deinit(std.testing.allocator); @@ -2693,7 +2693,7 @@ pub const Tensor = struct { // Test with setting individual values (no batching) { const a_host = try zml.HostBuffer.arange(std.testing.allocator, .{ .end = 9 }, .i32); - const a = try zml.Buffer.from(platform, a_host); + const a = try zml.Buffer.from(platform, a_host, .{}); defer a.deinit(); a_host.deinit(std.testing.allocator);