Update Llama example to use the new direct rope IR implementation.
This commit is contained in:
parent
b5c4fb7c58
commit
06865f5876
@ -295,18 +295,15 @@ pub const SelfAttn = struct {
|
|||||||
const kv_cache = kv_cache_ orelse initKvCache(k.shape());
|
const kv_cache = kv_cache_ orelse initKvCache(k.shape());
|
||||||
const seq_len = kv_cache.k.dim(.k);
|
const seq_len = kv_cache.k.dim(.k);
|
||||||
var attn_mask = zml.nn.causalAttnMask(.{ .q = seq_len, .k = seq_len }, x.dtype(), null);
|
var attn_mask = zml.nn.causalAttnMask(.{ .q = seq_len, .k = seq_len }, x.dtype(), null);
|
||||||
var cos, var sin = zml.nn.ropeCosSin(.{ .s = seq_len, .hd = k.dim(.hd) }, x.dtype(), self.rope_opts);
|
|
||||||
if (token_index) |idx| {
|
if (token_index) |idx| {
|
||||||
// Note: in Pytorch it would be very inefficient to generate the full ropeCosSin and attn_mask matrices, then slice into it,
|
// Note: in Pytorch it would be very inefficient to generate the full attn_mask,
|
||||||
// but XLA is able to optimize this correctly.
|
// then slice into it, but XLA is able to optimize this correctly.
|
||||||
attn_mask = attn_mask.dynamicSlice(.{ .q = .{ .start = idx, .len = 1 } });
|
attn_mask = attn_mask.dynamicSlice(.{ .q = .{ .start = idx, .len = 1 } });
|
||||||
cos = cos.dynamicSlice(.{ .s = .{ .start = idx, .len = 1 } });
|
|
||||||
sin = sin.dynamicSlice(.{ .s = .{ .start = idx, .len = 1 } });
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// In self-attention, .s axis is used both for keys and queries.
|
// In self-attention, .s axis is used both for keys and queries.
|
||||||
q = zml.nn.rope(q, .{ cos, sin }, self.rope_opts);
|
q = zml.nn.rope(q, token_index, self.rope_opts);
|
||||||
k = zml.nn.rope(k, .{ cos, sin }, self.rope_opts);
|
k = zml.nn.rope(k, token_index, self.rope_opts);
|
||||||
q = q.rename(.{ .s = .q });
|
q = q.rename(.{ .s = .q });
|
||||||
k = k.rename(.{ .s = .k });
|
k = k.rename(.{ .s = .k });
|
||||||
v = v.rename(.{ .s = .k });
|
v = v.rename(.{ .s = .k });
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user