diff --git a/zml/nn.zig b/zml/nn.zig index d609f3c..8ba6903 100644 --- a/zml/nn.zig +++ b/zml/nn.zig @@ -91,9 +91,9 @@ pub const LayerNorm = struct { pub fn forward(self: LayerNorm, x: Tensor) Tensor { const normed = normalizeVariance(x, self.eps); - - var out = normed.mul(self.weight.broadcastLeft(x.shape())); - if (self.bias) |bias| out = out.add(bias.broadcastLeft(x.shape())); + const ax = x.axis(-1); + var out = normed.mul(self.weight.broadcast(x.shape(), &.{ax})); + if (self.bias) |bias| out = out.add(bias.broadcast(x.shape(), &.{ax})); return out; } @@ -760,8 +760,8 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor { const dims = helpers.collectDims(.{ .h, .q, .k, .hd }, &.{ q, k, v, attn_mask }, .strict) catch { meta.panic(err_template ++ "Inputs have incompatible shapes.", err_args); }; - const sqrtHeadDim: f16 = 1.0 / std.math.sqrt(@as(f16, @floatFromInt(dims.hd))); - const scale_logit = if (opts.scale) |s| s else Tensor.scalar(sqrtHeadDim, .f16); + const sqrtHeadDim: f32 = 1.0 / std.math.sqrt(@as(f32, @floatFromInt(dims.hd))); + const scale_logit = if (opts.scale) |s| s else Tensor.scalar(sqrtHeadDim, k.dtype()); k = k.mul(scale_logit.convert(k.dtype())); var attn_weights = q.dot(k, .{.hd}); diff --git a/zml/tensor.zig b/zml/tensor.zig index 0d182fb..a8d71e8 100644 --- a/zml/tensor.zig +++ b/zml/tensor.zig @@ -1413,7 +1413,6 @@ pub const Tensor = struct { } pub const Slice = struct { - single: ?i64 = null, start: i64 = 0, end: ?i64 = null, step: i64 = 1, @@ -1499,9 +1498,8 @@ pub const Tensor = struct { } /// Concatenates the input Tensors along the given axis. - pub fn concatenate(tensors: []const Tensor, axis_: i64) Tensor { + pub fn concatenate(tensors: []const Tensor, axis_: anytype) Tensor { meta.assert(tensors.len <= 32, "concatenate only supports up to 32 tensors, got {}", .{tensors.len}); - // TODO taggedVal var buffer: [32]mlir.Value = undefined; std.debug.assert(tensors.len <= buffer.len); std.debug.assert(tensors.len > 0); @@ -2971,6 +2969,8 @@ pub const Tensor = struct { } /// Slices the input Tensor along a specific axis, with a start offset known at runtime. + /// Note: this doesn't support tagging, if you have tags, + /// you should use `dynamicSlice` directly. pub fn dynamicSlice1d(self: Tensor, axis_: i8, len: u63, start_indices: Tensor) Tensor { meta.assert(start_indices.rank() == 0, "dynamicSlice1d expects 'start_indices' tensor rank to be equal to 0, got {}", .{start_indices.rank()});