Update Llama example to use the simplified transpose implementation and increase default profiler size to 1,000,000 events.
This commit is contained in:
parent
145e60b4dd
commit
8a031bd4c8
@ -194,7 +194,7 @@ pub const Llama = struct {
|
|||||||
|
|
||||||
pub fn embed(embed_tokens_: zml.nn.TokenEmbedding, tokens_: Tensor, token_index: ?Tensor) Tensor {
|
pub fn embed(embed_tokens_: zml.nn.TokenEmbedding, tokens_: Tensor, token_index: ?Tensor) Tensor {
|
||||||
const tokens = if (token_index) |idx|
|
const tokens = if (token_index) |idx|
|
||||||
tokens_.dynamicSlice1d(-1, 1, idx)
|
tokens_.dynamicSlice1d(-1, .{ .start = idx, .len = 1 })
|
||||||
else
|
else
|
||||||
tokens_;
|
tokens_;
|
||||||
return zml.call(embed_tokens_, .forward, .{tokens}).withPartialTags(.{ .s, .d });
|
return zml.call(embed_tokens_, .forward, .{tokens}).withPartialTags(.{ .s, .d });
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user