zml: revamp scatterSlices

Main issue with current `scatter` implementation is that it uses
broadcasting dims of `stablehlo.scatter`.
While nice in theory, the optimizer doesn't handle them well and they
often are unrolled into while loop.
Here I convert the batching dim to extra iotas indices.
This commit is contained in:
Tarry Singh 2024-01-08 17:55:20 +00:00
parent 83b5e1ec48
commit 68dbc290e9
6 changed files with 560 additions and 206 deletions

View File

@ -392,6 +392,26 @@ test visit {
} }
} }
pub fn count(T: type, value: anytype) u32 {
var counter: u32 = 0;
visit(struct {
pub fn cb(res: *u32, _: *const T) void {
res.* += 1;
}
}.cb, &counter, value);
return counter;
}
pub fn first(T: type, value: anytype) T {
var res: ?T = null;
visit(struct {
pub fn cb(res_ptr: *?T, x: *const T) void {
if (res_ptr.* == null) res_ptr.* = x.*;
}
}.cb, &res, &value);
return res.?;
}
/// Given a `fn([]const T, Args) T` and a slice of values, /// Given a `fn([]const T, Args) T` and a slice of values,
/// will combine all values in one value. /// will combine all values in one value.
/// Only T elements of values will be looked at. /// Only T elements of values will be looked at.

View File

@ -368,7 +368,7 @@ pub const CompilationContext = struct {
defer arena_state.deinit(); defer arena_state.deinit();
const arena = arena_state.allocator(); const arena = arena_state.allocator();
const tensor_count = countTensors(args); const tensor_count = meta.count(Tensor, args);
const mlir_ctx = self.mlirCtx(); const mlir_ctx = self.mlirCtx();
const loc = mlir_ctx.location(@src()); const loc = mlir_ctx.location(@src());
@ -505,13 +505,13 @@ pub const CompilationContext = struct {
const Local = struct { const Local = struct {
bias: Tensor, bias: Tensor,
pub fn forward(self: @This(), x: Tensor, y: Tensor) [2]Tensor { pub fn _fwd(self: @This(), x: Tensor, y: Tensor) [2]Tensor {
const x1 = zml.ops.call(self, .inner, .{x}); const x1 = zml.ops.call(self, ._inner, .{x});
const x2 = zml.ops.call(self, .inner, .{x1}); const x2 = zml.ops.call(self, ._inner, .{x1});
return .{ x1.reuseBuffer(y), x2 }; return .{ x1.reuseBuffer(y), x2 };
} }
pub fn inner(self: @This(), x: Tensor) Tensor { pub fn _inner(self: @This(), x: Tensor) Tensor {
const y = x.add(self.bias); const y = x.add(self.bias);
return y.reuseBuffer(x); return y.reuseBuffer(x);
} }
@ -524,7 +524,7 @@ pub const CompilationContext = struct {
var comp = try zml.module.CompilationContext.init(std.testing.allocator, "test", platform); var comp = try zml.module.CompilationContext.init(std.testing.allocator, "test", platform);
defer comp.deinit(); defer comp.deinit();
var tensor_args = .{ model, Tensor{ ._shape = s, ._id = .{ .buffer_id = 1234 } }, Tensor{ ._shape = s, ._id = .{ .buffer_id = 1235 } } }; var tensor_args = .{ model, Tensor{ ._shape = s, ._id = .{ .buffer_id = 1234 } }, Tensor{ ._shape = s, ._id = .{ .buffer_id = 1235 } } };
const f = try comp.emitMlir(Local.forward, &tensor_args, .{ .name = "test.emitMlir.Local.forward", .kind = .main }); const f = try comp.emitMlir(Local._fwd, &tensor_args, .{ .name = "test.emitMlir.Local.forward", .kind = .main });
var mlir_bytecode = std.ArrayList(u8).init(std.testing.allocator); var mlir_bytecode = std.ArrayList(u8).init(std.testing.allocator);
defer mlir_bytecode.deinit(); defer mlir_bytecode.deinit();
@ -798,7 +798,7 @@ pub const CompilationContext = struct {
}; };
} }
fn getValue(self: *const CompilationContext, tensor: Tensor) mlir.Value { pub fn getValue(self: *const CompilationContext, tensor: Tensor) mlir.Value {
return self.getValueAndDonation(tensor)[0]; return self.getValueAndDonation(tensor)[0];
} }
@ -911,6 +911,7 @@ fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, m
// setFlag(&options, "xla_gpu_enable_cudnn_layer_norm", true); // setFlag(&options, "xla_gpu_enable_cudnn_layer_norm", true);
// setFlag(&options, "xla_gpu_enable_custom_fusions", true); // setFlag(&options, "xla_gpu_enable_custom_fusions", true);
// setFlag(&options, "xla_gpu_enable_dynamic_slice_fusion", true); // setFlag(&options, "xla_gpu_enable_dynamic_slice_fusion", true);
// setFlag(&options, "xla_gpu_enable_while_loop_double_buffering", true);
// setFlag(&options, "xla_gpu_use_runtime_fusion", true); // setFlag(&options, "xla_gpu_use_runtime_fusion", true);
// setFlag(&options, "xla_gpu_enable_latency_hiding_scheduler", true); // setFlag(&options, "xla_gpu_enable_latency_hiding_scheduler", true);
var r_ = try runfiles.Runfiles.create(.{ .allocator = arena }) orelse { var r_ = try runfiles.Runfiles.create(.{ .allocator = arena }) orelse {
@ -1068,7 +1069,7 @@ test FnCache {
w: Tensor, w: Tensor,
b: Tensor, b: Tensor,
pub fn forward(self: Layer_, x: Tensor) Tensor { pub fn _fwd(self: Layer_, x: Tensor) Tensor {
const wx = self.w.dotGeneral(x, &.{.{ -1, 0 }}, &.{}); const wx = self.w.dotGeneral(x, &.{.{ -1, 0 }}, &.{});
return wx.add(self.b.broad(wx.shape())).relu(); return wx.add(self.b.broad(wx.shape())).relu();
} }
@ -1078,18 +1079,18 @@ test FnCache {
const NN_ = @This(); const NN_ = @This();
layers: [3]Layer, layers: [3]Layer,
pub fn forward(self: NN_, x0: Tensor) Tensor { pub fn _fwd(self: NN_, x0: Tensor) Tensor {
var x = x0; var x = x0;
for (self.layers) |layer| { for (self.layers) |layer| {
x = ops.call(layer, .forward, .{x}); x = ops.call(layer, ._fwd, .{x});
} }
return x; return x;
} }
pub fn forwardRefImpl(self: NN_, x0: Tensor) Tensor { pub fn _forwardRefImpl(self: NN_, x0: Tensor) Tensor {
var x = x0; var x = x0;
for (self.layers) |layer| { for (self.layers) |layer| {
x = layer.forward(x); x = layer._fwd(x);
} }
return x; return x;
} }
@ -1113,8 +1114,8 @@ test FnCache {
}, },
}, },
}; };
const res = try zml.testing.compileAndCall(platform, NN.forward, .{ nn, x }); const res = try zml.testing.compileAndCall(platform, NN._fwd, .{ nn, x });
const expected = try zml.testing.compileAndCall(platform, NN.forwardRefImpl, .{ nn, x }); const expected = try zml.testing.compileAndCall(platform, NN._forwardRefImpl, .{ nn, x });
try zml.testing.expectClose(expected, res, 1e-4); try zml.testing.expectClose(expected, res, 1e-4);
} }

View File

@ -16,6 +16,7 @@ const EnumLiteral = @TypeOf(.enum_literal);
const HostBuffer = @import("hostbuffer.zig").HostBuffer; const HostBuffer = @import("hostbuffer.zig").HostBuffer;
const Shape = @import("shape.zig").Shape; const Shape = @import("shape.zig").Shape;
const Tensor = @import("tensor.zig").Tensor; const Tensor = @import("tensor.zig").Tensor;
const _collectAxes = @import("tensor.zig")._collectAxes;
const dialect = struct { const dialect = struct {
const stablehlo = @import("mlir/dialects").stablehlo; const stablehlo = @import("mlir/dialects").stablehlo;
@ -73,23 +74,23 @@ test "simple while" {
end: Tensor, end: Tensor,
const CountInts = @This(); const CountInts = @This();
pub fn hasNext(self: CountInts, i: Tensor, sum: Tensor) Tensor { pub fn _hasNext(self: CountInts, i: Tensor, sum: Tensor) Tensor {
_ = sum; _ = sum;
return i.cmp(.LT, self.end); return i.cmp(.LT, self.end);
} }
pub fn next(self: CountInts, i: Tensor, sum: Tensor) [2]Tensor { pub fn _next(self: CountInts, i: Tensor, sum: Tensor) [2]Tensor {
const r1 = i.add(self.step); const r1 = i.add(self.step);
const r2 = sum.add(i); const r2 = sum.add(i);
return .{ r1, r2 }; return .{ r1, r2 };
} }
pub fn forward(self: CountInts, init_i: Tensor, init_sum: Tensor) [2]Tensor { pub fn _fwd(self: CountInts, init_i: Tensor, init_sum: Tensor) [2]Tensor {
const x = init_i.scale(2); const x = init_i.scale(2);
return while_(CountInts.hasNext, CountInts.next, self, .{ x, init_sum }); return while_(CountInts._hasNext, CountInts._next, self, .{ x, init_sum });
} }
pub fn zigForward(step: i64, end: i64, init_i: i64, init_sum: i64) [2]i64 { pub fn _zigForward(step: i64, end: i64, init_i: i64, init_sum: i64) [2]i64 {
const x = init_i * 2; const x = init_i * 2;
var i = x; var i = x;
var sum = init_sum; var sum = init_sum;
@ -110,14 +111,14 @@ test "simple while" {
.step = try zml.Buffer.fromSlice(platform, .{}, &[_]i64{1}), .step = try zml.Buffer.fromSlice(platform, .{}, &[_]i64{1}),
.end = try zml.Buffer.fromSlice(platform, .{}, &[_]i64{10}), .end = try zml.Buffer.fromSlice(platform, .{}, &[_]i64{10}),
}; };
const res0, const res1 = try zml.testing.compileAndCall(platform, CountInts.forward, .{ counter, init_i, init_sum }); const res0, const res1 = try zml.testing.compileAndCall(platform, CountInts._fwd, .{ counter, init_i, init_sum });
const last_i = try res0.getValue(i64); const last_i = try res0.getValue(i64);
const sum = try res1.getValue(i64); const sum = try res1.getValue(i64);
try std.testing.expectEqual(10, last_i); try std.testing.expectEqual(10, last_i);
try std.testing.expectEqual(45, sum); try std.testing.expectEqual(45, sum);
try std.testing.expectEqual(.{ 10, 45 }, CountInts.zigForward(1, 10, 0, 0)); try std.testing.expectEqual(.{ 10, 45 }, CountInts._zigForward(1, 10, 0, 0));
} }
pub fn reduce( pub fn reduce(
@ -361,7 +362,7 @@ test for_ {
return f.mul(f); return f.mul(f);
} }
pub fn forward(num_steps: u63) Tensor { pub fn _fwd(num_steps: u63) Tensor {
return for_(Squares.sq, .{}, .{num_steps}); return for_(Squares.sq, .{}, .{num_steps});
} }
}; };
@ -370,19 +371,19 @@ test for_ {
// Just one baby step // Just one baby step
{ {
const squares = try zml.testing.compileAndCall(platform, Squares.forward, .{1}); const squares = try zml.testing.compileAndCall(platform, Squares._fwd, .{1});
try zml.testing.expectEqualShapes(Shape.init(.{1}, .f32), squares.shape()); try zml.testing.expectEqualShapes(Shape.init(.{1}, .f32), squares.shape());
try std.testing.expectEqual(0, squares.getValue(f32)); try std.testing.expectEqual(0, squares.getValue(f32));
} }
// Wow 4 in rows ! // Wow 4 in rows !
{ {
const squares = try zml.testing.compileAndCall(platform, Squares.forward, .{4}); const squares = try zml.testing.compileAndCall(platform, Squares._fwd, .{4});
try zml.testing.expectEqualShapes(Shape.init(.{4}, .f32), squares.shape()); try zml.testing.expectEqualShapes(Shape.init(.{4}, .f32), squares.shape());
try std.testing.expectEqual([_]f32{ 0, 1, 4, 9 }, try squares.getValue([4]f32)); try std.testing.expectEqual([_]f32{ 0, 1, 4, 9 }, try squares.getValue([4]f32));
} }
// AGI is coming, computing 10 squares as it's nothing. // AGI is coming, computing 10 squares as it's nothing.
{ {
const squares = try zml.testing.compileAndCall(platform, Squares.forward, .{10}); const squares = try zml.testing.compileAndCall(platform, Squares._fwd, .{10});
try zml.testing.expectEqualShapes(Shape.init(.{10}, .f32), squares.shape()); try zml.testing.expectEqualShapes(Shape.init(.{10}, .f32), squares.shape());
try std.testing.expectEqual( try std.testing.expectEqual(
[_]f32{ 0, 1, 4, 9, 16, 25, 36, 49, 64, 81 }, [_]f32{ 0, 1, 4, 9, 16, 25, 36, 49, 64, 81 },
@ -398,7 +399,7 @@ test "nested for" {
x: Tensor, x: Tensor,
x_row: Tensor, x_row: Tensor,
pub fn forward(x: Tensor) Tensor { pub fn _fwd(x: Tensor) Tensor {
return for_(OuterProd.scanRow, x, .{x.dim(0)}); return for_(OuterProd.scanRow, x, .{x.dim(0)});
} }
@ -418,7 +419,7 @@ test "nested for" {
// 5 to prevent inlining // 5 to prevent inlining
const x = try zml.Buffer.fromArray(platform, [5]f32{ 0, 1.0, -1.0, 2.0, -2.0 }); const x = try zml.Buffer.fromArray(platform, [5]f32{ 0, 1.0, -1.0, 2.0, -2.0 });
const outer_prod = try zml.testing.compileAndCall(platform, OuterProd.forward, .{x}); const outer_prod = try zml.testing.compileAndCall(platform, OuterProd._fwd, .{x});
const expected: [5][5]f32 = .{ const expected: [5][5]f32 = .{
.{ 0, 0, 0, 0, 0 }, .{ 0, 0, 0, 0, 0 },
.{ 0, 1.0, -1.0, 2.0, -2.0 }, .{ 0, 1.0, -1.0, 2.0, -2.0 },
@ -468,7 +469,7 @@ test "if" {
const allocator = std.testing.allocator; const allocator = std.testing.allocator;
const IfMod = struct { const IfMod = struct {
pub fn forward(pred: Tensor, a: Tensor, b: Tensor) Tensor { pub fn _fwd(pred: Tensor, a: Tensor, b: Tensor) Tensor {
const result = if_(pred.convert(.bool), condTrue, condFalse, .{ a, b }); const result = if_(pred.convert(.bool), condTrue, condFalse, .{ a, b });
return result; return result;
} }
@ -486,7 +487,7 @@ test "if" {
const pred = Shape.init(.{}, .i32); const pred = Shape.init(.{}, .i32);
const a = Shape.init(.{ 4, 4 }, .f32); const a = Shape.init(.{ 4, 4 }, .f32);
const b = Shape.init(.{ 4, 4 }, .f32); const b = Shape.init(.{ 4, 4 }, .f32);
const mod = try zml.compileFn(allocator, IfMod.forward, .{ pred, a, b }, platform); const mod = try zml.compileFn(allocator, IfMod._fwd, .{ pred, a, b }, platform);
defer mod.deinit(); defer mod.deinit();
} }
} }
@ -753,3 +754,317 @@ pub fn addHostCallback(
); );
return Tensor._result(input.shape(), op.result(0)); return Tensor._result(input.shape(), op.result(0));
} }
/// Generalized version of scatter to many inputs.
/// See `zml.Tensor.scatterSlices` for documentation on scatter.
///
/// This allows to use the same indices to update several tensors at once,
/// and where the update function is allow to look at elements from the different tensors
/// to compute the final value.
///
/// This sounds nice but in practice XLA doesn't support this well on GPU,
/// and will generate slow code. In practice stick with `zml.Tensor.scatterSlices`.
pub fn scatter(
comptime T: type,
comptime BlkCtx: type,
comptime update_fn: fn (BlkCtx, T, T) T,
inputs: T,
blkctx: BlkCtx,
index_tensors: anytype,
updates: T,
opts: Tensor.ScatterOpts,
) T {
const loc = @src();
const ctx = CompilationContext.current();
const n_inputs = meta.count(Tensor, &inputs);
const n_updates = meta.count(Tensor, &updates);
stdx.debug.assert(n_inputs == n_updates, "zml.ops.scatter expects the same number of tensors in inputs and updates, got {} and {}", .{ n_inputs, n_updates });
// Note: I was a bit lazy here, and I only look at tags on the first tensor.
// we probably should check all of them.
const self = meta.first(Tensor, inputs);
const update = meta.first(Tensor, updates);
var indices_per_axis, var indices_axes = Shape.parseStruct(Tensor, index_tensors);
if (indices_per_axis.len == 0) return inputs;
// validate coord axes: all coord_axes should exist inside self
for (indices_axes.constSlice()) |t| {
stdx.debug.assert(self._shape.hasTag(t) != null, "zml.ops.scatter expects axes of indices to be axes of inputs, got input={_} and indices={any}", .{ self, indices_axes });
}
// Handle scalar indices by broadcasting them to the indices with the highest rank.
const indices_shape = blk: {
var higher_rank = indices_per_axis.get(0).shape();
for (indices_per_axis.constSlice()[1..]) |indices| {
if (indices.rank() > higher_rank.rank()) {
higher_rank = indices.shape();
}
}
break :blk higher_rank;
};
for (indices_per_axis.slice()) |*idx| {
stdx.debug.assert(idx.shape().canBroadcastTo(indices_shape), "zml.ops.scatter expects all indices tensor to have the same shape, got {_}", .{indices_per_axis.slice()});
stdx.debug.assert(idx.dtype() == indices_shape.dtype(), "zml.ops.scatter expects all indices tensor to have the same dtype, got {_}", .{indices_per_axis.slice()});
idx.* = idx.broad(indices_shape);
}
// rewrite simple scatters to dynamicUpdateSlice.
if (T == Tensor and indices_shape.rank() == 0) {
return self.dynamicUpdateSlice(index_tensors, updates);
}
// TODO: ideally we should catch all possible scatter errors and provide nice error messages.
var config = scatterConfig(self.shape(), update.shape(), indices_per_axis, indices_axes);
const indices = scatterPrepareIndices(&config, self.shape(), update.shape(), &indices_per_axis, &indices_axes);
// const n_indices_axes = update.rank() - _collectAxes(AxisKind, up_kind, .update_window).len;
// stdx.debug.assert(n_indices_axe == indices_axes.len, "scatter({_}, {any}) expects 'updates' to contain all axes from 'indices', got indices={s}, updates={_}", .{ self, index_tensors, indices_axes.constSlice(), update });
const mlir_ctx = ctx.mlirCtx();
var _scalar: T = inputs;
meta.visit(struct {
pub fn cb(_: void, x: *Tensor) void {
x.* = .{ ._shape = Shape.init(.{}, x.dtype()), ._id = undefined };
}
}.cb, {}, &_scalar);
const UpdateS = BlockSign(update_fn);
const update_block, _ = ctx.makeBlock(.hermetic, UpdateS, update_fn, blkctx, .{ _scalar, _scalar });
var input_values = std.ArrayList(mlir.Value).initCapacity(ctx.allocator(), n_inputs) catch @panic("OOM");
defer input_values.deinit();
meta.collect(CompilationContext.getValue, ctx, &input_values, &inputs) catch unreachable;
var updates_values = std.ArrayList(mlir.Value).initCapacity(ctx.allocator(), n_updates) catch @panic("OOM");
defer updates_values.deinit();
meta.collect(CompilationContext.getValue, ctx, &updates_values, &updates) catch unreachable;
const op = dialect.stablehlo.scatter(
mlir_ctx,
input_values.items,
&.{indices.value()},
updates_values.items,
update_block,
.{
.update_window_dims = _collectAxes(AxisKind, config.up_kind, .update_window).constSlice(),
.inserted_window_dims = _collectAxes(AxisKind, config.op_kind, .inserted_window).constSlice(),
.input_batching_dims = _collectAxes(AxisKind, config.op_kind, .batching).constSlice(),
.scatter_indices_batching_dims = config.indices_batch_axes.constSlice(),
.scatter_dims_to_operand_dims = config.scatter_to_operand_axes.constSlice(),
.index_vector_dim = indices.rank() - 1,
.indices_are_sorted = opts.indices_are_sorted,
.unique_indices = opts.indices_are_unique,
},
mlir_ctx.location(loc),
);
var res: T = inputs;
const LocalContext = struct {
op: mlir.Operation,
index: usize = 0,
};
var local_context = LocalContext{ .op = op };
meta.visit((struct {
fn cb(inner_ctx: *LocalContext, tensor: *Tensor) void {
const val = inner_ctx.op.result(inner_ctx.index);
tensor.* = Tensor._result(tensor.shape(), val);
inner_ctx.index += 1;
}
}).cb, &local_context, &res);
assert(local_context.index == op.numResults());
return res;
}
const ScatterConfig = struct {
op_kind: std.BoundedArray(AxisKind, Tensor.MAX_RANK) = .{},
up_kind: std.BoundedArray(AxisKind, Tensor.MAX_RANK) = .{},
indices_batch_axes: Shape.DimsArray = .{},
scatter_to_operand_axes: Shape.DimsArray = .{},
updates_transpose: Shape.AxesArray = .{},
};
const AxisKind = enum { batching, update_window, inserted_window, window_id };
fn scatterConfig(
op: Shape,
update: Shape,
indices_per_axis: std.BoundedArray(Tensor, Tensor.MAX_RANK),
indices_axes: Shape.TagsArray,
) ScatterConfig {
var op_kind: std.BoundedArray(AxisKind, Tensor.MAX_RANK) = .{};
var up_kind: std.BoundedArray(AxisKind, Tensor.MAX_RANK) = .{};
var indices_batch_axes: Shape.DimsArray = .{};
var scatter_to_operand_axes: Shape.DimsArray = .{};
var updates_transpose: Shape.AxesArray = .{};
const tagged_api = indices_axes.len > 0;
const indices = indices_per_axis.get(0).shape();
if (tagged_api) {
for (indices_axes.constSlice()) |t| {
scatter_to_operand_axes.appendAssumeCapacity(op.axis(t));
}
for (indices.tags()) |t| {
stdx.debug.assert(update.hasTag(t) != null, "scatter expects 'updates' to have all axes of 'indices', got updates={} and indices={s}", .{ update, indices_axes.constSlice() });
updates_transpose.appendAssumeCapacity(update.axis(t));
}
for (op.tags()) |t| {
if (update.hasTag(t)) |up_ax| {
updates_transpose.appendAssumeCapacity(up_ax);
if (indices.hasTag(t)) |id_ax| {
if (std.mem.indexOfScalar(Shape.Tag, indices_axes.constSlice(), t) != null) {
// tag is in indices AND in coords -> it's a batching dim that has been rewritten to a regular insertion dim
op_kind.appendAssumeCapacity(.inserted_window);
} else {
// tag is in op, indices and updates -> it's a batching dim
op_kind.appendAssumeCapacity(.batching);
indices_batch_axes.appendAssumeCapacity(@intCast(id_ax));
}
} else {
op_kind.appendAssumeCapacity(.update_window);
}
} else {
op_kind.appendAssumeCapacity(.inserted_window);
}
}
for (update.tags(), 0..) |t, up_ax| {
// Handle batch axes right away.
if (op.hasTag(t)) |self_ax| {
if (op_kind.get(self_ax) == .batching) {
up_kind.appendAssumeCapacity(.batching);
continue;
}
}
if (indices.hasTag(t) != null) {
up_kind.appendAssumeCapacity(.window_id);
} else if (op.hasTag(t)) |self_ax| {
stdx.debug.assert(update.dim(up_ax) <= op.dim(self_ax), "scatter expects the slices described in 'updates' to fit inside 'op', but along axis .{s} it doesn't. Got op={_}, updates={_}.", .{ t, op, update });
up_kind.appendAssumeCapacity(.update_window);
} else {
// TODO: consider accepting untagged update here.
std.debug.panic("scatter expects 'updates' to be made of axes from op={_} and from indices={s}, got unknown tag {s} in {_}", .{ op, indices_axes.constSlice(), t, update });
}
}
} else {
for (0..indices_per_axis.len) |i| {
op_kind.appendAssumeCapacity(.inserted_window);
scatter_to_operand_axes.appendAssumeCapacity(@intCast(i));
up_kind.appendAssumeCapacity(.window_id);
}
for (indices_per_axis.len..op.rank()) |_| {
op_kind.appendAssumeCapacity(.update_window);
}
for (indices_per_axis.len..update.rank()) |_| {
up_kind.appendAssumeCapacity(.update_window);
}
for (0..update.rank()) |i| {
updates_transpose.appendAssumeCapacity(@intCast(i));
}
}
return .{
.op_kind = op_kind,
.up_kind = up_kind,
.indices_batch_axes = indices_batch_axes,
.scatter_to_operand_axes = scatter_to_operand_axes,
.updates_transpose = updates_transpose,
};
}
test scatterConfig {
const zml = @import("zml.zig");
const platform = zml.testing.env();
var comp = try zml.module.CompilationContext.init(std.testing.allocator, "test", platform);
defer comp.deinit();
comp.activate();
defer comp.deactivate();
const Local = struct {
pub fn _idx(idx_shape: anytype) Tensor {
return Tensor.constant(idx_shape, .{ .i32 = 0 });
}
};
const idx = Local._idx;
const op = Shape.init(.{ .a = 10, .b = 20 }, .f32);
// Use .a as a batching axis with .a=10 x .n=8 updates of 2 elements of .b
{
const indices, const coords_tags = Shape.parseStruct(Tensor, .{ .b = idx(.{ .a = 10, .n = 8 }) });
const update = Shape.init(.{ .a = 10, .n = 8, .b = 2 }, .f32);
const cfg = scatterConfig(op, update, indices, coords_tags);
try std.testing.expectEqualSlices(AxisKind, &.{ .batching, .update_window }, cfg.op_kind.constSlice());
try std.testing.expectEqualSlices(AxisKind, &.{ .batching, .window_id, .update_window }, cfg.up_kind.constSlice());
}
// similar, but use the normalized form where .a is no longer an explicit batching axis.
{
const indices, const coords_tags = Shape.parseStruct(Tensor, .{ .a = idx(.{ .a = 10, .n = 8 }), .b = idx(.{ .a = 10, .n = 8 }) });
const update = Shape.init(.{ .a = 10, .n = 8, .b = 2 }, .f32);
const cfg = scatterConfig(op, update, indices, coords_tags);
try std.testing.expectEqualSlices(AxisKind, &.{ .inserted_window, .update_window }, cfg.op_kind.constSlice());
try std.testing.expectEqualSlices(AxisKind, &.{ .window_id, .window_id, .update_window }, cfg.up_kind.constSlice());
}
}
/// Concatenate all indices tensor in one tensor.
///
/// Is allowed to reorder stuff to simplify the job of the backend,
/// and to expand the batching dims.
fn scatterPrepareIndices(
cfg: *ScatterConfig,
op: Shape,
update: Shape,
indices_per_axis: *std.BoundedArray(Tensor, Tensor.MAX_RANK),
indices_axes: *Shape.TagsArray,
) Tensor {
var old_scatter_to_op_axes = cfg.scatter_to_operand_axes;
const batching = _collectAxes(AxisKind, cfg.op_kind, .batching);
for (batching.constSlice()) |batch_ax| {
const id_shape = indices_per_axis.get(0).shape();
// batching requires tagging, so we're sure to have a tag here.
const batch_tag = op.tag(batch_ax);
indices_axes.appendAssumeCapacity(batch_tag);
const batch_id = Tensor.iota(id_shape, batch_tag).convert(id_shape.dtype());
indices_per_axis.appendAssumeCapacity(batch_id);
cfg.op_kind.buffer[@intCast(batch_ax)] = .inserted_window;
cfg.up_kind.buffer[update.axis(batch_tag)] = .window_id;
old_scatter_to_op_axes.appendAssumeCapacity(batch_ax);
}
cfg.indices_batch_axes = .{};
// Reorder the axes so that in indices_per_axis is ordered like in op if possible.
// TODO: transpose updates if needed
var indices: std.BoundedArray(Tensor, Tensor.MAX_RANK) = .{};
var scatter_to_op_axes: Shape.DimsArray = .{};
while (old_scatter_to_op_axes.len > 0) {
const scatter_ax = std.sort.argMin(i64, old_scatter_to_op_axes.constSlice(), {}, std.sort.asc(i64)).?;
const op_ax = old_scatter_to_op_axes.orderedRemove(scatter_ax);
const scatter_idx = indices_per_axis.orderedRemove(scatter_ax);
scatter_to_op_axes.appendAssumeCapacity(op_ax);
indices.appendAssumeCapacity(scatter_idx);
}
cfg.scatter_to_operand_axes = scatter_to_op_axes;
for (scatter_to_op_axes.constSlice(), 0..) |sc_ax, i| {
if (i != sc_ax) {
log.warn("Found a slow scatter pattern, which is going to generate a while loop: scatter({_}, {any}, {_}). Because the index axes aren't the major ones in the input tensor.", .{ op, scatter_to_op_axes.constSlice(), update });
break;
}
}
return Tensor.stack(indices.constSlice(), .last, .coord);
}
inline fn toI64(values: anytype) []i64 {
var res: [Tensor.MAX_RANK]i64 = undefined;
for (values, 0..) |val, i| res[i] = @intCast(val);
return res[0..values.len];
}

View File

@ -410,6 +410,28 @@ pub const Shape = struct {
_ = try writer.write(if (bare_fmt) "}" else "})"); _ = try writer.write(if (bare_fmt) "}" else "})");
} }
/// Broadcasts a Tensor to the given shape, extending dimensions if needed.
pub fn canBroadcastTo(self: Shape, other: Shape) bool {
// Already the right shape
if (std.mem.eql(i64, self.dims(), other.dims())) return true;
// Non ambiguous broadcasting
// TODO: broad is error prone because of this:
// it will happily broadcast .{ .a = 10, .b = 1 } to .{ .b = 10, .a = 5 }
if (self.rank() == 0 or self.rank() == other.rank()) {
for (0..self.rank()) |i| {
if (self.dim(i) != 1 and self.dim(i) != other.dim(i)) return false;
}
return true;
}
for (self.dims(), self.tags()) |d, t| {
const other_ax = other.hasTag(t) orelse return false;
if (d != 1 and d != other.dim(other_ax)) return false;
}
return true;
}
pub fn reshape(self: Shape, new_shape_: anytype) Shape { pub fn reshape(self: Shape, new_shape_: anytype) Shape {
var new_shape: Shape = .{ ._dtype = self.dtype() }; var new_shape: Shape = .{ ._dtype = self.dtype() };
new_shape._dims, new_shape._tags = parseDimensions(new_shape_); new_shape._dims, new_shape._tags = parseDimensions(new_shape_);

View File

@ -1535,21 +1535,21 @@ pub const Tensor = struct {
// Wrap slice1d to hide the anytype in the signature. // Wrap slice1d to hide the anytype in the signature.
const Local = struct { const Local = struct {
pub fn slice1dAxis(input: Tensor, ax: i8, slice_: Tensor.Slice) Tensor { pub fn _slice1dAxis(input: Tensor, ax: i8, slice_: Tensor.Slice) Tensor {
return input.slice1d(ax, slice_); return input.slice1d(ax, slice_);
} }
}; };
{ {
const res = try zml.testing.compileAndCallWithTensors(platform, Local.slice1dAxis, .{ x.shape(), 0, .{ .end = 1 } }, .{ x, 0, .{ .end = 1 } }); const res = try zml.testing.compileAndCallWithTensors(platform, Local._slice1dAxis, .{ x.shape(), 0, .{ .end = 1 } }, .{ x, 0, .{ .end = 1 } });
try testing.expectEqual([5]f32{ 0, 1, 2, 3, 4 }, try res.getValue([5]f32)); try testing.expectEqual([5]f32{ 0, 1, 2, 3, 4 }, try res.getValue([5]f32));
} }
{ {
const res = try zml.testing.compileAndCallWithTensors(platform, Local.slice1dAxis, .{ x.shape(), 1, .{ .start = 1, .step = 2 } }, .{ x, 0, .{ .start = 1, .step = 2 } }); const res = try zml.testing.compileAndCallWithTensors(platform, Local._slice1dAxis, .{ x.shape(), 1, .{ .start = 1, .step = 2 } }, .{ x, 0, .{ .start = 1, .step = 2 } });
try testing.expectEqual([4]f32{ 1, 3, 6, 8 }, try res.getValue([4]f32)); try testing.expectEqual([4]f32{ 1, 3, 6, 8 }, try res.getValue([4]f32));
} }
{ {
const res = try zml.testing.compileAndCallWithTensors(platform, Local.slice1dAxis, .{ x.shape(), -1, .{ .start = -2 } }, .{ x, 0, .{ .start = -2 } }); const res = try zml.testing.compileAndCallWithTensors(platform, Local._slice1dAxis, .{ x.shape(), -1, .{ .start = -2 } }, .{ x, 0, .{ .start = -2 } });
try testing.expectEqual([4]f32{ 3, 4, 8, 9 }, try res.getValue([4]f32)); try testing.expectEqual([4]f32{ 3, 4, 8, 9 }, try res.getValue([4]f32));
} }
} }
@ -1565,6 +1565,7 @@ pub const Tensor = struct {
/// Concatenates the input Tensors along the given axis. /// Concatenates the input Tensors along the given axis.
pub fn concatenate(tensors: []const Tensor, axis_: anytype) Tensor { pub fn concatenate(tensors: []const Tensor, axis_: anytype) Tensor {
if (tensors.len == 1) return tensors[0];
stdx.debug.assert(tensors.len <= 32, "concatenate only supports up to 32 tensors, got {}", .{tensors.len}); stdx.debug.assert(tensors.len <= 32, "concatenate only supports up to 32 tensors, got {}", .{tensors.len});
var buffer: [32]mlir.Value = undefined; var buffer: [32]mlir.Value = undefined;
std.debug.assert(tensors.len <= buffer.len); std.debug.assert(tensors.len <= buffer.len);
@ -1883,12 +1884,18 @@ pub const Tensor = struct {
/// you will lose the tags. /// you will lose the tags.
/// To avoid use favorise `.broad(shape)` when working with tagged tensors. /// To avoid use favorise `.broad(shape)` when working with tagged tensors.
pub fn broadcast(self: Tensor, output_shape: Shape, axes_: []const i64) Tensor { pub fn broadcast(self: Tensor, output_shape: Shape, axes_: []const i64) Tensor {
const res_shape = output_shape.withDtype(self.dtype());
stdx.debug.assert(axes_.len == self.rank(), "broadcast expects axes_ to map all axes from self to axes of the output shape, got broadcast({}, {}, {d})", .{ self, output_shape, axes_ }); stdx.debug.assert(axes_.len == self.rank(), "broadcast expects axes_ to map all axes from self to axes of the output shape, got broadcast({}, {}, {d})", .{ self, output_shape, axes_ });
for (0.., axes_) |self_ax, other_ax| { for (0.., axes_) |self_ax, other_ax| {
const d = self.dim(self_ax); const d = self.dim(self_ax);
stdx.debug.assert(d == 1 or d == output_shape.dim(other_ax), "broadcast expects shape axes to either be 1-sized or to match the target size. got broadcast({}, {}, {d}), error on self axis {} mapping to other axis {}", .{ self, output_shape, axes_, self_ax, other_ax }); stdx.debug.assert(d == 1 or d == output_shape.dim(other_ax), "broadcast expects shape axes to either be 1-sized or to match the target size. got broadcast({}, {}, {d}), error on self axis {} mapping to other axis {}", .{ self, output_shape, axes_, self_ax, other_ax });
} }
const res_shape = output_shape.withDtype(self.dtype());
if (std.mem.eql(i64, self.dims(), output_shape.dims())) {
// No broadcast needed. We don't emit a new stablehlo value
// but we propagate output_shape tags.
return _result(res_shape, self.value());
}
const ctx = self.getContext(); const ctx = self.getContext();
const result_type = mlir.ext.RankedTensorType.fromShape(ctx.mlirCtx(), res_shape).as(mlir.Type).?; const result_type = mlir.ext.RankedTensorType.fromShape(ctx.mlirCtx(), res_shape).as(mlir.Type).?;
const loc = ctx.location(@src(), "broadcast({_}, {_}, axes={d})", .{ self, res_shape, axes_ }); const loc = ctx.location(@src(), "broadcast({_}, {_}, axes={d})", .{ self, res_shape, axes_ });
@ -1922,21 +1929,25 @@ pub const Tensor = struct {
/// Broadcasts a Tensor to the given shape, extending dimensions if needed. /// Broadcasts a Tensor to the given shape, extending dimensions if needed.
pub fn broad(self: Tensor, other: Shape) Tensor { pub fn broad(self: Tensor, other: Shape) Tensor {
// TODO: broad is too restrictive because sometime you only want to specify one specific axis
// Note: if you code below, make sure to update Shape.canBroadcastTo.
stdx.debug.assert(self._shape.canBroadcastTo(other), "Can't broadcast {} to {}", .{ self, other });
// Already the right shape
if (std.mem.eql(i64, self.dims(), other.dims())) return self;
// Non ambiguous broadcasting // Non ambiguous broadcasting
// TODO: broad is error prone because of this:
// it will happily broadcast .{ .a = 10, .b = 1 } to .{ .b = 10, .a = 5 }
if (self._shape.rank() == 0 or self._shape.rank() == other.rank()) { if (self._shape.rank() == 0 or self._shape.rank() == other.rank()) {
return self.broadcast(other, Shape.range(self._shape.rank(), .bool).dims()); const all_axes = [MAX_RANK]i64{ 0, 1, 2, 3, 4, 5, 6, 7 };
return self.broadcast(other, all_axes[0..self.rank()]);
} }
// check that each axis of self maps to an axis of other // check that each axis of self maps to an axis of other
var axes_: std.BoundedArray(i64, MAX_RANK) = .{}; var axes_: std.BoundedArray(i64, MAX_RANK) = .{};
for (self._shape.tags()) |t| { for (self._shape.tags()) |t| {
if (t != Shape.TagUnknown) { axes_.appendAssumeCapacity(@intCast(other.axis(t)));
if (other.hasTag(t)) |ax| {
axes_.appendAssumeCapacity(@intCast(other.axis(ax)));
} else {
std.debug.panic("Can't broadcast {} to {}", .{ self, other });
}
}
} }
return self.broadcast(other, axes_.constSlice()); return self.broadcast(other, axes_.constSlice());
} }
@ -2392,12 +2403,13 @@ pub const Tensor = struct {
} }
pub const ScatterOpts = struct { pub const ScatterOpts = struct {
/// Promise scatter that all coordinates in `indices` are sorted, wrt to the final in memory offset. /// Promise scatter that all coordinates in `indices` are sorted, wrt to the final offset in `self`
/// Result is undefined if the promise is violated. /// Result is undefined if the promise is violated.
indices_are_sorted: bool = false, indices_are_sorted: bool = false,
/// Promise scatter that slices don't overlap. /// Promise scatter that slices don't overlap.
/// Result is undefined if the promise is violated. /// Result is undefined if the promise is violated.
/// This allows for better code generation, because it means that updates can be applied in parallel.
indices_are_unique: bool = false, indices_are_unique: bool = false,
/// Function used to update previous value in `self` with values from `updates`. /// Function used to update previous value in `self` with values from `updates`.
@ -2405,132 +2417,102 @@ pub const Tensor = struct {
/// then you should make sure the slices don't overlap, /// then you should make sure the slices don't overlap,
/// otherwise the result will depend on the runtime scheduling /// otherwise the result will depend on the runtime scheduling
/// of the operator which is backend specific. /// of the operator which is backend specific.
update_fn: *const fn (*const anyopaque, Tensor, Tensor) Tensor = increment, update_fn: *const fn (Tensor, Tensor) Tensor = increment,
/// Extra data that may be needed for a custom update function. pub fn increment(old_value: Tensor, new_value: Tensor) Tensor {
/// `override` and `increment` don't need it, leaving it to undefined works. return old_value.add(new_value.convert(old_value.dtype()));
update_fn_ctx: *const anyopaque = undefined,
pub fn increment(_: *const anyopaque, old_value: Tensor, new_value: Tensor) Tensor {
return old_value.add(new_value);
} }
pub fn override(_: *const anyopaque, old_value: Tensor, new_value: Tensor) Tensor { pub fn override(old_value: Tensor, new_value: Tensor) Tensor {
_ = old_value; return new_value.convert(old_value.dtype());
return new_value;
} }
}; };
/// Update the given tensors, by copying `values` into self slices. /// Update the given tensor, by copying `values` into slice by slice into `self`.
/// The slices are chosen at runtime by interpreting indices as coordinates into `self`. /// The slices are chosen at runtime by interpreting indices as coordinates into `self`.
/// * `indices` represents a set of coordinates into `self`. /// This is a generalized version of `dynamicUpdateSlice` where more than one offset can be specified at a time.
/// For the sake of simplifying the creation of `indices` tensor,
/// it's allowed to not mention a specific axis if the coordinate for this axis is always `0`.
/// Similarly to `gatherValues`, the coordinates are read from the `.coord` axis, or last axis if `.coord` is not found.
/// The coordinates represent the "top-left" corner of the slice to extract.
/// `indices.dim(.coord)` must match `coord_axes.len`.
/// Other axes identify one "slice" and they must be found inside `updates`.
/// ///
/// * the output tensor starts with axes from `indices`. /// ### Arguments
/// * if the input tensor has tagged axes, matching `indices` axes,
/// they will be considered "batching" axes.
/// ///
/// Sample input/output shapes: /// - Return a tensor with same shape than `self`, with updated content.
/// * scatterSlices([A, B, C, D], .{b, c}, [N, 2], [N, B', C']) -> [A, B, C, D] /// - `indices` is a set of Tensor (typically rank 1), representing coordinates into `self`.
/// * scatterSlices(x(a,b,c,d), g(n,m), y[n,b,c]) [A,B,C,D] { /// all indices must have the same shape, but scalars are accepted.
/// var z = x; /// - each `indices` entry contains offset along an axes into `self`.
/// for (0..N) |n| { z[a,g[n,0]+b',g[n,1]+c',d] = y[n,a,b',c',d]; } /// Typically axes are identified by their tags, but in the absence of tags on `indices`,
/// The entry in indices will be assigned to axes of `self` from major to minor axis.
/// It is recommended to have indices referencing only major axes of `self` for better performance.
/// - `values` shape is obtained by concatenating the shape of `indices` with the shape of the slices to be extracted.
/// - `opts`: `zml.Tensor.ScatterOpts` des
///
/// ### Sample input/output shapes with corresponding pseudo-code.
///
/// Basic `scatterSlices` with the first two axes (.a, .b) being indexed, and full (.c, .d) slice copies:
///
/// ```
/// fn scatterSlices(x[A, B, C, D], .{.a=off_a[N], .b=off_b[N]}, y[N, C, D]) [A, B, C, D] {
/// var z = x;
/// for (0..N) |n| {
/// for (0..C) |c| for (0..D) |d| {{
/// z[off_a[n],off_b[n],c,d] += y[n, c, d];
/// }}
/// }
/// return z;
/// } /// }
/// ```
/// ///
/// **Warning**: if `opts.update_fn` is not associative not all calls to `scatterSlices` are sound. /// `scatterSlices` with the first three axes (.a, .b, .c) being indexed, and a partial copy of (.c, .d).
/// Note that .c axis is present both in the indices and updates, and `updates.dim(.c) < self.dim(.c)`.
///
/// ```
/// fn scatterSlices(x[A, B, C, D], .{.a=off_a[N], .b=off_b[N], .c=off_c[N]}, y[N, C', D]) [A, B, C, D] {
/// var z = x;
/// for (0..N) |n| {
/// for (0..C') |c| for (0..D) |d| {{
/// z[off_a[n],off_b[n],off_c[n]+c,d] += y[n, c, d];
/// }}
/// }
/// return z;
/// }
/// ```
///
/// `scatterSlices` with the first axis .a being indexed, and where .b is used as a batching axis.
/// Note that here .b axis is present in `self`, `off_a`, and `updates`,
/// and is not mentionned in the axes of indices.
///
/// ```
/// fn scatterSlices(x[A, B, C, D], .{.a=off_a[B,N]}, y[N, B, C, D]) [A, B, C, D] {
/// var z = x;
/// for (0..B) |b| {
/// for (0..N) |n| {
/// for (0..C) |c| for (0..D) |d| {{
/// z[off_a[b,n],b,c,d] += y[n, b, c, d];
/// }}
/// }
/// }
/// return z;
/// }
/// ```
///
/// ### Warnings
///
/// - if `opts.update_fn` is not associative not all calls to `scatterSlices` are sound.
/// In particular if you scatter overlapping slices, with `zml.Tensor.ScatterOpts.override`, /// In particular if you scatter overlapping slices, with `zml.Tensor.ScatterOpts.override`,
/// then the result will depend on the execution order that you don't control. /// then the result will depend on the execution order that you don't control.
pub fn scatterSlices(self: Tensor, coord_axes: anytype, indices: Tensor, updates: Tensor, opts: ScatterOpts) Tensor { /// - `scatterSlices` is a very expressive operator, and can lead to complicated code generation
const loc = @src(); /// that requires host<->device synchronization.
// scoped_log.debug("scatterSlices({}, {any}, {}, {})", .{ self, coord_axes, indices, updates }); /// ZML tries to generate the easiest to optimize IR, and will warn you if it generates known problematic IR.
pub fn scatterSlices(self: Tensor, indices: anytype, updates: Tensor, opts: ScatterOpts) Tensor {
scoped_log.debug("scatterSlices({}, {any}, {})", .{ self, indices, updates });
stdx.debug.assert(self.dtype() == updates.dtype(), "scatterSlices expects input and 'updates' tensors to be of the same type, got {} and {}", .{ self.dtype(), updates.dtype() }); const UpdateType = @TypeOf(ScatterOpts.increment);
const single_coord, const coord_axes_ = _parseGatherCoord(self, coord_axes); const Custom = struct {
const AxisKind = enum { batching, update_window, inserted_window, window_id }; pub fn inc(custom: *const UpdateType, old_value: Tensor, new_value: Tensor) Tensor {
var self_kind: std.BoundedArray(AxisKind, MAX_RANK) = .{}; return @call(.auto, custom, .{ old_value, new_value });
var indices_batch_axes: Shape.DimsArray = .{};
for (self._shape.tags()) |t| {
if (updates._shape.hasTag(t)) |_| {
if (indices._shape.hasTag(t)) |id_ax| {
// tag is in self, indices and updates -> it's a batching dim
self_kind.appendAssumeCapacity(.batching);
indices_batch_axes.appendAssumeCapacity(id_ax);
} else {
self_kind.appendAssumeCapacity(.update_window);
}
} else {
self_kind.appendAssumeCapacity(.inserted_window);
} }
}
// scoped_log.warn(" self_kind -> {any}", .{self_kind.constSlice()});
const index_coord_axis = if (single_coord)
indices.rank()
else blk: {
const ax = indices._shape.hasTag(.coord) orelse indices._shape.axis(-1);
stdx.debug.assert(indices.dim(ax) == coord_axes_.len, "scatterSlices({}, coord_axes={any}, indices, updates) expects 'indices' to be a tensor [..., {}], got {}", .{ self, coord_axes, coord_axes_.len, indices });
break :blk ax;
}; };
if (indices.count() == 1 and !single_coord) {
return self.dynamicUpdateSlice1d(updates, coord_axes_.get(0), indices.reshape(.{}));
}
var up_kind: std.BoundedArray(AxisKind, MAX_RANK) = .{}; return ops.scatter(Tensor, *const UpdateType, Custom.inc, self, opts.update_fn, indices, updates, opts);
// Note: we assume the scatter_dims appear in the same order inside indices and inside self.
for (updates._shape.tags(), 0..) |t, up_ax| {
if (self._shape.hasTag(t)) |self_ax| {
if (self_kind.get(self_ax) == .batching) {
up_kind.appendAssumeCapacity(.batching);
} else {
stdx.debug.assert(updates.dim(up_ax) <= self.dim(self_ax), "scatterSlices expects the slices described in 'updates' to fit inside 'self', but along axis .{s} it doesn't. Got self={}, updates={}.", .{ t, self, updates });
up_kind.appendAssumeCapacity(.update_window);
}
} else if (t == Shape.TagUnknown or indices._shape.hasTag(t) != null) {
up_kind.appendAssumeCapacity(.window_id);
} else {
std.debug.panic("scatterSlices expects 'updates' to be made of axes from 'self={}' and from 'indices={}', got unknown tag {s} in {}", .{ self, indices, t, updates });
}
}
const n_indices_axes = updates.rank() - _collectAxes(AxisKind, up_kind, .update_window).len;
if (single_coord) {
stdx.debug.assert(n_indices_axes == indices.rank(), "scatterSlices({}, {any}) expects 'updates' to contain all axes from 'indices', got indices={}, updates={}", .{ self, coord_axes, indices, updates });
} else {
stdx.debug.assert(n_indices_axes == indices.rank() - 1, "scatterSlices({}, {any}) expects 'updates' to contain all-but-last axes from 'indices', got indices={}, updates={}", .{ self, coord_axes, indices, updates });
}
const ctx = self.getContext();
const mlir_ctx = ctx.mlirCtx();
const _scalar: Tensor = .{ ._shape = Shape.init(.{}, self.dtype()), ._id = undefined };
const UpdateS = ops.BlockSign(ScatterOpts.increment);
const update_block, _ = ctx.makeBlock(.hermetic, UpdateS, opts.update_fn, opts.update_fn_ctx, .{ _scalar, _scalar });
const op = dialect.stablehlo.scatter(
mlir_ctx,
&.{self.value()},
&.{indices.value()},
&.{updates.value()},
update_block,
.{
.update_window_dims = _collectAxes(AxisKind, up_kind, .update_window).constSlice(),
.inserted_window_dims = _collectAxes(AxisKind, self_kind, .inserted_window).constSlice(),
.input_batching_dims = _collectAxes(AxisKind, self_kind, .batching).constSlice(),
.scatter_indices_batching_dims = indices_batch_axes.constSlice(),
.scatter_dims_to_operand_dims = toI64(coord_axes_.constSlice()),
.index_vector_dim = index_coord_axis,
.indices_are_sorted = opts.indices_are_sorted,
.unique_indices = opts.indices_are_unique,
},
mlir_ctx.location(loc),
);
return _result(self._shape, op.result(0));
} }
test scatterSlices { test scatterSlices {
@ -2538,14 +2520,25 @@ pub const Tensor = struct {
const platform = zml.testing.env(); const platform = zml.testing.env();
const Local = struct { const Local = struct {
pub fn scatter(self: Tensor, coord_axes: Shape.AxesArray, indices: Tensor, updates: Tensor) Tensor { pub fn _scatter(self: Tensor, indices: []const Tensor, updates: Tensor) Tensor {
return self.scatterSlices( return self.scatterSlices(
coord_axes.constSlice(),
indices, indices,
updates, updates,
.{ .update_fn = ScatterOpts.increment }, .{ .update_fn = ScatterOpts.increment },
); );
} }
pub fn _scatterCB(self: Tensor, coords: Tensor, updates: Tensor) Tensor {
return self.scatterSlices(
.{ .c = coords.choose1d(.coord, 0), .b = coords.choose1d(.coord, 1) },
updates,
.{ .update_fn = ScatterOpts.increment },
);
}
pub fn _idx(idx_shape: anytype) Tensor {
return Tensor.constant(idx_shape, .{ .i32 = 0 });
}
}; };
{ {
@ -2554,23 +2547,27 @@ pub const Tensor = struct {
defer comp.deinit(); defer comp.deinit();
comp.activate(); comp.activate();
defer comp.deactivate(); defer comp.deactivate();
const idx = Local._idx;
inline for (.{ inline for (.{
.{ .{ .a = 10 }, .a, .{}, .{ .a = 3 } }, // This is equivalent to a dynamic update slice, update 3 values at given offset of axis .a:
.{ .{ .a = 10, .b = 20 }, .b, .{ .a = 10, .n = 8 }, .{ .a = 10, .n = 8, .b = 2 } }, .{ .{ .a = 10 }, .{ .a = idx(.{}) }, .{ .a = 3 } },
// I'm not sure I like this variant, cause `b` is not mentionned in updates. // Use .a as a batching axis with .a=10 x .n=8 updates of 2 elements of .b
// So 'stablehlo.scatter' is implicitly broadcasting the updates along `b` axis. .{ .{ .a = 10, .b = 20 }, .{ .b = idx(.{ .a = 10, .n = 8 }) }, .{ .a = 10, .n = 8, .b = 2 } },
// OTOH asking the user to do the broadcasting isn't trivial cause they will need to do shape wrangling and that's annoying. // Same but with update transposed
.{ .{ .a = 10, .b = 20 }, .a, .{ .n = 8 }, .{ .n = 8, .a = 2 } }, .{ .{ .a = 10, .b = 20 }, .{ .b = idx(.{ .a = 10, .n = 8 }) }, .{ .a = 10, .b = 2, .n = 8 } },
.{ .{ .a = 10, .b = 20 }, .{ .b, .a }, .{ .n = 8, ._ = 2 }, .{ .n = 8, .a = 3, .b = 2 } }, // similar, but use the normalized form where a is no longer an explicit batching axis.
.{ .{ .a = 10, .b = 20 }, .{ .a, .b }, .{ .n = 8, ._ = 2 }, .{ .a = 3, .n = 8, .b = 2 } }, .{ .{ .a = 10, .b = 20 }, .{ .a = idx(.{ .a2 = 10, .n = 8 }), .b = idx(.{ .a2 = 10, .n = 8 }) }, .{ .a2 = 10, .n = 8, .b = 2 } },
.{ .{ .a = 10, .b = 20 }, .{ .a = idx(.{ .a = 10, .n = 8 }), .b = idx(.{ .a = 10, .n = 8 }) }, .{ .a = 10, .n = 8, .b = 2 } },
.{ .{ .a = 10, .b = 20 }, .{ .a = idx(.{ .n = 8 }) }, .{ .n = 8, .a = 2 } },
.{ .{ .a = 10, .b = 20 }, .{ .b = idx(.{ .n = 8 }), .a = idx(.{ .n = 8 }) }, .{ .n = 8, .a = 3, .b = 2 } },
.{ .{ .a = 10, .b = 20 }, .{ .a = idx(.{ .n = 8 }), .b = idx(.{ .n = 8 }) }, .{ .a = 3, .n = 8, .b = 2 } },
}) |testcase| { }) |testcase| {
const x_shape, const axes_, const idx_shape, const updates_shapes = testcase; const x_shape, const indices, const updates_shapes = testcase;
const x = Tensor.constant(x_shape, .{ .f16 = 0 }); const x = Tensor.constant(x_shape, .{ .f16 = 0 });
const idx = Tensor.constant(idx_shape, .{ .i32 = 0 });
const updates = Tensor.constant(updates_shapes, .{ .f16 = 0 }); const updates = Tensor.constant(updates_shapes, .{ .f16 = 0 });
const y = scatterSlices(x, axes_, idx, updates, .{}); const y = scatterSlices(x, indices, updates, .{});
// Shape doesn't change with scatterSlices // Shape doesn't change with scatterSlices
try zml.testing.expectEqualShapes(x.shape(), y.shape()); try zml.testing.expectEqualShapes(x.shape(), y.shape());
try std.testing.expect(y.value().owner().verify()); try std.testing.expect(y.value().owner().verify());
@ -2583,14 +2580,13 @@ pub const Tensor = struct {
defer a.deinit(); defer a.deinit();
a_host.deinit(std.testing.allocator); a_host.deinit(std.testing.allocator);
const scatter_indices = try zml.Buffer.fromArray(platform, [2][1]i32{ .{0}, .{2} }); const scatter_indices = try zml.Buffer.fromArray(platform, [2]i32{ 0, 2 });
const updates = try zml.Buffer.fromArray(platform, [2][3]i32{ .{ 10, 20, 30 }, .{ 70, 80, 90 } }); const updates = try zml.Buffer.fromArray(platform, [2][3]i32{ .{ 10, 20, 30 }, .{ 70, 80, 90 } });
const expected = [3][3]i32{ .{ 10, 21, 32 }, .{ 3, 4, 5 }, .{ 76, 87, 98 } }; const expected = [3][3]i32{ .{ 10, 21, 32 }, .{ 3, 4, 5 }, .{ 76, 87, 98 } };
const result = try zml.testing.compileAndCall(platform, Local.scatter, .{ const result = try zml.testing.compileAndCall(platform, Local._scatter, .{
a, a,
a.shape().axes(.{.a}), &.{scatter_indices.withTags(.{.n})},
scatter_indices.withTags(.{ .n, .coord }),
updates.withTags(.{ .n, .b }), updates.withTags(.{ .n, .b }),
}); });
try std.testing.expect(a.shape().eql(result.shape())); try std.testing.expect(a.shape().eql(result.shape()));
@ -2603,14 +2599,13 @@ pub const Tensor = struct {
defer a.deinit(); defer a.deinit();
a_host.deinit(std.testing.allocator); a_host.deinit(std.testing.allocator);
const scatter_indices = try zml.Buffer.fromArray(platform, [2][1]i32{ .{2}, .{7} }); const scatter_indices = try zml.Buffer.fromArray(platform, [2]i32{ 2, 7 });
const updates = try zml.Buffer.fromArray(platform, [2]i32{ 20, 70 }); const updates = try zml.Buffer.fromArray(platform, [2]i32{ 20, 70 });
const expected = [9]i32{ 0, 1, 22, 3, 4, 5, 6, 77, 8 }; const expected = [9]i32{ 0, 1, 22, 3, 4, 5, 6, 77, 8 };
const result = try zml.testing.compileAndCall(platform, Local.scatter, .{ const result = try zml.testing.compileAndCall(platform, Local._scatter, .{
a, a,
a.shape().axes(.{0}), &.{scatter_indices.withTags(.{.n})},
scatter_indices.withTags(.{ .n, .coord }),
updates.withTags(.{.n}), updates.withTags(.{.n}),
}); });
try std.testing.expect(a.shape().eql(result.shape())); try std.testing.expect(a.shape().eql(result.shape()));
@ -2642,7 +2637,7 @@ pub const Tensor = struct {
); );
defer values.deinit(); defer values.deinit();
const result = try zml.testing.compileAndCall(platform, Local.scatter, .{ operand, operand.shape().axes(.{ .c, .b }), start_indices, values }); const result = try zml.testing.compileAndCall(platform, Local._scatterCB, .{ operand, start_indices, values });
const expected = [2][3][4][2]u16{ const expected = [2][3][4][2]u16{
.{ .{
@ -2740,12 +2735,12 @@ pub const Tensor = struct {
const platform = zml.testing.env(); const platform = zml.testing.env();
const allocator = std.testing.allocator; const allocator = std.testing.allocator;
const ArgMaxTest = struct { const ArgMaxTest = struct {
pub fn forward(x: Tensor) Tensor.ArgMaxRes { pub fn _fwd(x: Tensor) Tensor.ArgMaxRes {
return x.argMax(1); return x.argMax(1);
} }
}; };
const argmax = try zml.compileFn(allocator, ArgMaxTest.forward, .{Shape.init(.{ 1, 5 }, .f32)}, platform); const argmax = try zml.compileFn(allocator, ArgMaxTest._fwd, .{Shape.init(.{ 1, 5 }, .f32)}, platform);
defer argmax.deinit(); defer argmax.deinit();
// Test with tie // Test with tie
{ {
@ -3290,10 +3285,10 @@ pub const Tensor = struct {
const res = try zml.testing.compileAndCall( const res = try zml.testing.compileAndCall(
platform, platform,
struct { struct {
pub fn forward(x_: Tensor, idx_: struct { a: Tensor }, y_: Tensor) Tensor { pub fn _fwd(x_: Tensor, idx_: struct { a: Tensor }, y_: Tensor) Tensor {
return x_.dynamicUpdateSlice(idx_, y_); return x_.dynamicUpdateSlice(idx_, y_);
} }
}.forward, }._fwd,
.{ x.withTags(.{.a}), .{ .a = idx }, y.withTags(.{.a}) }, .{ x.withTags(.{.a}), .{ .a = idx }, y.withTags(.{.a}) },
); );
try testing.expectEqual([10]f32{ 0, 1, 2, 3, -1, -1, 6, 7, 8, 9 }, try res.getValue([10]f32)); try testing.expectEqual([10]f32{ 0, 1, 2, 3, -1, -1, 6, 7, 8, 9 }, try res.getValue([10]f32));
@ -3308,10 +3303,10 @@ pub const Tensor = struct {
const res = try zml.testing.compileAndCall( const res = try zml.testing.compileAndCall(
platform, platform,
struct { struct {
pub fn forward(x_: Tensor, idx_: Tensor, y_: Tensor) Tensor { pub fn _fwd(x_: Tensor, idx_: Tensor, y_: Tensor) Tensor {
return x_.dynamicUpdateSlice(.{ .b = idx_ }, y_); return x_.dynamicUpdateSlice(.{ .b = idx_ }, y_);
} }
}.forward, }._fwd,
.{ x.withTags(.{ .a, .b }), idx, y.withTags(.{.a}) }, .{ x.withTags(.{ .a, .b }), idx, y.withTags(.{.a}) },
); );
try testing.expectEqualDeep( try testing.expectEqualDeep(
@ -3328,10 +3323,10 @@ pub const Tensor = struct {
const res = try zml.testing.compileAndCall( const res = try zml.testing.compileAndCall(
platform, platform,
struct { struct {
pub fn forward(x_: Tensor, idx_: Tensor, y_: Tensor) Tensor { pub fn _fwd(x_: Tensor, idx_: Tensor, y_: Tensor) Tensor {
return x_.dynamicUpdateSlice(.{ zml.Tensor.scalar(0, .i32), idx_ }, y_); return x_.dynamicUpdateSlice(.{ zml.Tensor.scalar(0, .i32), idx_ }, y_);
} }
}.forward, }._fwd,
.{ x, idx, y }, .{ x, idx, y },
); );
try testing.expectEqualDeep( try testing.expectEqualDeep(
@ -3349,10 +3344,10 @@ pub const Tensor = struct {
const res = try zml.testing.compileAndCall( const res = try zml.testing.compileAndCall(
platform, platform,
struct { struct {
pub fn forward(x_: Tensor, idx_: struct { a: Tensor, b: Tensor }, y_: Tensor) Tensor { pub fn _fwd(x_: Tensor, idx_: struct { a: Tensor, b: Tensor }, y_: Tensor) Tensor {
return x_.dynamicUpdateSlice(idx_, y_); return x_.dynamicUpdateSlice(idx_, y_);
} }
}.forward, }._fwd,
.{ x.withTags(.{ .a, .b }), .{ .a = idx_a, .b = idx_b }, y.withTags(.{.a}) }, .{ x.withTags(.{ .a, .b }), .{ .a = idx_a, .b = idx_b }, y.withTags(.{.a}) },
); );
try testing.expectEqualDeep( try testing.expectEqualDeep(
@ -3368,11 +3363,11 @@ pub const Tensor = struct {
const idx_a = try zml.Buffer.scalar(platform, 1, .i32); const idx_a = try zml.Buffer.scalar(platform, 1, .i32);
const idx_b = try zml.Buffer.scalar(platform, 3, .i32); const idx_b = try zml.Buffer.scalar(platform, 3, .i32);
const A = struct { const A = struct {
pub fn forward(x_: Tensor, idx_: [2]Tensor, y_: Tensor) Tensor { pub fn _fwd(x_: Tensor, idx_: [2]Tensor, y_: Tensor) Tensor {
return x_.dynamicUpdateSlice(&idx_, y_); return x_.dynamicUpdateSlice(&idx_, y_);
} }
}; };
const res = try zml.testing.compileAndCall(platform, A.forward, .{ x, .{ idx_a, idx_b }, y }); const res = try zml.testing.compileAndCall(platform, A._fwd, .{ x, .{ idx_a, idx_b }, y });
try testing.expectEqualDeep( try testing.expectEqualDeep(
[2][5]f32{ .{ 0, 1, 2, 3, 4 }, .{ 5, 6, 7, -1, 9 } }, [2][5]f32{ .{ 0, 1, 2, 3, 4 }, .{ 5, 6, 7, -1, 9 } },
res.getValue([2][5]f32), res.getValue([2][5]f32),
@ -3589,15 +3584,16 @@ pub const Tensor = struct {
/// Given a set of N vectors of lengths A, B, C, D, /// Given a set of N vectors of lengths A, B, C, D,
/// returns N tensors of rank N, and shape (A, B, C, D). /// returns N tensors of rank N, and shape (A, B, C, D).
/// For any coordinate (a, b, c, d), /// For any coordinate (a, b, c, d), we have:
/// we have: ///
/// - res[0][a, b, c, d] == A[a] /// - res[0][a, b, c, d] == A[a]
/// - res[1][a, b, c, d] == B[b] /// - res[1][a, b, c, d] == B[b]
/// - res[2][a, b, c, d] == C[c] /// - res[2][a, b, c, d] == C[c]
/// - res[3][a, b, c, d] == D[d] /// - res[3][a, b, c, d] == D[d]
///
/// This is implemented with broadcasting, so typically it won't copy. /// This is implemented with broadcasting, so typically it won't copy.
/// In Pytorch/Numpy this is know as `meshgrid` with "ij" mode. /// In Pytorch/Numpy this is know as `meshgrid` with "ij" mode.
/// See torch.meshgrid for the "xy" mode. /// See `zml.torch.meshgrid` for the "xy" mode.
pub fn cartesianProduct(comptime N: u3, vectors: [N]Tensor) [N]Tensor { pub fn cartesianProduct(comptime N: u3, vectors: [N]Tensor) [N]Tensor {
var out: @TypeOf(vectors) = undefined; var out: @TypeOf(vectors) = undefined;
_cartesianProduct(&vectors, &out); _cartesianProduct(&vectors, &out);
@ -3634,13 +3630,13 @@ pub const Tensor = struct {
const y = try zml.Buffer.fromSlice(client, .{4}, &[_]i32{ 0, 1, 2, 3 }); const y = try zml.Buffer.fromSlice(client, .{4}, &[_]i32{ 0, 1, 2, 3 });
const Local = struct { const Local = struct {
pub fn cartesianProduct2(a: Tensor, b: Tensor) [2]Tensor { pub fn _cartesianProduct2(a: Tensor, b: Tensor) [2]Tensor {
return cartesianProduct(2, .{ a, b }); return cartesianProduct(2, .{ a, b });
} }
}; };
{ {
const xs, const ys = try zml.testing.compileAndCall(client, Local.cartesianProduct2, .{ x, y }); const xs, const ys = try zml.testing.compileAndCall(client, Local._cartesianProduct2, .{ x, y });
try std.testing.expectEqualSlices(i64, &.{ 6, 4 }, xs.shape().dims()); try std.testing.expectEqualSlices(i64, &.{ 6, 4 }, xs.shape().dims());
try std.testing.expectEqualSlices(i64, &.{ 6, 4 }, ys.shape().dims()); try std.testing.expectEqualSlices(i64, &.{ 6, 4 }, ys.shape().dims());
try std.testing.expectEqualDeep( try std.testing.expectEqualDeep(
@ -3670,8 +3666,8 @@ pub const Tensor = struct {
/// Given a set of N vectors of lengths A, B, C, D, /// Given a set of N vectors of lengths A, B, C, D,
/// returns 1 tensors of rank N+1, and shape (A, B, C, D, N). /// returns 1 tensors of rank N+1, and shape (A, B, C, D, N).
/// For any coordinate (a, b, c, d), /// For any coordinate (a, b, c, d), we have:
/// we have: ///
/// - res[a, b, c, d] == (A[a], B[b], C[c], D[d]) /// - res[a, b, c, d] == (A[a], B[b], C[c], D[d])
pub fn cartesianProductStacked(vectors: []const Tensor) Tensor { pub fn cartesianProductStacked(vectors: []const Tensor) Tensor {
var out = std.BoundedArray(Tensor, Tensor.MAX_RANK).init(vectors.len) catch unreachable; var out = std.BoundedArray(Tensor, Tensor.MAX_RANK).init(vectors.len) catch unreachable;
@ -3687,12 +3683,12 @@ pub const Tensor = struct {
const y = try zml.Buffer.fromSlice(platform, .{4}, &[_]i32{ 0, 1, 2, 3 }); const y = try zml.Buffer.fromSlice(platform, .{4}, &[_]i32{ 0, 1, 2, 3 });
const Local = struct { const Local = struct {
pub fn cartesianProduct2(a: Tensor, b: Tensor) Tensor { pub fn _fwd(a: Tensor, b: Tensor) Tensor {
return cartesianProductStacked(&.{ a, b }); return cartesianProductStacked(&.{ a, b });
} }
}; };
const z = try zml.testing.compileAndCall(platform, Local.cartesianProduct2, .{ x, y }); const z = try zml.testing.compileAndCall(platform, Local._fwd, .{ x, y });
try std.testing.expectEqualDeep( try std.testing.expectEqualDeep(
[6][4][2]i32{ [6][4][2]i32{
.{ .{ 0, 0 }, .{ 0, 1 }, .{ 0, 2 }, .{ 0, 3 } }, .{ .{ 0, 0 }, .{ 0, 1 }, .{ 0, 2 }, .{ 0, 3 } },
@ -3795,7 +3791,7 @@ test "Tensor.maxPool1d" {
const platform = zml.testing.env(); const platform = zml.testing.env();
const MaxPool = struct { const MaxPool = struct {
pub fn forward(x: zml.Tensor) Tensor.ArgMaxRes { pub fn _fwd(x: zml.Tensor) Tensor.ArgMaxRes {
return x.maxPool1d(.{ return x.maxPool1d(.{
.window_dimensions = 3, .window_dimensions = 3,
.window_strides = 2, .window_strides = 2,
@ -3807,7 +3803,7 @@ test "Tensor.maxPool1d" {
for (&data, 0..) |*v, i| v.* = @floatFromInt(i); for (&data, 0..) |*v, i| v.* = @floatFromInt(i);
const x = try zml.Buffer.fromSlice(platform, .{ 2, 2, 5 }, &data); const x = try zml.Buffer.fromSlice(platform, .{ 2, 2, 5 }, &data);
const result = try zml.testing.compileAndCall(platform, MaxPool.forward, .{x}); const result = try zml.testing.compileAndCall(platform, MaxPool._fwd, .{x});
try zml.testing.expectEqualShapes(Shape.init(.{ 2, 2, 2 }, .f32), result.values.shape()); try zml.testing.expectEqualShapes(Shape.init(.{ 2, 2, 2 }, .f32), result.values.shape());
try zml.testing.expectEqualShapes(Shape.init(.{ 2, 2, 2 }, .i32), result.indices.shape()); try zml.testing.expectEqualShapes(Shape.init(.{ 2, 2, 2 }, .i32), result.indices.shape());
const buffer = result.values.getValue([2][2][2]f32); const buffer = result.values.getValue([2][2][2]f32);
@ -3831,7 +3827,7 @@ test "Tensor.maxPool2d" {
const platform = zml.testing.env(); const platform = zml.testing.env();
const MaxPool = struct { const MaxPool = struct {
pub fn forward(x: Tensor) Tensor.ArgMaxRes { pub fn _fwd(x: Tensor) Tensor.ArgMaxRes {
return x.maxPool2d(.{ return x.maxPool2d(.{
.window_dimensions = .{ 3, 2 }, .window_dimensions = .{ 3, 2 },
.window_strides = .{ 2, 1 }, .window_strides = .{ 2, 1 },
@ -3843,7 +3839,7 @@ test "Tensor.maxPool2d" {
for (&data, 0..) |*v, i| v.* = @floatFromInt(i); for (&data, 0..) |*v, i| v.* = @floatFromInt(i);
const x = try zml.Buffer.fromSlice(platform, .{ 2, 2, 5, 5 }, &data); const x = try zml.Buffer.fromSlice(platform, .{ 2, 2, 5, 5 }, &data);
const result = try zml.testing.compileAndCall(platform, MaxPool.forward, .{x}); const result = try zml.testing.compileAndCall(platform, MaxPool._fwd, .{x});
try zml.testing.expectEqualShapes(Shape.init(.{ 2, 2, 2, 4 }, .f32), result.values.shape()); try zml.testing.expectEqualShapes(Shape.init(.{ 2, 2, 2, 4 }, .f32), result.values.shape());
try zml.testing.expectEqualShapes(Shape.init(.{ 2, 2, 2, 4 }, .i32), result.indices.shape()); try zml.testing.expectEqualShapes(Shape.init(.{ 2, 2, 2, 4 }, .i32), result.indices.shape());
var buffer: [2][2][2][4]f32 = undefined; var buffer: [2][2][2][4]f32 = undefined;
@ -3962,7 +3958,7 @@ test shapesOf {
} }
} }
fn _collectAxes(T: type, bounded_array: std.BoundedArray(T, Tensor.MAX_RANK), value: T) std.BoundedArray(i64, Tensor.MAX_RANK) { pub fn _collectAxes(T: type, bounded_array: std.BoundedArray(T, Tensor.MAX_RANK), value: T) std.BoundedArray(i64, Tensor.MAX_RANK) {
var res: std.BoundedArray(i64, Tensor.MAX_RANK) = .{}; var res: std.BoundedArray(i64, Tensor.MAX_RANK) = .{};
for (bounded_array.constSlice(), 0..) |v, ax| { for (bounded_array.constSlice(), 0..) |v, ax| {
if (v == value) { if (v == value) {
@ -4046,13 +4042,13 @@ test "unused tensor" {
const platform = zml.testing.env(); const platform = zml.testing.env();
const Local = struct { const Local = struct {
pub fn forward(x: Tensor) Tensor { pub fn _fwd(x: Tensor) Tensor {
const y = x.addConstant(1); const y = x.addConstant(1);
_ = y; _ = y;
return x; return x;
} }
}; };
const mod = try zml.compileFn(std.testing.allocator, Local.forward, .{Shape.init(.{10}, .f32)}, platform); const mod = try zml.compileFn(std.testing.allocator, Local._fwd, .{Shape.init(.{10}, .f32)}, platform);
defer mod.deinit(); defer mod.deinit();
} }

View File

@ -105,8 +105,8 @@ pub fn unsqueeze(
} }
test unsqueeze { test unsqueeze {
const UnsqueezeTest = struct { const Local = struct {
pub fn forward(x: Tensor) Tensor { pub fn _fwd(x: Tensor) Tensor {
var y = x; var y = x;
y = unsqueeze(y, 0); y = unsqueeze(y, 0);
y = unsqueeze(y, -1); y = unsqueeze(y, -1);
@ -117,7 +117,7 @@ test unsqueeze {
const platform = zml.testing.env(); const platform = zml.testing.env();
const x = try zml.Buffer.fromArray(platform, @as([8]f16, undefined)); const x = try zml.Buffer.fromArray(platform, @as([8]f16, undefined));
const res = try zml.testing.compileAndCall(platform, UnsqueezeTest.forward, .{x}); const res = try zml.testing.compileAndCall(platform, Local._fwd, .{x});
try zml.testing.expectEqualShapes(zml.Shape.init(.{ 1, 8, 1, 1 }, .f16), res.shape()); try zml.testing.expectEqualShapes(zml.Shape.init(.{ 1, 8, 1, 1 }, .f16), res.shape());
} }
@ -247,7 +247,7 @@ test meshgrid {
const y = try zml.Buffer.fromSlice(platform, .{4}, &[_]i32{ 0, 1, 2, 3 }); const y = try zml.Buffer.fromSlice(platform, .{4}, &[_]i32{ 0, 1, 2, 3 });
const Local = struct { const Local = struct {
pub fn meshgrid2(a: Tensor, b: Tensor, indexing: MeshgridIndexing) [2]Tensor { pub fn _meshgrid2(a: Tensor, b: Tensor, indexing: MeshgridIndexing) [2]Tensor {
return meshgrid(2, .{ a, b }, indexing); return meshgrid(2, .{ a, b }, indexing);
} }
}; };
@ -255,7 +255,7 @@ test meshgrid {
// Only test .xy mode, sinc .ij is just calling cartesianProduct which // Only test .xy mode, sinc .ij is just calling cartesianProduct which
// got its own tests. // got its own tests.
{ {
const xs, const ys = try zml.testing.compileAndCall(platform, Local.meshgrid2, .{ x, y, .xy }); const xs, const ys = try zml.testing.compileAndCall(platform, Local._meshgrid2, .{ x, y, .xy });
try std.testing.expectEqualSlices(i64, &.{ 4, 6 }, xs.dims()); try std.testing.expectEqualSlices(i64, &.{ 4, 6 }, xs.dims());
try std.testing.expectEqualSlices(i64, &.{ 4, 6 }, ys.dims()); try std.testing.expectEqualSlices(i64, &.{ 4, 6 }, ys.dims());
try std.testing.expectEqualDeep( try std.testing.expectEqualDeep(