diff --git a/examples/llama/llama.zig b/examples/llama/llama.zig index 4416811..fed4063 100644 --- a/examples/llama/llama.zig +++ b/examples/llama/llama.zig @@ -295,18 +295,15 @@ pub const SelfAttn = struct { const kv_cache = kv_cache_ orelse initKvCache(k.shape()); const seq_len = kv_cache.k.dim(.k); var attn_mask = zml.nn.causalAttnMask(.{ .q = seq_len, .k = seq_len }, x.dtype(), null); - var cos, var sin = zml.nn.ropeCosSin(.{ .s = seq_len, .hd = k.dim(.hd) }, x.dtype(), self.rope_opts); if (token_index) |idx| { - // Note: in Pytorch it would be very inefficient to generate the full ropeCosSin and attn_mask matrices, then slice into it, - // but XLA is able to optimize this correctly. + // Note: in Pytorch it would be very inefficient to generate the full attn_mask, + // then slice into it, but XLA is able to optimize this correctly. attn_mask = attn_mask.dynamicSlice(.{ .q = .{ .start = idx, .len = 1 } }); - cos = cos.dynamicSlice(.{ .s = .{ .start = idx, .len = 1 } }); - sin = sin.dynamicSlice(.{ .s = .{ .start = idx, .len = 1 } }); } // In self-attention, .s axis is used both for keys and queries. - q = zml.nn.rope(q, .{ cos, sin }, self.rope_opts); - k = zml.nn.rope(k, .{ cos, sin }, self.rope_opts); + q = zml.nn.rope(q, token_index, self.rope_opts); + k = zml.nn.rope(k, token_index, self.rope_opts); q = q.rename(.{ .s = .q }); k = k.rename(.{ .s = .k }); v = v.rename(.{ .s = .k });