Update Llama example to use the new direct rope IR implementation.

This commit is contained in:
Foke Singh 2023-09-25 10:22:05 +00:00
parent b5c4fb7c58
commit 06865f5876

View File

@ -295,18 +295,15 @@ pub const SelfAttn = struct {
const kv_cache = kv_cache_ orelse initKvCache(k.shape());
const seq_len = kv_cache.k.dim(.k);
var attn_mask = zml.nn.causalAttnMask(.{ .q = seq_len, .k = seq_len }, x.dtype(), null);
var cos, var sin = zml.nn.ropeCosSin(.{ .s = seq_len, .hd = k.dim(.hd) }, x.dtype(), self.rope_opts);
if (token_index) |idx| {
// Note: in Pytorch it would be very inefficient to generate the full ropeCosSin and attn_mask matrices, then slice into it,
// but XLA is able to optimize this correctly.
// Note: in Pytorch it would be very inefficient to generate the full attn_mask,
// then slice into it, but XLA is able to optimize this correctly.
attn_mask = attn_mask.dynamicSlice(.{ .q = .{ .start = idx, .len = 1 } });
cos = cos.dynamicSlice(.{ .s = .{ .start = idx, .len = 1 } });
sin = sin.dynamicSlice(.{ .s = .{ .start = idx, .len = 1 } });
}
// In self-attention, .s axis is used both for keys and queries.
q = zml.nn.rope(q, .{ cos, sin }, self.rope_opts);
k = zml.nn.rope(k, .{ cos, sin }, self.rope_opts);
q = zml.nn.rope(q, token_index, self.rope_opts);
k = zml.nn.rope(k, token_index, self.rope_opts);
q = q.rename(.{ .s = .q });
k = k.rename(.{ .s = .k });
v = v.rename(.{ .s = .k });