diff --git a/stdx/fmt.zig b/stdx/fmt.zig index 4e38488..e9df913 100644 --- a/stdx/fmt.zig +++ b/stdx/fmt.zig @@ -10,20 +10,20 @@ fn FmtSlice(T: type) type { slice: []const T, pub fn format(f: @This(), writer: *std.io.Writer) std.io.Writer.Error!void { - return try formatSliceAny(f.slice, .{}, writer); + return try formatSliceAny(f.slice, .{}, 1, writer); } pub fn formatNumber(f: @This(), writer: *std.io.Writer, n: std.fmt.Number) std.io.Writer.Error!void { return switch (@typeInfo(T)) { - .comptime_float, .float => try formatFloatSlice(f.slice, n, writer), - .comptime_int, .int => try formatIntSlice(f.slice, n, writer), - .bool => try formatBoolSlice(f.slice, n, writer), + .comptime_float, .float => try formatFloatSlice(f.slice, n, 1, writer), + .comptime_int, .int => try formatIntSlice(f.slice, n, 1, writer), + .bool => try formatBoolSlice(f.slice, n, 1, writer), .@"struct" => if (@hasField(T, "re") and @hasField(T, "im")) { - try formatComplexSlice(f.slice, n, writer); + try formatComplexSlice(f.slice, n, 1, writer); } else if (@hasDecl(T, "toF32")) { - try formatFloatSlice(f.slice, n, writer); + try formatFloatSlice(f.slice, n, 1, writer); } else { - try formatSliceAny(f.slice, n, writer); + try formatSliceAny(f.slice, n, 1, writer); }, else => @compileError("FmtSlice doesn't support type: " ++ @typeName(T)), }; @@ -72,53 +72,55 @@ pub fn formatAny(value: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) ! return try writer.alignBufferOptions(s, .{ .alignment = spec.alignment, .fill = spec.fill }); } -pub fn formatSliceCustom(fmt_func: anytype, values: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void { +pub fn formatSliceCustom(fmt_func: anytype, values: anytype, spec: std.fmt.Number, stride: i64, writer: *std.Io.Writer) !void { // use the format "width" for the number of columns instead of individual width. const num_cols: usize = spec.width orelse 12; var my_options = spec; my_options.width = null; - const n: usize = values.len; + // TODO: handle negative strides + const strd: usize = @intCast(stride); + const n: usize = @divTrunc(values.len, strd); _ = try writer.write("{"); if (n <= num_cols) { - for (values, 0..) |v, i| { + for (0..n) |i| { // Force inlining so that the switch and the buffer can be done once. - try @call(.always_inline, fmt_func, .{ v, my_options, writer }); + try @call(.always_inline, fmt_func, .{ values[i * strd], my_options, writer }); if (i < n - 1) _ = try writer.write(","); } } else { const half = @divFloor(num_cols, 2); - for (values[0..half]) |v| { - try @call(.always_inline, fmt_func, .{ v, my_options, writer }); + for (0..half) |i| { + try @call(.always_inline, fmt_func, .{ values[i * strd], my_options, writer }); _ = try writer.write(","); } _ = try writer.write(" ..., "); - for (values[n - half ..], 0..) |v, i| { - try @call(.always_inline, fmt_func, .{ v, my_options, writer }); - if (i < half - 1) _ = try writer.write(","); + for (n - half..n) |i| { + try @call(.always_inline, fmt_func, .{ values[i * strd], my_options, writer }); + if (i < n - 1) _ = try writer.write(","); } } _ = try writer.write("}"); } -pub fn formatSliceAny(values: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void { - return try formatSliceCustom(formatAny, values, spec, writer); +pub fn formatSliceAny(values: anytype, spec: std.fmt.Number, stride: i64, writer: *std.Io.Writer) !void { + return try formatSliceCustom(formatAny, values, spec, stride, writer); } -pub fn formatFloatSlice(values: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void { - return try formatSliceCustom(formatFloat, values, spec, writer); +pub fn formatFloatSlice(values: anytype, spec: std.fmt.Number, stride: i64, writer: *std.Io.Writer) !void { + return try formatSliceCustom(formatFloat, values, spec, stride, writer); } -pub fn formatIntSlice(values: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void { - return try formatSliceCustom(formatInt, values, spec, writer); +pub fn formatIntSlice(values: anytype, spec: std.fmt.Number, stride: i64, writer: *std.Io.Writer) !void { + return try formatSliceCustom(formatInt, values, spec, stride, writer); } -pub fn formatComplexSlice(values: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void { - return try formatSliceCustom(formatComplex, values, spec, writer); +pub fn formatComplexSlice(values: anytype, spec: std.fmt.Number, stride: i64, writer: *std.Io.Writer) !void { + return try formatSliceCustom(formatComplex, values, spec, stride, writer); } -pub fn formatBoolSlice(values: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void { - return try formatSliceCustom(formatBool, values, spec, writer); +pub fn formatBoolSlice(values: anytype, spec: std.fmt.Number, stride: i64, writer: *std.Io.Writer) !void { + return try formatSliceCustom(formatBool, values, spec, stride, writer); } /// Format a struct using `format` method of subfields when possible. diff --git a/tools/zml_utils.py b/tools/zml_utils.py index a0053e7..9bddc02 100644 --- a/tools/zml_utils.py +++ b/tools/zml_utils.py @@ -21,16 +21,6 @@ import torch log = logging.getLogger(__name__) -builtin_open = builtins.open - - -def log_and_open(file, *args, **kwargs): - print("open:", file, args, kwargs) - return builtin_open(file, *args, **kwargs) - - -builtins.open = log_and_open - class ActivationCollector: """Wrap a given torch.nn.Module and collect all its intermediary activations. diff --git a/zml/aio/safetensors.zig b/zml/aio/safetensors.zig index b49928a..de18f44 100644 --- a/zml/aio/safetensors.zig +++ b/zml/aio/safetensors.zig @@ -23,7 +23,7 @@ pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore } else { try loadFile(arena, &res, &files, path); } - res.files = try files.toOwnedSlice(allocator); + res.files = try files.toOwnedSlice(arena); return res; } diff --git a/zml/hostbuffer.zig b/zml/hostbuffer.zig index 3185348..2052970 100644 --- a/zml/hostbuffer.zig +++ b/zml/hostbuffer.zig @@ -355,12 +355,17 @@ pub const HostBuffer = struct { try writer.splatByteAll(' ', indent_level); switch (self.dtype()) { inline else => |dt| { - const values = self.items(dt.toZigType()); + const T = dt.toZigType(); + const n: i64 = self._shape.dim(0); + // TODO: handle negative strides + const byte_stride: i64 = self._strides[0]; + const elem_strides: i64 = @divExact(byte_stride, @sizeOf(T)); + const values: []const T = @ptrCast(@alignCast(self._data[0..@intCast(n * byte_stride)])); switch (comptime dt.class()) { - .float => try stdx.fmt.formatFloatSlice(values, options, writer), - .integer => try stdx.fmt.formatIntSlice(values, options, writer), - .complex => try stdx.fmt.formatComplexSlice(values, options, writer), - .bool => try stdx.fmt.formatBoolSlice(values, options, writer), + .float => try stdx.fmt.formatFloatSlice(values, options, elem_strides, writer), + .integer => try stdx.fmt.formatIntSlice(values, options, elem_strides, writer), + .complex => try stdx.fmt.formatComplexSlice(values, options, elem_strides, writer), + .bool => try stdx.fmt.formatBoolSlice(values, options, elem_strides, writer), } }, } diff --git a/zml/nn.zig b/zml/nn.zig index 4fbb17c..8bfd2f2 100644 --- a/zml/nn.zig +++ b/zml/nn.zig @@ -110,9 +110,8 @@ pub const LayerNorm = struct { pub fn rmsNorm(x: Tensor, axis: anytype, eps: f32) Tensor { const ax = x.axis(axis); // upcast to improve precision - const xf32 = x.convert(.f32); - const mean = xf32.mul(xf32).mean(ax); - const rsqrt = Tensor.rsqrt(mean.addConstant(eps)).convert(x.dtype()); + const variance = x.convert(.f32).powByConst(2).mean(ax); + const rsqrt = Tensor.rsqrt(variance.addConstant(eps)).convert(x.dtype()); return x.mul(rsqrt.broad(x.shape())); } @@ -816,7 +815,7 @@ pub fn resizeCubic1d(image: Tensor, axis: i8, new_len: u63, opt: ResizeOpts) Ten Tensor.constant(t.shape(), dtype.one()), t, t.mul(t), - t.pow(Tensor.scalar(3, dtype)), + t.powByConst(3), }, .last, ._interpolated); std.debug.assert(pos.dim(0) == new_len); @@ -894,6 +893,7 @@ pub fn causalAttnMask( pub const SdpaOpts = struct { attn_mask: ?Tensor = null, scale: ?Tensor = null, + softmax_bias: ?Tensor = null, allow_cudnn: bool = true, // TODO: put a callback instead of all this field, // so that @@ -922,7 +922,7 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor { stdx.debug.assert(k.shape().hasTags(.{ .h, .k, .hd }), err_template ++ "k is missing tags {{.h, .k, .hd}}", err_args); stdx.debug.assert(v.shape().hasTags(.{ .h, .k, .hd }), err_template ++ "v is missing tags {{.h, .k, .hd}}", err_args); - if (opts.allow_cudnn and cuda.canUseCudnnSdpa(q.shape())) { + if (opts.allow_cudnn and cuda.canUseCudnnSdpa(q.shape()) and opts.softmax_bias == null) { return cuda.sdpa(q, k, v, opts); } @@ -940,9 +940,15 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor { k = k.mul(head_scaling.convert(k.dtype())); var attn_weights = q.dot(k, .{.hd}); - // log.debug("attn_weights : {}, attn_mask : {?}", .{ attn_weights, attn_mask }); + // log.debug("attn_weights : {f}, attn_mask : {?f}", .{ attn_weights, attn_mask }); if (attn_mask) |mask| attn_weights = attn_weights.add(mask.broad(attn_weights.shape())); - attn_weights = attn_weights.convert(.f32).softmax(.k).convert(q.dtype()); + attn_weights = attn_weights.convert(.f32); + attn_weights = if (opts.softmax_bias) |softmax_bias| attn: { + // The split is needed because we also split q ourselves. + // TODO: consider letting the user do that. + const bias = softmax_bias.splitAxis(.h, .{ .h = k.dim(.h), .hq = .auto }); + break :attn attn_weights.convert(.f32).softmaxBiased(.k, bias).convert(q.dtype()); + } else attn_weights.convert(.f32).softmax(.k).convert(q.dtype()); var attn = attn_weights.dot(v, .{.k}); return attn.transpose(q.shape()).merge(.{ .h = .{ .h, .hq } }); diff --git a/zml/tensor.zig b/zml/tensor.zig index 78219cd..58798a0 100644 --- a/zml/tensor.zig +++ b/zml/tensor.zig @@ -1105,13 +1105,16 @@ pub const Tensor = struct { /// Axes with the same tag on both sides, and which aren't contracting, /// are considered "batching axes". pub fn dot(lhs: Tensor, rhs: Tensor, comptime contracting: anytype) Tensor { - var contracting_axes: [contracting.len][2]i8 = undefined; - inline for (contracting, 0..) |c, i| { - contracting_axes[i] = .{ lhs.axis(c), rhs.axis(c) }; + var contracting_axes: stdx.BoundedArray([2]i8, MAX_RANK) = .{}; + if (@TypeOf(contracting) == EnumLiteral) { + contracting_axes.appendAssumeCapacity(.{ lhs.axis(contracting), rhs.axis(contracting) }); + } else { + inline for (contracting) |c| { + contracting_axes.appendAssumeCapacity(.{ lhs.axis(c), rhs.axis(c) }); + } } - var batching_axes: [MAX_RANK][2]i8 = undefined; - var n_batching: u8 = 0; + var batching_axes: stdx.BoundedArray([2]i8, MAX_RANK) = .{}; for (lhs._shape.tags(), 0..) |l, li| { stdx.debug.assert(l != Shape.TagUnknown, "Can't use `dot(..., {any})` on {any}, it need to be explictily tagged.", .{ contracting, lhs }); @@ -1119,20 +1122,19 @@ pub const Tensor = struct { stdx.debug.assert(r != Shape.TagUnknown, "Can't use `dot(..., {any})` on {any}, it need to be explictily tagged.", .{ contracting, rhs }); if (l == r) { - for (contracting_axes) |ct| { + for (contracting_axes.slice()) |ct| { if (l == lhs._shape.tag(ct[0])) { break; } } else { // tag is both in lhs and rhs but not in contracting -> it's a batching dim. - batching_axes[n_batching] = .{ @intCast(li), @intCast(ri) }; - n_batching += 1; + batching_axes.appendAssumeCapacity(.{ @intCast(li), @intCast(ri) }); } } } } - return dotGeneral(lhs, rhs, contracting_axes[0..], batching_axes[0..n_batching]); + return dotGeneral(lhs, rhs, contracting_axes.slice(), batching_axes.slice()); } test dot { @@ -1318,23 +1320,39 @@ pub const Tensor = struct { /// Returns a Tensor containing the Sigmoid Linear Unit (SiLU) activation function applied to each element of the input Tensor. /// /// silu(x) = x σ(x) - /// https://paperswithcode.com/method/silu pub fn silu(x: Tensor) Tensor { - return x.mul(x.sigmoid()); + return x.mul(.sigmoid(x)); } - /// Returns a Tensor containing the softmax function applied to each element of the input Tensor. + /// Computes softmax along the given axis. + /// y[i] = exp(x[i]) / ( Σ_k exp(x[k]) + bias ) pub fn softmax(self: Tensor, axis_: anytype) Tensor { const a = self.axis(axis_); const max_val = self.max(a); - const row_mask = max_val.cmp(.GT, Tensor.scalar(-std.math.inf(f64), self.dtype())); + const row_mask = max_val.cmp(.GT, .scalar(-std.math.inf(f64), self.dtype())); - const exp_diff_max = self.sub(self.max(a).broad(self._shape)).exp(); - const res = exp_diff_max.div(exp_diff_max.sum(a)); + const exp_diff_max = self.sub(max_val).convert(.f32).exp(); + const res = exp_diff_max.div(exp_diff_max.sum(a)).convert(self.dtype()); // If a row is full -inf return full 0 instead of full nan, // this fix attention when mask hides a full row. - return row_mask.broad(self.shape()).select(res, Tensor.scalar(0, self.dtype())); + return row_mask.broad(self.shape()).select(res, .scalar(0, self.dtype())); + } + + /// Computes softmax, but adds a bias to the sum of exponentiel. + /// y[i] = exp(x[i]) / ( Σ_k exp(x[k]) + bias ) + pub fn softmaxBiased(self: Tensor, axis_: anytype, bias: ?Tensor) Tensor { + const a = self.axis(axis_); + + if (bias == null) return self.softmax(axis_); + const b = bias.?.convert(self.dtype()).broad(self.shape().setDim(a, 1)); + const max_val: Tensor = maximum(self.max(a), b); + const exp_diff_max = self.sub(max_val).exp(); + const bias_diff_max = b.sub(max_val).exp(); + const res = exp_diff_max.div(exp_diff_max.sum(a).add(bias_diff_max)); + + // The bias means that denominator won't be 0, we don't need to handle that case. + return res; } /// Returns a Tensor containing the log of the sum of exponential over the given axis. @@ -1534,21 +1552,31 @@ pub const Tensor = struct { step: i32 = 1, singleton: bool = false, + const full = .{ .start = 0, .end = to_the_end, .step = 1 }; + pub fn single(offset: i64) Slice { return .{ .start = offset, .end = offset + 1, .singleton = true }; } + pub fn absolute(s: Slice, d: i64) Slice { + const start = if (s.start < 0) d + s.start else s.start; + const end = if (s.end == to_the_end) d else if (s.end < 0) d + s.end else s.end; + const res: Slice = .{ .start = start, .end = end, .step = s.step, .singleton = s.singleton }; + stdx.debug.assert(start < end, "Slice {f} is invalid for axis of dimension {d} (resolved to {f}", .{ s, d, res }); + return res; + } + const to_the_end = std.math.maxInt(i64); pub fn format(self: Slice, writer: *std.Io.Writer) !void { if (self.singleton) { - try writer.print("[{}]", .{self.start}); + try writer.print("[{d}]", .{self.start}); } else if (self.end == to_the_end and self.step == 1) { - try writer.print("[{}..]", .{self.start}); + try writer.print("[{d}..]", .{self.start}); } else if (self.step == 1) { - try writer.print("[{}..{}]", .{ self.start, self.end }); + try writer.print("[{d}..{d}]", .{ self.start, self.end }); } else { - try writer.print("[{}..{}:{}]", .{ self.start, self.end, self.step }); + try writer.print("[{d}..{d}:{d}]", .{ self.start, self.end, self.step }); } } }; @@ -1570,11 +1598,7 @@ pub const Tensor = struct { for (slices, 0..) |s, a| { stdx.debug.assert(s.step > 0, "slice expects 'step' to be positive, got {} at index {}", .{ s.step, a }); - const args: Slice = .{ - .start = self.wrapIndex(a, s.start), - .end = if (s.end == Slice.to_the_end) self.dim(a) else self.wrapIndex(a, s.end), - .step = s.step, - }; + const args: Slice = s.absolute(self.dim(a)); start_indices[a] = args.start; limit_indices[a] = args.end; strides[a] = args.step; @@ -2016,6 +2040,10 @@ pub const Tensor = struct { return _result(sh, constant_op.result(0)).convert(val.dtype()); } + pub fn zeroes(sh: Shape) Tensor { + return .constant(sh, sh.dtype().zero()); + } + /// Embeds a buffer with concrete values into an Mlir program. pub fn constantTensor(val: HostBuffer) Tensor { const ctx = CompilationContext.current().mlirCtx(); @@ -3047,6 +3075,7 @@ pub const Tensor = struct { /// Returns a Tensor representing the result of Top-K over the given axis. pub fn topK(self: Tensor, named_axis_: anytype, k: u32, opts: struct { descending: bool = true }) SortRes { + stdx.debug.assert(k > 0, "topK expects a k > 0, got 0", .{}); const err_msg = "topK named axis should be an integer or a named axis, eg `x.topK(.{{ .best_token = .token }}, 16)` or `x.topK(-1, 16)`"; const has_name: ?[:0]const u8, const a = switch (@typeInfo(@TypeOf(named_axis_))) { .int, .comptime_int => .{ null, self.axis(@as(i64, @intCast(named_axis_))) },