Add llama example demonstrating the new gatherValues functionality.

This commit is contained in:
Foke Singh 2023-01-11 09:58:09 +00:00
parent 48b671f100
commit 16e066ec69

View File

@ -77,7 +77,7 @@ pub const LlamaLM = struct {
const tokens = tokens_.withPartialTags(.{.s});
const out = out_.withPartialTags(.{ .s, .d });
const next_token_pred = out.dynamicSlice(.{ .s = .{ .start = token_index, .len = 1 } });
const next_token_pred = out.gatherValues(.s, token_index, .{});
var logits = zml.call(lm_head, .forward, .{next_token_pred});
if (logits.shape().hasTag(.voc) == null)
logits = logits.rename(.{ .d = .voc });