From 98b512c4956c9b1f03e47aaf8d4ad3823bef8f3e Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Thu, 19 Oct 2023 17:01:55 +0000 Subject: [PATCH] Implement func.call emission and function caching across MLIR dialects and ZML module/ops, propagating tags and donations. --- mlir/dialects/func.zig | 3 +- mlir/dialects/stablehlo.zig | 2 +- mlir/mlir.zig | 4 +- zml/meta.zig | 36 +++--- zml/module.zig | 243 ++++++++++++++++++++---------------- zml/ops.zig | 7 +- zml/testing.zig | 2 +- 7 files changed, 162 insertions(+), 135 deletions(-) diff --git a/mlir/dialects/func.zig b/mlir/dialects/func.zig index 9e7c281..d947fed 100644 --- a/mlir/dialects/func.zig +++ b/mlir/dialects/func.zig @@ -13,8 +13,7 @@ pub fn func( location: mlir.Location, }, ) mlir.Operation { - const AttrTuple = struct { [:0]const u8, mlir.Attribute }; - var attrs_tuple_buffer = std.BoundedArray(AttrTuple, 4){}; + var attrs_tuple_buffer = std.BoundedArray(mlir.AttrTuple, 4){}; attrs_tuple_buffer.appendAssumeCapacity(.{ "sym_name", mlir.StringAttribute.init(ctx, args.sym_name).as(mlir.Attribute).? }); attrs_tuple_buffer.appendAssumeCapacity(.{ "function_type", mlir.TypeAttribute.init((mlir.FunctionType.init(ctx, args.args, args.results) catch unreachable).as(mlir.Type).?).as(mlir.Attribute).? }); if (args.arg_attrs.len > 0) { diff --git a/mlir/dialects/stablehlo.zig b/mlir/dialects/stablehlo.zig index 6d2c740..a5c59cb 100644 --- a/mlir/dialects/stablehlo.zig +++ b/mlir/dialects/stablehlo.zig @@ -189,7 +189,7 @@ pub fn dot_general( }, ) mlir.Operation { const precisions = [1]mlir.Attribute{opts.precision.precisionAttr(ctx)} ** 2; - const attributes = [3]mlir.Operation.AttrTuple{ + const attributes = [3]mlir.AttrTuple{ .{ "dot_dimension_numbers", DotDimensionNumbersAttribute.init(ctx, .{ .lhs_batching_dimensions = opts.lhs_batching_dimensions, diff --git a/mlir/mlir.zig b/mlir/mlir.zig index 05aef14..fd6bdb8 100644 --- a/mlir/mlir.zig +++ b/mlir/mlir.zig @@ -333,6 +333,8 @@ pub const Identifier = struct { } }; +pub const AttrTuple = struct { [:0]const u8, Attribute }; + pub const Attribute = struct { _inner: c.MlirAttribute, pub usingnamespace MlirHelpers(Attribute, .{ @@ -791,8 +793,6 @@ pub const Operation = struct { ) orelse Error.InvalidMlir; } - pub const AttrTuple = struct { [:0]const u8, Attribute }; - pub fn make(ctx: Context, op_name: [:0]const u8, args: struct { operands: ?[]const Value = null, variadic_operands: ?[]const []const Value = null, diff --git a/zml/meta.zig b/zml/meta.zig index 171bbb2..8a08792 100644 --- a/zml/meta.zig +++ b/zml/meta.zig @@ -86,12 +86,14 @@ pub fn MapType(From: type, To: type) type { /// For example it can convert from a comptime array to a runtime slice. /// `mapAlloc` can allocate new slices to write the result if the result struct requires it. /// The caller is owning said allocations, using an `ArenaAllocator` might help tracking them. -// TODO: handle tuple to slice conversion +/// +/// Note: to avoid infinite loop, mapAlloc doesn't look for `From` fields inside `To` struct. +/// Any `To` struct inside `from` will be copied over to the target. pub fn mapAlloc(comptime cb: anytype, allocator: std.mem.Allocator, ctx: FnParam(cb, 0), from: anytype, to: anytype) !void { - // const Ctx = FnParam(cb, 0); + // TODO: handle tuple to slice conversion const From = FnParam(cb, 1); + const To = stdx.meta.FnResult(cb); const FromStruct = @TypeOf(from); - const type_info_to_ptr = @typeInfo(@TypeOf(to)); if (type_info_to_ptr != .Pointer) { stdx.debug.compileError("convertType is expecting a mutable `to` argument but received: {}", .{@TypeOf(to)}); @@ -100,11 +102,12 @@ pub fn mapAlloc(comptime cb: anytype, allocator: std.mem.Allocator, ctx: FnParam const type_info_to = @typeInfo(ToStruct); if (FromStruct == From) { - // Special case for converting from shape to tensor: - // If the target type is Shape, skip tensor conversion. - // A general `to.* = from` assignment causes a Zig error in this scenario. - // (see below) - if (ToStruct == @import("shape.zig").Shape and FromStruct == ToStruct) { // FromStruct) { + // We have an issues with `Tensor` -> `Shape` -> `Tensor` conversion when compiling ZML functions where one argument is a Shape itself. + // Normally we should call `cb` on all `Shape`. + // But the "ShapeOf" struct will have more Shape than need on the output. + // So here we take a hint from the receiving object. + // If the target is indeed a Tensor, use the callback, but if the target is `Shape` just copy it over. + if (ToStruct != To and FromStruct == ToStruct) { to.* = from; } else { to.* = @call(.auto, cb, .{ ctx, from }); @@ -112,19 +115,11 @@ pub fn mapAlloc(comptime cb: anytype, allocator: std.mem.Allocator, ctx: FnParam return; } - // This is generally due to a user error, but let this fn compile, - // and the user will have a Zig error. - if (FromStruct == ToStruct) { + if (FromStruct == To) { to.* = from; return; } - // Don't go into Shape objects because of the weird tag. - // TODO: we could not error on pointers to basic types like u8 - if (FromStruct == @import("shape.zig").Shape) { - to.* = from; - return; - } switch (type_info_to) { .Struct => |info| inline for (info.fields) |field| { // if (field.is_comptime) continue; @@ -155,10 +150,11 @@ pub fn mapAlloc(comptime cb: anytype, allocator: std.mem.Allocator, ctx: FnParam .One => switch (type_info_to_ptr.Pointer.size) { // pointer to array -> slice promotion .Slice => { - to.* = try allocator.alloc(type_info_to_ptr.Pointer.child, from.len); - for (from, to.*) |f, *t| { + const items = try allocator.alloc(type_info_to_ptr.Pointer.child, from.len); + for (from, items) |f, *t| { try mapAlloc(cb, allocator, ctx, f, t); } + to.* = items; }, else => try mapAlloc(cb, allocator, ctx, from.*, to.*), }, @@ -177,7 +173,7 @@ pub fn mapAlloc(comptime cb: anytype, allocator: std.mem.Allocator, ctx: FnParam } else { to.* = null; }, - .Int, .Float => to.* = from, + .Int, .Float, .Enum => to.* = from, else => stdx.debug.compileError("zml.meta.mapAlloc doesn't support: {}", .{FromStruct}), } } diff --git a/zml/module.zig b/zml/module.zig index c1bd7ac..8909fc5 100644 --- a/zml/module.zig +++ b/zml/module.zig @@ -49,7 +49,7 @@ const Block = union(BlockKind) { .op_result => |parent_op| self.appendOperationRecursive(parent_op), .block_argument => |arg| { // Hermetic blocks are not allowed to use arguments from other blocks. - std.debug.assert(self == .open or self.block().eql(arg.block())); + stdx.debug.assert(self == .open or self.block().eql(arg.block()), "Can't add {} from {?x} block to {?x} block", .{ arg, arg.block()._inner.ptr, self.block()._inner.ptr }); }, .null => @panic("InvalidMlir"), } @@ -75,6 +75,11 @@ pub const MlirFn = struct { res_shapes: []Shape, res_donations: []Tensor._Donation, mlir_fn: mlir.Operation, + + pub const Kind = enum { + main, + private, + }; }; pub const CompilationContext = struct { @@ -151,7 +156,6 @@ pub const CompilationContext = struct { pub fn deactivate(self: *CompilationContext) void { std.debug.assert(_current != null and _current.? == self); _current = self._previous; - self._previous = null; } pub fn current() *CompilationContext { @@ -182,9 +186,9 @@ pub const CompilationContext = struct { const arena = arena_state.allocator(); var timer = std.time.Timer.start() catch null; - const tensor_args = self.tensorFromShapes(stdx.meta.FnArgs(func), arena, args); + const tensor_args = try self.tensorFromShapes(stdx.meta.FnArgs(func), arena, args); // Run in a dedicated thread because compilation relies on `threadlocal`. - const f = try asynk.callBlocking(CompilationContext.generateBytecode, .{ self, arena, "main", func, &tensor_args }); + const f = try asynk.callBlocking(CompilationContext.emitMlir, .{ self, arena, func, &tensor_args, .{ .name = "main", .kind = .main } }); const module = self._module; module.getBody().appendOperation(f.mlir_fn); @@ -329,15 +333,18 @@ pub const CompilationContext = struct { /// Generate an MLIR function from a ZML function. /// The caller is responsible to have properly created the input /// tensors with unique tensor ids. - pub fn generateBytecode( + pub fn emitMlir( self: *CompilationContext, allocator: std.mem.Allocator, - fn_name: []const u8, comptime func: anytype, args: *const stdx.meta.FnArgs(func), + opts: struct { + name: []const u8, + kind: MlirFn.Kind = .private, + }, ) error{OutOfMemory}!MlirFn { - const frame = self._tracer.frameStart("generateBytecode.emit"); - errdefer self._tracer.frameEnd(frame, "generateBytecode.emit"); + const frame = self._tracer.frameStart("emitMlir.emit"); + errdefer self._tracer.frameEnd(frame, "emitMlir.emit"); // Note: only temp allocations are done in the arena, // the other allocations are managed by the caller. @@ -371,11 +378,6 @@ pub const CompilationContext = struct { var fn_body = self.openBlock(.hermetic, input_types, locations) catch unreachable; { defer self.closeBlock(fn_body); - // Note: we could shrink self._buffer_to_arg once we called `func`. - // But for now we are only compiling one function per CompilationContext. - // So we don't need to do this since we won't reuse self._buffer_to_arg anyway. - // const n = self._buffer_to_arg.count(); - // defer self._buffer_to_arg.shrinkRetainingCapacity(n); try self._buffer_to_arg.ensureUnusedCapacity(self._allocator, @intCast(tensor_count)); const assigned_args_count = self.mapBlockArguments(args, fn_body.block(), 0); @@ -400,14 +402,15 @@ pub const CompilationContext = struct { const res_attrs = try arena.alloc(AttributeList, out_tensor_count); @memset(res_attrs, .{}); - // Donations attributes only make sense on the main function. - self.addDonationsAttributes(arg_attrs, fn_res_donations); - - if (self._platform.sharding().num_partitions > 1) { - self.addShardingAttributes(arg_attrs, res_attrs, input_shapes.items, fn_res_shapes); + if (opts.kind == .main) { + self.addDonationsAttributes(arg_attrs, fn_res_donations); + if (self._platform.sharding().num_partitions > 1) { + self.addShardingAttributes(arg_attrs, res_attrs, input_shapes.items, fn_res_shapes); + } } + const mlir_fn = dialect.func.func(self.mlirCtx(), .{ - .sym_name = fn_name, + .sym_name = opts.name, .args = input_types, .arg_attrs = try finalizeAttributeList(arena, mlir_ctx, arg_attrs), .results = fn_res_types, @@ -416,9 +419,9 @@ pub const CompilationContext = struct { .location = loc, }); - self._tracer.frameEnd(frame, "generateBytecode.emit"); - const canonicalize_frame = self._tracer.frameStart("generateBytecode.canonicalize"); - defer self._tracer.frameEnd(canonicalize_frame, "generateBytecode.canonicalize"); + self._tracer.frameEnd(frame, "emitMlir.emit"); + const canonicalize_frame = self._tracer.frameStart("emitMlir.canonicalize"); + defer self._tracer.frameEnd(canonicalize_frame, "emitMlir.canonicalize"); self._mlir_canonicalizer.runOnOp(mlir_fn) catch |err| switch (err) { error.InvalidMlir => { log.err("Failed to canonicalize invalid mlir: {}", .{mlir_fn.mlirFormatter(.{})}); @@ -429,7 +432,7 @@ pub const CompilationContext = struct { return .{ .mlir_fn = mlir_fn, - .name = fn_name, + .name = opts.name, .num_args = @intCast(tensor_count), .res_types = fn_res_types, .res_shapes = fn_res_shapes, @@ -478,7 +481,13 @@ pub const CompilationContext = struct { const Local = struct { bias: Tensor, - pub fn forward(self: @This(), x: Tensor) Tensor { + pub fn forward(self: @This(), x: Tensor, y: Tensor) [2]Tensor { + const x1 = zml.ops.call(self, .inner, .{x}); + const x2 = zml.ops.call(self, .inner, .{x1}); + return .{ x1.reuseBuffer(y), x2 }; + } + + pub fn inner(self: @This(), x: Tensor) Tensor { const y = x.add(self.bias); return y.reuseBuffer(x); } @@ -490,19 +499,25 @@ pub const CompilationContext = struct { var comp = try zml.module.CompilationContext.init(allocator, "test", platform); defer comp.deinit(); - var tensor_args = .{ model, Tensor{ ._shape = s, ._id = .{ .arg_id = 1234 } } }; - const f = try comp.generateBytecode(allocator, "test.generateBytecode.Local.forward", Local.forward, &tensor_args); + var tensor_args = .{ model, Tensor{ ._shape = s, ._id = .{ .buffer_id = 1234 } }, Tensor{ ._shape = s, ._id = .{ .buffer_id = 1235 } } }; + const f = try comp.emitMlir(allocator, Local.forward, &tensor_args, .{ .name = "test.emitMlir.Local.forward", .kind = .main }); var mlir_bytecode: std.ArrayListUnmanaged(u8) = .{}; try mlir_bytecode.writer(allocator).print("{}", .{f.mlir_fn.mlirFormatter(.{})}); // Check that the `x` input argument gives its buffer to the result tensor. - // `%arg0` is the bias of the model, `%arg1` is `x`. - try std.testing.expectEqual(2, f.num_args); - std.testing.expect(std.mem.indexOf(u8, mlir_bytecode.items, "tf.aliasing_output = 0 : i32") != null) catch |err| { - log.warn("Didn't produced the expected IR:\n{s}", .{mlir_bytecode.items}); - return err; - }; + // `%arg0` is the bias of the model, `%arg1` is `x`, `%arg2` is `y`. + try std.testing.expectEqual(3, f.num_args); + // We should have two buffers being donated. + const template = "tf.aliasing_output = {d} : i32"; + var buf = template.*; + for (0..2) |i| { + const alias_attr = std.fmt.bufPrint(&buf, template, .{i}) catch unreachable; + std.testing.expect(std.mem.indexOf(u8, mlir_bytecode.items, alias_attr) != null) catch |err| { + log.warn("Didn't produced the expected IR:\n{s}", .{mlir_bytecode.items}); + return err; + }; + } } pub fn getShardingAttr(self: CompilationContext, shape: Shape) mlir.StringAttribute { @@ -608,45 +623,77 @@ pub const CompilationContext = struct { // first, do the "compile" and check the bytecode // the result of this will also have the correct tags of the result shapes - const dummy_result = self.generateMlirBytecodeForFunction( - arena, - func_name, - func, - args, - ) catch unreachable; // TODO: do we like unreachable? - const bytecode_hash = hashArgs(dummy_result.bytecode_tmp); - - const key: FnCache.Key = .{ .fn_ptr = &func, .input_hash = bytecode_hash }; + const args_hash = hashArgs(args); + const key: FnCache.Key = .{ .fn_ptr = &func, .input_hash = args_hash }; const function = self._fn_cache.getEntry(key) orelse b: { const full_name: [:0]const u8 = if (std.mem.eql(u8, "main", func_name)) arena.dupeZ(u8, func_name) catch unreachable else std.fmt.allocPrintZ(arena, "{s}_{x}", .{ func_name, key.input_hash }) catch unreachable; - log.info("addFuncToModule {any} {s}", .{ key, full_name }); + const og_buffer_to_arg = self._buffer_to_arg; + defer { + self._buffer_to_arg.deinit(self._allocator); + self._buffer_to_arg = og_buffer_to_arg; + } - const value = self.addFuncToModule( - arena, - full_name, - func, - args, - ) catch unreachable; + // Reset the buffer -> assignement + self._buffer_to_arg = .{}; - break :b self._fn_cache.addEntry(key, value) catch unreachable; + var arg_id: u16 = 0; + var tensor_args: @TypeOf(args) = args; + meta.mapAlloc(struct { + fn cb(arg_id_: *u16, x: Tensor) Tensor { + const a = arg_id_.*; + arg_id_.* += 1; + return Tensor{ ._shape = x._shape, ._id = .{ .arg_id = a }, ._donation = .{ .arg = a } }; + } + }.cb, self._allocator, &arg_id, args, &tensor_args) catch @panic("OutOfMemory"); + + const f = self.emitMlir(arena, func, &tensor_args, .{ + .name = full_name, + }) catch @panic("OOM"); + self._module.getBody().appendOperation(f.mlir_fn); + + break :b self._fn_cache.addEntry(key, f) catch unreachable; }; - // Note: we won't increase the size of the cache until next `call` so - // we can use the memory there without worrying about fragmentation. - const loc = self.mlirCtx().location(@src()); - const values = arena.alloc(mlir.Value, function.n_model + function.n_args) catch unreachable; - self.extractValues(&args, values[function.n_model..]); + const values = arena.alloc(mlir.Value, function.num_args) catch unreachable; + self.extractValues(&args, values); - const op = dialect.func.call(self.mlirCtx(), function.name, values, function.res_types, loc); - // TODO: tags seem to be lost by `callFunc`. + const donations = arena.alloc(Tensor._Donation, function.num_args) catch unreachable; + meta.collectBuf(struct { + pub fn cb(ctx: *const CompilationContext, x: Tensor) Tensor._Donation { + return ctx.getValueAndDonation(x)[1]; + } + }.cb, self, &args, donations); + + const op = dialect.func.call(self.mlirCtx(), @ptrCast(function.name), values, function.res_types, loc); + // Create the result tensor object by combining the operand results, + // as well as the registered shapes and donations. + // Note: this assume res can be stack-allocated. + // Maybe it'd be simpler to just call the Zig function twice to do the shape/donation propagation for us. + // But this is blocked on https://github.com/zml/zml/issues/97 var res: stdx.meta.FnResult(func) = undefined; - assignResults(op, &res, function.res_shapes); + const LocalContext = struct { index: usize = 0, op: mlir.Operation, function: MlirFn, donations: []Tensor._Donation }; + var context: LocalContext = .{ .op = op, .function = function, .donations = donations }; + meta.visit((struct { + fn cb(ctx: *LocalContext, tensor: *Tensor) void { + const i = ctx.index; + ctx.index += 1; + var new = Tensor.fromMlirValue(ctx.op.result(i)); + new._shape = ctx.function.res_shapes[i]; + new._donation = switch (ctx.function.res_donations[i]) { + .no_buffer => .no_buffer, + .arg => |input_arg| ctx.donations[input_arg], + .input_buffer => .no_buffer, // user escaped the sandbox + }; + tensor.* = new; + } + }).cb, &context, &res); + std.debug.assert(context.index == op.numResults()); return res; } @@ -669,7 +716,7 @@ pub const CompilationContext = struct { const res = ctx.self._buffer_to_arg.getOrPutAssumeCapacity(tensor._id); if (res.found_existing) { - stdx.debug.panic("Failed compilation because received two tensors arguments with the same ID: {} and {}({}).", .{ res.key_ptr.*, tensor, tensor._id }); + stdx.debug.panic("Failed compilation because received two tensors arguments with the same ID: {} and {} at index {} ({}).", .{ res.value_ptr.*[0], tensor, ctx.index, tensor._id }); } else { res.value_ptr.* = .{ arg_value, .{ .arg = @intCast(ctx.index) } }; } @@ -681,7 +728,7 @@ pub const CompilationContext = struct { /// Create tensor from the given shapes. /// Each created tensor will receive a unique id, local to this CompilationContext. - pub fn tensorFromShapes(self: *CompilationContext, ArgsT: type, allocator: std.mem.Allocator, args_shapes: anytype) ArgsT { + pub fn tensorFromShapes(self: *CompilationContext, ArgsT: type, allocator: std.mem.Allocator, args_shapes: anytype) !ArgsT { const Local = struct { fn tensorFromShape(arg_id: *u64, shape: Shape) Tensor { defer arg_id.* += 1; @@ -951,29 +998,6 @@ fn assignBlockArguments(v: anytype, block: mlir.Block, start: usize) usize { return context.index; } -/// Visit the given struct and assign op results to each tensor found. -fn assignResults(op: mlir.Operation, v: anytype, shapes: []Shape) void { - const LocalContext = struct { - index: usize, - op: mlir.Operation, - shapes: ?[]Shape, - }; - var context = LocalContext{ .index = 0, .op = op, .shapes = shapes }; - meta.visit((struct { - fn cb(inner_ctx: *LocalContext, tensor: *Tensor) void { - var new = Tensor.fromMlirValue(inner_ctx.op.result(inner_ctx.index)); - if (inner_ctx.shapes) |sh| { - new._shape = sh[inner_ctx.index]; - } else { - new._shape._tags = tensor._shape._tags; - } - tensor.* = new; - inner_ctx.index += 1; - } - }).cb, &context, v); - std.debug.assert(context.index == op.numResults()); -} - pub const XxHash64Writer = struct { hasher: *std.hash.XxHash64, @@ -1039,8 +1063,7 @@ pub const FnCache = struct { const owned_value: MlirFn = .{ .name = name_copy, .mlir_fn = value.mlir_fn, - .n_model = value.n_model, - .n_args = value.n_args, + .num_args = value.num_args, .res_types = res_types_copy, .res_shapes = res_shapes_copy, .res_donations = res_donations_copy, @@ -1055,47 +1078,55 @@ test FnCache { const zml = @import("zml.zig"); const platform = zml.testing.env(); + const Layer = struct { + const Layer_ = @This(); + + w: Tensor, + b: Tensor, + + pub fn forward(self: Layer_, x: Tensor) Tensor { + const wx = self.w.dotGeneral(x, &.{.{ -1, 0 }}, &.{}); + return wx.add(self.b.broad(wx.shape())).relu(); + } + }; + const NN = struct { const NN_ = @This(); - layer_weights: [3]Tensor, - layer_biases: [3]Tensor, + layers: [3]Layer, pub fn forward(self: NN_, x0: Tensor) Tensor { var x = x0; - for (self.layer_weights, self.layer_biases) |w, b| { - // TODO use the `call` magic helper - // x = ops.callFunc(ctx, NN_, "reluLayer", .{ w, b, x }); - x = NN_.reluLayer(w, b, x); + for (self.layers) |layer| { + x = ops.call(layer, .forward, .{x}); } return x; } pub fn forwardRefImpl(self: NN_, x0: Tensor) Tensor { var x = x0; - for (self.layer_weights, self.layer_biases) |w, b| { - x = NN_.reluLayer(w, b, x); + for (self.layers) |layer| { + x = layer.forward(x); } return x; } - - pub fn reluLayer(w: Tensor, b: Tensor, x: Tensor) Tensor { - const wx = w.dotGeneral(x, &.{.{ -1, 0 }}, &.{}); - return wx.add(b.broadcastLeft(wx.shape())).relu(); - } }; const x = try zml.Buffer.fromSlice(platform, .{2}, &[_]f16{ -1, 1 }); const nn: zml.Bufferized(NN) = .{ - .layer_weights = .{ - try zml.Buffer.fromSlice(platform, .{ 2, 2 }, &[_]f16{ 1, -1, 0, 1 }), - try zml.Buffer.fromSlice(platform, .{ 2, 2 }, &[_]f16{ 1, 2, 1, -1 }), + .layers = .{ + .{ + .w = try zml.Buffer.fromSlice(platform, .{ 2, 2 }, &[_]f16{ 1, -1, 0, 1 }), + .b = try zml.Buffer.fromSlice(platform, .{2}, &[_]f16{ 0, 0 }), + }, + .{ + .w = try zml.Buffer.fromSlice(platform, .{ 2, 2 }, &[_]f16{ 1, 2, 1, -1 }), + .b = try zml.Buffer.fromSlice(platform, .{2}, &[_]f16{ 10, 10 }), + }, // third layer is different - try zml.Buffer.fromSlice(platform, .{ 3, 2 }, &[_]f16{ 1, 2, 0, 1, -1, 0 }), - }, - .layer_biases = .{ - try zml.Buffer.fromSlice(platform, .{2}, &[_]f16{ 0, 0 }), - try zml.Buffer.fromSlice(platform, .{2}, &[_]f16{ 10, 10 }), - try zml.Buffer.fromSlice(platform, .{3}, &[_]f16{ -10, -10, -10 }), + .{ + .w = try zml.Buffer.fromSlice(platform, .{ 3, 2 }, &[_]f16{ 1, 2, 0, 1, -1, 0 }), + .b = try zml.Buffer.fromSlice(platform, .{3}, &[_]f16{ -10, -10, -10 }), + }, }, }; const res = try zml.testing.compileAndCall(platform, NN.forward, .{ nn, x }); diff --git a/zml/ops.zig b/zml/ops.zig index aa17038..b446c9b 100644 --- a/zml/ops.zig +++ b/zml/ops.zig @@ -30,9 +30,10 @@ test { /// Generate an MLIR call to the given member function with the given tensors. pub fn call(self: anytype, comptime func: stdx.meta.DeclEnum(@TypeOf(self)), args: anytype) @TypeOf(@call(.auto, @field(stdx.meta.UnwrapPtr(@TypeOf(self)), @tagName(func)), .{self} ++ args)) { - // TODO: this should use `self.getContext().callFunc(self, args)` - - return @call(.auto, @field(@TypeOf(self), @tagName(func)), .{self} ++ args); + const ctx = CompilationContext.current(); + const name = @typeName(@TypeOf(self)) ++ "." ++ @tagName(func); + const actual_fn = @field(@TypeOf(self), @tagName(func)); + return ctx.callFunc(name, actual_fn, .{self} ++ args); } pub fn while_( diff --git a/zml/testing.zig b/zml/testing.zig index 38688f8..eebc2ce 100644 --- a/zml/testing.zig +++ b/zml/testing.zig @@ -138,7 +138,7 @@ pub fn compileAndCall(platform: zml.Platform, func: anytype, buffer_args: zml.Bu } }; var shape_args: zml.ShapeOf(stdx.meta.FnArgs(func)) = undefined; - try meta.mapAlloc(Local.bufferToShape, allocator, {}, buffer_args, &shape_args); + try meta.mapAlloc(Local.bufferToShape, arena.allocator(), {}, buffer_args, &shape_args); const mod = try zml.compileFn(allocator, func, shape_args, platform); defer mod.deinit();