Enable buffer donation in the Llama example, donating all buffers except the token_index buffer.
This commit is contained in:
parent
5122ca0203
commit
474f76cd75
@ -104,7 +104,7 @@ pub const LlamaLM = struct {
|
||||
}
|
||||
|
||||
pub fn increment(_: u8, token_index: Tensor) Tensor {
|
||||
return token_index.addConstant(1).reuseBuffer(token_index);
|
||||
return token_index.addConstant(1);
|
||||
}
|
||||
|
||||
/// Run the generation entirely within pjrt.
|
||||
|
||||
@ -57,10 +57,12 @@ pub fn generateText(
|
||||
defer output.deinit();
|
||||
|
||||
var tokens = try zml.Buffer.fromSlice(mod.platform(), .{max_seq_len}, token_buffer);
|
||||
var token_index = try zml.Buffer.fromSlice(mod.platform(), .{}, &[_]i32{@intCast(prompt_tok.len - 1)});
|
||||
var prefill_token_index = try zml.Buffer.fromSlice(mod.platform(), .{}, &[_]i32{@intCast(prompt_tok.len - 1)});
|
||||
defer prefill_token_index.deinit();
|
||||
|
||||
var rng = try zml.Tensor.Rng.init(mod.platform(), seed);
|
||||
tokens, token_index, var kv_cache, rng = mod_prefill.call(.{ tokens, token_index, null, rng });
|
||||
tokens, var token_index, var kv_cache, rng = mod_prefill.call(.{ tokens, prefill_token_index, null, rng });
|
||||
defer token_index.deinit();
|
||||
defer kv_cache.k.deinit();
|
||||
defer kv_cache.v.deinit();
|
||||
defer kv_cache.layer_index.deinit();
|
||||
@ -74,7 +76,9 @@ pub fn generateText(
|
||||
for (0..output_tokens_len) |i| {
|
||||
//_ = i;
|
||||
const frame_id = tracer.frameStart(try std.fmt.bufPrintZ(tracer_buffer, "Generate token {}/{}", .{ i + 1, output_tokens_len }));
|
||||
tokens, token_index, kv_cache, rng = mod.call(.{ tokens, token_index, kv_cache, rng });
|
||||
tokens, const new_token_index, kv_cache, rng = mod.call(.{ tokens, token_index, kv_cache, rng });
|
||||
token_index.deinit();
|
||||
token_index = new_token_index;
|
||||
if ((i + 1) % output_freq == 0) {
|
||||
const n = output.items.len;
|
||||
_ = try tokens.toHost(std.mem.sliceAsBytes(token_buffer));
|
||||
|
||||
Loading…
Reference in New Issue
Block a user