From d45a667ee5aed605cf6a796a529428d1e0f73a5f Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Fri, 26 Sep 2025 13:38:11 +0000 Subject: [PATCH] Revamp gather API with named indices (and add gather_ variant), improve topK handling, and add Yarn rope embedding support across core modules (buffer, nn, pjrtx, quantization, shape, tensor, testing, tokenizer, torch). --- zml/buffer.zig | 32 +- zml/nn.zig | 496 ++++++------------------ zml/pjrtx.zig | 59 ++- zml/quantization.zig | 38 +- zml/shape.zig | 8 + zml/tensor.zig | 367 +++++++++++------- zml/testing.zig | 2 +- zml/tokenizer/BUILD.bazel | 7 +- zml/tokenizer/sentencepiece/BUILD.bazel | 11 + zml/tokenizer/tokenizer.zig | 7 +- zml/torch.zig | 17 +- 11 files changed, 444 insertions(+), 600 deletions(-) diff --git a/zml/buffer.zig b/zml/buffer.zig index c9373e5..a4a9b19 100644 --- a/zml/buffer.zig +++ b/zml/buffer.zig @@ -54,7 +54,7 @@ pub const Buffer = struct { break :cs @divExact(host_buffer.dim(ax), n_partitions); } else 0; - const buffer_type = bufferTypeFromDtype(host_buffer.shape().dtype()); + const buffer_type = pjrt.bufferTypeFromDtype(host_buffer.shape().dtype()); const byte_strides = host_buffer.strides(); const devices = platform.getDevices(); @@ -256,7 +256,7 @@ pub const Buffer = struct { pub fn asViewOfDeviceBuffer(platform: Platform, shape_: Shape, stream: ?*const pjrt.Stream, device_data: *anyopaque) Buffer { const pjrt_buffer = platform.pjrt_client.createViewOfDeviceBuffer(platform.pjrt_api, .{ .data = device_data, - .element_type = bufferTypeFromDtype(shape_.dtype()), + .element_type = pjrt.bufferTypeFromDtype(shape_.dtype()), .dims = shape_.dims(), // TODO: exposes sharding in the API. .device = platform.getDevices()[0], @@ -437,7 +437,7 @@ pub const Buffer = struct { var args = pjrt.Client.CreateUninitializedBufferArgs{ .dims = shard_shape.dims(), - .element_type = bufferTypeFromDtype(shape_.dtype()), + .element_type = pjrt.bufferTypeFromDtype(shape_.dtype()), .layout = .{ .tiled = .{ .minor_to_major = minorToMajor(shape_.rank()), @@ -487,32 +487,6 @@ pub const Buffer = struct { } }; -pub fn bufferTypeFromDtype(dt: DataType) pjrt.BufferType { - return switch (dt) { - inline else => |tag| @field(pjrt.BufferType, @tagName(tag)), - }; -} - -pub fn dtypeFromBufferType(pjrt_type: pjrt.BufferType) DataType { - return switch (pjrt_type) { - .invalid => @panic("Found an invalid pjrt buffer"), - inline else => |tag| @field(DataType, @tagName(tag)), - }; -} - -test bufferTypeFromDtype { - inline for (@typeInfo(DataType).@"enum".fields) |field| { - const dt: DataType = @enumFromInt(field.value); - try std.testing.expectEqual(dt, dtypeFromBufferType(bufferTypeFromDtype(dt))); - } - - inline for (@typeInfo(pjrt.BufferType).@"enum".fields) |field| { - const dt: pjrt.BufferType = @enumFromInt(field.value); - if (dt == .invalid) continue; - try std.testing.expectEqual(dt, bufferTypeFromDtype(dtypeFromBufferType(dt))); - } -} - const _MINOR_TO_MAJOR = blk: { var ret: [Shape.MAX_RANK]i64 = undefined; for (0..Shape.MAX_RANK) |i| { diff --git a/zml/nn.zig b/zml/nn.zig index 631bce1..4fbb17c 100644 --- a/zml/nn.zig +++ b/zml/nn.zig @@ -32,7 +32,7 @@ pub const Linear = struct { } // log.debug("Linear({*}): {d} -> {d} -> {d}", .{ self, x.dims(), y.dims(), if (self.bias) |bias| y.add(bias).dims() else y.dims() }); - return if (self.bias) |bias| y.add(bias.broadcastLeft(y.shape())) else y; + return if (self.bias) |bias| y.add(bias.broadcast(y.shape(), &.{y.axis(-1)})) else y; } }; @@ -42,7 +42,7 @@ pub const TokenEmbedding = struct { pub fn forward(self: TokenEmbedding, idx: Tensor) Tensor { stdx.debug.assert(idx.dtype().isInteger(), "TokenEmbedding expects an integer input, received: {f}", .{idx}); stdx.debug.assert(self.weight.rank() == 2, "TokenEmbedding expects it's weight Tensor to be a 2D matrix, got {f}", .{self.weight}); - return self.weight.gatherValues(0, idx, .{}); + return self.weight.withTags(.{ .voc, .d }).gather(.{ .voc = idx }, .{}); } }; @@ -125,18 +125,18 @@ pub fn normalizeVariance(x: Tensor, eps: f32) Tensor { // Upcast to improve precision const xf32 = x.convert(.f32); const mean = xf32.sum(-1).scale(1.0 / N); - const mean_dev = xf32.sub(mean.broadcastRight(xf32.shape())); - const variance = mean_dev.mul(mean_dev).sum(-1).scale(1.0 / N); + const mean_dev = xf32.sub(mean); + const variance = mean_dev.mul(mean_dev).sum(-1).divByConst(N); const rsqrt = Tensor.rsqrt(variance.addConstant(eps)); - return mean_dev.mul(rsqrt.broadcastRight(mean_dev.shape())).convert(x.dtype()); + return mean_dev.mul(rsqrt).convert(x.dtype()); } // ref: https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html // Implementation equivalent to `nn.functional.normalize(tensor, dim=-1)` call pub fn normalizeL2(input: Tensor, eps: f32) Tensor { - const inv_norm = input.pow(Tensor.scalar(2, input.dtype())).sum(-1).addConstant(eps).rsqrt(); - return input.mul(inv_norm.broad(input.shape())); + const inv_norm = input.powByConst(2).sum(-1).addConstant(eps).rsqrt(); + return input.mul(inv_norm); } test normalizeL2 { @@ -165,12 +165,22 @@ pub const RopeOpts = struct { default: void, custom: []const f32, llama3: Llama3, + yarn: Yarn, pub const Llama3 = struct { factor: f32, high_freq_factor: f32, low_freq_factor: f32, original_max_position_embeddings: u32, + truncate: bool = true, + }; + + pub const Yarn = struct { + beta_fast: f32 = 32.0, + beta_slow: f32 = 1.0, + factor: f32, + truncate: bool = true, + original_max_position_embeddings: u32, }; /// Read a Rope scaling config from HF config.json format. @@ -185,12 +195,21 @@ pub const RopeOpts = struct { if (impl != .string) return error.InvalidEnumTag; if (std.mem.eql(u8, impl.string, "llama3")) { // Note: leaky is fine here cause Llama3 struct don't need to allocate memory. - return .{ .llama3 = try std.json.parseFromValueLeaky(Llama3, undefined, content, .{ .ignore_unknown_fields = true }) }; + return .{ .llama3 = try std.json.parseFromValueLeaky(Llama3, stdx.noalloc, content, .{ .ignore_unknown_fields = true }) }; + } else if (std.mem.eql(u8, impl.string, "yarn")) { + return .{ .yarn = try std.json.parseFromValueLeaky(Yarn, stdx.noalloc, content, .{ .ignore_unknown_fields = true }) }; } else { log.warn("Unsupported Rope implementation: {s}, will use the default one which will produce altered results", .{impl.string}); return .{ .default = {} }; } } + + pub fn attentionScaling(scaling: Scaling) f32 { + return switch (scaling) { + .yarn => |yarn| 0.1 * @log(yarn.factor) + 1.0, + else => 1.0, + }; + } }; }; @@ -218,8 +237,9 @@ pub fn rope(x: Tensor, pos_idx: ?Tensor, opts: RopeOpts) Tensor { // compute sin and cos in f32 before downcasting to x type. const inv_freq = invFreq(x.dim(.hd), opts).withTags(.{.hd}); const inv_freq_pos = Tensor.outer(idx.convert(.f32), inv_freq); - const cos = inv_freq_pos.cos().convert(x.dtype()).broad(x_real.shape()); - const sin = inv_freq_pos.sin().convert(x.dtype()).broad(x_real.shape()); + const scaling = opts.scaling.attentionScaling(); + const cos = inv_freq_pos.cos().scale(scaling).convert(x.dtype()).broad(x_real.shape()); + const sin = inv_freq_pos.sin().scale(scaling).convert(x.dtype()).broad(x_real.shape()); // apply rotation const y_real = x_real.mul(cos).sub(x_imag.mul(sin)); @@ -250,7 +270,7 @@ pub fn mergeRealImg(x_real: Tensor, x_imag: Tensor, layout: RopeOpts.Layout) Ten .interleaved => Tensor.concatenate(&.{ x_real.appendAxes(.{.interleaved_real_img}), x_imag.appendAxes(.{.interleaved_real_img}), - }, -1).flatten(-2), + }, -1).reshape(x_imag.shape().setDim(-1, -1)), }; } @@ -259,6 +279,8 @@ pub fn invFreq(N: i64, opts: RopeOpts) Tensor { const allocator = zml.module.CompilationContext.current().allocator(); const N_half: usize = @intCast(@divExact(N, 2)); const inv_freq = allocator.alloc(f32, N_half) catch @panic("OOM"); + defer allocator.free(inv_freq); + _invFreq(opts, inv_freq); return zml.Tensor.constantTensor(.fromSlice(.{@divExact(N, 2)}, inv_freq)); } @@ -299,10 +321,38 @@ fn _invFreq(opts: RopeOpts, inv_freq: []f32) void { } } }, + .yarn => |s| { + const N_f: f64 = @floatFromInt(inv_freq.len); + const M: f64 = @floatFromInt(s.original_max_position_embeddings); + const f_high = s.beta_fast * (2 * std.math.pi) / M; + const f_low = s.beta_slow * (2 * std.math.pi) / M; + const downscaling = 1.0 / s.factor; + + // This isn't a typo: low n have a high frequency, high n have a low frequency. + var n_low: f64 = -@log(f_high) / @log(opts.freq_base) * N_f; + var n_high: f64 = -@log(f_low) / @log(opts.freq_base) * N_f; + if (s.truncate) { + n_high = std.math.ceil(n_high); + n_low = std.math.floor(n_low); + } + std.debug.assert(n_high > n_low); + for (0..N, inv_freq) |n, f| { + if (f > f_high) { + // High freq match default implem + } else if (f < f_low) { + // Downscaling for low freq + inv_freq[n] *= downscaling; + } else { + // Yarn use lerp too but not in the frequency space, in the time space. + const lerp: f64 = (n_high - @as(f64, @floatFromInt(n))) / (n_high - n_low); + inv_freq[n] *= @floatCast(lerp + (1 - lerp) * downscaling); + } + } + }, } } -test invFreq { +test "invFreq Llama3" { // Llama 3.2-1B config const llama_conf: RopeOpts = .{ .freq_base = 500_000, @@ -323,6 +373,28 @@ test invFreq { } } +test "invFreq Yarn" { + const yarn_conf: RopeOpts = .{ + .freq_base = 150_000, + .scaling = .{ .yarn = .{ + .factor = 32.0, + .beta_fast = 32.0, + .beta_slow = 1.0, + .original_max_position_embeddings = 4096, + .truncate = true, + } }, + }; + const yarn_freq = [_]f32{ 1.000000000000e+00, 6.890442967415e-01, 4.747820496559e-01, 3.271458745003e-01, 2.254180014133e-01, 1.553229838610e-01, 1.070244237781e-01, 7.374456524849e-02, 5.081327259541e-02, 3.162075206637e-02, 1.945096626878e-02, 1.179219130427e-02, 7.015713956207e-03, 4.069554619491e-03, 2.277272054926e-03, 1.206130953506e-03, 5.809474969283e-04, 2.279478358105e-04, 3.830881178146e-05, 2.639646845637e-05, 1.818833698053e-05, 1.253256959899e-05, 8.635495760245e-06, 5.950239483354e-06, 4.099978468730e-06, 2.825066758305e-06, 1.946596285052e-06, 1.341290953860e-06, 9.242089618056e-07, 6.368209142238e-07, 4.387978549403e-07, 3.023511396805e-07 }; + + var inv_freq: @TypeOf(yarn_freq) = undefined; + _invFreq(yarn_conf, &inv_freq); + for (yarn_freq, inv_freq, 0..) |expected, actual, i| { + errdefer log.err("Mismatch at position {d}.\nExpected: {d}\nActual: {d}", .{ i, stdx.fmt.slice(&yarn_freq), stdx.fmt.slice(&inv_freq) }); + try std.testing.expectApproxEqRel(expected, actual, 1e-5); + } + try std.testing.expectApproxEqRel(1.3465735902799727, yarn_conf.scaling.attentionScaling(), 1e-5); +} + test "real/img" { const platform = zml.testing.env(); @@ -332,8 +404,8 @@ test "real/img" { const real, const imag = splitRealImg(x, layout); const y = mergeRealImg(real, imag, layout); const real2, const imag2 = splitRealImg(y, layout); - return real.cmp(.EQ, real2).flatten(0).convert(.i32).sum(-1).add( - imag.cmp(.EQ, imag2).flatten(0).convert(.i32).sum(-1), + return real.cmp(.EQ, real2).flatten().convert(.i32).sum(-1).add( + imag.cmp(.EQ, imag2).flatten().convert(.i32).sum(-1), ); } @@ -349,8 +421,8 @@ test "real/img" { Tensor.arange(.{ .start = 3, .end = 20, .step = 4 }, .f32).reshape(.{ 5, 1 }), }, 1); - return real.cmp(.EQ, x_real).flatten(0).convert(.i32).sum(-1).add( - imag.cmp(.EQ, x_imag).flatten(0).convert(.i32).sum(-1), + return real.cmp(.EQ, x_real).flatten().convert(.i32).sum(-1).add( + imag.cmp(.EQ, x_imag).flatten().convert(.i32).sum(-1), ); } @@ -366,8 +438,8 @@ test "real/img" { Tensor.arange(.{ .start = 3, .end = 20, .step = 4 }, .f32).reshape(.{ 5, 1 }), }, 1); - return real.cmp(.EQ, x_real).flatten(0).convert(.i32).sum(-1).add( - imag.cmp(.EQ, x_imag).flatten(0).convert(.i32).sum(-1), + return real.cmp(.EQ, x_real).flatten().convert(.i32).sum(-1).add( + imag.cmp(.EQ, x_imag).flatten().convert(.i32).sum(-1), ); } @@ -377,8 +449,8 @@ test "real/img" { const x_real = Tensor.arange(.{ .start = 0, .end = 20, .step = 2 }, .f32).reshape(.{ 5, 2 }); const x_imag = Tensor.arange(.{ .start = 1, .end = 20, .step = 2 }, .f32).reshape(.{ 5, 2 }); - return real.cmp(.EQ, x_real).flatten(0).convert(.i32).sum(-1).add( - imag.cmp(.EQ, x_imag).flatten(0).convert(.i32).sum(-1), + return real.cmp(.EQ, x_real).flatten().convert(.i32).sum(-1).add( + imag.cmp(.EQ, x_imag).flatten().convert(.i32).sum(-1), ); } }; @@ -484,21 +556,21 @@ pub fn nearest(input: Tensor, scale_factor: []const f64) Tensor { out_shape._dims.set(i + 2, @intFromFloat(@floor(@as(f64, @floatFromInt(out_shape.dim(i + 2))) * sf))); } // TODO(james): remove this implicit two batching dims - var sd: [3]usize = undefined; + var sd: [3]u3 = undefined; var len_sd: usize = 0; for (2..input.rank()) |i| { if (input.dim(i) != out_shape.dim(i)) { - sd[len_sd] = i; + sd[len_sd] = @intCast(i); len_sd += 1; } } - const spatial_dims = sd[0..len_sd]; + const spatial_axes = sd[0..len_sd]; var res = input; - for (spatial_dims) |d| { - const n = out_shape.dim(d); - const ratio = stdx.math.divFloat(f32, input.dim(d), n); + for (spatial_axes) |ax| { + const n = out_shape.dim(ax); + const ratio = stdx.math.divFloat(f32, input.dim(ax), n); const offsets = Tensor.arange(.{ .end = n }, .f32).addConstant(0.5).scale(ratio).floor().convert(.i32); - res = res.gatherValues(d, offsets, .{ .indices_are_sorted = true }); + res = res.gather_(&.{ax}, &.{offsets}, .{ .indices_are_sorted = true }); } return res; } @@ -670,10 +742,11 @@ test resizeBilinear { } pub fn resizeLinear1d(image: Tensor, axis: i8, new_len: u63, opt: ResizeOpts) Tensor { - const res_shape = image.shape().set(axis, new_len); + const ax = image.axis(axis); + const res_shape = image.shape().set(ax, new_len); const dtype = opt.precision orelse if (image.dtype().class() == .integer) .f32 else image.dtype(); - const og_len = opt.original_len orelse Tensor.scalar(image.dim(axis), dtype); + const og_len = opt.original_len orelse Tensor.scalar(image.dim(ax), dtype); const ratio = og_len.convert(dtype).scale(stdx.math.divFloat(f32, 1, new_len)); const scaled = Tensor.arange(.{ .end = new_len }, dtype).mul(ratio); const left = scaled.floor(); @@ -682,11 +755,11 @@ pub fn resizeLinear1d(image: Tensor, axis: i8, new_len: u63, opt: ResizeOpts) Te // TODO: check that two gather isn't too bad perf wise. // Normally we should use gatherSlices to collect the values 2 by 2, // but gatherSlices messes up with the order of axes. - const left_val = image.gatherValues(axis, left.convert(.i32), .{ .indices_are_sorted = true }).convert(dtype); - const right_val = image.gatherValues(axis, right.convert(.i32), .{ .indices_are_sorted = true }).convert(dtype); + const left_val = image.gather_(&.{ax}, &.{left.convert(.i32)}, .{ .indices_are_sorted = true }).convert(dtype); + const right_val = image.gather_(&.{ax}, &.{right.convert(.i32)}, .{ .indices_are_sorted = true }).convert(dtype); - const left_weight = right.sub(scaled).broadcast(res_shape, &.{axis}); - const right_weight = scaled.sub(left).broadcast(res_shape, &.{axis}); + const left_weight = right.sub(scaled).broadcast(res_shape, &.{ax}); + const right_weight = scaled.sub(left).broadcast(res_shape, &.{ax}); const res = left_val.mul(left_weight).add(right_val.mul(right_weight)); return res.convert(image.dtype()).withTags(image.shape().tags()); @@ -853,13 +926,10 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor { return cuda.sdpa(q, k, v, opts); } - if (q.dim(.h) != k.dim(.h)) { - stdx.debug.assert(@mod(q.dim(.h), k.dim(.h)) == 0, err_template ++ "Different number of heads for keys and queries, but can't repeat keys.", err_args); - // Note: we don't try to repeat queries. - // Repeating keys is the interesting optimisation cause it reduces KV cache memory usage. - const num_rep: u63 = @intCast(@divExact(q.dim(.h), k.dim(.h))); - k, v = .{ k.repeat1d(.h, num_rep), v.repeat1d(.h, num_rep) }; - } + // Handle different numbers of head by splitting q heads. + // This is a bit error prone in the sense that it depends of the layout of q heads. + // This is the Llama convention though. + q = q.splitAxis(.h, .{ .h = k.dim(.h), .hq = .auto }); const attn_mask = if (opts.attn_mask) |m| m else null; const dims = helpers.collectDims(.{ .h, .q, .k, .hd }, &.{ q, k, v, attn_mask }, .strict) catch { @@ -875,339 +945,7 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor { attn_weights = attn_weights.convert(.f32).softmax(.k).convert(q.dtype()); var attn = attn_weights.dot(v, .{.k}); - return attn.transpose(q.shape()); -} - -pub const SdpaChunks = struct { q_chunk_size: u32, k_chunk_size: u32 }; - -pub fn sdpaMemEfficient( - q: Tensor, - k: Tensor, - v: Tensor, - sdpa_opts: SdpaOpts, - chunking: SdpaChunks, -) Tensor { - const sdpa_mem_efficient: SdpaMemEfficient = .{ - .q = q, - .k = k, - .v = v, - .sdpa_opts = sdpa_opts, - .chunking = .{ - .q_chunk_size = @intCast(@min(q.dim(.q), chunking.q_chunk_size)), - .k_chunk_size = @intCast(@min(k.dim(.k), chunking.k_chunk_size)), - }, - }; - - return sdpa_mem_efficient.forward(); -} - -const SdpaMemEfficient = struct { - q: Tensor, - k: Tensor, - v: Tensor, - sdpa_opts: SdpaOpts, - chunking: SdpaChunks, - - fn forward(self: SdpaMemEfficient) Tensor { - stdx.debug.assert(@mod(self.q.dim(.q), self.chunking.q_chunk_size) == 0, "sdpaMemEfficient expects the chunk_size to exactly divise the seq_len, got: sdpaMemEfficient({f}, {})", .{ self.q, self.chunking }); - stdx.debug.assert(@mod(self.k.dim(.k), self.chunking.k_chunk_size) == 0, "sdpaMemEfficient expects the chunk_size to exactly divise the seq_len, got: sdpaMemEfficient({f}, {})", .{ self.k, self.chunking }); - const n_q_chunks: u32 = @intCast(@divExact(self.q.dim(.q), self.chunking.q_chunk_size)); - - const ctx = zml.module.CompilationContext.current(); - const q_chunks = ctx.allocator().alloc(zml.Tensor, n_q_chunks) catch unreachable; - defer ctx.allocator().free(q_chunks); - for (0..n_q_chunks) |i| { - const idx: u32 = @intCast(i); - const q_slice: zml.Tensor.DynSlice = .{ - .start = Tensor.scalar(idx * self.chunking.q_chunk_size, .i32), - .len = self.chunking.q_chunk_size, - }; - const q_chunk = self.q.dynamicSlice(.{ .q = q_slice }); - const attn_chunk = if (self.sdpa_opts.attn_mask) |attn_mask| attn_mask.dynamicSlice(.{ .q = q_slice }) else null; - - var chunk: SdpaMemEfficient = self; - chunk.q = q_chunk; - chunk.sdpa_opts.attn_mask = attn_chunk; - q_chunks[i] = chunk.scanKeyVal(); - } - - const res = zml.Tensor.concatenate(q_chunks, .q); - return res.transpose(self.q.shape()); - } - - fn nextQueriesChunk(self: SdpaMemEfficient, idx: Tensor) Tensor { - const q_slice: zml.Tensor.DynSlice = .{ - .start = idx.scale(self.chunking.q_chunk_size), - .len = self.chunking.q_chunk_size, - }; - const q_chunk = self.q.dynamicSlice(.{ .q = q_slice }); - const attn_chunk = if (self.sdpa_opts.attn_mask) |attn_mask| attn_mask.dynamicSlice(.{ .q = q_slice }) else null; - - var chunk: SdpaMemEfficient = self; - chunk.q = q_chunk; - chunk.sdpa_opts.attn_mask = attn_chunk; - return chunk.scanKeyVal(); - } - - fn scanKeyVal(self: SdpaMemEfficient) Tensor { - const n_chunks = @divExact(self.k.dim(.k), self.chunking.k_chunk_size); - return if (n_chunks <= 4) { - // Unrolled version - var partial_softmax: ?PartialSoftmax = null; - for (0..@intCast(n_chunks)) |idx| { - const next = self.nextKeyValChunk(Tensor.scalar(idx, .i32)); - partial_softmax = if (partial_softmax) |prev| prev.merge(next) else next; - } - return partial_softmax.?.finalize(); - } else { - // stablehlo.while version - const partial_softmax, _ = zml.ops.while_(hasNextKeyValChunk, nextKeyValChunkMerge, self, .{ PartialSoftmax.zeros(self.q.shape(), .f32), Tensor.scalar(0, .i32) }); - return partial_softmax.finalize(); - }; - } - - fn nextKeyValChunkMerge(self: SdpaMemEfficient, prev: PartialSoftmax, idx: Tensor) struct { PartialSoftmax, Tensor } { - const next = self.nextKeyValChunk(idx); - return .{ prev.merge(next), idx.addConstant(1) }; - } - - fn nextKeyValChunk(self: SdpaMemEfficient, idx: Tensor) PartialSoftmax { - const k_slice: zml.Tensor.DynSlice = .{ - .start = idx.scale(self.chunking.k_chunk_size), - .len = self.chunking.k_chunk_size, - }; - - const k_chunk = self.k.dynamicSlice(.{ .k = k_slice }); - const v_chunk = self.v.dynamicSlice(.{ .k = k_slice }); - const attn_chunk = if (self.sdpa_opts.attn_mask) |mask| mask.dynamicSlice(.{ .k = k_slice }) else null; - - return sdpaChunk(self.q, k_chunk, v_chunk, .{ .attn_mask = attn_chunk }); - } - - pub fn hasNextKeyValChunk(self: SdpaMemEfficient, _: PartialSoftmax, idx: Tensor) zml.Tensor { - const n_chunks = @divExact(self.k.dim(.k), self.chunking.k_chunk_size); - return idx.cmp(.LT, Tensor.scalar(n_chunks, idx.dtype())); - } -}; - -pub const PartialSoftmax = struct { - values: Tensor, - exp_sum: Tensor, - max_value: Tensor, - - pub fn zeros(q_shape: Shape, exp_sum_precision: DataType) PartialSoftmax { - return .{ - .values = Tensor.constant(q_shape, q_shape.dtype().zero()), - .exp_sum = Tensor.constant(q_shape.setDim(.hd, 1), exp_sum_precision.zero()), - .max_value = Tensor.constant(q_shape.setDim(.hd, 1), q_shape.dtype().minValue()), - }; - } - - pub fn merge(self: PartialSoftmax, other: PartialSoftmax) PartialSoftmax { - // Rescale self and other using the new global_max. - const global_max = self.max_value.maximum(other.max_value); - const new_self = self.rescale(global_max); - const new_other = other.rescale(global_max); - - // Now that self and other are using the same scale, we can just add them: - return .{ - .max_value = global_max, - .values = new_self.values.add(new_other.values), - .exp_sum = new_self.exp_sum.add(new_other.exp_sum), - }; - } - - /// Update max_value and rescale attn and exp_sum accordingly. - pub fn rescale(self: PartialSoftmax, max_value: Tensor) PartialSoftmax { - const max_diff_exp = self.max_value.sub(max_value).exp(); - const sum_dtype = self.exp_sum.dtype(); - return .{ - .max_value = max_value, - .values = self.values.mul(max_diff_exp.broad(self.values.shape())), - .exp_sum = self.exp_sum.mul(max_diff_exp.convert(sum_dtype)), - }; - } - - /// Divides the intermediary results by the exp_sum to get the proper attention values. - pub fn finalize(self: PartialSoftmax) Tensor { - return self.values.div(self.exp_sum.broad(self.values.shape()).convert(self.values.dtype())); - } -}; - -/// Compute softmax over a chunk. -/// Returns intermediary results to allow aggregating later. -pub fn partialSoftmax(self: Tensor, axis: anytype) PartialSoftmax { - const a = self.axis(axis); - const max_val = self.max(a).maximum(Tensor.scalar(-1e16, self.dtype())); - const out = self.sub(max_val.broad(self.shape())).exp(); - return .{ - .values = out, - .exp_sum = out.convert(.f32).sum(a), - .max_value = max_val, - }; -} - -/// Compute sdpa on a chunk, and computes a partial softmax. -/// q: (B, H, Sq, H_dim) ⊙ k: (B, H, Sk, H_dim) -> qk: (B, H, Sq, Sk) -pub fn sdpaChunk(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) PartialSoftmax { - // this is a dupe of sdpa, but return the PartialSoftmax instead of true Attn. - // Consider implementing sdpa from sdpaChunk. - var q, var k, var v = .{ q_, k_, v_ }; - - const err_template = "sdpa(q: {f}, k: {f}, v: {f}, attn: {?f}) is invalid ! "; - const err_args = .{ q, k, v, opts.attn_mask }; - stdx.debug.assert(q.shape().hasTags(.{ .h, .q, .hd }), err_template ++ "q is missing tags {{.h, .q, .hd}}", err_args); - stdx.debug.assert(k.shape().hasTags(.{ .h, .k, .hd }), err_template ++ "k is missing tags {{.h, .k, .hd}}", err_args); - stdx.debug.assert(v.shape().hasTags(.{ .h, .k, .hd }), err_template ++ "v is missing tags {{.h, .k, .hd}}", err_args); - - if (q.dim(.h) != k.dim(.h)) { - stdx.debug.assert(@mod(q.dim(.h), k.dim(.h)) == 0, err_template ++ "Different number of heads for keys and queries, but can't repeat keys.", err_args); - // Note: we don't try to repeat queries. - // Repeating keys is the interesting optimisation cause it reduces KV cache memory usage. - const num_rep: u63 = @intCast(@divExact(q.dim(.h), k.dim(.h))); - k, v = .{ k.repeat1d(.h, num_rep), v.repeat1d(.h, num_rep) }; - } - const attn_mask = if (opts.attn_mask) |m| m else null; - - const dims = helpers.collectDims(.{ .h, .q, .k, .hd }, &.{ q, k, v, attn_mask }, .strict) catch { - stdx.debug.panic(err_template ++ "Inputs have incompatible shapes.", err_args); - }; - const sqrtHeadDim: f32 = 1.0 / std.math.sqrt(@as(f32, @floatFromInt(dims.hd))); - const head_scaling = if (opts.scale) |s| s else Tensor.scalar(sqrtHeadDim, k.dtype()); - k = k.mul(head_scaling.convert(k.dtype())); - - var attn_weights = q.dot(k, .{.hd}); - // log.debug("attn_weights : {}", .{attn_weights}); - // log.debug("attn_mask : {?}", .{attn_mask}); - if (attn_mask) |mask| attn_weights = attn_weights.add(mask.broad(attn_weights.shape())); - - const partial = partialSoftmax(attn_weights, .k); - const attn = partial.values.dot(v, .{.k}).transpose(q.shape()); - - return .{ - .values = attn, - // The renaming is because the above dot projected values.k into .hd, - // do the same thing on the other tensors. - // This work because dot is a linear operation, and commutes with `PartialSoftmax.finalize` - .exp_sum = partial.exp_sum.rename(.{ .k = .hd }).transpose(attn.shape()), - .max_value = partial.max_value.rename(.{ .k = .hd }).transpose(attn.shape()), - }; -} - -test sdpaMemEfficient { - const platform = zml.testing.env(); - const allocator = std.testing.allocator; - - // Note we use small input vectors to have the tests run reasonably fast, - // but don't expect speed ups with this small sizes. - const rng = try zml.compileFn(allocator, Tensor.Rng.normal, .{ Shape.init(.{ 1, 10, 512, 64 }, .f32), .{ .mean = 0, .stddev = 1 } }, platform); - defer rng.deinit(); - - const rng_mask = try zml.compileFn(allocator, Tensor.Rng.normal, .{ Shape.init(.{ 512, 512 }, .f32), .{ .mean = 0, .stddev = 1 } }, platform); - defer rng_mask.deinit(); - - // Note: we pass void here, cause Rng.normal doesn't take any runtime inputs. - const q = rng.call({}).withTags(.{ .b, .h, .q, .hd }); - const k = rng.call({}).withTags(.{ .b, .h, .k, .hd }); - const v = rng.call({}).withTags(.{ .b, .h, .k, .hd }); - const mask = rng_mask.call({}).withTags(.{ .q, .k }); - - const ref_res = try zml.testing.compileAndCall( - platform, - sdpa, - .{ q, k, v, .{ .attn_mask = mask, .scale = null } }, - ); - try std.testing.expectEqualSlices(i64, q.shape().dims(), ref_res.shape().dims()); - { - // 4 k_chunks - const res = try zml.testing.compileAndCall( - platform, - sdpaMemEfficient, - .{ - q, - k, - v, - .{ .attn_mask = mask, .scale = null }, - .{ .q_chunk_size = 256, .k_chunk_size = @divExact(512, 4) }, - }, - ); - - try zml.testing.expectClose(ref_res, res, 2e-3); - } - { - // 16 k_chunks - const res = try zml.testing.compileAndCall( - platform, - sdpaMemEfficient, - .{ - q, - k, - v, - .{ .attn_mask = mask, .scale = null }, - .{ .q_chunk_size = 256, .k_chunk_size = @divExact(512, 16) }, - }, - ); - - try zml.testing.expectClose(ref_res, res, 2e-3); - } -} - -test "sdpaMemEfficient transposed" { - const platform = zml.testing.env(); - const allocator = std.testing.allocator; - - // Note we use small input vectors to have the tests run reasonably fast, - // but don't expect speed ups with this small sizes. - const rng = try zml.compileFn(allocator, Tensor.Rng.normal, .{ Shape.init(.{ 1, 512, 10, 64 }, .f32), .{ .mean = 0, .stddev = 1 } }, platform); - defer rng.deinit(); - - const rng_mask = try zml.compileFn(allocator, Tensor.Rng.normal, .{ Shape.init(.{ 512, 512 }, .f32), .{ .mean = 0, .stddev = 1 } }, platform); - defer rng_mask.deinit(); - - // Note: we pass void here, cause Rng.normal doesn't take any runtime inputs. - const q = rng.call({}).withTags(.{ .b, .q, .h, .hd }); - const k = rng.call({}).withTags(.{ .b, .k, .h, .hd }); - const v = rng.call({}).withTags(.{ .b, .k, .h, .hd }); - const mask = rng_mask.call({}).withTags(.{ .q, .k }); - - const ref_res = try zml.testing.compileAndCall( - platform, - sdpa, - .{ q, k, v, .{ .attn_mask = mask, .scale = null } }, - ); - try std.testing.expectEqualSlices(i64, q.shape().dims(), ref_res.shape().dims()); - - { - const res = try zml.testing.compileAndCall( - platform, - sdpaMemEfficient, - .{ - q, - k, - v, - .{ .attn_mask = mask, .scale = null }, - .{ .q_chunk_size = @divExact(512, 2), .k_chunk_size = @divExact(512, 4) }, - }, - ); - - try zml.testing.expectClose(ref_res, res, 1e-3); - } - - { - const res = try zml.testing.compileAndCall( - platform, - sdpaMemEfficient, - .{ - q, - k, - v, - .{ .attn_mask = mask, .scale = null }, - .{ .q_chunk_size = 512, .k_chunk_size = @divExact(512, 4) }, - }, - ); - - try zml.testing.expectClose(ref_res, res, 1e-3); - } + return attn.transpose(q.shape()).merge(.{ .h = .{ .h, .hq } }); } /// Options controlling generation. The default values correspond to greedy decoding. @@ -1225,9 +963,9 @@ pub fn sampleTokens(activations: Tensor, opts: SamplingStrategy, rng: Tensor.Rng return .{ next_tokens, rng }; } - const topk = activations.topK(opts.topk, .voc, .{}); - // After the topk, we don't have .voc values, anymore, only topk. - var x = topk.values.rename(.{ .voc = .topk }); + const topk = activations.topK(.{ .topk = .voc }, opts.topk, .{}); + // After the topk, we don't have .voc values, anymore, only .topk. + var x = topk.values; if (opts.temperature != 1.0) { x = x.scale(1 / opts.temperature); } @@ -1242,7 +980,7 @@ pub fn sampleTokens(activations: Tensor, opts: SamplingStrategy, rng: Tensor.Rng // topk_idx is indices into topk.values ! so in the range [0, topk] // Convert for the original indices from the full [0, voc] range. - const next_tokens = topk.indices.gatherValues(.voc, topk_idx.squeeze(.topk), .{}); + const next_tokens = topk.indices.gather(.{ .topk = topk_idx.squeeze(.topk) }, .{}); // log.debug("sampleTokens({}) -> {} -> {} -> {}", .{ activations, topk.indices, topk_idx, next_tokens }); return .{ next_tokens, next_rng }; } @@ -1330,7 +1068,7 @@ pub fn sampleTokensDynamic(logits: Tensor, opts: DynamicSamplingStrategy, rng: T x = x.add(gumbel_noise); const topk_idx = x.argMax(.topk).indices; - const next_tokens = topk_indices.gatherValues(.voc, topk_idx.squeeze(.topk), .{}); + const next_tokens = topk_indices.gather(.{ .voc = topk_idx.squeeze(.topk) }, .{}); return .{ next_tokens, next_rng }; } @@ -1339,7 +1077,7 @@ fn fixupLogits(logits: Tensor, opts: DynamicSamplingStrategy) [2]Tensor { // First reduce the vocab size to a reasonable sub set of candidate. const full_topk = if (opts.max_top_k > 0) - logits.topK(opts.max_top_k, .voc, .{ .descending = true }) + logits.topK(.{ .voc = .voc }, opts.max_top_k, .{ .descending = true }) else logits.sort(.voc, .{ .descending = true }); diff --git a/zml/pjrtx.zig b/zml/pjrtx.zig index 46c4052..f09ffa1 100644 --- a/zml/pjrtx.zig +++ b/zml/pjrtx.zig @@ -5,25 +5,30 @@ const dialects = @import("mlir/dialects"); const mlir = @import("mlir"); const pjrt = @import("pjrt"); pub const ffi = pjrt.ffi; -pub const ApiError = pjrt.ApiError; -pub const ErrorCode = pjrt.ErrorCode; -pub const ExecuteContext = pjrt.ExecuteContext; -pub const BufferType = pjrt.BufferType; -pub const Device = pjrt.Device; -pub const MemoryStats = pjrt.MemoryStats; -pub const DeviceDescription = pjrt.DeviceDescription; pub const Api = pjrt.Api; -pub const NamedValue = pjrt.NamedValue; +pub const ApiError = pjrt.ApiError; +pub const BufferType = pjrt.BufferType; pub const ClientInitError = pjrt.ClientInitError; -pub const Error = pjrt.Error; -pub const GetCostAnalysisError = pjrt.GetCostAnalysisError; -pub const SerializeResult = pjrt.SerializeResult; -pub const Executable = pjrt.Executable; pub const CompiledMemoryStats = pjrt.CompiledMemoryStats; +pub const Device = pjrt.Device; +pub const DeviceDescription = pjrt.DeviceDescription; +pub const Error = pjrt.Error; +pub const ErrorCode = pjrt.ErrorCode; +pub const Executable = pjrt.Executable; +pub const ExecuteContext = pjrt.ExecuteContext; pub const ExecuteError = ApiError; +pub const GetCostAnalysisError = pjrt.GetCostAnalysisError; pub const Memory = pjrt.Memory; +pub const MemoryStats = pjrt.MemoryStats; +pub const NamedValue = pjrt.NamedValue; +pub const Profiler = pjrt.Profiler; +pub const SerializeResult = pjrt.SerializeResult; pub const Stream = pjrt.Stream; +const zml = struct { + pub const DataType = @import("dtype.zig").DataType; +}; + const log = std.log.scoped(.zml); pub const CompileError = std.mem.Allocator.Error || error{InvalidMlirBytecodeVersion} || ApiError; @@ -165,8 +170,8 @@ pub const Buffer = opaque { return @ptrCast(try self.inner().toHostBuffer(api, dst)); } - pub fn getElementType(self: *const Buffer, api: *const Api) BufferType { - return self.inner().getElementType(api); + pub fn getElementType(self: *const Buffer, api: *const Api) zml.DataType { + return dtypeFromBufferType(self.inner().getElementType(api)); } pub fn getDimensions(self: *const Buffer, api: *const Api) []const i64 { @@ -329,3 +334,29 @@ pub const AsyncHostToDeviceTransferManager = opaque { return self.inner().addMetadata(api, transfer_metadata); } }; + +pub fn bufferTypeFromDtype(dt: zml.DataType) pjrt.BufferType { + return switch (dt) { + inline else => |tag| @field(pjrt.BufferType, @tagName(tag)), + }; +} + +pub fn dtypeFromBufferType(pjrt_type: pjrt.BufferType) zml.DataType { + return switch (pjrt_type) { + .invalid => @panic("Found an invalid pjrt buffer"), + inline else => |tag| @field(zml.DataType, @tagName(tag)), + }; +} + +test bufferTypeFromDtype { + inline for (@typeInfo(zml.DataType).@"enum".fields) |field| { + const dt: zml.DataType = @enumFromInt(field.value); + try std.testing.expectEqual(dt, dtypeFromBufferType(bufferTypeFromDtype(dt))); + } + + inline for (@typeInfo(pjrt.BufferType).@"enum".fields) |field| { + const dt: pjrt.BufferType = @enumFromInt(field.value); + if (dt == .invalid) continue; + try std.testing.expectEqual(dt, bufferTypeFromDtype(dtypeFromBufferType(dt))); + } +} diff --git a/zml/quantization.zig b/zml/quantization.zig index 2a850a4..90a35bd 100644 --- a/zml/quantization.zig +++ b/zml/quantization.zig @@ -1,9 +1,11 @@ const std = @import("std"); -const zml = @import("zml.zig"); - const Allocator = std.mem.Allocator; + +const zml = @import("zml.zig"); const module = zml.module; +// TODO add tests, use modern zml + pub fn Q4_0(comptime dtype: zml.DataType) type { return struct { const Self = @This(); @@ -38,38 +40,12 @@ pub fn Q4_0(comptime dtype: zml.DataType) type { const scales = extractScales(block_count, input); const weights = extractWeights(block_count, input); - return scales.reshape(.{ block_count, 1 }) - .broadcastLeft(zml.Shape.init(.{ block_count, 32 }, .f32)) + return scales.broadcast(weights.shape(), &.{0}) .mul(weights) .convert(dtype) - .reshape(.{block_count * 32}) .reshape(shape); } - pub fn scaleIndices(block_count: u63) zml.Tensor { - // indices1 is the offsets of the scale bytes, repeated block_count times. - const indices1 = zml.Tensor.arange(.{ .start = 0, .end = 2 }, .i32).repeat1d(0, block_count); - - // indices2 is the offsets of the blocks, repeated for each scale byte, repeated block_count times. - const indices2 = zml.Tensor.arange(.{ .start = 0, .end = block_stride * block_count, .step = block_stride }, .i32) - .reshape(.{ block_count, 1 }).broadcastLeft(zml.Shape.init(.{ block_count, 2 }, .i32)).reshape(.{2 * block_count}); - - // indices is the sum of the two, which is the offsets to all the bytes we are interested in. - return indices1.add(indices2); - } - - pub fn weightIndices(block_count: u63) zml.Tensor { - // indices1 is the offsets of the data bytes, repeated block_count times. - const indices1 = zml.Tensor.arange(.{ .start = 2, .end = 18 }, .i32).repeat1d(0, block_count); - - // indices2 is the offsets of the blocks, repeated for each data byte, repeated block_count times. - const indices2 = zml.Tensor.arange(.{ .start = 0, .end = block_stride * block_count, .step = block_stride }, .i32) - .reshape(.{ block_count, 1 }).broadcastLeft(zml.Shape.init(.{ block_count, 16 }, .i32)).reshape(.{16 * block_count}); - - // indices is the sum of the two, which is the offsets to all the bytes we are interested in. - return indices1.add(indices2); - } - pub fn extractScales(block_count: u63, input: zml.Tensor) zml.Tensor { // The goal here is to get the first two bytes of every 18-bytes block in the input. For that, // we generate a list of indices that we will use to gather from the input. @@ -85,7 +61,7 @@ pub fn Q4_0(comptime dtype: zml.DataType) type { const indices = indices1.add(indices2); // We select the values we are interested in with the indices, group them by pair and bitcast them to f16, then convert them to f32. - const scales = input.gatherValues(0, indices, .{ .indices_are_sorted = true }).reshape(.{ block_count, 2 }).bitCast(.f16).convert(.f32); + const scales = input.gather_(&.{0}, &.{indices}, .{ .indices_are_sorted = true }).reshape(.{ block_count, 2 }).bitCast(.f16).convert(.f32); return scales; } @@ -107,7 +83,7 @@ pub fn Q4_0(comptime dtype: zml.DataType) type { // NOTE(Corendos): i4 is not supported by bitcast convert, so we need the following workaround. // We select the values we are interested in with the indices, these are our quantized_weights. - const quantized_weights = input.gatherValues(0, indices, .{ .indices_are_sorted = true }); + const quantized_weights = input.gather_(&.{0}, &.{indices}, .{ .indices_are_sorted = true }); const lb_weights = quantized_weights .logical(.And, zml.Tensor.constant(.{16 * block_count}, zml.Data.init(.u8, 0xf))) .bitCast(.i8); diff --git a/zml/shape.zig b/zml/shape.zig index a9daffa..564f33b 100644 --- a/zml/shape.zig +++ b/zml/shape.zig @@ -413,6 +413,13 @@ pub const Shape = struct { // Already the right shape if (std.mem.eql(i64, self.dims(), other.dims())) return true; + if (std.mem.eql(Tag, self.tags(), other.tags())) { + for (0..self.rank()) |i| { + if (self.dim(i) != 1 and self.dim(i) != other.dim(i)) return false; + } + return true; + } + // Non ambiguous broadcasting // TODO: broad is error prone because of this: // it will happily broadcast .{ .a = 10, .b = 1 } to .{ .b = 10, .a = 5 } @@ -424,6 +431,7 @@ pub const Shape = struct { } for (self.dims(), self.tags()) |d, t| { + // TODO this is also wrong when the axes are in different order in the two shapes. const other_ax = other.hasTag(t) orelse return false; if (d != 1 and d != other.dim(other_ax)) return false; } diff --git a/zml/tensor.zig b/zml/tensor.zig index 373d5ef..78219cd 100644 --- a/zml/tensor.zig +++ b/zml/tensor.zig @@ -311,7 +311,7 @@ pub const Tensor = struct { /// Returns the given tensor as one contiguous buffer of bytes. pub fn bytes(self: Tensor) Tensor { - return self.bitCast(.u8).flattenAll().withTags(.{.bytes}); + return self.bitCast(.u8).flatten().withTags(.{.bytes}); } /// Returns a Tensor containing the element-wise number of leading 0 bits in the input Tensor. @@ -700,7 +700,7 @@ pub const Tensor = struct { pub fn gumbelStats(rand: Rng, target_dist: Tensor) struct { Rng, Stats } { const s = Shape.init(.{ .n = 1024, .d = 4 }, .f32); const rng, const data = rand.gumbel(s); - const flat = data.flattenAll(); + const flat = data.flatten(); const mean_ = flat.mean(0); const variance = flat.sub(mean_.broad(flat.shape())).pow(Tensor.scalar(2, .f32)).mean(0); @@ -719,7 +719,7 @@ pub const Tensor = struct { break :blk powers; }; const values = Tensor.constantTensor(HostBuffer.fromArray(&powers)).withTags(.{.d}); - const counts = values.gatherValues(.d, samples, .{}).sum(.n).bitCast(.u16); + const counts = values.gather(.{ .d = samples }, .{}).sum(.n).bitCast(.u16); const actual_dist = counts.reshape(target_dist.shape()).convert(target_dist.dtype()).divByConst(s.dim(.n)); return .{ rng, .{ .mean = mean_, .variance = variance, .actual_dist = actual_dist } }; } @@ -993,17 +993,22 @@ pub const Tensor = struct { /// Returns a Tensor containing the element-wise addition of the input Tensor with a constant. pub fn addConstant(self: Tensor, b: anytype) Tensor { - return self.add(Tensor.scalar(b, self.dtype())); + return self.add(.scalar(b, self.dtype())); } /// Returns a Tensor containing the element-wise division of the input Tensor by a constant. pub fn divByConst(self: Tensor, b: anytype) Tensor { - return self.div(Tensor.scalar(b, self.dtype())); + return self.div(.scalar(b, self.dtype())); + } + + /// Returns a Tensor containing the element-wise power of the input Tensor by a constant. + pub fn powByConst(self: Tensor, b: anytype) Tensor { + return self.pow(.scalar(b, self.dtype())); } /// Returns a Tensor containing the element-wise multiplication of the input Tensor by a constant. pub inline fn scale(self: Tensor, val: anytype) Tensor { - return self.mul(Tensor.scalar(val, self.dtype())); + return self.mul(.scalar(val, self.dtype())); } pub const LogicalOp = enum { OR, XOR, AND }; @@ -1325,7 +1330,7 @@ pub const Tensor = struct { const row_mask = max_val.cmp(.GT, Tensor.scalar(-std.math.inf(f64), self.dtype())); const exp_diff_max = self.sub(self.max(a).broad(self._shape)).exp(); - const res = exp_diff_max.div(exp_diff_max.sum(a).broad(self._shape)); + const res = exp_diff_max.div(exp_diff_max.sum(a)); // If a row is full -inf return full 0 instead of full nan, // this fix attention when mask hides a full row. @@ -1519,27 +1524,7 @@ pub const Tensor = struct { return self.transpose(perm.constSlice()); } - /// Flattens the given axis and the next one, into one new axis. - pub fn flatten(self: Tensor, axis_: anytype) Tensor { - // TODO: move to torch.zig, this is equivalent to merge - const old_shape = self._shape; - const a = self.axis(axis_); - // stdx.debug.assert(a + 1 < self.rank(), "Can't flatten {} on the last axis {}.", .{ self, axis }); - const new_shape = old_shape.remove(a + 1).set(a, old_shape.dim(a) * old_shape.dim(a + 1)); - - const loc = self.getContext().location(@src(), "flatten({f},{})", .{ self, axis_ }); - const reshaped_val = dialect.stablehlo.reshape( - self.getContext().mlirCtx(), - self.value(), - mlirx.tensorType(self.getContext().mlirCtx(), new_shape), - loc, - ); - // log.debug("flatten({d}, {d}) -> {d}", .{ self.dims(), axis_, new_shape[0 .. self.rank() - 1] }); - return _result(new_shape, reshaped_val.result(0)); - } - - pub inline fn flattenAll(self: Tensor) Tensor { - // TODO: rename to just flatten, once flatten is moved to torch + pub inline fn flatten(self: Tensor) Tensor { return self.reshape(.{self.count()}); } @@ -1727,14 +1712,12 @@ pub const Tensor = struct { } const a = self.axis(axis_); - const broadshape = self._shape.insert(a + 1, .{n_rep}); - const repeat_dims = Shape.range(self.rank() + 1, self.dtype()).remove(a + 1); + const res_shape = self._shape.setDim(a, self.dim(a) * n_rep); - var res = self.broadcast(broadshape, repeat_dims.dims()).flatten(a); - // Restor the tag that has been lost by flatten. - res._shape._tags.set(a, self._shape.tag(a)); + const broadshape = self._shape.insert(a, .{n_rep}); + const repeat_dims = Shape.range(self.rank() + 1, self.dtype()).remove(a); - return res; + return self.broadcast(broadshape, repeat_dims.dims()).reshape(res_shape); } /// Repeats a Tensor several times along the given axes. @@ -1751,15 +1734,47 @@ pub const Tensor = struct { return res; } + test repeat1d { + const zml = @import("zml.zig"); + const platform = zml.testing.env(); + + const Local = struct { + fn repeat1d(x: Tensor, axis_: u3, n_reps: u32) Tensor { + return x.repeat1d(axis_, n_reps); + } + }; + + { + const inputs: [3]u8 = .{ 1, 2, 3 }; + const expectations: [6]u8 = .{ 1, 2, 3, 1, 2, 3 }; + + const input = try zml.Buffer.fromArray(platform, inputs); + const output = try zml.testing.compileAndCall(platform, Local.repeat1d, .{ input, 0, 2 }); + + try std.testing.expectEqual(expectations, output.getValue(@TypeOf(expectations))); + } + { + const inputs: [2][3]u8 = .{ .{ 1, 2, 3 }, .{ 4, 5, 6 } }; + const expectations: [2][6]u8 = .{ .{ 1, 2, 3, 1, 2, 3 }, .{ 4, 5, 6, 4, 5, 6 } }; + + const input = try zml.Buffer.fromArray(platform, inputs); + const output = try zml.testing.compileAndCall(platform, Local.repeat1d, .{ input, 1, 2 }); + + try std.testing.expectEqual(expectations, output.getValue(@TypeOf(expectations))); + } + } + /// Repeats in line each value along the given axis. /// - /// * stutter1d([0, 1, 2, 3], 0, 2) = [0, 0, 1, 1, 2, 2, 3, 3] - pub fn stutter1d(self: Tensor, axis_: i64, n_rep: u63) Tensor { + /// * stutter1d([0, 1, 2, 3], -1, 2) = [0, 0, 1, 1, 2, 2, 3, 3] + /// This is equivalent to repeat(ax+1) unless ax is the last axis. + pub fn stutter1d(self: Tensor, axis_: i8, n_rep: u63) Tensor { const a = self.axis(axis_); const broadshape = self._shape.insert(a + 1, .{n_rep}); - const stutter_dims = Shape.range(self.rank() + 1, self.dtype()).remove(a + 1); + const res_shape = self._shape.setDim(a, self.dim(a) * n_rep); - return self.broadcast(broadshape, stutter_dims.dims()).flatten(a); + const stutter_dims = Shape.range(self.rank() + 1, self.dtype()).remove(a + 1); + return self.broadcast(broadshape, stutter_dims.dims()).reshape(res_shape); } /// Repeats in line each value along the given axes. @@ -1775,6 +1790,45 @@ pub const Tensor = struct { return res; } + test stutter1d { + const zml = @import("zml.zig"); + const platform = zml.testing.env(); + + const Local = struct { + fn stutter1d(x: Tensor, axis_: u3, n_reps: u32) Tensor { + return x.stutter1d(axis_, n_reps); + } + }; + + { + const inputs: [3]u8 = .{ 1, 2, 3 }; + const expectations: [6]u8 = .{ 1, 1, 2, 2, 3, 3 }; + + const input = try zml.Buffer.fromArray(platform, inputs); + const output = try zml.testing.compileAndCall(platform, Local.stutter1d, .{ input, 0, 2 }); + + try std.testing.expectEqual(expectations, output.getValue(@TypeOf(expectations))); + } + { + const inputs: [2][3]u8 = .{ .{ 1, 2, 3 }, .{ 4, 5, 6 } }; + const expectations: [2][6]u8 = .{ .{ 1, 1, 2, 2, 3, 3 }, .{ 4, 4, 5, 5, 6, 6 } }; + + const input = try zml.Buffer.fromArray(platform, inputs); + const output = try zml.testing.compileAndCall(platform, Local.stutter1d, .{ input, 1, 2 }); + + try std.testing.expectEqual(expectations, output.getValue(@TypeOf(expectations))); + } + { + const inputs: [2][3]u8 = .{ .{ 1, 2, 3 }, .{ 4, 5, 6 } }; + const expectations: [2][6]u8 = .{ .{ 1, 2, 3, 1, 2, 3 }, .{ 4, 5, 6, 4, 5, 6 } }; + + const input = try zml.Buffer.fromArray(platform, inputs); + const output = try zml.testing.compileAndCall(platform, Local.stutter1d, .{ input, 0, 2 }); + + try std.testing.expectEqual(expectations, output.getValue(@TypeOf(expectations))); + } + } + /// Returns a Tensor containing the element-wise negation of the input Tensor. pub fn negate(self: Tensor) Tensor { const loc = self.getContext().mlirCtx().location(@src()); @@ -2160,88 +2214,98 @@ pub const Tensor = struct { pub const GatherOpts = struct { indices_are_sorted: bool = false }; - /// For each coordinate in `indices`, - /// `gatherValues` extracts a single value of the given tensor. + /// `gather` extracts slices from the given tensor at the specified offsets. + /// example: `values.gather(.{ .a = idx }, .{})` + /// + /// * indices is a named list of integer tensors: eg `.{ .a = idx }`. + /// Each names specify a gathering axis, it must refer to an axis of self, + /// and the corresponding idx Tensor must contains valid indices into axis .a. + /// All indices must have the same shape or broadcast to the same shape. /// - /// * axes_ is a single axis, or a tuple of axis: .b, or .{ .b, .c } - /// * indices is an integer tensor /// * result is a tensor whose shape is similar to the input shape /// where the gathered axes have been replaced by axes from 'indices'. /// /// Some example input for the base case where we work on one axis: - /// - gatherValues(f:[a]->float, .a, ind:[n]->int)[n] == f[ind[n]] - /// - gatherValues(f:[a, b], .a, ind:[n])[n, b] == f[ind[n], b] - /// - gatherValues(f: [a,b,c], .{.b}, ind: [n,m])[a, n, m, c] == f[a, ind[n, m], c] + /// - gather(f:[a], .{ .a = idx:[n]})[n] == f[idx[n]] + /// - gather(f:[a, b], .a, idx:[n])[n, b] == f[idx[n], b] + /// - gather(f:[a,b,c], .{.b = idx:[n,m]})[a, n, m, c] == f[a, idx[n, m], c] /// /// If an axis in common between `self` and `indices`, /// it is treated as a "batching" axis, meaning that semantically - /// the operator is doing a gatherValues one time per dimension of this axis: - /// - gatherValues(f: [a,b,c], .{.b}, ind: [a,n])[a, n] == f[a, ind[a, n]] + /// the operator is doing a gather one time per dimension of this axis: + /// - gather(f: [a,b,c], .{.b=idx: [a,n]})[a, n] == f[a, idx[a, n]] /// - /// It is an error to have an axis present in `self`, `axes_` and `indices`. + /// It's possible to pass several indices: + /// - gather(f: [a,b,c], .{.b=idx_b[n], .c=idx_c[n]})[a, n] == f[a, idx_b[n], idx_c[n]] + /// - gather(f: [a,b,c,d], .{.b=idx_b[a,n], .c=idx_c[a, n]})[a, n, d] == f[a, idx_b[a, n], idx_c[a, n], d] /// - /// If several axes are passed, then the last axis of indices is treated as coordinates: - /// - gatherValues(f: [a,b,c], .{.b, .c}, ind: [n,2])[a, n] == f[a, ind[n][0], ind[n][1]] - /// - gatherValues(f: [a,b,c,d], .{.b, .c}, ind: [a, n,2])[a, n, d] == f[a, ind[a, n][0], ind[a, n][1], d] + /// If `self` isn't tagged, you can use `gather_` to specify gathered axis by their position but batching won't be available. /// - /// It is possible to use gatherValues without tags, but batching won't be available. - pub fn gatherValues(self: Tensor, coord_axes: anytype, indices: Tensor, opts: GatherOpts) Tensor { - // scoped_log.debug("gatherValues({}, {any}, {})", .{ self, coord_axes, indices }); - const single_coord, const coord_axes_ = _parseGatherCoord(self, coord_axes); + /// For performance it's better to have batching and gathering axes of `self` be the first one, + /// so that gather can + pub fn gather(self: Tensor, _indices: anytype, opts: GatherOpts) Tensor { + const idx_per_axis, const idx_tags = Shape.parseStruct(Tensor, _indices); + var idx_axes: Shape.AxesArray = .{}; + for (idx_tags.slice()) |t| { + idx_axes.appendAssumeCapacity(self.axis(t)); + } - stdx.debug.assert(coord_axes_.len > 0, "gatherValues expects 1 or more axes to operate one, received none. Example: `x.gatherValues(.a, indices, .{{}})`", .{}); - for (coord_axes_.constSlice(), 0..) |a, i| { + // TODO: sort indices following self.shape instead of asking the user to do it. + return self.gather_(idx_axes.slice(), idx_per_axis.slice(), opts); + } + + pub fn gather_(self: Tensor, idx_axes: []const u3, idx_per_axis: []const Tensor, opts: GatherOpts) Tensor { + stdx.debug.assert(idx_axes.len > 0, "gather expects 1 or more axes to operate one, received none. Example: `x.gather(.a, indices, .{{}})`", .{}); + for (idx_axes, 0..) |a, i| { if (i > 0) { - stdx.debug.assert(a == coord_axes_.get(i - 1) + 1, "gatherValues expects 'coord_axes' to be sequential. But {any} aren't sequential in {f}", .{ coord_axes, self }); + stdx.debug.assert(a == idx_axes[i - 1] + 1, "gather expects 'idx_axes' to be sequential. But {any} aren't sequential in {f}", .{ idx_axes, self }); } } + var indices_shape = idx_per_axis[0].shape(); + for (idx_per_axis[1..]) |idx| { + if (idx.rank() > indices_shape.rank()) { + indices_shape = idx.shape(); + } + } + for (idx_per_axis) |idx| { + stdx.debug.assert(idx.shape().canBroadcastTo(indices_shape), "gather indices can't be broadcasted together {any}", .{idx_per_axis}); + } + + var idx_batch_axes: Shape.DimsArray = .{}; const AxisKind = enum { batching, offset, collapsed, indices }; - var self_kind: stdx.BoundedArray(AxisKind, MAX_RANK) = .{}; - var indices_batch_axes: Shape.DimsArray = .{}; + var self_kind: stdx.BoundedArray(AxisKind, MAX_RANK) = .{ .buffer = @splat(.offset), .len = self.rank() }; + for (self._shape.tags(), 0..self.rank()) |t, self_ax| { - const maybe_coord_ax = std.mem.indexOfScalar(u3, coord_axes_.constSlice(), @intCast(self_ax)); - if (indices._shape.hasTag(t)) |id_ax| { + const is_gather_axis = std.mem.containsAtLeastScalar(u3, idx_axes, 1, @intCast(self_ax)); + if (indices_shape.hasTag(t)) |id_ax| { // tag is both in self and indices -> it's a batching dim // Note: tags are required for batching. - self_kind.appendAssumeCapacity(.batching); - indices_batch_axes.appendAssumeCapacity(id_ax); - stdx.debug.assert(maybe_coord_ax == null, "gatherValues expects axes to appear at most twice. Axis {s} has been found both in 'self={f}', in 'coord_axes_={any}' and in 'indices={f}'", .{ self._shape._tags.get(self_ax), self, coord_axes, indices }); - } else if (maybe_coord_ax) |_| { - // for gatherValues we collapsed all gathered axes - // (contrary to gatherSlices where we collapse none) - self_kind.appendAssumeCapacity(.collapsed); + self_kind.buffer[self_ax] = .batching; + idx_batch_axes.appendAssumeCapacity(id_ax); + stdx.debug.assert(!is_gather_axis, "gather expects axes to appear at most twice. Axis {s} has been found both in 'self={f}', in 'idx_axes={any}' and in 'indices={f}'", .{ t, self, idx_axes, indices_shape }); + } else if (is_gather_axis) { + // we collapsed all gathered axes + self_kind.buffer[self_ax] = .collapsed; + // idx_kind.buffer[id_ax] = .indices; } else { - self_kind.appendAssumeCapacity(.offset); + self_kind.buffer[self_ax] = .offset; } } - // When we receive several coord_axes we need an extra dimension to store - // one index per axis, which makes the coordinates of one value. - // Otherwi se stablehlo uses the "indices.rank()" default value. - const index_coord_axis = if (single_coord) - indices.rank() - else blk: { - const ax = indices._shape.hasTag(.coord) orelse indices._shape.axis(-1); - stdx.debug.assert(indices.dim(ax) == coord_axes_.len, "gatherValues with axes={any}, expects indices to be of shape [..., {}], got: {f}", .{ coord_axes, coord_axes_.len, indices }); - break :blk ax; - }; - // compute res shape var res_shape = Shape.init(.{}, self.dtype()); var res_kind: stdx.BoundedArray(AxisKind, MAX_RANK) = .{}; - for (self_kind.constSlice(), 0..) |kind, ax_usize| { + for (self_kind.slice(), 0..) |kind, ax_usize| { const ax: u3 = @intCast(ax_usize); - if (ax == coord_axes_.get(0)) { + if (ax == idx_axes[0]) { // The first val_ax is special cause this is the place where we insert indices axes. - for (indices._shape.tags(), 0..indices.rank()) |t, id_ax| { - if (id_ax == index_coord_axis) continue; - if (std.mem.indexOfScalar(i64, indices_batch_axes.constSlice(), @intCast(id_ax))) |_| { - // batching dim are already in res - continue; - } + for (0.., indices_shape.tags(), indices_shape.dims()) |id_axis_order, id_axis, id_inserted_dim| { + const is_batching_axis = std.mem.containsAtLeastScalar(i64, idx_batch_axes.constSlice(), 1, @intCast(id_axis_order)); + // Batching axis is already in self. + if (is_batching_axis) continue; - res_shape = res_shape.appendDim(indices.dim(id_ax), t); + res_shape = res_shape.appendDim(id_inserted_dim, id_axis); res_kind.appendAssumeCapacity(.indices); } } @@ -2257,12 +2321,12 @@ pub const Tensor = struct { // This is not a gather, but a dynamicSlice. // Sometimes the backend recognize this pattern, but not always. // So let us handle that. - if (indices.count() == 1) { - return self.dynamicSlice1d(coord_axes_.get(0), .{ .start = indices.flattenAll().squeeze(0), .len = 1 }).reshape(res_shape); + if (indices_shape.count() == 1 and idx_axes.len == 1) { + return self.dynamicSlice1d(idx_axes[0], .{ .start = idx_per_axis[0].asScalar(), .len = 1 }).reshape(res_shape); } var slice_dims: Shape.DimsArray = .{}; - for (self_kind.constSlice(), self.dims()) |k, d| { + for (self_kind.slice(), self.dims()) |k, d| { slice_dims.appendAssumeCapacity(switch (k) { .batching, .collapsed => 1, .offset => d, @@ -2270,7 +2334,9 @@ pub const Tensor = struct { }); } - // scoped_log.debug("gatherValues --> {} {any}", .{ res_shape, res_kind.constSlice() }); + // TODO: try changing .last by other axis and see the perf impact. + const indices = Tensor.stack(idx_per_axis, .last, .coord); + // scoped_log.debug("gather --> {} {any}", .{ res_shape, res_kind.constSlice() }); const loc = self.getContext().mlirCtx().location(@src()); const gather_op = dialect.stablehlo.gather( self.getContext().mlirCtx(), @@ -2282,22 +2348,30 @@ pub const Tensor = struct { .offset_dims = _collectAxes(AxisKind, res_kind, .offset).constSlice(), .collapsed_slice_dims = _collectAxes(AxisKind, self_kind, .collapsed).constSlice(), .operand_batching_dims = _collectAxes(AxisKind, self_kind, .batching).constSlice(), - .start_indices_batching_dims = indices_batch_axes.constSlice(), + .start_indices_batching_dims = idx_batch_axes.constSlice(), .start_index_map = _collectAxes(AxisKind, self_kind, .collapsed).constSlice(), - .index_vector_dim = index_coord_axis, + .index_vector_dim = indices.axis(.coord), .indices_are_sorted = opts.indices_are_sorted, }, ); const mlir_shape = fromMlirValue(gather_op.result(0)).shape(); - stdx.debug.assert(mlir_shape.eql(res_shape), "gatherValues expects that batching indices appear in the same order in 'self' and 'indices', got: self={f}, indices={f}. You should transpose one or the other.", .{ self, indices }); + stdx.debug.assert(mlir_shape.eql(res_shape), "gather expects that batching indices appear in the same order in 'self' and 'indices', got: self={f}, indices={f}. You should transpose one or the other.", .{ self, indices }); return _result(res_shape, gather_op.result(0)); } - test gatherValues { + test gather { const zml = @import("zml.zig"); const platform = zml.testing.env(); + const Local = struct { + pub fn _idx(idx_shape: anytype) Tensor { + return Tensor.constant(idx_shape, .{ .i32 = 0 }); + } + }; + + const idx = Local._idx; + { // Only test shapes var comp = try zml.module.CompilationContext.init(std.testing.allocator, "test", platform); @@ -2306,32 +2380,38 @@ pub const Tensor = struct { defer comp.deactivate(); inline for (.{ - .{ .{ .a = 10 }, .a, .{}, .{} }, - .{ .{ .a = 10 }, .a, .{ .n = 8 }, .{ .n = 8 } }, - .{ .{ .a = 10, .b = 20 }, .a, .{}, .{ .b = 20 } }, - .{ .{ .a = 10, .b = 20 }, .a, .{ .n = 8 }, .{ .n = 8, .b = 20 } }, - .{ .{ .a = 10, .b = 20 }, 0, .{ .n = 8 }, .{ .n = 8, .b = 20 } }, + .{ .{ .a = 10 }, .{ .a = idx(.{}) }, .{} }, + .{ .{ .a = 10 }, .{ .a = idx(.{ .n = 8 }) }, .{ .n = 8 } }, + .{ .{ .a = 10, .b = 20 }, .{ .a = idx(.{}) }, .{ .b = 20 } }, + .{ .{ .a = 10, .b = 20 }, .{ .a = idx(.{ .n = 8 }) }, .{ .n = 8, .b = 20 } }, + // .{ .{ .a = 10, .b = 20 }, 0, idx(.{ .n = 8 }), .{ .n = 8, .b = 20 } }, // Favor val shape, instead of indices shape. - .{ .{ .a = 10, .b = 20 }, .b, .{ .n = 8 }, .{ .a = 10, .n = 8 } }, - .{ .{ .a = 10, .b = 20, .c = 30 }, .b, .{ .n = 8 }, .{ .a = 10, .n = 8, .c = 30 } }, + .{ .{ .a = 10, .b = 20 }, .{ .b = idx(.{ .n = 8 }) }, .{ .a = 10, .n = 8 } }, + .{ .{ .a = 10, .b = 20, .c = 30 }, .{ .b = idx(.{ .n = 8 }) }, .{ .a = 10, .n = 8, .c = 30 } }, // batching axes are implicits. - .{ .{ .a = 10, .b = 20 }, .b, .{ .a = 10 }, .{ .a = 10 } }, - .{ .{ .a = 10, .b = 20 }, .a, .{ .b = 20 }, .{ .b = 20 } }, - .{ .{ .a = 10, .b = 20 }, .b, .{ .a = 10, .n = 8 }, .{ .a = 10, .n = 8 } }, + .{ .{ .a = 10, .b = 20 }, .{ .b = idx(.{ .a = 10 }) }, .{ .a = 10 } }, + .{ .{ .a = 10, .b = 20 }, .{ .a = idx(.{ .b = 20 }) }, .{ .b = 20 } }, + .{ .{ .a = 10, .b = 20 }, .{ .b = idx(.{ .a = 10, .n = 8 }) }, .{ .a = 10, .n = 8 } }, // stablehlo.gather is biased toward indices shape (like gatherSlice). // This make it awkward to use when you have both batching dimension and new indices dimensions. // For now we reject those, and let user explicitly transpose self or indices if needed. - // .{ .{ .a = 10, .b = 20 }, .b, .{ .n = 8, .a = 10 }, .{ .a = 10, .n = 8 } }, + // .{ .{ .a = 10, .b = 20 }, .{.b = idx(.{ .n = 8, .a = 10 })}, .{ .a = 10, .n = 8 } }, // Also handle tuples - .{ .{ .a = 10, .b = 20 }, .{ .a, .b }, .{ .n = 8, ._ = 2 }, .{ .n = 8 } }, - .{ .{ 10, 20 }, .{ -2, -1 }, .{ 8, 2 }, .{8} }, - // and 1-tuple - .{ .{ .a = 10, .b = 20 }, .{.b}, .{ .n = 8, ._ = 1 }, .{ .a = 10, .n = 8 } }, + .{ .{ .a = 10, .b = 20 }, .{ .a = idx(.{ .n = 8 }), .b = idx(.{ .n = 8 }) }, .{ .n = 8 } }, }) |testcase| { - const x_shape, const tag, const idx_shape, const res_shape = testcase; + const x_shape, const indices, const res_shape = testcase; const x = Tensor.constant(x_shape, .{ .f16 = 0 }); - const idx = Tensor.constant(idx_shape, .{ .i32 = 0 }); - const y = gatherValues(x, tag, idx, .{}); + const y = gather(x, indices, .{}); + try zml.testing.expectEqualShapes(Shape.init(res_shape, .f16), y.shape()); + try std.testing.expect(y.value().owner().verify()); + } + + inline for (.{ + .{ .{ 10, 20 }, &[_]u3{ 0, 1 }, &[_]Tensor{ idx(.{8}), idx(.{8}) }, .{8} }, + }) |testcase| { + const x_shape, const idx_axes, const idx_per_axis, const res_shape = testcase; + const x = Tensor.constant(x_shape, .{ .f16 = 0 }); + const y = gather_(x, idx_axes, idx_per_axis, .{}); try zml.testing.expectEqualShapes(Shape.init(res_shape, .f16), y.shape()); try std.testing.expect(y.value().owner().verify()); } @@ -2966,13 +3046,24 @@ pub const Tensor = struct { } /// Returns a Tensor representing the result of Top-K over the given axis. - pub fn topK(self: Tensor, k: u32, axis_: anytype, opts: struct { descending: bool = true }) SortRes { - const a = self.axis(axis_); - const result = self.sort(a, .{ .descending = opts.descending }); - return .{ - .values = result.values.slice1d(a, .{ .end = k }), - .indices = result.indices.slice1d(a, .{ .end = k }), + pub fn topK(self: Tensor, named_axis_: anytype, k: u32, opts: struct { descending: bool = true }) SortRes { + const err_msg = "topK named axis should be an integer or a named axis, eg `x.topK(.{{ .best_token = .token }}, 16)` or `x.topK(-1, 16)`"; + const has_name: ?[:0]const u8, const a = switch (@typeInfo(@TypeOf(named_axis_))) { + .int, .comptime_int => .{ null, self.axis(@as(i64, @intCast(named_axis_))) }, + .@"struct" => |info| blk: { + stdx.debug.assertComptime(info.fields.len == 1, err_msg, .{}); + break :blk .{ info.fields[0].name, self.axis(@field(named_axis_, info.fields[0].name)) }; + }, + else => stdx.debug.compileError(err_msg, .{}), }; + var result = self.sort(a, .{ .descending = opts.descending }); + result.values = result.values.slice1d(a, .{ .end = k }); + result.indices = result.indices.slice1d(a, .{ .end = k }); + if (has_name) |new_name| { + result.values._shape._tags.set(a, new_name.ptr); + result.indices._shape._tags.set(a, new_name.ptr); + } + return result; } pub const MaxPoolRes = ArgMaxRes; @@ -3827,11 +3918,20 @@ pub const Tensor = struct { return binaryOpHelper(self, other.broad(self._shape)); } - stdx.debug.assert(self._shape.eql(other._shape), "{s} expects tensor shapes to match, got {f} and {f}", .{ op_name, self._shape, other._shape }); + var other_ = other; + var same_shape = self._shape.eql(other._shape); + if (!same_shape and std.mem.eql(Shape.Tag, self._shape.tags(), other._shape.tags()) and other._shape.canBroadcastTo(self._shape)) { + // Only a restrictive version of broadcasting is allowed here, where all the tags matches already. + // Typical use case: `x.div(x.sum(.a))` + same_shape = true; + other_ = other.broad(self._shape); + } + + stdx.debug.assert(same_shape, "{s} expects tensor shapes to match, got {f} and {f}", .{ op_name, self._shape, other._shape }); const ctx = self.getContext(); - const location = ctx.location(src, "{s}({f}, {f})", .{ op_name, self, other }); - const ret = @call(.auto, op_fn, .{ ctx.mlirCtx(), self.value(), other.value(), location }); + const location = ctx.location(src, "{s}({f}, {f})", .{ op_name, self, other_ }); + const ret = @call(.auto, op_fn, .{ ctx.mlirCtx(), self.value(), other_.value(), location }); return _result(self._shape, ret.result(0)); } }.binaryOpHelper; @@ -3904,21 +4004,14 @@ test "Tensor.maxPool1d" { const x = try zml.Buffer.fromSlice(platform, .{ 2, 2, 5 }, &data); const result = try zml.testing.compileAndCall(platform, MaxPool._fwd, .{x}); - try zml.testing.expectEqualShapes(Shape.init(.{ 2, 2, 2 }, .f32), result.values.shape()); - try zml.testing.expectEqualShapes(Shape.init(.{ 2, 2, 2 }, .i32), result.indices.shape()); - const buffer = result.values.getValue([2][2][2]f32); + try zml.testing.expectEqualShapes(.init(.{ 2, 2, 2 }, .f32), result.values.shape()); + try zml.testing.expectEqualShapes(.init(.{ 2, 2, 2 }, .i32), result.indices.shape()); try std.testing.expectEqualDeep( [2][2][2]f32{ - [2][2]f32{ - [2]f32{ 2, 4 }, - [2]f32{ 7, 9 }, - }, - [2][2]f32{ - [2]f32{ 12, 14 }, - [2]f32{ 17, 19 }, - }, + .{ .{ 2, 4 }, .{ 7, 9 } }, + .{ .{ 12, 14 }, .{ 17, 19 } }, }, - buffer, + result.values.getValue([2][2][2]f32), ); } diff --git a/zml/testing.zig b/zml/testing.zig index 851bd8c..38952dd 100644 --- a/zml/testing.zig +++ b/zml/testing.zig @@ -33,7 +33,7 @@ pub fn approxEq(comptime Float: type, l: Float, r: Float, tolerance: Float) bool return closeRel or closeAbs; } -/// Testing utility. Accepts both Tensor and HostBuffer but Tensor will be copied to the +/// Testing utility. Accepts both zml.Buffer and zml.HostBuffer but zml.Buffer will be copied to the /// host for comparison ! pub fn expectClose(left_: anytype, right_: anytype, tolerance: f32) !void { const allocator = if (builtin.is_test) std.testing.allocator else std.heap.smp_allocator; diff --git a/zml/tokenizer/BUILD.bazel b/zml/tokenizer/BUILD.bazel index b7af9e2..012e218 100644 --- a/zml/tokenizer/BUILD.bazel +++ b/zml/tokenizer/BUILD.bazel @@ -1,4 +1,5 @@ load("@rules_zig//zig:defs.bzl", "zig_binary", "zig_library", "zig_test") +load("@zml//bazel:zig_srcs.bzl", "zig_srcs") zig_library( name = "tokenizer", @@ -7,7 +8,6 @@ zig_library( main = "tokenizer.zig", visibility = ["//visibility:public"], deps = [ - "//async", "//ffi:zig", "//stdx", "//zml/tokenizer/hftokenizers", @@ -36,3 +36,8 @@ zig_test( "//stdx", ], ) + +zig_srcs( + name = "sources", + zig_lib = ":tokenizer", +) diff --git a/zml/tokenizer/sentencepiece/BUILD.bazel b/zml/tokenizer/sentencepiece/BUILD.bazel index dba86dd..10ac080 100644 --- a/zml/tokenizer/sentencepiece/BUILD.bazel +++ b/zml/tokenizer/sentencepiece/BUILD.bazel @@ -1,5 +1,6 @@ load("@rules_zig//zig:defs.bzl", "zig_library") load("//bazel:swig.bzl", "swig_cc_library") +load("//bazel:zig_srcs.bzl", "zig_srcs") swig_cc_library( name = "sentencepiece_swig", @@ -22,3 +23,13 @@ zig_library( "//stdx", ], ) + +zig_srcs( + name = "sources", + zig_lib = ":sentencepiece", +) + +cc_static_library( + name="sentencepiece_static", + deps = [":sentencepiece_swig"] +) diff --git a/zml/tokenizer/tokenizer.zig b/zml/tokenizer/tokenizer.zig index c8dfb12..5109ce3 100644 --- a/zml/tokenizer/tokenizer.zig +++ b/zml/tokenizer/tokenizer.zig @@ -1,6 +1,5 @@ const std = @import("std"); -const async = @import("async"); const hftokenizers = @import("hftokenizers"); const sentencepiece = @import("sentencepiece"); @@ -98,15 +97,15 @@ pub const Tokenizer = union(Tokenizers) { pub fn fromFile(allocator: std.mem.Allocator, model: []const u8) !Tokenizer { if (std.mem.endsWith(u8, model, ".pb")) { - return .{ .sentencepiece = try async.callBlocking(sentencepiece.SentencePieceProcessor.fromFile, .{model}) }; + return .{ .sentencepiece = try sentencepiece.SentencePieceProcessor.fromFile(model) }; } if (std.mem.endsWith(u8, model, ".json")) { - return .{ .hftokenizers = try async.callBlocking(hftokenizers.HFTokenizer.fromFile, .{model}) }; + return .{ .hftokenizers = try hftokenizers.HFTokenizer.fromFile(model) }; } if (std.mem.endsWith(u8, model, ".tinyllama")) { const tokenizer = try allocator.create(homemade.Tokenizer); - tokenizer.* = try async.callBlocking(homemade.fromTinyLlamaFile, .{ allocator, model, 32000 }); + tokenizer.* = try homemade.fromTinyLlamaFile(allocator, model, 32000); return .{ .homemade = tokenizer }; } diff --git a/zml/torch.zig b/zml/torch.zig index e982aae..09888ae 100644 --- a/zml/torch.zig +++ b/zml/torch.zig @@ -169,7 +169,7 @@ test pixelShuffle { /// /// Note: at the difference of Pytorch, shifts need to be explicitly repeated, even if they are the same for all axes. /// ref: https://pytorch.org/docs/stable/generated/torch.roll.html -pub fn roll(self: Tensor, shifts: []const i64, axes_: []const u8) Tensor { +pub fn roll(self: Tensor, shifts: []const i64, axes_: []const i8) Tensor { // TODO(hugo) accept following syntax: x.roll(.{ .a = 5, .b = 8 }) stdx.debug.assert(self.rank() > 0 and shifts.len == axes_.len, "Shifts length ({d}) and dims length ({d}) are not equal, we expect the same length.", .{ shifts.len, axes_.len }); @@ -180,11 +180,11 @@ pub fn roll(self: Tensor, shifts: []const i64, axes_: []const u8) Tensor { return roll(first_dim_rolled, tail_shifts, tail_dims); } - const a = axes_[0]; + const a = self.axis(axes_[0]); const start = @mod(self.dim(a) - shifts[0], self.dim(a)); const idx = Tensor.arange(.{ .start = start, .end = start + self.dim(a) }, .f32); const divisor: f32 = @floatFromInt(self.dim(a)); - return self.gatherValues(a, idx.fmod(divisor).convert(.i32), .{}); + return self.gather_(&.{a}, &.{idx.fmod(divisor).convert(.i32)}, .{}); } test roll { @@ -194,7 +194,7 @@ test roll { const res = try zml.testing.compileAndCall( platform, roll, - .{ input, &[_]i64{ 2, 1 }, &[_]u8{ 0, 1 } }, + .{ input, &[_]i64{ 2, 1 }, &[_]i8{ 0, 1 } }, ); const expectation = zml.HostBuffer.fromSlice(.{ 4, 2 }, &[_]f32{ 6, 5, 8, 7, 2, 1, 4, 3 }); @@ -274,3 +274,12 @@ test meshgrid { ); } } + +/// Flattens the given axis and the next one, into one new axis. +pub fn flatten(self: Tensor, axis_: anytype) Tensor { + const old_shape = self._shape; + const a = self.axis(axis_); + stdx.debug.assert(a + 1 < self.rank(), "Can't flatten {f} on the last axis {}.", .{ self, axis_ }); + const new_shape = old_shape.mergeAxis(a, .{ a, a + 1 }); + return self.reshape(new_shape); +}