Add fallback for runtimes lacking PJRT_Event by using thread‑pool dispatch for buffer copies and treating operations as synchronous when events are absent.

This commit is contained in:
Tarry Singh 2023-05-09 12:44:56 +00:00
parent 672df8fa2f
commit 57130577e9
3 changed files with 49 additions and 25 deletions

View File

@ -284,7 +284,7 @@ pub const Client = opaque {
host_buffer_semantics: HostBufferSemantics,
};
pub fn bufferFromHostBuffer(self: *const Client, api: *const Api, args: BufferFromHostBufferArgs) ApiError!struct { *Buffer, *Event } {
pub fn bufferFromHostBuffer(self: *const Client, api: *const Api, args: BufferFromHostBufferArgs) ApiError!struct { *Buffer, ?*Event } {
const ret = try api.call(.PJRT_Client_BufferFromHostBuffer, .{
.client = self.inner(),
.data = @ptrCast(@constCast(args.data.ptr)),
@ -300,9 +300,10 @@ pub const Client = opaque {
.done_with_host_buffer = null,
.buffer = null,
});
return .{
@ptrCast(ret.buffer.?),
@ptrCast(ret.done_with_host_buffer.?),
@ptrCast(ret.done_with_host_buffer),
};
}
@ -499,7 +500,7 @@ pub const LoadedExecutable = opaque {
num_args: usize,
arguments: []const [*]const *const Buffer,
results: []const [*]*Buffer,
events: []*Event,
events: []?*Event,
non_donatable_input_indices: []const i64 = &.{},
}) ApiError!void {
var options = pjrtStruct(c.PJRT_ExecuteOptions{
@ -648,13 +649,13 @@ pub const Buffer = opaque {
return ret.is_on_cpu;
}
pub fn toHostBuffer(self: *const Buffer, api: *const Api, dst: []u8) ApiError!*Event {
pub fn toHostBuffer(self: *const Buffer, api: *const Api, dst: []u8) ApiError!?*Event {
const ret = try api.call(.PJRT_Buffer_ToHostBuffer, .{
.src = self.inner(),
.dst = @ptrCast(dst.ptr),
.dst_size = dst.len,
});
return @ptrCast(ret.event.?);
return @ptrCast(ret.event);
}
pub fn getElementType(self: *const Buffer, api: *const Api) BufferType {

View File

@ -2,6 +2,7 @@ const std = @import("std");
const testing = std.testing;
const pjrt = @import("pjrt");
const asynk = @import("async");
const meta = @import("meta.zig");
const Context = @import("context.zig").Context;
@ -53,7 +54,17 @@ pub const Buffer = struct {
const buffer_type = bufferTypeFromDtype(host_buffer.shape().dtype());
const byte_strides = host_buffer.strides() orelse host_buffer.shape().computeStrides().constSlice();
var events: std.BoundedArray(*pjrt.Event, MAX_NUM_SHARDS) = .{};
const xbufferFromHostBuffer = struct {
fn do(self: *const pjrt.Client, api: *const pjrt.Api, args: pjrt.Client.BufferFromHostBufferArgs) pjrt.ApiError!*pjrt.Buffer {
const buffer, const ev = try asynk.callBlocking(pjrt.Client.bufferFromHostBuffer, .{ self, api, args });
if (ev) |e| {
e.deinit(api);
}
return buffer;
}
}.do;
var frames: std.BoundedArray(asynk.Frame(xbufferFromHostBuffer), MAX_NUM_SHARDS) = .{};
const devices = platform.getDevices();
for (0..n_partitions) |i| {
// If no sharding if found, the given buffer is replicated on all devices.
@ -62,21 +73,25 @@ pub const Buffer = struct {
break :buf host_buffer.slice1d(ax, .{ .start = start, .end = start + chunk_size });
} else host_buffer;
const pjrt_buffer, const event = try platform.pjrt_client.bufferFromHostBuffer(platform.pjrt_api, .{
.data = buf.data,
.buffer_type = buffer_type,
.dims = buf.shape().dims(),
.byte_strides = byte_strides,
.device = devices[i],
.host_buffer_semantics = .ImmutableUntilTransferCompletes,
const frame = try asynk.asyncc(xbufferFromHostBuffer, .{
platform.pjrt_client,
platform.pjrt_api,
.{
.data = buf.data,
.buffer_type = buffer_type,
.dims = buf.shape().dims(),
.byte_strides = byte_strides,
.device = devices[i],
.host_buffer_semantics = .ImmutableOnlyDuringCall,
},
});
events.appendAssumeCapacity(event);
res._shards.appendAssumeCapacity(pjrt_buffer);
frames.appendAssumeCapacity(frame);
}
for (events.constSlice()) |event| {
try platform.awaitEvent(event);
for (frames.slice()) |*frame| {
const pjrt_buffer = try frame.await_();
res._shards.appendAssumeCapacity(pjrt_buffer);
}
return res;
}
@ -180,8 +195,10 @@ pub const Buffer = struct {
meta.assert(self._shape.byteSize() == @sizeOf(T), "Buffer {} has {d} bytes of data, can't load it to a {s} with {d} bytes", .{ self, self._shape.byteSize(), @typeName(T), @sizeOf(T) });
var res: T = undefined;
meta.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer -> Host transfer", .{});
const event = try self._shards.get(0).toHostBuffer(self._api, std.mem.asBytes(&res));
try event.await_(self._api);
const maybe_event = try self._shards.get(0).toHostBuffer(self._api, std.mem.asBytes(&res));
if (maybe_event) |event| {
try event.await_(self._api);
}
return res;
}
@ -190,8 +207,10 @@ pub const Buffer = struct {
/// The returned `HostBuffer` doesn't own the memory.
pub fn toHost(self: Buffer, output: []u8) !HostBuffer {
meta.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer -> Host transfer", .{});
const event = try self._shards.get(0).toHostBuffer(self._api, output);
try event.await_(self._api);
const maybe_event = try self._shards.get(0).toHostBuffer(self._api, output);
if (maybe_event) |event| {
try event.await_(self._api);
}
return HostBuffer.fromBytes(self.shape(), output);
}
@ -200,8 +219,10 @@ pub const Buffer = struct {
pub fn toHostAlloc(self: Buffer, allocator: std.mem.Allocator) !HostBuffer {
const output = try HostBuffer.empty(allocator, self.shape());
meta.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer -> Host transfer", .{});
const event = try self._shards.get(0).toHostBuffer(self._api, @constCast(output.data));
try event.await_(self._api);
const maybe_event = try self._shards.get(0).toHostBuffer(self._api, @constCast(output.data));
if (maybe_event) |event| {
try event.await_(self._api);
}
return output;
}

View File

@ -859,7 +859,7 @@ pub fn ExeWithWeights(comptime func: anytype) type {
pub fn call(self: Self, args: Bufferized(Signature.ArgsT)) Bufferized(Signature.ReturnT) {
fillBuffers(&args, self.input_per_device, self.inner.model_buffer_count, self.inner.args_buffer_count);
var events: [Platform.MAX_NUM_DEVICES]*pjrt.Event = undefined;
var events = [_]?*pjrt.Event{null} ** Platform.MAX_NUM_DEVICES;
const sharding = self.platform().sharding();
self.inner.exe.execute(self.inner.platform.pjrt_api, .{
@ -873,7 +873,9 @@ pub fn ExeWithWeights(comptime func: anytype) type {
}) catch unreachable;
for (events[0..sharding.num_partitions]) |e| {
e.await_(self.inner.platform.pjrt_api) catch unreachable;
if (e) |ev| {
ev.await_(self.inner.platform.pjrt_api) catch unreachable;
}
}
var result: Bufferized(Signature.ReturnT) = undefined;