From 7d36913b318180257e469d7192976bbfa082fe72 Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Fri, 13 Oct 2023 16:08:08 +0000 Subject: [PATCH] Refactor ZML API: move compile, compileFn and related types to `exe.zig`, update `BaseExe` allocation and inline caching in `compileInternal`, and clean up supporting modules (`func.zig`, `meta.zig`, `signature.zig`, `cuda.zig`, `testing.zig`, `zml.zig`). --- mlir/dialects/func.zig | 2 +- stdx/meta.zig | 22 + stdx/signature.zig | 34 +- zml/exe.zig | 356 ++++++++++++++++ zml/module.zig | 920 +++++++++++------------------------------ zml/nn/cuda.zig | 2 +- zml/testing.zig | 4 +- zml/zml.zig | 9 +- 8 files changed, 654 insertions(+), 695 deletions(-) create mode 100644 zml/exe.zig diff --git a/mlir/dialects/func.zig b/mlir/dialects/func.zig index 96ba1d4..9e7c281 100644 --- a/mlir/dialects/func.zig +++ b/mlir/dialects/func.zig @@ -4,7 +4,7 @@ const mlir = @import("mlir"); pub fn func( ctx: mlir.Context, args: struct { - sym_name: [:0]const u8, + sym_name: []const u8, args: []const mlir.Type, arg_attrs: []const mlir.Attribute = &.{}, results: []const mlir.Type, diff --git a/stdx/meta.zig b/stdx/meta.zig index 2eb69fe..40b33bf 100644 --- a/stdx/meta.zig +++ b/stdx/meta.zig @@ -156,3 +156,25 @@ pub fn FnArgs(comptime func: anytype) type { pub fn FnResult(comptime func: anytype) type { return FnSignature(func, null).ReturnT; } + +pub fn Head(Tuple: type) type { + return switch (@typeInfo(Tuple)) { + .Struct => |struct_info| { + if (struct_info.fields.len == 0) @compileError("Can't tail empty tuple"); + return struct_info.fields[0].type; + }, + else => @compileError("Head works on tuple type"), + }; +} + +pub fn Tail(Tuple: type) type { + return switch (@typeInfo(Tuple)) { + .Struct => |struct_info| { + if (struct_info.fields.len == 0) @compileError("Can't tail empty tuple"); + var types: [struct_info.fields.len - 1]type = undefined; + for (struct_info.fields[1..], 0..) |field, i| types[i] = field.type; + return std.meta.Tuple(&types); + }, + else => @compileError("Tail works on tuple type"), + }; +} diff --git a/stdx/signature.zig b/stdx/signature.zig index 48aa4e0..7058efd 100644 --- a/stdx/signature.zig +++ b/stdx/signature.zig @@ -1,8 +1,8 @@ const std = @import("std"); -const compileError = @import("meta.zig").compileError; +const compileError = @import("debug.zig").compileError; -pub fn ArgsTuple(comptime funcT: anytype, comptime argsT: ?type) type { +pub fn ArgsTuple(comptime funcT: anytype, comptime ArgsT: ?type) type { const params = @typeInfo(funcT).Fn.params; if (params.len == 0) { return @TypeOf(.{}); @@ -12,8 +12,11 @@ pub fn ArgsTuple(comptime funcT: anytype, comptime argsT: ?type) type { return std.meta.ArgsTuple(funcT); } - const args = std.meta.fields(argsT orelse compileError("generic function requires an explicit ArgsTuple", .{})); + const args = std.meta.fields(ArgsT orelse compileError("generic function requires an explicit ArgsTuple", .{})); var tuple_fields: [params.len]std.builtin.Type.StructField = undefined; + if (params.len != args.len) { + compileError("function {} expected {} args, got {}", .{ funcT, params.len, args.len }); + } inline for (params, args, 0..) |param, arg, i| { if (param.type == null) { tuple_fields[i] = arg; @@ -44,15 +47,30 @@ pub fn ArgsTuple(comptime funcT: anytype, comptime argsT: ?type) type { }); } -pub fn FnSignature(comptime func: anytype, comptime argsT: ?type) type { - return FnSignatureX(func, ArgsTuple(@TypeOf(func), argsT)); +pub fn FnSignature(comptime func: anytype, comptime ArgsT: ?type) type { + const n_params = switch (@typeInfo(@TypeOf(func))) { + .Fn => |fn_info| fn_info.params.len, + else => compileError("FnSignature expects a function as first argument got: {}", .{@TypeOf(func)}), + }; + if (ArgsT != null) { + const n_args = switch (@typeInfo(ArgsT.?)) { + .Struct => |struct_info| struct_info.fields.len, + else => compileError("function {} need to be called with a tuple of args", .{@TypeOf(func)}), + }; + if (n_params != n_args) { + compileError("function {} expected {} args, got {}", .{ @TypeOf(func), n_params, n_args }); + } + } + return FnSignatureX(func, ArgsTuple(@TypeOf(func), ArgsT)); } -fn FnSignatureX(comptime func: anytype, comptime argsT: type) type { +// TODO: I think this should return a struct instead of returing at type +// this gives a better error stacktrace because here the error is delayed to when the fields are read. +fn FnSignatureX(comptime func: anytype, comptime ArgsT_: type) type { return struct { pub const FuncT = @TypeOf(func); - pub const ArgsT = argsT; - pub const ReturnT = @TypeOf(@call(.auto, func, @as(ArgsT, undefined))); + pub const ArgsT = ArgsT_; + pub const ReturnT = @TypeOf(@call(.auto, func, @as(ArgsT_, undefined))); pub const ReturnPayloadT = switch (@typeInfo(ReturnT)) { .ErrorUnion => |u| u.payload, else => ReturnT, diff --git a/zml/exe.zig b/zml/exe.zig new file mode 100644 index 0000000..12a3fec --- /dev/null +++ b/zml/exe.zig @@ -0,0 +1,356 @@ +const std = @import("std"); +const stdx = @import("stdx"); + +const aio = @import("aio.zig"); +const meta = @import("meta.zig"); +const pjrt = @import("pjrtx.zig"); + +const Buffer = @import("buffer.zig").Buffer; +const Bufferized = @import("tensor.zig").Bufferized; +const CompilationContext = @import("module.zig").CompilationContext; +const Platform = @import("platform.zig").Platform; +const Shape = @import("shape.zig").Shape; +const ShapeOf = @import("tensor.zig").ShapeOf; + +const log = std.log.scoped(.zml); + +test { + std.testing.refAllDecls(@This()); +} + +/// Compiles a Model struct with the given configuration and shapes, for the given platform. +/// The steps are: +/// * lookup at tensors available in the store and create a `model: Model` struct with them +/// * call `model.init(init_args)` to fields of the model that aren't Tensor, ie hyperparemeters/config +/// * generate MLIR by calling `model.forward` with tensor of the given shapes and other arguments +pub fn compile( + allocator: std.mem.Allocator, + comptime func: anytype, + init_args: anytype, + args_shapes: ShapeOf(ModuleSignature(func).ArgsT), + buffer_store: aio.BufferStore, + platform: Platform, +) !FnExe(func) { + const ModelT = ModuleSignature(func).ModelT; + + var arena_state = std.heap.ArenaAllocator.init(allocator); + defer arena_state.deinit(); + const arena = arena_state.allocator(); + var model = try aio.populateModel(ModelT, arena, buffer_store); + + // If the Model has a "init" function, call it with the given parameters. + if (@hasDecl(ModelT, "init")) { + // TODO(Corentin,@Improvement): Add a warning/error if there is no init function but init_args is non-void. + @call(.auto, ModelT.init, .{@as(*ModelT, &model)} ++ init_args); + } + + return compileModel(allocator, func, model, args_shapes, platform); +} + +/// Compiles a Model struct with the given configuration and shapes, for the given platform. +/// Generate MLIR by calling `model.forward` with tensor of the given shapes and other arguments +pub fn compileModel( + allocator: std.mem.Allocator, + comptime func: anytype, + model: ModuleSignature(func).ModelT, + args_shapes: ShapeOf(ModuleSignature(func).ArgsT), + platform: Platform, +) !FnExe(func) { + const ModelT = ModuleSignature(func).ModelT; + const name = @typeName(ModelT) ++ ".forward"; + log.info("Compiling {s} with {}", .{ name, args_shapes }); + + var context = try CompilationContext.init(allocator, name, platform); + defer context.deinit(); + + return .{ .inner = try context.compileInternal(allocator, func, .{model} ++ args_shapes) }; +} + +/// Compiles a function with the given configuration and shapes, for the given platform. +/// Generate MLIR by calling the given function with tensor of the given shapes. +pub fn compileFn( + allocator: std.mem.Allocator, + comptime func: anytype, + args: ShapeOf(stdx.meta.FnArgs(func)), + platform: Platform, +) !FnExe(func) { + const name = @typeName(@TypeOf(func)); + var context = try CompilationContext.init(allocator, name, platform); + defer context.deinit(); + + return .{ .inner = try context.compileInternal(allocator, func, args) }; +} + +pub fn FnExe(comptime func: anytype) type { + return Exe(stdx.meta.FnArgs(func), stdx.meta.FnResult(func)); +} + +/// Represents a ZML model, compiled into a PJRT executable, and ready to call. +/// The buffers for the model weights are saved inside the struct and will be used in `call`. +/// You only need to pass the remaining arguments. +/// Creating a `ModuleExe` is a two steps proccess: +/// +/// ``` +/// const exe: zml.FnExe(MyModel.forward) = try zml.compile(allocator, MyModel.forward, init_args, model_shapes, buffer_store, platform);` +/// const module: zml.ModuleExe(MyModel.forward) = exe.prepare(model_buffers); +/// ``` +pub fn ModuleExe(comptime func: anytype) type { + const AllArgs = stdx.meta.FnArgs(func); + const len = @typeInfo(AllArgs).Struct.fields.len; + stdx.debug.assertComptime(len > 0, "ModuleExe expects a function with at least one argument where the first one is treated as the module, got {}", .{func}); + return Exe(stdx.meta.Tail(AllArgs), stdx.meta.FnResult(func)); +} + +// making this a struct force all fields to be evaluted on creation, +// which gives a better error stacktrace +// than delaying the error to when the object fields are read. +const Sign = struct { + ModelT: type, + ArgsT: type, + ReturnT: type, +}; + +pub fn ModuleSignature(comptime func: anytype) Sign { + const AllArgsT = stdx.meta.FnArgs(func); + const len = @typeInfo(AllArgsT).Struct.fields.len; + stdx.debug.assertComptime(len > 0, "ModuleExe expects a function with at least one argument where the first one is treated as the module, got {}", .{func}); + + return .{ + .ModelT = stdx.meta.Head(AllArgsT), + .ArgsT = stdx.meta.Tail(AllArgsT), + .ReturnT = stdx.meta.FnResult(func), + }; +} + +/// Represents an MLIR module compiled into a PJRT executable. +/// The BaseExe is a plain old struct and doesn't have information about Zig types. +/// +/// It also contains pre-allocated buffers so that we can pass them to PJRT_LoadedExecutable_Execute +/// without allocations. +pub const BaseExe = struct { + /// The platform for which this module was compiled. + platform: Platform, + + /// The PJRT executable representing the compiled module. + exe: *pjrt.LoadedExecutable, + + /// Pre-allocated slice of buffers to use as inputs when the module is called. + input_per_device: []const [*]*pjrt.Buffer, + + /// Pre-allocated slice of buffers to use as outputs when the module is called. + output_per_device: []const [*]*pjrt.Buffer, + + /// Number of buffers already fed to the executable. + ready_buffer_count: u32, + + /// Total number of buffers needed by this executable. + input_buffer_count: u32, + + result_shapes: []Shape, + + /// Num devices used (>1 for sharded executable) + num_devices: u8, + + /// Allocator backing memory + _arena: std.heap.ArenaAllocator, + + pub fn init(parent_allocator: std.mem.Allocator, platform: Platform, exe: *pjrt.LoadedExecutable, args: struct { n_in: u32, result_shapes: []const Shape, n_devices: u8 }) !BaseExe { + var arena = std.heap.ArenaAllocator.init(parent_allocator); + errdefer arena.deinit(); + const allocator = arena.allocator(); + const n_out = args.result_shapes.len; + const n_devices = args.n_devices; + // Allocate once for all the *pjrt.Buffer we need to store ... + const all_buffers = try allocator.alloc(*pjrt.Buffer, (args.n_in + n_out) * n_devices); + const all_input_buffers, const all_output_buffers = splitBuffer(*pjrt.Buffer, all_buffers, .{ args.n_in * n_devices, n_out * n_devices }); + + // ... and once for all the [*]*pjrt.Buffer. + const all_per_device = try allocator.alloc([*]*pjrt.Buffer, 2 * n_devices); + const input_per_device, const output_per_device = splitBuffer([*]*pjrt.Buffer, all_per_device, .{ n_devices, n_devices }); + + for (0..n_devices) |i| { + input_per_device[i] = all_input_buffers[i * args.n_in ..].ptr; + output_per_device[i] = all_output_buffers[i * n_out ..].ptr; + } + + return .{ + .platform = platform, + .exe = exe, + .ready_buffer_count = 0, + .input_buffer_count = args.n_in, + .num_devices = args.n_devices, + .input_per_device = input_per_device, + .output_per_device = output_per_device, + .result_shapes = try allocator.dupe(Shape, args.result_shapes), + ._arena = arena, + }; + } + + pub fn deinit(self: BaseExe) void { + self._arena.deinit(); + } + + pub fn call(self: BaseExe) void { + stdx.debug.assert(self.input_buffer_count == self.ready_buffer_count, "BaseExe isn't ready to be called, expected {} buffer inputs got {}", .{ self.input_buffer_count, self.ready_buffer_count }); + return self._unsafeCall(); + } + + pub fn _unsafeCall(self: BaseExe) void { + var events = [_]?*pjrt.Event{null} ** Platform.MAX_NUM_DEVICES; + const sharding = self.platform.sharding(); + + self.exe.execute(self.platform.pjrt_api, .{ + .arguments = self.input_per_device, + .num_args = self.input_buffer_count, + .results = self.output_per_device, + .events = events[0..sharding.num_partitions], + // this allows to tell a specific buffer shouldn't be donated, + // even if it has been marked as "can be donated" during compilation. + // TODO: expose it ? + .non_donatable_input_indices = &.{}, + }) catch unreachable; + + for (events[0..sharding.num_partitions]) |e| { + if (e) |ev| { + ev.await_(self.platform.pjrt_api) catch unreachable; + } + } + } + + pub fn serialize(self: BaseExe, writer: anytype) !void { + var executable = try self.exe.getExecutable(self.platform.pjrt_api); + var serialize_result = try executable.serialize(self.platform.pjrt_api); + defer serialize_result.deinit(); + try writer.writeAll(serialize_result.bytes); + } + + // pub fn deserialize(allocator: std.mem.Allocator, platform: Platform, reader: anytype) !Self { + // const bytes = try reader.readToEndAlloc(allocator, max_pjrt_executable_size); + // defer allocator.free(bytes); + // return platform.pjrt_client.deserializeAndLoad(platform.pjrt_api, bytes); + // } + + pub fn prepare(self: *BaseExe, x: anytype) void { + const n = fillBuffers(&x, self.input_per_device, self.ready_buffer_count); + self.ready_buffer_count += n; + } + + pub fn getOutputBuffer(self: BaseExe, i: usize) Buffer { + var shards: Buffer.Shards = .{}; + for (self.output_per_device) |dev_out| { + shards.appendAssumeCapacity(dev_out[i]); + } + + const out_shape = self.inner.result_buffer_shapes[i]; + return Buffer.fromPjrtBuffers(self.platform(), out_shape, shards.constSlice()); + } +}; + +/// Represents a ZML function, compiled into a PJRT executable. +/// The signature of the Exe reflects the arguments that are needed for `call`. +pub fn Exe(ArgsT: type, ReturnT: type) type { + return struct { + const Self = @This(); + + /// The raw untyped compiled module. + inner: BaseExe, + + pub fn deinit(self: Self) void { + self.inner.deinit(); + } + + /// Hardcode the first argument of the function to the given buffers. + /// Returns an Exe with one less argument in `call`. + /// In functional languages this is known as partial application. + /// + /// **Warning:** the new Exe reuses the underlying memory of the previous one. + /// The caller is responsible to come up with a strategy to call `deinit` exactly once. + pub fn prepare(self: Self, first_arg: Bufferized(stdx.meta.Head(ArgsT))) Exe(stdx.meta.Tail(ArgsT), ReturnT) { + var new: Exe(stdx.meta.Tail(ArgsT), ReturnT) = .{ .inner = self.inner }; + new.inner.prepare(first_arg); + return new; + } + + pub fn serialize(self: Self, writer: anytype) !void { + return try self.inner.serialize(writer); + } + + pub fn platform(self: Self) Platform { + return self.inner.platform; + } + + pub fn call(self: Self, args: Bufferized(ArgsT)) Bufferized(ReturnT) { + const total_ready = fillBuffers(&args, self.inner.input_per_device, self.inner.ready_buffer_count); + std.debug.assert(total_ready == self.inner.input_buffer_count); + self.inner._unsafeCall(); + var result: Bufferized(ReturnT) = undefined; + assignRawBuffers(&result, self.inner.platform, self.inner.output_per_device, self.inner.result_shapes); + return result; + } + }; +} + +fn splitBuffer(T: type, buffer: []T, lengths: anytype) [lengths.len][]T { + var res: [lengths.len][]T = undefined; + var i: usize = 0; + inline for (&res, lengths) |*r, len| { + r.* = buffer[i .. i + len]; + i += len; + } + std.debug.assert(i == buffer.len); + return res; +} + +/// Visit the given struct and fill the `buffers` slice with the buffer associated with encountered Tensor. +fn fillBuffers(v: anytype, buffers: []const [*]*pjrt.Buffer, start: u32) u32 { + const LocalContext = struct { + index: u32, + buffers: []const [*]*pjrt.Buffer, + }; + var context: LocalContext = .{ + .index = start, + .buffers = buffers, + }; + meta.visit((struct { + fn cb(ctx: *LocalContext, buffer: *const Buffer) void { + // stdx.debug.assert(!buffer._data.isDeleted(), "Can't use {} (argument buffer {}) because its pjrt buffer has been donated", .{ buffer, ctx.index }); + const model_sharding = ctx.buffers.len; + stdx.debug.assert(buffer._shards.len == model_sharding, "Can't feed a {}-sharded tensor into a {}-sharded model", .{ buffer._shards.len, ctx.buffers.len }); + for (buffer._shards.constSlice(), 0..) |shard, d| { + ctx.buffers[d][ctx.index] = shard; + } + ctx.index += 1; + } + }).cb, &context, v); + return context.index; +} + +/// Visit the given struct and override tensors by creating a new one using the provided PJRT buffers. +fn assignRawBuffers(v: anytype, platform: Platform, buffers: []const [*]*pjrt.Buffer, buffer_shapes: []Shape) void { + const LocalContext = struct { + index: u32, + platform: Platform, + buffers: []const [*]*pjrt.Buffer, + buffer_shapes: []Shape, + }; + var local_ctx: LocalContext = .{ + .index = 0, + .platform = platform, + .buffers = buffers, + .buffer_shapes = buffer_shapes, + }; + meta.visit((struct { + fn cb(ctx: *LocalContext, buffer: *Buffer) void { + const i = ctx.index; + ctx.index += 1; + if (i >= ctx.buffer_shapes.len) return; + + var shards: Buffer.Shards = .{}; + for (ctx.buffers) |buff| { + shards.appendAssumeCapacity(buff[i]); + } + buffer.* = Buffer.fromPjrtBuffers(ctx.platform, ctx.buffer_shapes[i], shards.constSlice()); + } + }).cb, &local_ctx, v); + stdx.debug.internalAssert(local_ctx.index == buffer_shapes.len, "Pjrt call returned {} tensors, but the return type {s}, contains {} Buffers. Note that modules need to have a comptime know number of returned tensors.", .{ buffers.len, @typeName(@TypeOf(v)), local_ctx.index }); +} diff --git a/zml/module.zig b/zml/module.zig index e2c621c..c1bd7ac 100644 --- a/zml/module.zig +++ b/zml/module.zig @@ -1,8 +1,8 @@ +const std = @import("std"); + const asynk = @import("async"); -const builtin = @import("builtin"); const dialect = @import("mlir/dialects"); const runfiles = @import("runfiles"); -const std = @import("std"); const stdx = @import("stdx"); const xla_pb = @import("//xla:xla_proto"); @@ -10,11 +10,10 @@ const meta = @import("meta.zig"); const mlir = @import("mlir.zig"); const ops = @import("ops.zig"); const pjrt = @import("pjrtx.zig"); -const aio = @import("aio.zig"); +const BaseExe = @import("exe.zig").BaseExe; const Buffer = @import("buffer.zig").Buffer; const Bufferized = @import("tensor.zig").Bufferized; -const Context = @import("context.zig").Context; const Location = mlir.Location; const Platform = @import("platform.zig").Platform; const Shape = @import("shape.zig").Shape; @@ -23,7 +22,6 @@ const Target = @import("platform.zig").Target; const Tensor = @import("tensor.zig").Tensor; const Tracer = @import("tools/tracer.zig").Tracer; -const assert = std.debug.assert; const log = std.log.scoped(.@"zml/module"); test { @@ -70,8 +68,18 @@ const Block = union(BlockKind) { } }; +pub const MlirFn = struct { + name: []const u8, + num_args: u32, + res_types: []mlir.Type, + res_shapes: []Shape, + res_donations: []Tensor._Donation, + mlir_fn: mlir.Operation, +}; + pub const CompilationContext = struct { _platform: Platform, + _name: []const u8, _mlir_ctx: mlir.Context, _mlir_registry: mlir.Registry, @@ -81,6 +89,7 @@ pub const CompilationContext = struct { _blocks: std.BoundedArray(Block, 64), _fn_cache: FnCache, + // TODO: make this an arena, that way it's fine if ops allocate inside it. _allocator: std.mem.Allocator, _buffer_to_arg: TensorToBlockArg = .{}, @@ -115,6 +124,7 @@ pub const CompilationContext = struct { return .{ ._platform = platform, + ._name = name, ._mlir_ctx = mlir_ctx, ._mlir_registry = mlir_registry, ._mlir_canonicalizer = canonicalizer, @@ -139,7 +149,7 @@ pub const CompilationContext = struct { } pub fn deactivate(self: *CompilationContext) void { - assert(_current != null and _current.? == self); + std.debug.assert(_current != null and _current.? == self); _current = self._previous; self._previous = null; } @@ -152,15 +162,103 @@ pub const CompilationContext = struct { return self._platform.target; } - pub fn targetIs(self: *const CompilationContext, value: Target) bool { - return self.target() == value; - } - pub fn mlirCtx(self: *const CompilationContext) mlir.Context { return self._mlir_ctx; } - pub fn currentBlock(self: *const CompilationContext) ?Block { + /// Compiles the given function with the given arguments. + /// This is the untyped API and is not meant to be use directly. + /// + /// * allocator is used to allocate the + /// * args can contain a mix of tensors and shapes, allowing to pass a "model struct" containig tensors. + pub fn compileInternal( + self: *CompilationContext, + allocator: std.mem.Allocator, + comptime func: anytype, + args: anytype, + ) !BaseExe { + var arena_state = std.heap.ArenaAllocator.init(allocator); + defer arena_state.deinit(); + const arena = arena_state.allocator(); + + var timer = std.time.Timer.start() catch null; + const tensor_args = self.tensorFromShapes(stdx.meta.FnArgs(func), arena, args); + // Run in a dedicated thread because compilation relies on `threadlocal`. + const f = try asynk.callBlocking(CompilationContext.generateBytecode, .{ self, arena, "main", func, &tensor_args }); + const module = self._module; + module.getBody().appendOperation(f.mlir_fn); + + const sharding = self._platform.sharding(); + const mlir_ctx = self._mlir_ctx; + module.op().setAttributeByName("mhlo.num_replicas", mlir.IntegerAttribute(.i32).init(mlir_ctx, sharding.num_replicas).asAttr()); + module.op().setAttributeByName("mhlo.num_partitions", mlir.IntegerAttribute(.i32).init(mlir_ctx, sharding.num_partitions).asAttr()); + + const module_hash = computeModuleHash(self._platform, module); + if (self._platform.compilation_options.xla_dump_to) |xla_dump_to| { + // Write the mlir to a file. All errors are discarded, since this is for debugging only. + if (std.fs.openDirAbsolute(xla_dump_to, .{})) |dir| { + const name = self._name; + const file_name = std.fmt.allocPrint(arena, "{s}_{x}.mlir", .{ name, module_hash }) catch name; + if (dir.createFile(file_name, .{ .truncate = true })) |file| { + module.op().print(file.writer(), .{ .debug_info = true, .debug_info_pretty_form = false }); + log.info("Wrote MLIR to {s}/{s}", .{ xla_dump_to, file_name }); + } else |_| { + log.warn("Failed to open {s}", .{file_name}); + } + } else |_| { + log.warn("Folder not found {s}", .{xla_dump_to}); + } + } + + const tracer = Tracer.init("ai.zml.compilation"); + const compile_frame = tracer.frameStart("pjrt cached compilation"); + defer tracer.frameEnd(compile_frame, "pjrt cached compilation"); + + const loaded_executable: *pjrt.LoadedExecutable = blk: { + const cache_location = try absoluteCacheFileZ(arena, self._platform.compilation_options.cache_location, module_hash); + if (cache_location) |cache_file| { + if (loadPjrtExecutable(arena, self._platform, cache_file)) |exe| { + break :blk exe; + } else |_| {} + } + + const loaded_executable = compileModuleToPjrtExecutable(arena, self._platform, module) catch |err| { + log.err( + "pjrt-{s} failed to compile following valid MLIR:\n{}\n{}", + .{ @tagName(self._platform.target), module.op().mlirFormatter(.{}), err }, + ); + return err; + }; + + if (cache_location) |cache_file| { + storePjrtExecutable(self._platform, loaded_executable, cache_file) catch |err| { + log.debug("Failed to store module: {}", .{err}); + }; + } + break :blk loaded_executable; + }; + + log.debug("******** ZML generated MLIR ********", .{}); + log.debug("{}", .{module.op().mlirFormatter(.{})}); + + if (timer) |*t| { + const time_ms = @divFloor(t.lap(), std.time.ns_per_ms); + if (time_ms > 1000) log.info("Compilation took {d:.3}s", .{stdx.math.divFloat(f32, time_ms, 1000)}); + } + + return BaseExe.init( + allocator, + self._platform, + loaded_executable, + .{ + .n_in = f.num_args, + .result_shapes = f.res_shapes, + .n_devices = sharding.num_replicas * sharding.num_partitions, + }, + ); + } + + fn currentBlock(self: *const CompilationContext) ?Block { return if (self._blocks.len > 0) self._blocks.get(self._blocks.len - 1) else null; } @@ -234,10 +332,9 @@ pub const CompilationContext = struct { pub fn generateBytecode( self: *CompilationContext, allocator: std.mem.Allocator, - fn_name: [:0]const u8, + fn_name: []const u8, comptime func: anytype, - model: *const ModuleSignature(func).ModelT, - args: *const ModuleSignature(func).ArgsT, + args: *const stdx.meta.FnArgs(func), ) error{OutOfMemory}!MlirFn { const frame = self._tracer.frameStart("generateBytecode.emit"); errdefer self._tracer.frameEnd(frame, "generateBytecode.emit"); @@ -248,10 +345,7 @@ pub const CompilationContext = struct { defer arena_state.deinit(); const arena = arena_state.allocator(); - const model_tensor_count = countTensors(model); - const args_tensor_count = countTensors(args); - - const tensor_count = model_tensor_count + args_tensor_count; + const tensor_count = countTensors(args); const mlir_ctx = self.mlirCtx(); const loc = mlir_ctx.location(@src()); @@ -260,8 +354,6 @@ pub const CompilationContext = struct { @memset(locations, mlir.Location.unknown(mlir_ctx)); var input_shapes = try std.ArrayList(Shape).initCapacity(arena, tensor_count); - meta.collect(Tensor.shape, {}, &input_shapes, model) catch unreachable; - stdx.debug.internalAssert(input_shapes.items.len == model_tensor_count, "model has changed ?", .{}); meta.collect(Tensor.shape, {}, &input_shapes, args) catch unreachable; stdx.debug.internalAssert(input_shapes.items.len == tensor_count, "args have changed ?", .{}); @@ -270,7 +362,8 @@ pub const CompilationContext = struct { // Note: this isn't stricly necessary. We call `countTensor` on `fn_res`. // But it forces user to have simpler function. - const out_tensor_count = comptime ops.staticCountTensors(ModuleSignature(func).ReturnT) orelse @compileError("Can't use " ++ @typeName(ModuleSignature(func).ReturnT) ++ " in an MLIR function, because it has a variable number of tensors"); + const ReturnT = stdx.meta.FnResult(func); + const out_tensor_count = comptime ops.staticCountTensors(ReturnT) orelse @compileError("Can't use " ++ @typeName(ReturnT) ++ " in an MLIR function, because it has a variable number of tensors"); // Those are returned to caller so we don't put them in the arena. const fn_res_types = try allocator.alloc(mlir.Type, out_tensor_count); const fn_res_shapes = try allocator.alloc(Shape, out_tensor_count); @@ -285,15 +378,13 @@ pub const CompilationContext = struct { // defer self._buffer_to_arg.shrinkRetainingCapacity(n); try self._buffer_to_arg.ensureUnusedCapacity(self._allocator, @intCast(tensor_count)); - const assigned_model_count = self.mapBlockArguments(model, fn_body.block(), 0); - const assigned_args_count = self.mapBlockArguments(args, fn_body.block(), assigned_model_count); - assert(assigned_model_count == model_tensor_count); - assert(assigned_args_count == tensor_count); + const assigned_args_count = self.mapBlockArguments(args, fn_body.block(), 0); + std.debug.assert(assigned_args_count == tensor_count); const fn_res = forward: { self.activate(); defer self.deactivate(); - break :forward @call(.auto, func, .{model.*} ++ args.*); + break :forward @call(.auto, func, args.*); }; var fn_res_values: [out_tensor_count]mlir.Value = undefined; @@ -339,8 +430,7 @@ pub const CompilationContext = struct { return .{ .mlir_fn = mlir_fn, .name = fn_name, - .n_model = @intCast(model_tensor_count), - .n_args = @intCast(args_tensor_count), + .num_args = @intCast(tensor_count), .res_types = fn_res_types, .res_shapes = fn_res_shapes, .res_donations = fn_res_donations, @@ -400,16 +490,15 @@ pub const CompilationContext = struct { var comp = try zml.module.CompilationContext.init(allocator, "test", platform); defer comp.deinit(); - var tensor_args = .{Tensor{ ._shape = s, ._id = .{ .arg_id = 1234 } }}; - const f = try comp.generateBytecode(allocator, "test.generateBytecode.Local.forward", Local.forward, &model, &tensor_args); + var tensor_args = .{ model, Tensor{ ._shape = s, ._id = .{ .arg_id = 1234 } } }; + const f = try comp.generateBytecode(allocator, "test.generateBytecode.Local.forward", Local.forward, &tensor_args); var mlir_bytecode: std.ArrayListUnmanaged(u8) = .{}; try mlir_bytecode.writer(allocator).print("{}", .{f.mlir_fn.mlirFormatter(.{})}); // Check that the `x` input argument gives its buffer to the result tensor. // `%arg0` is the bias of the model, `%arg1` is `x`. - try std.testing.expectEqual(1, f.n_model); - try std.testing.expectEqual(1, f.n_args); + try std.testing.expectEqual(2, f.num_args); std.testing.expect(std.mem.indexOf(u8, mlir_bytecode.items, "tf.aliasing_output = 0 : i32") != null) catch |err| { log.warn("Didn't produced the expected IR:\n{s}", .{mlir_bytecode.items}); return err; @@ -511,9 +600,8 @@ pub const CompilationContext = struct { self: *CompilationContext, func_name: [:0]const u8, comptime func: anytype, - model: *const ModuleSignature(func).ModelT, - args: *ModuleSignature(func).ArgsT, - ) ModuleSignature(func).ReturnT { + args: stdx.meta.FnArgs(func), + ) stdx.meta.FnResult(func) { var arena_state = std.heap.ArenaAllocator.init(self._allocator); defer arena_state.deinit(); const arena = arena_state.allocator(); @@ -524,7 +612,6 @@ pub const CompilationContext = struct { arena, func_name, func, - model, args, ) catch unreachable; // TODO: do we like unreachable? const bytecode_hash = hashArgs(dummy_result.bytecode_tmp); @@ -534,7 +621,7 @@ pub const CompilationContext = struct { const full_name: [:0]const u8 = if (std.mem.eql(u8, "main", func_name)) arena.dupeZ(u8, func_name) catch unreachable else - std.fmt.allocPrintZ(arena, "{s}.{s}_{x}", .{ @typeName(ModuleSignature(func).ModelT), func_name, key.input_hash }) catch unreachable; + std.fmt.allocPrintZ(arena, "{s}_{x}", .{ func_name, key.input_hash }) catch unreachable; log.info("addFuncToModule {any} {s}", .{ key, full_name }); @@ -542,7 +629,6 @@ pub const CompilationContext = struct { arena, full_name, func, - model, args, ) catch unreachable; @@ -555,7 +641,6 @@ pub const CompilationContext = struct { const loc = self.mlirCtx().location(@src()); const values = arena.alloc(mlir.Value, function.n_model + function.n_args) catch unreachable; - self.extractValues(&model, values[0..function.n_model]); self.extractValues(&args, values[function.n_model..]); const op = dialect.func.call(self.mlirCtx(), function.name, values, function.res_types, loc); @@ -614,7 +699,7 @@ pub const CompilationContext = struct { /// Visit the given struct and extract the mlir.Value and mlir.Type associated with each tensor found. pub fn extractValuesAndTypes(self: *const CompilationContext, v: anytype, values: []mlir.Value, types: []mlir.Type, shapes: []Shape, donations: []Tensor._Donation) void { - assert(values.len == types.len); + std.debug.assert(values.len == types.len); const LocalContext = struct { self: *const CompilationContext, index: usize = 0, @@ -634,7 +719,7 @@ pub const CompilationContext = struct { ctx.index += 1; } }).cb, &context, v); - assert(context.index == values.len); + std.debug.assert(context.index == values.len); } pub fn getValueAndDonation(self: *const CompilationContext, tensor: Tensor) struct { mlir.Value, Tensor._Donation } { @@ -658,485 +743,6 @@ pub const CompilationContext = struct { } }; -/// Visit the given struct and recursively counts the number of tensors found. -pub fn countTensors(v: anytype) usize { - const LocalContext = struct { - count: usize = 0, - }; - var context = LocalContext{}; - meta.visit((struct { - fn cb(inner_context: *LocalContext, _: *const Tensor) void { - inner_context.count += 1; - } - }).cb, &context, v); - return context.count; -} - -/// Visit the given struct and recursively fill the `types` slice with the mlir.Type associated with encountered Tensor. -pub fn fillMlirTypes(v: anytype, mlir_ctx: mlir.Context, types: []mlir.Type) void { - const LocalContext = struct { - index: usize = 0, - mlir_ctx: mlir.Context, - types: []mlir.Type, - }; - var context = LocalContext{ .mlir_ctx = mlir_ctx, .types = types }; - meta.visit((struct { - fn cb(inner_context: *LocalContext, tensor: *const Tensor) void { - inner_context.types[inner_context.index] = mlir.ext.mlirType(inner_context.mlir_ctx, tensor.shape()); - inner_context.index += 1; - } - }).cb, &context, v); - assert(context.index == types.len); -} - -/// Visit the given struct and recursively associate the `block` arguments with the `value` field of each encountered Tensor. -/// -/// This is done so that we have a mapping between the arguments of the kernel associated with a module and the actual Tensors -/// stored in the Module. -fn assignBlockArguments(v: anytype, block: mlir.Block, start: usize) usize { - const LocalContext = struct { index: usize, block: mlir.Block }; - var context = LocalContext{ .block = block, .index = start }; - meta.visit((struct { - fn cb(ctx: *LocalContext, tensor: *Tensor) void { - tensor._id = .{ .mlir = ctx.block.argument(ctx.index) }; - tensor._donation = .{ .arg = @intCast(ctx.index) }; - ctx.index += 1; - } - }).cb, &context, v); - return context.index; -} - -/// Visit the given struct and fill the `buffers` slice with the buffer associated with encountered Tensor. -fn fillBuffers(v: anytype, buffers: []const [*]*pjrt.Buffer, start: u32, len: u32) void { - const LocalContext = struct { - index: u32, - buffers: []const [*]*pjrt.Buffer, - }; - var context: LocalContext = .{ - .index = start, - .buffers = buffers, - }; - meta.visit((struct { - fn cb(ctx: *LocalContext, buffer: *const Buffer) void { - // stdx.debug.assert(!buffer._data.isDeleted(), "Can't use {} (argument buffer {}) because its pjrt buffer has been donated", .{ buffer, ctx.index }); - const model_sharding = ctx.buffers.len; - stdx.debug.assert(buffer._shards.len == model_sharding, "Can't feed a {}-sharded tensor into a {}-sharded model", .{ buffer._shards.len, ctx.buffers.len }); - for (buffer._shards.constSlice(), 0..) |shard, d| { - ctx.buffers[d][ctx.index] = shard; - } - ctx.index += 1; - } - }).cb, &context, v); - assert(context.index == start + len); -} - -/// Visit the given struct and override tensors by creating a new one using the provided PJRT buffers. -pub fn assignRawBuffers(v: anytype, platform: Platform, buffers: []const [*]*pjrt.Buffer, buffer_shapes: []Shape, expected_count: u32) void { - const LocalContext = struct { - index: u32, - platform: Platform, - buffers: []const [*]*pjrt.Buffer, - expected_count: u32, - buffer_shapes: []Shape, - }; - var local_ctx: LocalContext = .{ - .index = 0, - .platform = platform, - .buffers = buffers, - .expected_count = expected_count, - .buffer_shapes = buffer_shapes, - }; - meta.visit((struct { - fn cb(ctx: *LocalContext, buffer: *Buffer) void { - const i = ctx.index; - ctx.index += 1; - if (i >= ctx.expected_count) return; - - var shards: Buffer.Shards = .{}; - for (ctx.buffers) |buff| { - shards.appendAssumeCapacity(buff[i]); - } - buffer.* = Buffer.fromPjrtBuffers(ctx.platform, ctx.buffer_shapes[i], shards.constSlice()); - } - }).cb, &local_ctx, v); - stdx.debug.internalAssert(local_ctx.index == expected_count, "Pjrt call returned {} tensors, but the return type {s}, contains {} Buffers. Note that modules need to have a comptime know number of returned tensors.", .{ buffers.len, @typeName(@TypeOf(v)), local_ctx.index }); -} - -/// Visit the given struct and assign op results to each tensor found. -fn assignResults(op: mlir.Operation, v: anytype, shapes: []Shape) void { - const LocalContext = struct { - index: usize, - op: mlir.Operation, - shapes: ?[]Shape, - }; - var context = LocalContext{ .index = 0, .op = op, .shapes = shapes }; - meta.visit((struct { - fn cb(inner_ctx: *LocalContext, tensor: *Tensor) void { - var new = Tensor.fromMlirValue(inner_ctx.op.result(inner_ctx.index)); - if (inner_ctx.shapes) |sh| { - new._shape = sh[inner_ctx.index]; - } else { - new._shape._tags = tensor._shape._tags; - } - tensor.* = new; - inner_ctx.index += 1; - } - }).cb, &context, v); - assert(context.index == op.numResults()); -} - -/// Represents an MLIR module compiled into a PJRT executable. -/// The BaseExe is a plain old struct and doesn't have information -/// about Zig types. -const BaseExe = struct { - /// The platform for which this module was compiled. - platform: Platform, - /// The PJRT executable representing the compiled module. - exe: *pjrt.LoadedExecutable, - /// Number of buffers in the model. - model_buffer_count: u32, - /// Number of buffers in the arguments. - args_buffer_count: u32, - /// Number of buffers in result. - result_buffer_count: u32, - /// Shapes of buffers in result. - result_buffer_shapes: []Shape, - /// Num devices used (>1 for sharded executable) - num_devices: u8, - /// Allocator backing result_buffer_shapes and deinit by ExeWithWeights - _allocator: std.heap.ArenaAllocator, - - pub fn serialize(self: BaseExe, writer: anytype) !void { - var executable = try self.exe.getExecutable(self.pjrt_api); - var serialize_result = try executable.serialize(self.platform.pjrt_api); - defer serialize_result.deinit(); - try writer.writeAll(serialize_result.bytes); - } -}; - -/// Represents a ZML model, compiled into a PJRT executable. -/// -/// It's not directly callable, as it doesn't have associated model weights. -/// use `prepare` to assign weights and pre allocate memory needed to call. -pub fn Exe(comptime func: anytype) type { - const Signature = ModuleSignature(func); - return struct { - const Self = @This(); - - /// The raw untyped compiled module. - inner: BaseExe, - - /// Packages the given model weight with an `Exe` to produce an `ExeWithWeights` that can be called. - pub fn prepare(self: Self, allocator: std.mem.Allocator, model: Bufferized(Signature.ModelT)) !ExeWithWeights(func) { - return ExeWithWeights(func).initFromModel(allocator, self.inner, model); - } - - pub fn serialize(self: Self, writer: anytype) !void { - return try self.inner.serialize(writer); - } - - // pub fn deserialize(allocator: std.mem.Allocator, platform: Platform, reader: anytype) !Self { - // const bytes = try reader.readToEndAlloc(allocator, max_pjrt_executable_size); - // defer allocator.free(bytes); - // return platform.pjrt_client.deserializeAndLoad(platform.pjrt_api, bytes); - // } - }; -} - -/// Represents a ZML model, compiled into a PJRT executable, and ready to call. -/// The buffers for the model weights are saved inside the struct and will be used in `call`. -/// You only need to pass the remaining arguments. -pub fn ExeWithWeights(comptime func: anytype) type { - const Signature = ModuleSignature(func); - return struct { - const Self = @This(); - - /// The raw untyped compiled module. - inner: BaseExe, - - /// Pre-allocated slice of buffers to use as inputs when the module is called. - input_per_device: []const [*]*pjrt.Buffer, - - /// Pre-allocated slice of buffers to use as outputs when the module is called. - output_per_device: []const [*]*pjrt.Buffer, - - /// Internal memory slice used. - _all_buffers: []*pjrt.Buffer, - _all_per_device: [][*]*pjrt.Buffer, - - /// And the allocator backing _data_buffer. - _allocator: std.mem.Allocator, - - pub fn initFromModel(allocator: std.mem.Allocator, inner: BaseExe, model: Bufferized(Signature.ModelT)) !Self { - const n_input_buffers = inner.model_buffer_count + inner.args_buffer_count; - const n_output_buffers = inner.result_buffer_count; - const n_devices = inner.num_devices; - - // Allocate once for all the *pjrt.Buffer we need to store ... - const all_buffers = try allocator.alloc(*pjrt.Buffer, (n_input_buffers + n_output_buffers) * n_devices); - errdefer allocator.free(all_buffers); - const all_input_buffers, const all_output_buffers = splitBuffer(*pjrt.Buffer, all_buffers, .{ n_input_buffers * n_devices, n_output_buffers * n_devices }); - - // ... and once for all the [*]*pjrt.Buffer. - const all_per_device = try allocator.alloc([*]*pjrt.Buffer, 2 * n_devices); - errdefer allocator.free(all_per_device); - const input_per_device, const output_per_device = splitBuffer([*]*pjrt.Buffer, all_per_device, .{ n_devices, n_devices }); - - for (0..n_devices) |i| { - input_per_device[i] = all_input_buffers[i * n_input_buffers ..].ptr; - output_per_device[i] = all_output_buffers[i * n_output_buffers ..].ptr; - } - - fillBuffers(&model, input_per_device, 0, inner.model_buffer_count); - // Note: all_output_buffers is left undefined, it will be written to in `call`. - - return .{ - .inner = inner, - .input_per_device = input_per_device, - .output_per_device = output_per_device, - ._all_buffers = all_buffers, - ._all_per_device = all_per_device, - ._allocator = allocator, - }; - } - - pub fn deinit(self: Self) void { - // Free in reverse order of allocation. - self._allocator.free(self._all_per_device); - self._allocator.free(self._all_buffers); - self.inner._allocator.deinit(); - } - - pub fn platform(self: Self) Platform { - return self.inner.platform; - } - - pub fn getOutputBuffer(self: Self, i: usize) Buffer { - var shards: Buffer.Shards = .{}; - for (self.output_per_device) |dev_out| { - shards.appendAssumeCapacity(dev_out[i]); - } - - const out_shape = self.inner.result_buffer_shapes[i]; - return Buffer.fromPjrtBuffers(self.platform(), out_shape, shards.constSlice()); - } - - pub fn call(self: Self, args: Bufferized(Signature.ArgsT)) Bufferized(Signature.ReturnT) { - fillBuffers(&args, self.input_per_device, self.inner.model_buffer_count, self.inner.args_buffer_count); - var events = [_]?*pjrt.Event{null} ** Platform.MAX_NUM_DEVICES; - const sharding = self.platform().sharding(); - - self.inner.exe.execute(self.inner.platform.pjrt_api, .{ - .arguments = self.input_per_device, - .num_args = self.inner.args_buffer_count + self.inner.model_buffer_count, - .results = self.output_per_device, - .events = events[0..sharding.num_partitions], - // TODO: this allows to tell a specific buffer shouldn't be donated, - // even if it has been marked as "can be donated" during compilation. - .non_donatable_input_indices = &.{}, - }) catch unreachable; - - for (events[0..sharding.num_partitions]) |e| { - if (e) |ev| { - ev.await_(self.inner.platform.pjrt_api) catch unreachable; - } - } - - var result: Bufferized(Signature.ReturnT) = undefined; - assignRawBuffers(&result, self.inner.platform, self.output_per_device, self.inner.result_buffer_shapes, self.inner.result_buffer_count); - return result; - } - }; -} - -/// Compiles the given module with the given arguments. -/// The `model` (first fn argument), is treated differently from the other args. -/// This helps to have two separate lifetimes for the model buffers, -/// and for the arguments buffer. -fn compileInternal( - allocator: std.mem.Allocator, - context: *CompilationContext, - comptime func: anytype, - model: ModuleSignature(func).ModelT, - args: ShapeOf(ModuleSignature(func).ArgsT), -) !BaseExe { - var arena_state = std.heap.ArenaAllocator.init(allocator); - defer arena_state.deinit(); - const arena = arena_state.allocator(); - - var timer = std.time.Timer.start() catch null; - const tensor_args = context.tensorFromShapes(ModuleSignature(func).ArgsT, arena, args); - // Run in a dedicated thread because compilation relies on `threadlocal`. - const f = try asynk.callBlocking(CompilationContext.generateBytecode, .{ context, arena, "main", func, &model, &tensor_args }); - context._module.getBody().appendOperation(f.mlir_fn); - - const sharding = context._platform.sharding(); - const mlir_ctx = context._mlir_ctx; - context._module.op().setAttributeByName("mhlo.num_replicas", mlir.IntegerAttribute(.i32).init(mlir_ctx, sharding.num_replicas).asAttr()); - context._module.op().setAttributeByName("mhlo.num_partitions", mlir.IntegerAttribute(.i32).init(mlir_ctx, sharding.num_partitions).asAttr()); - - if (context._platform.compilation_options.xla_dump_to) |xla_dump_to| { - // Write the mlir to a file. All errors are discarded, since this is for debugging only. - if (std.fs.openDirAbsolute(xla_dump_to, .{})) |dir| { - const name_attr = context._module.op().getAttributeByName("sym_name").?.as(mlir.StringAttribute).?; - const file_name = std.fmt.allocPrint(arena, "{s}.mlir", .{name_attr.value()}) catch name_attr.value(); - if (dir.createFile(file_name, .{ .truncate = true })) |file| { - context._module.op().print(file.writer(), .{ .debug_info = true, .debug_info_pretty_form = false }); - log.info("Wrote MLIR to {s}/{s}", .{ xla_dump_to, file_name }); - } else |_| { - log.warn("Failed to open {s}", .{file_name}); - } - } else |_| { - log.warn("Folder not found {s}", .{xla_dump_to}); - } - } - - const loaded_executable = loadOrCompilePjrtExecutable(arena, context._platform, context._module) catch |err| { - log.err( - "pjrt-{s} failed to compile following valid MLIR:\n{}\n{}", - .{ @tagName(context._platform.target), context._module.op().mlirFormatter(.{}), err }, - ); - return err; - }; - - log.debug("******** ZML generated MLIR ********", .{}); - log.debug("{}", .{context._module.op().mlirFormatter(.{})}); - - if (timer) |*t| { - const time_ms = @divFloor(t.lap(), std.time.ns_per_ms); - if (time_ms > 1000) log.info("Compilation took {d:.3}s", .{stdx.math.divFloat(f32, time_ms, 1000)}); - } - - var arena_state_exe = std.heap.ArenaAllocator.init(allocator); - const arena_exe = arena_state_exe.allocator(); - - return .{ - .platform = context._platform, - .exe = loaded_executable, - .model_buffer_count = f.n_model, - .args_buffer_count = f.n_args, - .result_buffer_count = @intCast(f.res_types.len), - .result_buffer_shapes = arena_exe.dupe(Shape, f.res_shapes) catch unreachable, - .num_devices = sharding.num_replicas * sharding.num_partitions, - ._allocator = arena_state_exe, - }; -} - -pub fn load( - allocator: std.mem.Allocator, - comptime Model: type, - init_args: anytype, - comptime func: @TypeOf(.literal), - args_shapes: ShapeOf(ModuleSignature(@field(Model, @tagName(func))).ArgsT), - buffer_store: aio.BufferStore, - platform: Platform, -) !Exe(@field(Model, @tagName(func))) { - var arena_state = std.heap.ArenaAllocator.init(allocator); - defer arena_state.deinit(); - const arena = arena_state.allocator(); - var model = try aio.populateModel(Model, arena, buffer_store); - - // If the Model has a "init" function, call it with the given parameters. - if (@hasDecl(Model, "init")) { - // TODO(Corentin,@Improvement): Add a warning/error if there is no init function but init_args is non-void. - @call(.auto, Model.init, .{@as(*Model, &model)} ++ init_args); - } - - return compileModel(allocator, model, func, args_shapes, platform); -} - -/// Compiles a Model struct with the given configuration and shapes, for the given platform. -/// The steps are: -/// * lookup at tensors available in the store and create a `model: Model` struct with them -/// * call `model.init(init_args)` to fields of the model that aren't Tensor, ie hyperparemeters/config -/// * generate MLIR by calling `model.forward` with tensor of the given shapes and other arguments -pub fn compile( - allocator: std.mem.Allocator, - comptime func: anytype, - init_args: anytype, - args_shapes: ShapeOf(ModuleSignature(func).ArgsT), - buffer_store: aio.BufferStore, - platform: Platform, -) !Exe(func) { - const ModelT = ModuleSignature(func).ModelT; - - var arena_state = std.heap.ArenaAllocator.init(allocator); - defer arena_state.deinit(); - const arena = arena_state.allocator(); - var model = try aio.populateModel(ModelT, arena, buffer_store); - - // If the Model has a "init" function, call it with the given parameters. - if (@hasDecl(ModelT, "init")) { - // TODO(Corentin,@Improvement): Add a warning/error if there is no init function but init_args is non-void. - @call(.auto, ModelT.init, .{@as(*ModelT, &model)} ++ init_args); - } - - return compileModel(allocator, func, model, args_shapes, platform); -} - -/// Compiles a Model struct with the given configuration and shapes, for the given platform. -/// Generate MLIR by calling `model.forward` with tensor of the given shapes and other arguments -pub fn compileModel( - allocator: std.mem.Allocator, - comptime func: anytype, - model: ModuleSignature(func).ModelT, - args_shapes: ShapeOf(ModuleSignature(func).ArgsT), - platform: Platform, -) !Exe(func) { - const ModelT = ModuleSignature(func).ModelT; - const name = @typeName(ModelT) ++ ".forward"; - log.info("Compiling {s} with {}", .{ name, args_shapes }); - - var context = try CompilationContext.init(allocator, name, platform); - defer context.deinit(); - - const raw_module = try compileInternal(allocator, &context, func, model, args_shapes); - - return .{ .inner = raw_module }; -} - -/// Compiles a function with the given configuration and shapes, for the given platform. -/// Generate MLIR by calling the given function with tensor of the given shapes. -pub fn compileFn( - allocator: std.mem.Allocator, - comptime func: anytype, - args: ShapeOf(stdx.meta.FnArgs(func)), - platform: Platform, -) !FnExe(func) { - const name = @typeName(@TypeOf(func)); - var context = try CompilationContext.init(allocator, name, platform); - defer context.deinit(); - - const Local = struct { - // This is the function we will actually compile. - pub fn forward(_: void, inner_args: stdx.meta.FnArgs(func)) stdx.meta.FnResult(func) { - return @call(.auto, func, inner_args); - } - }; - - const void_model: void = {}; - const raw_module = try compileInternal(allocator, &context, Local.forward, void_model, .{args}); - // But we set the signature so that you can call the module as you would call the function. - return try ExeWithWeights(FnWithVoidArg(func)).initFromModel(allocator, raw_module, void_model); -} - -pub fn FnExe(comptime func: anytype) type { - return ExeWithWeights(FnWithVoidArg(func)); -} - -fn FnWithVoidArg(comptime func: anytype) type { - const fn_info = @typeInfo(@TypeOf(func)).Fn; - const void_param = std.builtin.Type.Fn.Param{ .is_generic = false, .is_noalias = false, .type = void }; - stdx.debug.assertComptime(!fn_info.is_generic, "Can't do reflection on generic function: {}", .{@TypeOf(func)}); - return @Type(.{ .Fn = .{ - .calling_convention = fn_info.calling_convention, - .is_generic = false, - .is_var_args = fn_info.is_var_args, - .return_type = fn_info.return_type, - .params = [1]std.builtin.Type.Fn.Param{void_param} ++ fn_info.params, - } }); -} - fn computeModuleHash(platform: Platform, module: mlir.Module) u64 { var hasher = std.hash.XxHash64.init(0); var hasher_writer = xxHash64Writer(&hasher); @@ -1154,29 +760,40 @@ fn computeModuleHash(platform: Platform, module: mlir.Module) u64 { return hasher.final(); } -const max_pjrt_executable_size = 400 * 1024 * 1024; - -fn loadPjrtExecutable(arena: std.mem.Allocator, platform: Platform, module_hash: u64, compilation_cache_location: []const u8) !*pjrt.LoadedExecutable { - const resolved_path = try std.fs.cwd().realpathAlloc(arena, compilation_cache_location); - const compilation_cache_dir = try std.fs.openDirAbsolute(resolved_path, .{}); - var buf: [16]u8 = undefined; - const filename = try std.fmt.bufPrint(&buf, "{x}", .{module_hash}); - const loaded_executable_file = try compilation_cache_dir.openFile(filename, .{}); - defer loaded_executable_file.close(); - - const bytes = try loaded_executable_file.readToEndAlloc(arena, max_pjrt_executable_size); - - return platform.pjrt_client.deserializeAndLoad(platform.pjrt_api, bytes); -} - -fn storePjrtExecutable(arena: std.mem.Allocator, platform: Platform, loaded_executable: *pjrt.LoadedExecutable, module_hash: u64, compilation_cache_location: []const u8) !void { - const resolved_path = try std.fs.cwd().realpathAlloc(arena, compilation_cache_location); - const compilation_cache_dir = std.fs.openDirAbsolute(resolved_path, .{}) catch blk: { - try std.fs.makeDirAbsolute(resolved_path); - break :blk try std.fs.openDirAbsolute(resolved_path, .{}); +fn absoluteCacheFileZ(arena: std.mem.Allocator, cache_path: ?[]const u8, module_hash: u64) !?[:0]const u8 { + if (cache_path == null) return null; + const resolved_path = try std.fs.cwd().realpathAlloc(arena, cache_path.?); + std.fs.makeDirAbsolute(resolved_path) catch |err| switch (err) { + error.PathAlreadyExists => {}, + else => return err, }; - const loaded_executable_file = try compilation_cache_dir.createFile(try std.fmt.allocPrint(arena, "{x}", .{module_hash}), .{}); + var buf: [24]u8 = undefined; + const module_name = std.fmt.bufPrint(&buf, "{x}.pjrt", .{module_hash}) catch unreachable; + return try std.fs.path.joinZ(arena, &.{ resolved_path, module_name }); +} + +const max_pjrt_executable_size = 400 * 1024 * 1024; + +fn loadPjrtExecutable(arena: std.mem.Allocator, platform: Platform, absolute_file: [:0]const u8) !*pjrt.LoadedExecutable { + const loaded_executable_file = try std.fs.openFileAbsoluteZ(absolute_file, .{}); + defer loaded_executable_file.close(); + + const exe_size = if (loaded_executable_file.stat()) |stat| stat.size else |_| max_pjrt_executable_size; + const bytes = try arena.alloc(u8, exe_size); + defer arena.free(bytes); + + const size = try loaded_executable_file.readAll(bytes); + + log.info("Loading module from {s}", .{absolute_file}); + return platform.pjrt_client.deserializeAndLoad(platform.pjrt_api, bytes[0..size]) catch |err| { + log.warn("Failed to load module: {}", .{err}); + return err; + }; +} + +fn storePjrtExecutable(platform: Platform, loaded_executable: *pjrt.LoadedExecutable, absolute_file: [:0]const u8) !void { + const loaded_executable_file = try std.fs.createFileAbsoluteZ(absolute_file, .{}); defer loaded_executable_file.close(); var executable = try loaded_executable.getExecutable(platform.pjrt_api); @@ -1186,30 +803,10 @@ fn storePjrtExecutable(arena: std.mem.Allocator, platform: Platform, loaded_exec defer serialize_result.deinit(); try loaded_executable_file.writeAll(serialize_result.bytes); + log.info("Stored module to {s}", .{absolute_file}); } -fn loadOrCompilePjrtExecutable( - arena: std.mem.Allocator, - platform: Platform, - module: mlir.Module, -) !*pjrt.LoadedExecutable { - const tracer = Tracer.init("ai.zml.compilation"); - const compile_frame = tracer.frameStart("pjrt cached compilation"); - defer tracer.frameEnd(compile_frame, "pjrt cached compilation"); - const module_hash = computeModuleHash(platform, module); - - if (platform.compilation_options.cache_location) |compilation_cache_location| { - log.debug("Loading module from {s}", .{compilation_cache_location}); - return loadPjrtExecutable(arena, platform, module_hash, compilation_cache_location) catch |err| { - log.debug("Failed to load module: {}", .{err}); - return compileModuleToPjrtExecutable(arena, platform, module, module_hash); - }; - } else { - return compileModuleToPjrtExecutable(arena, platform, module, module_hash); - } -} - -fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, module: mlir.Module, module_hash: u64) !*pjrt.LoadedExecutable { +fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, module: mlir.Module) !*pjrt.LoadedExecutable { const sharding = platform.sharding(); // NOTE(Corendos): Hack needed because Protobuf struct are not public. @@ -1303,16 +900,80 @@ fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, m const loaded_executable = try platform.pjrt_client.compile(platform.pjrt_api, arena, module, options_bytes); errdefer loaded_executable.deinit(); - if (platform.compilation_options.cache_location) |compilation_cache_location| { - log.debug("Storing module to {s}", .{compilation_cache_location}); - storePjrtExecutable(arena, platform, loaded_executable, module_hash, compilation_cache_location) catch |err| { - log.debug("Failed to store module: {}", .{err}); - }; - } - return loaded_executable; } +/// Visit the given struct and recursively counts the number of tensors found. +pub fn countTensors(v: anytype) usize { + const LocalContext = struct { + count: usize = 0, + }; + var context = LocalContext{}; + meta.visit((struct { + fn cb(inner_context: *LocalContext, _: *const Tensor) void { + inner_context.count += 1; + } + }).cb, &context, v); + return context.count; +} + +/// Visit the given struct and recursively fill the `types` slice with the mlir.Type associated with encountered Tensor. +pub fn fillMlirTypes(v: anytype, mlir_ctx: mlir.Context, types: []mlir.Type) void { + const LocalContext = struct { + index: usize = 0, + mlir_ctx: mlir.Context, + types: []mlir.Type, + }; + var context = LocalContext{ .mlir_ctx = mlir_ctx, .types = types }; + meta.visit((struct { + fn cb(inner_context: *LocalContext, tensor: *const Tensor) void { + inner_context.types[inner_context.index] = mlir.ext.mlirType(inner_context.mlir_ctx, tensor.shape()); + inner_context.index += 1; + } + }).cb, &context, v); + std.debug.assert(context.index == types.len); +} + +/// Visit the given struct and recursively associate the `block` arguments with the `value` field of each encountered Tensor. +/// +/// This is done so that we have a mapping between the arguments of the kernel associated with a module and the actual Tensors +/// stored in the Module. +fn assignBlockArguments(v: anytype, block: mlir.Block, start: usize) usize { + const LocalContext = struct { index: usize, block: mlir.Block }; + var context = LocalContext{ .block = block, .index = start }; + meta.visit((struct { + fn cb(ctx: *LocalContext, tensor: *Tensor) void { + tensor._id = .{ .mlir = ctx.block.argument(ctx.index) }; + tensor._donation = .{ .arg = @intCast(ctx.index) }; + ctx.index += 1; + } + }).cb, &context, v); + return context.index; +} + +/// Visit the given struct and assign op results to each tensor found. +fn assignResults(op: mlir.Operation, v: anytype, shapes: []Shape) void { + const LocalContext = struct { + index: usize, + op: mlir.Operation, + shapes: ?[]Shape, + }; + var context = LocalContext{ .index = 0, .op = op, .shapes = shapes }; + meta.visit((struct { + fn cb(inner_ctx: *LocalContext, tensor: *Tensor) void { + var new = Tensor.fromMlirValue(inner_ctx.op.result(inner_ctx.index)); + if (inner_ctx.shapes) |sh| { + new._shape = sh[inner_ctx.index]; + } else { + new._shape._tags = tensor._shape._tags; + } + tensor.* = new; + inner_ctx.index += 1; + } + }).cb, &context, v); + std.debug.assert(context.index == op.numResults()); +} + pub const XxHash64Writer = struct { hasher: *std.hash.XxHash64, @@ -1333,101 +994,6 @@ pub fn xxHash64Writer(hasher: *std.hash.XxHash64) XxHash64Writer { return .{ .hasher = hasher }; } -pub fn hasTensors(comptime T: type) bool { - if (T == Tensor) return true; - - return switch (@typeInfo(T)) { - inline .Array, .Pointer, .Optional => |info| hasTensors(info.child), - inline .Struct, .Union => |info| { - inline for (info.fields) |field| { - if (hasTensors(field.type)) return true; - } - return false; - }, - else => false, - }; -} - -test "hasTensors" { - comptime { - try std.testing.expect(hasTensors(?Tensor)); - try std.testing.expect(hasTensors(struct { u8, ?Tensor })); - try std.testing.expect(!hasTensors(struct { u8, usize })); - } -} - -pub fn hasConstTensors(comptime T: type, comptime self_const: bool) bool { - if (T == Tensor) return self_const; - - return switch (@typeInfo(T)) { - inline .Array, .Optional => |info| hasTensors(info.child) and self_const, - .Pointer => |ptr_info| hasConstTensors(ptr_info.child, ptr_info.is_const), - inline .Struct, .Union => |info| { - inline for (info.fields) |field| { - if (hasConstTensors(field.type, self_const)) return true; - } - return false; - }, - else => false, - }; -} - -test "hasConstTensors" { - try std.testing.expect(!hasConstTensors(?Tensor, false)); - try std.testing.expect(hasConstTensors(struct { u8, ?Tensor }, true)); - try std.testing.expect(!hasConstTensors(struct { u8, *Tensor }, true)); - try std.testing.expect(hasConstTensors(struct { u8, *const Tensor }, false)); - try std.testing.expect(!hasConstTensors(struct { *Tensor }, false)); - try std.testing.expect(!hasConstTensors(std.meta.Tuple(&[_]type{*Tensor}), false)); - try std.testing.expect(!hasConstTensors(struct { u8, usize }, false)); - try std.testing.expect(hasConstTensors(struct { [5]Tensor, usize }, true)); - try std.testing.expect(!hasConstTensors(struct { [5]Tensor, usize }, false)); -} - -// making this a struct force all fields to be evaluted on creation, -// which gives a better error stacktrace -// than delaying the error to when the object fields are read. -const Sign = struct { - FuncT: type, - ModelT: type, - ArgsT: type, - ReturnT: type, -}; - -pub fn ModuleSignature(comptime func: anytype) Sign { - const FuncT = if (@TypeOf(func) == type) func else @TypeOf(func); - return .{ - .FuncT = FuncT, - .ModelT = @typeInfo(FuncT).Fn.params[0].type orelse @compileError("cannot create ModuleSignature for function with an 'anytype' parameter"), - .ArgsT = blk: { - const function_info = @typeInfo(FuncT); - if (function_info.Fn.params.len < 2) { - break :blk @TypeOf(.{}); - } - - var argument_field_list: [function_info.Fn.params.len - 1]type = undefined; - for (function_info.Fn.params[1..], 0..) |arg, i| { - const T = arg.type orelse @compileError("cannot create ModuleSignature for function with an 'anytype' parameter"); - argument_field_list[i] = T; - } - - break :blk std.meta.Tuple(&argument_field_list); - }, - .ReturnT = @typeInfo(FuncT).Fn.return_type.?, - }; -} - -pub const MlirFn = struct { - name: [:0]const u8, - n_model: u32, - n_args: u32, - res_types: []mlir.Type, - res_shapes: []Shape, - res_donations: []Tensor._Donation, - mlir_fn: mlir.Operation, -}; - -// TODO(Corentin): Remove that pub const FnCache = struct { pub const Key = struct { fn_ptr: *const anyopaque, input_hash: u64 }; @@ -1543,10 +1109,14 @@ pub fn hashArgs(mod: anytype) u64 { return hasher.final(); } -pub fn hashTensor(hasher: *std.hash.Wyhash, tensor: Tensor) void { +pub fn hashShape(hasher: *std.hash.Wyhash, shape: Shape) void { // Note: if we enforced 0-init dims then we could hash dims instead. - hashArray(hasher, tensor.dims(), .Shallow); - hash(hasher, tensor.dtype(), .Shallow); + hashArray(hasher, shape.dims(), .Shallow); + hash(hasher, shape._dtype, .Shallow); + hash(hasher, shape._sharding_info, .Shallow); + for (shape.tags()) |tag| { + hash(hasher, @intFromPtr(tag), .Shallow); + } } const HashStrategy = std.hash.Strategy; @@ -1556,7 +1126,8 @@ const tensorAwareHash = hash; // alias for when "hash" is ambiguous /// Strategy is provided to determine if pointers should be followed or not. pub fn hash(hasher: *std.hash.Wyhash, key: anytype, comptime strat: HashStrategy) void { const Key = @TypeOf(key); - if (Key == Tensor) return hashTensor(hasher, key); + if (Key == Tensor) return hashShape(hasher, key.shape()); + if (Key == Shape) return hashShape(hasher, key); if (strat == .Shallow and std.meta.hasUniqueRepresentation(Key)) { hasher.update(std.mem.asBytes(&key)); @@ -1675,14 +1246,3 @@ fn hashArray(hasher: anytype, key: anytype, comptime strat: HashStrategy) void { hash(hasher, element, strat); } } - -fn splitBuffer(T: type, buffer: []T, lengths: anytype) [lengths.len][]T { - var res: [lengths.len][]T = undefined; - var i: usize = 0; - inline for (&res, lengths) |*r, len| { - r.* = buffer[i .. i + len]; - i += len; - } - std.debug.assert(i == buffer.len); - return res; -} diff --git a/zml/nn/cuda.zig b/zml/nn/cuda.zig index 456b52d..4c840cb 100644 --- a/zml/nn/cuda.zig +++ b/zml/nn/cuda.zig @@ -15,7 +15,7 @@ const CompilationContext = module.CompilationContext; pub fn canUseCudnnSdpa(q_shape: Shape) bool { const ctx = CompilationContext.current(); // TODO(Corendos): Check cuda version, cudnn version, device compatibility. - if (!ctx.targetIs(.cuda)) return false; + if (ctx.target() != .cuda) return false; if (q_shape.rank() != 4) return false; diff --git a/zml/testing.zig b/zml/testing.zig index d4a2bef..38688f8 100644 --- a/zml/testing.zig +++ b/zml/testing.zig @@ -187,7 +187,7 @@ pub fn testLayerOut( log.info("Testing {s}", .{name}); const fwd = @TypeOf(layer).forward; - const FwdSign = zml.module.ModuleSignature(fwd); + const FwdSign = zml.ModuleSignature(fwd); const input_tensors = try zml.aio.populateModelWithPrefix(FwdSign.ArgsT, alloc, activations, name ++ ".in"); const input_shapes = try shapesOf(input_tensors, alloc); @@ -204,7 +204,7 @@ pub fn testLayerOut( if (exe.inner.result_buffer_count != n_out_exp) { log.warn("Reference models produces {d} outputs, but implementation produces {d}", .{ n_out_exp, exe.inner.result_buffer_count }); } - const mod = try exe.prepare(alloc, layer_weights); + const mod = exe.prepare(layer_weights); const FetchCtx = struct { store: zml.aio.BufferStore, diff --git a/zml/zml.zig b/zml/zml.zig index 484e414..2f5b799 100644 --- a/zml/zml.zig +++ b/zml/zml.zig @@ -18,6 +18,7 @@ pub const Tensor = @import("tensor.zig").Tensor; // Namespaces pub const context = @import("context.zig"); +pub const exe = @import("exe.zig"); pub const floats = @import("floats.zig"); pub const helpers = @import("helpers.zig"); pub const nn = @import("nn.zig"); @@ -30,9 +31,11 @@ pub const torch = @import("torch.zig"); pub const tokenizer = @import("tokenizer.zig"); pub const call = ops.call; -pub const compile = module.compile; -pub const compileModel = module.compileModel; -pub const compileFn = module.compileFn; +pub const compile = exe.compile; +pub const compileFn = exe.compileFn; +pub const compileModel = exe.compileModel; +pub const FnExe = exe.FnExe; +pub const ModuleExe = exe.ModuleExe; pub const ops = @import("ops.zig"); pub const tools = struct {