Implement scatterSlices functionality.
This commit is contained in:
parent
934acb35a8
commit
24a7c98476
@ -273,21 +273,68 @@ pub fn gather(
|
||||
);
|
||||
}
|
||||
|
||||
pub const ScatterArgs = struct {
|
||||
update_window_dims: []const i64,
|
||||
inserted_window_dims: []const i64,
|
||||
scatter_dims_to_operand_dims: []const i64,
|
||||
index_vector_dim: i64,
|
||||
indices_are_sorted: bool = false,
|
||||
unique_indices: bool = false,
|
||||
};
|
||||
|
||||
fn elementTypeOrSelf(typ: mlir.Type) mlir.Type {
|
||||
return if (typ.as(mlir.ShapedType)) |shaped| {
|
||||
return shaped.elementType();
|
||||
} else typ;
|
||||
}
|
||||
|
||||
pub const ScatterArgs = struct {
|
||||
update_window_dims: []const i64,
|
||||
inserted_window_dims: []const i64,
|
||||
input_batching_dims: []const i64,
|
||||
scatter_indices_batching_dims: []const i64,
|
||||
scatter_dims_to_operand_dims: []const i64,
|
||||
index_vector_dim: i64,
|
||||
indices_are_sorted: bool = false,
|
||||
unique_indices: bool = false,
|
||||
|
||||
pub fn getScatterDimensionNumbers(self: ScatterArgs, ctx: mlir.Context) mlir.Attribute {
|
||||
return mlir.Attribute.wrap(
|
||||
c.stablehloScatterDimensionNumbersGet(
|
||||
ctx.inner(),
|
||||
@intCast(self.update_window_dims.len),
|
||||
self.update_window_dims.ptr,
|
||||
@intCast(self.inserted_window_dims.len),
|
||||
self.inserted_window_dims.ptr,
|
||||
@intCast(self.input_batching_dims.len),
|
||||
self.input_batching_dims.ptr,
|
||||
@intCast(self.scatter_indices_batching_dims.len),
|
||||
self.scatter_indices_batching_dims.ptr,
|
||||
@intCast(self.scatter_dims_to_operand_dims.len),
|
||||
self.scatter_dims_to_operand_dims.ptr,
|
||||
self.index_vector_dim,
|
||||
),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
pub fn scatter(
|
||||
ctx: mlir.Context,
|
||||
inputs: []const mlir.Value,
|
||||
scatter_indices: []const mlir.Value,
|
||||
updates: []const mlir.Value,
|
||||
update_block: mlir.Block,
|
||||
args: ScatterArgs,
|
||||
location: mlir.Location,
|
||||
) mlir.Operation {
|
||||
return mlir.Operation.make(
|
||||
ctx,
|
||||
"stablehlo.scatter",
|
||||
.{
|
||||
.variadic_operands = &.{ inputs, scatter_indices, updates },
|
||||
.blocks = &.{update_block},
|
||||
.attributes = &.{
|
||||
.{ "scatter_dimension_numbers", args.getScatterDimensionNumbers(ctx) },
|
||||
.{ "indices_are_sorted", mlir.BoolAttribute.init(ctx, args.indices_are_sorted).as(mlir.Attribute).? },
|
||||
.{ "unique_indices", mlir.BoolAttribute.init(ctx, args.unique_indices).as(mlir.Attribute).? },
|
||||
},
|
||||
.result_type_inference = true,
|
||||
.location = location,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
pub fn iota(ctx: mlir.Context, dimension: i64, result_type: mlir.Type, location: mlir.Location) mlir.Operation {
|
||||
return mlir.Operation.make(ctx, "stablehlo.iota", .{
|
||||
.operands = &.{},
|
||||
@ -915,82 +962,6 @@ pub const OutputOperandAliasAttribute = struct {
|
||||
}
|
||||
};
|
||||
|
||||
pub const ScatterDimensionNumbersAttribute = struct {
|
||||
_inner: c.MlirAttribute,
|
||||
|
||||
pub usingnamespace mlir.MlirHelpers(ScatterDimensionNumbersAttribute, .{
|
||||
.is_a_fn = c.stablehloAttributeIsAScatterDimensionNumbers,
|
||||
.is_null_fn = c.mlirAttributeIsNull,
|
||||
.dump_fn = c.mlirAttributeDump,
|
||||
.equal_fn = c.mlirAttributeEqual,
|
||||
});
|
||||
const Self = ScatterDimensionNumbersAttribute;
|
||||
|
||||
pub fn init(
|
||||
ctx: mlir.Context,
|
||||
update_window_dims: []const i64,
|
||||
inserted_window_dims: []const i64,
|
||||
input_batching_dims: []const i64,
|
||||
scatter_indices_batching_dims: []const i64,
|
||||
scatter_dims_to_operand_dims: []const i64,
|
||||
index_vector_dim: i64,
|
||||
) Self {
|
||||
return Self.wrap(
|
||||
c.stablehloScatterDimensionNumbersGet(
|
||||
ctx.inner(),
|
||||
@intCast(update_window_dims.len),
|
||||
update_window_dims.ptr,
|
||||
@intCast(inserted_window_dims.len),
|
||||
inserted_window_dims.ptr,
|
||||
@intCast(input_batching_dims.len),
|
||||
input_batching_dims.ptr,
|
||||
@intCast(scatter_indices_batching_dims.len),
|
||||
scatter_indices_batching_dims.ptr,
|
||||
@intCast(scatter_dims_to_operand_dims.len),
|
||||
scatter_dims_to_operand_dims.ptr,
|
||||
index_vector_dim,
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
pub fn getUpdateWindowDimsSize(self: Self) usize {
|
||||
return @intCast(c.stablehloScatterDimensionNumbersGetUpdateWindowDimsSize(self.inner()));
|
||||
}
|
||||
|
||||
pub fn getUpdateWindowDimsElem(self: Self, pos: usize) i64 {
|
||||
return c.stablehloScatterDimensionNumbersGetUpdateWindowDimsElem(self.inner(), @intCast(pos));
|
||||
}
|
||||
|
||||
pub fn getInsertedWindowDimsSize(self: Self) usize {
|
||||
return @intCast(c.stablehloScatterDimensionNumbersGetInsertedWindowDimsSize(self.inner()));
|
||||
}
|
||||
|
||||
pub fn getInsertedWindowDimsElem(self: Self, pos: usize) i64 {
|
||||
return c.stablehloScatterDimensionNumbersGetInsertedWindowDimsElem(self.inner(), @intCast(pos));
|
||||
}
|
||||
|
||||
pub fn getInputBatchingDimsSize(self: Self) usize {
|
||||
return @intCast(c.stablehloScatterDimensionNumbersGetInputBatchingDimsSize(self.inner()));
|
||||
}
|
||||
|
||||
pub fn getInputBatchingDimsElem(self: Self, pos: usize) i64 {
|
||||
return c.stablehloScatterDimensionNumbersGetInputBatchingDimsElem(self.inner(), @intCast(pos));
|
||||
}
|
||||
|
||||
pub fn getScatterIndicesBatchingDimsSize(self: Self) usize {
|
||||
return @intCast(c.stablehloScatterDimensionNumbersGetScatterIndicesBatchingDimsSize(self.inner()));
|
||||
}
|
||||
|
||||
pub fn getScatterIndicesBatchingDimsElem(self: Self, pos: usize) i64 {
|
||||
return c.stablehloScatterDimensionNumbersGetScatterIndicesBatchingDimsElem(self.inner(), @intCast(pos));
|
||||
}
|
||||
|
||||
pub fn getIndexVectorDim(self: Self) i64 {
|
||||
// There really is "Scatter" missing in the function name
|
||||
return c.stablehloDimensionNumbersGetIndexVectorDim(self.inner());
|
||||
}
|
||||
};
|
||||
|
||||
pub const PrecisionAttribute = struct {
|
||||
_inner: c.MlirAttribute,
|
||||
|
||||
|
||||
@ -416,6 +416,10 @@ pub const BoolAttribute = struct {
|
||||
pub fn value(self: Self) bool {
|
||||
return c.mlirBoolAttrGetValue(self.inner());
|
||||
}
|
||||
|
||||
pub fn asAttr(self: Self) Attribute {
|
||||
return self.as(Attribute).?;
|
||||
}
|
||||
};
|
||||
|
||||
pub const TypeAttribute = struct {
|
||||
|
||||
@ -49,18 +49,24 @@ pub const Context = struct {
|
||||
Context.mlir_once.call();
|
||||
|
||||
var platforms = PlatformsMap.initFill(null);
|
||||
var num_platforms: u8 = 0;
|
||||
var it = Context.apis.iterator();
|
||||
while (it.next()) |entry| {
|
||||
if (entry.value.*) |api| {
|
||||
const target = entry.key;
|
||||
const p = Platform.init(target, api) catch continue;
|
||||
const p = Platform.init(target, api) catch |err| {
|
||||
log.err("Failed to load platform .{s}: {}", .{ @tagName(target), err });
|
||||
continue;
|
||||
};
|
||||
if (p.getDevices().len == 0) {
|
||||
log.err("No device found for platform {} !", .{target});
|
||||
continue;
|
||||
}
|
||||
platforms.set(target, p);
|
||||
num_platforms += 1;
|
||||
}
|
||||
}
|
||||
if (num_platforms == 0) return error.NotFound;
|
||||
return .{
|
||||
.platforms = platforms,
|
||||
};
|
||||
@ -121,13 +127,13 @@ pub const Context = struct {
|
||||
pub fn autoPlatform(self: *Context) Platform {
|
||||
// the last platform is the one that with the high enum number, so considered
|
||||
// to be the "best" one
|
||||
var platform_: Platform = undefined;
|
||||
var platform_: ?Platform = null;
|
||||
var iterator = self.platforms.iterator();
|
||||
while (iterator.next()) |entry| {
|
||||
if (entry.value.*) |p| {
|
||||
platform_ = p;
|
||||
}
|
||||
}
|
||||
return platform_;
|
||||
return platform_ orelse @panic("No platform found !");
|
||||
}
|
||||
};
|
||||
|
||||
@ -234,11 +234,11 @@ pub const Data = union(DataType) {
|
||||
/// If the `dtype` and `@TypeOf(value)` are incompatible
|
||||
/// or a cast from `value` to `FieldType(dtype)` would
|
||||
/// be lossy, a panic occurs.
|
||||
pub fn init(dtype: DataType, value: anytype) Data {
|
||||
pub fn init(dtype_: DataType, value: anytype) Data {
|
||||
const T = @TypeOf(value);
|
||||
const Ti = @typeInfo(T);
|
||||
|
||||
return switch (dtype) {
|
||||
return switch (dtype_) {
|
||||
.bool => switch (Ti) {
|
||||
.Bool => .{ .bool = value },
|
||||
.ComptimeInt, .Int, .ComptimeFloat, .Float => .{ .bool = value != 0 },
|
||||
@ -302,7 +302,7 @@ pub const Data = union(DataType) {
|
||||
try std.testing.expectEqual(C128.init(1, 2), Data.init(.c128, C64.init(1, 2)).c128);
|
||||
}
|
||||
|
||||
pub fn dataType(self: Data) DataType {
|
||||
pub fn dtype(self: Data) DataType {
|
||||
return std.meta.activeTag(self);
|
||||
}
|
||||
|
||||
@ -327,7 +327,7 @@ pub const Data = union(DataType) {
|
||||
},
|
||||
else => {},
|
||||
}
|
||||
std.debug.panic("Unsupported conversion {} -> {s}", .{ self.dataType(), @typeName(T) });
|
||||
std.debug.panic("Unsupported conversion {} -> {s}", .{ self.dtype(), @typeName(T) });
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -269,7 +269,10 @@ pub fn MapType(From: type, To: type) type {
|
||||
[]const map(ptr_info.child)
|
||||
else
|
||||
[]map(ptr_info.child),
|
||||
.One => *map(ptr_info.child),
|
||||
.One => if (ptr_info.is_const)
|
||||
*const map(ptr_info.child)
|
||||
else
|
||||
*map(ptr_info.child),
|
||||
else => T,
|
||||
},
|
||||
.Optional => |opt_info| ?map(opt_info.child),
|
||||
@ -446,8 +449,9 @@ pub fn visit(comptime cb: anytype, ctx: FnParam(cb, 0), v: anytype) void {
|
||||
const Callback = @TypeOf(cb);
|
||||
@compileError("zml.meta.visit is expecting a pointer input to go with following callback signature: " ++ @typeName(Callback) ++ " but received: " ++ @typeName(T));
|
||||
}
|
||||
|
||||
const ptr_info = type_info_v.Pointer;
|
||||
if (@typeInfo(ptr_info.child) == .Fn) return;
|
||||
if (ptr_info.child == anyopaque) return;
|
||||
// This is important, because with trivial types like void,
|
||||
// Zig sometimes decide to call `visit` at comptime, but can't do
|
||||
// the pointer wrangling logic at comptime.
|
||||
|
||||
@ -155,7 +155,7 @@ pub const ext = struct {
|
||||
|
||||
pub const DenseIntOrFPElementsAttribute = struct {
|
||||
pub fn fromData(data: dtype.Data, result_type: mlir.Type) mlir.Attribute {
|
||||
return switch (data.dataType()) {
|
||||
return switch (data.dtype()) {
|
||||
.bool => mlir.DenseIntOrFPElementsAttribute(.bool).init(result_type, data.constSlice()).as(mlir.Attribute).?,
|
||||
.i8 => mlir.DenseIntOrFPElementsAttribute(.i8).init(result_type, data.constSlice()).as(mlir.Attribute).?,
|
||||
.i16 => mlir.DenseIntOrFPElementsAttribute(.i16).init(result_type, data.constSlice()).as(mlir.Attribute).?,
|
||||
|
||||
@ -140,8 +140,8 @@ pub const CompilationContext = struct {
|
||||
/// `blkctx` represents values from outside the block that can be accessed inside the block.
|
||||
pub fn makeBlock(
|
||||
self: *CompilationContext,
|
||||
comptime func: anytype,
|
||||
comptime S: ops.BlockSignature,
|
||||
func: *const S.Fn,
|
||||
blkctx: S.BlkCtx,
|
||||
args: S.Args,
|
||||
) mlir.Block {
|
||||
@ -996,7 +996,8 @@ fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, m
|
||||
else => {},
|
||||
}
|
||||
|
||||
const loaded_executable = try platform.pjrt_client.compile(platform.pjrt_api, arena, module, try options.encode(arena));
|
||||
const options_bytes = try options.encode(arena);
|
||||
const loaded_executable = try platform.pjrt_client.compile(platform.pjrt_api, arena, module, options_bytes);
|
||||
errdefer unreachable; // errdefer loaded_executable.deinit();
|
||||
|
||||
if (platform.compilation_options.cache_location) |compilation_cache_location| {
|
||||
|
||||
19
zml/ops.zig
19
zml/ops.zig
@ -44,8 +44,8 @@ pub fn while_(
|
||||
@compileError("cond_fn and body_fn signatures don't match ! " ++ @typeName(@TypeOf(cond_fn)) ++ " and " ++ @typeName(@TypeOf(body_fn)));
|
||||
}
|
||||
const ctx = CompilationContext.current();
|
||||
const cond_block = ctx.makeBlock(cond_fn, CondS, blkctx, inputs);
|
||||
const body_block = ctx.makeBlock(body_fn, BodyS, blkctx, inputs);
|
||||
const cond_block = ctx.makeBlock(CondS, &cond_fn, blkctx, inputs);
|
||||
const body_block = ctx.makeBlock(BodyS, &body_fn, blkctx, inputs);
|
||||
var input_values: [BodyS.nIn]mlir.Value = undefined;
|
||||
ctx.extractValues(&inputs, &input_values);
|
||||
|
||||
@ -136,7 +136,7 @@ pub fn reduce(
|
||||
var init_values: [N]mlir.Value = undefined;
|
||||
ctx.extractValues(&inits, &init_values);
|
||||
|
||||
const body_block = ctx.makeBlock(body_fn, BodyS, {}, .{ inits, inits });
|
||||
const body_block = ctx.makeBlock(BodyS, &body_fn, {}, .{ inits, inits });
|
||||
|
||||
const loc = ctx.mlirCtx().location(@src());
|
||||
|
||||
@ -226,7 +226,7 @@ pub fn reduceWindow(
|
||||
if (BodyS.Return != @TypeOf(inputs)) @compileError("reduce body function need to have the following signature `fn (left: T, right: T) T`, got: " ++ @typeName(body_fn));
|
||||
}
|
||||
const ctx = CompilationContext.current();
|
||||
const body_block = ctx.makeBlock(body_fn, BodyS, {}, .{ inits, inits });
|
||||
const body_block = ctx.makeBlock(BodyS, &body_fn, {}, .{ inits, inits });
|
||||
const N = comptime @divExact(BodyS.nIn, 2);
|
||||
var input_values: [N]mlir.Value = undefined;
|
||||
ctx.extractValues(&inputs, &input_values);
|
||||
@ -398,8 +398,8 @@ pub fn if_(
|
||||
@compileError("true_branch_fn and false_branch_fn return types don't match ! " ++ @typeName(TrueBlockSignature.Return) ++ " and " ++ @typeName(FalseBlockSignature.Return));
|
||||
}
|
||||
const ctx = CompilationContext.current();
|
||||
const true_branch_block = ctx.makeBlock(true_branch_fn, TrueBlockSignature, blkctx, {});
|
||||
const false_branch_block = ctx.makeBlock(false_branch_fn, TrueBlockSignature, blkctx, {});
|
||||
const true_branch_block = ctx.makeBlock(TrueBlockSignature, &true_branch_fn, blkctx, {});
|
||||
const false_branch_block = ctx.makeBlock(TrueBlockSignature, &false_branch_fn, blkctx, {});
|
||||
const loc = ctx.mlirCtx().location(@src());
|
||||
|
||||
const op = mlir.Operation.make(ctx.mlirCtx(), "stablehlo.if", .{
|
||||
@ -461,7 +461,7 @@ pub fn sort(
|
||||
inits[i * 2 + 1] = Tensor{ ._shape = arg_shape, ._id = undefined, ._donation = .no_buffer };
|
||||
}
|
||||
const ctx = CompilationContext.current();
|
||||
const block = ctx.makeBlock(comp_fn, BodyS, blkctx, inits);
|
||||
const block = ctx.makeBlock(BodyS, &comp_fn, blkctx, inits);
|
||||
var input_values: [@divExact(BodyS.nIn, 2)]mlir.Value = undefined;
|
||||
ctx.extractValues(&inputs, &input_values);
|
||||
|
||||
@ -486,6 +486,7 @@ pub fn sort(
|
||||
}
|
||||
|
||||
pub const BlockSignature = struct {
|
||||
Fn: type,
|
||||
BlkCtx: type,
|
||||
Args: type,
|
||||
FullArgs: type,
|
||||
@ -560,7 +561,8 @@ fn _BlockSign(comptime func: anytype, blk_type: BlockType) BlockSignature {
|
||||
.no_args => void,
|
||||
};
|
||||
|
||||
const xx = .{
|
||||
return .{
|
||||
.Fn = @TypeOf(func),
|
||||
.BlkCtx = BlkCtx,
|
||||
.Args = Args,
|
||||
.FullArgs = FullArgs,
|
||||
@ -568,7 +570,6 @@ fn _BlockSign(comptime func: anytype, blk_type: BlockType) BlockSignature {
|
||||
.nIn = n_tensors,
|
||||
.nOut = staticCountTensors(fn_info.return_type.?) orelse @compileError("Can't use " ++ @typeName(fn_info.return_type.?) ++ " in an MLIR function, because it has a variable number of tensors"),
|
||||
};
|
||||
return xx;
|
||||
}
|
||||
|
||||
pub fn staticIsOnlyTensors(comptime T: type) bool {
|
||||
|
||||
@ -290,14 +290,14 @@ pub const Shape = struct {
|
||||
}
|
||||
|
||||
fn axisFromInt(self: Shape, d: isize) u3 {
|
||||
const rank_: i8 = self.rank();
|
||||
if (d < 0) {
|
||||
return @intCast(d + rank_);
|
||||
const rk: i8 = self.rank();
|
||||
if (d < -rk or d > rk) {
|
||||
meta.panic("Tensor {} doesn't have dimension: {d}", .{ self, d });
|
||||
}
|
||||
if (d > rank_) {
|
||||
meta.panic("Tensor doesn't have dimension: {d}", .{d});
|
||||
}
|
||||
return @intCast(d);
|
||||
return if (d < 0)
|
||||
@intCast(d + rk)
|
||||
else
|
||||
@intCast(d);
|
||||
}
|
||||
|
||||
fn axisFromTagMaybe(self: Shape, d: Tag) ?u3 {
|
||||
|
||||
305
zml/tensor.zig
305
zml/tensor.zig
@ -1682,12 +1682,12 @@ pub const Tensor = struct {
|
||||
|
||||
/// Returns a constant Tensor with the given value.
|
||||
pub fn constant(dimz: anytype, val: Data) Tensor {
|
||||
const sh = Shape.init(dimz, val.dataType());
|
||||
const singleton_sh = Shape.init(.{}, val.dataType());
|
||||
const sh = Shape.init(dimz, val.dtype());
|
||||
const singleton_sh = Shape.init(.{}, val.dtype());
|
||||
const ctx = CompilationContext.current().mlirCtx();
|
||||
const loc = ctx.location(@src()).namedFmt(ctx, "dims={d}, value={}", .{ sh, val });
|
||||
const result_type = mlir.ext.RankedTensorType.fromShape(ctx, singleton_sh);
|
||||
const elem_type = mlir.ext.denseElementAttrType(val.dataType());
|
||||
const elem_type = mlir.ext.denseElementAttrType(val.dtype());
|
||||
var constant_op = dialect.stablehlo.constant(ctx, result_type, elem_type, val.constSlice(), loc);
|
||||
if (sh.rank() > 0) {
|
||||
constant_op = dialect.stablehlo.broadcast_in_dim(ctx, constant_op.result(0), &.{}, mlir.ext.RankedTensorType.fromShape(ctx, sh).as(mlir.Type).?, loc);
|
||||
@ -1925,20 +1925,14 @@ pub const Tensor = struct {
|
||||
/// - gatherValues(f: [a,b,c,d], .{.b, .c}, ind: [a, n,2])[a, n, d] == f[a, ind[a, n][0], ind[a, n][1], d]
|
||||
///
|
||||
/// It is possible to use gatherValues without tags, but batching won't be available.
|
||||
pub fn gatherValues(self: Tensor, axes_: anytype, indices: Tensor, opts: GatherOpts) Tensor {
|
||||
// scoped_log.debug("gatherValues({}, {any}, {})", .{ self, axes_, indices });
|
||||
const AxesT = @TypeOf(axes_);
|
||||
const axes_is_scalar = AxesT == EnumLiteral or AxesT == comptime_int or @typeInfo(AxesT) == .Int;
|
||||
pub fn gatherValues(self: Tensor, coord_axes: anytype, indices: Tensor, opts: GatherOpts) Tensor {
|
||||
// scoped_log.debug("gatherValues({}, {any}, {})", .{ self, coord_axes, indices });
|
||||
const single_coord, const coord_axes_ = _parseGatherCoord(self, coord_axes);
|
||||
|
||||
const val_axes = if (axes_is_scalar)
|
||||
std.BoundedArray(u3, Tensor.MAX_RANK).fromSlice(&.{self.axis(axes_)}) catch unreachable
|
||||
else
|
||||
self.axes(axes_);
|
||||
|
||||
meta.assert(val_axes.len > 0, "gatherValues expects 1 or more axes to operate one, received none. Example: `x.gatherValues(.a, indices, .{{}})`", .{});
|
||||
for (val_axes.constSlice(), 0..) |a, i| {
|
||||
meta.assert(coord_axes_.len > 0, "gatherValues expects 1 or more axes to operate one, received none. Example: `x.gatherValues(.a, indices, .{{}})`", .{});
|
||||
for (coord_axes_.constSlice(), 0..) |a, i| {
|
||||
if (i > 0) {
|
||||
meta.assert(a == val_axes.get(i - 1) + 1, "gatherValues expects 'axes_' too be sequential. But {any} aren't sequential in {}", .{ axes_, self });
|
||||
meta.assert(a == coord_axes_.get(i - 1) + 1, "gatherValues expects 'coord_axes' to be sequential. But {any} aren't sequential in {}", .{ coord_axes, self });
|
||||
}
|
||||
}
|
||||
|
||||
@ -1946,14 +1940,14 @@ pub const Tensor = struct {
|
||||
var self_kind: std.BoundedArray(AxisKind, MAX_RANK) = .{};
|
||||
var indices_batch_axes: Shape.DimsArray = .{};
|
||||
for (self._shape.tags(), 0..self.rank()) |t, self_ax| {
|
||||
const maybe_val_ax = std.mem.indexOfScalar(u3, val_axes.constSlice(), @intCast(self_ax));
|
||||
const maybe_coord_ax = std.mem.indexOfScalar(u3, coord_axes_.constSlice(), @intCast(self_ax));
|
||||
if (indices._shape.hasTag(t)) |id_ax| {
|
||||
// tag is both in self and indices -> it's a batching dim
|
||||
// Note: tags are required for batching.
|
||||
self_kind.appendAssumeCapacity(.batching);
|
||||
indices_batch_axes.appendAssumeCapacity(id_ax);
|
||||
meta.assert(maybe_val_ax == null, "gatherValues expects axes to be either batches or slices axes. Axis {s} has been found both in `axes={any}` and `indices={}`", .{ t, axes_, indices });
|
||||
} else if (maybe_val_ax) |_| {
|
||||
meta.assert(maybe_coord_ax == null, "gatherValues expects axes to appear at most twice. Axis {s} has been found both in 'self={any}', in 'coord_axes_={any}' and in 'indices={}'", .{ self._shape._tags.get(self_ax), self, coord_axes, indices });
|
||||
} else if (maybe_coord_ax) |_| {
|
||||
// for gatherValues we collapsed all gathered axes
|
||||
// (contrary to gatherSlices where we collapse none)
|
||||
self_kind.appendAssumeCapacity(.collapsed);
|
||||
@ -1962,14 +1956,14 @@ pub const Tensor = struct {
|
||||
}
|
||||
}
|
||||
|
||||
// When we receive several axes_ we need an extra dimension to store
|
||||
// When we receive several coord_axes we need an extra dimension to store
|
||||
// one index per axis, which makes the coordinates of one value.
|
||||
// Otherwi se stablehlo uses the "indices.rank()" default value.
|
||||
const index_coord_axis = if (axes_is_scalar)
|
||||
const index_coord_axis = if (single_coord)
|
||||
indices.rank()
|
||||
else blk: {
|
||||
const ax = indices._shape.hasTag(.coord) orelse indices._shape.axis(-1);
|
||||
meta.assert(indices.dim(ax) == val_axes.len, "gatherValues with axes={any}, expects indices to be of shape [..., {}], got: {}", .{ axes_, val_axes.len, indices });
|
||||
meta.assert(indices.dim(ax) == coord_axes_.len, "gatherValues with axes={any}, expects indices to be of shape [..., {}], got: {}", .{ coord_axes, coord_axes_.len, indices });
|
||||
break :blk ax;
|
||||
};
|
||||
|
||||
@ -1978,7 +1972,7 @@ pub const Tensor = struct {
|
||||
var res_kind: std.BoundedArray(AxisKind, MAX_RANK) = .{};
|
||||
for (self_kind.constSlice(), 0..) |kind, ax_usize| {
|
||||
const ax: u3 = @intCast(ax_usize);
|
||||
if (ax == val_axes.get(0)) {
|
||||
if (ax == coord_axes_.get(0)) {
|
||||
// The first val_ax is special cause this is the place where we insert indices axes.
|
||||
for (indices._shape.tags(), 0..indices.rank()) |t, id_ax| {
|
||||
if (id_ax == index_coord_axis) continue;
|
||||
@ -2004,7 +1998,7 @@ pub const Tensor = struct {
|
||||
// Sometimes the backend recognize this pattern, but not always.
|
||||
// So let us handle that.
|
||||
if (indices.count() == 1) {
|
||||
return self.dynamicSlice1d(val_axes.get(0), 1, indices.flattenAll().squeeze(0)).reshape(res_shape);
|
||||
return self.dynamicSlice1d(coord_axes_.get(0), 1, indices.flattenAll().squeeze(0)).reshape(res_shape);
|
||||
}
|
||||
|
||||
var slice_dims: Shape.DimsArray = .{};
|
||||
@ -2247,6 +2241,256 @@ pub const Tensor = struct {
|
||||
try zml.testing.expectClose(expected, result, 0);
|
||||
}
|
||||
|
||||
pub const ScatterOpts = struct {
|
||||
/// Promise scatter that all coordinates in `indices` are sorted, wrt to the final in memory offset.
|
||||
/// Result is undefined if the promise is violated.
|
||||
indices_are_sorted: bool = false,
|
||||
|
||||
/// Promise scatter that slices don't overlap.
|
||||
/// Result is undefined if the promise is violated.
|
||||
indices_are_unique: bool = false,
|
||||
|
||||
/// Function used to update previous value in `self` with values from `updates`.
|
||||
/// If `update_fn` is not associative (ie the order of execution matters),
|
||||
/// then you should make sure the slices don't overlap,
|
||||
/// otherwise the result will depend on the runtime scheduling
|
||||
/// of the operator which is backend specific.
|
||||
update_fn: *const fn (*const anyopaque, Tensor, Tensor) Tensor = increment,
|
||||
|
||||
/// Extra data that may be needed for a custom update function.
|
||||
/// `override` and `increment` don't need it, leaving it to undefined works.
|
||||
update_fn_ctx: *const anyopaque = undefined,
|
||||
|
||||
pub fn increment(_: *const anyopaque, old_value: Tensor, new_value: Tensor) Tensor {
|
||||
return old_value.add(new_value);
|
||||
}
|
||||
|
||||
pub fn override(_: *const anyopaque, old_value: Tensor, new_value: Tensor) Tensor {
|
||||
_ = old_value;
|
||||
return new_value;
|
||||
}
|
||||
};
|
||||
|
||||
/// Update the given tensors, by copying `values` into self slices.
|
||||
/// The slices are chosen at runtime by interpreting indices as coordinates into `self`.
|
||||
/// * `indices` represents a set of coordinates into `self`.
|
||||
/// For the sake of simplifying the creation of `indices` tensor,
|
||||
/// it's allowed to not mention a specific axis if the coordinate for this axis is always `0`.
|
||||
/// Similarly to `gatherValues`, the coordinates are read from the `.coord` axis, or last axis if `.coord` is not found.
|
||||
/// The coordinates represent the "top-left" corner of the slice to extract.
|
||||
/// `indices.dim(.coord)` must match `coord_axes.len`.
|
||||
/// Other axes identify one "slice" and they must be found inside `updates`.
|
||||
///
|
||||
/// * the output tensor starts with axes from `indices`.
|
||||
/// * if the input tensor has tagged axes, matching `indices` axes,
|
||||
/// they will be considered "batching" axes.
|
||||
///
|
||||
/// Sample input/output shapes:
|
||||
/// * scatterSlices([A, B, C, D], .{b, c}, [N, 2], [N, B', C']) -> [A, B, C, D]
|
||||
/// * scatterSlices(x(a,b,c,d), g(n,m), y[n,b,c]) [A,B,C,D] {
|
||||
/// var z = x;
|
||||
/// for (0..N) |n| { z[a,g[n,0]+b',g[n,1]+c',d] = y[n,a,b',c',d]; }
|
||||
/// }
|
||||
///
|
||||
/// **Warning**: if `opts.update_fn` is not associative not all calls to `scatterSlices` are sound.
|
||||
/// In particular if you scatter overlapping slices, with `zml.Tensor.ScatterOpts.override`,
|
||||
/// then the result will depend on the execution order that you don't control.
|
||||
pub fn scatterSlices(self: Tensor, coord_axes: anytype, indices: Tensor, updates: Tensor, opts: ScatterOpts) Tensor {
|
||||
const loc = @src();
|
||||
// scoped_log.debug("scatterSlices({}, {any}, {}, {})", .{ self, coord_axes, indices, updates });
|
||||
|
||||
meta.assert(self.dtype() == updates.dtype(), "scatterSlices expects input and 'updates' tensors to be of the same type, got {} and {}", .{ self.dtype(), updates.dtype() });
|
||||
|
||||
const single_coord, const coord_axes_ = _parseGatherCoord(self, coord_axes);
|
||||
const AxisKind = enum { batching, update_window, inserted_window, window_id };
|
||||
var self_kind: std.BoundedArray(AxisKind, MAX_RANK) = .{};
|
||||
var indices_batch_axes: Shape.DimsArray = .{};
|
||||
for (self._shape.tags()) |t| {
|
||||
if (updates._shape.hasTag(t)) |_| {
|
||||
if (indices._shape.hasTag(t)) |id_ax| {
|
||||
// tag is in self, indices and updates -> it's a batching dim
|
||||
self_kind.appendAssumeCapacity(.batching);
|
||||
indices_batch_axes.appendAssumeCapacity(id_ax);
|
||||
} else {
|
||||
self_kind.appendAssumeCapacity(.update_window);
|
||||
}
|
||||
} else {
|
||||
self_kind.appendAssumeCapacity(.inserted_window);
|
||||
}
|
||||
}
|
||||
// scoped_log.warn(" self_kind -> {any}", .{self_kind.constSlice()});
|
||||
|
||||
const index_coord_axis = if (single_coord)
|
||||
indices.rank()
|
||||
else blk: {
|
||||
const ax = indices._shape.hasTag(.coord) orelse indices._shape.axis(-1);
|
||||
meta.assert(indices.dim(ax) == coord_axes_.len, "scatterSlices({}, coord_axes={any}, indices, updates) expects 'indices' to be a tensor [..., {}], got {}", .{ self, coord_axes, coord_axes_.len, indices });
|
||||
|
||||
break :blk ax;
|
||||
};
|
||||
if (indices.count() == 1) {
|
||||
return self.dynamicUpdateSlice1d(updates, coord_axes_.get(0), indices.reshape(.{}));
|
||||
}
|
||||
|
||||
var up_kind: std.BoundedArray(AxisKind, MAX_RANK) = .{};
|
||||
// Note: we assume the scatter_dims appear in the same order inside indices and inside self.
|
||||
for (updates._shape.tags(), 0..) |t, up_ax| {
|
||||
if (self._shape.hasTag(t)) |self_ax| {
|
||||
if (self_kind.get(self_ax) == .batching) {
|
||||
up_kind.appendAssumeCapacity(.batching);
|
||||
} else {
|
||||
meta.assert(updates.dim(up_ax) <= self.dim(self_ax), "scatterSlices expects the slices described in 'updates' to fit inside 'self', but along axis .{s} it doesn't. Got self={}, updates={}.", .{ t, self, updates });
|
||||
up_kind.appendAssumeCapacity(.update_window);
|
||||
}
|
||||
} else if (t == Shape.TagUnknown or indices._shape.hasTag(t) != null) {
|
||||
up_kind.appendAssumeCapacity(.window_id);
|
||||
} else {
|
||||
std.debug.panic("scatterSlices expects 'updates' to be made of axes from 'self={}' and from 'indices={}', got unknown tag {s} in {}", .{ self, indices, t, updates });
|
||||
}
|
||||
}
|
||||
const n_indices_axes = updates.rank() - _collectAxes(AxisKind, up_kind, .update_window).len;
|
||||
if (single_coord) {
|
||||
meta.assert(n_indices_axes == indices.rank(), "scatterSlices({}, {any}) expects 'updates' to contain all axes from 'indices', got indices={}, updates={}", .{ self, coord_axes, indices, updates });
|
||||
} else {
|
||||
meta.assert(n_indices_axes == indices.rank() - 1, "scatterSlices({}, {any}) expects 'updates' to contain all-but-last axes from 'indices', got indices={}, updates={}", .{ self, coord_axes, indices, updates });
|
||||
}
|
||||
|
||||
const ctx = self.getContext();
|
||||
const mlir_ctx = ctx.mlirCtx();
|
||||
|
||||
const _scalar: Tensor = .{ ._shape = Shape.init(.{}, self.dtype()), ._id = undefined };
|
||||
const UpdateS = ops.BlockSign(ScatterOpts.increment);
|
||||
const update_block = ctx.makeBlock(UpdateS, opts.update_fn, opts.update_fn_ctx, .{ _scalar, _scalar });
|
||||
|
||||
const op = dialect.stablehlo.scatter(
|
||||
mlir_ctx,
|
||||
&.{self.value()},
|
||||
&.{indices.value()},
|
||||
&.{updates.value()},
|
||||
update_block,
|
||||
.{
|
||||
.update_window_dims = _collectAxes(AxisKind, up_kind, .update_window).constSlice(),
|
||||
.inserted_window_dims = _collectAxes(AxisKind, self_kind, .inserted_window).constSlice(),
|
||||
.input_batching_dims = _collectAxes(AxisKind, self_kind, .batching).constSlice(),
|
||||
.scatter_indices_batching_dims = indices_batch_axes.constSlice(),
|
||||
.scatter_dims_to_operand_dims = toI64(coord_axes_.constSlice()),
|
||||
.index_vector_dim = index_coord_axis,
|
||||
.indices_are_sorted = opts.indices_are_sorted,
|
||||
.unique_indices = opts.indices_are_unique,
|
||||
},
|
||||
mlir_ctx.location(loc),
|
||||
);
|
||||
return _result(self._shape, op.result(0));
|
||||
}
|
||||
|
||||
test scatterSlices {
|
||||
const zml = @import("zml.zig");
|
||||
const platform = zml.testing.env();
|
||||
|
||||
const Local = struct {
|
||||
pub fn scatter(self: Tensor, coord_axes: Shape.AxesArray, indices: Tensor, updates: Tensor) Tensor {
|
||||
return self.scatterSlices(
|
||||
coord_axes.constSlice(),
|
||||
indices,
|
||||
updates,
|
||||
.{ .update_fn = ScatterOpts.increment },
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
{
|
||||
// Only test shapes
|
||||
var comp = try zml.module.CompilationContext.init(std.heap.page_allocator, "test", platform);
|
||||
defer comp.deinit();
|
||||
comp.activate();
|
||||
defer comp.deactivate();
|
||||
|
||||
inline for (.{
|
||||
.{ .{ .a = 10 }, .a, .{}, .{ .a = 3 } },
|
||||
.{ .{ .a = 10, .b = 20 }, .b, .{ .a = 10, .n = 8 }, .{ .a = 10, .n = 8, .b = 2 } },
|
||||
// I'm not sure I like this variant, cause `b` is not mentionned in updates.
|
||||
// So 'stablehlo.scatter' is implicitly broadcasting the updates along `b` axis.
|
||||
// OTOH asking the user to do the broadcasting isn't trivial cause they will need to do shape wrangling and that's annoying.
|
||||
.{ .{ .a = 10, .b = 20 }, .a, .{ .n = 8 }, .{ .n = 8, .a = 2 } },
|
||||
.{ .{ .a = 10, .b = 20 }, .{ .b, .a }, .{ .n = 8, ._ = 2 }, .{ .n = 8, .a = 3, .b = 2 } },
|
||||
.{ .{ .a = 10, .b = 20 }, .{ .a, .b }, .{ .n = 8, ._ = 2 }, .{ .a = 3, .n = 8, .b = 2 } },
|
||||
}) |testcase| {
|
||||
const x_shape, const axes_, const idx_shape, const updates_shapes = testcase;
|
||||
const x = Tensor.constant(x_shape, .{ .f16 = 0 });
|
||||
const idx = Tensor.constant(idx_shape, .{ .i32 = 0 });
|
||||
const updates = Tensor.constant(updates_shapes, .{ .f16 = 0 });
|
||||
|
||||
const y = scatterSlices(x, axes_, idx, updates, .{});
|
||||
// Shape doesn't change with scatterSlices
|
||||
try zml.testing.expectEqualShapes(x.shape(), y.shape());
|
||||
try std.testing.expect(y.value().owner().verify());
|
||||
}
|
||||
}
|
||||
// Test with actual values, no batching.
|
||||
{
|
||||
const a_host = try zml.HostBuffer.arange(std.testing.allocator, .{ .end = 9 }, .i32);
|
||||
const a = (try zml.Buffer.from(platform, a_host.reshape(.{ 3, 3 }))).withTags(.{ .a, .b });
|
||||
defer a.deinit();
|
||||
a_host.deinit(std.testing.allocator);
|
||||
|
||||
const scatter_indices = try zml.Buffer.fromArray(platform, [2][1]i32{ .{0}, .{2} });
|
||||
const updates = try zml.Buffer.fromArray(platform, [2][3]i32{ .{ 10, 20, 30 }, .{ 70, 80, 90 } });
|
||||
|
||||
const expected = [3][3]i32{ .{ 10, 21, 32 }, .{ 3, 4, 5 }, .{ 76, 87, 98 } };
|
||||
const result = try zml.testing.compileAndCall(platform, Local.scatter, .{
|
||||
a,
|
||||
a.shape().axes(.{.a}),
|
||||
scatter_indices.withTags(.{ .n, .coord }),
|
||||
updates.withTags(.{ .n, .b }),
|
||||
});
|
||||
try std.testing.expect(a.shape().eql(result.shape()));
|
||||
try std.testing.expectEqual(expected, result.getValue(@TypeOf(expected)));
|
||||
}
|
||||
{
|
||||
// Test with actual values and batching along axis .a
|
||||
const operand = try zml.Buffer.constant(platform, Shape.init(.{ .a = 2, .b = 3, .c = 4, .d = 2 }, .u16), 0);
|
||||
defer operand.deinit();
|
||||
const start_indices = (try zml.Buffer.fromArray(
|
||||
platform,
|
||||
[2][2][3][2]i32{
|
||||
.{
|
||||
.{ .{ 0, 0 }, .{ 1, 0 }, .{ 2, 1 } },
|
||||
.{ .{ 0, 1 }, .{ 1, 1 }, .{ 0, 9 } },
|
||||
},
|
||||
.{
|
||||
.{ .{ 0, 0 }, .{ 2, 1 }, .{ 2, 2 } },
|
||||
.{ .{ 1, 2 }, .{ 0, 1 }, .{ 1, 0 } },
|
||||
},
|
||||
},
|
||||
)).withTags(.{ .n, .a, .m, .coord });
|
||||
defer start_indices.deinit();
|
||||
|
||||
const values = try zml.Buffer.constant(
|
||||
platform,
|
||||
Shape.init(.{ .n = 2, .a = 2, .m = 3, .c = 2, .d = 2 }, .u16),
|
||||
1,
|
||||
);
|
||||
defer values.deinit();
|
||||
|
||||
const result = try zml.testing.compileAndCall(platform, Local.scatter, .{ operand, operand.shape().axes(.{ .c, .b }), start_indices, values });
|
||||
|
||||
const expected = [2][3][4][2]u16{
|
||||
.{
|
||||
.{ .{ 2, 2 }, .{ 3, 3 }, .{ 1, 1 }, .{ 0, 0 } },
|
||||
.{ .{ 0, 0 }, .{ 0, 0 }, .{ 2, 2 }, .{ 2, 2 } },
|
||||
.{ .{ 0, 0 }, .{ 0, 0 }, .{ 1, 1 }, .{ 1, 1 } },
|
||||
},
|
||||
.{
|
||||
.{ .{ 0, 0 }, .{ 1, 1 }, .{ 1, 1 }, .{ 0, 0 } },
|
||||
.{ .{ 2, 2 }, .{ 3, 3 }, .{ 1, 1 }, .{ 0, 0 } },
|
||||
.{ .{ 0, 0 }, .{ 1, 1 }, .{ 1, 1 }, .{ 0, 0 } },
|
||||
},
|
||||
};
|
||||
try std.testing.expect(operand.shape().eql(result.shape()));
|
||||
try std.testing.expectEqual(expected, result.getValue(@TypeOf(expected)));
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a Tensor containing the maximum over a given axis.
|
||||
pub fn max(self: Tensor, axis_: anytype) Tensor {
|
||||
const a = self.axis(axis_);
|
||||
@ -3546,6 +3790,17 @@ fn _collectAxes(T: type, bounded_array: std.BoundedArray(T, Tensor.MAX_RANK), va
|
||||
|
||||
/// Returns a mirrored version of T where each Tensor has been replaced by a Buffer.
|
||||
pub fn Bufferized(comptime T: type) type {
|
||||
const M = meta.MapType(Tensor, Buffer);
|
||||
return M.map(T);
|
||||
return meta.MapType(Tensor, Buffer).map(T);
|
||||
}
|
||||
|
||||
fn _parseGatherCoord(self: Tensor, axes_: anytype) struct { bool, std.BoundedArray(u3, Tensor.MAX_RANK) } {
|
||||
const AxesT = @TypeOf(axes_);
|
||||
const axes_is_scalar = AxesT == EnumLiteral or AxesT == comptime_int or @typeInfo(AxesT) == .Int;
|
||||
|
||||
const coord_axes = if (axes_is_scalar)
|
||||
std.BoundedArray(u3, Tensor.MAX_RANK).fromSlice(&.{self.axis(axes_)}) catch unreachable
|
||||
else
|
||||
self.axes(axes_);
|
||||
|
||||
return .{ axes_is_scalar, coord_axes };
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user