Implement scatterSlices functionality.

This commit is contained in:
Tarry Singh 2023-02-14 13:52:49 +00:00
parent 934acb35a8
commit 24a7c98476
10 changed files with 380 additions and 138 deletions

View File

@ -273,21 +273,68 @@ pub fn gather(
); );
} }
pub const ScatterArgs = struct {
update_window_dims: []const i64,
inserted_window_dims: []const i64,
scatter_dims_to_operand_dims: []const i64,
index_vector_dim: i64,
indices_are_sorted: bool = false,
unique_indices: bool = false,
};
fn elementTypeOrSelf(typ: mlir.Type) mlir.Type { fn elementTypeOrSelf(typ: mlir.Type) mlir.Type {
return if (typ.as(mlir.ShapedType)) |shaped| { return if (typ.as(mlir.ShapedType)) |shaped| {
return shaped.elementType(); return shaped.elementType();
} else typ; } else typ;
} }
pub const ScatterArgs = struct {
update_window_dims: []const i64,
inserted_window_dims: []const i64,
input_batching_dims: []const i64,
scatter_indices_batching_dims: []const i64,
scatter_dims_to_operand_dims: []const i64,
index_vector_dim: i64,
indices_are_sorted: bool = false,
unique_indices: bool = false,
pub fn getScatterDimensionNumbers(self: ScatterArgs, ctx: mlir.Context) mlir.Attribute {
return mlir.Attribute.wrap(
c.stablehloScatterDimensionNumbersGet(
ctx.inner(),
@intCast(self.update_window_dims.len),
self.update_window_dims.ptr,
@intCast(self.inserted_window_dims.len),
self.inserted_window_dims.ptr,
@intCast(self.input_batching_dims.len),
self.input_batching_dims.ptr,
@intCast(self.scatter_indices_batching_dims.len),
self.scatter_indices_batching_dims.ptr,
@intCast(self.scatter_dims_to_operand_dims.len),
self.scatter_dims_to_operand_dims.ptr,
self.index_vector_dim,
),
);
}
};
pub fn scatter(
ctx: mlir.Context,
inputs: []const mlir.Value,
scatter_indices: []const mlir.Value,
updates: []const mlir.Value,
update_block: mlir.Block,
args: ScatterArgs,
location: mlir.Location,
) mlir.Operation {
return mlir.Operation.make(
ctx,
"stablehlo.scatter",
.{
.variadic_operands = &.{ inputs, scatter_indices, updates },
.blocks = &.{update_block},
.attributes = &.{
.{ "scatter_dimension_numbers", args.getScatterDimensionNumbers(ctx) },
.{ "indices_are_sorted", mlir.BoolAttribute.init(ctx, args.indices_are_sorted).as(mlir.Attribute).? },
.{ "unique_indices", mlir.BoolAttribute.init(ctx, args.unique_indices).as(mlir.Attribute).? },
},
.result_type_inference = true,
.location = location,
},
);
}
pub fn iota(ctx: mlir.Context, dimension: i64, result_type: mlir.Type, location: mlir.Location) mlir.Operation { pub fn iota(ctx: mlir.Context, dimension: i64, result_type: mlir.Type, location: mlir.Location) mlir.Operation {
return mlir.Operation.make(ctx, "stablehlo.iota", .{ return mlir.Operation.make(ctx, "stablehlo.iota", .{
.operands = &.{}, .operands = &.{},
@ -915,82 +962,6 @@ pub const OutputOperandAliasAttribute = struct {
} }
}; };
pub const ScatterDimensionNumbersAttribute = struct {
_inner: c.MlirAttribute,
pub usingnamespace mlir.MlirHelpers(ScatterDimensionNumbersAttribute, .{
.is_a_fn = c.stablehloAttributeIsAScatterDimensionNumbers,
.is_null_fn = c.mlirAttributeIsNull,
.dump_fn = c.mlirAttributeDump,
.equal_fn = c.mlirAttributeEqual,
});
const Self = ScatterDimensionNumbersAttribute;
pub fn init(
ctx: mlir.Context,
update_window_dims: []const i64,
inserted_window_dims: []const i64,
input_batching_dims: []const i64,
scatter_indices_batching_dims: []const i64,
scatter_dims_to_operand_dims: []const i64,
index_vector_dim: i64,
) Self {
return Self.wrap(
c.stablehloScatterDimensionNumbersGet(
ctx.inner(),
@intCast(update_window_dims.len),
update_window_dims.ptr,
@intCast(inserted_window_dims.len),
inserted_window_dims.ptr,
@intCast(input_batching_dims.len),
input_batching_dims.ptr,
@intCast(scatter_indices_batching_dims.len),
scatter_indices_batching_dims.ptr,
@intCast(scatter_dims_to_operand_dims.len),
scatter_dims_to_operand_dims.ptr,
index_vector_dim,
),
);
}
pub fn getUpdateWindowDimsSize(self: Self) usize {
return @intCast(c.stablehloScatterDimensionNumbersGetUpdateWindowDimsSize(self.inner()));
}
pub fn getUpdateWindowDimsElem(self: Self, pos: usize) i64 {
return c.stablehloScatterDimensionNumbersGetUpdateWindowDimsElem(self.inner(), @intCast(pos));
}
pub fn getInsertedWindowDimsSize(self: Self) usize {
return @intCast(c.stablehloScatterDimensionNumbersGetInsertedWindowDimsSize(self.inner()));
}
pub fn getInsertedWindowDimsElem(self: Self, pos: usize) i64 {
return c.stablehloScatterDimensionNumbersGetInsertedWindowDimsElem(self.inner(), @intCast(pos));
}
pub fn getInputBatchingDimsSize(self: Self) usize {
return @intCast(c.stablehloScatterDimensionNumbersGetInputBatchingDimsSize(self.inner()));
}
pub fn getInputBatchingDimsElem(self: Self, pos: usize) i64 {
return c.stablehloScatterDimensionNumbersGetInputBatchingDimsElem(self.inner(), @intCast(pos));
}
pub fn getScatterIndicesBatchingDimsSize(self: Self) usize {
return @intCast(c.stablehloScatterDimensionNumbersGetScatterIndicesBatchingDimsSize(self.inner()));
}
pub fn getScatterIndicesBatchingDimsElem(self: Self, pos: usize) i64 {
return c.stablehloScatterDimensionNumbersGetScatterIndicesBatchingDimsElem(self.inner(), @intCast(pos));
}
pub fn getIndexVectorDim(self: Self) i64 {
// There really is "Scatter" missing in the function name
return c.stablehloDimensionNumbersGetIndexVectorDim(self.inner());
}
};
pub const PrecisionAttribute = struct { pub const PrecisionAttribute = struct {
_inner: c.MlirAttribute, _inner: c.MlirAttribute,

View File

@ -416,6 +416,10 @@ pub const BoolAttribute = struct {
pub fn value(self: Self) bool { pub fn value(self: Self) bool {
return c.mlirBoolAttrGetValue(self.inner()); return c.mlirBoolAttrGetValue(self.inner());
} }
pub fn asAttr(self: Self) Attribute {
return self.as(Attribute).?;
}
}; };
pub const TypeAttribute = struct { pub const TypeAttribute = struct {

View File

@ -49,18 +49,24 @@ pub const Context = struct {
Context.mlir_once.call(); Context.mlir_once.call();
var platforms = PlatformsMap.initFill(null); var platforms = PlatformsMap.initFill(null);
var num_platforms: u8 = 0;
var it = Context.apis.iterator(); var it = Context.apis.iterator();
while (it.next()) |entry| { while (it.next()) |entry| {
if (entry.value.*) |api| { if (entry.value.*) |api| {
const target = entry.key; const target = entry.key;
const p = Platform.init(target, api) catch continue; const p = Platform.init(target, api) catch |err| {
log.err("Failed to load platform .{s}: {}", .{ @tagName(target), err });
continue;
};
if (p.getDevices().len == 0) { if (p.getDevices().len == 0) {
log.err("No device found for platform {} !", .{target}); log.err("No device found for platform {} !", .{target});
continue; continue;
} }
platforms.set(target, p); platforms.set(target, p);
num_platforms += 1;
} }
} }
if (num_platforms == 0) return error.NotFound;
return .{ return .{
.platforms = platforms, .platforms = platforms,
}; };
@ -121,13 +127,13 @@ pub const Context = struct {
pub fn autoPlatform(self: *Context) Platform { pub fn autoPlatform(self: *Context) Platform {
// the last platform is the one that with the high enum number, so considered // the last platform is the one that with the high enum number, so considered
// to be the "best" one // to be the "best" one
var platform_: Platform = undefined; var platform_: ?Platform = null;
var iterator = self.platforms.iterator(); var iterator = self.platforms.iterator();
while (iterator.next()) |entry| { while (iterator.next()) |entry| {
if (entry.value.*) |p| { if (entry.value.*) |p| {
platform_ = p; platform_ = p;
} }
} }
return platform_; return platform_ orelse @panic("No platform found !");
} }
}; };

View File

@ -234,11 +234,11 @@ pub const Data = union(DataType) {
/// If the `dtype` and `@TypeOf(value)` are incompatible /// If the `dtype` and `@TypeOf(value)` are incompatible
/// or a cast from `value` to `FieldType(dtype)` would /// or a cast from `value` to `FieldType(dtype)` would
/// be lossy, a panic occurs. /// be lossy, a panic occurs.
pub fn init(dtype: DataType, value: anytype) Data { pub fn init(dtype_: DataType, value: anytype) Data {
const T = @TypeOf(value); const T = @TypeOf(value);
const Ti = @typeInfo(T); const Ti = @typeInfo(T);
return switch (dtype) { return switch (dtype_) {
.bool => switch (Ti) { .bool => switch (Ti) {
.Bool => .{ .bool = value }, .Bool => .{ .bool = value },
.ComptimeInt, .Int, .ComptimeFloat, .Float => .{ .bool = value != 0 }, .ComptimeInt, .Int, .ComptimeFloat, .Float => .{ .bool = value != 0 },
@ -302,7 +302,7 @@ pub const Data = union(DataType) {
try std.testing.expectEqual(C128.init(1, 2), Data.init(.c128, C64.init(1, 2)).c128); try std.testing.expectEqual(C128.init(1, 2), Data.init(.c128, C64.init(1, 2)).c128);
} }
pub fn dataType(self: Data) DataType { pub fn dtype(self: Data) DataType {
return std.meta.activeTag(self); return std.meta.activeTag(self);
} }
@ -327,7 +327,7 @@ pub const Data = union(DataType) {
}, },
else => {}, else => {},
} }
std.debug.panic("Unsupported conversion {} -> {s}", .{ self.dataType(), @typeName(T) }); std.debug.panic("Unsupported conversion {} -> {s}", .{ self.dtype(), @typeName(T) });
} }
}; };

View File

@ -269,7 +269,10 @@ pub fn MapType(From: type, To: type) type {
[]const map(ptr_info.child) []const map(ptr_info.child)
else else
[]map(ptr_info.child), []map(ptr_info.child),
.One => *map(ptr_info.child), .One => if (ptr_info.is_const)
*const map(ptr_info.child)
else
*map(ptr_info.child),
else => T, else => T,
}, },
.Optional => |opt_info| ?map(opt_info.child), .Optional => |opt_info| ?map(opt_info.child),
@ -446,8 +449,9 @@ pub fn visit(comptime cb: anytype, ctx: FnParam(cb, 0), v: anytype) void {
const Callback = @TypeOf(cb); const Callback = @TypeOf(cb);
@compileError("zml.meta.visit is expecting a pointer input to go with following callback signature: " ++ @typeName(Callback) ++ " but received: " ++ @typeName(T)); @compileError("zml.meta.visit is expecting a pointer input to go with following callback signature: " ++ @typeName(Callback) ++ " but received: " ++ @typeName(T));
} }
const ptr_info = type_info_v.Pointer; const ptr_info = type_info_v.Pointer;
if (@typeInfo(ptr_info.child) == .Fn) return;
if (ptr_info.child == anyopaque) return;
// This is important, because with trivial types like void, // This is important, because with trivial types like void,
// Zig sometimes decide to call `visit` at comptime, but can't do // Zig sometimes decide to call `visit` at comptime, but can't do
// the pointer wrangling logic at comptime. // the pointer wrangling logic at comptime.

View File

@ -155,7 +155,7 @@ pub const ext = struct {
pub const DenseIntOrFPElementsAttribute = struct { pub const DenseIntOrFPElementsAttribute = struct {
pub fn fromData(data: dtype.Data, result_type: mlir.Type) mlir.Attribute { pub fn fromData(data: dtype.Data, result_type: mlir.Type) mlir.Attribute {
return switch (data.dataType()) { return switch (data.dtype()) {
.bool => mlir.DenseIntOrFPElementsAttribute(.bool).init(result_type, data.constSlice()).as(mlir.Attribute).?, .bool => mlir.DenseIntOrFPElementsAttribute(.bool).init(result_type, data.constSlice()).as(mlir.Attribute).?,
.i8 => mlir.DenseIntOrFPElementsAttribute(.i8).init(result_type, data.constSlice()).as(mlir.Attribute).?, .i8 => mlir.DenseIntOrFPElementsAttribute(.i8).init(result_type, data.constSlice()).as(mlir.Attribute).?,
.i16 => mlir.DenseIntOrFPElementsAttribute(.i16).init(result_type, data.constSlice()).as(mlir.Attribute).?, .i16 => mlir.DenseIntOrFPElementsAttribute(.i16).init(result_type, data.constSlice()).as(mlir.Attribute).?,

View File

@ -140,8 +140,8 @@ pub const CompilationContext = struct {
/// `blkctx` represents values from outside the block that can be accessed inside the block. /// `blkctx` represents values from outside the block that can be accessed inside the block.
pub fn makeBlock( pub fn makeBlock(
self: *CompilationContext, self: *CompilationContext,
comptime func: anytype,
comptime S: ops.BlockSignature, comptime S: ops.BlockSignature,
func: *const S.Fn,
blkctx: S.BlkCtx, blkctx: S.BlkCtx,
args: S.Args, args: S.Args,
) mlir.Block { ) mlir.Block {
@ -996,7 +996,8 @@ fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, m
else => {}, else => {},
} }
const loaded_executable = try platform.pjrt_client.compile(platform.pjrt_api, arena, module, try options.encode(arena)); const options_bytes = try options.encode(arena);
const loaded_executable = try platform.pjrt_client.compile(platform.pjrt_api, arena, module, options_bytes);
errdefer unreachable; // errdefer loaded_executable.deinit(); errdefer unreachable; // errdefer loaded_executable.deinit();
if (platform.compilation_options.cache_location) |compilation_cache_location| { if (platform.compilation_options.cache_location) |compilation_cache_location| {

View File

@ -44,8 +44,8 @@ pub fn while_(
@compileError("cond_fn and body_fn signatures don't match ! " ++ @typeName(@TypeOf(cond_fn)) ++ " and " ++ @typeName(@TypeOf(body_fn))); @compileError("cond_fn and body_fn signatures don't match ! " ++ @typeName(@TypeOf(cond_fn)) ++ " and " ++ @typeName(@TypeOf(body_fn)));
} }
const ctx = CompilationContext.current(); const ctx = CompilationContext.current();
const cond_block = ctx.makeBlock(cond_fn, CondS, blkctx, inputs); const cond_block = ctx.makeBlock(CondS, &cond_fn, blkctx, inputs);
const body_block = ctx.makeBlock(body_fn, BodyS, blkctx, inputs); const body_block = ctx.makeBlock(BodyS, &body_fn, blkctx, inputs);
var input_values: [BodyS.nIn]mlir.Value = undefined; var input_values: [BodyS.nIn]mlir.Value = undefined;
ctx.extractValues(&inputs, &input_values); ctx.extractValues(&inputs, &input_values);
@ -136,7 +136,7 @@ pub fn reduce(
var init_values: [N]mlir.Value = undefined; var init_values: [N]mlir.Value = undefined;
ctx.extractValues(&inits, &init_values); ctx.extractValues(&inits, &init_values);
const body_block = ctx.makeBlock(body_fn, BodyS, {}, .{ inits, inits }); const body_block = ctx.makeBlock(BodyS, &body_fn, {}, .{ inits, inits });
const loc = ctx.mlirCtx().location(@src()); const loc = ctx.mlirCtx().location(@src());
@ -226,7 +226,7 @@ pub fn reduceWindow(
if (BodyS.Return != @TypeOf(inputs)) @compileError("reduce body function need to have the following signature `fn (left: T, right: T) T`, got: " ++ @typeName(body_fn)); if (BodyS.Return != @TypeOf(inputs)) @compileError("reduce body function need to have the following signature `fn (left: T, right: T) T`, got: " ++ @typeName(body_fn));
} }
const ctx = CompilationContext.current(); const ctx = CompilationContext.current();
const body_block = ctx.makeBlock(body_fn, BodyS, {}, .{ inits, inits }); const body_block = ctx.makeBlock(BodyS, &body_fn, {}, .{ inits, inits });
const N = comptime @divExact(BodyS.nIn, 2); const N = comptime @divExact(BodyS.nIn, 2);
var input_values: [N]mlir.Value = undefined; var input_values: [N]mlir.Value = undefined;
ctx.extractValues(&inputs, &input_values); ctx.extractValues(&inputs, &input_values);
@ -398,8 +398,8 @@ pub fn if_(
@compileError("true_branch_fn and false_branch_fn return types don't match ! " ++ @typeName(TrueBlockSignature.Return) ++ " and " ++ @typeName(FalseBlockSignature.Return)); @compileError("true_branch_fn and false_branch_fn return types don't match ! " ++ @typeName(TrueBlockSignature.Return) ++ " and " ++ @typeName(FalseBlockSignature.Return));
} }
const ctx = CompilationContext.current(); const ctx = CompilationContext.current();
const true_branch_block = ctx.makeBlock(true_branch_fn, TrueBlockSignature, blkctx, {}); const true_branch_block = ctx.makeBlock(TrueBlockSignature, &true_branch_fn, blkctx, {});
const false_branch_block = ctx.makeBlock(false_branch_fn, TrueBlockSignature, blkctx, {}); const false_branch_block = ctx.makeBlock(TrueBlockSignature, &false_branch_fn, blkctx, {});
const loc = ctx.mlirCtx().location(@src()); const loc = ctx.mlirCtx().location(@src());
const op = mlir.Operation.make(ctx.mlirCtx(), "stablehlo.if", .{ const op = mlir.Operation.make(ctx.mlirCtx(), "stablehlo.if", .{
@ -461,7 +461,7 @@ pub fn sort(
inits[i * 2 + 1] = Tensor{ ._shape = arg_shape, ._id = undefined, ._donation = .no_buffer }; inits[i * 2 + 1] = Tensor{ ._shape = arg_shape, ._id = undefined, ._donation = .no_buffer };
} }
const ctx = CompilationContext.current(); const ctx = CompilationContext.current();
const block = ctx.makeBlock(comp_fn, BodyS, blkctx, inits); const block = ctx.makeBlock(BodyS, &comp_fn, blkctx, inits);
var input_values: [@divExact(BodyS.nIn, 2)]mlir.Value = undefined; var input_values: [@divExact(BodyS.nIn, 2)]mlir.Value = undefined;
ctx.extractValues(&inputs, &input_values); ctx.extractValues(&inputs, &input_values);
@ -486,6 +486,7 @@ pub fn sort(
} }
pub const BlockSignature = struct { pub const BlockSignature = struct {
Fn: type,
BlkCtx: type, BlkCtx: type,
Args: type, Args: type,
FullArgs: type, FullArgs: type,
@ -560,7 +561,8 @@ fn _BlockSign(comptime func: anytype, blk_type: BlockType) BlockSignature {
.no_args => void, .no_args => void,
}; };
const xx = .{ return .{
.Fn = @TypeOf(func),
.BlkCtx = BlkCtx, .BlkCtx = BlkCtx,
.Args = Args, .Args = Args,
.FullArgs = FullArgs, .FullArgs = FullArgs,
@ -568,7 +570,6 @@ fn _BlockSign(comptime func: anytype, blk_type: BlockType) BlockSignature {
.nIn = n_tensors, .nIn = n_tensors,
.nOut = staticCountTensors(fn_info.return_type.?) orelse @compileError("Can't use " ++ @typeName(fn_info.return_type.?) ++ " in an MLIR function, because it has a variable number of tensors"), .nOut = staticCountTensors(fn_info.return_type.?) orelse @compileError("Can't use " ++ @typeName(fn_info.return_type.?) ++ " in an MLIR function, because it has a variable number of tensors"),
}; };
return xx;
} }
pub fn staticIsOnlyTensors(comptime T: type) bool { pub fn staticIsOnlyTensors(comptime T: type) bool {

View File

@ -290,14 +290,14 @@ pub const Shape = struct {
} }
fn axisFromInt(self: Shape, d: isize) u3 { fn axisFromInt(self: Shape, d: isize) u3 {
const rank_: i8 = self.rank(); const rk: i8 = self.rank();
if (d < 0) { if (d < -rk or d > rk) {
return @intCast(d + rank_); meta.panic("Tensor {} doesn't have dimension: {d}", .{ self, d });
} }
if (d > rank_) { return if (d < 0)
meta.panic("Tensor doesn't have dimension: {d}", .{d}); @intCast(d + rk)
} else
return @intCast(d); @intCast(d);
} }
fn axisFromTagMaybe(self: Shape, d: Tag) ?u3 { fn axisFromTagMaybe(self: Shape, d: Tag) ?u3 {

View File

@ -1682,12 +1682,12 @@ pub const Tensor = struct {
/// Returns a constant Tensor with the given value. /// Returns a constant Tensor with the given value.
pub fn constant(dimz: anytype, val: Data) Tensor { pub fn constant(dimz: anytype, val: Data) Tensor {
const sh = Shape.init(dimz, val.dataType()); const sh = Shape.init(dimz, val.dtype());
const singleton_sh = Shape.init(.{}, val.dataType()); const singleton_sh = Shape.init(.{}, val.dtype());
const ctx = CompilationContext.current().mlirCtx(); const ctx = CompilationContext.current().mlirCtx();
const loc = ctx.location(@src()).namedFmt(ctx, "dims={d}, value={}", .{ sh, val }); const loc = ctx.location(@src()).namedFmt(ctx, "dims={d}, value={}", .{ sh, val });
const result_type = mlir.ext.RankedTensorType.fromShape(ctx, singleton_sh); const result_type = mlir.ext.RankedTensorType.fromShape(ctx, singleton_sh);
const elem_type = mlir.ext.denseElementAttrType(val.dataType()); const elem_type = mlir.ext.denseElementAttrType(val.dtype());
var constant_op = dialect.stablehlo.constant(ctx, result_type, elem_type, val.constSlice(), loc); var constant_op = dialect.stablehlo.constant(ctx, result_type, elem_type, val.constSlice(), loc);
if (sh.rank() > 0) { if (sh.rank() > 0) {
constant_op = dialect.stablehlo.broadcast_in_dim(ctx, constant_op.result(0), &.{}, mlir.ext.RankedTensorType.fromShape(ctx, sh).as(mlir.Type).?, loc); constant_op = dialect.stablehlo.broadcast_in_dim(ctx, constant_op.result(0), &.{}, mlir.ext.RankedTensorType.fromShape(ctx, sh).as(mlir.Type).?, loc);
@ -1925,20 +1925,14 @@ pub const Tensor = struct {
/// - gatherValues(f: [a,b,c,d], .{.b, .c}, ind: [a, n,2])[a, n, d] == f[a, ind[a, n][0], ind[a, n][1], d] /// - gatherValues(f: [a,b,c,d], .{.b, .c}, ind: [a, n,2])[a, n, d] == f[a, ind[a, n][0], ind[a, n][1], d]
/// ///
/// It is possible to use gatherValues without tags, but batching won't be available. /// It is possible to use gatherValues without tags, but batching won't be available.
pub fn gatherValues(self: Tensor, axes_: anytype, indices: Tensor, opts: GatherOpts) Tensor { pub fn gatherValues(self: Tensor, coord_axes: anytype, indices: Tensor, opts: GatherOpts) Tensor {
// scoped_log.debug("gatherValues({}, {any}, {})", .{ self, axes_, indices }); // scoped_log.debug("gatherValues({}, {any}, {})", .{ self, coord_axes, indices });
const AxesT = @TypeOf(axes_); const single_coord, const coord_axes_ = _parseGatherCoord(self, coord_axes);
const axes_is_scalar = AxesT == EnumLiteral or AxesT == comptime_int or @typeInfo(AxesT) == .Int;
const val_axes = if (axes_is_scalar) meta.assert(coord_axes_.len > 0, "gatherValues expects 1 or more axes to operate one, received none. Example: `x.gatherValues(.a, indices, .{{}})`", .{});
std.BoundedArray(u3, Tensor.MAX_RANK).fromSlice(&.{self.axis(axes_)}) catch unreachable for (coord_axes_.constSlice(), 0..) |a, i| {
else
self.axes(axes_);
meta.assert(val_axes.len > 0, "gatherValues expects 1 or more axes to operate one, received none. Example: `x.gatherValues(.a, indices, .{{}})`", .{});
for (val_axes.constSlice(), 0..) |a, i| {
if (i > 0) { if (i > 0) {
meta.assert(a == val_axes.get(i - 1) + 1, "gatherValues expects 'axes_' too be sequential. But {any} aren't sequential in {}", .{ axes_, self }); meta.assert(a == coord_axes_.get(i - 1) + 1, "gatherValues expects 'coord_axes' to be sequential. But {any} aren't sequential in {}", .{ coord_axes, self });
} }
} }
@ -1946,14 +1940,14 @@ pub const Tensor = struct {
var self_kind: std.BoundedArray(AxisKind, MAX_RANK) = .{}; var self_kind: std.BoundedArray(AxisKind, MAX_RANK) = .{};
var indices_batch_axes: Shape.DimsArray = .{}; var indices_batch_axes: Shape.DimsArray = .{};
for (self._shape.tags(), 0..self.rank()) |t, self_ax| { for (self._shape.tags(), 0..self.rank()) |t, self_ax| {
const maybe_val_ax = std.mem.indexOfScalar(u3, val_axes.constSlice(), @intCast(self_ax)); const maybe_coord_ax = std.mem.indexOfScalar(u3, coord_axes_.constSlice(), @intCast(self_ax));
if (indices._shape.hasTag(t)) |id_ax| { if (indices._shape.hasTag(t)) |id_ax| {
// tag is both in self and indices -> it's a batching dim // tag is both in self and indices -> it's a batching dim
// Note: tags are required for batching. // Note: tags are required for batching.
self_kind.appendAssumeCapacity(.batching); self_kind.appendAssumeCapacity(.batching);
indices_batch_axes.appendAssumeCapacity(id_ax); indices_batch_axes.appendAssumeCapacity(id_ax);
meta.assert(maybe_val_ax == null, "gatherValues expects axes to be either batches or slices axes. Axis {s} has been found both in `axes={any}` and `indices={}`", .{ t, axes_, indices }); meta.assert(maybe_coord_ax == null, "gatherValues expects axes to appear at most twice. Axis {s} has been found both in 'self={any}', in 'coord_axes_={any}' and in 'indices={}'", .{ self._shape._tags.get(self_ax), self, coord_axes, indices });
} else if (maybe_val_ax) |_| { } else if (maybe_coord_ax) |_| {
// for gatherValues we collapsed all gathered axes // for gatherValues we collapsed all gathered axes
// (contrary to gatherSlices where we collapse none) // (contrary to gatherSlices where we collapse none)
self_kind.appendAssumeCapacity(.collapsed); self_kind.appendAssumeCapacity(.collapsed);
@ -1962,14 +1956,14 @@ pub const Tensor = struct {
} }
} }
// When we receive several axes_ we need an extra dimension to store // When we receive several coord_axes we need an extra dimension to store
// one index per axis, which makes the coordinates of one value. // one index per axis, which makes the coordinates of one value.
// Otherwi se stablehlo uses the "indices.rank()" default value. // Otherwi se stablehlo uses the "indices.rank()" default value.
const index_coord_axis = if (axes_is_scalar) const index_coord_axis = if (single_coord)
indices.rank() indices.rank()
else blk: { else blk: {
const ax = indices._shape.hasTag(.coord) orelse indices._shape.axis(-1); const ax = indices._shape.hasTag(.coord) orelse indices._shape.axis(-1);
meta.assert(indices.dim(ax) == val_axes.len, "gatherValues with axes={any}, expects indices to be of shape [..., {}], got: {}", .{ axes_, val_axes.len, indices }); meta.assert(indices.dim(ax) == coord_axes_.len, "gatherValues with axes={any}, expects indices to be of shape [..., {}], got: {}", .{ coord_axes, coord_axes_.len, indices });
break :blk ax; break :blk ax;
}; };
@ -1978,7 +1972,7 @@ pub const Tensor = struct {
var res_kind: std.BoundedArray(AxisKind, MAX_RANK) = .{}; var res_kind: std.BoundedArray(AxisKind, MAX_RANK) = .{};
for (self_kind.constSlice(), 0..) |kind, ax_usize| { for (self_kind.constSlice(), 0..) |kind, ax_usize| {
const ax: u3 = @intCast(ax_usize); const ax: u3 = @intCast(ax_usize);
if (ax == val_axes.get(0)) { if (ax == coord_axes_.get(0)) {
// The first val_ax is special cause this is the place where we insert indices axes. // The first val_ax is special cause this is the place where we insert indices axes.
for (indices._shape.tags(), 0..indices.rank()) |t, id_ax| { for (indices._shape.tags(), 0..indices.rank()) |t, id_ax| {
if (id_ax == index_coord_axis) continue; if (id_ax == index_coord_axis) continue;
@ -2004,7 +1998,7 @@ pub const Tensor = struct {
// Sometimes the backend recognize this pattern, but not always. // Sometimes the backend recognize this pattern, but not always.
// So let us handle that. // So let us handle that.
if (indices.count() == 1) { if (indices.count() == 1) {
return self.dynamicSlice1d(val_axes.get(0), 1, indices.flattenAll().squeeze(0)).reshape(res_shape); return self.dynamicSlice1d(coord_axes_.get(0), 1, indices.flattenAll().squeeze(0)).reshape(res_shape);
} }
var slice_dims: Shape.DimsArray = .{}; var slice_dims: Shape.DimsArray = .{};
@ -2247,6 +2241,256 @@ pub const Tensor = struct {
try zml.testing.expectClose(expected, result, 0); try zml.testing.expectClose(expected, result, 0);
} }
pub const ScatterOpts = struct {
/// Promise scatter that all coordinates in `indices` are sorted, wrt to the final in memory offset.
/// Result is undefined if the promise is violated.
indices_are_sorted: bool = false,
/// Promise scatter that slices don't overlap.
/// Result is undefined if the promise is violated.
indices_are_unique: bool = false,
/// Function used to update previous value in `self` with values from `updates`.
/// If `update_fn` is not associative (ie the order of execution matters),
/// then you should make sure the slices don't overlap,
/// otherwise the result will depend on the runtime scheduling
/// of the operator which is backend specific.
update_fn: *const fn (*const anyopaque, Tensor, Tensor) Tensor = increment,
/// Extra data that may be needed for a custom update function.
/// `override` and `increment` don't need it, leaving it to undefined works.
update_fn_ctx: *const anyopaque = undefined,
pub fn increment(_: *const anyopaque, old_value: Tensor, new_value: Tensor) Tensor {
return old_value.add(new_value);
}
pub fn override(_: *const anyopaque, old_value: Tensor, new_value: Tensor) Tensor {
_ = old_value;
return new_value;
}
};
/// Update the given tensors, by copying `values` into self slices.
/// The slices are chosen at runtime by interpreting indices as coordinates into `self`.
/// * `indices` represents a set of coordinates into `self`.
/// For the sake of simplifying the creation of `indices` tensor,
/// it's allowed to not mention a specific axis if the coordinate for this axis is always `0`.
/// Similarly to `gatherValues`, the coordinates are read from the `.coord` axis, or last axis if `.coord` is not found.
/// The coordinates represent the "top-left" corner of the slice to extract.
/// `indices.dim(.coord)` must match `coord_axes.len`.
/// Other axes identify one "slice" and they must be found inside `updates`.
///
/// * the output tensor starts with axes from `indices`.
/// * if the input tensor has tagged axes, matching `indices` axes,
/// they will be considered "batching" axes.
///
/// Sample input/output shapes:
/// * scatterSlices([A, B, C, D], .{b, c}, [N, 2], [N, B', C']) -> [A, B, C, D]
/// * scatterSlices(x(a,b,c,d), g(n,m), y[n,b,c]) [A,B,C,D] {
/// var z = x;
/// for (0..N) |n| { z[a,g[n,0]+b',g[n,1]+c',d] = y[n,a,b',c',d]; }
/// }
///
/// **Warning**: if `opts.update_fn` is not associative not all calls to `scatterSlices` are sound.
/// In particular if you scatter overlapping slices, with `zml.Tensor.ScatterOpts.override`,
/// then the result will depend on the execution order that you don't control.
pub fn scatterSlices(self: Tensor, coord_axes: anytype, indices: Tensor, updates: Tensor, opts: ScatterOpts) Tensor {
const loc = @src();
// scoped_log.debug("scatterSlices({}, {any}, {}, {})", .{ self, coord_axes, indices, updates });
meta.assert(self.dtype() == updates.dtype(), "scatterSlices expects input and 'updates' tensors to be of the same type, got {} and {}", .{ self.dtype(), updates.dtype() });
const single_coord, const coord_axes_ = _parseGatherCoord(self, coord_axes);
const AxisKind = enum { batching, update_window, inserted_window, window_id };
var self_kind: std.BoundedArray(AxisKind, MAX_RANK) = .{};
var indices_batch_axes: Shape.DimsArray = .{};
for (self._shape.tags()) |t| {
if (updates._shape.hasTag(t)) |_| {
if (indices._shape.hasTag(t)) |id_ax| {
// tag is in self, indices and updates -> it's a batching dim
self_kind.appendAssumeCapacity(.batching);
indices_batch_axes.appendAssumeCapacity(id_ax);
} else {
self_kind.appendAssumeCapacity(.update_window);
}
} else {
self_kind.appendAssumeCapacity(.inserted_window);
}
}
// scoped_log.warn(" self_kind -> {any}", .{self_kind.constSlice()});
const index_coord_axis = if (single_coord)
indices.rank()
else blk: {
const ax = indices._shape.hasTag(.coord) orelse indices._shape.axis(-1);
meta.assert(indices.dim(ax) == coord_axes_.len, "scatterSlices({}, coord_axes={any}, indices, updates) expects 'indices' to be a tensor [..., {}], got {}", .{ self, coord_axes, coord_axes_.len, indices });
break :blk ax;
};
if (indices.count() == 1) {
return self.dynamicUpdateSlice1d(updates, coord_axes_.get(0), indices.reshape(.{}));
}
var up_kind: std.BoundedArray(AxisKind, MAX_RANK) = .{};
// Note: we assume the scatter_dims appear in the same order inside indices and inside self.
for (updates._shape.tags(), 0..) |t, up_ax| {
if (self._shape.hasTag(t)) |self_ax| {
if (self_kind.get(self_ax) == .batching) {
up_kind.appendAssumeCapacity(.batching);
} else {
meta.assert(updates.dim(up_ax) <= self.dim(self_ax), "scatterSlices expects the slices described in 'updates' to fit inside 'self', but along axis .{s} it doesn't. Got self={}, updates={}.", .{ t, self, updates });
up_kind.appendAssumeCapacity(.update_window);
}
} else if (t == Shape.TagUnknown or indices._shape.hasTag(t) != null) {
up_kind.appendAssumeCapacity(.window_id);
} else {
std.debug.panic("scatterSlices expects 'updates' to be made of axes from 'self={}' and from 'indices={}', got unknown tag {s} in {}", .{ self, indices, t, updates });
}
}
const n_indices_axes = updates.rank() - _collectAxes(AxisKind, up_kind, .update_window).len;
if (single_coord) {
meta.assert(n_indices_axes == indices.rank(), "scatterSlices({}, {any}) expects 'updates' to contain all axes from 'indices', got indices={}, updates={}", .{ self, coord_axes, indices, updates });
} else {
meta.assert(n_indices_axes == indices.rank() - 1, "scatterSlices({}, {any}) expects 'updates' to contain all-but-last axes from 'indices', got indices={}, updates={}", .{ self, coord_axes, indices, updates });
}
const ctx = self.getContext();
const mlir_ctx = ctx.mlirCtx();
const _scalar: Tensor = .{ ._shape = Shape.init(.{}, self.dtype()), ._id = undefined };
const UpdateS = ops.BlockSign(ScatterOpts.increment);
const update_block = ctx.makeBlock(UpdateS, opts.update_fn, opts.update_fn_ctx, .{ _scalar, _scalar });
const op = dialect.stablehlo.scatter(
mlir_ctx,
&.{self.value()},
&.{indices.value()},
&.{updates.value()},
update_block,
.{
.update_window_dims = _collectAxes(AxisKind, up_kind, .update_window).constSlice(),
.inserted_window_dims = _collectAxes(AxisKind, self_kind, .inserted_window).constSlice(),
.input_batching_dims = _collectAxes(AxisKind, self_kind, .batching).constSlice(),
.scatter_indices_batching_dims = indices_batch_axes.constSlice(),
.scatter_dims_to_operand_dims = toI64(coord_axes_.constSlice()),
.index_vector_dim = index_coord_axis,
.indices_are_sorted = opts.indices_are_sorted,
.unique_indices = opts.indices_are_unique,
},
mlir_ctx.location(loc),
);
return _result(self._shape, op.result(0));
}
test scatterSlices {
const zml = @import("zml.zig");
const platform = zml.testing.env();
const Local = struct {
pub fn scatter(self: Tensor, coord_axes: Shape.AxesArray, indices: Tensor, updates: Tensor) Tensor {
return self.scatterSlices(
coord_axes.constSlice(),
indices,
updates,
.{ .update_fn = ScatterOpts.increment },
);
}
};
{
// Only test shapes
var comp = try zml.module.CompilationContext.init(std.heap.page_allocator, "test", platform);
defer comp.deinit();
comp.activate();
defer comp.deactivate();
inline for (.{
.{ .{ .a = 10 }, .a, .{}, .{ .a = 3 } },
.{ .{ .a = 10, .b = 20 }, .b, .{ .a = 10, .n = 8 }, .{ .a = 10, .n = 8, .b = 2 } },
// I'm not sure I like this variant, cause `b` is not mentionned in updates.
// So 'stablehlo.scatter' is implicitly broadcasting the updates along `b` axis.
// OTOH asking the user to do the broadcasting isn't trivial cause they will need to do shape wrangling and that's annoying.
.{ .{ .a = 10, .b = 20 }, .a, .{ .n = 8 }, .{ .n = 8, .a = 2 } },
.{ .{ .a = 10, .b = 20 }, .{ .b, .a }, .{ .n = 8, ._ = 2 }, .{ .n = 8, .a = 3, .b = 2 } },
.{ .{ .a = 10, .b = 20 }, .{ .a, .b }, .{ .n = 8, ._ = 2 }, .{ .a = 3, .n = 8, .b = 2 } },
}) |testcase| {
const x_shape, const axes_, const idx_shape, const updates_shapes = testcase;
const x = Tensor.constant(x_shape, .{ .f16 = 0 });
const idx = Tensor.constant(idx_shape, .{ .i32 = 0 });
const updates = Tensor.constant(updates_shapes, .{ .f16 = 0 });
const y = scatterSlices(x, axes_, idx, updates, .{});
// Shape doesn't change with scatterSlices
try zml.testing.expectEqualShapes(x.shape(), y.shape());
try std.testing.expect(y.value().owner().verify());
}
}
// Test with actual values, no batching.
{
const a_host = try zml.HostBuffer.arange(std.testing.allocator, .{ .end = 9 }, .i32);
const a = (try zml.Buffer.from(platform, a_host.reshape(.{ 3, 3 }))).withTags(.{ .a, .b });
defer a.deinit();
a_host.deinit(std.testing.allocator);
const scatter_indices = try zml.Buffer.fromArray(platform, [2][1]i32{ .{0}, .{2} });
const updates = try zml.Buffer.fromArray(platform, [2][3]i32{ .{ 10, 20, 30 }, .{ 70, 80, 90 } });
const expected = [3][3]i32{ .{ 10, 21, 32 }, .{ 3, 4, 5 }, .{ 76, 87, 98 } };
const result = try zml.testing.compileAndCall(platform, Local.scatter, .{
a,
a.shape().axes(.{.a}),
scatter_indices.withTags(.{ .n, .coord }),
updates.withTags(.{ .n, .b }),
});
try std.testing.expect(a.shape().eql(result.shape()));
try std.testing.expectEqual(expected, result.getValue(@TypeOf(expected)));
}
{
// Test with actual values and batching along axis .a
const operand = try zml.Buffer.constant(platform, Shape.init(.{ .a = 2, .b = 3, .c = 4, .d = 2 }, .u16), 0);
defer operand.deinit();
const start_indices = (try zml.Buffer.fromArray(
platform,
[2][2][3][2]i32{
.{
.{ .{ 0, 0 }, .{ 1, 0 }, .{ 2, 1 } },
.{ .{ 0, 1 }, .{ 1, 1 }, .{ 0, 9 } },
},
.{
.{ .{ 0, 0 }, .{ 2, 1 }, .{ 2, 2 } },
.{ .{ 1, 2 }, .{ 0, 1 }, .{ 1, 0 } },
},
},
)).withTags(.{ .n, .a, .m, .coord });
defer start_indices.deinit();
const values = try zml.Buffer.constant(
platform,
Shape.init(.{ .n = 2, .a = 2, .m = 3, .c = 2, .d = 2 }, .u16),
1,
);
defer values.deinit();
const result = try zml.testing.compileAndCall(platform, Local.scatter, .{ operand, operand.shape().axes(.{ .c, .b }), start_indices, values });
const expected = [2][3][4][2]u16{
.{
.{ .{ 2, 2 }, .{ 3, 3 }, .{ 1, 1 }, .{ 0, 0 } },
.{ .{ 0, 0 }, .{ 0, 0 }, .{ 2, 2 }, .{ 2, 2 } },
.{ .{ 0, 0 }, .{ 0, 0 }, .{ 1, 1 }, .{ 1, 1 } },
},
.{
.{ .{ 0, 0 }, .{ 1, 1 }, .{ 1, 1 }, .{ 0, 0 } },
.{ .{ 2, 2 }, .{ 3, 3 }, .{ 1, 1 }, .{ 0, 0 } },
.{ .{ 0, 0 }, .{ 1, 1 }, .{ 1, 1 }, .{ 0, 0 } },
},
};
try std.testing.expect(operand.shape().eql(result.shape()));
try std.testing.expectEqual(expected, result.getValue(@TypeOf(expected)));
}
}
/// Returns a Tensor containing the maximum over a given axis. /// Returns a Tensor containing the maximum over a given axis.
pub fn max(self: Tensor, axis_: anytype) Tensor { pub fn max(self: Tensor, axis_: anytype) Tensor {
const a = self.axis(axis_); const a = self.axis(axis_);
@ -3546,6 +3790,17 @@ fn _collectAxes(T: type, bounded_array: std.BoundedArray(T, Tensor.MAX_RANK), va
/// Returns a mirrored version of T where each Tensor has been replaced by a Buffer. /// Returns a mirrored version of T where each Tensor has been replaced by a Buffer.
pub fn Bufferized(comptime T: type) type { pub fn Bufferized(comptime T: type) type {
const M = meta.MapType(Tensor, Buffer); return meta.MapType(Tensor, Buffer).map(T);
return M.map(T); }
fn _parseGatherCoord(self: Tensor, axes_: anytype) struct { bool, std.BoundedArray(u3, Tensor.MAX_RANK) } {
const AxesT = @TypeOf(axes_);
const axes_is_scalar = AxesT == EnumLiteral or AxesT == comptime_int or @typeInfo(AxesT) == .Int;
const coord_axes = if (axes_is_scalar)
std.BoundedArray(u3, Tensor.MAX_RANK).fromSlice(&.{self.axis(axes_)}) catch unreachable
else
self.axes(axes_);
return .{ axes_is_scalar, coord_axes };
} }