diff --git a/stdx/meta.zig b/stdx/meta.zig index 40b33bf..1efe7a7 100644 --- a/stdx/meta.zig +++ b/stdx/meta.zig @@ -150,11 +150,17 @@ pub fn FnParam(comptime func: anytype, comptime n: comptime_int) type { } pub fn FnArgs(comptime func: anytype) type { + debug.assertComptime(!@typeInfo(@TypeOf(func)).Fn.is_generic, "FnArgs expects non generic function, got: {}", .{@TypeOf(func)}); return FnSignature(func, null).ArgsT; } +pub fn FnArgsWithHint(comptime func: anytype, ArgsT: type) type { + debug.assertComptime(@typeInfo(@TypeOf(func)).Fn.is_generic, "FnArgsWithHint expects a generic function, got: {}", .{@TypeOf(func)}); + return FnSignature(func, ArgsT).ArgsT; +} + pub fn FnResult(comptime func: anytype) type { - return FnSignature(func, null).ReturnT; + return @typeInfo(@TypeOf(func)).Fn.return_type orelse @compileError("anytype is not supported"); } pub fn Head(Tuple: type) type { diff --git a/stdx/signature.zig b/stdx/signature.zig index 6a4e06f..a385f22 100644 --- a/stdx/signature.zig +++ b/stdx/signature.zig @@ -12,7 +12,7 @@ pub fn ArgsTuple(comptime funcT: anytype, comptime ArgsT: ?type) type { return std.meta.ArgsTuple(funcT); } - const args = std.meta.fields(ArgsT orelse compileError("generic function requires an explicit ArgsTuple", .{})); + const args = std.meta.fields(ArgsT orelse @compileError("generic function requires an explicit ArgsTuple")); var tuple_fields: [params.len]std.builtin.Type.StructField = undefined; if (params.len != args.len) { compileError("function {} expected {} args, got {}", .{ funcT, params.len, args.len }); @@ -23,7 +23,7 @@ pub fn ArgsTuple(comptime funcT: anytype, comptime ArgsT: ?type) type { continue; } const T = param.type.?; - var num_buf: [32]u8 = undefined; + var num_buf: [8]u8 = undefined; tuple_fields[i] = .{ .name = blk: { const s = std.fmt.formatIntBuf(&num_buf, i, 10, .lower, .{}); diff --git a/zml/aio.zig b/zml/aio.zig index f8fad85..0c49824 100644 --- a/zml/aio.zig +++ b/zml/aio.zig @@ -88,7 +88,7 @@ pub fn populateModelWithPrefix(comptime Model: type, allocator: std.mem.Allocato try prefix_builder.push(allocator, prefix); defer prefix_builder.deinit(allocator); - const unique_id = zml.Tensor.reserveIdRange(@intCast(store.buffers.count())); + const unique_id = zml.Tensor._reserveIdRange(@intCast(store.buffers.count())); const ok = _populateStruct(allocator, &prefix_builder, unique_id, store, &model, true) catch |err| { std.debug.panic("Can't populate model of type {s}: {s}", .{ @typeName(type), @errorName(err) }); }; diff --git a/zml/module.zig b/zml/module.zig index 8909fc5..9eaff9b 100644 --- a/zml/module.zig +++ b/zml/module.zig @@ -86,18 +86,17 @@ pub const CompilationContext = struct { _platform: Platform, _name: []const u8, + _arena: std.heap.ArenaAllocator, _mlir_ctx: mlir.Context, _mlir_registry: mlir.Registry, _mlir_canonicalizer: mlir.PassManager, _module: mlir.Module, - _blocks: std.BoundedArray(Block, 64), - _fn_cache: FnCache, - // TODO: make this an arena, that way it's fine if ops allocate inside it. - _allocator: std.mem.Allocator, + _blocks: std.BoundedArray(Block, 64) = .{}, + _fn_cache: FnCache = .{}, - _buffer_to_arg: TensorToBlockArg = .{}, + _block_args: TensorToBlockArg = .{}, _unique_id: u64 = 10000, _tracer: Tracer, @@ -107,7 +106,7 @@ pub const CompilationContext = struct { const TensorToBlockArg = std.AutoHashMapUnmanaged(Tensor._Id, struct { mlir.Value, Tensor._Donation }); const AttributeList = std.BoundedArray(mlir.NamedAttribute, 3); - pub fn init(allocator: std.mem.Allocator, name: []const u8, platform: Platform) !CompilationContext { + pub fn init(allocator_: std.mem.Allocator, name: []const u8, platform: Platform) !CompilationContext { const mlir_registry = mlir.Registry.init() catch unreachable; inline for (.{ "func", "stablehlo" }) |d| { mlir.DialectHandle.fromString(d).insertDialect(mlir_registry); @@ -127,6 +126,10 @@ pub const CompilationContext = struct { try opm.addPipeline("canonicalize"); } + var arena = std.heap.ArenaAllocator.init(allocator_); + _ = try arena.allocator().alloc(u8, std.mem.page_size); + _ = arena.reset(.retain_capacity); + return .{ ._platform = platform, ._name = name, @@ -135,17 +138,21 @@ pub const CompilationContext = struct { ._mlir_canonicalizer = canonicalizer, ._module = module, ._blocks = .{}, - ._fn_cache = FnCache.init(allocator), - ._allocator = allocator, + ._fn_cache = .{}, + ._arena = arena, ._tracer = Tracer.init("ai.zml.compilation"), }; } pub fn deinit(self: *CompilationContext) void { - self._fn_cache.deinit(); + // No need to deinit self._fn_cache cause it uses our arena self._mlir_ctx.deinit(); self._mlir_registry.deinit(); - self._buffer_to_arg.deinit(self._allocator); + self._arena.deinit(); + } + + pub fn allocator(self: *CompilationContext) std.mem.Allocator { + return self._arena.allocator(); } pub fn activate(self: *CompilationContext) void { @@ -173,22 +180,20 @@ pub const CompilationContext = struct { /// Compiles the given function with the given arguments. /// This is the untyped API and is not meant to be use directly. /// - /// * allocator is used to allocate the + /// * allocator is used to allocate the result Exe /// * args can contain a mix of tensors and shapes, allowing to pass a "model struct" containig tensors. pub fn compileInternal( self: *CompilationContext, - allocator: std.mem.Allocator, + allocator_: std.mem.Allocator, comptime func: anytype, args: anytype, ) !BaseExe { - var arena_state = std.heap.ArenaAllocator.init(allocator); - defer arena_state.deinit(); - const arena = arena_state.allocator(); + const arena = self.allocator(); var timer = std.time.Timer.start() catch null; const tensor_args = try self.tensorFromShapes(stdx.meta.FnArgs(func), arena, args); // Run in a dedicated thread because compilation relies on `threadlocal`. - const f = try asynk.callBlocking(CompilationContext.emitMlir, .{ self, arena, func, &tensor_args, .{ .name = "main", .kind = .main } }); + const f = try asynk.callBlocking(CompilationContext.emitMlir, .{ self, func, &tensor_args, .{ .name = "main", .kind = .main } }); const module = self._module; module.getBody().appendOperation(f.mlir_fn); @@ -251,7 +256,7 @@ pub const CompilationContext = struct { } return BaseExe.init( - allocator, + allocator_, self._platform, loaded_executable, .{ @@ -335,7 +340,6 @@ pub const CompilationContext = struct { /// tensors with unique tensor ids. pub fn emitMlir( self: *CompilationContext, - allocator: std.mem.Allocator, comptime func: anytype, args: *const stdx.meta.FnArgs(func), opts: struct { @@ -346,9 +350,10 @@ pub const CompilationContext = struct { const frame = self._tracer.frameStart("emitMlir.emit"); errdefer self._tracer.frameEnd(frame, "emitMlir.emit"); + const res_allocator = self.allocator(); // Note: only temp allocations are done in the arena, - // the other allocations are managed by the caller. - var arena_state = std.heap.ArenaAllocator.init(allocator); + // the other allocations are in the context allocator. + var arena_state = std.heap.ArenaAllocator.init(self._arena.child_allocator); defer arena_state.deinit(); const arena = arena_state.allocator(); @@ -367,19 +372,28 @@ pub const CompilationContext = struct { const input_types = try arena.alloc(mlir.Type, tensor_count); for (input_types, input_shapes.items) |*t, sh| t.* = mlir.ext.mlirType(mlir_ctx, sh); + const og_block_args = self._block_args; + defer { + self._block_args.deinit(self.allocator()); + self._block_args = og_block_args; + } + + // Reset the buffer -> assignement + self._block_args = .{}; + // Note: this isn't stricly necessary. We call `countTensor` on `fn_res`. // But it forces user to have simpler function. const ReturnT = stdx.meta.FnResult(func); const out_tensor_count = comptime ops.staticCountTensors(ReturnT) orelse @compileError("Can't use " ++ @typeName(ReturnT) ++ " in an MLIR function, because it has a variable number of tensors"); // Those are returned to caller so we don't put them in the arena. - const fn_res_types = try allocator.alloc(mlir.Type, out_tensor_count); - const fn_res_shapes = try allocator.alloc(Shape, out_tensor_count); - const fn_res_donations = try allocator.alloc(Tensor._Donation, out_tensor_count); + const fn_res_types = try res_allocator.alloc(mlir.Type, out_tensor_count); + const fn_res_shapes = try res_allocator.alloc(Shape, out_tensor_count); + const fn_res_donations = try res_allocator.alloc(Tensor._Donation, out_tensor_count); var fn_body = self.openBlock(.hermetic, input_types, locations) catch unreachable; { defer self.closeBlock(fn_body); - try self._buffer_to_arg.ensureUnusedCapacity(self._allocator, @intCast(tensor_count)); + try self._block_args.ensureUnusedCapacity(self.allocator(), @intCast(tensor_count)); const assigned_args_count = self.mapBlockArguments(args, fn_body.block(), 0); std.debug.assert(assigned_args_count == tensor_count); @@ -474,7 +488,6 @@ pub const CompilationContext = struct { const platform = zml.testing.env(); var arena = std.heap.ArenaAllocator.init(std.testing.allocator); defer arena.deinit(); - const allocator = arena.allocator(); const s = Shape.init(.{8}, .f16); @@ -497,13 +510,14 @@ pub const CompilationContext = struct { .bias = zml.Tensor{ ._shape = s, ._id = .{ .buffer_id = 0 } }, }; - var comp = try zml.module.CompilationContext.init(allocator, "test", platform); + var comp = try zml.module.CompilationContext.init(std.testing.allocator, "test", platform); defer comp.deinit(); var tensor_args = .{ model, Tensor{ ._shape = s, ._id = .{ .buffer_id = 1234 } }, Tensor{ ._shape = s, ._id = .{ .buffer_id = 1235 } } }; - const f = try comp.emitMlir(allocator, Local.forward, &tensor_args, .{ .name = "test.emitMlir.Local.forward", .kind = .main }); + const f = try comp.emitMlir(Local.forward, &tensor_args, .{ .name = "test.emitMlir.Local.forward", .kind = .main }); - var mlir_bytecode: std.ArrayListUnmanaged(u8) = .{}; - try mlir_bytecode.writer(allocator).print("{}", .{f.mlir_fn.mlirFormatter(.{})}); + var mlir_bytecode = std.ArrayList(u8).init(std.testing.allocator); + defer mlir_bytecode.deinit(); + try mlir_bytecode.writer().print("{}", .{f.mlir_fn.mlirFormatter(.{})}); // Check that the `x` input argument gives its buffer to the result tensor. // `%arg0` is the bias of the model, `%arg1` is `x`, `%arg2` is `y`. @@ -598,8 +612,8 @@ pub const CompilationContext = struct { } } - fn finalizeAttributeList(allocator: std.mem.Allocator, mlir_ctx: mlir.Context, attributes: []AttributeList) ![]mlir.Attribute { - const res = try allocator.alloc(mlir.Attribute, attributes.len); + fn finalizeAttributeList(allocator_: std.mem.Allocator, mlir_ctx: mlir.Context, attributes: []AttributeList) ![]mlir.Attribute { + const res = try allocator_.alloc(mlir.Attribute, attributes.len); for (res, attributes) |*r, attr| { r.* = mlir.DictionaryAttribute.init(mlir_ctx, attr.constSlice()).asAttr(); } @@ -617,7 +631,7 @@ pub const CompilationContext = struct { comptime func: anytype, args: stdx.meta.FnArgs(func), ) stdx.meta.FnResult(func) { - var arena_state = std.heap.ArenaAllocator.init(self._allocator); + var arena_state = std.heap.ArenaAllocator.init(self._arena.child_allocator); defer arena_state.deinit(); const arena = arena_state.allocator(); @@ -625,21 +639,13 @@ pub const CompilationContext = struct { // the result of this will also have the correct tags of the result shapes const args_hash = hashArgs(args); const key: FnCache.Key = .{ .fn_ptr = &func, .input_hash = args_hash }; + const function = self._fn_cache.getEntry(key) orelse b: { const full_name: [:0]const u8 = if (std.mem.eql(u8, "main", func_name)) arena.dupeZ(u8, func_name) catch unreachable else std.fmt.allocPrintZ(arena, "{s}_{x}", .{ func_name, key.input_hash }) catch unreachable; - const og_buffer_to_arg = self._buffer_to_arg; - defer { - self._buffer_to_arg.deinit(self._allocator); - self._buffer_to_arg = og_buffer_to_arg; - } - - // Reset the buffer -> assignement - self._buffer_to_arg = .{}; - var arg_id: u16 = 0; var tensor_args: @TypeOf(args) = args; meta.mapAlloc(struct { @@ -648,14 +654,14 @@ pub const CompilationContext = struct { arg_id_.* += 1; return Tensor{ ._shape = x._shape, ._id = .{ .arg_id = a }, ._donation = .{ .arg = a } }; } - }.cb, self._allocator, &arg_id, args, &tensor_args) catch @panic("OutOfMemory"); + }.cb, arena, &arg_id, args, &tensor_args) catch @panic("OutOfMemory"); - const f = self.emitMlir(arena, func, &tensor_args, .{ + const f = self.emitMlir(func, &tensor_args, .{ .name = full_name, }) catch @panic("OOM"); self._module.getBody().appendOperation(f.mlir_fn); - break :b self._fn_cache.addEntry(key, f) catch unreachable; + break :b self._fn_cache.addEntry(self.allocator(), key, f) catch unreachable; }; const loc = self.mlirCtx().location(@src()); @@ -701,7 +707,7 @@ pub const CompilationContext = struct { /// /// This is done so that we have a mapping between the arguments of the kernel associated with a module and the actual Tensors /// stored in the Module. - /// Caller need to allocate required memory in self._buffer_to_arg. + /// Caller need to allocate required memory in self._block_args. pub fn mapBlockArguments(self: *CompilationContext, v: anytype, block: mlir.Block, start: usize) usize { const LocalContext = struct { index: usize, @@ -714,7 +720,7 @@ pub const CompilationContext = struct { const arg_value = ctx.block.argument(ctx.index); // log.debug("mapping {} to arg {}", .{ tensor._id, ctx.index }); - const res = ctx.self._buffer_to_arg.getOrPutAssumeCapacity(tensor._id); + const res = ctx.self._block_args.getOrPutAssumeCapacity(tensor._id); if (res.found_existing) { stdx.debug.panic("Failed compilation because received two tensors arguments with the same ID: {} and {} at index {} ({}).", .{ res.value_ptr.*[0], tensor, ctx.index, tensor._id }); } else { @@ -728,7 +734,7 @@ pub const CompilationContext = struct { /// Create tensor from the given shapes. /// Each created tensor will receive a unique id, local to this CompilationContext. - pub fn tensorFromShapes(self: *CompilationContext, ArgsT: type, allocator: std.mem.Allocator, args_shapes: anytype) !ArgsT { + pub fn tensorFromShapes(self: *CompilationContext, ArgsT: type, allocator_: std.mem.Allocator, args_shapes: anytype) !ArgsT { const Local = struct { fn tensorFromShape(arg_id: *u64, shape: Shape) Tensor { defer arg_id.* += 1; @@ -740,7 +746,7 @@ pub const CompilationContext = struct { } }; var tensor_args: ArgsT = undefined; - try meta.mapAlloc(Local.tensorFromShape, allocator, &self._unique_id, args_shapes, &tensor_args); + try meta.mapAlloc(Local.tensorFromShape, allocator_, &self._unique_id, args_shapes, &tensor_args); return tensor_args; } @@ -771,7 +777,7 @@ pub const CompilationContext = struct { pub fn getValueAndDonation(self: *const CompilationContext, tensor: Tensor) struct { mlir.Value, Tensor._Donation } { return switch (tensor._id) { - .buffer_id, .arg_id => if (self._buffer_to_arg.get(tensor._id)) |res| + .buffer_id, .arg_id => if (self._block_args.get(tensor._id)) |res| .{ res[0], res[1] } else { log.err("Found unknown tensor id {}({})", .{ tensor, tensor._id }); @@ -1021,44 +1027,28 @@ pub fn xxHash64Writer(hasher: *std.hash.XxHash64) XxHash64Writer { pub const FnCache = struct { pub const Key = struct { fn_ptr: *const anyopaque, input_hash: u64 }; - // TODO: merge arenas - cache: std.AutoHashMapUnmanaged(Key, MlirFn), - // Arena for the cache entries - cache_arena: std.heap.ArenaAllocator, - // Arena for the cache data (name, res_type) - cache_data_arena: std.heap.ArenaAllocator, + cache: std.AutoHashMapUnmanaged(Key, MlirFn) = .{}, - pub fn init(allocator: std.mem.Allocator) FnCache { - return .{ - .cache = .{}, - .cache_arena = std.heap.ArenaAllocator.init(allocator), - .cache_data_arena = std.heap.ArenaAllocator.init(allocator), - }; - } - - pub fn deinit(self: FnCache) void { - self.cache_arena.deinit(); - self.cache_data_arena.deinit(); + pub fn deinit(self: FnCache, allocator: std.mem.Allocator) void { + self.cache.deinit(allocator); } pub fn getEntry(self: *const FnCache, key: Key) ?MlirFn { return self.cache.get(key); } - pub fn addEntry(self: *FnCache, key: Key, value: MlirFn) !MlirFn { - var cache_data_allocator = self.cache_data_arena.allocator(); + pub fn addEntry(self: *FnCache, allocator: std.mem.Allocator, key: Key, value: MlirFn) !MlirFn { + const res_types_copy = try allocator.dupe(mlir.Type, value.res_types); + errdefer allocator.free(res_types_copy); - const res_types_copy = try cache_data_allocator.dupe(mlir.Type, value.res_types); - errdefer cache_data_allocator.free(res_types_copy); + const res_shapes_copy = try allocator.dupe(Shape, value.res_shapes); + errdefer allocator.free(res_shapes_copy); - const res_shapes_copy = try cache_data_allocator.dupe(Shape, value.res_shapes); - errdefer cache_data_allocator.free(res_shapes_copy); + const res_donations_copy = try allocator.dupe(Tensor._Donation, value.res_donations); + errdefer allocator.free(res_donations_copy); - const res_donations_copy = try cache_data_allocator.dupe(Tensor._Donation, value.res_donations); - errdefer cache_data_allocator.free(res_donations_copy); - - const name_copy = try cache_data_allocator.dupeZ(u8, value.name); - errdefer cache_data_allocator.free(name_copy); + const name_copy = try allocator.dupeZ(u8, value.name); + errdefer allocator.free(name_copy); const owned_value: MlirFn = .{ .name = name_copy, @@ -1069,7 +1059,7 @@ pub const FnCache = struct { .res_donations = res_donations_copy, }; - try self.cache.putNoClobber(self.cache_arena.allocator(), key, owned_value); + try self.cache.putNoClobber(allocator, key, owned_value); return owned_value; } }; diff --git a/zml/nn.zig b/zml/nn.zig index b8d7d93..7ab1137 100644 --- a/zml/nn.zig +++ b/zml/nn.zig @@ -816,8 +816,8 @@ const SdpaMemEfficient = struct { const n_q_chunks: u32 = @intCast(@divExact(self.q.dim(.q), self.chunking.q_chunk_size)); const ctx = zml.module.CompilationContext.current(); - const q_chunks = ctx._allocator.alloc(zml.Tensor, n_q_chunks) catch unreachable; - defer ctx._allocator.free(q_chunks); + const q_chunks = ctx.allocator().alloc(zml.Tensor, n_q_chunks) catch unreachable; + defer ctx.allocator().free(q_chunks); for (0..n_q_chunks) |i| { const idx: u32 = @intCast(i); const q_slice: zml.Tensor.DynSlice = .{ diff --git a/zml/ops.zig b/zml/ops.zig index b446c9b..7809143 100644 --- a/zml/ops.zig +++ b/zml/ops.zig @@ -316,7 +316,7 @@ pub fn for_(comptime func: anytype, blk_ctx: BlockSign(func).BlkCtx, num_steps_: // but because of https://github.com/zml/zml/issues/97 we also reuse it to start the while_ loop. const first_step = @call(.auto, func, .{ blk_ctx, Tensor.scalar(0, .i32) }); log.debug("for_ first_step: {}", .{first_step}); - const allocator = CompilationContext.current()._allocator; + const allocator = CompilationContext.current().allocator(); // Optimize for small num reps if (num_steps == 1) { var res = first_step; diff --git a/zml/tensor.zig b/zml/tensor.zig index a44958b..00ae1d3 100644 --- a/zml/tensor.zig +++ b/zml/tensor.zig @@ -126,7 +126,7 @@ pub const Tensor = struct { /// Returns the dimension of axis 'axis_'. /// - /// 'axis_' can be a signed integer or a tag. + /// 'axis_' can be an integer or a tag. pub fn dim(self: Tensor, axis_: anytype) i64 { return self._shape.dim(axis_); } @@ -138,11 +138,18 @@ pub const Tensor = struct { /// Returns the index of axis 'axis_'. /// - /// 'axis_' can be a signed integer or a tag. + /// 'axis_' can be an integer or a tag. pub fn axis(self: Tensor, axis_: anytype) u3 { return self._shape.axis(axis_); } + /// Returns the indices of each of the given axes. + /// + /// 'axis_' can be an integer or a tag. + pub fn axes(self: Tensor, axes_: anytype) std.BoundedArray(u3, Tensor.MAX_RANK) { + return self._shape.axes(axes_); + } + /// Returns a Tensor tagged with the tags in 'tagz'. pub fn withTags(self: Tensor, tagz: anytype) Tensor { var res = self; @@ -227,13 +234,13 @@ pub const Tensor = struct { var _global_tensor_counter: u64 = 0; /// Internal use - pub fn reserveIdRange(len: u32) u64 { + pub fn _reserveIdRange(len: u32) u64 { return @atomicRmw(u64, &_global_tensor_counter, .Add, len, .seq_cst); } /// Internal use pub fn setUniqueId(self: *Tensor) void { - self._id = .{ .buffer_id = reserveIdRange(1) }; + self._id = .{ .buffer_id = _reserveIdRange(1) }; } /// Returns a Tensor containing the absolute value of each element of the input Tensor. @@ -357,7 +364,7 @@ pub const Tensor = struct { /// Returns the Cholesky decomposition of the input Tensor. /// - /// 'lower' controls the form of the outut Tensor. The output will be lower-triangular if 'lower' is true + /// 'lower' controls the form of the output Tensor. The output will be lower-triangular if 'lower' is true /// and upper-triangular otherwise. pub fn cholesky(self: Tensor, lower: bool) Tensor { stdx.debug.assert(self.rank() <= 2, "cholesky expects tensor rank to be <= 2, got {}", .{self.rank()}); @@ -1070,7 +1077,7 @@ pub const Tensor = struct { const zml = @import("zml.zig"); const platform = zml.testing.env(); - var comp = try zml.module.CompilationContext.init(std.heap.page_allocator, "test", platform); + var comp = try zml.module.CompilationContext.init(std.testing.allocator, "test", platform); defer comp.deinit(); comp.activate(); @@ -1390,6 +1397,7 @@ pub const Tensor = struct { /// /// unflatten((d0, d1, axis_m, d3), 2, n) -> (d0, d1, n, d2_m, d3) pub fn unflatten(self: Tensor, axis_: i8, n: i64) Tensor { + // TODO: move to torch.zig, this equivalent to `spitAxis` stdx.debug.assert(self.rank() < Tensor.MAX_RANK, "unflatten expects input tensor rank to be less than {}, got {}", .{ Tensor.MAX_RANK, self.rank() }); const a = if (axis_ >= 0) self.axis(axis_) else self.axis(axis_) + 1; @@ -1445,6 +1453,7 @@ pub const Tensor = struct { /// Flattens the given axis and the next one, into one new axis. pub fn flatten(self: Tensor, axis_: anytype) Tensor { + // TODO: move to torch.zig, this is equivalent to merge const old_shape = self._shape; const a = self.axis(axis_); // stdx.debug.assert(a + 1 < self.rank(), "Can't flatten {} on the last axis {}.", .{ self, axis }); @@ -1463,6 +1472,7 @@ pub const Tensor = struct { } pub inline fn flattenAll(self: Tensor) Tensor { + // TODO: rename to just flatten, once flatten is moved to torch return self.reshape(.{self.count()}); } @@ -1547,7 +1557,8 @@ pub const Tensor = struct { return if (idx < 0) self.dim(axis_) + idx else idx; } - pub fn choose1d(self: Tensor, axis_: i64, i: i64) Tensor { + pub fn choose1d(self: Tensor, axis_: anytype, i: i64) Tensor { + // TODO: this use case could be handled directly by slice if we added a .single field return self.slice1d(axis_, .{ .start = i, .end = i + 1 }).squeeze(axis_); } @@ -1651,6 +1662,7 @@ pub const Tensor = struct { /// Repeats in line each value along the given axes. pub fn stutter(self: Tensor, n_reps: []const u63) Tensor { + // TODO: this should support the tagged syntax: x.repeat(.{ .a = 3, .b = 2}); stdx.debug.assert(n_reps.len == self.rank(), "stutter expects tensor rank and 'n_reps' length to be equal, got {} and {}", .{ self.rank(), n_reps.len }); var res = self; @@ -2159,7 +2171,7 @@ pub const Tensor = struct { { // Only test shapes - var comp = try zml.module.CompilationContext.init(std.heap.page_allocator, "test", platform); + var comp = try zml.module.CompilationContext.init(std.testing.allocator, "test", platform); defer comp.deinit(); comp.activate(); defer comp.deactivate(); @@ -2295,7 +2307,7 @@ pub const Tensor = struct { { // Only test shapes - var comp = try zml.module.CompilationContext.init(std.heap.page_allocator, "test", platform); + var comp = try zml.module.CompilationContext.init(std.testing.allocator, "test", platform); defer comp.deinit(); comp.activate(); defer comp.deactivate(); @@ -2519,7 +2531,7 @@ pub const Tensor = struct { { // Only test shapes - var comp = try zml.module.CompilationContext.init(std.heap.page_allocator, "test", platform); + var comp = try zml.module.CompilationContext.init(std.testing.allocator, "test", platform); defer comp.deinit(); comp.activate(); defer comp.deactivate(); @@ -2759,8 +2771,10 @@ pub const Tensor = struct { return .{ .values = res[0], .indices = res[1] }; } + pub const ArgSortOpts = struct { descending: bool = false }; + /// Returns a Tensor containing the indices corresponding to the sorted values over the given axis. - pub fn argsort(self: Tensor, axis_: i64, opts: struct { descending: bool = false }) Tensor { + pub fn argsort(self: Tensor, axis_: anytype, opts: ArgSortOpts) Tensor { return self.sort(axis_, .{ .descending = opts.descending }).indices; } @@ -2768,13 +2782,19 @@ pub const Tensor = struct { const zml = @import("zml.zig"); const platform = zml.testing.env(); + const Local = struct { + pub fn _argsort(x: Tensor, axis_: u3, opts: ArgSortOpts) Tensor { + return x.argsort(axis_, opts); + } + }; + var arena_state = std.heap.ArenaAllocator.init(std.testing.allocator); defer arena_state.deinit(); const allocator = arena_state.allocator(); // 2D Tensor - dim = 1, ascending { const x = try zml.Buffer.fromSlice(platform, .{ 2, 5 }, &[_]f32{ -0.9264, 0.7156, 1.0202, 0.3992, 1.2349, 1.0003, -0.1932, 1.3935, 0.7316, 0.0851 }); - const res = try zml.testing.compileAndCall(platform, Tensor.argsort, .{ x, 1, .{} }); + const res = try zml.testing.compileAndCall(platform, Local._argsort, .{ x, 1, .{} }); const res_cpu = try res.toHostAlloc(allocator); try testing.expectEqualSlices(i32, &.{ 0, 3, 1, 2, 4, 1, 4, 3, 0, 2 }, res_cpu.items(i32)); } @@ -2787,7 +2807,7 @@ pub const Tensor = struct { 0.6626, -0.3040, -0.8726, -1.4805, -1.6943, 1.1055, -2.0078, -0.5288, 0.8813, 0.8008, 2.0527, 1.1230, 0.5430, 0.2494, -0.9434, 0.7876, 0.1818, 0.9258, -2.4902, 1.5918, }); - const res_dev = try zml.testing.compileAndCall(platform, Tensor.argsort, .{ x, 1, .{ .descending = true } }); + const res_dev = try zml.testing.compileAndCall(platform, Local._argsort, .{ x, 1, .{ .descending = true } }); const res = try res_dev.toHostAlloc(allocator); try testing.expectEqualSlices(i32, &.{ 4, 1, 1, 2, 0, 2, 0, 0, 3, 4, @@ -2809,7 +2829,7 @@ pub const Tensor = struct { 64, 86, 62, 88, 57, 21, 19, 12, }); - const res_dev = try zml.testing.compileAndCallWithTensors(platform, Tensor.argsort, .{ x.shape(), 3, .{} }, .{ x, 0, .{} }); + const res_dev = try zml.testing.compileAndCall(platform, Local._argsort, .{ x, 3, .{} }); const res = try res_dev.toHostAlloc(allocator); try testing.expectEqualSlices(i32, &.{ 2, 1, 3, 0, @@ -2921,10 +2941,6 @@ pub const Tensor = struct { ); } - pub inline fn axes(self: Tensor, axes_: anytype) std.BoundedArray(u3, Tensor.MAX_RANK) { - return self._shape.axes(axes_); - } - /// Chunk a given tensor into exactly n parts of equal shape. /// `self.dim(axis_)` must be divisible by n_chunks. pub fn chunkExact(self: Tensor, axis_: anytype, n_chunks: comptime_int) [n_chunks]Tensor { @@ -2944,7 +2960,7 @@ pub const Tensor = struct { const platform = zml.testing.env(); // Only test shapes - var comp = try zml.module.CompilationContext.init(std.heap.page_allocator, "test", platform); + var comp = try zml.module.CompilationContext.init(std.testing.allocator, "test", platform); defer comp.deinit(); comp.activate(); defer comp.deactivate(); @@ -2959,7 +2975,7 @@ pub const Tensor = struct { const chunks = x.chunkExact(ax, n_chunks); const res_shape = Shape.init(res, .f16); - for (&chunks) |chk| { + for (chunks) |chk| { try zml.testing.expectEqualShapes(res_shape, chk.shape()); } } @@ -2971,13 +2987,14 @@ pub const Tensor = struct { self: Tensor, axis_: i64, n_chunks: comptime_int, - ) std.BoundedArray(Tensor, n_chunks + 1) { + ) []Tensor { const a = self.axis(axis_); const d = self.dim(a); const chunk_size: i64 = @divFloor(d, n_chunks); const tail_chunk_size: i64 = @rem(d, chunk_size); - var chunks: std.BoundedArray(Tensor, n_chunks + 1) = .{}; + const allocator = self.getContext().allocator(); + var chunks = std.ArrayListUnmanaged(Tensor).initCapacity(allocator, n_chunks + 1) catch @panic("OOM"); for (0..n_chunks) |i| { const start: i64 = @as(i64, @intCast(i)) * chunk_size; chunks.appendAssumeCapacity( @@ -2988,7 +3005,7 @@ pub const Tensor = struct { const start: i64 = n_chunks * chunk_size; chunks.appendAssumeCapacity(self.slice1d(a, .{ .start = start })); } - return chunks; + return chunks.items; } test chunkAllowTrailing { @@ -2996,7 +3013,7 @@ pub const Tensor = struct { const platform = zml.testing.env(); // Only test shapes - var comp = try zml.module.CompilationContext.init(std.heap.page_allocator, "test", platform); + var comp = try zml.module.CompilationContext.init(std.testing.allocator, "test", platform); defer comp.deinit(); comp.activate(); defer comp.deactivate(); @@ -3012,35 +3029,34 @@ pub const Tensor = struct { const chunks = x.chunkAllowTrailing(x.axis(ax), n_chunks); const res_shape = Shape.init(res, .f16); - for (chunks.constSlice()[0..n_chunks]) |chk| { + for (chunks[0..n_chunks]) |chk| { try zml.testing.expectEqualShapes(res_shape, chk.shape()); } const trailing_shape = Shape.init(trailing, .f16); if (trailing_shape.rank() > 0) { try std.testing.expectEqual(n_chunks + 1, chunks.len); - try zml.testing.expectEqualShapes(trailing_shape, chunks.get(n_chunks).shape()); + try zml.testing.expectEqualShapes(trailing_shape, chunks[n_chunks].shape()); } else { try std.testing.expectEqual(n_chunks, chunks.len); } } } - pub fn split(self: Tensor, allocator: std.mem.Allocator, split_size_or_sections: []const i64, axis_: i64) ![]Tensor { - stdx.debug.assert(split_size_or_sections.len > 0, "split expects 'split_size_or_sections' length to be positive, got {}", .{split_size_or_sections.len}); + pub fn split(self: Tensor, axis_: anytype, split_sizes: []const i64) []Tensor { + stdx.debug.assert(split_sizes.len > 0, "split expects at least one 'split_sizes', got 0", .{}); const a = self.axis(axis_); - const length = self.dim(a); - if (split_size_or_sections.len != 1) { - var split_sum: i64 = 0; - for (split_size_or_sections) |n| split_sum += n; - stdx.debug.assert(split_sum == length, "split expects sum of 'split_size_or_sections' values and axis dimension to be equal, got {} and {}", .{ split_sum, length }); - } + const d = self.dim(a); + var split_sum: i64 = 0; + for (split_sizes) |n| split_sum += n; + stdx.debug.assert(split_sum == d, "split expects sum of 'split_sizes' values and axis dimension to be equal, got {} and {}", .{ split_sum, d }); - const res = try allocator.alloc(Tensor, split_size_or_sections.len); + const allocator = self.getContext().allocator(); + const res = allocator.alloc(Tensor, split_sizes.len) catch @panic("OOM"); errdefer allocator.dealloc(res); var start: i64 = 0; - for (split_size_or_sections, 0..) |n, i| { + for (split_sizes, 0..) |n, i| { res[i] = self.slice1d(a, .{ .start = start, .end = start + n }); start += n; } @@ -3528,7 +3544,7 @@ pub const Tensor = struct { } /// Returns a Tensor containing boolean indicating if there is a non-zero value over the given axis. - pub fn any(self: Tensor, axis_: i64) Tensor { + pub fn any(self: Tensor, axis_: anytype) Tensor { const pred = self.cmp(.NE, Tensor.constant(self.dims(), self.dtype().zero())); const red = ops.reduce( struct { diff --git a/zml/testing.zig b/zml/testing.zig index ae09026..92e2b70 100644 --- a/zml/testing.zig +++ b/zml/testing.zig @@ -127,6 +127,9 @@ pub fn expectEqualShapes(expected: zml.Shape, actual: zml.Shape) error{TestExpec /// Compile a function and immediatly call it with the given buffers. /// The compiled module is discarded after the call. /// Useful during testing when a module is typically called only once. +/// +/// Note: `func` needs explicit types on all parameters. +/// To test a function with `anytype` (typically for tagged API), you need to create a specialized version of it with specific types. pub fn compileAndCall(platform: zml.Platform, func: anytype, buffer_args: zml.Bufferized(stdx.meta.FnArgs(func))) !zml.Bufferized(stdx.meta.FnResult(func)) { // This simplify test API and also ensure this fn isn't used outside of tests. const allocator = std.testing.allocator;