From 57130577e9d570b301d252275e00a41c70e90e4a Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Tue, 9 May 2023 12:44:56 +0000 Subject: [PATCH] =?UTF-8?q?Add=20fallback=20for=20runtimes=20lacking=20PJR?= =?UTF-8?q?T=5FEvent=20by=20using=20thread=E2=80=91pool=20dispatch=20for?= =?UTF-8?q?=20buffer=20copies=20and=20treating=20operations=20as=20synchro?= =?UTF-8?q?nous=20when=20events=20are=20absent.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pjrt/pjrt.zig | 11 +++++----- zml/buffer.zig | 57 ++++++++++++++++++++++++++++++++++---------------- zml/module.zig | 6 ++++-- 3 files changed, 49 insertions(+), 25 deletions(-) diff --git a/pjrt/pjrt.zig b/pjrt/pjrt.zig index 5a4aaec..30562f3 100644 --- a/pjrt/pjrt.zig +++ b/pjrt/pjrt.zig @@ -284,7 +284,7 @@ pub const Client = opaque { host_buffer_semantics: HostBufferSemantics, }; - pub fn bufferFromHostBuffer(self: *const Client, api: *const Api, args: BufferFromHostBufferArgs) ApiError!struct { *Buffer, *Event } { + pub fn bufferFromHostBuffer(self: *const Client, api: *const Api, args: BufferFromHostBufferArgs) ApiError!struct { *Buffer, ?*Event } { const ret = try api.call(.PJRT_Client_BufferFromHostBuffer, .{ .client = self.inner(), .data = @ptrCast(@constCast(args.data.ptr)), @@ -300,9 +300,10 @@ pub const Client = opaque { .done_with_host_buffer = null, .buffer = null, }); + return .{ @ptrCast(ret.buffer.?), - @ptrCast(ret.done_with_host_buffer.?), + @ptrCast(ret.done_with_host_buffer), }; } @@ -499,7 +500,7 @@ pub const LoadedExecutable = opaque { num_args: usize, arguments: []const [*]const *const Buffer, results: []const [*]*Buffer, - events: []*Event, + events: []?*Event, non_donatable_input_indices: []const i64 = &.{}, }) ApiError!void { var options = pjrtStruct(c.PJRT_ExecuteOptions{ @@ -648,13 +649,13 @@ pub const Buffer = opaque { return ret.is_on_cpu; } - pub fn toHostBuffer(self: *const Buffer, api: *const Api, dst: []u8) ApiError!*Event { + pub fn toHostBuffer(self: *const Buffer, api: *const Api, dst: []u8) ApiError!?*Event { const ret = try api.call(.PJRT_Buffer_ToHostBuffer, .{ .src = self.inner(), .dst = @ptrCast(dst.ptr), .dst_size = dst.len, }); - return @ptrCast(ret.event.?); + return @ptrCast(ret.event); } pub fn getElementType(self: *const Buffer, api: *const Api) BufferType { diff --git a/zml/buffer.zig b/zml/buffer.zig index d093d65..1fbdc30 100644 --- a/zml/buffer.zig +++ b/zml/buffer.zig @@ -2,6 +2,7 @@ const std = @import("std"); const testing = std.testing; const pjrt = @import("pjrt"); +const asynk = @import("async"); const meta = @import("meta.zig"); const Context = @import("context.zig").Context; @@ -53,7 +54,17 @@ pub const Buffer = struct { const buffer_type = bufferTypeFromDtype(host_buffer.shape().dtype()); const byte_strides = host_buffer.strides() orelse host_buffer.shape().computeStrides().constSlice(); - var events: std.BoundedArray(*pjrt.Event, MAX_NUM_SHARDS) = .{}; + const xbufferFromHostBuffer = struct { + fn do(self: *const pjrt.Client, api: *const pjrt.Api, args: pjrt.Client.BufferFromHostBufferArgs) pjrt.ApiError!*pjrt.Buffer { + const buffer, const ev = try asynk.callBlocking(pjrt.Client.bufferFromHostBuffer, .{ self, api, args }); + if (ev) |e| { + e.deinit(api); + } + return buffer; + } + }.do; + + var frames: std.BoundedArray(asynk.Frame(xbufferFromHostBuffer), MAX_NUM_SHARDS) = .{}; const devices = platform.getDevices(); for (0..n_partitions) |i| { // If no sharding if found, the given buffer is replicated on all devices. @@ -62,21 +73,25 @@ pub const Buffer = struct { break :buf host_buffer.slice1d(ax, .{ .start = start, .end = start + chunk_size }); } else host_buffer; - const pjrt_buffer, const event = try platform.pjrt_client.bufferFromHostBuffer(platform.pjrt_api, .{ - .data = buf.data, - .buffer_type = buffer_type, - .dims = buf.shape().dims(), - .byte_strides = byte_strides, - .device = devices[i], - .host_buffer_semantics = .ImmutableUntilTransferCompletes, + const frame = try asynk.asyncc(xbufferFromHostBuffer, .{ + platform.pjrt_client, + platform.pjrt_api, + .{ + .data = buf.data, + .buffer_type = buffer_type, + .dims = buf.shape().dims(), + .byte_strides = byte_strides, + .device = devices[i], + .host_buffer_semantics = .ImmutableOnlyDuringCall, + }, }); - events.appendAssumeCapacity(event); - res._shards.appendAssumeCapacity(pjrt_buffer); + frames.appendAssumeCapacity(frame); } - for (events.constSlice()) |event| { - try platform.awaitEvent(event); + for (frames.slice()) |*frame| { + const pjrt_buffer = try frame.await_(); + res._shards.appendAssumeCapacity(pjrt_buffer); } return res; } @@ -180,8 +195,10 @@ pub const Buffer = struct { meta.assert(self._shape.byteSize() == @sizeOf(T), "Buffer {} has {d} bytes of data, can't load it to a {s} with {d} bytes", .{ self, self._shape.byteSize(), @typeName(T), @sizeOf(T) }); var res: T = undefined; meta.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer -> Host transfer", .{}); - const event = try self._shards.get(0).toHostBuffer(self._api, std.mem.asBytes(&res)); - try event.await_(self._api); + const maybe_event = try self._shards.get(0).toHostBuffer(self._api, std.mem.asBytes(&res)); + if (maybe_event) |event| { + try event.await_(self._api); + } return res; } @@ -190,8 +207,10 @@ pub const Buffer = struct { /// The returned `HostBuffer` doesn't own the memory. pub fn toHost(self: Buffer, output: []u8) !HostBuffer { meta.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer -> Host transfer", .{}); - const event = try self._shards.get(0).toHostBuffer(self._api, output); - try event.await_(self._api); + const maybe_event = try self._shards.get(0).toHostBuffer(self._api, output); + if (maybe_event) |event| { + try event.await_(self._api); + } return HostBuffer.fromBytes(self.shape(), output); } @@ -200,8 +219,10 @@ pub const Buffer = struct { pub fn toHostAlloc(self: Buffer, allocator: std.mem.Allocator) !HostBuffer { const output = try HostBuffer.empty(allocator, self.shape()); meta.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer -> Host transfer", .{}); - const event = try self._shards.get(0).toHostBuffer(self._api, @constCast(output.data)); - try event.await_(self._api); + const maybe_event = try self._shards.get(0).toHostBuffer(self._api, @constCast(output.data)); + if (maybe_event) |event| { + try event.await_(self._api); + } return output; } diff --git a/zml/module.zig b/zml/module.zig index fbc15ce..11cd146 100644 --- a/zml/module.zig +++ b/zml/module.zig @@ -859,7 +859,7 @@ pub fn ExeWithWeights(comptime func: anytype) type { pub fn call(self: Self, args: Bufferized(Signature.ArgsT)) Bufferized(Signature.ReturnT) { fillBuffers(&args, self.input_per_device, self.inner.model_buffer_count, self.inner.args_buffer_count); - var events: [Platform.MAX_NUM_DEVICES]*pjrt.Event = undefined; + var events = [_]?*pjrt.Event{null} ** Platform.MAX_NUM_DEVICES; const sharding = self.platform().sharding(); self.inner.exe.execute(self.inner.platform.pjrt_api, .{ @@ -873,7 +873,9 @@ pub fn ExeWithWeights(comptime func: anytype) type { }) catch unreachable; for (events[0..sharding.num_partitions]) |e| { - e.await_(self.inner.platform.pjrt_api) catch unreachable; + if (e) |ev| { + ev.await_(self.inner.platform.pjrt_api) catch unreachable; + } } var result: Bufferized(Signature.ReturnT) = undefined;