Revamp gather API with named indices (and add gather_ variant), improve topK handling, and add Yarn rope embedding support across core modules (buffer, nn, pjrtx, quantization, shape, tensor, testing, tokenizer, torch).

This commit is contained in:
Tarry Singh 2025-09-26 13:38:11 +00:00
parent 7264fff493
commit d45a667ee5
11 changed files with 444 additions and 600 deletions

View File

@ -54,7 +54,7 @@ pub const Buffer = struct {
break :cs @divExact(host_buffer.dim(ax), n_partitions);
} else 0;
const buffer_type = bufferTypeFromDtype(host_buffer.shape().dtype());
const buffer_type = pjrt.bufferTypeFromDtype(host_buffer.shape().dtype());
const byte_strides = host_buffer.strides();
const devices = platform.getDevices();
@ -256,7 +256,7 @@ pub const Buffer = struct {
pub fn asViewOfDeviceBuffer(platform: Platform, shape_: Shape, stream: ?*const pjrt.Stream, device_data: *anyopaque) Buffer {
const pjrt_buffer = platform.pjrt_client.createViewOfDeviceBuffer(platform.pjrt_api, .{
.data = device_data,
.element_type = bufferTypeFromDtype(shape_.dtype()),
.element_type = pjrt.bufferTypeFromDtype(shape_.dtype()),
.dims = shape_.dims(),
// TODO: exposes sharding in the API.
.device = platform.getDevices()[0],
@ -437,7 +437,7 @@ pub const Buffer = struct {
var args = pjrt.Client.CreateUninitializedBufferArgs{
.dims = shard_shape.dims(),
.element_type = bufferTypeFromDtype(shape_.dtype()),
.element_type = pjrt.bufferTypeFromDtype(shape_.dtype()),
.layout = .{
.tiled = .{
.minor_to_major = minorToMajor(shape_.rank()),
@ -487,32 +487,6 @@ pub const Buffer = struct {
}
};
pub fn bufferTypeFromDtype(dt: DataType) pjrt.BufferType {
return switch (dt) {
inline else => |tag| @field(pjrt.BufferType, @tagName(tag)),
};
}
pub fn dtypeFromBufferType(pjrt_type: pjrt.BufferType) DataType {
return switch (pjrt_type) {
.invalid => @panic("Found an invalid pjrt buffer"),
inline else => |tag| @field(DataType, @tagName(tag)),
};
}
test bufferTypeFromDtype {
inline for (@typeInfo(DataType).@"enum".fields) |field| {
const dt: DataType = @enumFromInt(field.value);
try std.testing.expectEqual(dt, dtypeFromBufferType(bufferTypeFromDtype(dt)));
}
inline for (@typeInfo(pjrt.BufferType).@"enum".fields) |field| {
const dt: pjrt.BufferType = @enumFromInt(field.value);
if (dt == .invalid) continue;
try std.testing.expectEqual(dt, bufferTypeFromDtype(dtypeFromBufferType(dt)));
}
}
const _MINOR_TO_MAJOR = blk: {
var ret: [Shape.MAX_RANK]i64 = undefined;
for (0..Shape.MAX_RANK) |i| {

View File

@ -32,7 +32,7 @@ pub const Linear = struct {
}
// log.debug("Linear({*}): {d} -> {d} -> {d}", .{ self, x.dims(), y.dims(), if (self.bias) |bias| y.add(bias).dims() else y.dims() });
return if (self.bias) |bias| y.add(bias.broadcastLeft(y.shape())) else y;
return if (self.bias) |bias| y.add(bias.broadcast(y.shape(), &.{y.axis(-1)})) else y;
}
};
@ -42,7 +42,7 @@ pub const TokenEmbedding = struct {
pub fn forward(self: TokenEmbedding, idx: Tensor) Tensor {
stdx.debug.assert(idx.dtype().isInteger(), "TokenEmbedding expects an integer input, received: {f}", .{idx});
stdx.debug.assert(self.weight.rank() == 2, "TokenEmbedding expects it's weight Tensor to be a 2D matrix, got {f}", .{self.weight});
return self.weight.gatherValues(0, idx, .{});
return self.weight.withTags(.{ .voc, .d }).gather(.{ .voc = idx }, .{});
}
};
@ -125,18 +125,18 @@ pub fn normalizeVariance(x: Tensor, eps: f32) Tensor {
// Upcast to improve precision
const xf32 = x.convert(.f32);
const mean = xf32.sum(-1).scale(1.0 / N);
const mean_dev = xf32.sub(mean.broadcastRight(xf32.shape()));
const variance = mean_dev.mul(mean_dev).sum(-1).scale(1.0 / N);
const mean_dev = xf32.sub(mean);
const variance = mean_dev.mul(mean_dev).sum(-1).divByConst(N);
const rsqrt = Tensor.rsqrt(variance.addConstant(eps));
return mean_dev.mul(rsqrt.broadcastRight(mean_dev.shape())).convert(x.dtype());
return mean_dev.mul(rsqrt).convert(x.dtype());
}
// ref: https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
// Implementation equivalent to `nn.functional.normalize(tensor, dim=-1)` call
pub fn normalizeL2(input: Tensor, eps: f32) Tensor {
const inv_norm = input.pow(Tensor.scalar(2, input.dtype())).sum(-1).addConstant(eps).rsqrt();
return input.mul(inv_norm.broad(input.shape()));
const inv_norm = input.powByConst(2).sum(-1).addConstant(eps).rsqrt();
return input.mul(inv_norm);
}
test normalizeL2 {
@ -165,12 +165,22 @@ pub const RopeOpts = struct {
default: void,
custom: []const f32,
llama3: Llama3,
yarn: Yarn,
pub const Llama3 = struct {
factor: f32,
high_freq_factor: f32,
low_freq_factor: f32,
original_max_position_embeddings: u32,
truncate: bool = true,
};
pub const Yarn = struct {
beta_fast: f32 = 32.0,
beta_slow: f32 = 1.0,
factor: f32,
truncate: bool = true,
original_max_position_embeddings: u32,
};
/// Read a Rope scaling config from HF config.json format.
@ -185,12 +195,21 @@ pub const RopeOpts = struct {
if (impl != .string) return error.InvalidEnumTag;
if (std.mem.eql(u8, impl.string, "llama3")) {
// Note: leaky is fine here cause Llama3 struct don't need to allocate memory.
return .{ .llama3 = try std.json.parseFromValueLeaky(Llama3, undefined, content, .{ .ignore_unknown_fields = true }) };
return .{ .llama3 = try std.json.parseFromValueLeaky(Llama3, stdx.noalloc, content, .{ .ignore_unknown_fields = true }) };
} else if (std.mem.eql(u8, impl.string, "yarn")) {
return .{ .yarn = try std.json.parseFromValueLeaky(Yarn, stdx.noalloc, content, .{ .ignore_unknown_fields = true }) };
} else {
log.warn("Unsupported Rope implementation: {s}, will use the default one which will produce altered results", .{impl.string});
return .{ .default = {} };
}
}
pub fn attentionScaling(scaling: Scaling) f32 {
return switch (scaling) {
.yarn => |yarn| 0.1 * @log(yarn.factor) + 1.0,
else => 1.0,
};
}
};
};
@ -218,8 +237,9 @@ pub fn rope(x: Tensor, pos_idx: ?Tensor, opts: RopeOpts) Tensor {
// compute sin and cos in f32 before downcasting to x type.
const inv_freq = invFreq(x.dim(.hd), opts).withTags(.{.hd});
const inv_freq_pos = Tensor.outer(idx.convert(.f32), inv_freq);
const cos = inv_freq_pos.cos().convert(x.dtype()).broad(x_real.shape());
const sin = inv_freq_pos.sin().convert(x.dtype()).broad(x_real.shape());
const scaling = opts.scaling.attentionScaling();
const cos = inv_freq_pos.cos().scale(scaling).convert(x.dtype()).broad(x_real.shape());
const sin = inv_freq_pos.sin().scale(scaling).convert(x.dtype()).broad(x_real.shape());
// apply rotation
const y_real = x_real.mul(cos).sub(x_imag.mul(sin));
@ -250,7 +270,7 @@ pub fn mergeRealImg(x_real: Tensor, x_imag: Tensor, layout: RopeOpts.Layout) Ten
.interleaved => Tensor.concatenate(&.{
x_real.appendAxes(.{.interleaved_real_img}),
x_imag.appendAxes(.{.interleaved_real_img}),
}, -1).flatten(-2),
}, -1).reshape(x_imag.shape().setDim(-1, -1)),
};
}
@ -259,6 +279,8 @@ pub fn invFreq(N: i64, opts: RopeOpts) Tensor {
const allocator = zml.module.CompilationContext.current().allocator();
const N_half: usize = @intCast(@divExact(N, 2));
const inv_freq = allocator.alloc(f32, N_half) catch @panic("OOM");
defer allocator.free(inv_freq);
_invFreq(opts, inv_freq);
return zml.Tensor.constantTensor(.fromSlice(.{@divExact(N, 2)}, inv_freq));
}
@ -299,10 +321,38 @@ fn _invFreq(opts: RopeOpts, inv_freq: []f32) void {
}
}
},
.yarn => |s| {
const N_f: f64 = @floatFromInt(inv_freq.len);
const M: f64 = @floatFromInt(s.original_max_position_embeddings);
const f_high = s.beta_fast * (2 * std.math.pi) / M;
const f_low = s.beta_slow * (2 * std.math.pi) / M;
const downscaling = 1.0 / s.factor;
// This isn't a typo: low n have a high frequency, high n have a low frequency.
var n_low: f64 = -@log(f_high) / @log(opts.freq_base) * N_f;
var n_high: f64 = -@log(f_low) / @log(opts.freq_base) * N_f;
if (s.truncate) {
n_high = std.math.ceil(n_high);
n_low = std.math.floor(n_low);
}
std.debug.assert(n_high > n_low);
for (0..N, inv_freq) |n, f| {
if (f > f_high) {
// High freq match default implem
} else if (f < f_low) {
// Downscaling for low freq
inv_freq[n] *= downscaling;
} else {
// Yarn use lerp too but not in the frequency space, in the time space.
const lerp: f64 = (n_high - @as(f64, @floatFromInt(n))) / (n_high - n_low);
inv_freq[n] *= @floatCast(lerp + (1 - lerp) * downscaling);
}
}
},
}
}
test invFreq {
test "invFreq Llama3" {
// Llama 3.2-1B config
const llama_conf: RopeOpts = .{
.freq_base = 500_000,
@ -323,6 +373,28 @@ test invFreq {
}
}
test "invFreq Yarn" {
const yarn_conf: RopeOpts = .{
.freq_base = 150_000,
.scaling = .{ .yarn = .{
.factor = 32.0,
.beta_fast = 32.0,
.beta_slow = 1.0,
.original_max_position_embeddings = 4096,
.truncate = true,
} },
};
const yarn_freq = [_]f32{ 1.000000000000e+00, 6.890442967415e-01, 4.747820496559e-01, 3.271458745003e-01, 2.254180014133e-01, 1.553229838610e-01, 1.070244237781e-01, 7.374456524849e-02, 5.081327259541e-02, 3.162075206637e-02, 1.945096626878e-02, 1.179219130427e-02, 7.015713956207e-03, 4.069554619491e-03, 2.277272054926e-03, 1.206130953506e-03, 5.809474969283e-04, 2.279478358105e-04, 3.830881178146e-05, 2.639646845637e-05, 1.818833698053e-05, 1.253256959899e-05, 8.635495760245e-06, 5.950239483354e-06, 4.099978468730e-06, 2.825066758305e-06, 1.946596285052e-06, 1.341290953860e-06, 9.242089618056e-07, 6.368209142238e-07, 4.387978549403e-07, 3.023511396805e-07 };
var inv_freq: @TypeOf(yarn_freq) = undefined;
_invFreq(yarn_conf, &inv_freq);
for (yarn_freq, inv_freq, 0..) |expected, actual, i| {
errdefer log.err("Mismatch at position {d}.\nExpected: {d}\nActual: {d}", .{ i, stdx.fmt.slice(&yarn_freq), stdx.fmt.slice(&inv_freq) });
try std.testing.expectApproxEqRel(expected, actual, 1e-5);
}
try std.testing.expectApproxEqRel(1.3465735902799727, yarn_conf.scaling.attentionScaling(), 1e-5);
}
test "real/img" {
const platform = zml.testing.env();
@ -332,8 +404,8 @@ test "real/img" {
const real, const imag = splitRealImg(x, layout);
const y = mergeRealImg(real, imag, layout);
const real2, const imag2 = splitRealImg(y, layout);
return real.cmp(.EQ, real2).flatten(0).convert(.i32).sum(-1).add(
imag.cmp(.EQ, imag2).flatten(0).convert(.i32).sum(-1),
return real.cmp(.EQ, real2).flatten().convert(.i32).sum(-1).add(
imag.cmp(.EQ, imag2).flatten().convert(.i32).sum(-1),
);
}
@ -349,8 +421,8 @@ test "real/img" {
Tensor.arange(.{ .start = 3, .end = 20, .step = 4 }, .f32).reshape(.{ 5, 1 }),
}, 1);
return real.cmp(.EQ, x_real).flatten(0).convert(.i32).sum(-1).add(
imag.cmp(.EQ, x_imag).flatten(0).convert(.i32).sum(-1),
return real.cmp(.EQ, x_real).flatten().convert(.i32).sum(-1).add(
imag.cmp(.EQ, x_imag).flatten().convert(.i32).sum(-1),
);
}
@ -366,8 +438,8 @@ test "real/img" {
Tensor.arange(.{ .start = 3, .end = 20, .step = 4 }, .f32).reshape(.{ 5, 1 }),
}, 1);
return real.cmp(.EQ, x_real).flatten(0).convert(.i32).sum(-1).add(
imag.cmp(.EQ, x_imag).flatten(0).convert(.i32).sum(-1),
return real.cmp(.EQ, x_real).flatten().convert(.i32).sum(-1).add(
imag.cmp(.EQ, x_imag).flatten().convert(.i32).sum(-1),
);
}
@ -377,8 +449,8 @@ test "real/img" {
const x_real = Tensor.arange(.{ .start = 0, .end = 20, .step = 2 }, .f32).reshape(.{ 5, 2 });
const x_imag = Tensor.arange(.{ .start = 1, .end = 20, .step = 2 }, .f32).reshape(.{ 5, 2 });
return real.cmp(.EQ, x_real).flatten(0).convert(.i32).sum(-1).add(
imag.cmp(.EQ, x_imag).flatten(0).convert(.i32).sum(-1),
return real.cmp(.EQ, x_real).flatten().convert(.i32).sum(-1).add(
imag.cmp(.EQ, x_imag).flatten().convert(.i32).sum(-1),
);
}
};
@ -484,21 +556,21 @@ pub fn nearest(input: Tensor, scale_factor: []const f64) Tensor {
out_shape._dims.set(i + 2, @intFromFloat(@floor(@as(f64, @floatFromInt(out_shape.dim(i + 2))) * sf)));
}
// TODO(james): remove this implicit two batching dims
var sd: [3]usize = undefined;
var sd: [3]u3 = undefined;
var len_sd: usize = 0;
for (2..input.rank()) |i| {
if (input.dim(i) != out_shape.dim(i)) {
sd[len_sd] = i;
sd[len_sd] = @intCast(i);
len_sd += 1;
}
}
const spatial_dims = sd[0..len_sd];
const spatial_axes = sd[0..len_sd];
var res = input;
for (spatial_dims) |d| {
const n = out_shape.dim(d);
const ratio = stdx.math.divFloat(f32, input.dim(d), n);
for (spatial_axes) |ax| {
const n = out_shape.dim(ax);
const ratio = stdx.math.divFloat(f32, input.dim(ax), n);
const offsets = Tensor.arange(.{ .end = n }, .f32).addConstant(0.5).scale(ratio).floor().convert(.i32);
res = res.gatherValues(d, offsets, .{ .indices_are_sorted = true });
res = res.gather_(&.{ax}, &.{offsets}, .{ .indices_are_sorted = true });
}
return res;
}
@ -670,10 +742,11 @@ test resizeBilinear {
}
pub fn resizeLinear1d(image: Tensor, axis: i8, new_len: u63, opt: ResizeOpts) Tensor {
const res_shape = image.shape().set(axis, new_len);
const ax = image.axis(axis);
const res_shape = image.shape().set(ax, new_len);
const dtype = opt.precision orelse if (image.dtype().class() == .integer) .f32 else image.dtype();
const og_len = opt.original_len orelse Tensor.scalar(image.dim(axis), dtype);
const og_len = opt.original_len orelse Tensor.scalar(image.dim(ax), dtype);
const ratio = og_len.convert(dtype).scale(stdx.math.divFloat(f32, 1, new_len));
const scaled = Tensor.arange(.{ .end = new_len }, dtype).mul(ratio);
const left = scaled.floor();
@ -682,11 +755,11 @@ pub fn resizeLinear1d(image: Tensor, axis: i8, new_len: u63, opt: ResizeOpts) Te
// TODO: check that two gather isn't too bad perf wise.
// Normally we should use gatherSlices to collect the values 2 by 2,
// but gatherSlices messes up with the order of axes.
const left_val = image.gatherValues(axis, left.convert(.i32), .{ .indices_are_sorted = true }).convert(dtype);
const right_val = image.gatherValues(axis, right.convert(.i32), .{ .indices_are_sorted = true }).convert(dtype);
const left_val = image.gather_(&.{ax}, &.{left.convert(.i32)}, .{ .indices_are_sorted = true }).convert(dtype);
const right_val = image.gather_(&.{ax}, &.{right.convert(.i32)}, .{ .indices_are_sorted = true }).convert(dtype);
const left_weight = right.sub(scaled).broadcast(res_shape, &.{axis});
const right_weight = scaled.sub(left).broadcast(res_shape, &.{axis});
const left_weight = right.sub(scaled).broadcast(res_shape, &.{ax});
const right_weight = scaled.sub(left).broadcast(res_shape, &.{ax});
const res = left_val.mul(left_weight).add(right_val.mul(right_weight));
return res.convert(image.dtype()).withTags(image.shape().tags());
@ -853,13 +926,10 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor {
return cuda.sdpa(q, k, v, opts);
}
if (q.dim(.h) != k.dim(.h)) {
stdx.debug.assert(@mod(q.dim(.h), k.dim(.h)) == 0, err_template ++ "Different number of heads for keys and queries, but can't repeat keys.", err_args);
// Note: we don't try to repeat queries.
// Repeating keys is the interesting optimisation cause it reduces KV cache memory usage.
const num_rep: u63 = @intCast(@divExact(q.dim(.h), k.dim(.h)));
k, v = .{ k.repeat1d(.h, num_rep), v.repeat1d(.h, num_rep) };
}
// Handle different numbers of head by splitting q heads.
// This is a bit error prone in the sense that it depends of the layout of q heads.
// This is the Llama convention though.
q = q.splitAxis(.h, .{ .h = k.dim(.h), .hq = .auto });
const attn_mask = if (opts.attn_mask) |m| m else null;
const dims = helpers.collectDims(.{ .h, .q, .k, .hd }, &.{ q, k, v, attn_mask }, .strict) catch {
@ -875,339 +945,7 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor {
attn_weights = attn_weights.convert(.f32).softmax(.k).convert(q.dtype());
var attn = attn_weights.dot(v, .{.k});
return attn.transpose(q.shape());
}
pub const SdpaChunks = struct { q_chunk_size: u32, k_chunk_size: u32 };
pub fn sdpaMemEfficient(
q: Tensor,
k: Tensor,
v: Tensor,
sdpa_opts: SdpaOpts,
chunking: SdpaChunks,
) Tensor {
const sdpa_mem_efficient: SdpaMemEfficient = .{
.q = q,
.k = k,
.v = v,
.sdpa_opts = sdpa_opts,
.chunking = .{
.q_chunk_size = @intCast(@min(q.dim(.q), chunking.q_chunk_size)),
.k_chunk_size = @intCast(@min(k.dim(.k), chunking.k_chunk_size)),
},
};
return sdpa_mem_efficient.forward();
}
const SdpaMemEfficient = struct {
q: Tensor,
k: Tensor,
v: Tensor,
sdpa_opts: SdpaOpts,
chunking: SdpaChunks,
fn forward(self: SdpaMemEfficient) Tensor {
stdx.debug.assert(@mod(self.q.dim(.q), self.chunking.q_chunk_size) == 0, "sdpaMemEfficient expects the chunk_size to exactly divise the seq_len, got: sdpaMemEfficient({f}, {})", .{ self.q, self.chunking });
stdx.debug.assert(@mod(self.k.dim(.k), self.chunking.k_chunk_size) == 0, "sdpaMemEfficient expects the chunk_size to exactly divise the seq_len, got: sdpaMemEfficient({f}, {})", .{ self.k, self.chunking });
const n_q_chunks: u32 = @intCast(@divExact(self.q.dim(.q), self.chunking.q_chunk_size));
const ctx = zml.module.CompilationContext.current();
const q_chunks = ctx.allocator().alloc(zml.Tensor, n_q_chunks) catch unreachable;
defer ctx.allocator().free(q_chunks);
for (0..n_q_chunks) |i| {
const idx: u32 = @intCast(i);
const q_slice: zml.Tensor.DynSlice = .{
.start = Tensor.scalar(idx * self.chunking.q_chunk_size, .i32),
.len = self.chunking.q_chunk_size,
};
const q_chunk = self.q.dynamicSlice(.{ .q = q_slice });
const attn_chunk = if (self.sdpa_opts.attn_mask) |attn_mask| attn_mask.dynamicSlice(.{ .q = q_slice }) else null;
var chunk: SdpaMemEfficient = self;
chunk.q = q_chunk;
chunk.sdpa_opts.attn_mask = attn_chunk;
q_chunks[i] = chunk.scanKeyVal();
}
const res = zml.Tensor.concatenate(q_chunks, .q);
return res.transpose(self.q.shape());
}
fn nextQueriesChunk(self: SdpaMemEfficient, idx: Tensor) Tensor {
const q_slice: zml.Tensor.DynSlice = .{
.start = idx.scale(self.chunking.q_chunk_size),
.len = self.chunking.q_chunk_size,
};
const q_chunk = self.q.dynamicSlice(.{ .q = q_slice });
const attn_chunk = if (self.sdpa_opts.attn_mask) |attn_mask| attn_mask.dynamicSlice(.{ .q = q_slice }) else null;
var chunk: SdpaMemEfficient = self;
chunk.q = q_chunk;
chunk.sdpa_opts.attn_mask = attn_chunk;
return chunk.scanKeyVal();
}
fn scanKeyVal(self: SdpaMemEfficient) Tensor {
const n_chunks = @divExact(self.k.dim(.k), self.chunking.k_chunk_size);
return if (n_chunks <= 4) {
// Unrolled version
var partial_softmax: ?PartialSoftmax = null;
for (0..@intCast(n_chunks)) |idx| {
const next = self.nextKeyValChunk(Tensor.scalar(idx, .i32));
partial_softmax = if (partial_softmax) |prev| prev.merge(next) else next;
}
return partial_softmax.?.finalize();
} else {
// stablehlo.while version
const partial_softmax, _ = zml.ops.while_(hasNextKeyValChunk, nextKeyValChunkMerge, self, .{ PartialSoftmax.zeros(self.q.shape(), .f32), Tensor.scalar(0, .i32) });
return partial_softmax.finalize();
};
}
fn nextKeyValChunkMerge(self: SdpaMemEfficient, prev: PartialSoftmax, idx: Tensor) struct { PartialSoftmax, Tensor } {
const next = self.nextKeyValChunk(idx);
return .{ prev.merge(next), idx.addConstant(1) };
}
fn nextKeyValChunk(self: SdpaMemEfficient, idx: Tensor) PartialSoftmax {
const k_slice: zml.Tensor.DynSlice = .{
.start = idx.scale(self.chunking.k_chunk_size),
.len = self.chunking.k_chunk_size,
};
const k_chunk = self.k.dynamicSlice(.{ .k = k_slice });
const v_chunk = self.v.dynamicSlice(.{ .k = k_slice });
const attn_chunk = if (self.sdpa_opts.attn_mask) |mask| mask.dynamicSlice(.{ .k = k_slice }) else null;
return sdpaChunk(self.q, k_chunk, v_chunk, .{ .attn_mask = attn_chunk });
}
pub fn hasNextKeyValChunk(self: SdpaMemEfficient, _: PartialSoftmax, idx: Tensor) zml.Tensor {
const n_chunks = @divExact(self.k.dim(.k), self.chunking.k_chunk_size);
return idx.cmp(.LT, Tensor.scalar(n_chunks, idx.dtype()));
}
};
pub const PartialSoftmax = struct {
values: Tensor,
exp_sum: Tensor,
max_value: Tensor,
pub fn zeros(q_shape: Shape, exp_sum_precision: DataType) PartialSoftmax {
return .{
.values = Tensor.constant(q_shape, q_shape.dtype().zero()),
.exp_sum = Tensor.constant(q_shape.setDim(.hd, 1), exp_sum_precision.zero()),
.max_value = Tensor.constant(q_shape.setDim(.hd, 1), q_shape.dtype().minValue()),
};
}
pub fn merge(self: PartialSoftmax, other: PartialSoftmax) PartialSoftmax {
// Rescale self and other using the new global_max.
const global_max = self.max_value.maximum(other.max_value);
const new_self = self.rescale(global_max);
const new_other = other.rescale(global_max);
// Now that self and other are using the same scale, we can just add them:
return .{
.max_value = global_max,
.values = new_self.values.add(new_other.values),
.exp_sum = new_self.exp_sum.add(new_other.exp_sum),
};
}
/// Update max_value and rescale attn and exp_sum accordingly.
pub fn rescale(self: PartialSoftmax, max_value: Tensor) PartialSoftmax {
const max_diff_exp = self.max_value.sub(max_value).exp();
const sum_dtype = self.exp_sum.dtype();
return .{
.max_value = max_value,
.values = self.values.mul(max_diff_exp.broad(self.values.shape())),
.exp_sum = self.exp_sum.mul(max_diff_exp.convert(sum_dtype)),
};
}
/// Divides the intermediary results by the exp_sum to get the proper attention values.
pub fn finalize(self: PartialSoftmax) Tensor {
return self.values.div(self.exp_sum.broad(self.values.shape()).convert(self.values.dtype()));
}
};
/// Compute softmax over a chunk.
/// Returns intermediary results to allow aggregating later.
pub fn partialSoftmax(self: Tensor, axis: anytype) PartialSoftmax {
const a = self.axis(axis);
const max_val = self.max(a).maximum(Tensor.scalar(-1e16, self.dtype()));
const out = self.sub(max_val.broad(self.shape())).exp();
return .{
.values = out,
.exp_sum = out.convert(.f32).sum(a),
.max_value = max_val,
};
}
/// Compute sdpa on a chunk, and computes a partial softmax.
/// q: (B, H, Sq, H_dim) k: (B, H, Sk, H_dim) -> qk: (B, H, Sq, Sk)
pub fn sdpaChunk(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) PartialSoftmax {
// this is a dupe of sdpa, but return the PartialSoftmax instead of true Attn.
// Consider implementing sdpa from sdpaChunk.
var q, var k, var v = .{ q_, k_, v_ };
const err_template = "sdpa(q: {f}, k: {f}, v: {f}, attn: {?f}) is invalid ! ";
const err_args = .{ q, k, v, opts.attn_mask };
stdx.debug.assert(q.shape().hasTags(.{ .h, .q, .hd }), err_template ++ "q is missing tags {{.h, .q, .hd}}", err_args);
stdx.debug.assert(k.shape().hasTags(.{ .h, .k, .hd }), err_template ++ "k is missing tags {{.h, .k, .hd}}", err_args);
stdx.debug.assert(v.shape().hasTags(.{ .h, .k, .hd }), err_template ++ "v is missing tags {{.h, .k, .hd}}", err_args);
if (q.dim(.h) != k.dim(.h)) {
stdx.debug.assert(@mod(q.dim(.h), k.dim(.h)) == 0, err_template ++ "Different number of heads for keys and queries, but can't repeat keys.", err_args);
// Note: we don't try to repeat queries.
// Repeating keys is the interesting optimisation cause it reduces KV cache memory usage.
const num_rep: u63 = @intCast(@divExact(q.dim(.h), k.dim(.h)));
k, v = .{ k.repeat1d(.h, num_rep), v.repeat1d(.h, num_rep) };
}
const attn_mask = if (opts.attn_mask) |m| m else null;
const dims = helpers.collectDims(.{ .h, .q, .k, .hd }, &.{ q, k, v, attn_mask }, .strict) catch {
stdx.debug.panic(err_template ++ "Inputs have incompatible shapes.", err_args);
};
const sqrtHeadDim: f32 = 1.0 / std.math.sqrt(@as(f32, @floatFromInt(dims.hd)));
const head_scaling = if (opts.scale) |s| s else Tensor.scalar(sqrtHeadDim, k.dtype());
k = k.mul(head_scaling.convert(k.dtype()));
var attn_weights = q.dot(k, .{.hd});
// log.debug("attn_weights : {}", .{attn_weights});
// log.debug("attn_mask : {?}", .{attn_mask});
if (attn_mask) |mask| attn_weights = attn_weights.add(mask.broad(attn_weights.shape()));
const partial = partialSoftmax(attn_weights, .k);
const attn = partial.values.dot(v, .{.k}).transpose(q.shape());
return .{
.values = attn,
// The renaming is because the above dot projected values.k into .hd,
// do the same thing on the other tensors.
// This work because dot is a linear operation, and commutes with `PartialSoftmax.finalize`
.exp_sum = partial.exp_sum.rename(.{ .k = .hd }).transpose(attn.shape()),
.max_value = partial.max_value.rename(.{ .k = .hd }).transpose(attn.shape()),
};
}
test sdpaMemEfficient {
const platform = zml.testing.env();
const allocator = std.testing.allocator;
// Note we use small input vectors to have the tests run reasonably fast,
// but don't expect speed ups with this small sizes.
const rng = try zml.compileFn(allocator, Tensor.Rng.normal, .{ Shape.init(.{ 1, 10, 512, 64 }, .f32), .{ .mean = 0, .stddev = 1 } }, platform);
defer rng.deinit();
const rng_mask = try zml.compileFn(allocator, Tensor.Rng.normal, .{ Shape.init(.{ 512, 512 }, .f32), .{ .mean = 0, .stddev = 1 } }, platform);
defer rng_mask.deinit();
// Note: we pass void here, cause Rng.normal doesn't take any runtime inputs.
const q = rng.call({}).withTags(.{ .b, .h, .q, .hd });
const k = rng.call({}).withTags(.{ .b, .h, .k, .hd });
const v = rng.call({}).withTags(.{ .b, .h, .k, .hd });
const mask = rng_mask.call({}).withTags(.{ .q, .k });
const ref_res = try zml.testing.compileAndCall(
platform,
sdpa,
.{ q, k, v, .{ .attn_mask = mask, .scale = null } },
);
try std.testing.expectEqualSlices(i64, q.shape().dims(), ref_res.shape().dims());
{
// 4 k_chunks
const res = try zml.testing.compileAndCall(
platform,
sdpaMemEfficient,
.{
q,
k,
v,
.{ .attn_mask = mask, .scale = null },
.{ .q_chunk_size = 256, .k_chunk_size = @divExact(512, 4) },
},
);
try zml.testing.expectClose(ref_res, res, 2e-3);
}
{
// 16 k_chunks
const res = try zml.testing.compileAndCall(
platform,
sdpaMemEfficient,
.{
q,
k,
v,
.{ .attn_mask = mask, .scale = null },
.{ .q_chunk_size = 256, .k_chunk_size = @divExact(512, 16) },
},
);
try zml.testing.expectClose(ref_res, res, 2e-3);
}
}
test "sdpaMemEfficient transposed" {
const platform = zml.testing.env();
const allocator = std.testing.allocator;
// Note we use small input vectors to have the tests run reasonably fast,
// but don't expect speed ups with this small sizes.
const rng = try zml.compileFn(allocator, Tensor.Rng.normal, .{ Shape.init(.{ 1, 512, 10, 64 }, .f32), .{ .mean = 0, .stddev = 1 } }, platform);
defer rng.deinit();
const rng_mask = try zml.compileFn(allocator, Tensor.Rng.normal, .{ Shape.init(.{ 512, 512 }, .f32), .{ .mean = 0, .stddev = 1 } }, platform);
defer rng_mask.deinit();
// Note: we pass void here, cause Rng.normal doesn't take any runtime inputs.
const q = rng.call({}).withTags(.{ .b, .q, .h, .hd });
const k = rng.call({}).withTags(.{ .b, .k, .h, .hd });
const v = rng.call({}).withTags(.{ .b, .k, .h, .hd });
const mask = rng_mask.call({}).withTags(.{ .q, .k });
const ref_res = try zml.testing.compileAndCall(
platform,
sdpa,
.{ q, k, v, .{ .attn_mask = mask, .scale = null } },
);
try std.testing.expectEqualSlices(i64, q.shape().dims(), ref_res.shape().dims());
{
const res = try zml.testing.compileAndCall(
platform,
sdpaMemEfficient,
.{
q,
k,
v,
.{ .attn_mask = mask, .scale = null },
.{ .q_chunk_size = @divExact(512, 2), .k_chunk_size = @divExact(512, 4) },
},
);
try zml.testing.expectClose(ref_res, res, 1e-3);
}
{
const res = try zml.testing.compileAndCall(
platform,
sdpaMemEfficient,
.{
q,
k,
v,
.{ .attn_mask = mask, .scale = null },
.{ .q_chunk_size = 512, .k_chunk_size = @divExact(512, 4) },
},
);
try zml.testing.expectClose(ref_res, res, 1e-3);
}
return attn.transpose(q.shape()).merge(.{ .h = .{ .h, .hq } });
}
/// Options controlling generation. The default values correspond to greedy decoding.
@ -1225,9 +963,9 @@ pub fn sampleTokens(activations: Tensor, opts: SamplingStrategy, rng: Tensor.Rng
return .{ next_tokens, rng };
}
const topk = activations.topK(opts.topk, .voc, .{});
// After the topk, we don't have .voc values, anymore, only topk.
var x = topk.values.rename(.{ .voc = .topk });
const topk = activations.topK(.{ .topk = .voc }, opts.topk, .{});
// After the topk, we don't have .voc values, anymore, only .topk.
var x = topk.values;
if (opts.temperature != 1.0) {
x = x.scale(1 / opts.temperature);
}
@ -1242,7 +980,7 @@ pub fn sampleTokens(activations: Tensor, opts: SamplingStrategy, rng: Tensor.Rng
// topk_idx is indices into topk.values ! so in the range [0, topk]
// Convert for the original indices from the full [0, voc] range.
const next_tokens = topk.indices.gatherValues(.voc, topk_idx.squeeze(.topk), .{});
const next_tokens = topk.indices.gather(.{ .topk = topk_idx.squeeze(.topk) }, .{});
// log.debug("sampleTokens({}) -> {} -> {} -> {}", .{ activations, topk.indices, topk_idx, next_tokens });
return .{ next_tokens, next_rng };
}
@ -1330,7 +1068,7 @@ pub fn sampleTokensDynamic(logits: Tensor, opts: DynamicSamplingStrategy, rng: T
x = x.add(gumbel_noise);
const topk_idx = x.argMax(.topk).indices;
const next_tokens = topk_indices.gatherValues(.voc, topk_idx.squeeze(.topk), .{});
const next_tokens = topk_indices.gather(.{ .voc = topk_idx.squeeze(.topk) }, .{});
return .{ next_tokens, next_rng };
}
@ -1339,7 +1077,7 @@ fn fixupLogits(logits: Tensor, opts: DynamicSamplingStrategy) [2]Tensor {
// First reduce the vocab size to a reasonable sub set of candidate.
const full_topk = if (opts.max_top_k > 0)
logits.topK(opts.max_top_k, .voc, .{ .descending = true })
logits.topK(.{ .voc = .voc }, opts.max_top_k, .{ .descending = true })
else
logits.sort(.voc, .{ .descending = true });

View File

@ -5,25 +5,30 @@ const dialects = @import("mlir/dialects");
const mlir = @import("mlir");
const pjrt = @import("pjrt");
pub const ffi = pjrt.ffi;
pub const ApiError = pjrt.ApiError;
pub const ErrorCode = pjrt.ErrorCode;
pub const ExecuteContext = pjrt.ExecuteContext;
pub const BufferType = pjrt.BufferType;
pub const Device = pjrt.Device;
pub const MemoryStats = pjrt.MemoryStats;
pub const DeviceDescription = pjrt.DeviceDescription;
pub const Api = pjrt.Api;
pub const NamedValue = pjrt.NamedValue;
pub const ApiError = pjrt.ApiError;
pub const BufferType = pjrt.BufferType;
pub const ClientInitError = pjrt.ClientInitError;
pub const Error = pjrt.Error;
pub const GetCostAnalysisError = pjrt.GetCostAnalysisError;
pub const SerializeResult = pjrt.SerializeResult;
pub const Executable = pjrt.Executable;
pub const CompiledMemoryStats = pjrt.CompiledMemoryStats;
pub const Device = pjrt.Device;
pub const DeviceDescription = pjrt.DeviceDescription;
pub const Error = pjrt.Error;
pub const ErrorCode = pjrt.ErrorCode;
pub const Executable = pjrt.Executable;
pub const ExecuteContext = pjrt.ExecuteContext;
pub const ExecuteError = ApiError;
pub const GetCostAnalysisError = pjrt.GetCostAnalysisError;
pub const Memory = pjrt.Memory;
pub const MemoryStats = pjrt.MemoryStats;
pub const NamedValue = pjrt.NamedValue;
pub const Profiler = pjrt.Profiler;
pub const SerializeResult = pjrt.SerializeResult;
pub const Stream = pjrt.Stream;
const zml = struct {
pub const DataType = @import("dtype.zig").DataType;
};
const log = std.log.scoped(.zml);
pub const CompileError = std.mem.Allocator.Error || error{InvalidMlirBytecodeVersion} || ApiError;
@ -165,8 +170,8 @@ pub const Buffer = opaque {
return @ptrCast(try self.inner().toHostBuffer(api, dst));
}
pub fn getElementType(self: *const Buffer, api: *const Api) BufferType {
return self.inner().getElementType(api);
pub fn getElementType(self: *const Buffer, api: *const Api) zml.DataType {
return dtypeFromBufferType(self.inner().getElementType(api));
}
pub fn getDimensions(self: *const Buffer, api: *const Api) []const i64 {
@ -329,3 +334,29 @@ pub const AsyncHostToDeviceTransferManager = opaque {
return self.inner().addMetadata(api, transfer_metadata);
}
};
pub fn bufferTypeFromDtype(dt: zml.DataType) pjrt.BufferType {
return switch (dt) {
inline else => |tag| @field(pjrt.BufferType, @tagName(tag)),
};
}
pub fn dtypeFromBufferType(pjrt_type: pjrt.BufferType) zml.DataType {
return switch (pjrt_type) {
.invalid => @panic("Found an invalid pjrt buffer"),
inline else => |tag| @field(zml.DataType, @tagName(tag)),
};
}
test bufferTypeFromDtype {
inline for (@typeInfo(zml.DataType).@"enum".fields) |field| {
const dt: zml.DataType = @enumFromInt(field.value);
try std.testing.expectEqual(dt, dtypeFromBufferType(bufferTypeFromDtype(dt)));
}
inline for (@typeInfo(pjrt.BufferType).@"enum".fields) |field| {
const dt: pjrt.BufferType = @enumFromInt(field.value);
if (dt == .invalid) continue;
try std.testing.expectEqual(dt, bufferTypeFromDtype(dtypeFromBufferType(dt)));
}
}

View File

@ -1,9 +1,11 @@
const std = @import("std");
const zml = @import("zml.zig");
const Allocator = std.mem.Allocator;
const zml = @import("zml.zig");
const module = zml.module;
// TODO add tests, use modern zml
pub fn Q4_0(comptime dtype: zml.DataType) type {
return struct {
const Self = @This();
@ -38,38 +40,12 @@ pub fn Q4_0(comptime dtype: zml.DataType) type {
const scales = extractScales(block_count, input);
const weights = extractWeights(block_count, input);
return scales.reshape(.{ block_count, 1 })
.broadcastLeft(zml.Shape.init(.{ block_count, 32 }, .f32))
return scales.broadcast(weights.shape(), &.{0})
.mul(weights)
.convert(dtype)
.reshape(.{block_count * 32})
.reshape(shape);
}
pub fn scaleIndices(block_count: u63) zml.Tensor {
// indices1 is the offsets of the scale bytes, repeated block_count times.
const indices1 = zml.Tensor.arange(.{ .start = 0, .end = 2 }, .i32).repeat1d(0, block_count);
// indices2 is the offsets of the blocks, repeated for each scale byte, repeated block_count times.
const indices2 = zml.Tensor.arange(.{ .start = 0, .end = block_stride * block_count, .step = block_stride }, .i32)
.reshape(.{ block_count, 1 }).broadcastLeft(zml.Shape.init(.{ block_count, 2 }, .i32)).reshape(.{2 * block_count});
// indices is the sum of the two, which is the offsets to all the bytes we are interested in.
return indices1.add(indices2);
}
pub fn weightIndices(block_count: u63) zml.Tensor {
// indices1 is the offsets of the data bytes, repeated block_count times.
const indices1 = zml.Tensor.arange(.{ .start = 2, .end = 18 }, .i32).repeat1d(0, block_count);
// indices2 is the offsets of the blocks, repeated for each data byte, repeated block_count times.
const indices2 = zml.Tensor.arange(.{ .start = 0, .end = block_stride * block_count, .step = block_stride }, .i32)
.reshape(.{ block_count, 1 }).broadcastLeft(zml.Shape.init(.{ block_count, 16 }, .i32)).reshape(.{16 * block_count});
// indices is the sum of the two, which is the offsets to all the bytes we are interested in.
return indices1.add(indices2);
}
pub fn extractScales(block_count: u63, input: zml.Tensor) zml.Tensor {
// The goal here is to get the first two bytes of every 18-bytes block in the input. For that,
// we generate a list of indices that we will use to gather from the input.
@ -85,7 +61,7 @@ pub fn Q4_0(comptime dtype: zml.DataType) type {
const indices = indices1.add(indices2);
// We select the values we are interested in with the indices, group them by pair and bitcast them to f16, then convert them to f32.
const scales = input.gatherValues(0, indices, .{ .indices_are_sorted = true }).reshape(.{ block_count, 2 }).bitCast(.f16).convert(.f32);
const scales = input.gather_(&.{0}, &.{indices}, .{ .indices_are_sorted = true }).reshape(.{ block_count, 2 }).bitCast(.f16).convert(.f32);
return scales;
}
@ -107,7 +83,7 @@ pub fn Q4_0(comptime dtype: zml.DataType) type {
// NOTE(Corendos): i4 is not supported by bitcast convert, so we need the following workaround.
// We select the values we are interested in with the indices, these are our quantized_weights.
const quantized_weights = input.gatherValues(0, indices, .{ .indices_are_sorted = true });
const quantized_weights = input.gather_(&.{0}, &.{indices}, .{ .indices_are_sorted = true });
const lb_weights = quantized_weights
.logical(.And, zml.Tensor.constant(.{16 * block_count}, zml.Data.init(.u8, 0xf)))
.bitCast(.i8);

View File

@ -413,6 +413,13 @@ pub const Shape = struct {
// Already the right shape
if (std.mem.eql(i64, self.dims(), other.dims())) return true;
if (std.mem.eql(Tag, self.tags(), other.tags())) {
for (0..self.rank()) |i| {
if (self.dim(i) != 1 and self.dim(i) != other.dim(i)) return false;
}
return true;
}
// Non ambiguous broadcasting
// TODO: broad is error prone because of this:
// it will happily broadcast .{ .a = 10, .b = 1 } to .{ .b = 10, .a = 5 }
@ -424,6 +431,7 @@ pub const Shape = struct {
}
for (self.dims(), self.tags()) |d, t| {
// TODO this is also wrong when the axes are in different order in the two shapes.
const other_ax = other.hasTag(t) orelse return false;
if (d != 1 and d != other.dim(other_ax)) return false;
}

View File

@ -311,7 +311,7 @@ pub const Tensor = struct {
/// Returns the given tensor as one contiguous buffer of bytes.
pub fn bytes(self: Tensor) Tensor {
return self.bitCast(.u8).flattenAll().withTags(.{.bytes});
return self.bitCast(.u8).flatten().withTags(.{.bytes});
}
/// Returns a Tensor containing the element-wise number of leading 0 bits in the input Tensor.
@ -700,7 +700,7 @@ pub const Tensor = struct {
pub fn gumbelStats(rand: Rng, target_dist: Tensor) struct { Rng, Stats } {
const s = Shape.init(.{ .n = 1024, .d = 4 }, .f32);
const rng, const data = rand.gumbel(s);
const flat = data.flattenAll();
const flat = data.flatten();
const mean_ = flat.mean(0);
const variance = flat.sub(mean_.broad(flat.shape())).pow(Tensor.scalar(2, .f32)).mean(0);
@ -719,7 +719,7 @@ pub const Tensor = struct {
break :blk powers;
};
const values = Tensor.constantTensor(HostBuffer.fromArray(&powers)).withTags(.{.d});
const counts = values.gatherValues(.d, samples, .{}).sum(.n).bitCast(.u16);
const counts = values.gather(.{ .d = samples }, .{}).sum(.n).bitCast(.u16);
const actual_dist = counts.reshape(target_dist.shape()).convert(target_dist.dtype()).divByConst(s.dim(.n));
return .{ rng, .{ .mean = mean_, .variance = variance, .actual_dist = actual_dist } };
}
@ -993,17 +993,22 @@ pub const Tensor = struct {
/// Returns a Tensor containing the element-wise addition of the input Tensor with a constant.
pub fn addConstant(self: Tensor, b: anytype) Tensor {
return self.add(Tensor.scalar(b, self.dtype()));
return self.add(.scalar(b, self.dtype()));
}
/// Returns a Tensor containing the element-wise division of the input Tensor by a constant.
pub fn divByConst(self: Tensor, b: anytype) Tensor {
return self.div(Tensor.scalar(b, self.dtype()));
return self.div(.scalar(b, self.dtype()));
}
/// Returns a Tensor containing the element-wise power of the input Tensor by a constant.
pub fn powByConst(self: Tensor, b: anytype) Tensor {
return self.pow(.scalar(b, self.dtype()));
}
/// Returns a Tensor containing the element-wise multiplication of the input Tensor by a constant.
pub inline fn scale(self: Tensor, val: anytype) Tensor {
return self.mul(Tensor.scalar(val, self.dtype()));
return self.mul(.scalar(val, self.dtype()));
}
pub const LogicalOp = enum { OR, XOR, AND };
@ -1325,7 +1330,7 @@ pub const Tensor = struct {
const row_mask = max_val.cmp(.GT, Tensor.scalar(-std.math.inf(f64), self.dtype()));
const exp_diff_max = self.sub(self.max(a).broad(self._shape)).exp();
const res = exp_diff_max.div(exp_diff_max.sum(a).broad(self._shape));
const res = exp_diff_max.div(exp_diff_max.sum(a));
// If a row is full -inf return full 0 instead of full nan,
// this fix attention when mask hides a full row.
@ -1519,27 +1524,7 @@ pub const Tensor = struct {
return self.transpose(perm.constSlice());
}
/// Flattens the given axis and the next one, into one new axis.
pub fn flatten(self: Tensor, axis_: anytype) Tensor {
// TODO: move to torch.zig, this is equivalent to merge
const old_shape = self._shape;
const a = self.axis(axis_);
// stdx.debug.assert(a + 1 < self.rank(), "Can't flatten {} on the last axis {}.", .{ self, axis });
const new_shape = old_shape.remove(a + 1).set(a, old_shape.dim(a) * old_shape.dim(a + 1));
const loc = self.getContext().location(@src(), "flatten({f},{})", .{ self, axis_ });
const reshaped_val = dialect.stablehlo.reshape(
self.getContext().mlirCtx(),
self.value(),
mlirx.tensorType(self.getContext().mlirCtx(), new_shape),
loc,
);
// log.debug("flatten({d}, {d}) -> {d}", .{ self.dims(), axis_, new_shape[0 .. self.rank() - 1] });
return _result(new_shape, reshaped_val.result(0));
}
pub inline fn flattenAll(self: Tensor) Tensor {
// TODO: rename to just flatten, once flatten is moved to torch
pub inline fn flatten(self: Tensor) Tensor {
return self.reshape(.{self.count()});
}
@ -1727,14 +1712,12 @@ pub const Tensor = struct {
}
const a = self.axis(axis_);
const broadshape = self._shape.insert(a + 1, .{n_rep});
const repeat_dims = Shape.range(self.rank() + 1, self.dtype()).remove(a + 1);
const res_shape = self._shape.setDim(a, self.dim(a) * n_rep);
var res = self.broadcast(broadshape, repeat_dims.dims()).flatten(a);
// Restor the tag that has been lost by flatten.
res._shape._tags.set(a, self._shape.tag(a));
const broadshape = self._shape.insert(a, .{n_rep});
const repeat_dims = Shape.range(self.rank() + 1, self.dtype()).remove(a);
return res;
return self.broadcast(broadshape, repeat_dims.dims()).reshape(res_shape);
}
/// Repeats a Tensor several times along the given axes.
@ -1751,15 +1734,47 @@ pub const Tensor = struct {
return res;
}
test repeat1d {
const zml = @import("zml.zig");
const platform = zml.testing.env();
const Local = struct {
fn repeat1d(x: Tensor, axis_: u3, n_reps: u32) Tensor {
return x.repeat1d(axis_, n_reps);
}
};
{
const inputs: [3]u8 = .{ 1, 2, 3 };
const expectations: [6]u8 = .{ 1, 2, 3, 1, 2, 3 };
const input = try zml.Buffer.fromArray(platform, inputs);
const output = try zml.testing.compileAndCall(platform, Local.repeat1d, .{ input, 0, 2 });
try std.testing.expectEqual(expectations, output.getValue(@TypeOf(expectations)));
}
{
const inputs: [2][3]u8 = .{ .{ 1, 2, 3 }, .{ 4, 5, 6 } };
const expectations: [2][6]u8 = .{ .{ 1, 2, 3, 1, 2, 3 }, .{ 4, 5, 6, 4, 5, 6 } };
const input = try zml.Buffer.fromArray(platform, inputs);
const output = try zml.testing.compileAndCall(platform, Local.repeat1d, .{ input, 1, 2 });
try std.testing.expectEqual(expectations, output.getValue(@TypeOf(expectations)));
}
}
/// Repeats in line each value along the given axis.
///
/// * stutter1d([0, 1, 2, 3], 0, 2) = [0, 0, 1, 1, 2, 2, 3, 3]
pub fn stutter1d(self: Tensor, axis_: i64, n_rep: u63) Tensor {
/// * stutter1d([0, 1, 2, 3], -1, 2) = [0, 0, 1, 1, 2, 2, 3, 3]
/// This is equivalent to repeat(ax+1) unless ax is the last axis.
pub fn stutter1d(self: Tensor, axis_: i8, n_rep: u63) Tensor {
const a = self.axis(axis_);
const broadshape = self._shape.insert(a + 1, .{n_rep});
const stutter_dims = Shape.range(self.rank() + 1, self.dtype()).remove(a + 1);
const res_shape = self._shape.setDim(a, self.dim(a) * n_rep);
return self.broadcast(broadshape, stutter_dims.dims()).flatten(a);
const stutter_dims = Shape.range(self.rank() + 1, self.dtype()).remove(a + 1);
return self.broadcast(broadshape, stutter_dims.dims()).reshape(res_shape);
}
/// Repeats in line each value along the given axes.
@ -1775,6 +1790,45 @@ pub const Tensor = struct {
return res;
}
test stutter1d {
const zml = @import("zml.zig");
const platform = zml.testing.env();
const Local = struct {
fn stutter1d(x: Tensor, axis_: u3, n_reps: u32) Tensor {
return x.stutter1d(axis_, n_reps);
}
};
{
const inputs: [3]u8 = .{ 1, 2, 3 };
const expectations: [6]u8 = .{ 1, 1, 2, 2, 3, 3 };
const input = try zml.Buffer.fromArray(platform, inputs);
const output = try zml.testing.compileAndCall(platform, Local.stutter1d, .{ input, 0, 2 });
try std.testing.expectEqual(expectations, output.getValue(@TypeOf(expectations)));
}
{
const inputs: [2][3]u8 = .{ .{ 1, 2, 3 }, .{ 4, 5, 6 } };
const expectations: [2][6]u8 = .{ .{ 1, 1, 2, 2, 3, 3 }, .{ 4, 4, 5, 5, 6, 6 } };
const input = try zml.Buffer.fromArray(platform, inputs);
const output = try zml.testing.compileAndCall(platform, Local.stutter1d, .{ input, 1, 2 });
try std.testing.expectEqual(expectations, output.getValue(@TypeOf(expectations)));
}
{
const inputs: [2][3]u8 = .{ .{ 1, 2, 3 }, .{ 4, 5, 6 } };
const expectations: [2][6]u8 = .{ .{ 1, 2, 3, 1, 2, 3 }, .{ 4, 5, 6, 4, 5, 6 } };
const input = try zml.Buffer.fromArray(platform, inputs);
const output = try zml.testing.compileAndCall(platform, Local.stutter1d, .{ input, 0, 2 });
try std.testing.expectEqual(expectations, output.getValue(@TypeOf(expectations)));
}
}
/// Returns a Tensor containing the element-wise negation of the input Tensor.
pub fn negate(self: Tensor) Tensor {
const loc = self.getContext().mlirCtx().location(@src());
@ -2160,88 +2214,98 @@ pub const Tensor = struct {
pub const GatherOpts = struct { indices_are_sorted: bool = false };
/// For each coordinate in `indices`,
/// `gatherValues` extracts a single value of the given tensor.
/// `gather` extracts slices from the given tensor at the specified offsets.
/// example: `values.gather(.{ .a = idx }, .{})`
///
/// * indices is a named list of integer tensors: eg `.{ .a = idx }`.
/// Each names specify a gathering axis, it must refer to an axis of self,
/// and the corresponding idx Tensor must contains valid indices into axis .a.
/// All indices must have the same shape or broadcast to the same shape.
///
/// * axes_ is a single axis, or a tuple of axis: .b, or .{ .b, .c }
/// * indices is an integer tensor
/// * result is a tensor whose shape is similar to the input shape
/// where the gathered axes have been replaced by axes from 'indices'.
///
/// Some example input for the base case where we work on one axis:
/// - gatherValues(f:[a]->float, .a, ind:[n]->int)[n] == f[ind[n]]
/// - gatherValues(f:[a, b], .a, ind:[n])[n, b] == f[ind[n], b]
/// - gatherValues(f: [a,b,c], .{.b}, ind: [n,m])[a, n, m, c] == f[a, ind[n, m], c]
/// - gather(f:[a], .{ .a = idx:[n]})[n] == f[idx[n]]
/// - gather(f:[a, b], .a, idx:[n])[n, b] == f[idx[n], b]
/// - gather(f:[a,b,c], .{.b = idx:[n,m]})[a, n, m, c] == f[a, idx[n, m], c]
///
/// If an axis in common between `self` and `indices`,
/// it is treated as a "batching" axis, meaning that semantically
/// the operator is doing a gatherValues one time per dimension of this axis:
/// - gatherValues(f: [a,b,c], .{.b}, ind: [a,n])[a, n] == f[a, ind[a, n]]
/// the operator is doing a gather one time per dimension of this axis:
/// - gather(f: [a,b,c], .{.b=idx: [a,n]})[a, n] == f[a, idx[a, n]]
///
/// It is an error to have an axis present in `self`, `axes_` and `indices`.
/// It's possible to pass several indices:
/// - gather(f: [a,b,c], .{.b=idx_b[n], .c=idx_c[n]})[a, n] == f[a, idx_b[n], idx_c[n]]
/// - gather(f: [a,b,c,d], .{.b=idx_b[a,n], .c=idx_c[a, n]})[a, n, d] == f[a, idx_b[a, n], idx_c[a, n], d]
///
/// If several axes are passed, then the last axis of indices is treated as coordinates:
/// - gatherValues(f: [a,b,c], .{.b, .c}, ind: [n,2])[a, n] == f[a, ind[n][0], ind[n][1]]
/// - gatherValues(f: [a,b,c,d], .{.b, .c}, ind: [a, n,2])[a, n, d] == f[a, ind[a, n][0], ind[a, n][1], d]
/// If `self` isn't tagged, you can use `gather_` to specify gathered axis by their position but batching won't be available.
///
/// It is possible to use gatherValues without tags, but batching won't be available.
pub fn gatherValues(self: Tensor, coord_axes: anytype, indices: Tensor, opts: GatherOpts) Tensor {
// scoped_log.debug("gatherValues({}, {any}, {})", .{ self, coord_axes, indices });
const single_coord, const coord_axes_ = _parseGatherCoord(self, coord_axes);
/// For performance it's better to have batching and gathering axes of `self` be the first one,
/// so that gather can
pub fn gather(self: Tensor, _indices: anytype, opts: GatherOpts) Tensor {
const idx_per_axis, const idx_tags = Shape.parseStruct(Tensor, _indices);
var idx_axes: Shape.AxesArray = .{};
for (idx_tags.slice()) |t| {
idx_axes.appendAssumeCapacity(self.axis(t));
}
stdx.debug.assert(coord_axes_.len > 0, "gatherValues expects 1 or more axes to operate one, received none. Example: `x.gatherValues(.a, indices, .{{}})`", .{});
for (coord_axes_.constSlice(), 0..) |a, i| {
// TODO: sort indices following self.shape instead of asking the user to do it.
return self.gather_(idx_axes.slice(), idx_per_axis.slice(), opts);
}
pub fn gather_(self: Tensor, idx_axes: []const u3, idx_per_axis: []const Tensor, opts: GatherOpts) Tensor {
stdx.debug.assert(idx_axes.len > 0, "gather expects 1 or more axes to operate one, received none. Example: `x.gather(.a, indices, .{{}})`", .{});
for (idx_axes, 0..) |a, i| {
if (i > 0) {
stdx.debug.assert(a == coord_axes_.get(i - 1) + 1, "gatherValues expects 'coord_axes' to be sequential. But {any} aren't sequential in {f}", .{ coord_axes, self });
stdx.debug.assert(a == idx_axes[i - 1] + 1, "gather expects 'idx_axes' to be sequential. But {any} aren't sequential in {f}", .{ idx_axes, self });
}
}
var indices_shape = idx_per_axis[0].shape();
for (idx_per_axis[1..]) |idx| {
if (idx.rank() > indices_shape.rank()) {
indices_shape = idx.shape();
}
}
for (idx_per_axis) |idx| {
stdx.debug.assert(idx.shape().canBroadcastTo(indices_shape), "gather indices can't be broadcasted together {any}", .{idx_per_axis});
}
var idx_batch_axes: Shape.DimsArray = .{};
const AxisKind = enum { batching, offset, collapsed, indices };
var self_kind: stdx.BoundedArray(AxisKind, MAX_RANK) = .{};
var indices_batch_axes: Shape.DimsArray = .{};
var self_kind: stdx.BoundedArray(AxisKind, MAX_RANK) = .{ .buffer = @splat(.offset), .len = self.rank() };
for (self._shape.tags(), 0..self.rank()) |t, self_ax| {
const maybe_coord_ax = std.mem.indexOfScalar(u3, coord_axes_.constSlice(), @intCast(self_ax));
if (indices._shape.hasTag(t)) |id_ax| {
const is_gather_axis = std.mem.containsAtLeastScalar(u3, idx_axes, 1, @intCast(self_ax));
if (indices_shape.hasTag(t)) |id_ax| {
// tag is both in self and indices -> it's a batching dim
// Note: tags are required for batching.
self_kind.appendAssumeCapacity(.batching);
indices_batch_axes.appendAssumeCapacity(id_ax);
stdx.debug.assert(maybe_coord_ax == null, "gatherValues expects axes to appear at most twice. Axis {s} has been found both in 'self={f}', in 'coord_axes_={any}' and in 'indices={f}'", .{ self._shape._tags.get(self_ax), self, coord_axes, indices });
} else if (maybe_coord_ax) |_| {
// for gatherValues we collapsed all gathered axes
// (contrary to gatherSlices where we collapse none)
self_kind.appendAssumeCapacity(.collapsed);
self_kind.buffer[self_ax] = .batching;
idx_batch_axes.appendAssumeCapacity(id_ax);
stdx.debug.assert(!is_gather_axis, "gather expects axes to appear at most twice. Axis {s} has been found both in 'self={f}', in 'idx_axes={any}' and in 'indices={f}'", .{ t, self, idx_axes, indices_shape });
} else if (is_gather_axis) {
// we collapsed all gathered axes
self_kind.buffer[self_ax] = .collapsed;
// idx_kind.buffer[id_ax] = .indices;
} else {
self_kind.appendAssumeCapacity(.offset);
self_kind.buffer[self_ax] = .offset;
}
}
// When we receive several coord_axes we need an extra dimension to store
// one index per axis, which makes the coordinates of one value.
// Otherwi se stablehlo uses the "indices.rank()" default value.
const index_coord_axis = if (single_coord)
indices.rank()
else blk: {
const ax = indices._shape.hasTag(.coord) orelse indices._shape.axis(-1);
stdx.debug.assert(indices.dim(ax) == coord_axes_.len, "gatherValues with axes={any}, expects indices to be of shape [..., {}], got: {f}", .{ coord_axes, coord_axes_.len, indices });
break :blk ax;
};
// compute res shape
var res_shape = Shape.init(.{}, self.dtype());
var res_kind: stdx.BoundedArray(AxisKind, MAX_RANK) = .{};
for (self_kind.constSlice(), 0..) |kind, ax_usize| {
for (self_kind.slice(), 0..) |kind, ax_usize| {
const ax: u3 = @intCast(ax_usize);
if (ax == coord_axes_.get(0)) {
if (ax == idx_axes[0]) {
// The first val_ax is special cause this is the place where we insert indices axes.
for (indices._shape.tags(), 0..indices.rank()) |t, id_ax| {
if (id_ax == index_coord_axis) continue;
if (std.mem.indexOfScalar(i64, indices_batch_axes.constSlice(), @intCast(id_ax))) |_| {
// batching dim are already in res
continue;
}
for (0.., indices_shape.tags(), indices_shape.dims()) |id_axis_order, id_axis, id_inserted_dim| {
const is_batching_axis = std.mem.containsAtLeastScalar(i64, idx_batch_axes.constSlice(), 1, @intCast(id_axis_order));
// Batching axis is already in self.
if (is_batching_axis) continue;
res_shape = res_shape.appendDim(indices.dim(id_ax), t);
res_shape = res_shape.appendDim(id_inserted_dim, id_axis);
res_kind.appendAssumeCapacity(.indices);
}
}
@ -2257,12 +2321,12 @@ pub const Tensor = struct {
// This is not a gather, but a dynamicSlice.
// Sometimes the backend recognize this pattern, but not always.
// So let us handle that.
if (indices.count() == 1) {
return self.dynamicSlice1d(coord_axes_.get(0), .{ .start = indices.flattenAll().squeeze(0), .len = 1 }).reshape(res_shape);
if (indices_shape.count() == 1 and idx_axes.len == 1) {
return self.dynamicSlice1d(idx_axes[0], .{ .start = idx_per_axis[0].asScalar(), .len = 1 }).reshape(res_shape);
}
var slice_dims: Shape.DimsArray = .{};
for (self_kind.constSlice(), self.dims()) |k, d| {
for (self_kind.slice(), self.dims()) |k, d| {
slice_dims.appendAssumeCapacity(switch (k) {
.batching, .collapsed => 1,
.offset => d,
@ -2270,7 +2334,9 @@ pub const Tensor = struct {
});
}
// scoped_log.debug("gatherValues --> {} {any}", .{ res_shape, res_kind.constSlice() });
// TODO: try changing .last by other axis and see the perf impact.
const indices = Tensor.stack(idx_per_axis, .last, .coord);
// scoped_log.debug("gather --> {} {any}", .{ res_shape, res_kind.constSlice() });
const loc = self.getContext().mlirCtx().location(@src());
const gather_op = dialect.stablehlo.gather(
self.getContext().mlirCtx(),
@ -2282,22 +2348,30 @@ pub const Tensor = struct {
.offset_dims = _collectAxes(AxisKind, res_kind, .offset).constSlice(),
.collapsed_slice_dims = _collectAxes(AxisKind, self_kind, .collapsed).constSlice(),
.operand_batching_dims = _collectAxes(AxisKind, self_kind, .batching).constSlice(),
.start_indices_batching_dims = indices_batch_axes.constSlice(),
.start_indices_batching_dims = idx_batch_axes.constSlice(),
.start_index_map = _collectAxes(AxisKind, self_kind, .collapsed).constSlice(),
.index_vector_dim = index_coord_axis,
.index_vector_dim = indices.axis(.coord),
.indices_are_sorted = opts.indices_are_sorted,
},
);
const mlir_shape = fromMlirValue(gather_op.result(0)).shape();
stdx.debug.assert(mlir_shape.eql(res_shape), "gatherValues expects that batching indices appear in the same order in 'self' and 'indices', got: self={f}, indices={f}. You should transpose one or the other.", .{ self, indices });
stdx.debug.assert(mlir_shape.eql(res_shape), "gather expects that batching indices appear in the same order in 'self' and 'indices', got: self={f}, indices={f}. You should transpose one or the other.", .{ self, indices });
return _result(res_shape, gather_op.result(0));
}
test gatherValues {
test gather {
const zml = @import("zml.zig");
const platform = zml.testing.env();
const Local = struct {
pub fn _idx(idx_shape: anytype) Tensor {
return Tensor.constant(idx_shape, .{ .i32 = 0 });
}
};
const idx = Local._idx;
{
// Only test shapes
var comp = try zml.module.CompilationContext.init(std.testing.allocator, "test", platform);
@ -2306,32 +2380,38 @@ pub const Tensor = struct {
defer comp.deactivate();
inline for (.{
.{ .{ .a = 10 }, .a, .{}, .{} },
.{ .{ .a = 10 }, .a, .{ .n = 8 }, .{ .n = 8 } },
.{ .{ .a = 10, .b = 20 }, .a, .{}, .{ .b = 20 } },
.{ .{ .a = 10, .b = 20 }, .a, .{ .n = 8 }, .{ .n = 8, .b = 20 } },
.{ .{ .a = 10, .b = 20 }, 0, .{ .n = 8 }, .{ .n = 8, .b = 20 } },
.{ .{ .a = 10 }, .{ .a = idx(.{}) }, .{} },
.{ .{ .a = 10 }, .{ .a = idx(.{ .n = 8 }) }, .{ .n = 8 } },
.{ .{ .a = 10, .b = 20 }, .{ .a = idx(.{}) }, .{ .b = 20 } },
.{ .{ .a = 10, .b = 20 }, .{ .a = idx(.{ .n = 8 }) }, .{ .n = 8, .b = 20 } },
// .{ .{ .a = 10, .b = 20 }, 0, idx(.{ .n = 8 }), .{ .n = 8, .b = 20 } },
// Favor val shape, instead of indices shape.
.{ .{ .a = 10, .b = 20 }, .b, .{ .n = 8 }, .{ .a = 10, .n = 8 } },
.{ .{ .a = 10, .b = 20, .c = 30 }, .b, .{ .n = 8 }, .{ .a = 10, .n = 8, .c = 30 } },
.{ .{ .a = 10, .b = 20 }, .{ .b = idx(.{ .n = 8 }) }, .{ .a = 10, .n = 8 } },
.{ .{ .a = 10, .b = 20, .c = 30 }, .{ .b = idx(.{ .n = 8 }) }, .{ .a = 10, .n = 8, .c = 30 } },
// batching axes are implicits.
.{ .{ .a = 10, .b = 20 }, .b, .{ .a = 10 }, .{ .a = 10 } },
.{ .{ .a = 10, .b = 20 }, .a, .{ .b = 20 }, .{ .b = 20 } },
.{ .{ .a = 10, .b = 20 }, .b, .{ .a = 10, .n = 8 }, .{ .a = 10, .n = 8 } },
.{ .{ .a = 10, .b = 20 }, .{ .b = idx(.{ .a = 10 }) }, .{ .a = 10 } },
.{ .{ .a = 10, .b = 20 }, .{ .a = idx(.{ .b = 20 }) }, .{ .b = 20 } },
.{ .{ .a = 10, .b = 20 }, .{ .b = idx(.{ .a = 10, .n = 8 }) }, .{ .a = 10, .n = 8 } },
// stablehlo.gather is biased toward indices shape (like gatherSlice).
// This make it awkward to use when you have both batching dimension and new indices dimensions.
// For now we reject those, and let user explicitly transpose self or indices if needed.
// .{ .{ .a = 10, .b = 20 }, .b, .{ .n = 8, .a = 10 }, .{ .a = 10, .n = 8 } },
// .{ .{ .a = 10, .b = 20 }, .{.b = idx(.{ .n = 8, .a = 10 })}, .{ .a = 10, .n = 8 } },
// Also handle tuples
.{ .{ .a = 10, .b = 20 }, .{ .a, .b }, .{ .n = 8, ._ = 2 }, .{ .n = 8 } },
.{ .{ 10, 20 }, .{ -2, -1 }, .{ 8, 2 }, .{8} },
// and 1-tuple
.{ .{ .a = 10, .b = 20 }, .{.b}, .{ .n = 8, ._ = 1 }, .{ .a = 10, .n = 8 } },
.{ .{ .a = 10, .b = 20 }, .{ .a = idx(.{ .n = 8 }), .b = idx(.{ .n = 8 }) }, .{ .n = 8 } },
}) |testcase| {
const x_shape, const tag, const idx_shape, const res_shape = testcase;
const x_shape, const indices, const res_shape = testcase;
const x = Tensor.constant(x_shape, .{ .f16 = 0 });
const idx = Tensor.constant(idx_shape, .{ .i32 = 0 });
const y = gatherValues(x, tag, idx, .{});
const y = gather(x, indices, .{});
try zml.testing.expectEqualShapes(Shape.init(res_shape, .f16), y.shape());
try std.testing.expect(y.value().owner().verify());
}
inline for (.{
.{ .{ 10, 20 }, &[_]u3{ 0, 1 }, &[_]Tensor{ idx(.{8}), idx(.{8}) }, .{8} },
}) |testcase| {
const x_shape, const idx_axes, const idx_per_axis, const res_shape = testcase;
const x = Tensor.constant(x_shape, .{ .f16 = 0 });
const y = gather_(x, idx_axes, idx_per_axis, .{});
try zml.testing.expectEqualShapes(Shape.init(res_shape, .f16), y.shape());
try std.testing.expect(y.value().owner().verify());
}
@ -2966,13 +3046,24 @@ pub const Tensor = struct {
}
/// Returns a Tensor representing the result of Top-K over the given axis.
pub fn topK(self: Tensor, k: u32, axis_: anytype, opts: struct { descending: bool = true }) SortRes {
const a = self.axis(axis_);
const result = self.sort(a, .{ .descending = opts.descending });
return .{
.values = result.values.slice1d(a, .{ .end = k }),
.indices = result.indices.slice1d(a, .{ .end = k }),
pub fn topK(self: Tensor, named_axis_: anytype, k: u32, opts: struct { descending: bool = true }) SortRes {
const err_msg = "topK named axis should be an integer or a named axis, eg `x.topK(.{{ .best_token = .token }}, 16)` or `x.topK(-1, 16)`";
const has_name: ?[:0]const u8, const a = switch (@typeInfo(@TypeOf(named_axis_))) {
.int, .comptime_int => .{ null, self.axis(@as(i64, @intCast(named_axis_))) },
.@"struct" => |info| blk: {
stdx.debug.assertComptime(info.fields.len == 1, err_msg, .{});
break :blk .{ info.fields[0].name, self.axis(@field(named_axis_, info.fields[0].name)) };
},
else => stdx.debug.compileError(err_msg, .{}),
};
var result = self.sort(a, .{ .descending = opts.descending });
result.values = result.values.slice1d(a, .{ .end = k });
result.indices = result.indices.slice1d(a, .{ .end = k });
if (has_name) |new_name| {
result.values._shape._tags.set(a, new_name.ptr);
result.indices._shape._tags.set(a, new_name.ptr);
}
return result;
}
pub const MaxPoolRes = ArgMaxRes;
@ -3827,11 +3918,20 @@ pub const Tensor = struct {
return binaryOpHelper(self, other.broad(self._shape));
}
stdx.debug.assert(self._shape.eql(other._shape), "{s} expects tensor shapes to match, got {f} and {f}", .{ op_name, self._shape, other._shape });
var other_ = other;
var same_shape = self._shape.eql(other._shape);
if (!same_shape and std.mem.eql(Shape.Tag, self._shape.tags(), other._shape.tags()) and other._shape.canBroadcastTo(self._shape)) {
// Only a restrictive version of broadcasting is allowed here, where all the tags matches already.
// Typical use case: `x.div(x.sum(.a))`
same_shape = true;
other_ = other.broad(self._shape);
}
stdx.debug.assert(same_shape, "{s} expects tensor shapes to match, got {f} and {f}", .{ op_name, self._shape, other._shape });
const ctx = self.getContext();
const location = ctx.location(src, "{s}({f}, {f})", .{ op_name, self, other });
const ret = @call(.auto, op_fn, .{ ctx.mlirCtx(), self.value(), other.value(), location });
const location = ctx.location(src, "{s}({f}, {f})", .{ op_name, self, other_ });
const ret = @call(.auto, op_fn, .{ ctx.mlirCtx(), self.value(), other_.value(), location });
return _result(self._shape, ret.result(0));
}
}.binaryOpHelper;
@ -3904,21 +4004,14 @@ test "Tensor.maxPool1d" {
const x = try zml.Buffer.fromSlice(platform, .{ 2, 2, 5 }, &data);
const result = try zml.testing.compileAndCall(platform, MaxPool._fwd, .{x});
try zml.testing.expectEqualShapes(Shape.init(.{ 2, 2, 2 }, .f32), result.values.shape());
try zml.testing.expectEqualShapes(Shape.init(.{ 2, 2, 2 }, .i32), result.indices.shape());
const buffer = result.values.getValue([2][2][2]f32);
try zml.testing.expectEqualShapes(.init(.{ 2, 2, 2 }, .f32), result.values.shape());
try zml.testing.expectEqualShapes(.init(.{ 2, 2, 2 }, .i32), result.indices.shape());
try std.testing.expectEqualDeep(
[2][2][2]f32{
[2][2]f32{
[2]f32{ 2, 4 },
[2]f32{ 7, 9 },
},
[2][2]f32{
[2]f32{ 12, 14 },
[2]f32{ 17, 19 },
},
.{ .{ 2, 4 }, .{ 7, 9 } },
.{ .{ 12, 14 }, .{ 17, 19 } },
},
buffer,
result.values.getValue([2][2][2]f32),
);
}

View File

@ -33,7 +33,7 @@ pub fn approxEq(comptime Float: type, l: Float, r: Float, tolerance: Float) bool
return closeRel or closeAbs;
}
/// Testing utility. Accepts both Tensor and HostBuffer but Tensor will be copied to the
/// Testing utility. Accepts both zml.Buffer and zml.HostBuffer but zml.Buffer will be copied to the
/// host for comparison !
pub fn expectClose(left_: anytype, right_: anytype, tolerance: f32) !void {
const allocator = if (builtin.is_test) std.testing.allocator else std.heap.smp_allocator;

View File

@ -1,4 +1,5 @@
load("@rules_zig//zig:defs.bzl", "zig_binary", "zig_library", "zig_test")
load("@zml//bazel:zig_srcs.bzl", "zig_srcs")
zig_library(
name = "tokenizer",
@ -7,7 +8,6 @@ zig_library(
main = "tokenizer.zig",
visibility = ["//visibility:public"],
deps = [
"//async",
"//ffi:zig",
"//stdx",
"//zml/tokenizer/hftokenizers",
@ -36,3 +36,8 @@ zig_test(
"//stdx",
],
)
zig_srcs(
name = "sources",
zig_lib = ":tokenizer",
)

View File

@ -1,5 +1,6 @@
load("@rules_zig//zig:defs.bzl", "zig_library")
load("//bazel:swig.bzl", "swig_cc_library")
load("//bazel:zig_srcs.bzl", "zig_srcs")
swig_cc_library(
name = "sentencepiece_swig",
@ -22,3 +23,13 @@ zig_library(
"//stdx",
],
)
zig_srcs(
name = "sources",
zig_lib = ":sentencepiece",
)
cc_static_library(
name="sentencepiece_static",
deps = [":sentencepiece_swig"]
)

View File

@ -1,6 +1,5 @@
const std = @import("std");
const async = @import("async");
const hftokenizers = @import("hftokenizers");
const sentencepiece = @import("sentencepiece");
@ -98,15 +97,15 @@ pub const Tokenizer = union(Tokenizers) {
pub fn fromFile(allocator: std.mem.Allocator, model: []const u8) !Tokenizer {
if (std.mem.endsWith(u8, model, ".pb")) {
return .{ .sentencepiece = try async.callBlocking(sentencepiece.SentencePieceProcessor.fromFile, .{model}) };
return .{ .sentencepiece = try sentencepiece.SentencePieceProcessor.fromFile(model) };
}
if (std.mem.endsWith(u8, model, ".json")) {
return .{ .hftokenizers = try async.callBlocking(hftokenizers.HFTokenizer.fromFile, .{model}) };
return .{ .hftokenizers = try hftokenizers.HFTokenizer.fromFile(model) };
}
if (std.mem.endsWith(u8, model, ".tinyllama")) {
const tokenizer = try allocator.create(homemade.Tokenizer);
tokenizer.* = try async.callBlocking(homemade.fromTinyLlamaFile, .{ allocator, model, 32000 });
tokenizer.* = try homemade.fromTinyLlamaFile(allocator, model, 32000);
return .{ .homemade = tokenizer };
}

View File

@ -169,7 +169,7 @@ test pixelShuffle {
///
/// Note: at the difference of Pytorch, shifts need to be explicitly repeated, even if they are the same for all axes.
/// ref: https://pytorch.org/docs/stable/generated/torch.roll.html
pub fn roll(self: Tensor, shifts: []const i64, axes_: []const u8) Tensor {
pub fn roll(self: Tensor, shifts: []const i64, axes_: []const i8) Tensor {
// TODO(hugo) accept following syntax: x.roll(.{ .a = 5, .b = 8 })
stdx.debug.assert(self.rank() > 0 and shifts.len == axes_.len, "Shifts length ({d}) and dims length ({d}) are not equal, we expect the same length.", .{ shifts.len, axes_.len });
@ -180,11 +180,11 @@ pub fn roll(self: Tensor, shifts: []const i64, axes_: []const u8) Tensor {
return roll(first_dim_rolled, tail_shifts, tail_dims);
}
const a = axes_[0];
const a = self.axis(axes_[0]);
const start = @mod(self.dim(a) - shifts[0], self.dim(a));
const idx = Tensor.arange(.{ .start = start, .end = start + self.dim(a) }, .f32);
const divisor: f32 = @floatFromInt(self.dim(a));
return self.gatherValues(a, idx.fmod(divisor).convert(.i32), .{});
return self.gather_(&.{a}, &.{idx.fmod(divisor).convert(.i32)}, .{});
}
test roll {
@ -194,7 +194,7 @@ test roll {
const res = try zml.testing.compileAndCall(
platform,
roll,
.{ input, &[_]i64{ 2, 1 }, &[_]u8{ 0, 1 } },
.{ input, &[_]i64{ 2, 1 }, &[_]i8{ 0, 1 } },
);
const expectation = zml.HostBuffer.fromSlice(.{ 4, 2 }, &[_]f32{ 6, 5, 8, 7, 2, 1, 4, 3 });
@ -274,3 +274,12 @@ test meshgrid {
);
}
}
/// Flattens the given axis and the next one, into one new axis.
pub fn flatten(self: Tensor, axis_: anytype) Tensor {
const old_shape = self._shape;
const a = self.axis(axis_);
stdx.debug.assert(a + 1 < self.rank(), "Can't flatten {f} on the last axis {}.", .{ self, axis_ });
const new_shape = old_shape.mergeAxis(a, .{ a, a + 1 });
return self.reshape(new_shape);
}