diff --git a/examples/llama/llama.zig b/examples/llama/llama.zig index 0112c01..f01ab91 100644 --- a/examples/llama/llama.zig +++ b/examples/llama/llama.zig @@ -77,7 +77,7 @@ pub const LlamaLM = struct { const tokens = tokens_.withPartialTags(.{.s}); const out = out_.withPartialTags(.{ .s, .d }); - const next_token_pred = out.dynamicSlice(.{ .s = .{ .start = token_index, .len = 1 } }); + const next_token_pred = out.gatherValues(.s, token_index, .{}); var logits = zml.call(lm_head, .forward, .{next_token_pred}); if (logits.shape().hasTag(.voc) == null) logits = logits.rename(.{ .d = .voc });