Add Tensor.concatenate support, begin deprecating broadcastLeft, and compute transformer head scaling constant in f32 for higher precision.
This commit is contained in:
parent
11006ca08d
commit
ed6444b775
10
zml/nn.zig
10
zml/nn.zig
@ -91,9 +91,9 @@ pub const LayerNorm = struct {
|
|||||||
|
|
||||||
pub fn forward(self: LayerNorm, x: Tensor) Tensor {
|
pub fn forward(self: LayerNorm, x: Tensor) Tensor {
|
||||||
const normed = normalizeVariance(x, self.eps);
|
const normed = normalizeVariance(x, self.eps);
|
||||||
|
const ax = x.axis(-1);
|
||||||
var out = normed.mul(self.weight.broadcastLeft(x.shape()));
|
var out = normed.mul(self.weight.broadcast(x.shape(), &.{ax}));
|
||||||
if (self.bias) |bias| out = out.add(bias.broadcastLeft(x.shape()));
|
if (self.bias) |bias| out = out.add(bias.broadcast(x.shape(), &.{ax}));
|
||||||
|
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
@ -760,8 +760,8 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor {
|
|||||||
const dims = helpers.collectDims(.{ .h, .q, .k, .hd }, &.{ q, k, v, attn_mask }, .strict) catch {
|
const dims = helpers.collectDims(.{ .h, .q, .k, .hd }, &.{ q, k, v, attn_mask }, .strict) catch {
|
||||||
meta.panic(err_template ++ "Inputs have incompatible shapes.", err_args);
|
meta.panic(err_template ++ "Inputs have incompatible shapes.", err_args);
|
||||||
};
|
};
|
||||||
const sqrtHeadDim: f16 = 1.0 / std.math.sqrt(@as(f16, @floatFromInt(dims.hd)));
|
const sqrtHeadDim: f32 = 1.0 / std.math.sqrt(@as(f32, @floatFromInt(dims.hd)));
|
||||||
const scale_logit = if (opts.scale) |s| s else Tensor.scalar(sqrtHeadDim, .f16);
|
const scale_logit = if (opts.scale) |s| s else Tensor.scalar(sqrtHeadDim, k.dtype());
|
||||||
k = k.mul(scale_logit.convert(k.dtype()));
|
k = k.mul(scale_logit.convert(k.dtype()));
|
||||||
|
|
||||||
var attn_weights = q.dot(k, .{.hd});
|
var attn_weights = q.dot(k, .{.hd});
|
||||||
|
|||||||
@ -1413,7 +1413,6 @@ pub const Tensor = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub const Slice = struct {
|
pub const Slice = struct {
|
||||||
single: ?i64 = null,
|
|
||||||
start: i64 = 0,
|
start: i64 = 0,
|
||||||
end: ?i64 = null,
|
end: ?i64 = null,
|
||||||
step: i64 = 1,
|
step: i64 = 1,
|
||||||
@ -1499,9 +1498,8 @@ pub const Tensor = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Concatenates the input Tensors along the given axis.
|
/// Concatenates the input Tensors along the given axis.
|
||||||
pub fn concatenate(tensors: []const Tensor, axis_: i64) Tensor {
|
pub fn concatenate(tensors: []const Tensor, axis_: anytype) Tensor {
|
||||||
meta.assert(tensors.len <= 32, "concatenate only supports up to 32 tensors, got {}", .{tensors.len});
|
meta.assert(tensors.len <= 32, "concatenate only supports up to 32 tensors, got {}", .{tensors.len});
|
||||||
// TODO taggedVal
|
|
||||||
var buffer: [32]mlir.Value = undefined;
|
var buffer: [32]mlir.Value = undefined;
|
||||||
std.debug.assert(tensors.len <= buffer.len);
|
std.debug.assert(tensors.len <= buffer.len);
|
||||||
std.debug.assert(tensors.len > 0);
|
std.debug.assert(tensors.len > 0);
|
||||||
@ -2971,6 +2969,8 @@ pub const Tensor = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Slices the input Tensor along a specific axis, with a start offset known at runtime.
|
/// Slices the input Tensor along a specific axis, with a start offset known at runtime.
|
||||||
|
/// Note: this doesn't support tagging, if you have tags,
|
||||||
|
/// you should use `dynamicSlice` directly.
|
||||||
pub fn dynamicSlice1d(self: Tensor, axis_: i8, len: u63, start_indices: Tensor) Tensor {
|
pub fn dynamicSlice1d(self: Tensor, axis_: i8, len: u63, start_indices: Tensor) Tensor {
|
||||||
meta.assert(start_indices.rank() == 0, "dynamicSlice1d expects 'start_indices' tensor rank to be equal to 0, got {}", .{start_indices.rank()});
|
meta.assert(start_indices.rank() == 0, "dynamicSlice1d expects 'start_indices' tensor rank to be equal to 0, got {}", .{start_indices.rank()});
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user