Add multi‑axis, batched gatherValues support to tensor, shape, nn, quantization, and torch modules.

This commit is contained in:
Tarry Singh 2023-01-18 12:03:48 +00:00
parent 16e066ec69
commit ccdf218961
5 changed files with 190 additions and 40 deletions

View File

@ -38,7 +38,7 @@ pub const TokenEmbedding = struct {
pub fn forward(self: TokenEmbedding, idx: Tensor) Tensor {
meta.assert(idx.dtype().isInteger(), "TokenEmbedding expects an integer input, received: {}", .{idx});
meta.assert(self.weight.rank() == 2, "TokenEmbedding expects it's weight Tensor to be a 2D matrix, got {}", .{self.weight});
return self.weight.gather1d(0, idx, .{});
return self.weight.gatherValues(0, idx, .{});
}
};
@ -393,7 +393,7 @@ pub fn nearest(input: Tensor, scale_factor: []const f64) Tensor {
const n = out_shape.dim(d);
const ratio = meta.divFloat(f32, input.dim(d), n);
const offsets = Tensor.arange(.{ .end = n }, .f32).addConstant(0.5).scale(ratio).floor().convert(.i32);
res = res.gather1d(d, offsets, .{ .indices_are_sorted = true });
res = res.gatherValues(d, offsets, .{ .indices_are_sorted = true });
}
return res;
}
@ -927,7 +927,7 @@ pub fn sampleTokens(activations: Tensor, opts: SamplingStrategy, rng: Tensor.Rng
// topk_idx is indices into topk.values ! so in the range [0, topk]
// Convert for the original indices from the full [0, voc] range.
const next_tokens = topk.indices.gather1d(.voc, topk_idx.squeeze(.topk), .{}).squeeze(.voc);
const next_tokens = topk.indices.gatherValues(.voc, topk_idx.squeeze(.topk), .{});
// log.debug("sampleTokens({}) -> {} -> {} -> {}", .{ activations, topk.indices, topk_idx, next_tokens });
return .{ next_tokens, next_rng };
}

View File

@ -85,7 +85,7 @@ pub fn Q4_0(comptime dtype: zml.DataType) type {
const indices = indices1.add(indices2);
// We select the values we are interested in with the indices, group them by pair and bitcast them to f16, then convert them to f32.
const scales = input.gather1d(0, indices, .{ .indices_are_sorted = true }).reshape(.{ block_count, 2 }).bitCast(.f16).convert(.f32);
const scales = input.gatherValues(0, indices, .{ .indices_are_sorted = true }).reshape(.{ block_count, 2 }).bitCast(.f16).convert(.f32);
return scales;
}
@ -107,7 +107,7 @@ pub fn Q4_0(comptime dtype: zml.DataType) type {
// NOTE(Corendos): i4 is not supported by bitcast convert, so we need the following workaround.
// We select the values we are interested in with the indices, these are our quantized_weights.
const quantized_weights = input.gather1d(0, indices, .{ .indices_are_sorted = true });
const quantized_weights = input.gatherValues(0, indices, .{ .indices_are_sorted = true });
const lb_weights = quantized_weights
.logical(.And, zml.Tensor.constant(.{16 * block_count}, zml.Data.init(.u8, 0xf)))
.bitCast(.i8);

View File

@ -20,11 +20,9 @@ pub const Shape = struct {
pub const TagUnknown = "_".ptr;
const TagLast = "last".ptr;
// Note: we can't make those public otherwise `refAllDeclsRecursive`
// will try to compile `std.BoundedArray.Writer` and will produce a compile error.
const DimsArray = std.BoundedArray(i64, MAX_RANK);
const TagsArray = std.BoundedArray(Tag, MAX_RANK);
const AxesArray = std.BoundedArray(u3, MAX_RANK);
pub const DimsArray = std.BoundedArray(i64, MAX_RANK);
pub const TagsArray = std.BoundedArray(Tag, MAX_RANK);
pub const AxesArray = std.BoundedArray(u3, MAX_RANK);
const UnknownTags: TagsArray = .{ .len = 0, .buffer = [_]Tag{TagUnknown} ** MAX_RANK };

View File

@ -654,7 +654,7 @@ pub const Tensor = struct {
break :blk powers;
};
const values = Tensor.constantTensor(HostBuffer.fromArray(&powers)).withTags(.{.d});
const counts = values.gather1d(.d, samples, .{}).sum(.d).bitCast(.u16);
const counts = values.gatherValues(.d, samples, .{}).sum(.d).bitCast(.u16);
const actual_dist = counts.reshape(target_dist.shape()).convert(target_dist.dtype()).divByConst(s.dim(.n));
return .{ rng, .{ .mean = mean_, .variance = variance, .actual_dist = actual_dist } };
}
@ -1877,50 +1877,191 @@ pub const Tensor = struct {
pub const GatherOpts = struct { indices_are_sorted: bool = false };
/// Gathers along a given axis with 1d indices.
/// ([A, B, C, D], .c, [N]) -> (A, B, N, D)
/// ([A, B, C, D], .c, [N, M]) -> (A, B, N, M, D)
pub fn gather1d(self: Tensor, axis_: anytype, indices: Tensor, opts: GatherOpts) Tensor {
// TODO: handle batching dims
meta.assert(self.rank() > 0 and self.rank() - 1 < MAX_RANK - indices.rank(), "Can't gather1d({}, {}) the resulting shape would have too many axes", .{ self, indices });
/// For each coordinate in `indices`,
/// `gatherValues` extracts a single value of the given tensor.
///
/// * axes_ is a single axis, or a tuple of axis: .b, or .{ .b, .c }
/// * indices is an integer tensor
/// * result is a tensor whose shape is similar to the input shape
/// where the gathered axes have been replaced by axes from 'indices'.
///
/// Some example input for the base case where we work on one axis:
/// - gatherValues(f:[a]->float, .a, ind:[n]->int)[n] == f[ind[n]]
/// - gatherValues(f:[a, b], .a, ind:[n])[n, b] == f[ind[n], b]
/// - gatherValues(f: [a,b,c], .{.b}, ind: [n,m])[a, n, m, c] == f[a, ind[n, m], c]
///
/// If an axis in common between `self` and `indices`,
/// it is treated as a "batching" axis, meaning that semantically
/// the operator is doing a gatherValues one time per dimension of this axis:
/// - gatherValues(f: [a,b,c], .{.b}, ind: [a,n])[a, n] == f[a, ind[a, n]]
///
/// It is an error to have an axis present in `self`, `axes_` and `indices`.
///
/// If several axes are passed, then the last axis of indices is treated as coordinates:
/// - gatherValues(f: [a,b,c], .{.b, .c}, ind: [n,2])[a, n] == f[a, ind[n][0], ind[n][1]]
/// - gatherValues(f: [a,b,c,d], .{.b, .c}, ind: [a, n,2])[a, n, d] == f[a, ind[a, n][0], ind[a, n][1], d]
///
/// It is possible to use gatherValues without tags, but batching won't be available.
pub fn gatherValues(self: Tensor, axes_: anytype, indices: Tensor, opts: GatherOpts) Tensor {
// scoped_log.debug("gatherValues({}, {any}, {})", .{ self, axes_, indices });
const AxesT = @TypeOf(axes_);
const axes_is_scalar = AxesT == EnumLiteral or AxesT == comptime_int or @typeInfo(AxesT) == .Int;
const a = self.axis(axis_);
if (indices.rank() > 1) {
const flattened_gather = self.gather1d(a, indices.flattenAll(), opts);
var tgt_shape = self._shape.drop(a);
for (0..indices.rank()) |i| {
tgt_shape._dims.insert(@intCast(a + i), indices.dim(i)) catch unreachable;
tgt_shape._tags.insert(@intCast(a + i), indices._shape._tags.get(i)) catch unreachable;
const val_axes = if (axes_is_scalar)
std.BoundedArray(u3, Tensor.MAX_RANK).fromSlice(&.{self.axis(axes_)}) catch unreachable
else
self.axes(axes_);
meta.assert(val_axes.len > 0, "gatherValues expects 1 or more axes to operate one, received none. Example: `x.gatherValues(.a, indices, .{{}})`", .{});
for (val_axes.constSlice(), 0..) |a, i| {
if (i > 0) {
meta.assert(a == val_axes.get(i - 1) + 1, "gatherValues expects 'axes_' too be sequential. But {any} aren't sequential in {}", .{ axes_, self });
}
return flattened_gather.reshape(tgt_shape);
}
meta.assert(indices.rank() == 1, "gather1d expects 'indices' tensor rank to be equal to 1, got {}", .{indices.rank()});
const AxisKind = enum { batching, offset, collapsed, indices };
var self_kind: std.BoundedArray(AxisKind, MAX_RANK) = .{};
var indices_batch_axes: Shape.DimsArray = .{};
for (self._shape.tags(), 0..self.rank()) |t, self_ax| {
const maybe_val_ax = std.mem.indexOfScalar(u3, val_axes.constSlice(), @intCast(self_ax));
if (indices._shape.hasTag(t)) |id_ax| {
// tag is both in self and indices -> it's a batching dim
// Note: tags are required for batching.
self_kind.appendAssumeCapacity(.batching);
indices_batch_axes.appendAssumeCapacity(id_ax);
meta.assert(maybe_val_ax == null, "gatherValues expects axes to be either batches or slices axes. Axis {s} has been found both in `axes={any}` and `indices={}`", .{ t, axes_, indices });
} else if (maybe_val_ax) |_| {
// for gatherValues we collapsed all gathered axes
// (contrary to gatherSlices where we collapse none)
self_kind.appendAssumeCapacity(.collapsed);
} else {
self_kind.appendAssumeCapacity(.offset);
}
}
const res_shape = self._shape.set(a, indices.dim(0));
const slice_sizes = self._shape.set(a, 1);
const offset_dims = Shape.range(self.rank(), self.dtype()).drop(a);
// When we receive several axes_ we need an extra dimension to store
// one index per axis, which makes the coordinates of one value.
// Otherwi se stablehlo uses the "indices.rank()" default value.
const index_coord_axis = if (axes_is_scalar)
indices.rank()
else blk: {
const ax = indices._shape.hasTag(.coord) orelse indices._shape.axis(-1);
meta.assert(indices.dim(ax) == val_axes.len, "gatherValues with axes={any}, expects indices to be of shape [..., {}], got: {}", .{ axes_, val_axes.len, indices });
break :blk ax;
};
// compute res shape
var res_shape = Shape.init(.{}, self.dtype());
var res_kind: std.BoundedArray(AxisKind, MAX_RANK) = .{};
for (self_kind.constSlice(), 0..) |kind, ax_usize| {
const ax: u3 = @intCast(ax_usize);
if (ax == val_axes.get(0)) {
// The first val_ax is special cause this is the place where we insert indices axes.
for (indices._shape.tags(), 0..indices.rank()) |t, id_ax| {
if (id_ax == index_coord_axis) continue;
if (std.mem.indexOfScalar(i64, indices_batch_axes.constSlice(), @intCast(id_ax))) |_| {
// batching dim are already in res
continue;
}
res_shape = res_shape.appendDim(indices.dim(id_ax), t);
res_kind.appendAssumeCapacity(.indices);
}
}
switch (kind) {
.collapsed => continue,
else => {
res_shape = res_shape.appendDim(self.dim(ax), self._shape.tag(ax));
res_kind.appendAssumeCapacity(kind);
},
}
}
// This is not a gather, but a dynamicSlice.
// Sometimes the backend recognize this pattern, but not always.
// So let us handle that.
if (indices.count() == 1) {
return self.dynamicSlice1d(val_axes.get(0), 1, indices.flattenAll().squeeze(0)).reshape(res_shape);
}
var slice_dims: Shape.DimsArray = .{};
for (self_kind.constSlice(), self.dims()) |k, d| {
slice_dims.appendAssumeCapacity(switch (k) {
.batching, .collapsed => 1,
.offset => d,
.indices => unreachable,
});
}
// scoped_log.debug("gatherValues --> {} {any}", .{ res_shape, res_kind.constSlice() });
const loc = self.getContext().mlirCtx().location(@src());
const gather_op = dialect.stablehlo.gather(
self.getContext().mlirCtx(),
self.value(),
indices.value(),
slice_sizes.dims(),
slice_dims.constSlice(),
loc,
.{
.offset_dims = offset_dims.dims(),
.collapsed_slice_dims = &.{a},
.operand_batching_dims = &.{},
.start_indices_batching_dims = &.{},
.start_index_map = &.{a},
.index_vector_dim = indices.rank(),
.offset_dims = _collectAxes(AxisKind, res_kind, .offset).constSlice(),
.collapsed_slice_dims = _collectAxes(AxisKind, self_kind, .collapsed).constSlice(),
.operand_batching_dims = _collectAxes(AxisKind, self_kind, .batching).constSlice(),
.start_indices_batching_dims = indices_batch_axes.constSlice(),
.start_index_map = _collectAxes(AxisKind, self_kind, .collapsed).constSlice(),
.index_vector_dim = index_coord_axis,
.indices_are_sorted = opts.indices_are_sorted,
},
);
const mlir_shape = fromMlirValue(gather_op.result(0)).shape();
meta.assert(mlir_shape.eql(res_shape), "gatherValues expects that batching indices appear in the same order in 'self' and 'indices', got: self={}, indices={}. You should transpose one or the other.", .{ self, indices });
return _result(res_shape, gather_op.result(0));
}
test gatherValues {
const zml = @import("zml.zig");
const platform = zml.testing.env();
{
// Only test shapes
var comp = try zml.module.CompilationContext.init(std.heap.page_allocator, "test", platform);
defer comp.deinit();
comp.activate();
defer comp.deactivate();
inline for (.{
.{ .{ .a = 10 }, .a, .{}, .{} },
.{ .{ .a = 10 }, .a, .{ .n = 8 }, .{ .n = 8 } },
.{ .{ .a = 10, .b = 20 }, .a, .{}, .{ .b = 20 } },
.{ .{ .a = 10, .b = 20 }, .a, .{ .n = 8 }, .{ .n = 8, .b = 20 } },
.{ .{ .a = 10, .b = 20 }, 0, .{ .n = 8 }, .{ .n = 8, .b = 20 } },
// Favor val shape, instead of indices shape.
.{ .{ .a = 10, .b = 20 }, .b, .{ .n = 8 }, .{ .a = 10, .n = 8 } },
.{ .{ .a = 10, .b = 20, .c = 30 }, .b, .{ .n = 8 }, .{ .a = 10, .n = 8, .c = 30 } },
// batching axes are implicits.
// TODO: batched gather don't compile https://github.com/zml/zml/issues/400
.{ .{ .a = 10, .b = 20 }, .b, .{ .a = 10 }, .{ .a = 10 } },
.{ .{ .a = 10, .b = 20 }, .a, .{ .b = 20 }, .{ .b = 20 } },
.{ .{ .a = 10, .b = 20 }, .b, .{ .a = 10, .n = 8 }, .{ .a = 10, .n = 8 } },
// stablehlo.gather is biased toward indices shape (like gatherSlice).
// This make it awkward to use when you have both batching dimension and new indices dimensions.
// For now we reject those, and let user explicitly transpose self or indices if needed.
// .{ .{ .a = 10, .b = 20 }, .b, .{ .n = 8, .a = 10 }, .{ .a = 10, .n = 8 } },
// Also handle tuples
.{ .{ .a = 10, .b = 20 }, .{ .a, .b }, .{ .n = 8, ._ = 2 }, .{ .n = 8 } },
.{ .{ 10, 20 }, .{ -2, -1 }, .{ 8, 2 }, .{8} },
// and 1-tuple
.{ .{ .a = 10, .b = 20 }, .{.b}, .{ .n = 8, ._ = 1 }, .{ .a = 10, .n = 8 } },
}) |testcase| {
const x_shape, const tag, const idx_shape, const res_shape = testcase;
const x = Tensor.constant(x_shape, .{ .f16 = 0 });
const idx = Tensor.constant(idx_shape, .{ .i32 = 0 });
const y = gatherValues(x, tag, idx, .{});
try zml.testing.expectEqualShapes(Shape.init(res_shape, .f16), y.shape());
try std.testing.expect(y.value().owner().verify());
}
}
}
/// Gathers slices along the given axes with runtime indices.
/// * slice_shape represents the shape of the slices to extract,
/// it must be smaller than original shape.
@ -1938,9 +2079,9 @@ pub const Tensor = struct {
/// * gatherSlices([A, B, C, D], .{.b=B', .c=C'}, [N, 2]) -> [N, A, B', C', D]
/// * gatherSlices(x(a,b,c,d), .{.b=B', .c=C'}, g(n,m)) = z(n, a, b', c', d) = x(a, g(n, 0) + b', g(n, 1) + c', d)
///
/// Note: the axis order of the result is different from gather1d.
/// Note: the axis order of the result is different from gatherValues.
/// This is because gatherSlices, favorizes contiguous copy of the extracted slices,
/// while gather1d, always copy values one by one, and as such don't have the same issues.
/// while gatherValues, always copy values one by one, and as such don't have the same issues.
/// In our example the contiguous dimension .d is not sliced
/// and gatherSlices can copy data by group of C'*D elements.
pub fn gatherSlices(self: Tensor, slice_shape: Shape, indices: Tensor, opts: GatherOpts) Tensor {
@ -2390,6 +2531,8 @@ pub const Tensor = struct {
return _result(new_shape, op.result(0));
}
pub const DynSlice = struct { start: Tensor, len: i64 };
/// Slices a Tensor across many axes, with runtime known offsets.
///
/// Due to the nature of stablehlo, the length of the slices need to be known when compiling the IR.
@ -2407,7 +2550,6 @@ pub const Tensor = struct {
pub fn dynamicSlice(self: Tensor, slices_: anytype) Tensor {
// TODO: the untagged api is a bit verbose. Should I allow: `Tensor(.{ 20,30,40}).dynamicSlice(.{.{}, .{ .start = b_off, .len = 12 }, .{}});` ??
//
const DynSlice = struct { start: Tensor, len: i64 };
const slices, const slices_tags = Shape.parseStruct(DynSlice, slices_);
// TODO use slices and slices_tags for the format.
@ -3328,3 +3470,13 @@ test shapesOf {
try std.testing.expectEqual(fc2_bias_shape, shapes.fc2.bias);
}
}
fn _collectAxes(T: type, bounded_array: std.BoundedArray(T, Tensor.MAX_RANK), value: T) std.BoundedArray(i64, Tensor.MAX_RANK) {
var res: std.BoundedArray(i64, Tensor.MAX_RANK) = .{};
for (bounded_array.constSlice(), 0..) |v, ax| {
if (v == value) {
res.appendAssumeCapacity(@intCast(ax));
}
}
return res;
}

View File

@ -186,7 +186,7 @@ pub fn roll(self: Tensor, shifts: []const i64, axes_: []const u8) Tensor {
const start = @mod(self.dim(a) - shifts[0], self.dim(a));
const idx = Tensor.arange(.{ .start = start, .end = start + self.dim(a) }, .f32);
const divisor: f32 = @floatFromInt(self.dim(a));
return self.gather1d(a, idx.fmod(divisor).convert(.i32), .{});
return self.gatherValues(a, idx.fmod(divisor).convert(.i32), .{});
}
test roll {