Add gpt-oss model support to core ZML components: formatting, utility functions, safetensors I/O, host buffer management, NN layer handling, and tensor operations.

This commit is contained in:
Tarry Singh 2025-10-01 14:20:32 +00:00
parent e1b7fc5781
commit 77cd21d2b2
6 changed files with 106 additions and 74 deletions

View File

@ -10,20 +10,20 @@ fn FmtSlice(T: type) type {
slice: []const T,
pub fn format(f: @This(), writer: *std.io.Writer) std.io.Writer.Error!void {
return try formatSliceAny(f.slice, .{}, writer);
return try formatSliceAny(f.slice, .{}, 1, writer);
}
pub fn formatNumber(f: @This(), writer: *std.io.Writer, n: std.fmt.Number) std.io.Writer.Error!void {
return switch (@typeInfo(T)) {
.comptime_float, .float => try formatFloatSlice(f.slice, n, writer),
.comptime_int, .int => try formatIntSlice(f.slice, n, writer),
.bool => try formatBoolSlice(f.slice, n, writer),
.comptime_float, .float => try formatFloatSlice(f.slice, n, 1, writer),
.comptime_int, .int => try formatIntSlice(f.slice, n, 1, writer),
.bool => try formatBoolSlice(f.slice, n, 1, writer),
.@"struct" => if (@hasField(T, "re") and @hasField(T, "im")) {
try formatComplexSlice(f.slice, n, writer);
try formatComplexSlice(f.slice, n, 1, writer);
} else if (@hasDecl(T, "toF32")) {
try formatFloatSlice(f.slice, n, writer);
try formatFloatSlice(f.slice, n, 1, writer);
} else {
try formatSliceAny(f.slice, n, writer);
try formatSliceAny(f.slice, n, 1, writer);
},
else => @compileError("FmtSlice doesn't support type: " ++ @typeName(T)),
};
@ -72,53 +72,55 @@ pub fn formatAny(value: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !
return try writer.alignBufferOptions(s, .{ .alignment = spec.alignment, .fill = spec.fill });
}
pub fn formatSliceCustom(fmt_func: anytype, values: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void {
pub fn formatSliceCustom(fmt_func: anytype, values: anytype, spec: std.fmt.Number, stride: i64, writer: *std.Io.Writer) !void {
// use the format "width" for the number of columns instead of individual width.
const num_cols: usize = spec.width orelse 12;
var my_options = spec;
my_options.width = null;
const n: usize = values.len;
// TODO: handle negative strides
const strd: usize = @intCast(stride);
const n: usize = @divTrunc(values.len, strd);
_ = try writer.write("{");
if (n <= num_cols) {
for (values, 0..) |v, i| {
for (0..n) |i| {
// Force inlining so that the switch and the buffer can be done once.
try @call(.always_inline, fmt_func, .{ v, my_options, writer });
try @call(.always_inline, fmt_func, .{ values[i * strd], my_options, writer });
if (i < n - 1) _ = try writer.write(",");
}
} else {
const half = @divFloor(num_cols, 2);
for (values[0..half]) |v| {
try @call(.always_inline, fmt_func, .{ v, my_options, writer });
for (0..half) |i| {
try @call(.always_inline, fmt_func, .{ values[i * strd], my_options, writer });
_ = try writer.write(",");
}
_ = try writer.write(" ..., ");
for (values[n - half ..], 0..) |v, i| {
try @call(.always_inline, fmt_func, .{ v, my_options, writer });
if (i < half - 1) _ = try writer.write(",");
for (n - half..n) |i| {
try @call(.always_inline, fmt_func, .{ values[i * strd], my_options, writer });
if (i < n - 1) _ = try writer.write(",");
}
}
_ = try writer.write("}");
}
pub fn formatSliceAny(values: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void {
return try formatSliceCustom(formatAny, values, spec, writer);
pub fn formatSliceAny(values: anytype, spec: std.fmt.Number, stride: i64, writer: *std.Io.Writer) !void {
return try formatSliceCustom(formatAny, values, spec, stride, writer);
}
pub fn formatFloatSlice(values: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void {
return try formatSliceCustom(formatFloat, values, spec, writer);
pub fn formatFloatSlice(values: anytype, spec: std.fmt.Number, stride: i64, writer: *std.Io.Writer) !void {
return try formatSliceCustom(formatFloat, values, spec, stride, writer);
}
pub fn formatIntSlice(values: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void {
return try formatSliceCustom(formatInt, values, spec, writer);
pub fn formatIntSlice(values: anytype, spec: std.fmt.Number, stride: i64, writer: *std.Io.Writer) !void {
return try formatSliceCustom(formatInt, values, spec, stride, writer);
}
pub fn formatComplexSlice(values: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void {
return try formatSliceCustom(formatComplex, values, spec, writer);
pub fn formatComplexSlice(values: anytype, spec: std.fmt.Number, stride: i64, writer: *std.Io.Writer) !void {
return try formatSliceCustom(formatComplex, values, spec, stride, writer);
}
pub fn formatBoolSlice(values: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void {
return try formatSliceCustom(formatBool, values, spec, writer);
pub fn formatBoolSlice(values: anytype, spec: std.fmt.Number, stride: i64, writer: *std.Io.Writer) !void {
return try formatSliceCustom(formatBool, values, spec, stride, writer);
}
/// Format a struct using `format` method of subfields when possible.

View File

@ -21,16 +21,6 @@ import torch
log = logging.getLogger(__name__)
builtin_open = builtins.open
def log_and_open(file, *args, **kwargs):
print("open:", file, args, kwargs)
return builtin_open(file, *args, **kwargs)
builtins.open = log_and_open
class ActivationCollector:
"""Wrap a given torch.nn.Module and collect all its intermediary activations.

View File

@ -23,7 +23,7 @@ pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore
} else {
try loadFile(arena, &res, &files, path);
}
res.files = try files.toOwnedSlice(allocator);
res.files = try files.toOwnedSlice(arena);
return res;
}

View File

@ -355,12 +355,17 @@ pub const HostBuffer = struct {
try writer.splatByteAll(' ', indent_level);
switch (self.dtype()) {
inline else => |dt| {
const values = self.items(dt.toZigType());
const T = dt.toZigType();
const n: i64 = self._shape.dim(0);
// TODO: handle negative strides
const byte_stride: i64 = self._strides[0];
const elem_strides: i64 = @divExact(byte_stride, @sizeOf(T));
const values: []const T = @ptrCast(@alignCast(self._data[0..@intCast(n * byte_stride)]));
switch (comptime dt.class()) {
.float => try stdx.fmt.formatFloatSlice(values, options, writer),
.integer => try stdx.fmt.formatIntSlice(values, options, writer),
.complex => try stdx.fmt.formatComplexSlice(values, options, writer),
.bool => try stdx.fmt.formatBoolSlice(values, options, writer),
.float => try stdx.fmt.formatFloatSlice(values, options, elem_strides, writer),
.integer => try stdx.fmt.formatIntSlice(values, options, elem_strides, writer),
.complex => try stdx.fmt.formatComplexSlice(values, options, elem_strides, writer),
.bool => try stdx.fmt.formatBoolSlice(values, options, elem_strides, writer),
}
},
}

View File

@ -110,9 +110,8 @@ pub const LayerNorm = struct {
pub fn rmsNorm(x: Tensor, axis: anytype, eps: f32) Tensor {
const ax = x.axis(axis);
// upcast to improve precision
const xf32 = x.convert(.f32);
const mean = xf32.mul(xf32).mean(ax);
const rsqrt = Tensor.rsqrt(mean.addConstant(eps)).convert(x.dtype());
const variance = x.convert(.f32).powByConst(2).mean(ax);
const rsqrt = Tensor.rsqrt(variance.addConstant(eps)).convert(x.dtype());
return x.mul(rsqrt.broad(x.shape()));
}
@ -816,7 +815,7 @@ pub fn resizeCubic1d(image: Tensor, axis: i8, new_len: u63, opt: ResizeOpts) Ten
Tensor.constant(t.shape(), dtype.one()),
t,
t.mul(t),
t.pow(Tensor.scalar(3, dtype)),
t.powByConst(3),
}, .last, ._interpolated);
std.debug.assert(pos.dim(0) == new_len);
@ -894,6 +893,7 @@ pub fn causalAttnMask(
pub const SdpaOpts = struct {
attn_mask: ?Tensor = null,
scale: ?Tensor = null,
softmax_bias: ?Tensor = null,
allow_cudnn: bool = true,
// TODO: put a callback instead of all this field,
// so that
@ -922,7 +922,7 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor {
stdx.debug.assert(k.shape().hasTags(.{ .h, .k, .hd }), err_template ++ "k is missing tags {{.h, .k, .hd}}", err_args);
stdx.debug.assert(v.shape().hasTags(.{ .h, .k, .hd }), err_template ++ "v is missing tags {{.h, .k, .hd}}", err_args);
if (opts.allow_cudnn and cuda.canUseCudnnSdpa(q.shape())) {
if (opts.allow_cudnn and cuda.canUseCudnnSdpa(q.shape()) and opts.softmax_bias == null) {
return cuda.sdpa(q, k, v, opts);
}
@ -940,9 +940,15 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor {
k = k.mul(head_scaling.convert(k.dtype()));
var attn_weights = q.dot(k, .{.hd});
// log.debug("attn_weights : {}, attn_mask : {?}", .{ attn_weights, attn_mask });
// log.debug("attn_weights : {f}, attn_mask : {?f}", .{ attn_weights, attn_mask });
if (attn_mask) |mask| attn_weights = attn_weights.add(mask.broad(attn_weights.shape()));
attn_weights = attn_weights.convert(.f32).softmax(.k).convert(q.dtype());
attn_weights = attn_weights.convert(.f32);
attn_weights = if (opts.softmax_bias) |softmax_bias| attn: {
// The split is needed because we also split q ourselves.
// TODO: consider letting the user do that.
const bias = softmax_bias.splitAxis(.h, .{ .h = k.dim(.h), .hq = .auto });
break :attn attn_weights.convert(.f32).softmaxBiased(.k, bias).convert(q.dtype());
} else attn_weights.convert(.f32).softmax(.k).convert(q.dtype());
var attn = attn_weights.dot(v, .{.k});
return attn.transpose(q.shape()).merge(.{ .h = .{ .h, .hq } });

View File

@ -1105,13 +1105,16 @@ pub const Tensor = struct {
/// Axes with the same tag on both sides, and which aren't contracting,
/// are considered "batching axes".
pub fn dot(lhs: Tensor, rhs: Tensor, comptime contracting: anytype) Tensor {
var contracting_axes: [contracting.len][2]i8 = undefined;
inline for (contracting, 0..) |c, i| {
contracting_axes[i] = .{ lhs.axis(c), rhs.axis(c) };
var contracting_axes: stdx.BoundedArray([2]i8, MAX_RANK) = .{};
if (@TypeOf(contracting) == EnumLiteral) {
contracting_axes.appendAssumeCapacity(.{ lhs.axis(contracting), rhs.axis(contracting) });
} else {
inline for (contracting) |c| {
contracting_axes.appendAssumeCapacity(.{ lhs.axis(c), rhs.axis(c) });
}
}
var batching_axes: [MAX_RANK][2]i8 = undefined;
var n_batching: u8 = 0;
var batching_axes: stdx.BoundedArray([2]i8, MAX_RANK) = .{};
for (lhs._shape.tags(), 0..) |l, li| {
stdx.debug.assert(l != Shape.TagUnknown, "Can't use `dot(..., {any})` on {any}, it need to be explictily tagged.", .{ contracting, lhs });
@ -1119,20 +1122,19 @@ pub const Tensor = struct {
stdx.debug.assert(r != Shape.TagUnknown, "Can't use `dot(..., {any})` on {any}, it need to be explictily tagged.", .{ contracting, rhs });
if (l == r) {
for (contracting_axes) |ct| {
for (contracting_axes.slice()) |ct| {
if (l == lhs._shape.tag(ct[0])) {
break;
}
} else {
// tag is both in lhs and rhs but not in contracting -> it's a batching dim.
batching_axes[n_batching] = .{ @intCast(li), @intCast(ri) };
n_batching += 1;
batching_axes.appendAssumeCapacity(.{ @intCast(li), @intCast(ri) });
}
}
}
}
return dotGeneral(lhs, rhs, contracting_axes[0..], batching_axes[0..n_batching]);
return dotGeneral(lhs, rhs, contracting_axes.slice(), batching_axes.slice());
}
test dot {
@ -1318,23 +1320,39 @@ pub const Tensor = struct {
/// Returns a Tensor containing the Sigmoid Linear Unit (SiLU) activation function applied to each element of the input Tensor.
///
/// silu(x) = x σ(x)
/// https://paperswithcode.com/method/silu
pub fn silu(x: Tensor) Tensor {
return x.mul(x.sigmoid());
return x.mul(.sigmoid(x));
}
/// Returns a Tensor containing the softmax function applied to each element of the input Tensor.
/// Computes softmax along the given axis.
/// y[i] = exp(x[i]) / ( Σ_k exp(x[k]) + bias )
pub fn softmax(self: Tensor, axis_: anytype) Tensor {
const a = self.axis(axis_);
const max_val = self.max(a);
const row_mask = max_val.cmp(.GT, Tensor.scalar(-std.math.inf(f64), self.dtype()));
const row_mask = max_val.cmp(.GT, .scalar(-std.math.inf(f64), self.dtype()));
const exp_diff_max = self.sub(self.max(a).broad(self._shape)).exp();
const res = exp_diff_max.div(exp_diff_max.sum(a));
const exp_diff_max = self.sub(max_val).convert(.f32).exp();
const res = exp_diff_max.div(exp_diff_max.sum(a)).convert(self.dtype());
// If a row is full -inf return full 0 instead of full nan,
// this fix attention when mask hides a full row.
return row_mask.broad(self.shape()).select(res, Tensor.scalar(0, self.dtype()));
return row_mask.broad(self.shape()).select(res, .scalar(0, self.dtype()));
}
/// Computes softmax, but adds a bias to the sum of exponentiel.
/// y[i] = exp(x[i]) / ( Σ_k exp(x[k]) + bias )
pub fn softmaxBiased(self: Tensor, axis_: anytype, bias: ?Tensor) Tensor {
const a = self.axis(axis_);
if (bias == null) return self.softmax(axis_);
const b = bias.?.convert(self.dtype()).broad(self.shape().setDim(a, 1));
const max_val: Tensor = maximum(self.max(a), b);
const exp_diff_max = self.sub(max_val).exp();
const bias_diff_max = b.sub(max_val).exp();
const res = exp_diff_max.div(exp_diff_max.sum(a).add(bias_diff_max));
// The bias means that denominator won't be 0, we don't need to handle that case.
return res;
}
/// Returns a Tensor containing the log of the sum of exponential over the given axis.
@ -1534,21 +1552,31 @@ pub const Tensor = struct {
step: i32 = 1,
singleton: bool = false,
const full = .{ .start = 0, .end = to_the_end, .step = 1 };
pub fn single(offset: i64) Slice {
return .{ .start = offset, .end = offset + 1, .singleton = true };
}
pub fn absolute(s: Slice, d: i64) Slice {
const start = if (s.start < 0) d + s.start else s.start;
const end = if (s.end == to_the_end) d else if (s.end < 0) d + s.end else s.end;
const res: Slice = .{ .start = start, .end = end, .step = s.step, .singleton = s.singleton };
stdx.debug.assert(start < end, "Slice {f} is invalid for axis of dimension {d} (resolved to {f}", .{ s, d, res });
return res;
}
const to_the_end = std.math.maxInt(i64);
pub fn format(self: Slice, writer: *std.Io.Writer) !void {
if (self.singleton) {
try writer.print("[{}]", .{self.start});
try writer.print("[{d}]", .{self.start});
} else if (self.end == to_the_end and self.step == 1) {
try writer.print("[{}..]", .{self.start});
try writer.print("[{d}..]", .{self.start});
} else if (self.step == 1) {
try writer.print("[{}..{}]", .{ self.start, self.end });
try writer.print("[{d}..{d}]", .{ self.start, self.end });
} else {
try writer.print("[{}..{}:{}]", .{ self.start, self.end, self.step });
try writer.print("[{d}..{d}:{d}]", .{ self.start, self.end, self.step });
}
}
};
@ -1570,11 +1598,7 @@ pub const Tensor = struct {
for (slices, 0..) |s, a| {
stdx.debug.assert(s.step > 0, "slice expects 'step' to be positive, got {} at index {}", .{ s.step, a });
const args: Slice = .{
.start = self.wrapIndex(a, s.start),
.end = if (s.end == Slice.to_the_end) self.dim(a) else self.wrapIndex(a, s.end),
.step = s.step,
};
const args: Slice = s.absolute(self.dim(a));
start_indices[a] = args.start;
limit_indices[a] = args.end;
strides[a] = args.step;
@ -2016,6 +2040,10 @@ pub const Tensor = struct {
return _result(sh, constant_op.result(0)).convert(val.dtype());
}
pub fn zeroes(sh: Shape) Tensor {
return .constant(sh, sh.dtype().zero());
}
/// Embeds a buffer with concrete values into an Mlir program.
pub fn constantTensor(val: HostBuffer) Tensor {
const ctx = CompilationContext.current().mlirCtx();
@ -3047,6 +3075,7 @@ pub const Tensor = struct {
/// Returns a Tensor representing the result of Top-K over the given axis.
pub fn topK(self: Tensor, named_axis_: anytype, k: u32, opts: struct { descending: bool = true }) SortRes {
stdx.debug.assert(k > 0, "topK expects a k > 0, got 0", .{});
const err_msg = "topK named axis should be an integer or a named axis, eg `x.topK(.{{ .best_token = .token }}, 16)` or `x.topK(-1, 16)`";
const has_name: ?[:0]const u8, const a = switch (@typeInfo(@TypeOf(named_axis_))) {
.int, .comptime_int => .{ null, self.axis(@as(i64, @intCast(named_axis_))) },