Add llama example demonstrating the new gatherValues functionality.
This commit is contained in:
parent
48b671f100
commit
16e066ec69
@ -77,7 +77,7 @@ pub const LlamaLM = struct {
|
||||
const tokens = tokens_.withPartialTags(.{.s});
|
||||
const out = out_.withPartialTags(.{ .s, .d });
|
||||
|
||||
const next_token_pred = out.dynamicSlice(.{ .s = .{ .start = token_index, .len = 1 } });
|
||||
const next_token_pred = out.gatherValues(.s, token_index, .{});
|
||||
var logits = zml.call(lm_head, .forward, .{next_token_pred});
|
||||
if (logits.shape().hasTag(.voc) == null)
|
||||
logits = logits.rename(.{ .d = .voc });
|
||||
|
||||
Loading…
Reference in New Issue
Block a user