Update Llama example to use the new direct rope IR implementation.
This commit is contained in:
parent
b5c4fb7c58
commit
06865f5876
@ -295,18 +295,15 @@ pub const SelfAttn = struct {
|
||||
const kv_cache = kv_cache_ orelse initKvCache(k.shape());
|
||||
const seq_len = kv_cache.k.dim(.k);
|
||||
var attn_mask = zml.nn.causalAttnMask(.{ .q = seq_len, .k = seq_len }, x.dtype(), null);
|
||||
var cos, var sin = zml.nn.ropeCosSin(.{ .s = seq_len, .hd = k.dim(.hd) }, x.dtype(), self.rope_opts);
|
||||
if (token_index) |idx| {
|
||||
// Note: in Pytorch it would be very inefficient to generate the full ropeCosSin and attn_mask matrices, then slice into it,
|
||||
// but XLA is able to optimize this correctly.
|
||||
// Note: in Pytorch it would be very inefficient to generate the full attn_mask,
|
||||
// then slice into it, but XLA is able to optimize this correctly.
|
||||
attn_mask = attn_mask.dynamicSlice(.{ .q = .{ .start = idx, .len = 1 } });
|
||||
cos = cos.dynamicSlice(.{ .s = .{ .start = idx, .len = 1 } });
|
||||
sin = sin.dynamicSlice(.{ .s = .{ .start = idx, .len = 1 } });
|
||||
}
|
||||
|
||||
// In self-attention, .s axis is used both for keys and queries.
|
||||
q = zml.nn.rope(q, .{ cos, sin }, self.rope_opts);
|
||||
k = zml.nn.rope(k, .{ cos, sin }, self.rope_opts);
|
||||
q = zml.nn.rope(q, token_index, self.rope_opts);
|
||||
k = zml.nn.rope(k, token_index, self.rope_opts);
|
||||
q = q.rename(.{ .s = .q });
|
||||
k = k.rename(.{ .s = .k });
|
||||
v = v.rename(.{ .s = .k });
|
||||
|
||||
Loading…
Reference in New Issue
Block a user