From ccdf218961f7ffa34b1391ed4ec9477e84753113 Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Wed, 18 Jan 2023 12:03:48 +0000 Subject: [PATCH] =?UTF-8?q?Add=20multi=E2=80=91axis,=20batched=20`gatherVa?= =?UTF-8?q?lues`=20support=20to=20tensor,=20shape,=20nn,=20quantization,?= =?UTF-8?q?=20and=20torch=20modules.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- zml/nn.zig | 6 +- zml/quantization.zig | 4 +- zml/shape.zig | 8 +- zml/tensor.zig | 210 +++++++++++++++++++++++++++++++++++++------ zml/torch.zig | 2 +- 5 files changed, 190 insertions(+), 40 deletions(-) diff --git a/zml/nn.zig b/zml/nn.zig index 9fb426b..4af9695 100644 --- a/zml/nn.zig +++ b/zml/nn.zig @@ -38,7 +38,7 @@ pub const TokenEmbedding = struct { pub fn forward(self: TokenEmbedding, idx: Tensor) Tensor { meta.assert(idx.dtype().isInteger(), "TokenEmbedding expects an integer input, received: {}", .{idx}); meta.assert(self.weight.rank() == 2, "TokenEmbedding expects it's weight Tensor to be a 2D matrix, got {}", .{self.weight}); - return self.weight.gather1d(0, idx, .{}); + return self.weight.gatherValues(0, idx, .{}); } }; @@ -393,7 +393,7 @@ pub fn nearest(input: Tensor, scale_factor: []const f64) Tensor { const n = out_shape.dim(d); const ratio = meta.divFloat(f32, input.dim(d), n); const offsets = Tensor.arange(.{ .end = n }, .f32).addConstant(0.5).scale(ratio).floor().convert(.i32); - res = res.gather1d(d, offsets, .{ .indices_are_sorted = true }); + res = res.gatherValues(d, offsets, .{ .indices_are_sorted = true }); } return res; } @@ -927,7 +927,7 @@ pub fn sampleTokens(activations: Tensor, opts: SamplingStrategy, rng: Tensor.Rng // topk_idx is indices into topk.values ! so in the range [0, topk] // Convert for the original indices from the full [0, voc] range. - const next_tokens = topk.indices.gather1d(.voc, topk_idx.squeeze(.topk), .{}).squeeze(.voc); + const next_tokens = topk.indices.gatherValues(.voc, topk_idx.squeeze(.topk), .{}); // log.debug("sampleTokens({}) -> {} -> {} -> {}", .{ activations, topk.indices, topk_idx, next_tokens }); return .{ next_tokens, next_rng }; } diff --git a/zml/quantization.zig b/zml/quantization.zig index 5fa09ab..2a850a4 100644 --- a/zml/quantization.zig +++ b/zml/quantization.zig @@ -85,7 +85,7 @@ pub fn Q4_0(comptime dtype: zml.DataType) type { const indices = indices1.add(indices2); // We select the values we are interested in with the indices, group them by pair and bitcast them to f16, then convert them to f32. - const scales = input.gather1d(0, indices, .{ .indices_are_sorted = true }).reshape(.{ block_count, 2 }).bitCast(.f16).convert(.f32); + const scales = input.gatherValues(0, indices, .{ .indices_are_sorted = true }).reshape(.{ block_count, 2 }).bitCast(.f16).convert(.f32); return scales; } @@ -107,7 +107,7 @@ pub fn Q4_0(comptime dtype: zml.DataType) type { // NOTE(Corendos): i4 is not supported by bitcast convert, so we need the following workaround. // We select the values we are interested in with the indices, these are our quantized_weights. - const quantized_weights = input.gather1d(0, indices, .{ .indices_are_sorted = true }); + const quantized_weights = input.gatherValues(0, indices, .{ .indices_are_sorted = true }); const lb_weights = quantized_weights .logical(.And, zml.Tensor.constant(.{16 * block_count}, zml.Data.init(.u8, 0xf))) .bitCast(.i8); diff --git a/zml/shape.zig b/zml/shape.zig index 784ca67..842e1ed 100644 --- a/zml/shape.zig +++ b/zml/shape.zig @@ -20,11 +20,9 @@ pub const Shape = struct { pub const TagUnknown = "_".ptr; const TagLast = "last".ptr; - // Note: we can't make those public otherwise `refAllDeclsRecursive` - // will try to compile `std.BoundedArray.Writer` and will produce a compile error. - const DimsArray = std.BoundedArray(i64, MAX_RANK); - const TagsArray = std.BoundedArray(Tag, MAX_RANK); - const AxesArray = std.BoundedArray(u3, MAX_RANK); + pub const DimsArray = std.BoundedArray(i64, MAX_RANK); + pub const TagsArray = std.BoundedArray(Tag, MAX_RANK); + pub const AxesArray = std.BoundedArray(u3, MAX_RANK); const UnknownTags: TagsArray = .{ .len = 0, .buffer = [_]Tag{TagUnknown} ** MAX_RANK }; diff --git a/zml/tensor.zig b/zml/tensor.zig index 901450b..0c92825 100644 --- a/zml/tensor.zig +++ b/zml/tensor.zig @@ -654,7 +654,7 @@ pub const Tensor = struct { break :blk powers; }; const values = Tensor.constantTensor(HostBuffer.fromArray(&powers)).withTags(.{.d}); - const counts = values.gather1d(.d, samples, .{}).sum(.d).bitCast(.u16); + const counts = values.gatherValues(.d, samples, .{}).sum(.d).bitCast(.u16); const actual_dist = counts.reshape(target_dist.shape()).convert(target_dist.dtype()).divByConst(s.dim(.n)); return .{ rng, .{ .mean = mean_, .variance = variance, .actual_dist = actual_dist } }; } @@ -1877,50 +1877,191 @@ pub const Tensor = struct { pub const GatherOpts = struct { indices_are_sorted: bool = false }; - /// Gathers along a given axis with 1d indices. - /// ([A, B, C, D], .c, [N]) -> (A, B, N, D) - /// ([A, B, C, D], .c, [N, M]) -> (A, B, N, M, D) - pub fn gather1d(self: Tensor, axis_: anytype, indices: Tensor, opts: GatherOpts) Tensor { - // TODO: handle batching dims - meta.assert(self.rank() > 0 and self.rank() - 1 < MAX_RANK - indices.rank(), "Can't gather1d({}, {}) the resulting shape would have too many axes", .{ self, indices }); + /// For each coordinate in `indices`, + /// `gatherValues` extracts a single value of the given tensor. + /// + /// * axes_ is a single axis, or a tuple of axis: .b, or .{ .b, .c } + /// * indices is an integer tensor + /// * result is a tensor whose shape is similar to the input shape + /// where the gathered axes have been replaced by axes from 'indices'. + /// + /// Some example input for the base case where we work on one axis: + /// - gatherValues(f:[a]->float, .a, ind:[n]->int)[n] == f[ind[n]] + /// - gatherValues(f:[a, b], .a, ind:[n])[n, b] == f[ind[n], b] + /// - gatherValues(f: [a,b,c], .{.b}, ind: [n,m])[a, n, m, c] == f[a, ind[n, m], c] + /// + /// If an axis in common between `self` and `indices`, + /// it is treated as a "batching" axis, meaning that semantically + /// the operator is doing a gatherValues one time per dimension of this axis: + /// - gatherValues(f: [a,b,c], .{.b}, ind: [a,n])[a, n] == f[a, ind[a, n]] + /// + /// It is an error to have an axis present in `self`, `axes_` and `indices`. + /// + /// If several axes are passed, then the last axis of indices is treated as coordinates: + /// - gatherValues(f: [a,b,c], .{.b, .c}, ind: [n,2])[a, n] == f[a, ind[n][0], ind[n][1]] + /// - gatherValues(f: [a,b,c,d], .{.b, .c}, ind: [a, n,2])[a, n, d] == f[a, ind[a, n][0], ind[a, n][1], d] + /// + /// It is possible to use gatherValues without tags, but batching won't be available. + pub fn gatherValues(self: Tensor, axes_: anytype, indices: Tensor, opts: GatherOpts) Tensor { + // scoped_log.debug("gatherValues({}, {any}, {})", .{ self, axes_, indices }); + const AxesT = @TypeOf(axes_); + const axes_is_scalar = AxesT == EnumLiteral or AxesT == comptime_int or @typeInfo(AxesT) == .Int; - const a = self.axis(axis_); - if (indices.rank() > 1) { - const flattened_gather = self.gather1d(a, indices.flattenAll(), opts); - var tgt_shape = self._shape.drop(a); - for (0..indices.rank()) |i| { - tgt_shape._dims.insert(@intCast(a + i), indices.dim(i)) catch unreachable; - tgt_shape._tags.insert(@intCast(a + i), indices._shape._tags.get(i)) catch unreachable; + const val_axes = if (axes_is_scalar) + std.BoundedArray(u3, Tensor.MAX_RANK).fromSlice(&.{self.axis(axes_)}) catch unreachable + else + self.axes(axes_); + + meta.assert(val_axes.len > 0, "gatherValues expects 1 or more axes to operate one, received none. Example: `x.gatherValues(.a, indices, .{{}})`", .{}); + for (val_axes.constSlice(), 0..) |a, i| { + if (i > 0) { + meta.assert(a == val_axes.get(i - 1) + 1, "gatherValues expects 'axes_' too be sequential. But {any} aren't sequential in {}", .{ axes_, self }); } - return flattened_gather.reshape(tgt_shape); } - meta.assert(indices.rank() == 1, "gather1d expects 'indices' tensor rank to be equal to 1, got {}", .{indices.rank()}); + const AxisKind = enum { batching, offset, collapsed, indices }; + var self_kind: std.BoundedArray(AxisKind, MAX_RANK) = .{}; + var indices_batch_axes: Shape.DimsArray = .{}; + for (self._shape.tags(), 0..self.rank()) |t, self_ax| { + const maybe_val_ax = std.mem.indexOfScalar(u3, val_axes.constSlice(), @intCast(self_ax)); + if (indices._shape.hasTag(t)) |id_ax| { + // tag is both in self and indices -> it's a batching dim + // Note: tags are required for batching. + self_kind.appendAssumeCapacity(.batching); + indices_batch_axes.appendAssumeCapacity(id_ax); + meta.assert(maybe_val_ax == null, "gatherValues expects axes to be either batches or slices axes. Axis {s} has been found both in `axes={any}` and `indices={}`", .{ t, axes_, indices }); + } else if (maybe_val_ax) |_| { + // for gatherValues we collapsed all gathered axes + // (contrary to gatherSlices where we collapse none) + self_kind.appendAssumeCapacity(.collapsed); + } else { + self_kind.appendAssumeCapacity(.offset); + } + } - const res_shape = self._shape.set(a, indices.dim(0)); - const slice_sizes = self._shape.set(a, 1); - const offset_dims = Shape.range(self.rank(), self.dtype()).drop(a); + // When we receive several axes_ we need an extra dimension to store + // one index per axis, which makes the coordinates of one value. + // Otherwi se stablehlo uses the "indices.rank()" default value. + const index_coord_axis = if (axes_is_scalar) + indices.rank() + else blk: { + const ax = indices._shape.hasTag(.coord) orelse indices._shape.axis(-1); + meta.assert(indices.dim(ax) == val_axes.len, "gatherValues with axes={any}, expects indices to be of shape [..., {}], got: {}", .{ axes_, val_axes.len, indices }); + break :blk ax; + }; + // compute res shape + var res_shape = Shape.init(.{}, self.dtype()); + var res_kind: std.BoundedArray(AxisKind, MAX_RANK) = .{}; + for (self_kind.constSlice(), 0..) |kind, ax_usize| { + const ax: u3 = @intCast(ax_usize); + if (ax == val_axes.get(0)) { + // The first val_ax is special cause this is the place where we insert indices axes. + for (indices._shape.tags(), 0..indices.rank()) |t, id_ax| { + if (id_ax == index_coord_axis) continue; + if (std.mem.indexOfScalar(i64, indices_batch_axes.constSlice(), @intCast(id_ax))) |_| { + // batching dim are already in res + continue; + } + + res_shape = res_shape.appendDim(indices.dim(id_ax), t); + res_kind.appendAssumeCapacity(.indices); + } + } + switch (kind) { + .collapsed => continue, + else => { + res_shape = res_shape.appendDim(self.dim(ax), self._shape.tag(ax)); + res_kind.appendAssumeCapacity(kind); + }, + } + } + + // This is not a gather, but a dynamicSlice. + // Sometimes the backend recognize this pattern, but not always. + // So let us handle that. + if (indices.count() == 1) { + return self.dynamicSlice1d(val_axes.get(0), 1, indices.flattenAll().squeeze(0)).reshape(res_shape); + } + + var slice_dims: Shape.DimsArray = .{}; + for (self_kind.constSlice(), self.dims()) |k, d| { + slice_dims.appendAssumeCapacity(switch (k) { + .batching, .collapsed => 1, + .offset => d, + .indices => unreachable, + }); + } + + // scoped_log.debug("gatherValues --> {} {any}", .{ res_shape, res_kind.constSlice() }); const loc = self.getContext().mlirCtx().location(@src()); const gather_op = dialect.stablehlo.gather( self.getContext().mlirCtx(), self.value(), indices.value(), - slice_sizes.dims(), + slice_dims.constSlice(), loc, .{ - .offset_dims = offset_dims.dims(), - .collapsed_slice_dims = &.{a}, - .operand_batching_dims = &.{}, - .start_indices_batching_dims = &.{}, - .start_index_map = &.{a}, - .index_vector_dim = indices.rank(), + .offset_dims = _collectAxes(AxisKind, res_kind, .offset).constSlice(), + .collapsed_slice_dims = _collectAxes(AxisKind, self_kind, .collapsed).constSlice(), + .operand_batching_dims = _collectAxes(AxisKind, self_kind, .batching).constSlice(), + .start_indices_batching_dims = indices_batch_axes.constSlice(), + .start_index_map = _collectAxes(AxisKind, self_kind, .collapsed).constSlice(), + .index_vector_dim = index_coord_axis, .indices_are_sorted = opts.indices_are_sorted, }, ); + + const mlir_shape = fromMlirValue(gather_op.result(0)).shape(); + meta.assert(mlir_shape.eql(res_shape), "gatherValues expects that batching indices appear in the same order in 'self' and 'indices', got: self={}, indices={}. You should transpose one or the other.", .{ self, indices }); return _result(res_shape, gather_op.result(0)); } + test gatherValues { + const zml = @import("zml.zig"); + const platform = zml.testing.env(); + + { + // Only test shapes + var comp = try zml.module.CompilationContext.init(std.heap.page_allocator, "test", platform); + defer comp.deinit(); + comp.activate(); + defer comp.deactivate(); + + inline for (.{ + .{ .{ .a = 10 }, .a, .{}, .{} }, + .{ .{ .a = 10 }, .a, .{ .n = 8 }, .{ .n = 8 } }, + .{ .{ .a = 10, .b = 20 }, .a, .{}, .{ .b = 20 } }, + .{ .{ .a = 10, .b = 20 }, .a, .{ .n = 8 }, .{ .n = 8, .b = 20 } }, + .{ .{ .a = 10, .b = 20 }, 0, .{ .n = 8 }, .{ .n = 8, .b = 20 } }, + // Favor val shape, instead of indices shape. + .{ .{ .a = 10, .b = 20 }, .b, .{ .n = 8 }, .{ .a = 10, .n = 8 } }, + .{ .{ .a = 10, .b = 20, .c = 30 }, .b, .{ .n = 8 }, .{ .a = 10, .n = 8, .c = 30 } }, + // batching axes are implicits. + // TODO: batched gather don't compile https://github.com/zml/zml/issues/400 + .{ .{ .a = 10, .b = 20 }, .b, .{ .a = 10 }, .{ .a = 10 } }, + .{ .{ .a = 10, .b = 20 }, .a, .{ .b = 20 }, .{ .b = 20 } }, + .{ .{ .a = 10, .b = 20 }, .b, .{ .a = 10, .n = 8 }, .{ .a = 10, .n = 8 } }, + // stablehlo.gather is biased toward indices shape (like gatherSlice). + // This make it awkward to use when you have both batching dimension and new indices dimensions. + // For now we reject those, and let user explicitly transpose self or indices if needed. + // .{ .{ .a = 10, .b = 20 }, .b, .{ .n = 8, .a = 10 }, .{ .a = 10, .n = 8 } }, + // Also handle tuples + .{ .{ .a = 10, .b = 20 }, .{ .a, .b }, .{ .n = 8, ._ = 2 }, .{ .n = 8 } }, + .{ .{ 10, 20 }, .{ -2, -1 }, .{ 8, 2 }, .{8} }, + // and 1-tuple + .{ .{ .a = 10, .b = 20 }, .{.b}, .{ .n = 8, ._ = 1 }, .{ .a = 10, .n = 8 } }, + }) |testcase| { + const x_shape, const tag, const idx_shape, const res_shape = testcase; + const x = Tensor.constant(x_shape, .{ .f16 = 0 }); + const idx = Tensor.constant(idx_shape, .{ .i32 = 0 }); + const y = gatherValues(x, tag, idx, .{}); + try zml.testing.expectEqualShapes(Shape.init(res_shape, .f16), y.shape()); + try std.testing.expect(y.value().owner().verify()); + } + } + } + /// Gathers slices along the given axes with runtime indices. /// * slice_shape represents the shape of the slices to extract, /// it must be smaller than original shape. @@ -1938,9 +2079,9 @@ pub const Tensor = struct { /// * gatherSlices([A, B, C, D], .{.b=B', .c=C'}, [N, 2]) -> [N, A, B', C', D] /// * gatherSlices(x(a,b,c,d), .{.b=B', .c=C'}, g(n,m)) = z(n, a, b', c', d) = x(a, g(n, 0) + b', g(n, 1) + c', d) /// - /// Note: the axis order of the result is different from gather1d. + /// Note: the axis order of the result is different from gatherValues. /// This is because gatherSlices, favorizes contiguous copy of the extracted slices, - /// while gather1d, always copy values one by one, and as such don't have the same issues. + /// while gatherValues, always copy values one by one, and as such don't have the same issues. /// In our example the contiguous dimension .d is not sliced /// and gatherSlices can copy data by group of C'*D elements. pub fn gatherSlices(self: Tensor, slice_shape: Shape, indices: Tensor, opts: GatherOpts) Tensor { @@ -2390,6 +2531,8 @@ pub const Tensor = struct { return _result(new_shape, op.result(0)); } + pub const DynSlice = struct { start: Tensor, len: i64 }; + /// Slices a Tensor across many axes, with runtime known offsets. /// /// Due to the nature of stablehlo, the length of the slices need to be known when compiling the IR. @@ -2407,7 +2550,6 @@ pub const Tensor = struct { pub fn dynamicSlice(self: Tensor, slices_: anytype) Tensor { // TODO: the untagged api is a bit verbose. Should I allow: `Tensor(.{ 20,30,40}).dynamicSlice(.{.{}, .{ .start = b_off, .len = 12 }, .{}});` ?? // - const DynSlice = struct { start: Tensor, len: i64 }; const slices, const slices_tags = Shape.parseStruct(DynSlice, slices_); // TODO use slices and slices_tags for the format. @@ -3328,3 +3470,13 @@ test shapesOf { try std.testing.expectEqual(fc2_bias_shape, shapes.fc2.bias); } } + +fn _collectAxes(T: type, bounded_array: std.BoundedArray(T, Tensor.MAX_RANK), value: T) std.BoundedArray(i64, Tensor.MAX_RANK) { + var res: std.BoundedArray(i64, Tensor.MAX_RANK) = .{}; + for (bounded_array.constSlice(), 0..) |v, ax| { + if (v == value) { + res.appendAssumeCapacity(@intCast(ax)); + } + } + return res; +} diff --git a/zml/torch.zig b/zml/torch.zig index 6fd1742..c85cd23 100644 --- a/zml/torch.zig +++ b/zml/torch.zig @@ -186,7 +186,7 @@ pub fn roll(self: Tensor, shifts: []const i64, axes_: []const u8) Tensor { const start = @mod(self.dim(a) - shifts[0], self.dim(a)); const idx = Tensor.arange(.{ .start = start, .end = start + self.dim(a) }, .f32); const divisor: f32 = @floatFromInt(self.dim(a)); - return self.gather1d(a, idx.fmod(divisor).convert(.i32), .{}); + return self.gatherValues(a, idx.fmod(divisor).convert(.i32), .{}); } test roll {