Enable buffer donation in the Llama example, donating all buffers except the token_index buffer.

This commit is contained in:
Foke Singh 2023-10-03 16:32:40 +00:00
parent 5122ca0203
commit 474f76cd75
2 changed files with 8 additions and 4 deletions

View File

@ -104,7 +104,7 @@ pub const LlamaLM = struct {
} }
pub fn increment(_: u8, token_index: Tensor) Tensor { pub fn increment(_: u8, token_index: Tensor) Tensor {
return token_index.addConstant(1).reuseBuffer(token_index); return token_index.addConstant(1);
} }
/// Run the generation entirely within pjrt. /// Run the generation entirely within pjrt.

View File

@ -57,10 +57,12 @@ pub fn generateText(
defer output.deinit(); defer output.deinit();
var tokens = try zml.Buffer.fromSlice(mod.platform(), .{max_seq_len}, token_buffer); var tokens = try zml.Buffer.fromSlice(mod.platform(), .{max_seq_len}, token_buffer);
var token_index = try zml.Buffer.fromSlice(mod.platform(), .{}, &[_]i32{@intCast(prompt_tok.len - 1)}); var prefill_token_index = try zml.Buffer.fromSlice(mod.platform(), .{}, &[_]i32{@intCast(prompt_tok.len - 1)});
defer prefill_token_index.deinit();
var rng = try zml.Tensor.Rng.init(mod.platform(), seed); var rng = try zml.Tensor.Rng.init(mod.platform(), seed);
tokens, token_index, var kv_cache, rng = mod_prefill.call(.{ tokens, token_index, null, rng }); tokens, var token_index, var kv_cache, rng = mod_prefill.call(.{ tokens, prefill_token_index, null, rng });
defer token_index.deinit();
defer kv_cache.k.deinit(); defer kv_cache.k.deinit();
defer kv_cache.v.deinit(); defer kv_cache.v.deinit();
defer kv_cache.layer_index.deinit(); defer kv_cache.layer_index.deinit();
@ -74,7 +76,9 @@ pub fn generateText(
for (0..output_tokens_len) |i| { for (0..output_tokens_len) |i| {
//_ = i; //_ = i;
const frame_id = tracer.frameStart(try std.fmt.bufPrintZ(tracer_buffer, "Generate token {}/{}", .{ i + 1, output_tokens_len })); const frame_id = tracer.frameStart(try std.fmt.bufPrintZ(tracer_buffer, "Generate token {}/{}", .{ i + 1, output_tokens_len }));
tokens, token_index, kv_cache, rng = mod.call(.{ tokens, token_index, kv_cache, rng }); tokens, const new_token_index, kv_cache, rng = mod.call(.{ tokens, token_index, kv_cache, rng });
token_index.deinit();
token_index = new_token_index;
if ((i + 1) % output_freq == 0) { if ((i + 1) % output_freq == 0) {
const n = output.items.len; const n = output.items.len;
_ = try tokens.toHost(std.mem.sliceAsBytes(token_buffer)); _ = try tokens.toHost(std.mem.sliceAsBytes(token_buffer));