From 5122ca020361902411a4d8447fc1e8a802461e33 Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Wed, 27 Sep 2023 11:45:33 +0000 Subject: [PATCH] Refactor rope implementation to compute only required offsets, eliminating full cos/sin matrix generation in module, nn, and tensor code. --- zml/module.zig | 1 + zml/nn.zig | 75 +++++++++++++++++++++++--------------------------- zml/tensor.zig | 20 ++++++++++---- 3 files changed, 50 insertions(+), 46 deletions(-) diff --git a/zml/module.zig b/zml/module.zig index 8c9df46..5ced279 100644 --- a/zml/module.zig +++ b/zml/module.zig @@ -562,6 +562,7 @@ pub const CompilationContext = struct { self.extractValues(&args, values[function.n_model..]); const op = dialect.func.call(self.mlirCtx(), function.name, values, function.res_types, loc); + // TODO: tags seem to be lost by `callFunc`. var res: stdx.meta.FnResult(func) = undefined; assignResults(op, &res, function.res_shapes); return res; diff --git a/zml/nn.zig b/zml/nn.zig index 5b33a50..b8d7d93 100644 --- a/zml/nn.zig +++ b/zml/nn.zig @@ -151,44 +151,39 @@ pub const RopeOpts = struct { freq_base: f32 = 10_000, }; -pub const CosSin = [2]Tensor; - /// Rotary position embedding modify queries and keys tensor before compute Q * K in self attention. /// This biases a token to look at token near him. -/// The nice thing with this solution is that you can cache the modified queries and keys directly. +/// The nice thing with rope is that you can cache the modified queries and keys directly. /// See: https://paperswithcode.com/method/rope -pub fn rope(x: Tensor, cos_sin_cache: CosSin, opts: RopeOpts) Tensor { - const cos, const sin = cos_sin_cache; - stdx.debug.assert(x.dim(-1) == 2 * cos.dim(-1), "Couldn't compute rope({}, {}, {})", .{ x, cos, sin }); - // broadcast cos / sin to .{ batch, .seq, .half_dim } +/// +/// Expected shapes of tensor: +/// - x: .{ .s, .hd } where .s is the sequence length and .hd the head dimension +/// - pos_idx: optional tensor which indicates which positions are needed. +/// When not set `rope` return all positions from 0 to x.dim(.s) which is the max seq len. +pub fn rope(x: Tensor, pos_idx: ?Tensor, opts: RopeOpts) Tensor { + stdx.debug.assert(@mod(x.dim(.hd), 2) == 0, "rope expects a even head dim (.hd), got {}", .{x}); + + const idx = if (pos_idx) |idx| blk: { + stdx.debug.assert(x.shape().hasTags(.{.hd}), "rope expects x argument to have .hd axes got: rope(x={}, idx={})", .{ x, idx }); + break :blk idx; + } else blk: { + stdx.debug.assert(x.shape().hasTags(.{ .s, .hd }), "rope expects x argument to have both .s and .hd axes got: rope(x={})", .{x}); + break :blk Tensor.arange(.{ .end = x.dim(.s) }, .f32).withTags(.{.s}); + }; const x_real, const x_imag = splitRealImg(x, opts.impl); - const has_tags = cos.shape().tag(0) != Shape.TagUnknown; - const b_cos = if (has_tags) cos.broad(x_real.shape()) else cos.broadcastLeft(x_real.shape()); - const b_sin = if (has_tags) sin.broad(x_real.shape()) else sin.broadcastLeft(x_real.shape()); - - // apply rotation - const y_real = x_real.mul(b_cos).sub(x_imag.mul(b_sin)); - const y_imag = x_real.mul(b_sin).add(x_imag.mul(b_cos)); - - // flatten last dimensions - const y = mergeRealImg(y_real, y_imag, opts.impl); - - return y; -} - -pub fn ropeCosSin(sh: anytype, dtype: DataType, opts: RopeOpts) CosSin { - const shape = Shape.init(sh, dtype); - stdx.debug.assert(shape.rank() == 2, "ropeCosSin({}) shape need to exactly have 2 axes", .{shape}); - const seq_len, const head_dim = .{ shape.dim(0), shape.dim(1) }; - stdx.debug.assert(@mod(head_dim, 2) == 0, "ropeCosSin requires an even head_dim, got {}", .{head_dim}); // compute sin and cos in f32 before downcasting to x type. - const inv_freq = invFreq(head_dim, opts.freq_base, .f32); - var inv_freq_pos = Tensor.outer(Tensor.arange(.{ .end = seq_len }, .f32), inv_freq).convert(shape.dtype()); - inv_freq_pos._shape._tags = shape._tags; - const cos = inv_freq_pos.cos(); - const sin = inv_freq_pos.sin(); - return .{ cos, sin }; + const inv_freq = invFreq(x.dim(.hd), opts.freq_base, .f32).withTags(.{.hd}); + const inv_freq_pos = Tensor.outer(idx.convert(.f32), inv_freq); + const cos = inv_freq_pos.cos().convert(x.dtype()).broad(x_real.shape()); + const sin = inv_freq_pos.sin().convert(x.dtype()).broad(x_real.shape()); + + // apply rotation + const y_real = x_real.mul(cos).sub(x_imag.mul(sin)); + const y_imag = x_real.mul(sin).add(x_imag.mul(cos)); + + // flatten last dimensions + return mergeRealImg(y_real, y_imag, opts.impl); } pub fn splitRealImg(x: Tensor, impl: RopeOpts.Implementation) [2]Tensor { @@ -308,19 +303,18 @@ test "real/img" { try testing.expectEqual(20, d_split_interleaved.getValue(i32)); } -test "rope" { +test rope { const platform = zml.testing.env(); - const TestRope = struct { - fn forward(x: Tensor, opts: RopeOpts) Tensor { + const Local = struct { + fn _fwd(x: Tensor, opts: RopeOpts) Tensor { var input = x; { // Convert input to the requested format const real, const imag = splitRealImg(input, .sequential); input = mergeRealImg(real, imag, opts.impl); } - const cos_sin = ropeCosSin(.{ input.dim(-2), input.dim(-1) }, input.dtype(), opts); - var res = rope(input, cos_sin, opts).squeeze(0); + var res = rope(input, null, opts).squeeze(0); { // Convert back to sequential @@ -333,10 +327,9 @@ test "rope" { // x is made such as the interleaved and sequential reps are the same. // So the two implementations should give the same results. - const x = try zml.Buffer.fromSlice(platform, .{ 1, 5, 4 }, &[_]f32{ 1.0, 0.1, -1.0, -0.5 } ** 5); - const res1 = try zml.testing.compileAndCall(platform, TestRope.forward, .{ x, RopeOpts{ .impl = .interleaved } }); - const res2 = try zml.testing.compileAndCall(platform, TestRope.forward, .{ x, RopeOpts{ .impl = .sequential } }); - + const x = try zml.Buffer.fromSlice(platform, .{ .b = 1, .s = 5, .hd = 4 }, &[_]f32{ 1.0, 0.1, -1.0, -0.5 } ** 5); + const res1 = try zml.testing.compileAndCall(platform, Local._fwd, .{ x, RopeOpts{ .impl = .interleaved } }); + const res2 = try zml.testing.compileAndCall(platform, Local._fwd, .{ x, RopeOpts{ .impl = .sequential } }); try zml.testing.expectClose(res1, res2, 1e-4); } diff --git a/zml/tensor.zig b/zml/tensor.zig index 584dd12..a44958b 100644 --- a/zml/tensor.zig +++ b/zml/tensor.zig @@ -1835,15 +1835,25 @@ pub const Tensor = struct { /// Returns a Tensor containing the result of the outer product between the input Tensors. pub fn outer(self: Tensor, other: Tensor) Tensor { - stdx.debug.assert(self.rank() < 2 and other.rank() < 2 and self.rank() + other.rank() != 0, "outer expects tensor ranks to be at most 1, got {} and {}", .{ self.rank(), other.rank() }); - if (self.rank() + other.rank() == 1) { return self.mul(other); } - const dimz = .{ self.dim(0), other.dim(0) }; - const left = self.broadcast(Shape.init(dimz, self.dtype()), &.{0}); - const right = other.broadcast(Shape.init(dimz, other.dtype()), &.{1}); + const other_shape = other.shape(); + var res_shape = self.shape(); + var batching_axes: u8 = 0; + for (0..other.rank()) |ax| { + if (other_shape.tag(ax) != Shape.TagUnknown) { + if (self.shape().hasTag(other_shape.tag(ax))) |batching_ax| { + stdx.debug.assert(batching_ax == batching_axes and batching_ax == ax, "outer expects batching dims to be the first dims in both tensors, got outer({}, {})", .{ self, other }); + batching_axes += 1; + } + } + + res_shape = res_shape.appendDim(other_shape.dim(ax), other_shape.tag(ax)); + } + const left = self.broad(res_shape); + const right = other.broad(res_shape); return left.mul(right); }