diff --git a/async/async.zig b/async/async.zig index ae77cf4..8e8cf74 100644 --- a/async/async.zig +++ b/async/async.zig @@ -398,3 +398,5 @@ pub const Mutex = struct { _ = self.inner.recv(); } }; + +pub const inCoro = libcoro.inCoro; diff --git a/zml/aio.zig b/zml/aio.zig index c7318f5..3a5ca4d 100644 --- a/zml/aio.zig +++ b/zml/aio.zig @@ -2,7 +2,6 @@ const builtin = @import("builtin"); const asynk = @import("async"); const std = @import("std"); const zml = @import("zml.zig"); -const pjrt = @import("pjrtx.zig"); const c = @import("c"); const posix = @import("posix.zig"); diff --git a/zml/buffer.zig b/zml/buffer.zig index ad1798a..02d4bbf 100644 --- a/zml/buffer.zig +++ b/zml/buffer.zig @@ -1,9 +1,9 @@ const std = @import("std"); const testing = std.testing; -const meta = @import("meta.zig"); const pjrt = @import("pjrt"); +const meta = @import("meta.zig"); const Context = @import("context.zig").Context; const Data = @import("dtype.zig").Data; const DataType = @import("dtype.zig").DataType; @@ -53,6 +53,7 @@ pub const Buffer = struct { const buffer_type = bufferTypeFromDtype(host_buffer.shape().dtype()); const byte_strides = host_buffer.strides() orelse host_buffer.shape().computeStrides().constSlice(); + var events: std.BoundedArray(*pjrt.Event, MAX_NUM_SHARDS) = .{}; const devices = platform.getDevices(); for (0..n_partitions) |i| { // If no sharding if found, the given buffer is replicated on all devices. @@ -61,7 +62,7 @@ pub const Buffer = struct { break :buf host_buffer.slice1d(ax, .{ .start = start, .end = start + chunk_size }); } else host_buffer; - const pjrt_buffer = try platform.pjrt_client.bufferFromHostBuffer(platform.pjrt_api, .{ + const pjrt_buffer, const event = try platform.pjrt_client.bufferFromHostBuffer(platform.pjrt_api, .{ .data = buf.data, .buffer_type = buffer_type, .dims = buf.shape().dims(), @@ -70,8 +71,13 @@ pub const Buffer = struct { .host_buffer_semantics = .ImmutableUntilTransferCompletes, }); + events.appendAssumeCapacity(event); res._shards.appendAssumeCapacity(pjrt_buffer); } + + for (events.constSlice()) |event| { + try platform.awaitEvent(event); + } return res; } diff --git a/zml/context.zig b/zml/context.zig index 4139972..6666589 100644 --- a/zml/context.zig +++ b/zml/context.zig @@ -1,18 +1,17 @@ const builtin = @import("builtin"); const std = @import("std"); -const mlir = @import("mlir"); + const asynk = @import("async"); +const mlir = @import("mlir"); +const pjrt = @import("pjrt"); const platform = @import("platform.zig"); -const pjrtx = @import("pjrtx.zig"); - -const available_targets = @import("platform.zig").available_targets; const Target = @import("platform.zig").Target; const Platform = @import("platform.zig").Platform; const log = std.log.scoped(.zml); -const PjrtApiMap = std.EnumArray(Target, ?*const pjrtx.Api); +const PjrtApiMap = std.EnumArray(Target, ?*const pjrt.Api); const PlatformsMap = std.EnumArray(Target, ?Platform); /// Every program using ZML must start with a `zml.Context.init(.{});` @@ -27,7 +26,7 @@ pub const Context = struct { fn call() void { inline for (platform.available_targets) |t| { if (canLoad(t)) { - if (pjrtx.Api.loadFrom(platformToLibrary(t))) |api| { + if (pjrt.Api.loadFrom(platformToLibrary(t))) |api| { Context.apis.set(t, api); } else |_| {} } @@ -107,7 +106,7 @@ pub const Context = struct { return std.mem.eql(u8, &buf, GoogleComputeEngine); } - pub fn pjrtApi(target: Target) *const pjrtx.Api { + pub fn pjrtApi(target: Target) *const pjrt.Api { return Context.apis.get(target).?; } diff --git a/zml/module.zig b/zml/module.zig index e8927e3..2cd64d0 100644 --- a/zml/module.zig +++ b/zml/module.zig @@ -1,5 +1,6 @@ const builtin = @import("builtin"); const std = @import("std"); +const pjrt = @import("pjrt"); const runfiles = @import("runfiles"); @@ -7,7 +8,6 @@ const xla_pb = @import("//xla:xla_proto"); const meta = @import("meta.zig"); const mlir = @import("mlir.zig"); const ops = @import("ops.zig"); -const pjrt = @import("pjrtx.zig"); const protobuf = @import("io/protobuf"); const asynk = @import("async"); const aio = @import("aio.zig"); @@ -1118,8 +1118,21 @@ fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, m } const options_bytes = try options.encode(arena); - const loaded_executable = try platform.pjrt_client.compile(platform.pjrt_api, arena, module, options_bytes); - errdefer unreachable; // errdefer loaded_executable.deinit(); + + var mlir_bytecode = std.ArrayList(u8).init(arena); + defer mlir_bytecode.deinit(); + // Note: we may need to restore IR downgrade if we need to support old pjrt plugins. + module.op().writeBytecode(mlir_bytecode.writer()); + + const loaded_executable = try asynk.call(pjrt.Client.compile, .{ + platform.pjrt_client, platform.pjrt_api, .{ + .bytecode = mlir_bytecode.items, + .bytecode_format = .mlir, + .compile_options_pb = options_bytes, + }, + }); + + errdefer loaded_executable.deinit(); if (platform.compilation_options.cache_location) |compilation_cache_location| { log.debug("Storing module to {s}", .{compilation_cache_location}); diff --git a/zml/pjrtx.zig b/zml/pjrtx.zig deleted file mode 100644 index 2f04c22..0000000 --- a/zml/pjrtx.zig +++ /dev/null @@ -1,270 +0,0 @@ -const builtin = @import("builtin"); -const std = @import("std"); - -const mlir = @import("mlir"); -const pjrt = @import("pjrt"); -const dtype = @import("dtype.zig"); -const meta = @import("meta.zig"); -const asynk = @import("async"); - -pub const Profiler = pjrt.Profiler; -pub const ApiError = pjrt.ApiError; -pub const ErrorCode = pjrt.ErrorCode; - -const Target = @import("platform.zig").Target; - -const log = std.log.scoped(.zml); - -pub const Buffer = pjrt.Buffer; -pub const Device = pjrt.Device; -pub const DeviceDescription = pjrt.DeviceDescription; -pub const Api = pjrt.Api; -pub const NamedValue = pjrt.NamedValue; -pub const ClientInitError = pjrt.ClientInitError; -pub const CompileError = std.mem.Allocator.Error || ApiError; -pub const Error = pjrt.Error; -pub const GetCostAnalysisError = pjrt.GetCostAnalysisError; -pub const SerializeResult = pjrt.SerializeResult; -pub const Executable = pjrt.Executable; -pub const ExecuteError = ApiError; - -test { - std.testing.refAllDecls(Client); - std.testing.refAllDecls(Event); - std.testing.refAllDecls(LoadedExecutable); -} - -fn InnerMixin(comptime innerT: type) type { - return struct { - inline fn inner(self: anytype) if (@typeInfo(@TypeOf(self)).Pointer.is_const) *const innerT else *innerT { - return @ptrCast(self); - } - }; -} - -pub const Client = opaque { - const inner = InnerMixin(pjrt.Client).inner; - - pub fn init(api: *const Api, create_options: []const NamedValue) ClientInitError!*Client { - return @ptrCast(try pjrt.Client.init(api, create_options)); - } - - pub fn deinit(self: *Client, api: *const Api) void { - self.inner().deinit(api); - } - - pub fn getPlatformName(self: *const Client, api: *const Api) []const u8 { - return self.inner().getPlatformName(api); - } - - pub fn getDevices(self: *const Client, api: *const Api) []const *const Device { - return self.inner().getDevices(api); - } - - pub fn getAddressableDevices(self: *const Client, api: *const Api) []const *const Device { - return self.inner().getAddressableDevices(api); - } - - pub const BufferFromHostBufferArgs = pjrt.Client.BufferFromHostBufferArgs; - pub fn bufferFromHostBuffer(self: *const Client, api: *const Api, args: BufferFromHostBufferArgs) !*Buffer { - const buffer, const event_ = try self.inner().bufferFromHostBuffer(api, args); - const event: *Event = @ptrCast(event_); - try event.await_(api); - return buffer; - } - - pub fn deserializeAndLoad(self: *const Client, api: *const Api, bytes: []const u8) ApiError!*LoadedExecutable { - return @ptrCast(try asynk.call(pjrt.Client.deserializeAndLoad, .{ self.inner(), api, bytes })); - } - - pub const CreateViewOfDeviceBufferArgs = pjrt.Client.CreateViewOfDeviceBufferArgs; - pub fn createViewOfDeviceBuffer(self: *const Client, api: *const Api, args: CreateViewOfDeviceBufferArgs) ApiError!*Buffer { - var args_ = args; - args_.on_delete_callback = args_.on_delete_callback orelse &(struct { - fn call(_: ?*anyopaque, _: ?*anyopaque) callconv(.C) void {} - }.call); - const buf = try self.inner().createViewOfDeviceBuffer(api, args_); - return @ptrCast(buf); - } - - fn downgradeStableHLO(self: Client, operation: mlir.Operation) mlir.Operation { - var cloned = operation.clone() catch unreachable; - cloned.walk(.pre_order, .{ .api_version = self.getApiVersion() }, struct { - const OpsStaticMap = std.StaticStringMap([]const [:0]const u8); - const convertPre40Ops = OpsStaticMap.initComptime(.{ - .{ "stablehlo.broadcast", &.{"broadcast_sizes"} }, - .{ "stablehlo.dynamic_slice", &.{"slice_sizes"} }, - .{ "stablehlo.fft", &.{"fft_length"} }, - .{ "stablehlo.pad", &.{ "edge_padding_low", "edge_padding_high", "interior_padding" } }, - .{ "stablehlo.reverse", &.{"dimensions"} }, - .{ "stablehlo.slice", &.{ "start_indices", "limit_indices", "strides" } }, - .{ "stablehlo.transpose", &.{"permutation"} }, - }); - const convertOps = OpsStaticMap.initComptime(.{ - .{ "stablehlo.broadcast_in_dim", &.{"broadcast_dimensions"} }, - .{ "stablehlo.convolution", &.{ "window_strides", "rhs_dilation", "lhs_dilation", "window_reversal" } }, - .{ "stablehlo.dynamic_broadcast_in_dim", &.{ "broadcast_dimensions", "known_expanding_dimensions", "known_nonexpanding_dimensions" } }, - .{ "stablehlo.dynamic_convolution", &.{ "window_strides", "rhs_dilation", "lhs_dilation", "window_reversal" } }, - .{ "stablehlo.gather", &.{"slice_sizes"} }, - .{ "stablehlo.map", &.{"dimensions"} }, - .{ "stablehlo.reduce", &.{"dimensions"} }, - .{ "stablehlo.reduce_window", &.{ "window_dimensions", "window_strides", "base_dilations", "window_dilations" } }, - .{ "stablehlo.select_and_scatter", &.{ "window_dimensions", "window_strides" } }, - }); - - fn convert(map: OpsStaticMap, op: mlir.Operation) void { - if (map.get(op.name().str())) |attrs| { - for (attrs) |attr_name| { - if (op.getAttributeByName(attr_name)) |attr| { - if (attr.as(mlir.DenseArrayAttribute(.bool))) |attr_| { - op.setAttributeByName(attr_name, attr_.toElements().as(mlir.Attribute).?); - } else if (attr.as(mlir.DenseArrayAttribute(.i64))) |attr_| { - op.setAttributeByName(attr_name, attr_.toElements().as(mlir.Attribute).?); - } - } - } - } - } - - fn walk(wctx: anytype, op: mlir.Operation) mlir.Operation.WalkResult { - // Keep in sync with https://github.com/openxla/xla/blob/a05ff095226aa2301903c2b475017b248d2c5ef3/xla/pjrt/mlir_to_hlo.cc#L101 - if (wctx.api_version.minor < 40) { - convert(convertPre40Ops, op); - } - convert(convertOps, op); - - return .advance; - } - }.walk); - return cloned; - } - - fn compileSync(self: *const Client, api: *const Api, allocator: std.mem.Allocator, module: mlir.Module, compile_options_pb: []const u8) CompileError!*LoadedExecutable { - var buffer = std.ArrayList(u8).init(allocator); - defer buffer.deinit(); - // Note: we may need to restore IR downgrade if we need to support old pjrt plugins. - module.op().writeBytecode(buffer.writer()); - - return @ptrCast(try self.inner().compile(api, .{ - .bytecode = buffer.items, - .bytecode_format = .mlir, - .compile_options_pb = compile_options_pb, - })); - } - - fn compileSync2(self: *const Client, api: *const Api, module: []const u8, compile_options_pb: []const u8) CompileError!*LoadedExecutable { - return @ptrCast(try self.inner().compile(api, .{ - .bytecode = module, - .bytecode_format = .mlir, - .compile_options_pb = compile_options_pb, - })); - } - - pub fn compile(self: *const Client, api: *const Api, allocator: std.mem.Allocator, module: mlir.Module, compile_options_pb: []const u8) CompileError!*LoadedExecutable { - return try asynk.call(compileSync, .{ self, api, allocator, module, compile_options_pb }); - } - - pub fn compile2(self: *const Client, api: *const Api, module: []const u8, compile_options_pb: []const u8) CompileError!*LoadedExecutable { - return try asynk.call(compileSync2, .{ self, api, module, compile_options_pb }); - } - - /// Returns the Profiler for this API. - /// Not all platform have a profiling api, for those the profiler object will do nothing. - /// Platforms with known profiler extensions: cuda, xpu - pub fn getProfiler(self: *const Client, api: *const Api, options: pjrt.Profiler.Options) pjrt.Profiler { - return self.inner().getProfiler(api, options); - } -}; - -pub const Event = opaque { - pub const inner = InnerMixin(pjrt.Event).inner; - - pub fn deinit(self: *Event, api: *const Api) void { - self.inner().deinit(api); - } - - pub fn isReady(self: *const Event, api: *const Api) bool { - return self.inner().isReady(api); - } - - pub fn getEventError(self: *const Event, api: *const Api) ?*Error { - return self.inner().getEventError(api); - } - - pub fn await_(self: *Event, api: *const Api) !void { - defer self.deinit(api); - try self.inner().await_(api); - - var ctx = struct { - err: ?*pjrt.Error = null, - notif: asynk.Notification, - ready: bool = false, - }{ - .notif = try asynk.Notification.init(), - }; - defer ctx.notif.deinit(); - - try self.inner().onReady(api, &(struct { - fn call(err: ?*pjrt.Error, user_arg: ?*anyopaque) callconv(.C) void { - const ctx_: *@TypeOf(ctx) = @ptrCast(@alignCast(user_arg.?)); - ctx_.err = err; - @atomicStore(bool, &ctx_.ready, true, .seq_cst); - ctx_.notif.notify() catch @panic("Unable to notify"); - } - }.call), &ctx); - - while (!ctx.ready) { - try ctx.notif.wait(); - } - if (ctx.err) |e| { - defer e.deinit(api); - return e.getCode(api).toApiError(); - } - } -}; - -pub const LoadedExecutable = opaque { - const inner = InnerMixin(pjrt.LoadedExecutable).inner; - - // pub fn deinit(self: *LoadedExecutable, api: *const Api) void { - // self.inner().deinit(api); - // } - - pub fn delete(self: *LoadedExecutable, api: *const Api) void { - self.inner().delete(api); - } - - pub fn isDeleted(self: *const LoadedExecutable, api: *const Api) bool { - return self.inner().isDeleted(api); - } - - // TODO fix me - // pub fn getAddressableDevices(self: *const LoadedExecutable, api: *const Api) []*const Device { - // return self.inner().getAddressableDevices(api); - // } - - pub fn execute(self: *const LoadedExecutable, api: *const Api, args: struct { - arguments: []const [*]const *const Buffer, - num_args: usize, - results: []const [*]*Buffer, - events: []*Event, - non_donatable_input_indices: []const i64 = &.{}, - }) ExecuteError!void { - try self.inner().execute(api, .{ - .num_args = args.num_args, - .arguments = @ptrCast(args.arguments), - .results = @ptrCast(args.results), - .events = @ptrCast(args.events), - .non_donatable_input_indices = args.non_donatable_input_indices, - }); - - for (args.events) |event| { - // TODO(Corentin): Maybe better handle the error here. - event.await_(api) catch return error.Unknown; - } - } - - pub fn getExecutable(self: *LoadedExecutable, api: *const Api) ApiError!*Executable { - return try self.inner().getExecutable(api); - } -}; diff --git a/zml/platform.zig b/zml/platform.zig index 586a2fa..a8ec82f 100644 --- a/zml/platform.zig +++ b/zml/platform.zig @@ -1,11 +1,11 @@ const builtin = @import("builtin"); const std = @import("std"); -const aio = @import("aio.zig"); +const pjrt = @import("pjrt"); +const asynk = @import("async"); + const meta = @import("meta.zig"); const module = @import("module.zig"); -const pjrt = @import("pjrtx.zig"); -const pjrt_core = @import("pjrt"); const log = std.log.scoped(.zml); pub const Target = enum { @@ -58,7 +58,7 @@ pub const Platform = struct { }; } - pub fn getDevices(self: Platform) []const *const pjrt_core.Device { + pub fn getDevices(self: Platform) []const *const pjrt.Device { const all_devices = self.pjrt_client.getAddressableDevices(self.pjrt_api); if (all_devices.len > MAX_NUM_DEVICES) { return all_devices[0..MAX_NUM_DEVICES]; @@ -94,7 +94,40 @@ pub const Platform = struct { /// Returns the Profiler for this API. /// Not all platform have a profiling api, for those the profiler object will do nothing. /// Platforms with known profiler extensions: cuda, xpu - pub fn getProfiler(self: Platform, options: pjrt_core.Profiler.Options) pjrt_core.Profiler { + pub fn getProfiler(self: Platform, options: pjrt.Profiler.Options) pjrt.Profiler { return self.pjrt_client.getProfiler(self.pjrt_api, options); } + + /// Suspend the current co-routine while awaiting for a pjrt event to be over. + pub fn awaitEvent(self: Platform, event: *pjrt.Event) !void { + defer event.deinit(self.pjrt_api); + // If we aren't in a coroutine just use the normal blocking api. + if (!asynk.inCoro()) { + return try event.await_(self.pjrt_api); + } + + var ctx = struct { + err: ?*pjrt.Error = null, + notif: asynk.Notification, + ready: bool = false, + }{ + .notif = try asynk.Notification.init(), + }; + defer ctx.notif.deinit(); + + try event.onReady(self.pjrt_api, &(struct { + fn call(err: ?*pjrt.Error, user_arg: ?*anyopaque) callconv(.C) void { + const ctx_: *@TypeOf(ctx) = @ptrCast(@alignCast(user_arg.?)); + ctx_.err = err; + @atomicStore(bool, &ctx_.ready, true, .seq_cst); + ctx_.notif.notify() catch @panic("Unable to notify"); + } + }.call), &ctx); + // Suspend + try ctx.notif.wait(); + if (ctx.err) |e| { + defer e.deinit(self.pjrt_api); + return e.getCode(self.pjrt_api).toApiError(); + } + } }; diff --git a/zml/tensor.zig b/zml/tensor.zig index dc0012e..514e986 100644 --- a/zml/tensor.zig +++ b/zml/tensor.zig @@ -3,7 +3,6 @@ const std = @import("std"); const assert = std.debug.assert; const testing = std.testing; -const pjrt = @import("pjrtx.zig"); const meta = @import("meta.zig"); const mlir = @import("mlir.zig"); const ops = @import("ops.zig"); @@ -3487,46 +3486,6 @@ test "argMax" { } } -fn dynamicSlice1d() void { - const zml = @import("zml.zig"); - const platform = zml.testing.env(); - var arena_state = std.heap.ArenaAllocator.init(std.testing.allocator); - defer arena_state.deinit(); - const allocator = arena_state.allocator(); - const T = f32; - - { - defer _ = arena_state.reset(.retain_capacity); - const x = try zml.Buffer.fromArray(platform, [10]T{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 }); - const z = try zml.Buffer.scalar(platform, 4, .i32); - var comp = zml.module.CompilationContext.init(allocator, "test", platform, .{}); - defer comp.deinit(); - var x_tensor = x.shape(); - var args: struct { i8, u63, zml.Shape } = .{ 0, 2, z.shape() }; - var dynamicSlice = try zml.compileRaw(allocator, &comp, Tensor.dynamicSlice1d, &x_tensor, &args); - - var res: [1]*pjrt.Buffer = undefined; - dynamicSlice.call(&.{ x._data, z._data }, &res); - try testing.expectEqual([2]T{ 4, 5 }, try zml.Buffer.fromPjrtBuffer(platform, res[0]).getValue([2]T)); - } - - { - // Strided - var x = try zml.Buffer.fromArray(platform, [2][5]T{ .{ 0, 1, 2, 3, 4 }, .{ 5, 6, 7, 8, 9 } }); - var z = try zml.Buffer.scalar(platform, 3, .i32); - - var comp = zml.module.CompilationContext.init(allocator, "test", platform, .{}); - defer comp.deinit(); - var x_tensor = x.shape(); - var args: struct { i8, u63, zml.Tensor } = .{ 1, 2, z.shape() }; - var dynamicSlice = try zml.compileRaw(allocator, &comp, Tensor.dynamicSlice1d, &x_tensor, &args); - - var res: [1]*pjrt.Buffer = undefined; - dynamicSlice.call(&.{ x._data, z._data }, &res); - try testing.expectEqualSlices(T, &.{ 3, 4, 8, 9 }, &(try zml.Buffer.fromPjrtBuffer(platform, res[0]).getValue([4]T))); - } -} - test "dynamicUpdateSlice1d" { const zml = @import("zml.zig"); const platform = zml.testing.env();