Add fallback for runtimes lacking PJRT_Event by using thread‑pool dispatch for buffer copies and treating operations as synchronous when events are absent.
This commit is contained in:
parent
672df8fa2f
commit
57130577e9
@ -284,7 +284,7 @@ pub const Client = opaque {
|
||||
host_buffer_semantics: HostBufferSemantics,
|
||||
};
|
||||
|
||||
pub fn bufferFromHostBuffer(self: *const Client, api: *const Api, args: BufferFromHostBufferArgs) ApiError!struct { *Buffer, *Event } {
|
||||
pub fn bufferFromHostBuffer(self: *const Client, api: *const Api, args: BufferFromHostBufferArgs) ApiError!struct { *Buffer, ?*Event } {
|
||||
const ret = try api.call(.PJRT_Client_BufferFromHostBuffer, .{
|
||||
.client = self.inner(),
|
||||
.data = @ptrCast(@constCast(args.data.ptr)),
|
||||
@ -300,9 +300,10 @@ pub const Client = opaque {
|
||||
.done_with_host_buffer = null,
|
||||
.buffer = null,
|
||||
});
|
||||
|
||||
return .{
|
||||
@ptrCast(ret.buffer.?),
|
||||
@ptrCast(ret.done_with_host_buffer.?),
|
||||
@ptrCast(ret.done_with_host_buffer),
|
||||
};
|
||||
}
|
||||
|
||||
@ -499,7 +500,7 @@ pub const LoadedExecutable = opaque {
|
||||
num_args: usize,
|
||||
arguments: []const [*]const *const Buffer,
|
||||
results: []const [*]*Buffer,
|
||||
events: []*Event,
|
||||
events: []?*Event,
|
||||
non_donatable_input_indices: []const i64 = &.{},
|
||||
}) ApiError!void {
|
||||
var options = pjrtStruct(c.PJRT_ExecuteOptions{
|
||||
@ -648,13 +649,13 @@ pub const Buffer = opaque {
|
||||
return ret.is_on_cpu;
|
||||
}
|
||||
|
||||
pub fn toHostBuffer(self: *const Buffer, api: *const Api, dst: []u8) ApiError!*Event {
|
||||
pub fn toHostBuffer(self: *const Buffer, api: *const Api, dst: []u8) ApiError!?*Event {
|
||||
const ret = try api.call(.PJRT_Buffer_ToHostBuffer, .{
|
||||
.src = self.inner(),
|
||||
.dst = @ptrCast(dst.ptr),
|
||||
.dst_size = dst.len,
|
||||
});
|
||||
return @ptrCast(ret.event.?);
|
||||
return @ptrCast(ret.event);
|
||||
}
|
||||
|
||||
pub fn getElementType(self: *const Buffer, api: *const Api) BufferType {
|
||||
|
||||
@ -2,6 +2,7 @@ const std = @import("std");
|
||||
const testing = std.testing;
|
||||
|
||||
const pjrt = @import("pjrt");
|
||||
const asynk = @import("async");
|
||||
|
||||
const meta = @import("meta.zig");
|
||||
const Context = @import("context.zig").Context;
|
||||
@ -53,7 +54,17 @@ pub const Buffer = struct {
|
||||
const buffer_type = bufferTypeFromDtype(host_buffer.shape().dtype());
|
||||
const byte_strides = host_buffer.strides() orelse host_buffer.shape().computeStrides().constSlice();
|
||||
|
||||
var events: std.BoundedArray(*pjrt.Event, MAX_NUM_SHARDS) = .{};
|
||||
const xbufferFromHostBuffer = struct {
|
||||
fn do(self: *const pjrt.Client, api: *const pjrt.Api, args: pjrt.Client.BufferFromHostBufferArgs) pjrt.ApiError!*pjrt.Buffer {
|
||||
const buffer, const ev = try asynk.callBlocking(pjrt.Client.bufferFromHostBuffer, .{ self, api, args });
|
||||
if (ev) |e| {
|
||||
e.deinit(api);
|
||||
}
|
||||
return buffer;
|
||||
}
|
||||
}.do;
|
||||
|
||||
var frames: std.BoundedArray(asynk.Frame(xbufferFromHostBuffer), MAX_NUM_SHARDS) = .{};
|
||||
const devices = platform.getDevices();
|
||||
for (0..n_partitions) |i| {
|
||||
// If no sharding if found, the given buffer is replicated on all devices.
|
||||
@ -62,21 +73,25 @@ pub const Buffer = struct {
|
||||
break :buf host_buffer.slice1d(ax, .{ .start = start, .end = start + chunk_size });
|
||||
} else host_buffer;
|
||||
|
||||
const pjrt_buffer, const event = try platform.pjrt_client.bufferFromHostBuffer(platform.pjrt_api, .{
|
||||
.data = buf.data,
|
||||
.buffer_type = buffer_type,
|
||||
.dims = buf.shape().dims(),
|
||||
.byte_strides = byte_strides,
|
||||
.device = devices[i],
|
||||
.host_buffer_semantics = .ImmutableUntilTransferCompletes,
|
||||
const frame = try asynk.asyncc(xbufferFromHostBuffer, .{
|
||||
platform.pjrt_client,
|
||||
platform.pjrt_api,
|
||||
.{
|
||||
.data = buf.data,
|
||||
.buffer_type = buffer_type,
|
||||
.dims = buf.shape().dims(),
|
||||
.byte_strides = byte_strides,
|
||||
.device = devices[i],
|
||||
.host_buffer_semantics = .ImmutableOnlyDuringCall,
|
||||
},
|
||||
});
|
||||
|
||||
events.appendAssumeCapacity(event);
|
||||
res._shards.appendAssumeCapacity(pjrt_buffer);
|
||||
frames.appendAssumeCapacity(frame);
|
||||
}
|
||||
|
||||
for (events.constSlice()) |event| {
|
||||
try platform.awaitEvent(event);
|
||||
for (frames.slice()) |*frame| {
|
||||
const pjrt_buffer = try frame.await_();
|
||||
res._shards.appendAssumeCapacity(pjrt_buffer);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
@ -180,8 +195,10 @@ pub const Buffer = struct {
|
||||
meta.assert(self._shape.byteSize() == @sizeOf(T), "Buffer {} has {d} bytes of data, can't load it to a {s} with {d} bytes", .{ self, self._shape.byteSize(), @typeName(T), @sizeOf(T) });
|
||||
var res: T = undefined;
|
||||
meta.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer -> Host transfer", .{});
|
||||
const event = try self._shards.get(0).toHostBuffer(self._api, std.mem.asBytes(&res));
|
||||
try event.await_(self._api);
|
||||
const maybe_event = try self._shards.get(0).toHostBuffer(self._api, std.mem.asBytes(&res));
|
||||
if (maybe_event) |event| {
|
||||
try event.await_(self._api);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
@ -190,8 +207,10 @@ pub const Buffer = struct {
|
||||
/// The returned `HostBuffer` doesn't own the memory.
|
||||
pub fn toHost(self: Buffer, output: []u8) !HostBuffer {
|
||||
meta.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer -> Host transfer", .{});
|
||||
const event = try self._shards.get(0).toHostBuffer(self._api, output);
|
||||
try event.await_(self._api);
|
||||
const maybe_event = try self._shards.get(0).toHostBuffer(self._api, output);
|
||||
if (maybe_event) |event| {
|
||||
try event.await_(self._api);
|
||||
}
|
||||
return HostBuffer.fromBytes(self.shape(), output);
|
||||
}
|
||||
|
||||
@ -200,8 +219,10 @@ pub const Buffer = struct {
|
||||
pub fn toHostAlloc(self: Buffer, allocator: std.mem.Allocator) !HostBuffer {
|
||||
const output = try HostBuffer.empty(allocator, self.shape());
|
||||
meta.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer -> Host transfer", .{});
|
||||
const event = try self._shards.get(0).toHostBuffer(self._api, @constCast(output.data));
|
||||
try event.await_(self._api);
|
||||
const maybe_event = try self._shards.get(0).toHostBuffer(self._api, @constCast(output.data));
|
||||
if (maybe_event) |event| {
|
||||
try event.await_(self._api);
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
|
||||
@ -859,7 +859,7 @@ pub fn ExeWithWeights(comptime func: anytype) type {
|
||||
|
||||
pub fn call(self: Self, args: Bufferized(Signature.ArgsT)) Bufferized(Signature.ReturnT) {
|
||||
fillBuffers(&args, self.input_per_device, self.inner.model_buffer_count, self.inner.args_buffer_count);
|
||||
var events: [Platform.MAX_NUM_DEVICES]*pjrt.Event = undefined;
|
||||
var events = [_]?*pjrt.Event{null} ** Platform.MAX_NUM_DEVICES;
|
||||
const sharding = self.platform().sharding();
|
||||
|
||||
self.inner.exe.execute(self.inner.platform.pjrt_api, .{
|
||||
@ -873,7 +873,9 @@ pub fn ExeWithWeights(comptime func: anytype) type {
|
||||
}) catch unreachable;
|
||||
|
||||
for (events[0..sharding.num_partitions]) |e| {
|
||||
e.await_(self.inner.platform.pjrt_api) catch unreachable;
|
||||
if (e) |ev| {
|
||||
ev.await_(self.inner.platform.pjrt_api) catch unreachable;
|
||||
}
|
||||
}
|
||||
|
||||
var result: Bufferized(Signature.ReturnT) = undefined;
|
||||
|
||||
Loading…
Reference in New Issue
Block a user