zml: reintroduce pjrtx to handle reactor blocking issues in async scenarios, particularly with Events.
This commit is contained in:
parent
c68ec4bc5c
commit
52ef20f981
@ -93,6 +93,27 @@ pub const AsyncThread = struct {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
pub const Notification = struct {
|
||||||
|
inner: std.Thread.ResetEvent,
|
||||||
|
|
||||||
|
pub fn init() !Notification {
|
||||||
|
return .{ .inner = .{} };
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn notify(self: *Notification) !void {
|
||||||
|
self.inner.set();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn wait(self: *Notification) !void {
|
||||||
|
self.inner.wait();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn deinit(self: *Notification) void {
|
||||||
|
self.inner.set();
|
||||||
|
self.* = undefined;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
pub fn StdIn() !File {
|
pub fn StdIn() !File {
|
||||||
return File.init(std.io.getStdIn()) catch @panic("Unable to open stdin");
|
return File.init(std.io.getStdIn()) catch @panic("Unable to open stdin");
|
||||||
}
|
}
|
||||||
|
|||||||
@ -370,5 +370,3 @@ pub const Mutex = struct {
|
|||||||
_ = self.inner.recv();
|
_ = self.inner.recv();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
pub const inCoro = libcoro.inCoro;
|
|
||||||
|
|||||||
@ -1187,8 +1187,15 @@ pub fn stablehloVersionFromCompatibilityRequirement(requirement: c.MlirStablehlo
|
|||||||
return context.str;
|
return context.str;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn stablehloGetMinimumVersion(writer: anytype) void {
|
pub fn getMinimumVersion() []const u8 {
|
||||||
var context = .{ .writer = writer };
|
const state = struct {
|
||||||
|
var buf: [32]u8 = undefined;
|
||||||
|
var str: []const u8 = undefined;
|
||||||
|
var once = std.once(call);
|
||||||
|
|
||||||
|
fn call() void {
|
||||||
|
var stream = std.io.fixedBufferStream(&buf);
|
||||||
|
var context = .{ .writer = stream.writer() };
|
||||||
const WriterContext = @TypeOf(context);
|
const WriterContext = @TypeOf(context);
|
||||||
|
|
||||||
c.stablehloGetMinimumVersion((struct {
|
c.stablehloGetMinimumVersion((struct {
|
||||||
@ -1197,13 +1204,20 @@ pub fn stablehloGetMinimumVersion(writer: anytype) void {
|
|||||||
_ = inner_ctx.writer.write(mlir.fromStringRef(mlir_str)) catch unreachable;
|
_ = inner_ctx.writer.write(mlir.fromStringRef(mlir_str)) catch unreachable;
|
||||||
}
|
}
|
||||||
}).callback, &context);
|
}).callback, &context);
|
||||||
|
|
||||||
|
str = buf[0..stream.pos];
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
state.once.call();
|
||||||
|
return state.str;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn serializePortableArtifact(module_str: []const u8, target_version: []const u8, writer: anytype) !void {
|
pub fn serializePortableArtifact(bytecode: []const u8, target_version: []const u8, writer: anytype) !void {
|
||||||
var context = .{ .writer = writer };
|
var context = .{ .writer = writer };
|
||||||
const WriterContext = @TypeOf(context);
|
const WriterContext = @TypeOf(context);
|
||||||
|
|
||||||
try mlir.successOr(c.stablehloSerializePortableArtifactFromStringRef(mlir.stringRef(module_str), mlir.stringRef(target_version), (struct {
|
try mlir.successOr(c.stablehloSerializePortableArtifactFromStringRef(mlir.stringRef(bytecode), mlir.stringRef(target_version), (struct {
|
||||||
pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.C) void {
|
pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.C) void {
|
||||||
const inner_ctx: *WriterContext = @ptrCast(@alignCast(userdata));
|
const inner_ctx: *WriterContext = @ptrCast(@alignCast(userdata));
|
||||||
_ = inner_ctx.writer.write(mlir.fromStringRef(mlir_str)) catch unreachable;
|
_ = inner_ctx.writer.write(mlir.fromStringRef(mlir_str)) catch unreachable;
|
||||||
|
|||||||
@ -1,10 +1,10 @@
|
|||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
const testing = std.testing;
|
const testing = std.testing;
|
||||||
|
|
||||||
const pjrt = @import("pjrt");
|
const meta = @import("meta.zig");
|
||||||
|
const pjrt = @import("pjrtx.zig");
|
||||||
const asynk = @import("async");
|
const asynk = @import("async");
|
||||||
|
|
||||||
const meta = @import("meta.zig");
|
|
||||||
const Context = @import("context.zig").Context;
|
const Context = @import("context.zig").Context;
|
||||||
const Data = @import("dtype.zig").Data;
|
const Data = @import("dtype.zig").Data;
|
||||||
const DataType = @import("dtype.zig").DataType;
|
const DataType = @import("dtype.zig").DataType;
|
||||||
@ -54,17 +54,7 @@ pub const Buffer = struct {
|
|||||||
const buffer_type = bufferTypeFromDtype(host_buffer.shape().dtype());
|
const buffer_type = bufferTypeFromDtype(host_buffer.shape().dtype());
|
||||||
const byte_strides = host_buffer.strides() orelse host_buffer.shape().computeStrides().constSlice();
|
const byte_strides = host_buffer.strides() orelse host_buffer.shape().computeStrides().constSlice();
|
||||||
|
|
||||||
const xbufferFromHostBuffer = struct {
|
var frames: std.BoundedArray(asynk.Frame(pjrt.Client.bufferFromHostBuffer), MAX_NUM_SHARDS) = .{};
|
||||||
fn do(self: *const pjrt.Client, api: *const pjrt.Api, args: pjrt.Client.BufferFromHostBufferArgs) pjrt.ApiError!*pjrt.Buffer {
|
|
||||||
const buffer, const ev = try asynk.callBlocking(pjrt.Client.bufferFromHostBuffer, .{ self, api, args });
|
|
||||||
if (ev) |e| {
|
|
||||||
e.deinit(api);
|
|
||||||
}
|
|
||||||
return buffer;
|
|
||||||
}
|
|
||||||
}.do;
|
|
||||||
|
|
||||||
var frames: std.BoundedArray(asynk.Frame(xbufferFromHostBuffer), MAX_NUM_SHARDS) = .{};
|
|
||||||
const devices = platform.getDevices();
|
const devices = platform.getDevices();
|
||||||
for (0..n_partitions) |i| {
|
for (0..n_partitions) |i| {
|
||||||
// If no sharding if found, the given buffer is replicated on all devices.
|
// If no sharding if found, the given buffer is replicated on all devices.
|
||||||
@ -73,7 +63,7 @@ pub const Buffer = struct {
|
|||||||
break :buf host_buffer.slice1d(ax, .{ .start = start, .end = start + chunk_size });
|
break :buf host_buffer.slice1d(ax, .{ .start = start, .end = start + chunk_size });
|
||||||
} else host_buffer;
|
} else host_buffer;
|
||||||
|
|
||||||
const frame = try asynk.asyncc(xbufferFromHostBuffer, .{
|
const frame = try asynk.asyncc(pjrt.Client.bufferFromHostBuffer, .{
|
||||||
platform.pjrt_client,
|
platform.pjrt_client,
|
||||||
platform.pjrt_api,
|
platform.pjrt_api,
|
||||||
.{
|
.{
|
||||||
|
|||||||
@ -1,14 +1,14 @@
|
|||||||
const builtin = @import("builtin");
|
const builtin = @import("builtin");
|
||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
|
|
||||||
const asynk = @import("async");
|
|
||||||
const mlir = @import("mlir");
|
const mlir = @import("mlir");
|
||||||
const pjrt = @import("pjrt");
|
|
||||||
const c = @import("c");
|
const c = @import("c");
|
||||||
const runfiles = @import("runfiles");
|
const runfiles = @import("runfiles");
|
||||||
const runtimes = @import("runtimes");
|
const runtimes = @import("runtimes");
|
||||||
|
|
||||||
const platform = @import("platform.zig");
|
const platform = @import("platform.zig");
|
||||||
|
const pjrt = @import("pjrtx.zig");
|
||||||
|
|
||||||
|
const available_targets = @import("platform.zig").available_targets;
|
||||||
const Target = @import("platform.zig").Target;
|
const Target = @import("platform.zig").Target;
|
||||||
const Platform = @import("platform.zig").Platform;
|
const Platform = @import("platform.zig").Platform;
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
const builtin = @import("builtin");
|
const builtin = @import("builtin");
|
||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
const pjrt = @import("pjrt");
|
|
||||||
|
|
||||||
const runfiles = @import("runfiles");
|
const runfiles = @import("runfiles");
|
||||||
|
|
||||||
@ -8,6 +7,7 @@ const xla_pb = @import("//xla:xla_proto");
|
|||||||
const meta = @import("meta.zig");
|
const meta = @import("meta.zig");
|
||||||
const mlir = @import("mlir.zig");
|
const mlir = @import("mlir.zig");
|
||||||
const ops = @import("ops.zig");
|
const ops = @import("ops.zig");
|
||||||
|
const pjrt = @import("pjrtx.zig");
|
||||||
const protobuf = @import("io/protobuf");
|
const protobuf = @import("io/protobuf");
|
||||||
const asynk = @import("async");
|
const asynk = @import("async");
|
||||||
const aio = @import("aio.zig");
|
const aio = @import("aio.zig");
|
||||||
@ -1170,19 +1170,7 @@ fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, m
|
|||||||
|
|
||||||
const options_bytes = try options.encode(arena);
|
const options_bytes = try options.encode(arena);
|
||||||
|
|
||||||
var mlir_bytecode = std.ArrayList(u8).init(arena);
|
const loaded_executable = try platform.pjrt_client.compile(platform.pjrt_api, arena, module, options_bytes);
|
||||||
defer mlir_bytecode.deinit();
|
|
||||||
// Note: we may need to restore IR downgrade if we need to support old pjrt plugins.
|
|
||||||
module.op().writeBytecode(mlir_bytecode.writer());
|
|
||||||
|
|
||||||
const loaded_executable = try asynk.callBlocking(pjrt.Client.compile, .{
|
|
||||||
platform.pjrt_client, platform.pjrt_api, .{
|
|
||||||
.bytecode = mlir_bytecode.items,
|
|
||||||
.bytecode_format = .mlir,
|
|
||||||
.compile_options_pb = options_bytes,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
errdefer loaded_executable.deinit();
|
errdefer loaded_executable.deinit();
|
||||||
|
|
||||||
if (platform.compilation_options.cache_location) |compilation_cache_location| {
|
if (platform.compilation_options.cache_location) |compilation_cache_location| {
|
||||||
|
|||||||
203
zml/pjrtx.zig
Normal file
203
zml/pjrtx.zig
Normal file
@ -0,0 +1,203 @@
|
|||||||
|
const builtin = @import("builtin");
|
||||||
|
const std = @import("std");
|
||||||
|
|
||||||
|
const asynk = @import("async");
|
||||||
|
const mlir = @import("mlir");
|
||||||
|
|
||||||
|
const pjrt = @import("pjrt");
|
||||||
|
const dtype = @import("dtype.zig");
|
||||||
|
const meta = @import("meta.zig");
|
||||||
|
const dialects = @import("mlir/dialects");
|
||||||
|
|
||||||
|
pub const Profiler = pjrt.Profiler;
|
||||||
|
pub const ApiError = pjrt.ApiError;
|
||||||
|
pub const ErrorCode = pjrt.ErrorCode;
|
||||||
|
|
||||||
|
const Target = @import("platform.zig").Target;
|
||||||
|
|
||||||
|
const log = std.log.scoped(.zml);
|
||||||
|
|
||||||
|
pub const Buffer = pjrt.Buffer;
|
||||||
|
pub const BufferType = pjrt.BufferType;
|
||||||
|
pub const Device = pjrt.Device;
|
||||||
|
pub const DeviceDescription = pjrt.DeviceDescription;
|
||||||
|
pub const Api = pjrt.Api;
|
||||||
|
pub const NamedValue = pjrt.NamedValue;
|
||||||
|
pub const ClientInitError = pjrt.ClientInitError;
|
||||||
|
pub const CompileError = std.mem.Allocator.Error || ApiError;
|
||||||
|
pub const Error = pjrt.Error;
|
||||||
|
pub const GetCostAnalysisError = pjrt.GetCostAnalysisError;
|
||||||
|
pub const SerializeResult = pjrt.SerializeResult;
|
||||||
|
pub const Executable = pjrt.Executable;
|
||||||
|
pub const ExecuteError = ApiError;
|
||||||
|
|
||||||
|
fn InnerMixin(comptime innerT: type) type {
|
||||||
|
return struct {
|
||||||
|
inline fn inner(self: anytype) if (@typeInfo(@TypeOf(self)).Pointer.is_const) *const innerT else *innerT {
|
||||||
|
return @ptrCast(self);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub const Client = opaque {
|
||||||
|
const inner = InnerMixin(pjrt.Client).inner;
|
||||||
|
|
||||||
|
pub fn init(api: *const Api, create_options: []const NamedValue) ClientInitError!*Client {
|
||||||
|
return @ptrCast(try pjrt.Client.init(api, create_options));
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn deinit(self: *Client, api: *const Api) void {
|
||||||
|
self.inner().deinit(api);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn getPlatformName(self: *const Client, api: *const Api) []const u8 {
|
||||||
|
return self.inner().getPlatformName(api);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn getDevices(self: *const Client, api: *const Api) []const *const Device {
|
||||||
|
return self.inner().getDevices(api);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn getAddressableDevices(self: *const Client, api: *const Api) []const *const Device {
|
||||||
|
return self.inner().getAddressableDevices(api);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub const BufferFromHostBufferArgs = pjrt.Client.BufferFromHostBufferArgs;
|
||||||
|
pub fn bufferFromHostBuffer(self: *const Client, api: *const Api, args: BufferFromHostBufferArgs) !*Buffer {
|
||||||
|
const buffer, const event_ = try self.inner().bufferFromHostBuffer(api, args);
|
||||||
|
if (event_) |event__| {
|
||||||
|
const event: *Event = @ptrCast(event__);
|
||||||
|
try event.await_(api);
|
||||||
|
}
|
||||||
|
return buffer;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn deserializeAndLoad(self: *const Client, api: *const Api, bytes: []const u8) ApiError!*LoadedExecutable {
|
||||||
|
return @ptrCast(try asynk.callBlocking(pjrt.Client.deserializeAndLoad, .{ self.inner(), api, bytes }));
|
||||||
|
}
|
||||||
|
|
||||||
|
pub const CreateViewOfDeviceBufferArgs = pjrt.Client.CreateViewOfDeviceBufferArgs;
|
||||||
|
pub fn createViewOfDeviceBuffer(self: *const Client, api: *const Api, args: CreateViewOfDeviceBufferArgs) ApiError!*Buffer {
|
||||||
|
var args_ = args;
|
||||||
|
args_.on_delete_callback = args_.on_delete_callback orelse &(struct {
|
||||||
|
fn call(_: ?*anyopaque, _: ?*anyopaque) callconv(.C) void {}
|
||||||
|
}.call);
|
||||||
|
const buf = try self.inner().createViewOfDeviceBuffer(api, args_);
|
||||||
|
return @ptrCast(buf);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn compileSync(self: *const Client, api: *const Api, allocator: std.mem.Allocator, module: mlir.Module, compile_options_pb: []const u8) CompileError!*LoadedExecutable {
|
||||||
|
var bytecode = std.ArrayList(u8).init(allocator);
|
||||||
|
defer bytecode.deinit();
|
||||||
|
module.op().writeBytecodeWithConfig(bytecode.writer(), .{ .desiredEmitedVersion = 1 }) catch {
|
||||||
|
std.debug.print("failed to write module bytecode\n", .{});
|
||||||
|
unreachable;
|
||||||
|
};
|
||||||
|
|
||||||
|
var serialized_buffer = std.ArrayList(u8).init(allocator);
|
||||||
|
defer serialized_buffer.deinit();
|
||||||
|
dialects.stablehlo.serializePortableArtifact(bytecode.items, dialects.stablehlo.getMinimumVersion(), serialized_buffer.writer()) catch {
|
||||||
|
std.debug.print("failed to serialize to portable artifact\n", .{});
|
||||||
|
unreachable;
|
||||||
|
};
|
||||||
|
|
||||||
|
return @ptrCast(try self.inner().compile(api, .{
|
||||||
|
.bytecode = serialized_buffer.items,
|
||||||
|
.bytecode_format = .mlir,
|
||||||
|
.compile_options_pb = compile_options_pb,
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn compile(self: *const Client, api: *const Api, allocator: std.mem.Allocator, module: mlir.Module, compile_options_pb: []const u8) CompileError!*LoadedExecutable {
|
||||||
|
return try asynk.callBlocking(compileSync, .{ self, api, allocator, module, compile_options_pb });
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the Profiler for this API.
|
||||||
|
/// Not all platform have a profiling api, for those the profiler object will do nothing.
|
||||||
|
/// Platforms with known profiler extensions: cuda, xpu
|
||||||
|
pub fn getProfiler(self: *const Client, api: *const Api, options: pjrt.Profiler.Options) pjrt.Profiler {
|
||||||
|
return self.inner().getProfiler(api, options);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
pub const Event = opaque {
|
||||||
|
pub const inner = InnerMixin(pjrt.Event).inner;
|
||||||
|
|
||||||
|
pub fn deinit(self: *Event, api: *const Api) void {
|
||||||
|
self.inner().deinit(api);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn isReady(self: *const Event, api: *const Api) bool {
|
||||||
|
return self.inner().isReady(api);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn getEventError(self: *const Event, api: *const Api) ?*Error {
|
||||||
|
return self.inner().getEventError(api);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn await_(self: *Event, api: *const Api) !void {
|
||||||
|
defer self.deinit(api);
|
||||||
|
|
||||||
|
var ctx = struct {
|
||||||
|
err: ?*pjrt.Error = null,
|
||||||
|
notif: asynk.Notification,
|
||||||
|
}{
|
||||||
|
.notif = try asynk.Notification.init(),
|
||||||
|
};
|
||||||
|
defer ctx.notif.deinit();
|
||||||
|
|
||||||
|
try self.inner().onReady(api, &(struct {
|
||||||
|
fn call(err: ?*pjrt.Error, user_arg: ?*anyopaque) callconv(.C) void {
|
||||||
|
const ctx_: *@TypeOf(ctx) = @ptrCast(@alignCast(user_arg.?));
|
||||||
|
ctx_.err = err;
|
||||||
|
ctx_.notif.notify() catch @panic("Unable to notify");
|
||||||
|
}
|
||||||
|
}.call), &ctx);
|
||||||
|
|
||||||
|
try ctx.notif.wait();
|
||||||
|
if (ctx.err) |e| {
|
||||||
|
defer e.deinit(api);
|
||||||
|
return e.getCode(api).toApiError();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
pub const LoadedExecutable = opaque {
|
||||||
|
const inner = InnerMixin(pjrt.LoadedExecutable).inner;
|
||||||
|
|
||||||
|
pub fn deinit(self: *LoadedExecutable, api: *const Api) void {
|
||||||
|
self.inner().deinit(api);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn delete(self: *LoadedExecutable, api: *const Api) void {
|
||||||
|
self.inner().delete(api);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn isDeleted(self: *const LoadedExecutable, api: *const Api) bool {
|
||||||
|
return self.inner().isDeleted(api);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn getAddressableDevices(self: *const LoadedExecutable, api: *const Api) []*const Device {
|
||||||
|
return self.inner().getAddressableDevices(api);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn execute(self: *const LoadedExecutable, api: *const Api, args: struct {
|
||||||
|
arguments: []const [*]const *const Buffer,
|
||||||
|
num_args: usize,
|
||||||
|
results: []const [*]*Buffer,
|
||||||
|
events: []?*Event,
|
||||||
|
non_donatable_input_indices: []const i64 = &.{},
|
||||||
|
}) ExecuteError!void {
|
||||||
|
try asynk.callBlocking(pjrt.LoadedExecutable.execute, .{ self.inner(), api, .{
|
||||||
|
.num_args = args.num_args,
|
||||||
|
.arguments = @ptrCast(args.arguments),
|
||||||
|
.results = @ptrCast(args.results),
|
||||||
|
.events = @ptrCast(args.events),
|
||||||
|
.non_donatable_input_indices = args.non_donatable_input_indices,
|
||||||
|
} });
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn getExecutable(self: *LoadedExecutable, api: *const Api) ApiError!*Executable {
|
||||||
|
return try self.inner().getExecutable(api);
|
||||||
|
}
|
||||||
|
};
|
||||||
@ -1,12 +1,12 @@
|
|||||||
const builtin = @import("builtin");
|
const builtin = @import("builtin");
|
||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
|
|
||||||
const pjrt = @import("pjrt");
|
|
||||||
const asynk = @import("async");
|
const asynk = @import("async");
|
||||||
const runtimes = @import("runtimes");
|
const runtimes = @import("runtimes");
|
||||||
|
|
||||||
const meta = @import("meta.zig");
|
const meta = @import("meta.zig");
|
||||||
const module = @import("module.zig");
|
const module = @import("module.zig");
|
||||||
|
const pjrt = @import("pjrtx.zig");
|
||||||
const log = std.log.scoped(.zml);
|
const log = std.log.scoped(.zml);
|
||||||
|
|
||||||
pub const Target = runtimes.Platform;
|
pub const Target = runtimes.Platform;
|
||||||
@ -93,37 +93,4 @@ pub const Platform = struct {
|
|||||||
pub fn getProfiler(self: Platform, options: pjrt.Profiler.Options) pjrt.Profiler {
|
pub fn getProfiler(self: Platform, options: pjrt.Profiler.Options) pjrt.Profiler {
|
||||||
return self.pjrt_client.getProfiler(self.pjrt_api, options);
|
return self.pjrt_client.getProfiler(self.pjrt_api, options);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Suspend the current co-routine while awaiting for a pjrt event to be over.
|
|
||||||
pub fn awaitEvent(self: Platform, event: *pjrt.Event) !void {
|
|
||||||
defer event.deinit(self.pjrt_api);
|
|
||||||
// If we aren't in a coroutine just use the normal blocking api.
|
|
||||||
if (!asynk.inCoro()) {
|
|
||||||
return try event.await_(self.pjrt_api);
|
|
||||||
}
|
|
||||||
|
|
||||||
var ctx = struct {
|
|
||||||
err: ?*pjrt.Error = null,
|
|
||||||
notif: asynk.Notification,
|
|
||||||
ready: bool = false,
|
|
||||||
}{
|
|
||||||
.notif = try asynk.Notification.init(),
|
|
||||||
};
|
|
||||||
defer ctx.notif.deinit();
|
|
||||||
|
|
||||||
try event.onReady(self.pjrt_api, &(struct {
|
|
||||||
fn call(err: ?*pjrt.Error, user_arg: ?*anyopaque) callconv(.C) void {
|
|
||||||
const ctx_: *@TypeOf(ctx) = @ptrCast(@alignCast(user_arg.?));
|
|
||||||
ctx_.err = err;
|
|
||||||
@atomicStore(bool, &ctx_.ready, true, .seq_cst);
|
|
||||||
ctx_.notif.notify() catch @panic("Unable to notify");
|
|
||||||
}
|
|
||||||
}.call), &ctx);
|
|
||||||
// Suspend
|
|
||||||
try ctx.notif.wait();
|
|
||||||
if (ctx.err) |e| {
|
|
||||||
defer e.deinit(self.pjrt_api);
|
|
||||||
return e.getCode(self.pjrt_api).toApiError();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user