From 24a7c984766cfc9a3cc25ec735480f1207306d1f Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Tue, 14 Feb 2023 13:52:49 +0000 Subject: [PATCH] Implement scatterSlices functionality. --- mlir/dialects/stablehlo.zig | 141 +++++++---------- mlir/mlir.zig | 4 + zml/context.zig | 12 +- zml/dtype.zig | 8 +- zml/meta.zig | 8 +- zml/mlir.zig | 2 +- zml/module.zig | 5 +- zml/ops.zig | 19 +-- zml/shape.zig | 14 +- zml/tensor.zig | 305 +++++++++++++++++++++++++++++++++--- 10 files changed, 380 insertions(+), 138 deletions(-) diff --git a/mlir/dialects/stablehlo.zig b/mlir/dialects/stablehlo.zig index af95f10..55e1915 100644 --- a/mlir/dialects/stablehlo.zig +++ b/mlir/dialects/stablehlo.zig @@ -273,21 +273,68 @@ pub fn gather( ); } -pub const ScatterArgs = struct { - update_window_dims: []const i64, - inserted_window_dims: []const i64, - scatter_dims_to_operand_dims: []const i64, - index_vector_dim: i64, - indices_are_sorted: bool = false, - unique_indices: bool = false, -}; - fn elementTypeOrSelf(typ: mlir.Type) mlir.Type { return if (typ.as(mlir.ShapedType)) |shaped| { return shaped.elementType(); } else typ; } +pub const ScatterArgs = struct { + update_window_dims: []const i64, + inserted_window_dims: []const i64, + input_batching_dims: []const i64, + scatter_indices_batching_dims: []const i64, + scatter_dims_to_operand_dims: []const i64, + index_vector_dim: i64, + indices_are_sorted: bool = false, + unique_indices: bool = false, + + pub fn getScatterDimensionNumbers(self: ScatterArgs, ctx: mlir.Context) mlir.Attribute { + return mlir.Attribute.wrap( + c.stablehloScatterDimensionNumbersGet( + ctx.inner(), + @intCast(self.update_window_dims.len), + self.update_window_dims.ptr, + @intCast(self.inserted_window_dims.len), + self.inserted_window_dims.ptr, + @intCast(self.input_batching_dims.len), + self.input_batching_dims.ptr, + @intCast(self.scatter_indices_batching_dims.len), + self.scatter_indices_batching_dims.ptr, + @intCast(self.scatter_dims_to_operand_dims.len), + self.scatter_dims_to_operand_dims.ptr, + self.index_vector_dim, + ), + ); + } +}; + +pub fn scatter( + ctx: mlir.Context, + inputs: []const mlir.Value, + scatter_indices: []const mlir.Value, + updates: []const mlir.Value, + update_block: mlir.Block, + args: ScatterArgs, + location: mlir.Location, +) mlir.Operation { + return mlir.Operation.make( + ctx, + "stablehlo.scatter", + .{ + .variadic_operands = &.{ inputs, scatter_indices, updates }, + .blocks = &.{update_block}, + .attributes = &.{ + .{ "scatter_dimension_numbers", args.getScatterDimensionNumbers(ctx) }, + .{ "indices_are_sorted", mlir.BoolAttribute.init(ctx, args.indices_are_sorted).as(mlir.Attribute).? }, + .{ "unique_indices", mlir.BoolAttribute.init(ctx, args.unique_indices).as(mlir.Attribute).? }, + }, + .result_type_inference = true, + .location = location, + }, + ); +} + pub fn iota(ctx: mlir.Context, dimension: i64, result_type: mlir.Type, location: mlir.Location) mlir.Operation { return mlir.Operation.make(ctx, "stablehlo.iota", .{ .operands = &.{}, @@ -915,82 +962,6 @@ pub const OutputOperandAliasAttribute = struct { } }; -pub const ScatterDimensionNumbersAttribute = struct { - _inner: c.MlirAttribute, - - pub usingnamespace mlir.MlirHelpers(ScatterDimensionNumbersAttribute, .{ - .is_a_fn = c.stablehloAttributeIsAScatterDimensionNumbers, - .is_null_fn = c.mlirAttributeIsNull, - .dump_fn = c.mlirAttributeDump, - .equal_fn = c.mlirAttributeEqual, - }); - const Self = ScatterDimensionNumbersAttribute; - - pub fn init( - ctx: mlir.Context, - update_window_dims: []const i64, - inserted_window_dims: []const i64, - input_batching_dims: []const i64, - scatter_indices_batching_dims: []const i64, - scatter_dims_to_operand_dims: []const i64, - index_vector_dim: i64, - ) Self { - return Self.wrap( - c.stablehloScatterDimensionNumbersGet( - ctx.inner(), - @intCast(update_window_dims.len), - update_window_dims.ptr, - @intCast(inserted_window_dims.len), - inserted_window_dims.ptr, - @intCast(input_batching_dims.len), - input_batching_dims.ptr, - @intCast(scatter_indices_batching_dims.len), - scatter_indices_batching_dims.ptr, - @intCast(scatter_dims_to_operand_dims.len), - scatter_dims_to_operand_dims.ptr, - index_vector_dim, - ), - ); - } - - pub fn getUpdateWindowDimsSize(self: Self) usize { - return @intCast(c.stablehloScatterDimensionNumbersGetUpdateWindowDimsSize(self.inner())); - } - - pub fn getUpdateWindowDimsElem(self: Self, pos: usize) i64 { - return c.stablehloScatterDimensionNumbersGetUpdateWindowDimsElem(self.inner(), @intCast(pos)); - } - - pub fn getInsertedWindowDimsSize(self: Self) usize { - return @intCast(c.stablehloScatterDimensionNumbersGetInsertedWindowDimsSize(self.inner())); - } - - pub fn getInsertedWindowDimsElem(self: Self, pos: usize) i64 { - return c.stablehloScatterDimensionNumbersGetInsertedWindowDimsElem(self.inner(), @intCast(pos)); - } - - pub fn getInputBatchingDimsSize(self: Self) usize { - return @intCast(c.stablehloScatterDimensionNumbersGetInputBatchingDimsSize(self.inner())); - } - - pub fn getInputBatchingDimsElem(self: Self, pos: usize) i64 { - return c.stablehloScatterDimensionNumbersGetInputBatchingDimsElem(self.inner(), @intCast(pos)); - } - - pub fn getScatterIndicesBatchingDimsSize(self: Self) usize { - return @intCast(c.stablehloScatterDimensionNumbersGetScatterIndicesBatchingDimsSize(self.inner())); - } - - pub fn getScatterIndicesBatchingDimsElem(self: Self, pos: usize) i64 { - return c.stablehloScatterDimensionNumbersGetScatterIndicesBatchingDimsElem(self.inner(), @intCast(pos)); - } - - pub fn getIndexVectorDim(self: Self) i64 { - // There really is "Scatter" missing in the function name - return c.stablehloDimensionNumbersGetIndexVectorDim(self.inner()); - } -}; - pub const PrecisionAttribute = struct { _inner: c.MlirAttribute, diff --git a/mlir/mlir.zig b/mlir/mlir.zig index f3ccc70..50a66ec 100644 --- a/mlir/mlir.zig +++ b/mlir/mlir.zig @@ -416,6 +416,10 @@ pub const BoolAttribute = struct { pub fn value(self: Self) bool { return c.mlirBoolAttrGetValue(self.inner()); } + + pub fn asAttr(self: Self) Attribute { + return self.as(Attribute).?; + } }; pub const TypeAttribute = struct { diff --git a/zml/context.zig b/zml/context.zig index a7f26ae..4139972 100644 --- a/zml/context.zig +++ b/zml/context.zig @@ -49,18 +49,24 @@ pub const Context = struct { Context.mlir_once.call(); var platforms = PlatformsMap.initFill(null); + var num_platforms: u8 = 0; var it = Context.apis.iterator(); while (it.next()) |entry| { if (entry.value.*) |api| { const target = entry.key; - const p = Platform.init(target, api) catch continue; + const p = Platform.init(target, api) catch |err| { + log.err("Failed to load platform .{s}: {}", .{ @tagName(target), err }); + continue; + }; if (p.getDevices().len == 0) { log.err("No device found for platform {} !", .{target}); continue; } platforms.set(target, p); + num_platforms += 1; } } + if (num_platforms == 0) return error.NotFound; return .{ .platforms = platforms, }; @@ -121,13 +127,13 @@ pub const Context = struct { pub fn autoPlatform(self: *Context) Platform { // the last platform is the one that with the high enum number, so considered // to be the "best" one - var platform_: Platform = undefined; + var platform_: ?Platform = null; var iterator = self.platforms.iterator(); while (iterator.next()) |entry| { if (entry.value.*) |p| { platform_ = p; } } - return platform_; + return platform_ orelse @panic("No platform found !"); } }; diff --git a/zml/dtype.zig b/zml/dtype.zig index 2230463..cfc13cf 100644 --- a/zml/dtype.zig +++ b/zml/dtype.zig @@ -234,11 +234,11 @@ pub const Data = union(DataType) { /// If the `dtype` and `@TypeOf(value)` are incompatible /// or a cast from `value` to `FieldType(dtype)` would /// be lossy, a panic occurs. - pub fn init(dtype: DataType, value: anytype) Data { + pub fn init(dtype_: DataType, value: anytype) Data { const T = @TypeOf(value); const Ti = @typeInfo(T); - return switch (dtype) { + return switch (dtype_) { .bool => switch (Ti) { .Bool => .{ .bool = value }, .ComptimeInt, .Int, .ComptimeFloat, .Float => .{ .bool = value != 0 }, @@ -302,7 +302,7 @@ pub const Data = union(DataType) { try std.testing.expectEqual(C128.init(1, 2), Data.init(.c128, C64.init(1, 2)).c128); } - pub fn dataType(self: Data) DataType { + pub fn dtype(self: Data) DataType { return std.meta.activeTag(self); } @@ -327,7 +327,7 @@ pub const Data = union(DataType) { }, else => {}, } - std.debug.panic("Unsupported conversion {} -> {s}", .{ self.dataType(), @typeName(T) }); + std.debug.panic("Unsupported conversion {} -> {s}", .{ self.dtype(), @typeName(T) }); } }; diff --git a/zml/meta.zig b/zml/meta.zig index 0f593f1..4f35c1a 100644 --- a/zml/meta.zig +++ b/zml/meta.zig @@ -269,7 +269,10 @@ pub fn MapType(From: type, To: type) type { []const map(ptr_info.child) else []map(ptr_info.child), - .One => *map(ptr_info.child), + .One => if (ptr_info.is_const) + *const map(ptr_info.child) + else + *map(ptr_info.child), else => T, }, .Optional => |opt_info| ?map(opt_info.child), @@ -446,8 +449,9 @@ pub fn visit(comptime cb: anytype, ctx: FnParam(cb, 0), v: anytype) void { const Callback = @TypeOf(cb); @compileError("zml.meta.visit is expecting a pointer input to go with following callback signature: " ++ @typeName(Callback) ++ " but received: " ++ @typeName(T)); } - const ptr_info = type_info_v.Pointer; + if (@typeInfo(ptr_info.child) == .Fn) return; + if (ptr_info.child == anyopaque) return; // This is important, because with trivial types like void, // Zig sometimes decide to call `visit` at comptime, but can't do // the pointer wrangling logic at comptime. diff --git a/zml/mlir.zig b/zml/mlir.zig index 538deeb..b90763d 100644 --- a/zml/mlir.zig +++ b/zml/mlir.zig @@ -155,7 +155,7 @@ pub const ext = struct { pub const DenseIntOrFPElementsAttribute = struct { pub fn fromData(data: dtype.Data, result_type: mlir.Type) mlir.Attribute { - return switch (data.dataType()) { + return switch (data.dtype()) { .bool => mlir.DenseIntOrFPElementsAttribute(.bool).init(result_type, data.constSlice()).as(mlir.Attribute).?, .i8 => mlir.DenseIntOrFPElementsAttribute(.i8).init(result_type, data.constSlice()).as(mlir.Attribute).?, .i16 => mlir.DenseIntOrFPElementsAttribute(.i16).init(result_type, data.constSlice()).as(mlir.Attribute).?, diff --git a/zml/module.zig b/zml/module.zig index 63c80f9..00504fc 100644 --- a/zml/module.zig +++ b/zml/module.zig @@ -140,8 +140,8 @@ pub const CompilationContext = struct { /// `blkctx` represents values from outside the block that can be accessed inside the block. pub fn makeBlock( self: *CompilationContext, - comptime func: anytype, comptime S: ops.BlockSignature, + func: *const S.Fn, blkctx: S.BlkCtx, args: S.Args, ) mlir.Block { @@ -996,7 +996,8 @@ fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, m else => {}, } - const loaded_executable = try platform.pjrt_client.compile(platform.pjrt_api, arena, module, try options.encode(arena)); + const options_bytes = try options.encode(arena); + const loaded_executable = try platform.pjrt_client.compile(platform.pjrt_api, arena, module, options_bytes); errdefer unreachable; // errdefer loaded_executable.deinit(); if (platform.compilation_options.cache_location) |compilation_cache_location| { diff --git a/zml/ops.zig b/zml/ops.zig index 9335ff7..35765ce 100644 --- a/zml/ops.zig +++ b/zml/ops.zig @@ -44,8 +44,8 @@ pub fn while_( @compileError("cond_fn and body_fn signatures don't match ! " ++ @typeName(@TypeOf(cond_fn)) ++ " and " ++ @typeName(@TypeOf(body_fn))); } const ctx = CompilationContext.current(); - const cond_block = ctx.makeBlock(cond_fn, CondS, blkctx, inputs); - const body_block = ctx.makeBlock(body_fn, BodyS, blkctx, inputs); + const cond_block = ctx.makeBlock(CondS, &cond_fn, blkctx, inputs); + const body_block = ctx.makeBlock(BodyS, &body_fn, blkctx, inputs); var input_values: [BodyS.nIn]mlir.Value = undefined; ctx.extractValues(&inputs, &input_values); @@ -136,7 +136,7 @@ pub fn reduce( var init_values: [N]mlir.Value = undefined; ctx.extractValues(&inits, &init_values); - const body_block = ctx.makeBlock(body_fn, BodyS, {}, .{ inits, inits }); + const body_block = ctx.makeBlock(BodyS, &body_fn, {}, .{ inits, inits }); const loc = ctx.mlirCtx().location(@src()); @@ -226,7 +226,7 @@ pub fn reduceWindow( if (BodyS.Return != @TypeOf(inputs)) @compileError("reduce body function need to have the following signature `fn (left: T, right: T) T`, got: " ++ @typeName(body_fn)); } const ctx = CompilationContext.current(); - const body_block = ctx.makeBlock(body_fn, BodyS, {}, .{ inits, inits }); + const body_block = ctx.makeBlock(BodyS, &body_fn, {}, .{ inits, inits }); const N = comptime @divExact(BodyS.nIn, 2); var input_values: [N]mlir.Value = undefined; ctx.extractValues(&inputs, &input_values); @@ -398,8 +398,8 @@ pub fn if_( @compileError("true_branch_fn and false_branch_fn return types don't match ! " ++ @typeName(TrueBlockSignature.Return) ++ " and " ++ @typeName(FalseBlockSignature.Return)); } const ctx = CompilationContext.current(); - const true_branch_block = ctx.makeBlock(true_branch_fn, TrueBlockSignature, blkctx, {}); - const false_branch_block = ctx.makeBlock(false_branch_fn, TrueBlockSignature, blkctx, {}); + const true_branch_block = ctx.makeBlock(TrueBlockSignature, &true_branch_fn, blkctx, {}); + const false_branch_block = ctx.makeBlock(TrueBlockSignature, &false_branch_fn, blkctx, {}); const loc = ctx.mlirCtx().location(@src()); const op = mlir.Operation.make(ctx.mlirCtx(), "stablehlo.if", .{ @@ -461,7 +461,7 @@ pub fn sort( inits[i * 2 + 1] = Tensor{ ._shape = arg_shape, ._id = undefined, ._donation = .no_buffer }; } const ctx = CompilationContext.current(); - const block = ctx.makeBlock(comp_fn, BodyS, blkctx, inits); + const block = ctx.makeBlock(BodyS, &comp_fn, blkctx, inits); var input_values: [@divExact(BodyS.nIn, 2)]mlir.Value = undefined; ctx.extractValues(&inputs, &input_values); @@ -486,6 +486,7 @@ pub fn sort( } pub const BlockSignature = struct { + Fn: type, BlkCtx: type, Args: type, FullArgs: type, @@ -560,7 +561,8 @@ fn _BlockSign(comptime func: anytype, blk_type: BlockType) BlockSignature { .no_args => void, }; - const xx = .{ + return .{ + .Fn = @TypeOf(func), .BlkCtx = BlkCtx, .Args = Args, .FullArgs = FullArgs, @@ -568,7 +570,6 @@ fn _BlockSign(comptime func: anytype, blk_type: BlockType) BlockSignature { .nIn = n_tensors, .nOut = staticCountTensors(fn_info.return_type.?) orelse @compileError("Can't use " ++ @typeName(fn_info.return_type.?) ++ " in an MLIR function, because it has a variable number of tensors"), }; - return xx; } pub fn staticIsOnlyTensors(comptime T: type) bool { diff --git a/zml/shape.zig b/zml/shape.zig index 1c6fce8..2cfc292 100644 --- a/zml/shape.zig +++ b/zml/shape.zig @@ -290,14 +290,14 @@ pub const Shape = struct { } fn axisFromInt(self: Shape, d: isize) u3 { - const rank_: i8 = self.rank(); - if (d < 0) { - return @intCast(d + rank_); + const rk: i8 = self.rank(); + if (d < -rk or d > rk) { + meta.panic("Tensor {} doesn't have dimension: {d}", .{ self, d }); } - if (d > rank_) { - meta.panic("Tensor doesn't have dimension: {d}", .{d}); - } - return @intCast(d); + return if (d < 0) + @intCast(d + rk) + else + @intCast(d); } fn axisFromTagMaybe(self: Shape, d: Tag) ?u3 { diff --git a/zml/tensor.zig b/zml/tensor.zig index 80382a9..4cde57e 100644 --- a/zml/tensor.zig +++ b/zml/tensor.zig @@ -1682,12 +1682,12 @@ pub const Tensor = struct { /// Returns a constant Tensor with the given value. pub fn constant(dimz: anytype, val: Data) Tensor { - const sh = Shape.init(dimz, val.dataType()); - const singleton_sh = Shape.init(.{}, val.dataType()); + const sh = Shape.init(dimz, val.dtype()); + const singleton_sh = Shape.init(.{}, val.dtype()); const ctx = CompilationContext.current().mlirCtx(); const loc = ctx.location(@src()).namedFmt(ctx, "dims={d}, value={}", .{ sh, val }); const result_type = mlir.ext.RankedTensorType.fromShape(ctx, singleton_sh); - const elem_type = mlir.ext.denseElementAttrType(val.dataType()); + const elem_type = mlir.ext.denseElementAttrType(val.dtype()); var constant_op = dialect.stablehlo.constant(ctx, result_type, elem_type, val.constSlice(), loc); if (sh.rank() > 0) { constant_op = dialect.stablehlo.broadcast_in_dim(ctx, constant_op.result(0), &.{}, mlir.ext.RankedTensorType.fromShape(ctx, sh).as(mlir.Type).?, loc); @@ -1925,20 +1925,14 @@ pub const Tensor = struct { /// - gatherValues(f: [a,b,c,d], .{.b, .c}, ind: [a, n,2])[a, n, d] == f[a, ind[a, n][0], ind[a, n][1], d] /// /// It is possible to use gatherValues without tags, but batching won't be available. - pub fn gatherValues(self: Tensor, axes_: anytype, indices: Tensor, opts: GatherOpts) Tensor { - // scoped_log.debug("gatherValues({}, {any}, {})", .{ self, axes_, indices }); - const AxesT = @TypeOf(axes_); - const axes_is_scalar = AxesT == EnumLiteral or AxesT == comptime_int or @typeInfo(AxesT) == .Int; + pub fn gatherValues(self: Tensor, coord_axes: anytype, indices: Tensor, opts: GatherOpts) Tensor { + // scoped_log.debug("gatherValues({}, {any}, {})", .{ self, coord_axes, indices }); + const single_coord, const coord_axes_ = _parseGatherCoord(self, coord_axes); - const val_axes = if (axes_is_scalar) - std.BoundedArray(u3, Tensor.MAX_RANK).fromSlice(&.{self.axis(axes_)}) catch unreachable - else - self.axes(axes_); - - meta.assert(val_axes.len > 0, "gatherValues expects 1 or more axes to operate one, received none. Example: `x.gatherValues(.a, indices, .{{}})`", .{}); - for (val_axes.constSlice(), 0..) |a, i| { + meta.assert(coord_axes_.len > 0, "gatherValues expects 1 or more axes to operate one, received none. Example: `x.gatherValues(.a, indices, .{{}})`", .{}); + for (coord_axes_.constSlice(), 0..) |a, i| { if (i > 0) { - meta.assert(a == val_axes.get(i - 1) + 1, "gatherValues expects 'axes_' too be sequential. But {any} aren't sequential in {}", .{ axes_, self }); + meta.assert(a == coord_axes_.get(i - 1) + 1, "gatherValues expects 'coord_axes' to be sequential. But {any} aren't sequential in {}", .{ coord_axes, self }); } } @@ -1946,14 +1940,14 @@ pub const Tensor = struct { var self_kind: std.BoundedArray(AxisKind, MAX_RANK) = .{}; var indices_batch_axes: Shape.DimsArray = .{}; for (self._shape.tags(), 0..self.rank()) |t, self_ax| { - const maybe_val_ax = std.mem.indexOfScalar(u3, val_axes.constSlice(), @intCast(self_ax)); + const maybe_coord_ax = std.mem.indexOfScalar(u3, coord_axes_.constSlice(), @intCast(self_ax)); if (indices._shape.hasTag(t)) |id_ax| { // tag is both in self and indices -> it's a batching dim // Note: tags are required for batching. self_kind.appendAssumeCapacity(.batching); indices_batch_axes.appendAssumeCapacity(id_ax); - meta.assert(maybe_val_ax == null, "gatherValues expects axes to be either batches or slices axes. Axis {s} has been found both in `axes={any}` and `indices={}`", .{ t, axes_, indices }); - } else if (maybe_val_ax) |_| { + meta.assert(maybe_coord_ax == null, "gatherValues expects axes to appear at most twice. Axis {s} has been found both in 'self={any}', in 'coord_axes_={any}' and in 'indices={}'", .{ self._shape._tags.get(self_ax), self, coord_axes, indices }); + } else if (maybe_coord_ax) |_| { // for gatherValues we collapsed all gathered axes // (contrary to gatherSlices where we collapse none) self_kind.appendAssumeCapacity(.collapsed); @@ -1962,14 +1956,14 @@ pub const Tensor = struct { } } - // When we receive several axes_ we need an extra dimension to store + // When we receive several coord_axes we need an extra dimension to store // one index per axis, which makes the coordinates of one value. // Otherwi se stablehlo uses the "indices.rank()" default value. - const index_coord_axis = if (axes_is_scalar) + const index_coord_axis = if (single_coord) indices.rank() else blk: { const ax = indices._shape.hasTag(.coord) orelse indices._shape.axis(-1); - meta.assert(indices.dim(ax) == val_axes.len, "gatherValues with axes={any}, expects indices to be of shape [..., {}], got: {}", .{ axes_, val_axes.len, indices }); + meta.assert(indices.dim(ax) == coord_axes_.len, "gatherValues with axes={any}, expects indices to be of shape [..., {}], got: {}", .{ coord_axes, coord_axes_.len, indices }); break :blk ax; }; @@ -1978,7 +1972,7 @@ pub const Tensor = struct { var res_kind: std.BoundedArray(AxisKind, MAX_RANK) = .{}; for (self_kind.constSlice(), 0..) |kind, ax_usize| { const ax: u3 = @intCast(ax_usize); - if (ax == val_axes.get(0)) { + if (ax == coord_axes_.get(0)) { // The first val_ax is special cause this is the place where we insert indices axes. for (indices._shape.tags(), 0..indices.rank()) |t, id_ax| { if (id_ax == index_coord_axis) continue; @@ -2004,7 +1998,7 @@ pub const Tensor = struct { // Sometimes the backend recognize this pattern, but not always. // So let us handle that. if (indices.count() == 1) { - return self.dynamicSlice1d(val_axes.get(0), 1, indices.flattenAll().squeeze(0)).reshape(res_shape); + return self.dynamicSlice1d(coord_axes_.get(0), 1, indices.flattenAll().squeeze(0)).reshape(res_shape); } var slice_dims: Shape.DimsArray = .{}; @@ -2247,6 +2241,256 @@ pub const Tensor = struct { try zml.testing.expectClose(expected, result, 0); } + pub const ScatterOpts = struct { + /// Promise scatter that all coordinates in `indices` are sorted, wrt to the final in memory offset. + /// Result is undefined if the promise is violated. + indices_are_sorted: bool = false, + + /// Promise scatter that slices don't overlap. + /// Result is undefined if the promise is violated. + indices_are_unique: bool = false, + + /// Function used to update previous value in `self` with values from `updates`. + /// If `update_fn` is not associative (ie the order of execution matters), + /// then you should make sure the slices don't overlap, + /// otherwise the result will depend on the runtime scheduling + /// of the operator which is backend specific. + update_fn: *const fn (*const anyopaque, Tensor, Tensor) Tensor = increment, + + /// Extra data that may be needed for a custom update function. + /// `override` and `increment` don't need it, leaving it to undefined works. + update_fn_ctx: *const anyopaque = undefined, + + pub fn increment(_: *const anyopaque, old_value: Tensor, new_value: Tensor) Tensor { + return old_value.add(new_value); + } + + pub fn override(_: *const anyopaque, old_value: Tensor, new_value: Tensor) Tensor { + _ = old_value; + return new_value; + } + }; + + /// Update the given tensors, by copying `values` into self slices. + /// The slices are chosen at runtime by interpreting indices as coordinates into `self`. + /// * `indices` represents a set of coordinates into `self`. + /// For the sake of simplifying the creation of `indices` tensor, + /// it's allowed to not mention a specific axis if the coordinate for this axis is always `0`. + /// Similarly to `gatherValues`, the coordinates are read from the `.coord` axis, or last axis if `.coord` is not found. + /// The coordinates represent the "top-left" corner of the slice to extract. + /// `indices.dim(.coord)` must match `coord_axes.len`. + /// Other axes identify one "slice" and they must be found inside `updates`. + /// + /// * the output tensor starts with axes from `indices`. + /// * if the input tensor has tagged axes, matching `indices` axes, + /// they will be considered "batching" axes. + /// + /// Sample input/output shapes: + /// * scatterSlices([A, B, C, D], .{b, c}, [N, 2], [N, B', C']) -> [A, B, C, D] + /// * scatterSlices(x(a,b,c,d), g(n,m), y[n,b,c]) [A,B,C,D] { + /// var z = x; + /// for (0..N) |n| { z[a,g[n,0]+b',g[n,1]+c',d] = y[n,a,b',c',d]; } + /// } + /// + /// **Warning**: if `opts.update_fn` is not associative not all calls to `scatterSlices` are sound. + /// In particular if you scatter overlapping slices, with `zml.Tensor.ScatterOpts.override`, + /// then the result will depend on the execution order that you don't control. + pub fn scatterSlices(self: Tensor, coord_axes: anytype, indices: Tensor, updates: Tensor, opts: ScatterOpts) Tensor { + const loc = @src(); + // scoped_log.debug("scatterSlices({}, {any}, {}, {})", .{ self, coord_axes, indices, updates }); + + meta.assert(self.dtype() == updates.dtype(), "scatterSlices expects input and 'updates' tensors to be of the same type, got {} and {}", .{ self.dtype(), updates.dtype() }); + + const single_coord, const coord_axes_ = _parseGatherCoord(self, coord_axes); + const AxisKind = enum { batching, update_window, inserted_window, window_id }; + var self_kind: std.BoundedArray(AxisKind, MAX_RANK) = .{}; + var indices_batch_axes: Shape.DimsArray = .{}; + for (self._shape.tags()) |t| { + if (updates._shape.hasTag(t)) |_| { + if (indices._shape.hasTag(t)) |id_ax| { + // tag is in self, indices and updates -> it's a batching dim + self_kind.appendAssumeCapacity(.batching); + indices_batch_axes.appendAssumeCapacity(id_ax); + } else { + self_kind.appendAssumeCapacity(.update_window); + } + } else { + self_kind.appendAssumeCapacity(.inserted_window); + } + } + // scoped_log.warn(" self_kind -> {any}", .{self_kind.constSlice()}); + + const index_coord_axis = if (single_coord) + indices.rank() + else blk: { + const ax = indices._shape.hasTag(.coord) orelse indices._shape.axis(-1); + meta.assert(indices.dim(ax) == coord_axes_.len, "scatterSlices({}, coord_axes={any}, indices, updates) expects 'indices' to be a tensor [..., {}], got {}", .{ self, coord_axes, coord_axes_.len, indices }); + + break :blk ax; + }; + if (indices.count() == 1) { + return self.dynamicUpdateSlice1d(updates, coord_axes_.get(0), indices.reshape(.{})); + } + + var up_kind: std.BoundedArray(AxisKind, MAX_RANK) = .{}; + // Note: we assume the scatter_dims appear in the same order inside indices and inside self. + for (updates._shape.tags(), 0..) |t, up_ax| { + if (self._shape.hasTag(t)) |self_ax| { + if (self_kind.get(self_ax) == .batching) { + up_kind.appendAssumeCapacity(.batching); + } else { + meta.assert(updates.dim(up_ax) <= self.dim(self_ax), "scatterSlices expects the slices described in 'updates' to fit inside 'self', but along axis .{s} it doesn't. Got self={}, updates={}.", .{ t, self, updates }); + up_kind.appendAssumeCapacity(.update_window); + } + } else if (t == Shape.TagUnknown or indices._shape.hasTag(t) != null) { + up_kind.appendAssumeCapacity(.window_id); + } else { + std.debug.panic("scatterSlices expects 'updates' to be made of axes from 'self={}' and from 'indices={}', got unknown tag {s} in {}", .{ self, indices, t, updates }); + } + } + const n_indices_axes = updates.rank() - _collectAxes(AxisKind, up_kind, .update_window).len; + if (single_coord) { + meta.assert(n_indices_axes == indices.rank(), "scatterSlices({}, {any}) expects 'updates' to contain all axes from 'indices', got indices={}, updates={}", .{ self, coord_axes, indices, updates }); + } else { + meta.assert(n_indices_axes == indices.rank() - 1, "scatterSlices({}, {any}) expects 'updates' to contain all-but-last axes from 'indices', got indices={}, updates={}", .{ self, coord_axes, indices, updates }); + } + + const ctx = self.getContext(); + const mlir_ctx = ctx.mlirCtx(); + + const _scalar: Tensor = .{ ._shape = Shape.init(.{}, self.dtype()), ._id = undefined }; + const UpdateS = ops.BlockSign(ScatterOpts.increment); + const update_block = ctx.makeBlock(UpdateS, opts.update_fn, opts.update_fn_ctx, .{ _scalar, _scalar }); + + const op = dialect.stablehlo.scatter( + mlir_ctx, + &.{self.value()}, + &.{indices.value()}, + &.{updates.value()}, + update_block, + .{ + .update_window_dims = _collectAxes(AxisKind, up_kind, .update_window).constSlice(), + .inserted_window_dims = _collectAxes(AxisKind, self_kind, .inserted_window).constSlice(), + .input_batching_dims = _collectAxes(AxisKind, self_kind, .batching).constSlice(), + .scatter_indices_batching_dims = indices_batch_axes.constSlice(), + .scatter_dims_to_operand_dims = toI64(coord_axes_.constSlice()), + .index_vector_dim = index_coord_axis, + .indices_are_sorted = opts.indices_are_sorted, + .unique_indices = opts.indices_are_unique, + }, + mlir_ctx.location(loc), + ); + return _result(self._shape, op.result(0)); + } + + test scatterSlices { + const zml = @import("zml.zig"); + const platform = zml.testing.env(); + + const Local = struct { + pub fn scatter(self: Tensor, coord_axes: Shape.AxesArray, indices: Tensor, updates: Tensor) Tensor { + return self.scatterSlices( + coord_axes.constSlice(), + indices, + updates, + .{ .update_fn = ScatterOpts.increment }, + ); + } + }; + + { + // Only test shapes + var comp = try zml.module.CompilationContext.init(std.heap.page_allocator, "test", platform); + defer comp.deinit(); + comp.activate(); + defer comp.deactivate(); + + inline for (.{ + .{ .{ .a = 10 }, .a, .{}, .{ .a = 3 } }, + .{ .{ .a = 10, .b = 20 }, .b, .{ .a = 10, .n = 8 }, .{ .a = 10, .n = 8, .b = 2 } }, + // I'm not sure I like this variant, cause `b` is not mentionned in updates. + // So 'stablehlo.scatter' is implicitly broadcasting the updates along `b` axis. + // OTOH asking the user to do the broadcasting isn't trivial cause they will need to do shape wrangling and that's annoying. + .{ .{ .a = 10, .b = 20 }, .a, .{ .n = 8 }, .{ .n = 8, .a = 2 } }, + .{ .{ .a = 10, .b = 20 }, .{ .b, .a }, .{ .n = 8, ._ = 2 }, .{ .n = 8, .a = 3, .b = 2 } }, + .{ .{ .a = 10, .b = 20 }, .{ .a, .b }, .{ .n = 8, ._ = 2 }, .{ .a = 3, .n = 8, .b = 2 } }, + }) |testcase| { + const x_shape, const axes_, const idx_shape, const updates_shapes = testcase; + const x = Tensor.constant(x_shape, .{ .f16 = 0 }); + const idx = Tensor.constant(idx_shape, .{ .i32 = 0 }); + const updates = Tensor.constant(updates_shapes, .{ .f16 = 0 }); + + const y = scatterSlices(x, axes_, idx, updates, .{}); + // Shape doesn't change with scatterSlices + try zml.testing.expectEqualShapes(x.shape(), y.shape()); + try std.testing.expect(y.value().owner().verify()); + } + } + // Test with actual values, no batching. + { + const a_host = try zml.HostBuffer.arange(std.testing.allocator, .{ .end = 9 }, .i32); + const a = (try zml.Buffer.from(platform, a_host.reshape(.{ 3, 3 }))).withTags(.{ .a, .b }); + defer a.deinit(); + a_host.deinit(std.testing.allocator); + + const scatter_indices = try zml.Buffer.fromArray(platform, [2][1]i32{ .{0}, .{2} }); + const updates = try zml.Buffer.fromArray(platform, [2][3]i32{ .{ 10, 20, 30 }, .{ 70, 80, 90 } }); + + const expected = [3][3]i32{ .{ 10, 21, 32 }, .{ 3, 4, 5 }, .{ 76, 87, 98 } }; + const result = try zml.testing.compileAndCall(platform, Local.scatter, .{ + a, + a.shape().axes(.{.a}), + scatter_indices.withTags(.{ .n, .coord }), + updates.withTags(.{ .n, .b }), + }); + try std.testing.expect(a.shape().eql(result.shape())); + try std.testing.expectEqual(expected, result.getValue(@TypeOf(expected))); + } + { + // Test with actual values and batching along axis .a + const operand = try zml.Buffer.constant(platform, Shape.init(.{ .a = 2, .b = 3, .c = 4, .d = 2 }, .u16), 0); + defer operand.deinit(); + const start_indices = (try zml.Buffer.fromArray( + platform, + [2][2][3][2]i32{ + .{ + .{ .{ 0, 0 }, .{ 1, 0 }, .{ 2, 1 } }, + .{ .{ 0, 1 }, .{ 1, 1 }, .{ 0, 9 } }, + }, + .{ + .{ .{ 0, 0 }, .{ 2, 1 }, .{ 2, 2 } }, + .{ .{ 1, 2 }, .{ 0, 1 }, .{ 1, 0 } }, + }, + }, + )).withTags(.{ .n, .a, .m, .coord }); + defer start_indices.deinit(); + + const values = try zml.Buffer.constant( + platform, + Shape.init(.{ .n = 2, .a = 2, .m = 3, .c = 2, .d = 2 }, .u16), + 1, + ); + defer values.deinit(); + + const result = try zml.testing.compileAndCall(platform, Local.scatter, .{ operand, operand.shape().axes(.{ .c, .b }), start_indices, values }); + + const expected = [2][3][4][2]u16{ + .{ + .{ .{ 2, 2 }, .{ 3, 3 }, .{ 1, 1 }, .{ 0, 0 } }, + .{ .{ 0, 0 }, .{ 0, 0 }, .{ 2, 2 }, .{ 2, 2 } }, + .{ .{ 0, 0 }, .{ 0, 0 }, .{ 1, 1 }, .{ 1, 1 } }, + }, + .{ + .{ .{ 0, 0 }, .{ 1, 1 }, .{ 1, 1 }, .{ 0, 0 } }, + .{ .{ 2, 2 }, .{ 3, 3 }, .{ 1, 1 }, .{ 0, 0 } }, + .{ .{ 0, 0 }, .{ 1, 1 }, .{ 1, 1 }, .{ 0, 0 } }, + }, + }; + try std.testing.expect(operand.shape().eql(result.shape())); + try std.testing.expectEqual(expected, result.getValue(@TypeOf(expected))); + } + } + /// Returns a Tensor containing the maximum over a given axis. pub fn max(self: Tensor, axis_: anytype) Tensor { const a = self.axis(axis_); @@ -3546,6 +3790,17 @@ fn _collectAxes(T: type, bounded_array: std.BoundedArray(T, Tensor.MAX_RANK), va /// Returns a mirrored version of T where each Tensor has been replaced by a Buffer. pub fn Bufferized(comptime T: type) type { - const M = meta.MapType(Tensor, Buffer); - return M.map(T); + return meta.MapType(Tensor, Buffer).map(T); +} + +fn _parseGatherCoord(self: Tensor, axes_: anytype) struct { bool, std.BoundedArray(u3, Tensor.MAX_RANK) } { + const AxesT = @TypeOf(axes_); + const axes_is_scalar = AxesT == EnumLiteral or AxesT == comptime_int or @typeInfo(AxesT) == .Int; + + const coord_axes = if (axes_is_scalar) + std.BoundedArray(u3, Tensor.MAX_RANK).fromSlice(&.{self.axis(axes_)}) catch unreachable + else + self.axes(axes_); + + return .{ axes_is_scalar, coord_axes }; }