Add fallback for runtimes lacking PJRT_Event by using thread‑pool dispatch for buffer copies and treating operations as synchronous when events are absent.

This commit is contained in:
Tarry Singh 2023-05-09 12:44:56 +00:00
parent 672df8fa2f
commit 57130577e9
3 changed files with 49 additions and 25 deletions

View File

@ -284,7 +284,7 @@ pub const Client = opaque {
host_buffer_semantics: HostBufferSemantics, host_buffer_semantics: HostBufferSemantics,
}; };
pub fn bufferFromHostBuffer(self: *const Client, api: *const Api, args: BufferFromHostBufferArgs) ApiError!struct { *Buffer, *Event } { pub fn bufferFromHostBuffer(self: *const Client, api: *const Api, args: BufferFromHostBufferArgs) ApiError!struct { *Buffer, ?*Event } {
const ret = try api.call(.PJRT_Client_BufferFromHostBuffer, .{ const ret = try api.call(.PJRT_Client_BufferFromHostBuffer, .{
.client = self.inner(), .client = self.inner(),
.data = @ptrCast(@constCast(args.data.ptr)), .data = @ptrCast(@constCast(args.data.ptr)),
@ -300,9 +300,10 @@ pub const Client = opaque {
.done_with_host_buffer = null, .done_with_host_buffer = null,
.buffer = null, .buffer = null,
}); });
return .{ return .{
@ptrCast(ret.buffer.?), @ptrCast(ret.buffer.?),
@ptrCast(ret.done_with_host_buffer.?), @ptrCast(ret.done_with_host_buffer),
}; };
} }
@ -499,7 +500,7 @@ pub const LoadedExecutable = opaque {
num_args: usize, num_args: usize,
arguments: []const [*]const *const Buffer, arguments: []const [*]const *const Buffer,
results: []const [*]*Buffer, results: []const [*]*Buffer,
events: []*Event, events: []?*Event,
non_donatable_input_indices: []const i64 = &.{}, non_donatable_input_indices: []const i64 = &.{},
}) ApiError!void { }) ApiError!void {
var options = pjrtStruct(c.PJRT_ExecuteOptions{ var options = pjrtStruct(c.PJRT_ExecuteOptions{
@ -648,13 +649,13 @@ pub const Buffer = opaque {
return ret.is_on_cpu; return ret.is_on_cpu;
} }
pub fn toHostBuffer(self: *const Buffer, api: *const Api, dst: []u8) ApiError!*Event { pub fn toHostBuffer(self: *const Buffer, api: *const Api, dst: []u8) ApiError!?*Event {
const ret = try api.call(.PJRT_Buffer_ToHostBuffer, .{ const ret = try api.call(.PJRT_Buffer_ToHostBuffer, .{
.src = self.inner(), .src = self.inner(),
.dst = @ptrCast(dst.ptr), .dst = @ptrCast(dst.ptr),
.dst_size = dst.len, .dst_size = dst.len,
}); });
return @ptrCast(ret.event.?); return @ptrCast(ret.event);
} }
pub fn getElementType(self: *const Buffer, api: *const Api) BufferType { pub fn getElementType(self: *const Buffer, api: *const Api) BufferType {

View File

@ -2,6 +2,7 @@ const std = @import("std");
const testing = std.testing; const testing = std.testing;
const pjrt = @import("pjrt"); const pjrt = @import("pjrt");
const asynk = @import("async");
const meta = @import("meta.zig"); const meta = @import("meta.zig");
const Context = @import("context.zig").Context; const Context = @import("context.zig").Context;
@ -53,7 +54,17 @@ pub const Buffer = struct {
const buffer_type = bufferTypeFromDtype(host_buffer.shape().dtype()); const buffer_type = bufferTypeFromDtype(host_buffer.shape().dtype());
const byte_strides = host_buffer.strides() orelse host_buffer.shape().computeStrides().constSlice(); const byte_strides = host_buffer.strides() orelse host_buffer.shape().computeStrides().constSlice();
var events: std.BoundedArray(*pjrt.Event, MAX_NUM_SHARDS) = .{}; const xbufferFromHostBuffer = struct {
fn do(self: *const pjrt.Client, api: *const pjrt.Api, args: pjrt.Client.BufferFromHostBufferArgs) pjrt.ApiError!*pjrt.Buffer {
const buffer, const ev = try asynk.callBlocking(pjrt.Client.bufferFromHostBuffer, .{ self, api, args });
if (ev) |e| {
e.deinit(api);
}
return buffer;
}
}.do;
var frames: std.BoundedArray(asynk.Frame(xbufferFromHostBuffer), MAX_NUM_SHARDS) = .{};
const devices = platform.getDevices(); const devices = platform.getDevices();
for (0..n_partitions) |i| { for (0..n_partitions) |i| {
// If no sharding if found, the given buffer is replicated on all devices. // If no sharding if found, the given buffer is replicated on all devices.
@ -62,21 +73,25 @@ pub const Buffer = struct {
break :buf host_buffer.slice1d(ax, .{ .start = start, .end = start + chunk_size }); break :buf host_buffer.slice1d(ax, .{ .start = start, .end = start + chunk_size });
} else host_buffer; } else host_buffer;
const pjrt_buffer, const event = try platform.pjrt_client.bufferFromHostBuffer(platform.pjrt_api, .{ const frame = try asynk.asyncc(xbufferFromHostBuffer, .{
platform.pjrt_client,
platform.pjrt_api,
.{
.data = buf.data, .data = buf.data,
.buffer_type = buffer_type, .buffer_type = buffer_type,
.dims = buf.shape().dims(), .dims = buf.shape().dims(),
.byte_strides = byte_strides, .byte_strides = byte_strides,
.device = devices[i], .device = devices[i],
.host_buffer_semantics = .ImmutableUntilTransferCompletes, .host_buffer_semantics = .ImmutableOnlyDuringCall,
},
}); });
events.appendAssumeCapacity(event); frames.appendAssumeCapacity(frame);
res._shards.appendAssumeCapacity(pjrt_buffer);
} }
for (events.constSlice()) |event| { for (frames.slice()) |*frame| {
try platform.awaitEvent(event); const pjrt_buffer = try frame.await_();
res._shards.appendAssumeCapacity(pjrt_buffer);
} }
return res; return res;
} }
@ -180,8 +195,10 @@ pub const Buffer = struct {
meta.assert(self._shape.byteSize() == @sizeOf(T), "Buffer {} has {d} bytes of data, can't load it to a {s} with {d} bytes", .{ self, self._shape.byteSize(), @typeName(T), @sizeOf(T) }); meta.assert(self._shape.byteSize() == @sizeOf(T), "Buffer {} has {d} bytes of data, can't load it to a {s} with {d} bytes", .{ self, self._shape.byteSize(), @typeName(T), @sizeOf(T) });
var res: T = undefined; var res: T = undefined;
meta.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer -> Host transfer", .{}); meta.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer -> Host transfer", .{});
const event = try self._shards.get(0).toHostBuffer(self._api, std.mem.asBytes(&res)); const maybe_event = try self._shards.get(0).toHostBuffer(self._api, std.mem.asBytes(&res));
if (maybe_event) |event| {
try event.await_(self._api); try event.await_(self._api);
}
return res; return res;
} }
@ -190,8 +207,10 @@ pub const Buffer = struct {
/// The returned `HostBuffer` doesn't own the memory. /// The returned `HostBuffer` doesn't own the memory.
pub fn toHost(self: Buffer, output: []u8) !HostBuffer { pub fn toHost(self: Buffer, output: []u8) !HostBuffer {
meta.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer -> Host transfer", .{}); meta.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer -> Host transfer", .{});
const event = try self._shards.get(0).toHostBuffer(self._api, output); const maybe_event = try self._shards.get(0).toHostBuffer(self._api, output);
if (maybe_event) |event| {
try event.await_(self._api); try event.await_(self._api);
}
return HostBuffer.fromBytes(self.shape(), output); return HostBuffer.fromBytes(self.shape(), output);
} }
@ -200,8 +219,10 @@ pub const Buffer = struct {
pub fn toHostAlloc(self: Buffer, allocator: std.mem.Allocator) !HostBuffer { pub fn toHostAlloc(self: Buffer, allocator: std.mem.Allocator) !HostBuffer {
const output = try HostBuffer.empty(allocator, self.shape()); const output = try HostBuffer.empty(allocator, self.shape());
meta.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer -> Host transfer", .{}); meta.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer -> Host transfer", .{});
const event = try self._shards.get(0).toHostBuffer(self._api, @constCast(output.data)); const maybe_event = try self._shards.get(0).toHostBuffer(self._api, @constCast(output.data));
if (maybe_event) |event| {
try event.await_(self._api); try event.await_(self._api);
}
return output; return output;
} }

View File

@ -859,7 +859,7 @@ pub fn ExeWithWeights(comptime func: anytype) type {
pub fn call(self: Self, args: Bufferized(Signature.ArgsT)) Bufferized(Signature.ReturnT) { pub fn call(self: Self, args: Bufferized(Signature.ArgsT)) Bufferized(Signature.ReturnT) {
fillBuffers(&args, self.input_per_device, self.inner.model_buffer_count, self.inner.args_buffer_count); fillBuffers(&args, self.input_per_device, self.inner.model_buffer_count, self.inner.args_buffer_count);
var events: [Platform.MAX_NUM_DEVICES]*pjrt.Event = undefined; var events = [_]?*pjrt.Event{null} ** Platform.MAX_NUM_DEVICES;
const sharding = self.platform().sharding(); const sharding = self.platform().sharding();
self.inner.exe.execute(self.inner.platform.pjrt_api, .{ self.inner.exe.execute(self.inner.platform.pjrt_api, .{
@ -873,7 +873,9 @@ pub fn ExeWithWeights(comptime func: anytype) type {
}) catch unreachable; }) catch unreachable;
for (events[0..sharding.num_partitions]) |e| { for (events[0..sharding.num_partitions]) |e| {
e.await_(self.inner.platform.pjrt_api) catch unreachable; if (e) |ev| {
ev.await_(self.inner.platform.pjrt_api) catch unreachable;
}
} }
var result: Bufferized(Signature.ReturnT) = undefined; var result: Bufferized(Signature.ReturnT) = undefined;