From 68dbc290e91ec4eb313730d917d492d934d55c4a Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Mon, 8 Jan 2024 17:55:20 +0000 Subject: [PATCH] zml: revamp scatterSlices Main issue with current `scatter` implementation is that it uses broadcasting dims of `stablehlo.scatter`. While nice in theory, the optimizer doesn't handle them well and they often are unrolled into while loop. Here I convert the batching dim to extra iotas indices. --- zml/meta.zig | 20 +++ zml/module.zig | 29 +++-- zml/ops.zig | 345 ++++++++++++++++++++++++++++++++++++++++++++++--- zml/shape.zig | 22 ++++ zml/tensor.zig | 340 ++++++++++++++++++++++++------------------------ zml/torch.zig | 10 +- 6 files changed, 560 insertions(+), 206 deletions(-) diff --git a/zml/meta.zig b/zml/meta.zig index 2e62605..82b25d7 100644 --- a/zml/meta.zig +++ b/zml/meta.zig @@ -392,6 +392,26 @@ test visit { } } +pub fn count(T: type, value: anytype) u32 { + var counter: u32 = 0; + visit(struct { + pub fn cb(res: *u32, _: *const T) void { + res.* += 1; + } + }.cb, &counter, value); + return counter; +} + +pub fn first(T: type, value: anytype) T { + var res: ?T = null; + visit(struct { + pub fn cb(res_ptr: *?T, x: *const T) void { + if (res_ptr.* == null) res_ptr.* = x.*; + } + }.cb, &res, &value); + return res.?; +} + /// Given a `fn([]const T, Args) T` and a slice of values, /// will combine all values in one value. /// Only T elements of values will be looked at. diff --git a/zml/module.zig b/zml/module.zig index d2dcacf..3bcdc9e 100644 --- a/zml/module.zig +++ b/zml/module.zig @@ -368,7 +368,7 @@ pub const CompilationContext = struct { defer arena_state.deinit(); const arena = arena_state.allocator(); - const tensor_count = countTensors(args); + const tensor_count = meta.count(Tensor, args); const mlir_ctx = self.mlirCtx(); const loc = mlir_ctx.location(@src()); @@ -505,13 +505,13 @@ pub const CompilationContext = struct { const Local = struct { bias: Tensor, - pub fn forward(self: @This(), x: Tensor, y: Tensor) [2]Tensor { - const x1 = zml.ops.call(self, .inner, .{x}); - const x2 = zml.ops.call(self, .inner, .{x1}); + pub fn _fwd(self: @This(), x: Tensor, y: Tensor) [2]Tensor { + const x1 = zml.ops.call(self, ._inner, .{x}); + const x2 = zml.ops.call(self, ._inner, .{x1}); return .{ x1.reuseBuffer(y), x2 }; } - pub fn inner(self: @This(), x: Tensor) Tensor { + pub fn _inner(self: @This(), x: Tensor) Tensor { const y = x.add(self.bias); return y.reuseBuffer(x); } @@ -524,7 +524,7 @@ pub const CompilationContext = struct { var comp = try zml.module.CompilationContext.init(std.testing.allocator, "test", platform); defer comp.deinit(); var tensor_args = .{ model, Tensor{ ._shape = s, ._id = .{ .buffer_id = 1234 } }, Tensor{ ._shape = s, ._id = .{ .buffer_id = 1235 } } }; - const f = try comp.emitMlir(Local.forward, &tensor_args, .{ .name = "test.emitMlir.Local.forward", .kind = .main }); + const f = try comp.emitMlir(Local._fwd, &tensor_args, .{ .name = "test.emitMlir.Local.forward", .kind = .main }); var mlir_bytecode = std.ArrayList(u8).init(std.testing.allocator); defer mlir_bytecode.deinit(); @@ -798,7 +798,7 @@ pub const CompilationContext = struct { }; } - fn getValue(self: *const CompilationContext, tensor: Tensor) mlir.Value { + pub fn getValue(self: *const CompilationContext, tensor: Tensor) mlir.Value { return self.getValueAndDonation(tensor)[0]; } @@ -911,6 +911,7 @@ fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, m // setFlag(&options, "xla_gpu_enable_cudnn_layer_norm", true); // setFlag(&options, "xla_gpu_enable_custom_fusions", true); // setFlag(&options, "xla_gpu_enable_dynamic_slice_fusion", true); + // setFlag(&options, "xla_gpu_enable_while_loop_double_buffering", true); // setFlag(&options, "xla_gpu_use_runtime_fusion", true); // setFlag(&options, "xla_gpu_enable_latency_hiding_scheduler", true); var r_ = try runfiles.Runfiles.create(.{ .allocator = arena }) orelse { @@ -1068,7 +1069,7 @@ test FnCache { w: Tensor, b: Tensor, - pub fn forward(self: Layer_, x: Tensor) Tensor { + pub fn _fwd(self: Layer_, x: Tensor) Tensor { const wx = self.w.dotGeneral(x, &.{.{ -1, 0 }}, &.{}); return wx.add(self.b.broad(wx.shape())).relu(); } @@ -1078,18 +1079,18 @@ test FnCache { const NN_ = @This(); layers: [3]Layer, - pub fn forward(self: NN_, x0: Tensor) Tensor { + pub fn _fwd(self: NN_, x0: Tensor) Tensor { var x = x0; for (self.layers) |layer| { - x = ops.call(layer, .forward, .{x}); + x = ops.call(layer, ._fwd, .{x}); } return x; } - pub fn forwardRefImpl(self: NN_, x0: Tensor) Tensor { + pub fn _forwardRefImpl(self: NN_, x0: Tensor) Tensor { var x = x0; for (self.layers) |layer| { - x = layer.forward(x); + x = layer._fwd(x); } return x; } @@ -1113,8 +1114,8 @@ test FnCache { }, }, }; - const res = try zml.testing.compileAndCall(platform, NN.forward, .{ nn, x }); - const expected = try zml.testing.compileAndCall(platform, NN.forwardRefImpl, .{ nn, x }); + const res = try zml.testing.compileAndCall(platform, NN._fwd, .{ nn, x }); + const expected = try zml.testing.compileAndCall(platform, NN._forwardRefImpl, .{ nn, x }); try zml.testing.expectClose(expected, res, 1e-4); } diff --git a/zml/ops.zig b/zml/ops.zig index 7809143..21bcb9f 100644 --- a/zml/ops.zig +++ b/zml/ops.zig @@ -16,6 +16,7 @@ const EnumLiteral = @TypeOf(.enum_literal); const HostBuffer = @import("hostbuffer.zig").HostBuffer; const Shape = @import("shape.zig").Shape; const Tensor = @import("tensor.zig").Tensor; +const _collectAxes = @import("tensor.zig")._collectAxes; const dialect = struct { const stablehlo = @import("mlir/dialects").stablehlo; @@ -73,23 +74,23 @@ test "simple while" { end: Tensor, const CountInts = @This(); - pub fn hasNext(self: CountInts, i: Tensor, sum: Tensor) Tensor { + pub fn _hasNext(self: CountInts, i: Tensor, sum: Tensor) Tensor { _ = sum; return i.cmp(.LT, self.end); } - pub fn next(self: CountInts, i: Tensor, sum: Tensor) [2]Tensor { + pub fn _next(self: CountInts, i: Tensor, sum: Tensor) [2]Tensor { const r1 = i.add(self.step); const r2 = sum.add(i); return .{ r1, r2 }; } - pub fn forward(self: CountInts, init_i: Tensor, init_sum: Tensor) [2]Tensor { + pub fn _fwd(self: CountInts, init_i: Tensor, init_sum: Tensor) [2]Tensor { const x = init_i.scale(2); - return while_(CountInts.hasNext, CountInts.next, self, .{ x, init_sum }); + return while_(CountInts._hasNext, CountInts._next, self, .{ x, init_sum }); } - pub fn zigForward(step: i64, end: i64, init_i: i64, init_sum: i64) [2]i64 { + pub fn _zigForward(step: i64, end: i64, init_i: i64, init_sum: i64) [2]i64 { const x = init_i * 2; var i = x; var sum = init_sum; @@ -110,14 +111,14 @@ test "simple while" { .step = try zml.Buffer.fromSlice(platform, .{}, &[_]i64{1}), .end = try zml.Buffer.fromSlice(platform, .{}, &[_]i64{10}), }; - const res0, const res1 = try zml.testing.compileAndCall(platform, CountInts.forward, .{ counter, init_i, init_sum }); + const res0, const res1 = try zml.testing.compileAndCall(platform, CountInts._fwd, .{ counter, init_i, init_sum }); const last_i = try res0.getValue(i64); const sum = try res1.getValue(i64); try std.testing.expectEqual(10, last_i); try std.testing.expectEqual(45, sum); - try std.testing.expectEqual(.{ 10, 45 }, CountInts.zigForward(1, 10, 0, 0)); + try std.testing.expectEqual(.{ 10, 45 }, CountInts._zigForward(1, 10, 0, 0)); } pub fn reduce( @@ -361,7 +362,7 @@ test for_ { return f.mul(f); } - pub fn forward(num_steps: u63) Tensor { + pub fn _fwd(num_steps: u63) Tensor { return for_(Squares.sq, .{}, .{num_steps}); } }; @@ -370,19 +371,19 @@ test for_ { // Just one baby step { - const squares = try zml.testing.compileAndCall(platform, Squares.forward, .{1}); + const squares = try zml.testing.compileAndCall(platform, Squares._fwd, .{1}); try zml.testing.expectEqualShapes(Shape.init(.{1}, .f32), squares.shape()); try std.testing.expectEqual(0, squares.getValue(f32)); } // Wow 4 in rows ! { - const squares = try zml.testing.compileAndCall(platform, Squares.forward, .{4}); + const squares = try zml.testing.compileAndCall(platform, Squares._fwd, .{4}); try zml.testing.expectEqualShapes(Shape.init(.{4}, .f32), squares.shape()); try std.testing.expectEqual([_]f32{ 0, 1, 4, 9 }, try squares.getValue([4]f32)); } // AGI is coming, computing 10 squares as it's nothing. { - const squares = try zml.testing.compileAndCall(platform, Squares.forward, .{10}); + const squares = try zml.testing.compileAndCall(platform, Squares._fwd, .{10}); try zml.testing.expectEqualShapes(Shape.init(.{10}, .f32), squares.shape()); try std.testing.expectEqual( [_]f32{ 0, 1, 4, 9, 16, 25, 36, 49, 64, 81 }, @@ -398,7 +399,7 @@ test "nested for" { x: Tensor, x_row: Tensor, - pub fn forward(x: Tensor) Tensor { + pub fn _fwd(x: Tensor) Tensor { return for_(OuterProd.scanRow, x, .{x.dim(0)}); } @@ -418,7 +419,7 @@ test "nested for" { // 5 to prevent inlining const x = try zml.Buffer.fromArray(platform, [5]f32{ 0, 1.0, -1.0, 2.0, -2.0 }); - const outer_prod = try zml.testing.compileAndCall(platform, OuterProd.forward, .{x}); + const outer_prod = try zml.testing.compileAndCall(platform, OuterProd._fwd, .{x}); const expected: [5][5]f32 = .{ .{ 0, 0, 0, 0, 0 }, .{ 0, 1.0, -1.0, 2.0, -2.0 }, @@ -468,7 +469,7 @@ test "if" { const allocator = std.testing.allocator; const IfMod = struct { - pub fn forward(pred: Tensor, a: Tensor, b: Tensor) Tensor { + pub fn _fwd(pred: Tensor, a: Tensor, b: Tensor) Tensor { const result = if_(pred.convert(.bool), condTrue, condFalse, .{ a, b }); return result; } @@ -486,7 +487,7 @@ test "if" { const pred = Shape.init(.{}, .i32); const a = Shape.init(.{ 4, 4 }, .f32); const b = Shape.init(.{ 4, 4 }, .f32); - const mod = try zml.compileFn(allocator, IfMod.forward, .{ pred, a, b }, platform); + const mod = try zml.compileFn(allocator, IfMod._fwd, .{ pred, a, b }, platform); defer mod.deinit(); } } @@ -753,3 +754,317 @@ pub fn addHostCallback( ); return Tensor._result(input.shape(), op.result(0)); } + +/// Generalized version of scatter to many inputs. +/// See `zml.Tensor.scatterSlices` for documentation on scatter. +/// +/// This allows to use the same indices to update several tensors at once, +/// and where the update function is allow to look at elements from the different tensors +/// to compute the final value. +/// +/// This sounds nice but in practice XLA doesn't support this well on GPU, +/// and will generate slow code. In practice stick with `zml.Tensor.scatterSlices`. +pub fn scatter( + comptime T: type, + comptime BlkCtx: type, + comptime update_fn: fn (BlkCtx, T, T) T, + inputs: T, + blkctx: BlkCtx, + index_tensors: anytype, + updates: T, + opts: Tensor.ScatterOpts, +) T { + const loc = @src(); + const ctx = CompilationContext.current(); + + const n_inputs = meta.count(Tensor, &inputs); + const n_updates = meta.count(Tensor, &updates); + stdx.debug.assert(n_inputs == n_updates, "zml.ops.scatter expects the same number of tensors in inputs and updates, got {} and {}", .{ n_inputs, n_updates }); + + // Note: I was a bit lazy here, and I only look at tags on the first tensor. + // we probably should check all of them. + const self = meta.first(Tensor, inputs); + const update = meta.first(Tensor, updates); + var indices_per_axis, var indices_axes = Shape.parseStruct(Tensor, index_tensors); + + if (indices_per_axis.len == 0) return inputs; + + // validate coord axes: all coord_axes should exist inside self + for (indices_axes.constSlice()) |t| { + stdx.debug.assert(self._shape.hasTag(t) != null, "zml.ops.scatter expects axes of indices to be axes of inputs, got input={_} and indices={any}", .{ self, indices_axes }); + } + + // Handle scalar indices by broadcasting them to the indices with the highest rank. + const indices_shape = blk: { + var higher_rank = indices_per_axis.get(0).shape(); + for (indices_per_axis.constSlice()[1..]) |indices| { + if (indices.rank() > higher_rank.rank()) { + higher_rank = indices.shape(); + } + } + break :blk higher_rank; + }; + for (indices_per_axis.slice()) |*idx| { + stdx.debug.assert(idx.shape().canBroadcastTo(indices_shape), "zml.ops.scatter expects all indices tensor to have the same shape, got {_}", .{indices_per_axis.slice()}); + stdx.debug.assert(idx.dtype() == indices_shape.dtype(), "zml.ops.scatter expects all indices tensor to have the same dtype, got {_}", .{indices_per_axis.slice()}); + idx.* = idx.broad(indices_shape); + } + + // rewrite simple scatters to dynamicUpdateSlice. + if (T == Tensor and indices_shape.rank() == 0) { + return self.dynamicUpdateSlice(index_tensors, updates); + } + + // TODO: ideally we should catch all possible scatter errors and provide nice error messages. + var config = scatterConfig(self.shape(), update.shape(), indices_per_axis, indices_axes); + const indices = scatterPrepareIndices(&config, self.shape(), update.shape(), &indices_per_axis, &indices_axes); + // const n_indices_axes = update.rank() - _collectAxes(AxisKind, up_kind, .update_window).len; + // stdx.debug.assert(n_indices_axe == indices_axes.len, "scatter({_}, {any}) expects 'updates' to contain all axes from 'indices', got indices={s}, updates={_}", .{ self, index_tensors, indices_axes.constSlice(), update }); + + const mlir_ctx = ctx.mlirCtx(); + var _scalar: T = inputs; + meta.visit(struct { + pub fn cb(_: void, x: *Tensor) void { + x.* = .{ ._shape = Shape.init(.{}, x.dtype()), ._id = undefined }; + } + }.cb, {}, &_scalar); + + const UpdateS = BlockSign(update_fn); + const update_block, _ = ctx.makeBlock(.hermetic, UpdateS, update_fn, blkctx, .{ _scalar, _scalar }); + + var input_values = std.ArrayList(mlir.Value).initCapacity(ctx.allocator(), n_inputs) catch @panic("OOM"); + defer input_values.deinit(); + meta.collect(CompilationContext.getValue, ctx, &input_values, &inputs) catch unreachable; + var updates_values = std.ArrayList(mlir.Value).initCapacity(ctx.allocator(), n_updates) catch @panic("OOM"); + defer updates_values.deinit(); + meta.collect(CompilationContext.getValue, ctx, &updates_values, &updates) catch unreachable; + + const op = dialect.stablehlo.scatter( + mlir_ctx, + input_values.items, + &.{indices.value()}, + updates_values.items, + update_block, + .{ + .update_window_dims = _collectAxes(AxisKind, config.up_kind, .update_window).constSlice(), + .inserted_window_dims = _collectAxes(AxisKind, config.op_kind, .inserted_window).constSlice(), + .input_batching_dims = _collectAxes(AxisKind, config.op_kind, .batching).constSlice(), + .scatter_indices_batching_dims = config.indices_batch_axes.constSlice(), + .scatter_dims_to_operand_dims = config.scatter_to_operand_axes.constSlice(), + .index_vector_dim = indices.rank() - 1, + .indices_are_sorted = opts.indices_are_sorted, + .unique_indices = opts.indices_are_unique, + }, + mlir_ctx.location(loc), + ); + + var res: T = inputs; + const LocalContext = struct { + op: mlir.Operation, + index: usize = 0, + }; + var local_context = LocalContext{ .op = op }; + meta.visit((struct { + fn cb(inner_ctx: *LocalContext, tensor: *Tensor) void { + const val = inner_ctx.op.result(inner_ctx.index); + tensor.* = Tensor._result(tensor.shape(), val); + inner_ctx.index += 1; + } + }).cb, &local_context, &res); + assert(local_context.index == op.numResults()); + return res; +} + +const ScatterConfig = struct { + op_kind: std.BoundedArray(AxisKind, Tensor.MAX_RANK) = .{}, + up_kind: std.BoundedArray(AxisKind, Tensor.MAX_RANK) = .{}, + indices_batch_axes: Shape.DimsArray = .{}, + scatter_to_operand_axes: Shape.DimsArray = .{}, + updates_transpose: Shape.AxesArray = .{}, +}; + +const AxisKind = enum { batching, update_window, inserted_window, window_id }; + +fn scatterConfig( + op: Shape, + update: Shape, + indices_per_axis: std.BoundedArray(Tensor, Tensor.MAX_RANK), + indices_axes: Shape.TagsArray, +) ScatterConfig { + var op_kind: std.BoundedArray(AxisKind, Tensor.MAX_RANK) = .{}; + var up_kind: std.BoundedArray(AxisKind, Tensor.MAX_RANK) = .{}; + var indices_batch_axes: Shape.DimsArray = .{}; + var scatter_to_operand_axes: Shape.DimsArray = .{}; + var updates_transpose: Shape.AxesArray = .{}; + + const tagged_api = indices_axes.len > 0; + const indices = indices_per_axis.get(0).shape(); + + if (tagged_api) { + for (indices_axes.constSlice()) |t| { + scatter_to_operand_axes.appendAssumeCapacity(op.axis(t)); + } + for (indices.tags()) |t| { + stdx.debug.assert(update.hasTag(t) != null, "scatter expects 'updates' to have all axes of 'indices', got updates={} and indices={s}", .{ update, indices_axes.constSlice() }); + updates_transpose.appendAssumeCapacity(update.axis(t)); + } + + for (op.tags()) |t| { + if (update.hasTag(t)) |up_ax| { + updates_transpose.appendAssumeCapacity(up_ax); + + if (indices.hasTag(t)) |id_ax| { + if (std.mem.indexOfScalar(Shape.Tag, indices_axes.constSlice(), t) != null) { + // tag is in indices AND in coords -> it's a batching dim that has been rewritten to a regular insertion dim + op_kind.appendAssumeCapacity(.inserted_window); + } else { + // tag is in op, indices and updates -> it's a batching dim + op_kind.appendAssumeCapacity(.batching); + indices_batch_axes.appendAssumeCapacity(@intCast(id_ax)); + } + } else { + op_kind.appendAssumeCapacity(.update_window); + } + } else { + op_kind.appendAssumeCapacity(.inserted_window); + } + } + + for (update.tags(), 0..) |t, up_ax| { + // Handle batch axes right away. + if (op.hasTag(t)) |self_ax| { + if (op_kind.get(self_ax) == .batching) { + up_kind.appendAssumeCapacity(.batching); + continue; + } + } + if (indices.hasTag(t) != null) { + up_kind.appendAssumeCapacity(.window_id); + } else if (op.hasTag(t)) |self_ax| { + stdx.debug.assert(update.dim(up_ax) <= op.dim(self_ax), "scatter expects the slices described in 'updates' to fit inside 'op', but along axis .{s} it doesn't. Got op={_}, updates={_}.", .{ t, op, update }); + up_kind.appendAssumeCapacity(.update_window); + } else { + // TODO: consider accepting untagged update here. + std.debug.panic("scatter expects 'updates' to be made of axes from op={_} and from indices={s}, got unknown tag {s} in {_}", .{ op, indices_axes.constSlice(), t, update }); + } + } + } else { + for (0..indices_per_axis.len) |i| { + op_kind.appendAssumeCapacity(.inserted_window); + scatter_to_operand_axes.appendAssumeCapacity(@intCast(i)); + up_kind.appendAssumeCapacity(.window_id); + } + for (indices_per_axis.len..op.rank()) |_| { + op_kind.appendAssumeCapacity(.update_window); + } + for (indices_per_axis.len..update.rank()) |_| { + up_kind.appendAssumeCapacity(.update_window); + } + for (0..update.rank()) |i| { + updates_transpose.appendAssumeCapacity(@intCast(i)); + } + } + + return .{ + .op_kind = op_kind, + .up_kind = up_kind, + .indices_batch_axes = indices_batch_axes, + .scatter_to_operand_axes = scatter_to_operand_axes, + .updates_transpose = updates_transpose, + }; +} + +test scatterConfig { + const zml = @import("zml.zig"); + const platform = zml.testing.env(); + + var comp = try zml.module.CompilationContext.init(std.testing.allocator, "test", platform); + defer comp.deinit(); + comp.activate(); + defer comp.deactivate(); + + const Local = struct { + pub fn _idx(idx_shape: anytype) Tensor { + return Tensor.constant(idx_shape, .{ .i32 = 0 }); + } + }; + + const idx = Local._idx; + const op = Shape.init(.{ .a = 10, .b = 20 }, .f32); + + // Use .a as a batching axis with .a=10 x .n=8 updates of 2 elements of .b + { + const indices, const coords_tags = Shape.parseStruct(Tensor, .{ .b = idx(.{ .a = 10, .n = 8 }) }); + const update = Shape.init(.{ .a = 10, .n = 8, .b = 2 }, .f32); + + const cfg = scatterConfig(op, update, indices, coords_tags); + try std.testing.expectEqualSlices(AxisKind, &.{ .batching, .update_window }, cfg.op_kind.constSlice()); + try std.testing.expectEqualSlices(AxisKind, &.{ .batching, .window_id, .update_window }, cfg.up_kind.constSlice()); + } + + // similar, but use the normalized form where .a is no longer an explicit batching axis. + { + const indices, const coords_tags = Shape.parseStruct(Tensor, .{ .a = idx(.{ .a = 10, .n = 8 }), .b = idx(.{ .a = 10, .n = 8 }) }); + const update = Shape.init(.{ .a = 10, .n = 8, .b = 2 }, .f32); + + const cfg = scatterConfig(op, update, indices, coords_tags); + try std.testing.expectEqualSlices(AxisKind, &.{ .inserted_window, .update_window }, cfg.op_kind.constSlice()); + try std.testing.expectEqualSlices(AxisKind, &.{ .window_id, .window_id, .update_window }, cfg.up_kind.constSlice()); + } +} + +/// Concatenate all indices tensor in one tensor. +/// +/// Is allowed to reorder stuff to simplify the job of the backend, +/// and to expand the batching dims. +fn scatterPrepareIndices( + cfg: *ScatterConfig, + op: Shape, + update: Shape, + indices_per_axis: *std.BoundedArray(Tensor, Tensor.MAX_RANK), + indices_axes: *Shape.TagsArray, +) Tensor { + var old_scatter_to_op_axes = cfg.scatter_to_operand_axes; + const batching = _collectAxes(AxisKind, cfg.op_kind, .batching); + for (batching.constSlice()) |batch_ax| { + const id_shape = indices_per_axis.get(0).shape(); + // batching requires tagging, so we're sure to have a tag here. + const batch_tag = op.tag(batch_ax); + indices_axes.appendAssumeCapacity(batch_tag); + const batch_id = Tensor.iota(id_shape, batch_tag).convert(id_shape.dtype()); + indices_per_axis.appendAssumeCapacity(batch_id); + cfg.op_kind.buffer[@intCast(batch_ax)] = .inserted_window; + cfg.up_kind.buffer[update.axis(batch_tag)] = .window_id; + old_scatter_to_op_axes.appendAssumeCapacity(batch_ax); + } + cfg.indices_batch_axes = .{}; + + // Reorder the axes so that in indices_per_axis is ordered like in op if possible. + // TODO: transpose updates if needed + var indices: std.BoundedArray(Tensor, Tensor.MAX_RANK) = .{}; + var scatter_to_op_axes: Shape.DimsArray = .{}; + + while (old_scatter_to_op_axes.len > 0) { + const scatter_ax = std.sort.argMin(i64, old_scatter_to_op_axes.constSlice(), {}, std.sort.asc(i64)).?; + const op_ax = old_scatter_to_op_axes.orderedRemove(scatter_ax); + const scatter_idx = indices_per_axis.orderedRemove(scatter_ax); + + scatter_to_op_axes.appendAssumeCapacity(op_ax); + indices.appendAssumeCapacity(scatter_idx); + } + cfg.scatter_to_operand_axes = scatter_to_op_axes; + + for (scatter_to_op_axes.constSlice(), 0..) |sc_ax, i| { + if (i != sc_ax) { + log.warn("Found a slow scatter pattern, which is going to generate a while loop: scatter({_}, {any}, {_}). Because the index axes aren't the major ones in the input tensor.", .{ op, scatter_to_op_axes.constSlice(), update }); + break; + } + } + return Tensor.stack(indices.constSlice(), .last, .coord); +} + +inline fn toI64(values: anytype) []i64 { + var res: [Tensor.MAX_RANK]i64 = undefined; + for (values, 0..) |val, i| res[i] = @intCast(val); + return res[0..values.len]; +} diff --git a/zml/shape.zig b/zml/shape.zig index 54fa3a9..3d2b9f8 100644 --- a/zml/shape.zig +++ b/zml/shape.zig @@ -410,6 +410,28 @@ pub const Shape = struct { _ = try writer.write(if (bare_fmt) "}" else "})"); } + /// Broadcasts a Tensor to the given shape, extending dimensions if needed. + pub fn canBroadcastTo(self: Shape, other: Shape) bool { + // Already the right shape + if (std.mem.eql(i64, self.dims(), other.dims())) return true; + + // Non ambiguous broadcasting + // TODO: broad is error prone because of this: + // it will happily broadcast .{ .a = 10, .b = 1 } to .{ .b = 10, .a = 5 } + if (self.rank() == 0 or self.rank() == other.rank()) { + for (0..self.rank()) |i| { + if (self.dim(i) != 1 and self.dim(i) != other.dim(i)) return false; + } + return true; + } + + for (self.dims(), self.tags()) |d, t| { + const other_ax = other.hasTag(t) orelse return false; + if (d != 1 and d != other.dim(other_ax)) return false; + } + return true; + } + pub fn reshape(self: Shape, new_shape_: anytype) Shape { var new_shape: Shape = .{ ._dtype = self.dtype() }; new_shape._dims, new_shape._tags = parseDimensions(new_shape_); diff --git a/zml/tensor.zig b/zml/tensor.zig index 4edb06c..1acb797 100644 --- a/zml/tensor.zig +++ b/zml/tensor.zig @@ -1535,21 +1535,21 @@ pub const Tensor = struct { // Wrap slice1d to hide the anytype in the signature. const Local = struct { - pub fn slice1dAxis(input: Tensor, ax: i8, slice_: Tensor.Slice) Tensor { + pub fn _slice1dAxis(input: Tensor, ax: i8, slice_: Tensor.Slice) Tensor { return input.slice1d(ax, slice_); } }; { - const res = try zml.testing.compileAndCallWithTensors(platform, Local.slice1dAxis, .{ x.shape(), 0, .{ .end = 1 } }, .{ x, 0, .{ .end = 1 } }); + const res = try zml.testing.compileAndCallWithTensors(platform, Local._slice1dAxis, .{ x.shape(), 0, .{ .end = 1 } }, .{ x, 0, .{ .end = 1 } }); try testing.expectEqual([5]f32{ 0, 1, 2, 3, 4 }, try res.getValue([5]f32)); } { - const res = try zml.testing.compileAndCallWithTensors(platform, Local.slice1dAxis, .{ x.shape(), 1, .{ .start = 1, .step = 2 } }, .{ x, 0, .{ .start = 1, .step = 2 } }); + const res = try zml.testing.compileAndCallWithTensors(platform, Local._slice1dAxis, .{ x.shape(), 1, .{ .start = 1, .step = 2 } }, .{ x, 0, .{ .start = 1, .step = 2 } }); try testing.expectEqual([4]f32{ 1, 3, 6, 8 }, try res.getValue([4]f32)); } { - const res = try zml.testing.compileAndCallWithTensors(platform, Local.slice1dAxis, .{ x.shape(), -1, .{ .start = -2 } }, .{ x, 0, .{ .start = -2 } }); + const res = try zml.testing.compileAndCallWithTensors(platform, Local._slice1dAxis, .{ x.shape(), -1, .{ .start = -2 } }, .{ x, 0, .{ .start = -2 } }); try testing.expectEqual([4]f32{ 3, 4, 8, 9 }, try res.getValue([4]f32)); } } @@ -1565,6 +1565,7 @@ pub const Tensor = struct { /// Concatenates the input Tensors along the given axis. pub fn concatenate(tensors: []const Tensor, axis_: anytype) Tensor { + if (tensors.len == 1) return tensors[0]; stdx.debug.assert(tensors.len <= 32, "concatenate only supports up to 32 tensors, got {}", .{tensors.len}); var buffer: [32]mlir.Value = undefined; std.debug.assert(tensors.len <= buffer.len); @@ -1883,12 +1884,18 @@ pub const Tensor = struct { /// you will lose the tags. /// To avoid use favorise `.broad(shape)` when working with tagged tensors. pub fn broadcast(self: Tensor, output_shape: Shape, axes_: []const i64) Tensor { - const res_shape = output_shape.withDtype(self.dtype()); stdx.debug.assert(axes_.len == self.rank(), "broadcast expects axes_ to map all axes from self to axes of the output shape, got broadcast({}, {}, {d})", .{ self, output_shape, axes_ }); for (0.., axes_) |self_ax, other_ax| { const d = self.dim(self_ax); stdx.debug.assert(d == 1 or d == output_shape.dim(other_ax), "broadcast expects shape axes to either be 1-sized or to match the target size. got broadcast({}, {}, {d}), error on self axis {} mapping to other axis {}", .{ self, output_shape, axes_, self_ax, other_ax }); } + + const res_shape = output_shape.withDtype(self.dtype()); + if (std.mem.eql(i64, self.dims(), output_shape.dims())) { + // No broadcast needed. We don't emit a new stablehlo value + // but we propagate output_shape tags. + return _result(res_shape, self.value()); + } const ctx = self.getContext(); const result_type = mlir.ext.RankedTensorType.fromShape(ctx.mlirCtx(), res_shape).as(mlir.Type).?; const loc = ctx.location(@src(), "broadcast({_}, {_}, axes={d})", .{ self, res_shape, axes_ }); @@ -1922,21 +1929,25 @@ pub const Tensor = struct { /// Broadcasts a Tensor to the given shape, extending dimensions if needed. pub fn broad(self: Tensor, other: Shape) Tensor { + // TODO: broad is too restrictive because sometime you only want to specify one specific axis + // Note: if you code below, make sure to update Shape.canBroadcastTo. + stdx.debug.assert(self._shape.canBroadcastTo(other), "Can't broadcast {} to {}", .{ self, other }); + + // Already the right shape + if (std.mem.eql(i64, self.dims(), other.dims())) return self; + // Non ambiguous broadcasting + // TODO: broad is error prone because of this: + // it will happily broadcast .{ .a = 10, .b = 1 } to .{ .b = 10, .a = 5 } if (self._shape.rank() == 0 or self._shape.rank() == other.rank()) { - return self.broadcast(other, Shape.range(self._shape.rank(), .bool).dims()); + const all_axes = [MAX_RANK]i64{ 0, 1, 2, 3, 4, 5, 6, 7 }; + return self.broadcast(other, all_axes[0..self.rank()]); } // check that each axis of self maps to an axis of other var axes_: std.BoundedArray(i64, MAX_RANK) = .{}; for (self._shape.tags()) |t| { - if (t != Shape.TagUnknown) { - if (other.hasTag(t)) |ax| { - axes_.appendAssumeCapacity(@intCast(other.axis(ax))); - } else { - std.debug.panic("Can't broadcast {} to {}", .{ self, other }); - } - } + axes_.appendAssumeCapacity(@intCast(other.axis(t))); } return self.broadcast(other, axes_.constSlice()); } @@ -2392,12 +2403,13 @@ pub const Tensor = struct { } pub const ScatterOpts = struct { - /// Promise scatter that all coordinates in `indices` are sorted, wrt to the final in memory offset. + /// Promise scatter that all coordinates in `indices` are sorted, wrt to the final offset in `self` /// Result is undefined if the promise is violated. indices_are_sorted: bool = false, /// Promise scatter that slices don't overlap. /// Result is undefined if the promise is violated. + /// This allows for better code generation, because it means that updates can be applied in parallel. indices_are_unique: bool = false, /// Function used to update previous value in `self` with values from `updates`. @@ -2405,132 +2417,102 @@ pub const Tensor = struct { /// then you should make sure the slices don't overlap, /// otherwise the result will depend on the runtime scheduling /// of the operator which is backend specific. - update_fn: *const fn (*const anyopaque, Tensor, Tensor) Tensor = increment, + update_fn: *const fn (Tensor, Tensor) Tensor = increment, - /// Extra data that may be needed for a custom update function. - /// `override` and `increment` don't need it, leaving it to undefined works. - update_fn_ctx: *const anyopaque = undefined, - - pub fn increment(_: *const anyopaque, old_value: Tensor, new_value: Tensor) Tensor { - return old_value.add(new_value); + pub fn increment(old_value: Tensor, new_value: Tensor) Tensor { + return old_value.add(new_value.convert(old_value.dtype())); } - pub fn override(_: *const anyopaque, old_value: Tensor, new_value: Tensor) Tensor { - _ = old_value; - return new_value; + pub fn override(old_value: Tensor, new_value: Tensor) Tensor { + return new_value.convert(old_value.dtype()); } }; - /// Update the given tensors, by copying `values` into self slices. + /// Update the given tensor, by copying `values` into slice by slice into `self`. /// The slices are chosen at runtime by interpreting indices as coordinates into `self`. - /// * `indices` represents a set of coordinates into `self`. - /// For the sake of simplifying the creation of `indices` tensor, - /// it's allowed to not mention a specific axis if the coordinate for this axis is always `0`. - /// Similarly to `gatherValues`, the coordinates are read from the `.coord` axis, or last axis if `.coord` is not found. - /// The coordinates represent the "top-left" corner of the slice to extract. - /// `indices.dim(.coord)` must match `coord_axes.len`. - /// Other axes identify one "slice" and they must be found inside `updates`. + /// This is a generalized version of `dynamicUpdateSlice` where more than one offset can be specified at a time. /// - /// * the output tensor starts with axes from `indices`. - /// * if the input tensor has tagged axes, matching `indices` axes, - /// they will be considered "batching" axes. + /// ### Arguments /// - /// Sample input/output shapes: - /// * scatterSlices([A, B, C, D], .{b, c}, [N, 2], [N, B', C']) -> [A, B, C, D] - /// * scatterSlices(x(a,b,c,d), g(n,m), y[n,b,c]) [A,B,C,D] { - /// var z = x; - /// for (0..N) |n| { z[a,g[n,0]+b',g[n,1]+c',d] = y[n,a,b',c',d]; } + /// - Return a tensor with same shape than `self`, with updated content. + /// - `indices` is a set of Tensor (typically rank 1), representing coordinates into `self`. + /// all indices must have the same shape, but scalars are accepted. + /// - each `indices` entry contains offset along an axes into `self`. + /// Typically axes are identified by their tags, but in the absence of tags on `indices`, + /// The entry in indices will be assigned to axes of `self` from major to minor axis. + /// It is recommended to have indices referencing only major axes of `self` for better performance. + /// - `values` shape is obtained by concatenating the shape of `indices` with the shape of the slices to be extracted. + /// - `opts`: `zml.Tensor.ScatterOpts` des + /// + /// ### Sample input/output shapes with corresponding pseudo-code. + /// + /// Basic `scatterSlices` with the first two axes (.a, .b) being indexed, and full (.c, .d) slice copies: + /// + /// ``` + /// fn scatterSlices(x[A, B, C, D], .{.a=off_a[N], .b=off_b[N]}, y[N, C, D]) [A, B, C, D] { + /// var z = x; + /// for (0..N) |n| { + /// for (0..C) |c| for (0..D) |d| {{ + /// z[off_a[n],off_b[n],c,d] += y[n, c, d]; + /// }} + /// } + /// return z; /// } + /// ``` /// - /// **Warning**: if `opts.update_fn` is not associative not all calls to `scatterSlices` are sound. + /// `scatterSlices` with the first three axes (.a, .b, .c) being indexed, and a partial copy of (.c, .d). + /// Note that .c axis is present both in the indices and updates, and `updates.dim(.c) < self.dim(.c)`. + /// + /// ``` + /// fn scatterSlices(x[A, B, C, D], .{.a=off_a[N], .b=off_b[N], .c=off_c[N]}, y[N, C', D]) [A, B, C, D] { + /// var z = x; + /// for (0..N) |n| { + /// for (0..C') |c| for (0..D) |d| {{ + /// z[off_a[n],off_b[n],off_c[n]+c,d] += y[n, c, d]; + /// }} + /// } + /// return z; + /// } + /// ``` + /// + /// `scatterSlices` with the first axis .a being indexed, and where .b is used as a batching axis. + /// Note that here .b axis is present in `self`, `off_a`, and `updates`, + /// and is not mentionned in the axes of indices. + /// + /// ``` + /// fn scatterSlices(x[A, B, C, D], .{.a=off_a[B,N]}, y[N, B, C, D]) [A, B, C, D] { + /// var z = x; + /// for (0..B) |b| { + /// for (0..N) |n| { + /// for (0..C) |c| for (0..D) |d| {{ + /// z[off_a[b,n],b,c,d] += y[n, b, c, d]; + /// }} + /// } + /// } + /// return z; + /// } + /// ``` + /// + /// ### Warnings + /// + /// - if `opts.update_fn` is not associative not all calls to `scatterSlices` are sound. /// In particular if you scatter overlapping slices, with `zml.Tensor.ScatterOpts.override`, /// then the result will depend on the execution order that you don't control. - pub fn scatterSlices(self: Tensor, coord_axes: anytype, indices: Tensor, updates: Tensor, opts: ScatterOpts) Tensor { - const loc = @src(); - // scoped_log.debug("scatterSlices({}, {any}, {}, {})", .{ self, coord_axes, indices, updates }); + /// - `scatterSlices` is a very expressive operator, and can lead to complicated code generation + /// that requires host<->device synchronization. + /// ZML tries to generate the easiest to optimize IR, and will warn you if it generates known problematic IR. + pub fn scatterSlices(self: Tensor, indices: anytype, updates: Tensor, opts: ScatterOpts) Tensor { + scoped_log.debug("scatterSlices({}, {any}, {})", .{ self, indices, updates }); - stdx.debug.assert(self.dtype() == updates.dtype(), "scatterSlices expects input and 'updates' tensors to be of the same type, got {} and {}", .{ self.dtype(), updates.dtype() }); + const UpdateType = @TypeOf(ScatterOpts.increment); - const single_coord, const coord_axes_ = _parseGatherCoord(self, coord_axes); - const AxisKind = enum { batching, update_window, inserted_window, window_id }; - var self_kind: std.BoundedArray(AxisKind, MAX_RANK) = .{}; - var indices_batch_axes: Shape.DimsArray = .{}; - for (self._shape.tags()) |t| { - if (updates._shape.hasTag(t)) |_| { - if (indices._shape.hasTag(t)) |id_ax| { - // tag is in self, indices and updates -> it's a batching dim - self_kind.appendAssumeCapacity(.batching); - indices_batch_axes.appendAssumeCapacity(id_ax); - } else { - self_kind.appendAssumeCapacity(.update_window); - } - } else { - self_kind.appendAssumeCapacity(.inserted_window); + const Custom = struct { + pub fn inc(custom: *const UpdateType, old_value: Tensor, new_value: Tensor) Tensor { + return @call(.auto, custom, .{ old_value, new_value }); } - } - // scoped_log.warn(" self_kind -> {any}", .{self_kind.constSlice()}); - - const index_coord_axis = if (single_coord) - indices.rank() - else blk: { - const ax = indices._shape.hasTag(.coord) orelse indices._shape.axis(-1); - stdx.debug.assert(indices.dim(ax) == coord_axes_.len, "scatterSlices({}, coord_axes={any}, indices, updates) expects 'indices' to be a tensor [..., {}], got {}", .{ self, coord_axes, coord_axes_.len, indices }); - - break :blk ax; }; - if (indices.count() == 1 and !single_coord) { - return self.dynamicUpdateSlice1d(updates, coord_axes_.get(0), indices.reshape(.{})); - } - var up_kind: std.BoundedArray(AxisKind, MAX_RANK) = .{}; - // Note: we assume the scatter_dims appear in the same order inside indices and inside self. - for (updates._shape.tags(), 0..) |t, up_ax| { - if (self._shape.hasTag(t)) |self_ax| { - if (self_kind.get(self_ax) == .batching) { - up_kind.appendAssumeCapacity(.batching); - } else { - stdx.debug.assert(updates.dim(up_ax) <= self.dim(self_ax), "scatterSlices expects the slices described in 'updates' to fit inside 'self', but along axis .{s} it doesn't. Got self={}, updates={}.", .{ t, self, updates }); - up_kind.appendAssumeCapacity(.update_window); - } - } else if (t == Shape.TagUnknown or indices._shape.hasTag(t) != null) { - up_kind.appendAssumeCapacity(.window_id); - } else { - std.debug.panic("scatterSlices expects 'updates' to be made of axes from 'self={}' and from 'indices={}', got unknown tag {s} in {}", .{ self, indices, t, updates }); - } - } - const n_indices_axes = updates.rank() - _collectAxes(AxisKind, up_kind, .update_window).len; - if (single_coord) { - stdx.debug.assert(n_indices_axes == indices.rank(), "scatterSlices({}, {any}) expects 'updates' to contain all axes from 'indices', got indices={}, updates={}", .{ self, coord_axes, indices, updates }); - } else { - stdx.debug.assert(n_indices_axes == indices.rank() - 1, "scatterSlices({}, {any}) expects 'updates' to contain all-but-last axes from 'indices', got indices={}, updates={}", .{ self, coord_axes, indices, updates }); - } - - const ctx = self.getContext(); - const mlir_ctx = ctx.mlirCtx(); - - const _scalar: Tensor = .{ ._shape = Shape.init(.{}, self.dtype()), ._id = undefined }; - const UpdateS = ops.BlockSign(ScatterOpts.increment); - const update_block, _ = ctx.makeBlock(.hermetic, UpdateS, opts.update_fn, opts.update_fn_ctx, .{ _scalar, _scalar }); - - const op = dialect.stablehlo.scatter( - mlir_ctx, - &.{self.value()}, - &.{indices.value()}, - &.{updates.value()}, - update_block, - .{ - .update_window_dims = _collectAxes(AxisKind, up_kind, .update_window).constSlice(), - .inserted_window_dims = _collectAxes(AxisKind, self_kind, .inserted_window).constSlice(), - .input_batching_dims = _collectAxes(AxisKind, self_kind, .batching).constSlice(), - .scatter_indices_batching_dims = indices_batch_axes.constSlice(), - .scatter_dims_to_operand_dims = toI64(coord_axes_.constSlice()), - .index_vector_dim = index_coord_axis, - .indices_are_sorted = opts.indices_are_sorted, - .unique_indices = opts.indices_are_unique, - }, - mlir_ctx.location(loc), - ); - return _result(self._shape, op.result(0)); + return ops.scatter(Tensor, *const UpdateType, Custom.inc, self, opts.update_fn, indices, updates, opts); } test scatterSlices { @@ -2538,14 +2520,25 @@ pub const Tensor = struct { const platform = zml.testing.env(); const Local = struct { - pub fn scatter(self: Tensor, coord_axes: Shape.AxesArray, indices: Tensor, updates: Tensor) Tensor { + pub fn _scatter(self: Tensor, indices: []const Tensor, updates: Tensor) Tensor { return self.scatterSlices( - coord_axes.constSlice(), indices, updates, .{ .update_fn = ScatterOpts.increment }, ); } + + pub fn _scatterCB(self: Tensor, coords: Tensor, updates: Tensor) Tensor { + return self.scatterSlices( + .{ .c = coords.choose1d(.coord, 0), .b = coords.choose1d(.coord, 1) }, + updates, + .{ .update_fn = ScatterOpts.increment }, + ); + } + + pub fn _idx(idx_shape: anytype) Tensor { + return Tensor.constant(idx_shape, .{ .i32 = 0 }); + } }; { @@ -2554,23 +2547,27 @@ pub const Tensor = struct { defer comp.deinit(); comp.activate(); defer comp.deactivate(); + const idx = Local._idx; inline for (.{ - .{ .{ .a = 10 }, .a, .{}, .{ .a = 3 } }, - .{ .{ .a = 10, .b = 20 }, .b, .{ .a = 10, .n = 8 }, .{ .a = 10, .n = 8, .b = 2 } }, - // I'm not sure I like this variant, cause `b` is not mentionned in updates. - // So 'stablehlo.scatter' is implicitly broadcasting the updates along `b` axis. - // OTOH asking the user to do the broadcasting isn't trivial cause they will need to do shape wrangling and that's annoying. - .{ .{ .a = 10, .b = 20 }, .a, .{ .n = 8 }, .{ .n = 8, .a = 2 } }, - .{ .{ .a = 10, .b = 20 }, .{ .b, .a }, .{ .n = 8, ._ = 2 }, .{ .n = 8, .a = 3, .b = 2 } }, - .{ .{ .a = 10, .b = 20 }, .{ .a, .b }, .{ .n = 8, ._ = 2 }, .{ .a = 3, .n = 8, .b = 2 } }, + // This is equivalent to a dynamic update slice, update 3 values at given offset of axis .a: + .{ .{ .a = 10 }, .{ .a = idx(.{}) }, .{ .a = 3 } }, + // Use .a as a batching axis with .a=10 x .n=8 updates of 2 elements of .b + .{ .{ .a = 10, .b = 20 }, .{ .b = idx(.{ .a = 10, .n = 8 }) }, .{ .a = 10, .n = 8, .b = 2 } }, + // Same but with update transposed + .{ .{ .a = 10, .b = 20 }, .{ .b = idx(.{ .a = 10, .n = 8 }) }, .{ .a = 10, .b = 2, .n = 8 } }, + // similar, but use the normalized form where a is no longer an explicit batching axis. + .{ .{ .a = 10, .b = 20 }, .{ .a = idx(.{ .a2 = 10, .n = 8 }), .b = idx(.{ .a2 = 10, .n = 8 }) }, .{ .a2 = 10, .n = 8, .b = 2 } }, + .{ .{ .a = 10, .b = 20 }, .{ .a = idx(.{ .a = 10, .n = 8 }), .b = idx(.{ .a = 10, .n = 8 }) }, .{ .a = 10, .n = 8, .b = 2 } }, + .{ .{ .a = 10, .b = 20 }, .{ .a = idx(.{ .n = 8 }) }, .{ .n = 8, .a = 2 } }, + .{ .{ .a = 10, .b = 20 }, .{ .b = idx(.{ .n = 8 }), .a = idx(.{ .n = 8 }) }, .{ .n = 8, .a = 3, .b = 2 } }, + .{ .{ .a = 10, .b = 20 }, .{ .a = idx(.{ .n = 8 }), .b = idx(.{ .n = 8 }) }, .{ .a = 3, .n = 8, .b = 2 } }, }) |testcase| { - const x_shape, const axes_, const idx_shape, const updates_shapes = testcase; + const x_shape, const indices, const updates_shapes = testcase; const x = Tensor.constant(x_shape, .{ .f16 = 0 }); - const idx = Tensor.constant(idx_shape, .{ .i32 = 0 }); const updates = Tensor.constant(updates_shapes, .{ .f16 = 0 }); - const y = scatterSlices(x, axes_, idx, updates, .{}); + const y = scatterSlices(x, indices, updates, .{}); // Shape doesn't change with scatterSlices try zml.testing.expectEqualShapes(x.shape(), y.shape()); try std.testing.expect(y.value().owner().verify()); @@ -2583,14 +2580,13 @@ pub const Tensor = struct { defer a.deinit(); a_host.deinit(std.testing.allocator); - const scatter_indices = try zml.Buffer.fromArray(platform, [2][1]i32{ .{0}, .{2} }); + const scatter_indices = try zml.Buffer.fromArray(platform, [2]i32{ 0, 2 }); const updates = try zml.Buffer.fromArray(platform, [2][3]i32{ .{ 10, 20, 30 }, .{ 70, 80, 90 } }); const expected = [3][3]i32{ .{ 10, 21, 32 }, .{ 3, 4, 5 }, .{ 76, 87, 98 } }; - const result = try zml.testing.compileAndCall(platform, Local.scatter, .{ + const result = try zml.testing.compileAndCall(platform, Local._scatter, .{ a, - a.shape().axes(.{.a}), - scatter_indices.withTags(.{ .n, .coord }), + &.{scatter_indices.withTags(.{.n})}, updates.withTags(.{ .n, .b }), }); try std.testing.expect(a.shape().eql(result.shape())); @@ -2603,14 +2599,13 @@ pub const Tensor = struct { defer a.deinit(); a_host.deinit(std.testing.allocator); - const scatter_indices = try zml.Buffer.fromArray(platform, [2][1]i32{ .{2}, .{7} }); + const scatter_indices = try zml.Buffer.fromArray(platform, [2]i32{ 2, 7 }); const updates = try zml.Buffer.fromArray(platform, [2]i32{ 20, 70 }); const expected = [9]i32{ 0, 1, 22, 3, 4, 5, 6, 77, 8 }; - const result = try zml.testing.compileAndCall(platform, Local.scatter, .{ + const result = try zml.testing.compileAndCall(platform, Local._scatter, .{ a, - a.shape().axes(.{0}), - scatter_indices.withTags(.{ .n, .coord }), + &.{scatter_indices.withTags(.{.n})}, updates.withTags(.{.n}), }); try std.testing.expect(a.shape().eql(result.shape())); @@ -2642,7 +2637,7 @@ pub const Tensor = struct { ); defer values.deinit(); - const result = try zml.testing.compileAndCall(platform, Local.scatter, .{ operand, operand.shape().axes(.{ .c, .b }), start_indices, values }); + const result = try zml.testing.compileAndCall(platform, Local._scatterCB, .{ operand, start_indices, values }); const expected = [2][3][4][2]u16{ .{ @@ -2740,12 +2735,12 @@ pub const Tensor = struct { const platform = zml.testing.env(); const allocator = std.testing.allocator; const ArgMaxTest = struct { - pub fn forward(x: Tensor) Tensor.ArgMaxRes { + pub fn _fwd(x: Tensor) Tensor.ArgMaxRes { return x.argMax(1); } }; - const argmax = try zml.compileFn(allocator, ArgMaxTest.forward, .{Shape.init(.{ 1, 5 }, .f32)}, platform); + const argmax = try zml.compileFn(allocator, ArgMaxTest._fwd, .{Shape.init(.{ 1, 5 }, .f32)}, platform); defer argmax.deinit(); // Test with tie { @@ -3290,10 +3285,10 @@ pub const Tensor = struct { const res = try zml.testing.compileAndCall( platform, struct { - pub fn forward(x_: Tensor, idx_: struct { a: Tensor }, y_: Tensor) Tensor { + pub fn _fwd(x_: Tensor, idx_: struct { a: Tensor }, y_: Tensor) Tensor { return x_.dynamicUpdateSlice(idx_, y_); } - }.forward, + }._fwd, .{ x.withTags(.{.a}), .{ .a = idx }, y.withTags(.{.a}) }, ); try testing.expectEqual([10]f32{ 0, 1, 2, 3, -1, -1, 6, 7, 8, 9 }, try res.getValue([10]f32)); @@ -3308,10 +3303,10 @@ pub const Tensor = struct { const res = try zml.testing.compileAndCall( platform, struct { - pub fn forward(x_: Tensor, idx_: Tensor, y_: Tensor) Tensor { + pub fn _fwd(x_: Tensor, idx_: Tensor, y_: Tensor) Tensor { return x_.dynamicUpdateSlice(.{ .b = idx_ }, y_); } - }.forward, + }._fwd, .{ x.withTags(.{ .a, .b }), idx, y.withTags(.{.a}) }, ); try testing.expectEqualDeep( @@ -3328,10 +3323,10 @@ pub const Tensor = struct { const res = try zml.testing.compileAndCall( platform, struct { - pub fn forward(x_: Tensor, idx_: Tensor, y_: Tensor) Tensor { + pub fn _fwd(x_: Tensor, idx_: Tensor, y_: Tensor) Tensor { return x_.dynamicUpdateSlice(.{ zml.Tensor.scalar(0, .i32), idx_ }, y_); } - }.forward, + }._fwd, .{ x, idx, y }, ); try testing.expectEqualDeep( @@ -3349,10 +3344,10 @@ pub const Tensor = struct { const res = try zml.testing.compileAndCall( platform, struct { - pub fn forward(x_: Tensor, idx_: struct { a: Tensor, b: Tensor }, y_: Tensor) Tensor { + pub fn _fwd(x_: Tensor, idx_: struct { a: Tensor, b: Tensor }, y_: Tensor) Tensor { return x_.dynamicUpdateSlice(idx_, y_); } - }.forward, + }._fwd, .{ x.withTags(.{ .a, .b }), .{ .a = idx_a, .b = idx_b }, y.withTags(.{.a}) }, ); try testing.expectEqualDeep( @@ -3368,11 +3363,11 @@ pub const Tensor = struct { const idx_a = try zml.Buffer.scalar(platform, 1, .i32); const idx_b = try zml.Buffer.scalar(platform, 3, .i32); const A = struct { - pub fn forward(x_: Tensor, idx_: [2]Tensor, y_: Tensor) Tensor { + pub fn _fwd(x_: Tensor, idx_: [2]Tensor, y_: Tensor) Tensor { return x_.dynamicUpdateSlice(&idx_, y_); } }; - const res = try zml.testing.compileAndCall(platform, A.forward, .{ x, .{ idx_a, idx_b }, y }); + const res = try zml.testing.compileAndCall(platform, A._fwd, .{ x, .{ idx_a, idx_b }, y }); try testing.expectEqualDeep( [2][5]f32{ .{ 0, 1, 2, 3, 4 }, .{ 5, 6, 7, -1, 9 } }, res.getValue([2][5]f32), @@ -3589,15 +3584,16 @@ pub const Tensor = struct { /// Given a set of N vectors of lengths A, B, C, D, /// returns N tensors of rank N, and shape (A, B, C, D). - /// For any coordinate (a, b, c, d), - /// we have: + /// For any coordinate (a, b, c, d), we have: + /// /// - res[0][a, b, c, d] == A[a] /// - res[1][a, b, c, d] == B[b] /// - res[2][a, b, c, d] == C[c] /// - res[3][a, b, c, d] == D[d] + /// /// This is implemented with broadcasting, so typically it won't copy. /// In Pytorch/Numpy this is know as `meshgrid` with "ij" mode. - /// See torch.meshgrid for the "xy" mode. + /// See `zml.torch.meshgrid` for the "xy" mode. pub fn cartesianProduct(comptime N: u3, vectors: [N]Tensor) [N]Tensor { var out: @TypeOf(vectors) = undefined; _cartesianProduct(&vectors, &out); @@ -3634,13 +3630,13 @@ pub const Tensor = struct { const y = try zml.Buffer.fromSlice(client, .{4}, &[_]i32{ 0, 1, 2, 3 }); const Local = struct { - pub fn cartesianProduct2(a: Tensor, b: Tensor) [2]Tensor { + pub fn _cartesianProduct2(a: Tensor, b: Tensor) [2]Tensor { return cartesianProduct(2, .{ a, b }); } }; { - const xs, const ys = try zml.testing.compileAndCall(client, Local.cartesianProduct2, .{ x, y }); + const xs, const ys = try zml.testing.compileAndCall(client, Local._cartesianProduct2, .{ x, y }); try std.testing.expectEqualSlices(i64, &.{ 6, 4 }, xs.shape().dims()); try std.testing.expectEqualSlices(i64, &.{ 6, 4 }, ys.shape().dims()); try std.testing.expectEqualDeep( @@ -3670,8 +3666,8 @@ pub const Tensor = struct { /// Given a set of N vectors of lengths A, B, C, D, /// returns 1 tensors of rank N+1, and shape (A, B, C, D, N). - /// For any coordinate (a, b, c, d), - /// we have: + /// For any coordinate (a, b, c, d), we have: + /// /// - res[a, b, c, d] == (A[a], B[b], C[c], D[d]) pub fn cartesianProductStacked(vectors: []const Tensor) Tensor { var out = std.BoundedArray(Tensor, Tensor.MAX_RANK).init(vectors.len) catch unreachable; @@ -3687,12 +3683,12 @@ pub const Tensor = struct { const y = try zml.Buffer.fromSlice(platform, .{4}, &[_]i32{ 0, 1, 2, 3 }); const Local = struct { - pub fn cartesianProduct2(a: Tensor, b: Tensor) Tensor { + pub fn _fwd(a: Tensor, b: Tensor) Tensor { return cartesianProductStacked(&.{ a, b }); } }; - const z = try zml.testing.compileAndCall(platform, Local.cartesianProduct2, .{ x, y }); + const z = try zml.testing.compileAndCall(platform, Local._fwd, .{ x, y }); try std.testing.expectEqualDeep( [6][4][2]i32{ .{ .{ 0, 0 }, .{ 0, 1 }, .{ 0, 2 }, .{ 0, 3 } }, @@ -3795,7 +3791,7 @@ test "Tensor.maxPool1d" { const platform = zml.testing.env(); const MaxPool = struct { - pub fn forward(x: zml.Tensor) Tensor.ArgMaxRes { + pub fn _fwd(x: zml.Tensor) Tensor.ArgMaxRes { return x.maxPool1d(.{ .window_dimensions = 3, .window_strides = 2, @@ -3807,7 +3803,7 @@ test "Tensor.maxPool1d" { for (&data, 0..) |*v, i| v.* = @floatFromInt(i); const x = try zml.Buffer.fromSlice(platform, .{ 2, 2, 5 }, &data); - const result = try zml.testing.compileAndCall(platform, MaxPool.forward, .{x}); + const result = try zml.testing.compileAndCall(platform, MaxPool._fwd, .{x}); try zml.testing.expectEqualShapes(Shape.init(.{ 2, 2, 2 }, .f32), result.values.shape()); try zml.testing.expectEqualShapes(Shape.init(.{ 2, 2, 2 }, .i32), result.indices.shape()); const buffer = result.values.getValue([2][2][2]f32); @@ -3831,7 +3827,7 @@ test "Tensor.maxPool2d" { const platform = zml.testing.env(); const MaxPool = struct { - pub fn forward(x: Tensor) Tensor.ArgMaxRes { + pub fn _fwd(x: Tensor) Tensor.ArgMaxRes { return x.maxPool2d(.{ .window_dimensions = .{ 3, 2 }, .window_strides = .{ 2, 1 }, @@ -3843,7 +3839,7 @@ test "Tensor.maxPool2d" { for (&data, 0..) |*v, i| v.* = @floatFromInt(i); const x = try zml.Buffer.fromSlice(platform, .{ 2, 2, 5, 5 }, &data); - const result = try zml.testing.compileAndCall(platform, MaxPool.forward, .{x}); + const result = try zml.testing.compileAndCall(platform, MaxPool._fwd, .{x}); try zml.testing.expectEqualShapes(Shape.init(.{ 2, 2, 2, 4 }, .f32), result.values.shape()); try zml.testing.expectEqualShapes(Shape.init(.{ 2, 2, 2, 4 }, .i32), result.indices.shape()); var buffer: [2][2][2][4]f32 = undefined; @@ -3962,7 +3958,7 @@ test shapesOf { } } -fn _collectAxes(T: type, bounded_array: std.BoundedArray(T, Tensor.MAX_RANK), value: T) std.BoundedArray(i64, Tensor.MAX_RANK) { +pub fn _collectAxes(T: type, bounded_array: std.BoundedArray(T, Tensor.MAX_RANK), value: T) std.BoundedArray(i64, Tensor.MAX_RANK) { var res: std.BoundedArray(i64, Tensor.MAX_RANK) = .{}; for (bounded_array.constSlice(), 0..) |v, ax| { if (v == value) { @@ -4046,13 +4042,13 @@ test "unused tensor" { const platform = zml.testing.env(); const Local = struct { - pub fn forward(x: Tensor) Tensor { + pub fn _fwd(x: Tensor) Tensor { const y = x.addConstant(1); _ = y; return x; } }; - const mod = try zml.compileFn(std.testing.allocator, Local.forward, .{Shape.init(.{10}, .f32)}, platform); + const mod = try zml.compileFn(std.testing.allocator, Local._fwd, .{Shape.init(.{10}, .f32)}, platform); defer mod.deinit(); } diff --git a/zml/torch.zig b/zml/torch.zig index 0fe69c0..1dc52f8 100644 --- a/zml/torch.zig +++ b/zml/torch.zig @@ -105,8 +105,8 @@ pub fn unsqueeze( } test unsqueeze { - const UnsqueezeTest = struct { - pub fn forward(x: Tensor) Tensor { + const Local = struct { + pub fn _fwd(x: Tensor) Tensor { var y = x; y = unsqueeze(y, 0); y = unsqueeze(y, -1); @@ -117,7 +117,7 @@ test unsqueeze { const platform = zml.testing.env(); const x = try zml.Buffer.fromArray(platform, @as([8]f16, undefined)); - const res = try zml.testing.compileAndCall(platform, UnsqueezeTest.forward, .{x}); + const res = try zml.testing.compileAndCall(platform, Local._fwd, .{x}); try zml.testing.expectEqualShapes(zml.Shape.init(.{ 1, 8, 1, 1 }, .f16), res.shape()); } @@ -247,7 +247,7 @@ test meshgrid { const y = try zml.Buffer.fromSlice(platform, .{4}, &[_]i32{ 0, 1, 2, 3 }); const Local = struct { - pub fn meshgrid2(a: Tensor, b: Tensor, indexing: MeshgridIndexing) [2]Tensor { + pub fn _meshgrid2(a: Tensor, b: Tensor, indexing: MeshgridIndexing) [2]Tensor { return meshgrid(2, .{ a, b }, indexing); } }; @@ -255,7 +255,7 @@ test meshgrid { // Only test .xy mode, sinc .ij is just calling cartesianProduct which // got its own tests. { - const xs, const ys = try zml.testing.compileAndCall(platform, Local.meshgrid2, .{ x, y, .xy }); + const xs, const ys = try zml.testing.compileAndCall(platform, Local._meshgrid2, .{ x, y, .xy }); try std.testing.expectEqualSlices(i64, &.{ 4, 6 }, xs.dims()); try std.testing.expectEqualSlices(i64, &.{ 4, 6 }, ys.dims()); try std.testing.expectEqualDeep(