From 52ef20f981c6cd1a887ac7b669ee40740b19ba5c Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Fri, 26 May 2023 15:54:15 +0000 Subject: [PATCH] zml: reintroduce pjrtx to handle reactor blocking issues in async scenarios, particularly with Events. --- async/threaded.zig | 21 ++++ async/zigcoro.zig | 2 - mlir/dialects/stablehlo.zig | 34 ++++-- zml/buffer.zig | 18 +--- zml/context.zig | 6 +- zml/module.zig | 16 +-- zml/pjrtx.zig | 203 ++++++++++++++++++++++++++++++++++++ zml/platform.zig | 35 +------ 8 files changed, 258 insertions(+), 77 deletions(-) create mode 100644 zml/pjrtx.zig diff --git a/async/threaded.zig b/async/threaded.zig index 9d54392..51065bf 100644 --- a/async/threaded.zig +++ b/async/threaded.zig @@ -93,6 +93,27 @@ pub const AsyncThread = struct { } }; +pub const Notification = struct { + inner: std.Thread.ResetEvent, + + pub fn init() !Notification { + return .{ .inner = .{} }; + } + + pub fn notify(self: *Notification) !void { + self.inner.set(); + } + + pub fn wait(self: *Notification) !void { + self.inner.wait(); + } + + pub fn deinit(self: *Notification) void { + self.inner.set(); + self.* = undefined; + } +}; + pub fn StdIn() !File { return File.init(std.io.getStdIn()) catch @panic("Unable to open stdin"); } diff --git a/async/zigcoro.zig b/async/zigcoro.zig index cc5bc18..bf8b959 100644 --- a/async/zigcoro.zig +++ b/async/zigcoro.zig @@ -370,5 +370,3 @@ pub const Mutex = struct { _ = self.inner.recv(); } }; - -pub const inCoro = libcoro.inCoro; diff --git a/mlir/dialects/stablehlo.zig b/mlir/dialects/stablehlo.zig index fff3021..0c54835 100644 --- a/mlir/dialects/stablehlo.zig +++ b/mlir/dialects/stablehlo.zig @@ -1187,23 +1187,37 @@ pub fn stablehloVersionFromCompatibilityRequirement(requirement: c.MlirStablehlo return context.str; } -pub fn stablehloGetMinimumVersion(writer: anytype) void { - var context = .{ .writer = writer }; - const WriterContext = @TypeOf(context); +pub fn getMinimumVersion() []const u8 { + const state = struct { + var buf: [32]u8 = undefined; + var str: []const u8 = undefined; + var once = std.once(call); - c.stablehloGetMinimumVersion((struct { - pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.C) void { - const inner_ctx: *WriterContext = @ptrCast(@alignCast(userdata)); - _ = inner_ctx.writer.write(mlir.fromStringRef(mlir_str)) catch unreachable; + fn call() void { + var stream = std.io.fixedBufferStream(&buf); + var context = .{ .writer = stream.writer() }; + const WriterContext = @TypeOf(context); + + c.stablehloGetMinimumVersion((struct { + pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.C) void { + const inner_ctx: *WriterContext = @ptrCast(@alignCast(userdata)); + _ = inner_ctx.writer.write(mlir.fromStringRef(mlir_str)) catch unreachable; + } + }).callback, &context); + + str = buf[0..stream.pos]; } - }).callback, &context); + }; + + state.once.call(); + return state.str; } -pub fn serializePortableArtifact(module_str: []const u8, target_version: []const u8, writer: anytype) !void { +pub fn serializePortableArtifact(bytecode: []const u8, target_version: []const u8, writer: anytype) !void { var context = .{ .writer = writer }; const WriterContext = @TypeOf(context); - try mlir.successOr(c.stablehloSerializePortableArtifactFromStringRef(mlir.stringRef(module_str), mlir.stringRef(target_version), (struct { + try mlir.successOr(c.stablehloSerializePortableArtifactFromStringRef(mlir.stringRef(bytecode), mlir.stringRef(target_version), (struct { pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.C) void { const inner_ctx: *WriterContext = @ptrCast(@alignCast(userdata)); _ = inner_ctx.writer.write(mlir.fromStringRef(mlir_str)) catch unreachable; diff --git a/zml/buffer.zig b/zml/buffer.zig index 1fbdc30..9dc28c3 100644 --- a/zml/buffer.zig +++ b/zml/buffer.zig @@ -1,10 +1,10 @@ const std = @import("std"); const testing = std.testing; -const pjrt = @import("pjrt"); +const meta = @import("meta.zig"); +const pjrt = @import("pjrtx.zig"); const asynk = @import("async"); -const meta = @import("meta.zig"); const Context = @import("context.zig").Context; const Data = @import("dtype.zig").Data; const DataType = @import("dtype.zig").DataType; @@ -54,17 +54,7 @@ pub const Buffer = struct { const buffer_type = bufferTypeFromDtype(host_buffer.shape().dtype()); const byte_strides = host_buffer.strides() orelse host_buffer.shape().computeStrides().constSlice(); - const xbufferFromHostBuffer = struct { - fn do(self: *const pjrt.Client, api: *const pjrt.Api, args: pjrt.Client.BufferFromHostBufferArgs) pjrt.ApiError!*pjrt.Buffer { - const buffer, const ev = try asynk.callBlocking(pjrt.Client.bufferFromHostBuffer, .{ self, api, args }); - if (ev) |e| { - e.deinit(api); - } - return buffer; - } - }.do; - - var frames: std.BoundedArray(asynk.Frame(xbufferFromHostBuffer), MAX_NUM_SHARDS) = .{}; + var frames: std.BoundedArray(asynk.Frame(pjrt.Client.bufferFromHostBuffer), MAX_NUM_SHARDS) = .{}; const devices = platform.getDevices(); for (0..n_partitions) |i| { // If no sharding if found, the given buffer is replicated on all devices. @@ -73,7 +63,7 @@ pub const Buffer = struct { break :buf host_buffer.slice1d(ax, .{ .start = start, .end = start + chunk_size }); } else host_buffer; - const frame = try asynk.asyncc(xbufferFromHostBuffer, .{ + const frame = try asynk.asyncc(pjrt.Client.bufferFromHostBuffer, .{ platform.pjrt_client, platform.pjrt_api, .{ diff --git a/zml/context.zig b/zml/context.zig index 38cb76e..30c48f3 100644 --- a/zml/context.zig +++ b/zml/context.zig @@ -1,14 +1,14 @@ const builtin = @import("builtin"); const std = @import("std"); - -const asynk = @import("async"); const mlir = @import("mlir"); -const pjrt = @import("pjrt"); const c = @import("c"); const runfiles = @import("runfiles"); const runtimes = @import("runtimes"); const platform = @import("platform.zig"); +const pjrt = @import("pjrtx.zig"); + +const available_targets = @import("platform.zig").available_targets; const Target = @import("platform.zig").Target; const Platform = @import("platform.zig").Platform; diff --git a/zml/module.zig b/zml/module.zig index 3410117..bf4976f 100644 --- a/zml/module.zig +++ b/zml/module.zig @@ -1,6 +1,5 @@ const builtin = @import("builtin"); const std = @import("std"); -const pjrt = @import("pjrt"); const runfiles = @import("runfiles"); @@ -8,6 +7,7 @@ const xla_pb = @import("//xla:xla_proto"); const meta = @import("meta.zig"); const mlir = @import("mlir.zig"); const ops = @import("ops.zig"); +const pjrt = @import("pjrtx.zig"); const protobuf = @import("io/protobuf"); const asynk = @import("async"); const aio = @import("aio.zig"); @@ -1170,19 +1170,7 @@ fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, m const options_bytes = try options.encode(arena); - var mlir_bytecode = std.ArrayList(u8).init(arena); - defer mlir_bytecode.deinit(); - // Note: we may need to restore IR downgrade if we need to support old pjrt plugins. - module.op().writeBytecode(mlir_bytecode.writer()); - - const loaded_executable = try asynk.callBlocking(pjrt.Client.compile, .{ - platform.pjrt_client, platform.pjrt_api, .{ - .bytecode = mlir_bytecode.items, - .bytecode_format = .mlir, - .compile_options_pb = options_bytes, - }, - }); - + const loaded_executable = try platform.pjrt_client.compile(platform.pjrt_api, arena, module, options_bytes); errdefer loaded_executable.deinit(); if (platform.compilation_options.cache_location) |compilation_cache_location| { diff --git a/zml/pjrtx.zig b/zml/pjrtx.zig new file mode 100644 index 0000000..61d6bc1 --- /dev/null +++ b/zml/pjrtx.zig @@ -0,0 +1,203 @@ +const builtin = @import("builtin"); +const std = @import("std"); + +const asynk = @import("async"); +const mlir = @import("mlir"); + +const pjrt = @import("pjrt"); +const dtype = @import("dtype.zig"); +const meta = @import("meta.zig"); +const dialects = @import("mlir/dialects"); + +pub const Profiler = pjrt.Profiler; +pub const ApiError = pjrt.ApiError; +pub const ErrorCode = pjrt.ErrorCode; + +const Target = @import("platform.zig").Target; + +const log = std.log.scoped(.zml); + +pub const Buffer = pjrt.Buffer; +pub const BufferType = pjrt.BufferType; +pub const Device = pjrt.Device; +pub const DeviceDescription = pjrt.DeviceDescription; +pub const Api = pjrt.Api; +pub const NamedValue = pjrt.NamedValue; +pub const ClientInitError = pjrt.ClientInitError; +pub const CompileError = std.mem.Allocator.Error || ApiError; +pub const Error = pjrt.Error; +pub const GetCostAnalysisError = pjrt.GetCostAnalysisError; +pub const SerializeResult = pjrt.SerializeResult; +pub const Executable = pjrt.Executable; +pub const ExecuteError = ApiError; + +fn InnerMixin(comptime innerT: type) type { + return struct { + inline fn inner(self: anytype) if (@typeInfo(@TypeOf(self)).Pointer.is_const) *const innerT else *innerT { + return @ptrCast(self); + } + }; +} + +pub const Client = opaque { + const inner = InnerMixin(pjrt.Client).inner; + + pub fn init(api: *const Api, create_options: []const NamedValue) ClientInitError!*Client { + return @ptrCast(try pjrt.Client.init(api, create_options)); + } + + pub fn deinit(self: *Client, api: *const Api) void { + self.inner().deinit(api); + } + + pub fn getPlatformName(self: *const Client, api: *const Api) []const u8 { + return self.inner().getPlatformName(api); + } + + pub fn getDevices(self: *const Client, api: *const Api) []const *const Device { + return self.inner().getDevices(api); + } + + pub fn getAddressableDevices(self: *const Client, api: *const Api) []const *const Device { + return self.inner().getAddressableDevices(api); + } + + pub const BufferFromHostBufferArgs = pjrt.Client.BufferFromHostBufferArgs; + pub fn bufferFromHostBuffer(self: *const Client, api: *const Api, args: BufferFromHostBufferArgs) !*Buffer { + const buffer, const event_ = try self.inner().bufferFromHostBuffer(api, args); + if (event_) |event__| { + const event: *Event = @ptrCast(event__); + try event.await_(api); + } + return buffer; + } + + pub fn deserializeAndLoad(self: *const Client, api: *const Api, bytes: []const u8) ApiError!*LoadedExecutable { + return @ptrCast(try asynk.callBlocking(pjrt.Client.deserializeAndLoad, .{ self.inner(), api, bytes })); + } + + pub const CreateViewOfDeviceBufferArgs = pjrt.Client.CreateViewOfDeviceBufferArgs; + pub fn createViewOfDeviceBuffer(self: *const Client, api: *const Api, args: CreateViewOfDeviceBufferArgs) ApiError!*Buffer { + var args_ = args; + args_.on_delete_callback = args_.on_delete_callback orelse &(struct { + fn call(_: ?*anyopaque, _: ?*anyopaque) callconv(.C) void {} + }.call); + const buf = try self.inner().createViewOfDeviceBuffer(api, args_); + return @ptrCast(buf); + } + + fn compileSync(self: *const Client, api: *const Api, allocator: std.mem.Allocator, module: mlir.Module, compile_options_pb: []const u8) CompileError!*LoadedExecutable { + var bytecode = std.ArrayList(u8).init(allocator); + defer bytecode.deinit(); + module.op().writeBytecodeWithConfig(bytecode.writer(), .{ .desiredEmitedVersion = 1 }) catch { + std.debug.print("failed to write module bytecode\n", .{}); + unreachable; + }; + + var serialized_buffer = std.ArrayList(u8).init(allocator); + defer serialized_buffer.deinit(); + dialects.stablehlo.serializePortableArtifact(bytecode.items, dialects.stablehlo.getMinimumVersion(), serialized_buffer.writer()) catch { + std.debug.print("failed to serialize to portable artifact\n", .{}); + unreachable; + }; + + return @ptrCast(try self.inner().compile(api, .{ + .bytecode = serialized_buffer.items, + .bytecode_format = .mlir, + .compile_options_pb = compile_options_pb, + })); + } + + pub fn compile(self: *const Client, api: *const Api, allocator: std.mem.Allocator, module: mlir.Module, compile_options_pb: []const u8) CompileError!*LoadedExecutable { + return try asynk.callBlocking(compileSync, .{ self, api, allocator, module, compile_options_pb }); + } + + /// Returns the Profiler for this API. + /// Not all platform have a profiling api, for those the profiler object will do nothing. + /// Platforms with known profiler extensions: cuda, xpu + pub fn getProfiler(self: *const Client, api: *const Api, options: pjrt.Profiler.Options) pjrt.Profiler { + return self.inner().getProfiler(api, options); + } +}; + +pub const Event = opaque { + pub const inner = InnerMixin(pjrt.Event).inner; + + pub fn deinit(self: *Event, api: *const Api) void { + self.inner().deinit(api); + } + + pub fn isReady(self: *const Event, api: *const Api) bool { + return self.inner().isReady(api); + } + + pub fn getEventError(self: *const Event, api: *const Api) ?*Error { + return self.inner().getEventError(api); + } + + pub fn await_(self: *Event, api: *const Api) !void { + defer self.deinit(api); + + var ctx = struct { + err: ?*pjrt.Error = null, + notif: asynk.Notification, + }{ + .notif = try asynk.Notification.init(), + }; + defer ctx.notif.deinit(); + + try self.inner().onReady(api, &(struct { + fn call(err: ?*pjrt.Error, user_arg: ?*anyopaque) callconv(.C) void { + const ctx_: *@TypeOf(ctx) = @ptrCast(@alignCast(user_arg.?)); + ctx_.err = err; + ctx_.notif.notify() catch @panic("Unable to notify"); + } + }.call), &ctx); + + try ctx.notif.wait(); + if (ctx.err) |e| { + defer e.deinit(api); + return e.getCode(api).toApiError(); + } + } +}; + +pub const LoadedExecutable = opaque { + const inner = InnerMixin(pjrt.LoadedExecutable).inner; + + pub fn deinit(self: *LoadedExecutable, api: *const Api) void { + self.inner().deinit(api); + } + + pub fn delete(self: *LoadedExecutable, api: *const Api) void { + self.inner().delete(api); + } + + pub fn isDeleted(self: *const LoadedExecutable, api: *const Api) bool { + return self.inner().isDeleted(api); + } + + pub fn getAddressableDevices(self: *const LoadedExecutable, api: *const Api) []*const Device { + return self.inner().getAddressableDevices(api); + } + + pub fn execute(self: *const LoadedExecutable, api: *const Api, args: struct { + arguments: []const [*]const *const Buffer, + num_args: usize, + results: []const [*]*Buffer, + events: []?*Event, + non_donatable_input_indices: []const i64 = &.{}, + }) ExecuteError!void { + try asynk.callBlocking(pjrt.LoadedExecutable.execute, .{ self.inner(), api, .{ + .num_args = args.num_args, + .arguments = @ptrCast(args.arguments), + .results = @ptrCast(args.results), + .events = @ptrCast(args.events), + .non_donatable_input_indices = args.non_donatable_input_indices, + } }); + } + + pub fn getExecutable(self: *LoadedExecutable, api: *const Api) ApiError!*Executable { + return try self.inner().getExecutable(api); + } +}; diff --git a/zml/platform.zig b/zml/platform.zig index 8ba22ea..d037c2b 100644 --- a/zml/platform.zig +++ b/zml/platform.zig @@ -1,12 +1,12 @@ const builtin = @import("builtin"); const std = @import("std"); -const pjrt = @import("pjrt"); const asynk = @import("async"); const runtimes = @import("runtimes"); const meta = @import("meta.zig"); const module = @import("module.zig"); +const pjrt = @import("pjrtx.zig"); const log = std.log.scoped(.zml); pub const Target = runtimes.Platform; @@ -93,37 +93,4 @@ pub const Platform = struct { pub fn getProfiler(self: Platform, options: pjrt.Profiler.Options) pjrt.Profiler { return self.pjrt_client.getProfiler(self.pjrt_api, options); } - - /// Suspend the current co-routine while awaiting for a pjrt event to be over. - pub fn awaitEvent(self: Platform, event: *pjrt.Event) !void { - defer event.deinit(self.pjrt_api); - // If we aren't in a coroutine just use the normal blocking api. - if (!asynk.inCoro()) { - return try event.await_(self.pjrt_api); - } - - var ctx = struct { - err: ?*pjrt.Error = null, - notif: asynk.Notification, - ready: bool = false, - }{ - .notif = try asynk.Notification.init(), - }; - defer ctx.notif.deinit(); - - try event.onReady(self.pjrt_api, &(struct { - fn call(err: ?*pjrt.Error, user_arg: ?*anyopaque) callconv(.C) void { - const ctx_: *@TypeOf(ctx) = @ptrCast(@alignCast(user_arg.?)); - ctx_.err = err; - @atomicStore(bool, &ctx_.ready, true, .seq_cst); - ctx_.notif.notify() catch @panic("Unable to notify"); - } - }.call), &ctx); - // Suspend - try ctx.notif.wait(); - if (ctx.err) |e| { - defer e.deinit(self.pjrt_api); - return e.getCode(self.pjrt_api).toApiError(); - } - } };