diff --git a/examples/llama/llama.zig b/examples/llama/llama.zig index fed4063..b28b86b 100644 --- a/examples/llama/llama.zig +++ b/examples/llama/llama.zig @@ -104,7 +104,7 @@ pub const LlamaLM = struct { } pub fn increment(_: u8, token_index: Tensor) Tensor { - return token_index.addConstant(1).reuseBuffer(token_index); + return token_index.addConstant(1); } /// Run the generation entirely within pjrt. diff --git a/examples/llama/main.zig b/examples/llama/main.zig index 1eda838..38ae79c 100644 --- a/examples/llama/main.zig +++ b/examples/llama/main.zig @@ -57,10 +57,12 @@ pub fn generateText( defer output.deinit(); var tokens = try zml.Buffer.fromSlice(mod.platform(), .{max_seq_len}, token_buffer); - var token_index = try zml.Buffer.fromSlice(mod.platform(), .{}, &[_]i32{@intCast(prompt_tok.len - 1)}); + var prefill_token_index = try zml.Buffer.fromSlice(mod.platform(), .{}, &[_]i32{@intCast(prompt_tok.len - 1)}); + defer prefill_token_index.deinit(); var rng = try zml.Tensor.Rng.init(mod.platform(), seed); - tokens, token_index, var kv_cache, rng = mod_prefill.call(.{ tokens, token_index, null, rng }); + tokens, var token_index, var kv_cache, rng = mod_prefill.call(.{ tokens, prefill_token_index, null, rng }); + defer token_index.deinit(); defer kv_cache.k.deinit(); defer kv_cache.v.deinit(); defer kv_cache.layer_index.deinit(); @@ -74,7 +76,9 @@ pub fn generateText( for (0..output_tokens_len) |i| { //_ = i; const frame_id = tracer.frameStart(try std.fmt.bufPrintZ(tracer_buffer, "Generate token {}/{}", .{ i + 1, output_tokens_len })); - tokens, token_index, kv_cache, rng = mod.call(.{ tokens, token_index, kv_cache, rng }); + tokens, const new_token_index, kv_cache, rng = mod.call(.{ tokens, token_index, kv_cache, rng }); + token_index.deinit(); + token_index = new_token_index; if ((i + 1) % output_freq == 0) { const n = output.items.len; _ = try tokens.toHost(std.mem.sliceAsBytes(token_buffer));