const std = @import("std"); const stdx = @import("stdx"); const aio = @import("aio.zig"); const Buffer = @import("buffer.zig").Buffer; const Bufferized = @import("tensor.zig").Bufferized; const callback = @import("callback.zig"); const CompilationContext = @import("module.zig").CompilationContext; const meta = @import("meta.zig"); const pjrt = @import("pjrtx.zig"); const Platform = @import("platform.zig").Platform; const Shape = @import("shape.zig").Shape; const ShapeOf = @import("tensor.zig").ShapeOf; const log = std.log.scoped(.@"zml/exe"); test { std.testing.refAllDecls(@This()); } /// Compiles a Model struct with the given configuration and shapes, for the given platform. /// The steps are: /// * lookup at tensors available in the store and create a `model: Model` struct with them /// * call `model.init(init_args)` to fields of the model that aren't Tensor, ie hyperparemeters/config /// * generate MLIR by calling `model.forward` with tensor of the given shapes and other arguments pub fn compile( allocator: std.mem.Allocator, comptime func: anytype, init_args: anytype, args_shapes: ShapeOf(ModuleSignature(func).ArgsT), buffer_store: aio.BufferStore, platform: Platform, ) !FnExe(func) { return compileWithPrefix(allocator, func, init_args, args_shapes, buffer_store, platform, ""); } /// Compiles a Model struct with the given configuration and shapes, for the given platform. /// Uses a prefix for looking up model weights in the buffer store. /// The steps are: /// * lookup at tensors available in the store and create a `model: Model` struct with them /// * call `model.init(init_args)` to fields of the model that aren't Tensor, ie hyperparemeters/config /// * generate MLIR by calling `model.forward` with tensor of the given shapes and other arguments pub fn compileWithPrefix( allocator: std.mem.Allocator, comptime func: anytype, init_args: anytype, args_shapes: ShapeOf(ModuleSignature(func).ArgsT), buffer_store: aio.BufferStore, platform: Platform, prefix: []const u8, ) !FnExe(func) { const ModelT = ModuleSignature(func).ModelT; var arena_state = std.heap.ArenaAllocator.init(allocator); defer arena_state.deinit(); const arena = arena_state.allocator(); var model = try aio.populateModelWithPrefix(ModelT, arena, buffer_store, prefix); // If the Model has a "init" function, call it with the given parameters. if (@hasDecl(ModelT, "init")) { // TODO(Corentin,@Improvement): Add a warning/error if there is no init function but init_args is non-void. @call(.auto, ModelT.init, .{@as(*ModelT, &model)} ++ init_args); } return compileModel(allocator, func, model, args_shapes, platform); } /// Compiles a Model struct with the given configuration and shapes, for the given platform. /// Generate MLIR by calling `model.forward` with tensor of the given shapes and other arguments pub fn compileModel( allocator: std.mem.Allocator, comptime func: anytype, model: ModuleSignature(func).ModelT, args_shapes: ShapeOf(ModuleSignature(func).ArgsT), platform: Platform, ) !FnExe(func) { const ModelT = ModuleSignature(func).ModelT; const name = @typeName(ModelT) ++ ".forward"; log.info("Compiling {s} with {f}", .{ name, stdx.fmt.any(args_shapes) }); var context = try CompilationContext.init(allocator, name, platform); defer context.deinit(); return .{ .inner = try context.compileInternal(allocator, func, .{model} ++ args_shapes) }; } /// Compiles a function with the given configuration and shapes, for the given platform. /// Generate MLIR by calling the given function with tensor of the given shapes. pub fn compileFn( allocator: std.mem.Allocator, comptime func: anytype, args: ShapeOf(stdx.meta.FnArgs(func)), platform: Platform, ) !FnExe(func) { var pretty_name = try prettyFnName(func, allocator); defer pretty_name.deinit(allocator); var context = try CompilationContext.init(allocator, pretty_name.items, platform); defer context.deinit(); return .{ .inner = try context.compileInternal(allocator, func, args) }; } pub fn FnExe(comptime func: anytype) type { return Exe(stdx.meta.FnArgs(func), stdx.meta.FnResult(func)); } /// Represents a ZML model, compiled into a PJRT executable, and ready to call. /// The buffers for the model weights are saved inside the struct and will be used in `call`. /// You only need to pass the remaining arguments. /// Creating a `ModuleExe` is a two steps proccess: /// /// ``` /// const exe: zml.FnExe(MyModel.forward) = try zml.compile(allocator, MyModel.forward, init_args, model_shapes, buffer_store, platform);` /// const module: zml.ModuleExe(MyModel.forward) = exe.prepare(model_buffers); /// ``` pub fn ModuleExe(comptime func: anytype) type { const AllArgs = stdx.meta.FnArgs(func); const len = @typeInfo(AllArgs).@"struct".fields.len; stdx.debug.assertComptime(len > 0, "ModuleExe expects a function with at least one argument where the first one is treated as the module, got {}", .{func}); return Exe(stdx.meta.Tail(AllArgs), stdx.meta.FnResult(func)); } // making this a struct force all fields to be evaluted on creation, // which gives a better error stacktrace // than delaying the error to when the object fields are read. const Sign = struct { ModelT: type, ArgsT: type, ReturnT: type, }; pub fn ModuleSignature(comptime func: anytype) Sign { const AllArgsT = stdx.meta.FnArgs(func); const len = @typeInfo(AllArgsT).@"struct".fields.len; stdx.debug.assertComptime(len > 0, "ModuleExe expects a function with at least one argument where the first one is treated as the module, got {}", .{func}); return .{ .ModelT = stdx.meta.Head(AllArgsT), .ArgsT = stdx.meta.Tail(AllArgsT), .ReturnT = stdx.meta.FnResult(func), }; } /// Represents an MLIR module compiled into a PJRT executable. /// The BaseExe is a plain old struct and doesn't have information about Zig types. /// /// It also contains pre-allocated buffers so that we can pass them to PJRT_LoadedExecutable_Execute /// without allocations. pub const BaseExe = struct { /// The platform for which this module was compiled. platform: Platform, /// The PJRT executable representing the compiled module. exe: *pjrt.LoadedExecutable, /// The execution context for this executable. execute_context: ?*pjrt.ExecuteContext, /// Pre-allocated slice of buffers to use as inputs when the module is called. input_per_device: []const [*]*pjrt.Buffer, /// Pre-allocated slice of buffers to use as outputs when the module is called. output_per_device: []const [*]*pjrt.Buffer, /// Number of buffers already fed to the executable. ready_buffer_count: u32, /// Total number of buffers needed by this executable. input_buffer_count: u32, input_shapes: []Shape, result_shapes: []Shape, /// Num devices used (>1 for sharded executable) num_devices: u8, /// Allocator backing memory _arena: std.heap.ArenaAllocator, pub fn init( parent_allocator: std.mem.Allocator, platform: Platform, exe: *pjrt.LoadedExecutable, args: struct { input_shapes: []const Shape, result_shapes: []const Shape, n_devices: u8 }, ) !BaseExe { var arena = std.heap.ArenaAllocator.init(parent_allocator); errdefer arena.deinit(); const allocator = arena.allocator(); const n_in = args.input_shapes.len; const n_out = args.result_shapes.len; const n_devices = args.n_devices; // Allocate once for all the *pjrt.Buffer we need to store ... const all_buffers = try allocator.alloc(*pjrt.Buffer, (n_in + n_out) * n_devices); const all_input_buffers, const all_output_buffers = splitBuffer(*pjrt.Buffer, all_buffers, .{ n_in * n_devices, n_out * n_devices }); // ... and once for all the [*]*pjrt.Buffer. const all_per_device = try allocator.alloc([*]*pjrt.Buffer, 2 * n_devices); const input_per_device, const output_per_device = splitBuffer([*]*pjrt.Buffer, all_per_device, .{ n_devices, n_devices }); for (0..n_devices) |i| { input_per_device[i] = all_input_buffers[i * n_in ..].ptr; output_per_device[i] = all_output_buffers[i * n_out ..].ptr; } const all_shapes = try allocator.alloc(Shape, n_in + n_out); @memcpy(all_shapes[0..n_in], args.input_shapes); @memcpy(all_shapes[n_in..], args.result_shapes); var execute_context: ?*pjrt.ExecuteContext = null; if (platform.pjrt_api.ffi()) |ffi| { execute_context = try platform.pjrt_api.createExecuteContext(); try callback.bindInternalCallbacks(allocator, platform, ffi, execute_context.?); // log.info("Created context execution {*} for {*}", .{ execute_context, exe }); } return .{ .platform = platform, .exe = exe, .execute_context = execute_context, .ready_buffer_count = 0, .input_buffer_count = @intCast(n_in), .num_devices = args.n_devices, .input_per_device = input_per_device, .output_per_device = output_per_device, .input_shapes = all_shapes[0..n_in], .result_shapes = all_shapes[n_in..], ._arena = arena, }; } pub fn deinit(self: BaseExe) void { if (self.execute_context) |ctx| { ctx.deinit(self.platform.pjrt_api); } self._arena.deinit(); } pub fn call(self: BaseExe) void { stdx.debug.assert(self.input_buffer_count == self.ready_buffer_count, "BaseExe isn't ready to be called, expected {} buffer inputs got {}", .{ self.input_buffer_count, self.ready_buffer_count }); return self._unsafeCall(); } pub fn _unsafeCall(self: BaseExe) void { var events = [_]?*pjrt.Event{null} ** Platform.MAX_NUM_DEVICES; const sharding = self.platform.sharding(); self.exe.execute(self.platform.pjrt_api, .{ .arguments = self.input_per_device, .num_args = self.input_buffer_count, .results = self.output_per_device, .events = events[0..sharding.num_partitions], // this allows to tell a specific buffer shouldn't be donated, // even if it has been marked as "can be donated" during compilation. // TODO: expose it ? .non_donatable_input_indices = &.{}, .context = self.execute_context, }) catch |err| { std.debug.panic("PJRT_LoadedExecutable_Execute failed with: {}", .{err}); }; // for (events[0..sharding.num_partitions]) |e| { // if (e) |ev| { // ev.await(self.platform.pjrt_api) catch unreachable; // } // } } pub fn _unsafeAssignResults(self: BaseExe, T: type, result: *T) void { const LocalContext = struct { index: u32, platform: Platform, outputs: []const [*]*pjrt.Buffer, output_shapes: []Shape, }; var local_ctx: LocalContext = .{ .index = 0, .platform = self.platform, .outputs = self.output_per_device, .output_shapes = self.result_shapes, }; meta.visit((struct { fn cb(ctx: *LocalContext, buffer: *Buffer) void { const i = ctx.index; ctx.index += 1; if (i >= ctx.output_shapes.len) return; var shards: Buffer.Shards = .{}; for (ctx.outputs) |buff| { shards.appendAssumeCapacity(buff[i]); } buffer.* = Buffer.fromPjrtBuffers(ctx.platform, ctx.output_shapes[i], shards.constSlice()); } }).cb, &local_ctx, result); stdx.debug.internalAssert(local_ctx.index == self.result_shapes.len, "Pjrt call returned {} tensors, but the return type {s}, contains {} Buffers. Note that modules need to have a comptime know number of returned tensors.", .{ self.output_per_device.len, @typeName(T), local_ctx.index }); } pub fn bind(exe: BaseExe, Callback: type, op: *Callback) !void { stdx.debug.assert(exe.execute_context != null, "Exe doesn't have an execution context", .{}); const pjrt_api = exe.platform.pjrt_api; if (pjrt_api.ffi()) |ffi| { try callback.addUserData(Callback, pjrt_api, ffi, exe.execute_context.?, op); } else { stdx.debug.panic("Callbacks are not supported for target {s}", .{@tagName(exe.platform.target)}); } } pub fn serialize(self: BaseExe, writer: anytype) !void { var executable = try self.exe.getExecutable(self.platform.pjrt_api); var serialize_result = try executable.serialize(self.platform.pjrt_api); defer serialize_result.deinit(); try writer.writeAll(serialize_result.bytes); } // pub fn deserialize(allocator: std.mem.Allocator, platform: Platform, reader: anytype) !Self { // const bytes = try reader.readToEndAlloc(allocator, max_pjrt_executable_size); // defer allocator.free(bytes); // return platform.pjrt_client.deserializeAndLoad(platform.pjrt_api, bytes); // } pub fn prepare(self: *BaseExe, x: anytype) void { const n = fillBuffers(&x, self.input_shapes, self.input_per_device, self.ready_buffer_count); self.ready_buffer_count += n; } pub fn getOutputBuffer(self: BaseExe, i: usize) Buffer { var shards: Buffer.Shards = .{}; for (self.output_per_device) |dev_out| { shards.appendAssumeCapacity(dev_out[i]); } return Buffer.fromPjrtBuffers(self.platform, self.result_shapes[i], shards.constSlice()); } pub fn clone(self: BaseExe, parent_allocator: std.mem.Allocator) !BaseExe { var exe: BaseExe = try .init(parent_allocator, self.platform, self.exe, .{ .input_shapes = self.input_shapes, .result_shapes = self.result_shapes, .n_devices = self.num_devices, }); exe.execute_context = self.execute_context; return exe; } }; /// Represents a ZML function, compiled into a PJRT executable. /// The signature of the Exe reflects the arguments that are needed for `call`. pub fn Exe(ArgsT: type, ReturnT: type) type { return struct { const Self = @This(); /// The raw untyped compiled module. inner: BaseExe, pub fn deinit(self: Self) void { self.inner.deinit(); } /// Hardcode the first argument of the function to the given buffers. /// Returns an Exe with one less argument in `call`. /// In functional languages this is known as partial application. /// /// **Warning:** the new Exe reuses the underlying memory of the previous one. /// The caller is responsible to come up with a strategy to call `deinit` exactly once. pub fn prepare(self: Self, first_arg: Bufferized(stdx.meta.Head(ArgsT))) Exe(stdx.meta.Tail(ArgsT), ReturnT) { var new: Exe(stdx.meta.Tail(ArgsT), ReturnT) = .{ .inner = self.inner }; new.inner.prepare(first_arg); return new; } /// For a given customCall inside this executable, /// provide a pointer to runtime data. /// The caller keeps memory ownership and need to ensure that the value /// stays alive as long as the executable. pub fn bind(self: Self, comptime T: type, value: *T) !void { try self.inner.bind(T, value); } pub fn serialize(self: Self, writer: anytype) !void { return try self.inner.serialize(writer); } pub fn platform(self: Self) Platform { return self.inner.platform; } pub fn call(self: Self, args: Bufferized(ArgsT)) Bufferized(ReturnT) { const total_ready = fillBuffers(&args, self.inner.input_shapes, self.inner.input_per_device, self.inner.ready_buffer_count); std.debug.assert(total_ready == self.inner.input_buffer_count); self.inner._unsafeCall(); var result: Bufferized(ReturnT) = undefined; self.inner._unsafeAssignResults(Bufferized(ReturnT), &result); return result; } }; } fn splitBuffer(T: type, buffer: []T, lengths: anytype) [lengths.len][]T { var res: [lengths.len][]T = undefined; var i: usize = 0; inline for (&res, lengths) |*r, len| { r.* = buffer[i .. i + len]; i += len; } std.debug.assert(i == buffer.len); return res; } /// Visit the given struct and fill the `buffers` slice with the buffer associated with encountered Tensor. fn fillBuffers(v: anytype, shapes: []const Shape, buffers: []const [*]*pjrt.Buffer, start: u32) u32 { const LocalContext = struct { index: u32, buffers: []const [*]*pjrt.Buffer, shapes: []const Shape, }; var context: LocalContext = .{ .index = start, .buffers = buffers, .shapes = shapes, }; meta.visit((struct { fn cb(ctx: *LocalContext, buffer: *const Buffer) void { // stdx.debug.assert(!buffer._data.isDeleted(), "Can't use {} (argument buffer {}) because its pjrt buffer has been donated", .{ buffer, ctx.index }); const model_sharding = ctx.buffers.len; stdx.debug.assert(buffer._shards.len == model_sharding, "Can't feed a {d}-sharded tensor into a {d}-sharded model", .{ buffer._shards.len, ctx.buffers.len }); stdx.debug.assert(ctx.shapes[ctx.index].eql(buffer.shape()), "Executable expected argument {} to have shape {f}, got {f}", .{ ctx.index, ctx.shapes[ctx.index], buffer.shape() }); for (buffer._shards.constSlice(), 0..) |shard, d| { ctx.buffers[d][ctx.index] = shard; } ctx.index += 1; } }).cb, &context, v); return context.index; } fn prettyFnName( comptime func: anytype, allocator: std.mem.Allocator, ) !std.ArrayListUnmanaged(u8) { const full_noisy_name = @typeName(@TypeOf(func)); const og_len = full_noisy_name.len; const buffer = try allocator.alloc(u8, og_len); errdefer comptime unreachable; // No errors below this point. var out: []u8 = buffer; { const verbose = "tensor.Tensor"; const compact = "Tensor"; const num_replacements = std.mem.replace(u8, full_noisy_name, verbose, compact, buffer); out.len = out.len + num_replacements * compact.len - num_replacements * verbose.len; } { const verbose = "tensor.Tensor."; const compact = ""; const num_replacements = std.mem.replace(u8, out, verbose, compact, buffer); out.len = out.len + num_replacements * compact.len - num_replacements * verbose.len; } { const verbose = "shape.Shape"; const compact = "Shape"; const num_replacements = std.mem.replace(u8, out, verbose, compact, buffer); out.len = out.len + num_replacements * compact.len - num_replacements * verbose.len; } return .{ .items = out, .capacity = og_len }; }