Enable buffer donation in the Llama example, donating all buffers except the token_index buffer.

This commit is contained in:
Foke Singh 2023-10-03 16:32:40 +00:00
parent 5122ca0203
commit 474f76cd75
2 changed files with 8 additions and 4 deletions

View File

@ -104,7 +104,7 @@ pub const LlamaLM = struct {
}
pub fn increment(_: u8, token_index: Tensor) Tensor {
return token_index.addConstant(1).reuseBuffer(token_index);
return token_index.addConstant(1);
}
/// Run the generation entirely within pjrt.

View File

@ -57,10 +57,12 @@ pub fn generateText(
defer output.deinit();
var tokens = try zml.Buffer.fromSlice(mod.platform(), .{max_seq_len}, token_buffer);
var token_index = try zml.Buffer.fromSlice(mod.platform(), .{}, &[_]i32{@intCast(prompt_tok.len - 1)});
var prefill_token_index = try zml.Buffer.fromSlice(mod.platform(), .{}, &[_]i32{@intCast(prompt_tok.len - 1)});
defer prefill_token_index.deinit();
var rng = try zml.Tensor.Rng.init(mod.platform(), seed);
tokens, token_index, var kv_cache, rng = mod_prefill.call(.{ tokens, token_index, null, rng });
tokens, var token_index, var kv_cache, rng = mod_prefill.call(.{ tokens, prefill_token_index, null, rng });
defer token_index.deinit();
defer kv_cache.k.deinit();
defer kv_cache.v.deinit();
defer kv_cache.layer_index.deinit();
@ -74,7 +76,9 @@ pub fn generateText(
for (0..output_tokens_len) |i| {
//_ = i;
const frame_id = tracer.frameStart(try std.fmt.bufPrintZ(tracer_buffer, "Generate token {}/{}", .{ i + 1, output_tokens_len }));
tokens, token_index, kv_cache, rng = mod.call(.{ tokens, token_index, kv_cache, rng });
tokens, const new_token_index, kv_cache, rng = mod.call(.{ tokens, token_index, kv_cache, rng });
token_index.deinit();
token_index = new_token_index;
if ((i + 1) % output_freq == 0) {
const n = output.items.len;
_ = try tokens.toHost(std.mem.sliceAsBytes(token_buffer));