Add Tensor.concatenate support, begin deprecating broadcastLeft, and compute transformer head scaling constant in f32 for higher precision.

This commit is contained in:
Tarry Singh 2023-04-21 15:55:07 +00:00
parent 11006ca08d
commit ed6444b775
2 changed files with 8 additions and 8 deletions

View File

@ -91,9 +91,9 @@ pub const LayerNorm = struct {
pub fn forward(self: LayerNorm, x: Tensor) Tensor { pub fn forward(self: LayerNorm, x: Tensor) Tensor {
const normed = normalizeVariance(x, self.eps); const normed = normalizeVariance(x, self.eps);
const ax = x.axis(-1);
var out = normed.mul(self.weight.broadcastLeft(x.shape())); var out = normed.mul(self.weight.broadcast(x.shape(), &.{ax}));
if (self.bias) |bias| out = out.add(bias.broadcastLeft(x.shape())); if (self.bias) |bias| out = out.add(bias.broadcast(x.shape(), &.{ax}));
return out; return out;
} }
@ -760,8 +760,8 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor {
const dims = helpers.collectDims(.{ .h, .q, .k, .hd }, &.{ q, k, v, attn_mask }, .strict) catch { const dims = helpers.collectDims(.{ .h, .q, .k, .hd }, &.{ q, k, v, attn_mask }, .strict) catch {
meta.panic(err_template ++ "Inputs have incompatible shapes.", err_args); meta.panic(err_template ++ "Inputs have incompatible shapes.", err_args);
}; };
const sqrtHeadDim: f16 = 1.0 / std.math.sqrt(@as(f16, @floatFromInt(dims.hd))); const sqrtHeadDim: f32 = 1.0 / std.math.sqrt(@as(f32, @floatFromInt(dims.hd)));
const scale_logit = if (opts.scale) |s| s else Tensor.scalar(sqrtHeadDim, .f16); const scale_logit = if (opts.scale) |s| s else Tensor.scalar(sqrtHeadDim, k.dtype());
k = k.mul(scale_logit.convert(k.dtype())); k = k.mul(scale_logit.convert(k.dtype()));
var attn_weights = q.dot(k, .{.hd}); var attn_weights = q.dot(k, .{.hd});

View File

@ -1413,7 +1413,6 @@ pub const Tensor = struct {
} }
pub const Slice = struct { pub const Slice = struct {
single: ?i64 = null,
start: i64 = 0, start: i64 = 0,
end: ?i64 = null, end: ?i64 = null,
step: i64 = 1, step: i64 = 1,
@ -1499,9 +1498,8 @@ pub const Tensor = struct {
} }
/// Concatenates the input Tensors along the given axis. /// Concatenates the input Tensors along the given axis.
pub fn concatenate(tensors: []const Tensor, axis_: i64) Tensor { pub fn concatenate(tensors: []const Tensor, axis_: anytype) Tensor {
meta.assert(tensors.len <= 32, "concatenate only supports up to 32 tensors, got {}", .{tensors.len}); meta.assert(tensors.len <= 32, "concatenate only supports up to 32 tensors, got {}", .{tensors.len});
// TODO taggedVal
var buffer: [32]mlir.Value = undefined; var buffer: [32]mlir.Value = undefined;
std.debug.assert(tensors.len <= buffer.len); std.debug.assert(tensors.len <= buffer.len);
std.debug.assert(tensors.len > 0); std.debug.assert(tensors.len > 0);
@ -2971,6 +2969,8 @@ pub const Tensor = struct {
} }
/// Slices the input Tensor along a specific axis, with a start offset known at runtime. /// Slices the input Tensor along a specific axis, with a start offset known at runtime.
/// Note: this doesn't support tagging, if you have tags,
/// you should use `dynamicSlice` directly.
pub fn dynamicSlice1d(self: Tensor, axis_: i8, len: u63, start_indices: Tensor) Tensor { pub fn dynamicSlice1d(self: Tensor, axis_: i8, len: u63, start_indices: Tensor) Tensor {
meta.assert(start_indices.rank() == 0, "dynamicSlice1d expects 'start_indices' tensor rank to be equal to 0, got {}", .{start_indices.rank()}); meta.assert(start_indices.rank() == 0, "dynamicSlice1d expects 'start_indices' tensor rank to be equal to 0, got {}", .{start_indices.rank()});