Refactor rope implementation to compute only required offsets, eliminating full cos/sin matrix generation in module, nn, and tensor code.
This commit is contained in:
parent
06865f5876
commit
5122ca0203
@ -562,6 +562,7 @@ pub const CompilationContext = struct {
|
|||||||
self.extractValues(&args, values[function.n_model..]);
|
self.extractValues(&args, values[function.n_model..]);
|
||||||
|
|
||||||
const op = dialect.func.call(self.mlirCtx(), function.name, values, function.res_types, loc);
|
const op = dialect.func.call(self.mlirCtx(), function.name, values, function.res_types, loc);
|
||||||
|
// TODO: tags seem to be lost by `callFunc`.
|
||||||
var res: stdx.meta.FnResult(func) = undefined;
|
var res: stdx.meta.FnResult(func) = undefined;
|
||||||
assignResults(op, &res, function.res_shapes);
|
assignResults(op, &res, function.res_shapes);
|
||||||
return res;
|
return res;
|
||||||
|
|||||||
75
zml/nn.zig
75
zml/nn.zig
@ -151,44 +151,39 @@ pub const RopeOpts = struct {
|
|||||||
freq_base: f32 = 10_000,
|
freq_base: f32 = 10_000,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub const CosSin = [2]Tensor;
|
|
||||||
|
|
||||||
/// Rotary position embedding modify queries and keys tensor before compute Q * K in self attention.
|
/// Rotary position embedding modify queries and keys tensor before compute Q * K in self attention.
|
||||||
/// This biases a token to look at token near him.
|
/// This biases a token to look at token near him.
|
||||||
/// The nice thing with this solution is that you can cache the modified queries and keys directly.
|
/// The nice thing with rope is that you can cache the modified queries and keys directly.
|
||||||
/// See: https://paperswithcode.com/method/rope
|
/// See: https://paperswithcode.com/method/rope
|
||||||
pub fn rope(x: Tensor, cos_sin_cache: CosSin, opts: RopeOpts) Tensor {
|
///
|
||||||
const cos, const sin = cos_sin_cache;
|
/// Expected shapes of tensor:
|
||||||
stdx.debug.assert(x.dim(-1) == 2 * cos.dim(-1), "Couldn't compute rope({}, {}, {})", .{ x, cos, sin });
|
/// - x: .{ .s, .hd } where .s is the sequence length and .hd the head dimension
|
||||||
// broadcast cos / sin to .{ batch, .seq, .half_dim }
|
/// - pos_idx: optional tensor which indicates which positions are needed.
|
||||||
|
/// When not set `rope` return all positions from 0 to x.dim(.s) which is the max seq len.
|
||||||
|
pub fn rope(x: Tensor, pos_idx: ?Tensor, opts: RopeOpts) Tensor {
|
||||||
|
stdx.debug.assert(@mod(x.dim(.hd), 2) == 0, "rope expects a even head dim (.hd), got {}", .{x});
|
||||||
|
|
||||||
|
const idx = if (pos_idx) |idx| blk: {
|
||||||
|
stdx.debug.assert(x.shape().hasTags(.{.hd}), "rope expects x argument to have .hd axes got: rope(x={}, idx={})", .{ x, idx });
|
||||||
|
break :blk idx;
|
||||||
|
} else blk: {
|
||||||
|
stdx.debug.assert(x.shape().hasTags(.{ .s, .hd }), "rope expects x argument to have both .s and .hd axes got: rope(x={})", .{x});
|
||||||
|
break :blk Tensor.arange(.{ .end = x.dim(.s) }, .f32).withTags(.{.s});
|
||||||
|
};
|
||||||
const x_real, const x_imag = splitRealImg(x, opts.impl);
|
const x_real, const x_imag = splitRealImg(x, opts.impl);
|
||||||
const has_tags = cos.shape().tag(0) != Shape.TagUnknown;
|
|
||||||
const b_cos = if (has_tags) cos.broad(x_real.shape()) else cos.broadcastLeft(x_real.shape());
|
|
||||||
const b_sin = if (has_tags) sin.broad(x_real.shape()) else sin.broadcastLeft(x_real.shape());
|
|
||||||
|
|
||||||
// apply rotation
|
|
||||||
const y_real = x_real.mul(b_cos).sub(x_imag.mul(b_sin));
|
|
||||||
const y_imag = x_real.mul(b_sin).add(x_imag.mul(b_cos));
|
|
||||||
|
|
||||||
// flatten last dimensions
|
|
||||||
const y = mergeRealImg(y_real, y_imag, opts.impl);
|
|
||||||
|
|
||||||
return y;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn ropeCosSin(sh: anytype, dtype: DataType, opts: RopeOpts) CosSin {
|
|
||||||
const shape = Shape.init(sh, dtype);
|
|
||||||
stdx.debug.assert(shape.rank() == 2, "ropeCosSin({}) shape need to exactly have 2 axes", .{shape});
|
|
||||||
const seq_len, const head_dim = .{ shape.dim(0), shape.dim(1) };
|
|
||||||
stdx.debug.assert(@mod(head_dim, 2) == 0, "ropeCosSin requires an even head_dim, got {}", .{head_dim});
|
|
||||||
|
|
||||||
// compute sin and cos in f32 before downcasting to x type.
|
// compute sin and cos in f32 before downcasting to x type.
|
||||||
const inv_freq = invFreq(head_dim, opts.freq_base, .f32);
|
const inv_freq = invFreq(x.dim(.hd), opts.freq_base, .f32).withTags(.{.hd});
|
||||||
var inv_freq_pos = Tensor.outer(Tensor.arange(.{ .end = seq_len }, .f32), inv_freq).convert(shape.dtype());
|
const inv_freq_pos = Tensor.outer(idx.convert(.f32), inv_freq);
|
||||||
inv_freq_pos._shape._tags = shape._tags;
|
const cos = inv_freq_pos.cos().convert(x.dtype()).broad(x_real.shape());
|
||||||
const cos = inv_freq_pos.cos();
|
const sin = inv_freq_pos.sin().convert(x.dtype()).broad(x_real.shape());
|
||||||
const sin = inv_freq_pos.sin();
|
|
||||||
return .{ cos, sin };
|
// apply rotation
|
||||||
|
const y_real = x_real.mul(cos).sub(x_imag.mul(sin));
|
||||||
|
const y_imag = x_real.mul(sin).add(x_imag.mul(cos));
|
||||||
|
|
||||||
|
// flatten last dimensions
|
||||||
|
return mergeRealImg(y_real, y_imag, opts.impl);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn splitRealImg(x: Tensor, impl: RopeOpts.Implementation) [2]Tensor {
|
pub fn splitRealImg(x: Tensor, impl: RopeOpts.Implementation) [2]Tensor {
|
||||||
@ -308,19 +303,18 @@ test "real/img" {
|
|||||||
try testing.expectEqual(20, d_split_interleaved.getValue(i32));
|
try testing.expectEqual(20, d_split_interleaved.getValue(i32));
|
||||||
}
|
}
|
||||||
|
|
||||||
test "rope" {
|
test rope {
|
||||||
const platform = zml.testing.env();
|
const platform = zml.testing.env();
|
||||||
|
|
||||||
const TestRope = struct {
|
const Local = struct {
|
||||||
fn forward(x: Tensor, opts: RopeOpts) Tensor {
|
fn _fwd(x: Tensor, opts: RopeOpts) Tensor {
|
||||||
var input = x;
|
var input = x;
|
||||||
{
|
{
|
||||||
// Convert input to the requested format
|
// Convert input to the requested format
|
||||||
const real, const imag = splitRealImg(input, .sequential);
|
const real, const imag = splitRealImg(input, .sequential);
|
||||||
input = mergeRealImg(real, imag, opts.impl);
|
input = mergeRealImg(real, imag, opts.impl);
|
||||||
}
|
}
|
||||||
const cos_sin = ropeCosSin(.{ input.dim(-2), input.dim(-1) }, input.dtype(), opts);
|
var res = rope(input, null, opts).squeeze(0);
|
||||||
var res = rope(input, cos_sin, opts).squeeze(0);
|
|
||||||
|
|
||||||
{
|
{
|
||||||
// Convert back to sequential
|
// Convert back to sequential
|
||||||
@ -333,10 +327,9 @@ test "rope" {
|
|||||||
|
|
||||||
// x is made such as the interleaved and sequential reps are the same.
|
// x is made such as the interleaved and sequential reps are the same.
|
||||||
// So the two implementations should give the same results.
|
// So the two implementations should give the same results.
|
||||||
const x = try zml.Buffer.fromSlice(platform, .{ 1, 5, 4 }, &[_]f32{ 1.0, 0.1, -1.0, -0.5 } ** 5);
|
const x = try zml.Buffer.fromSlice(platform, .{ .b = 1, .s = 5, .hd = 4 }, &[_]f32{ 1.0, 0.1, -1.0, -0.5 } ** 5);
|
||||||
const res1 = try zml.testing.compileAndCall(platform, TestRope.forward, .{ x, RopeOpts{ .impl = .interleaved } });
|
const res1 = try zml.testing.compileAndCall(platform, Local._fwd, .{ x, RopeOpts{ .impl = .interleaved } });
|
||||||
const res2 = try zml.testing.compileAndCall(platform, TestRope.forward, .{ x, RopeOpts{ .impl = .sequential } });
|
const res2 = try zml.testing.compileAndCall(platform, Local._fwd, .{ x, RopeOpts{ .impl = .sequential } });
|
||||||
|
|
||||||
try zml.testing.expectClose(res1, res2, 1e-4);
|
try zml.testing.expectClose(res1, res2, 1e-4);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1835,15 +1835,25 @@ pub const Tensor = struct {
|
|||||||
|
|
||||||
/// Returns a Tensor containing the result of the outer product between the input Tensors.
|
/// Returns a Tensor containing the result of the outer product between the input Tensors.
|
||||||
pub fn outer(self: Tensor, other: Tensor) Tensor {
|
pub fn outer(self: Tensor, other: Tensor) Tensor {
|
||||||
stdx.debug.assert(self.rank() < 2 and other.rank() < 2 and self.rank() + other.rank() != 0, "outer expects tensor ranks to be at most 1, got {} and {}", .{ self.rank(), other.rank() });
|
|
||||||
|
|
||||||
if (self.rank() + other.rank() == 1) {
|
if (self.rank() + other.rank() == 1) {
|
||||||
return self.mul(other);
|
return self.mul(other);
|
||||||
}
|
}
|
||||||
|
|
||||||
const dimz = .{ self.dim(0), other.dim(0) };
|
const other_shape = other.shape();
|
||||||
const left = self.broadcast(Shape.init(dimz, self.dtype()), &.{0});
|
var res_shape = self.shape();
|
||||||
const right = other.broadcast(Shape.init(dimz, other.dtype()), &.{1});
|
var batching_axes: u8 = 0;
|
||||||
|
for (0..other.rank()) |ax| {
|
||||||
|
if (other_shape.tag(ax) != Shape.TagUnknown) {
|
||||||
|
if (self.shape().hasTag(other_shape.tag(ax))) |batching_ax| {
|
||||||
|
stdx.debug.assert(batching_ax == batching_axes and batching_ax == ax, "outer expects batching dims to be the first dims in both tensors, got outer({}, {})", .{ self, other });
|
||||||
|
batching_axes += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
res_shape = res_shape.appendDim(other_shape.dim(ax), other_shape.tag(ax));
|
||||||
|
}
|
||||||
|
const left = self.broad(res_shape);
|
||||||
|
const right = other.broad(res_shape);
|
||||||
return left.mul(right);
|
return left.mul(right);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user