Add multi‑axis, batched gatherValues support to tensor, shape, nn, quantization, and torch modules.
This commit is contained in:
parent
16e066ec69
commit
ccdf218961
@ -38,7 +38,7 @@ pub const TokenEmbedding = struct {
|
||||
pub fn forward(self: TokenEmbedding, idx: Tensor) Tensor {
|
||||
meta.assert(idx.dtype().isInteger(), "TokenEmbedding expects an integer input, received: {}", .{idx});
|
||||
meta.assert(self.weight.rank() == 2, "TokenEmbedding expects it's weight Tensor to be a 2D matrix, got {}", .{self.weight});
|
||||
return self.weight.gather1d(0, idx, .{});
|
||||
return self.weight.gatherValues(0, idx, .{});
|
||||
}
|
||||
};
|
||||
|
||||
@ -393,7 +393,7 @@ pub fn nearest(input: Tensor, scale_factor: []const f64) Tensor {
|
||||
const n = out_shape.dim(d);
|
||||
const ratio = meta.divFloat(f32, input.dim(d), n);
|
||||
const offsets = Tensor.arange(.{ .end = n }, .f32).addConstant(0.5).scale(ratio).floor().convert(.i32);
|
||||
res = res.gather1d(d, offsets, .{ .indices_are_sorted = true });
|
||||
res = res.gatherValues(d, offsets, .{ .indices_are_sorted = true });
|
||||
}
|
||||
return res;
|
||||
}
|
||||
@ -927,7 +927,7 @@ pub fn sampleTokens(activations: Tensor, opts: SamplingStrategy, rng: Tensor.Rng
|
||||
|
||||
// topk_idx is indices into topk.values ! so in the range [0, topk]
|
||||
// Convert for the original indices from the full [0, voc] range.
|
||||
const next_tokens = topk.indices.gather1d(.voc, topk_idx.squeeze(.topk), .{}).squeeze(.voc);
|
||||
const next_tokens = topk.indices.gatherValues(.voc, topk_idx.squeeze(.topk), .{});
|
||||
// log.debug("sampleTokens({}) -> {} -> {} -> {}", .{ activations, topk.indices, topk_idx, next_tokens });
|
||||
return .{ next_tokens, next_rng };
|
||||
}
|
||||
|
||||
@ -85,7 +85,7 @@ pub fn Q4_0(comptime dtype: zml.DataType) type {
|
||||
const indices = indices1.add(indices2);
|
||||
|
||||
// We select the values we are interested in with the indices, group them by pair and bitcast them to f16, then convert them to f32.
|
||||
const scales = input.gather1d(0, indices, .{ .indices_are_sorted = true }).reshape(.{ block_count, 2 }).bitCast(.f16).convert(.f32);
|
||||
const scales = input.gatherValues(0, indices, .{ .indices_are_sorted = true }).reshape(.{ block_count, 2 }).bitCast(.f16).convert(.f32);
|
||||
|
||||
return scales;
|
||||
}
|
||||
@ -107,7 +107,7 @@ pub fn Q4_0(comptime dtype: zml.DataType) type {
|
||||
// NOTE(Corendos): i4 is not supported by bitcast convert, so we need the following workaround.
|
||||
|
||||
// We select the values we are interested in with the indices, these are our quantized_weights.
|
||||
const quantized_weights = input.gather1d(0, indices, .{ .indices_are_sorted = true });
|
||||
const quantized_weights = input.gatherValues(0, indices, .{ .indices_are_sorted = true });
|
||||
const lb_weights = quantized_weights
|
||||
.logical(.And, zml.Tensor.constant(.{16 * block_count}, zml.Data.init(.u8, 0xf)))
|
||||
.bitCast(.i8);
|
||||
|
||||
@ -20,11 +20,9 @@ pub const Shape = struct {
|
||||
pub const TagUnknown = "_".ptr;
|
||||
const TagLast = "last".ptr;
|
||||
|
||||
// Note: we can't make those public otherwise `refAllDeclsRecursive`
|
||||
// will try to compile `std.BoundedArray.Writer` and will produce a compile error.
|
||||
const DimsArray = std.BoundedArray(i64, MAX_RANK);
|
||||
const TagsArray = std.BoundedArray(Tag, MAX_RANK);
|
||||
const AxesArray = std.BoundedArray(u3, MAX_RANK);
|
||||
pub const DimsArray = std.BoundedArray(i64, MAX_RANK);
|
||||
pub const TagsArray = std.BoundedArray(Tag, MAX_RANK);
|
||||
pub const AxesArray = std.BoundedArray(u3, MAX_RANK);
|
||||
|
||||
const UnknownTags: TagsArray = .{ .len = 0, .buffer = [_]Tag{TagUnknown} ** MAX_RANK };
|
||||
|
||||
|
||||
210
zml/tensor.zig
210
zml/tensor.zig
@ -654,7 +654,7 @@ pub const Tensor = struct {
|
||||
break :blk powers;
|
||||
};
|
||||
const values = Tensor.constantTensor(HostBuffer.fromArray(&powers)).withTags(.{.d});
|
||||
const counts = values.gather1d(.d, samples, .{}).sum(.d).bitCast(.u16);
|
||||
const counts = values.gatherValues(.d, samples, .{}).sum(.d).bitCast(.u16);
|
||||
const actual_dist = counts.reshape(target_dist.shape()).convert(target_dist.dtype()).divByConst(s.dim(.n));
|
||||
return .{ rng, .{ .mean = mean_, .variance = variance, .actual_dist = actual_dist } };
|
||||
}
|
||||
@ -1877,50 +1877,191 @@ pub const Tensor = struct {
|
||||
|
||||
pub const GatherOpts = struct { indices_are_sorted: bool = false };
|
||||
|
||||
/// Gathers along a given axis with 1d indices.
|
||||
/// ([A, B, C, D], .c, [N]) -> (A, B, N, D)
|
||||
/// ([A, B, C, D], .c, [N, M]) -> (A, B, N, M, D)
|
||||
pub fn gather1d(self: Tensor, axis_: anytype, indices: Tensor, opts: GatherOpts) Tensor {
|
||||
// TODO: handle batching dims
|
||||
meta.assert(self.rank() > 0 and self.rank() - 1 < MAX_RANK - indices.rank(), "Can't gather1d({}, {}) the resulting shape would have too many axes", .{ self, indices });
|
||||
/// For each coordinate in `indices`,
|
||||
/// `gatherValues` extracts a single value of the given tensor.
|
||||
///
|
||||
/// * axes_ is a single axis, or a tuple of axis: .b, or .{ .b, .c }
|
||||
/// * indices is an integer tensor
|
||||
/// * result is a tensor whose shape is similar to the input shape
|
||||
/// where the gathered axes have been replaced by axes from 'indices'.
|
||||
///
|
||||
/// Some example input for the base case where we work on one axis:
|
||||
/// - gatherValues(f:[a]->float, .a, ind:[n]->int)[n] == f[ind[n]]
|
||||
/// - gatherValues(f:[a, b], .a, ind:[n])[n, b] == f[ind[n], b]
|
||||
/// - gatherValues(f: [a,b,c], .{.b}, ind: [n,m])[a, n, m, c] == f[a, ind[n, m], c]
|
||||
///
|
||||
/// If an axis in common between `self` and `indices`,
|
||||
/// it is treated as a "batching" axis, meaning that semantically
|
||||
/// the operator is doing a gatherValues one time per dimension of this axis:
|
||||
/// - gatherValues(f: [a,b,c], .{.b}, ind: [a,n])[a, n] == f[a, ind[a, n]]
|
||||
///
|
||||
/// It is an error to have an axis present in `self`, `axes_` and `indices`.
|
||||
///
|
||||
/// If several axes are passed, then the last axis of indices is treated as coordinates:
|
||||
/// - gatherValues(f: [a,b,c], .{.b, .c}, ind: [n,2])[a, n] == f[a, ind[n][0], ind[n][1]]
|
||||
/// - gatherValues(f: [a,b,c,d], .{.b, .c}, ind: [a, n,2])[a, n, d] == f[a, ind[a, n][0], ind[a, n][1], d]
|
||||
///
|
||||
/// It is possible to use gatherValues without tags, but batching won't be available.
|
||||
pub fn gatherValues(self: Tensor, axes_: anytype, indices: Tensor, opts: GatherOpts) Tensor {
|
||||
// scoped_log.debug("gatherValues({}, {any}, {})", .{ self, axes_, indices });
|
||||
const AxesT = @TypeOf(axes_);
|
||||
const axes_is_scalar = AxesT == EnumLiteral or AxesT == comptime_int or @typeInfo(AxesT) == .Int;
|
||||
|
||||
const a = self.axis(axis_);
|
||||
if (indices.rank() > 1) {
|
||||
const flattened_gather = self.gather1d(a, indices.flattenAll(), opts);
|
||||
var tgt_shape = self._shape.drop(a);
|
||||
for (0..indices.rank()) |i| {
|
||||
tgt_shape._dims.insert(@intCast(a + i), indices.dim(i)) catch unreachable;
|
||||
tgt_shape._tags.insert(@intCast(a + i), indices._shape._tags.get(i)) catch unreachable;
|
||||
const val_axes = if (axes_is_scalar)
|
||||
std.BoundedArray(u3, Tensor.MAX_RANK).fromSlice(&.{self.axis(axes_)}) catch unreachable
|
||||
else
|
||||
self.axes(axes_);
|
||||
|
||||
meta.assert(val_axes.len > 0, "gatherValues expects 1 or more axes to operate one, received none. Example: `x.gatherValues(.a, indices, .{{}})`", .{});
|
||||
for (val_axes.constSlice(), 0..) |a, i| {
|
||||
if (i > 0) {
|
||||
meta.assert(a == val_axes.get(i - 1) + 1, "gatherValues expects 'axes_' too be sequential. But {any} aren't sequential in {}", .{ axes_, self });
|
||||
}
|
||||
return flattened_gather.reshape(tgt_shape);
|
||||
}
|
||||
|
||||
meta.assert(indices.rank() == 1, "gather1d expects 'indices' tensor rank to be equal to 1, got {}", .{indices.rank()});
|
||||
const AxisKind = enum { batching, offset, collapsed, indices };
|
||||
var self_kind: std.BoundedArray(AxisKind, MAX_RANK) = .{};
|
||||
var indices_batch_axes: Shape.DimsArray = .{};
|
||||
for (self._shape.tags(), 0..self.rank()) |t, self_ax| {
|
||||
const maybe_val_ax = std.mem.indexOfScalar(u3, val_axes.constSlice(), @intCast(self_ax));
|
||||
if (indices._shape.hasTag(t)) |id_ax| {
|
||||
// tag is both in self and indices -> it's a batching dim
|
||||
// Note: tags are required for batching.
|
||||
self_kind.appendAssumeCapacity(.batching);
|
||||
indices_batch_axes.appendAssumeCapacity(id_ax);
|
||||
meta.assert(maybe_val_ax == null, "gatherValues expects axes to be either batches or slices axes. Axis {s} has been found both in `axes={any}` and `indices={}`", .{ t, axes_, indices });
|
||||
} else if (maybe_val_ax) |_| {
|
||||
// for gatherValues we collapsed all gathered axes
|
||||
// (contrary to gatherSlices where we collapse none)
|
||||
self_kind.appendAssumeCapacity(.collapsed);
|
||||
} else {
|
||||
self_kind.appendAssumeCapacity(.offset);
|
||||
}
|
||||
}
|
||||
|
||||
const res_shape = self._shape.set(a, indices.dim(0));
|
||||
const slice_sizes = self._shape.set(a, 1);
|
||||
const offset_dims = Shape.range(self.rank(), self.dtype()).drop(a);
|
||||
// When we receive several axes_ we need an extra dimension to store
|
||||
// one index per axis, which makes the coordinates of one value.
|
||||
// Otherwi se stablehlo uses the "indices.rank()" default value.
|
||||
const index_coord_axis = if (axes_is_scalar)
|
||||
indices.rank()
|
||||
else blk: {
|
||||
const ax = indices._shape.hasTag(.coord) orelse indices._shape.axis(-1);
|
||||
meta.assert(indices.dim(ax) == val_axes.len, "gatherValues with axes={any}, expects indices to be of shape [..., {}], got: {}", .{ axes_, val_axes.len, indices });
|
||||
break :blk ax;
|
||||
};
|
||||
|
||||
// compute res shape
|
||||
var res_shape = Shape.init(.{}, self.dtype());
|
||||
var res_kind: std.BoundedArray(AxisKind, MAX_RANK) = .{};
|
||||
for (self_kind.constSlice(), 0..) |kind, ax_usize| {
|
||||
const ax: u3 = @intCast(ax_usize);
|
||||
if (ax == val_axes.get(0)) {
|
||||
// The first val_ax is special cause this is the place where we insert indices axes.
|
||||
for (indices._shape.tags(), 0..indices.rank()) |t, id_ax| {
|
||||
if (id_ax == index_coord_axis) continue;
|
||||
if (std.mem.indexOfScalar(i64, indices_batch_axes.constSlice(), @intCast(id_ax))) |_| {
|
||||
// batching dim are already in res
|
||||
continue;
|
||||
}
|
||||
|
||||
res_shape = res_shape.appendDim(indices.dim(id_ax), t);
|
||||
res_kind.appendAssumeCapacity(.indices);
|
||||
}
|
||||
}
|
||||
switch (kind) {
|
||||
.collapsed => continue,
|
||||
else => {
|
||||
res_shape = res_shape.appendDim(self.dim(ax), self._shape.tag(ax));
|
||||
res_kind.appendAssumeCapacity(kind);
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// This is not a gather, but a dynamicSlice.
|
||||
// Sometimes the backend recognize this pattern, but not always.
|
||||
// So let us handle that.
|
||||
if (indices.count() == 1) {
|
||||
return self.dynamicSlice1d(val_axes.get(0), 1, indices.flattenAll().squeeze(0)).reshape(res_shape);
|
||||
}
|
||||
|
||||
var slice_dims: Shape.DimsArray = .{};
|
||||
for (self_kind.constSlice(), self.dims()) |k, d| {
|
||||
slice_dims.appendAssumeCapacity(switch (k) {
|
||||
.batching, .collapsed => 1,
|
||||
.offset => d,
|
||||
.indices => unreachable,
|
||||
});
|
||||
}
|
||||
|
||||
// scoped_log.debug("gatherValues --> {} {any}", .{ res_shape, res_kind.constSlice() });
|
||||
const loc = self.getContext().mlirCtx().location(@src());
|
||||
const gather_op = dialect.stablehlo.gather(
|
||||
self.getContext().mlirCtx(),
|
||||
self.value(),
|
||||
indices.value(),
|
||||
slice_sizes.dims(),
|
||||
slice_dims.constSlice(),
|
||||
loc,
|
||||
.{
|
||||
.offset_dims = offset_dims.dims(),
|
||||
.collapsed_slice_dims = &.{a},
|
||||
.operand_batching_dims = &.{},
|
||||
.start_indices_batching_dims = &.{},
|
||||
.start_index_map = &.{a},
|
||||
.index_vector_dim = indices.rank(),
|
||||
.offset_dims = _collectAxes(AxisKind, res_kind, .offset).constSlice(),
|
||||
.collapsed_slice_dims = _collectAxes(AxisKind, self_kind, .collapsed).constSlice(),
|
||||
.operand_batching_dims = _collectAxes(AxisKind, self_kind, .batching).constSlice(),
|
||||
.start_indices_batching_dims = indices_batch_axes.constSlice(),
|
||||
.start_index_map = _collectAxes(AxisKind, self_kind, .collapsed).constSlice(),
|
||||
.index_vector_dim = index_coord_axis,
|
||||
.indices_are_sorted = opts.indices_are_sorted,
|
||||
},
|
||||
);
|
||||
|
||||
const mlir_shape = fromMlirValue(gather_op.result(0)).shape();
|
||||
meta.assert(mlir_shape.eql(res_shape), "gatherValues expects that batching indices appear in the same order in 'self' and 'indices', got: self={}, indices={}. You should transpose one or the other.", .{ self, indices });
|
||||
return _result(res_shape, gather_op.result(0));
|
||||
}
|
||||
|
||||
test gatherValues {
|
||||
const zml = @import("zml.zig");
|
||||
const platform = zml.testing.env();
|
||||
|
||||
{
|
||||
// Only test shapes
|
||||
var comp = try zml.module.CompilationContext.init(std.heap.page_allocator, "test", platform);
|
||||
defer comp.deinit();
|
||||
comp.activate();
|
||||
defer comp.deactivate();
|
||||
|
||||
inline for (.{
|
||||
.{ .{ .a = 10 }, .a, .{}, .{} },
|
||||
.{ .{ .a = 10 }, .a, .{ .n = 8 }, .{ .n = 8 } },
|
||||
.{ .{ .a = 10, .b = 20 }, .a, .{}, .{ .b = 20 } },
|
||||
.{ .{ .a = 10, .b = 20 }, .a, .{ .n = 8 }, .{ .n = 8, .b = 20 } },
|
||||
.{ .{ .a = 10, .b = 20 }, 0, .{ .n = 8 }, .{ .n = 8, .b = 20 } },
|
||||
// Favor val shape, instead of indices shape.
|
||||
.{ .{ .a = 10, .b = 20 }, .b, .{ .n = 8 }, .{ .a = 10, .n = 8 } },
|
||||
.{ .{ .a = 10, .b = 20, .c = 30 }, .b, .{ .n = 8 }, .{ .a = 10, .n = 8, .c = 30 } },
|
||||
// batching axes are implicits.
|
||||
// TODO: batched gather don't compile https://github.com/zml/zml/issues/400
|
||||
.{ .{ .a = 10, .b = 20 }, .b, .{ .a = 10 }, .{ .a = 10 } },
|
||||
.{ .{ .a = 10, .b = 20 }, .a, .{ .b = 20 }, .{ .b = 20 } },
|
||||
.{ .{ .a = 10, .b = 20 }, .b, .{ .a = 10, .n = 8 }, .{ .a = 10, .n = 8 } },
|
||||
// stablehlo.gather is biased toward indices shape (like gatherSlice).
|
||||
// This make it awkward to use when you have both batching dimension and new indices dimensions.
|
||||
// For now we reject those, and let user explicitly transpose self or indices if needed.
|
||||
// .{ .{ .a = 10, .b = 20 }, .b, .{ .n = 8, .a = 10 }, .{ .a = 10, .n = 8 } },
|
||||
// Also handle tuples
|
||||
.{ .{ .a = 10, .b = 20 }, .{ .a, .b }, .{ .n = 8, ._ = 2 }, .{ .n = 8 } },
|
||||
.{ .{ 10, 20 }, .{ -2, -1 }, .{ 8, 2 }, .{8} },
|
||||
// and 1-tuple
|
||||
.{ .{ .a = 10, .b = 20 }, .{.b}, .{ .n = 8, ._ = 1 }, .{ .a = 10, .n = 8 } },
|
||||
}) |testcase| {
|
||||
const x_shape, const tag, const idx_shape, const res_shape = testcase;
|
||||
const x = Tensor.constant(x_shape, .{ .f16 = 0 });
|
||||
const idx = Tensor.constant(idx_shape, .{ .i32 = 0 });
|
||||
const y = gatherValues(x, tag, idx, .{});
|
||||
try zml.testing.expectEqualShapes(Shape.init(res_shape, .f16), y.shape());
|
||||
try std.testing.expect(y.value().owner().verify());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Gathers slices along the given axes with runtime indices.
|
||||
/// * slice_shape represents the shape of the slices to extract,
|
||||
/// it must be smaller than original shape.
|
||||
@ -1938,9 +2079,9 @@ pub const Tensor = struct {
|
||||
/// * gatherSlices([A, B, C, D], .{.b=B', .c=C'}, [N, 2]) -> [N, A, B', C', D]
|
||||
/// * gatherSlices(x(a,b,c,d), .{.b=B', .c=C'}, g(n,m)) = z(n, a, b', c', d) = x(a, g(n, 0) + b', g(n, 1) + c', d)
|
||||
///
|
||||
/// Note: the axis order of the result is different from gather1d.
|
||||
/// Note: the axis order of the result is different from gatherValues.
|
||||
/// This is because gatherSlices, favorizes contiguous copy of the extracted slices,
|
||||
/// while gather1d, always copy values one by one, and as such don't have the same issues.
|
||||
/// while gatherValues, always copy values one by one, and as such don't have the same issues.
|
||||
/// In our example the contiguous dimension .d is not sliced
|
||||
/// and gatherSlices can copy data by group of C'*D elements.
|
||||
pub fn gatherSlices(self: Tensor, slice_shape: Shape, indices: Tensor, opts: GatherOpts) Tensor {
|
||||
@ -2390,6 +2531,8 @@ pub const Tensor = struct {
|
||||
return _result(new_shape, op.result(0));
|
||||
}
|
||||
|
||||
pub const DynSlice = struct { start: Tensor, len: i64 };
|
||||
|
||||
/// Slices a Tensor across many axes, with runtime known offsets.
|
||||
///
|
||||
/// Due to the nature of stablehlo, the length of the slices need to be known when compiling the IR.
|
||||
@ -2407,7 +2550,6 @@ pub const Tensor = struct {
|
||||
pub fn dynamicSlice(self: Tensor, slices_: anytype) Tensor {
|
||||
// TODO: the untagged api is a bit verbose. Should I allow: `Tensor(.{ 20,30,40}).dynamicSlice(.{.{}, .{ .start = b_off, .len = 12 }, .{}});` ??
|
||||
//
|
||||
const DynSlice = struct { start: Tensor, len: i64 };
|
||||
const slices, const slices_tags = Shape.parseStruct(DynSlice, slices_);
|
||||
|
||||
// TODO use slices and slices_tags for the format.
|
||||
@ -3328,3 +3470,13 @@ test shapesOf {
|
||||
try std.testing.expectEqual(fc2_bias_shape, shapes.fc2.bias);
|
||||
}
|
||||
}
|
||||
|
||||
fn _collectAxes(T: type, bounded_array: std.BoundedArray(T, Tensor.MAX_RANK), value: T) std.BoundedArray(i64, Tensor.MAX_RANK) {
|
||||
var res: std.BoundedArray(i64, Tensor.MAX_RANK) = .{};
|
||||
for (bounded_array.constSlice(), 0..) |v, ax| {
|
||||
if (v == value) {
|
||||
res.appendAssumeCapacity(@intCast(ax));
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
@ -186,7 +186,7 @@ pub fn roll(self: Tensor, shifts: []const i64, axes_: []const u8) Tensor {
|
||||
const start = @mod(self.dim(a) - shifts[0], self.dim(a));
|
||||
const idx = Tensor.arange(.{ .start = start, .end = start + self.dim(a) }, .f32);
|
||||
const divisor: f32 = @floatFromInt(self.dim(a));
|
||||
return self.gather1d(a, idx.fmod(divisor).convert(.i32), .{});
|
||||
return self.gatherValues(a, idx.fmod(divisor).convert(.i32), .{});
|
||||
}
|
||||
|
||||
test roll {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user