Fix llama example to correctly handle token output and avoid re‑feeding the last prompt token.
This commit is contained in:
parent
5a2171793d
commit
394e63e273
@ -87,16 +87,15 @@ pub const LlamaLM = struct {
|
|||||||
rng: Tensor.Rng,
|
rng: Tensor.Rng,
|
||||||
) struct { Tensor, KvCache, Tensor.Rng } {
|
) struct { Tensor, KvCache, Tensor.Rng } {
|
||||||
stdx.debug.assert(tokens_.dtype() == .u32 and tokens_.rank() >= 1 and token_index.dtype() == .u32 and token_index.rank() <= 1, "Can't run Llama ! Expected >=1d tokens and 0d token_index, got: {} and {}", .{ tokens_, token_index });
|
stdx.debug.assert(tokens_.dtype() == .u32 and tokens_.rank() >= 1 and token_index.dtype() == .u32 and token_index.rank() <= 1, "Can't run Llama ! Expected >=1d tokens and 0d token_index, got: {} and {}", .{ tokens_, token_index });
|
||||||
var tokens = tokens_.withPartialTags(.{.s});
|
const tokens = tokens_.withPartialTags(.{.s});
|
||||||
const out, const updated_kv_cache = zml.call(self.model, .forward, .{ tokens, token_index, kv_cache });
|
const out, const updated_kv_cache = zml.call(self.model, .forward, .{ tokens, token_index, kv_cache });
|
||||||
tokens, const new_rng = self.sampleTokens(self.lm_head, tokens, out, rng, self.gen_opts);
|
const new_tokens, const new_rng = self.sampleTokens(self.lm_head, out, rng, self.gen_opts);
|
||||||
return .{ tokens, updated_kv_cache, new_rng };
|
return .{ new_tokens.convert(tokens.dtype()).reuseBuffer(tokens), updated_kv_cache, new_rng };
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn sampleTokens(
|
pub fn sampleTokens(
|
||||||
self: LlamaLM,
|
self: LlamaLM,
|
||||||
lm_head_: ?zml.nn.Linear,
|
lm_head_: ?zml.nn.Linear,
|
||||||
tokens_: Tensor,
|
|
||||||
out_: Tensor,
|
out_: Tensor,
|
||||||
rng: Tensor.Rng,
|
rng: Tensor.Rng,
|
||||||
opts: zml.nn.SamplingStrategy,
|
opts: zml.nn.SamplingStrategy,
|
||||||
@ -115,7 +114,7 @@ pub const LlamaLM = struct {
|
|||||||
logits = logits.rename(.{ .d = .voc });
|
logits = logits.rename(.{ .d = .voc });
|
||||||
|
|
||||||
const next_tokens, const new_rng = zml.nn.sampleTokens(logits, opts, rng);
|
const next_tokens, const new_rng = zml.nn.sampleTokens(logits, opts, rng);
|
||||||
return .{ next_tokens.convert(tokens_.dtype()).reuseBuffer(tokens_), new_rng };
|
return .{ next_tokens, new_rng };
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn increment(_: u8, token_index: Tensor) Tensor {
|
pub fn increment(_: u8, token_index: Tensor) Tensor {
|
||||||
|
|||||||
@ -22,11 +22,17 @@ pub const std_options = .{
|
|||||||
.logFn = asynk.logFn(std.log.defaultLog),
|
.logFn = asynk.logFn(std.log.defaultLog),
|
||||||
};
|
};
|
||||||
|
|
||||||
pub fn tokenizePromptLlama3(allocator: std.mem.Allocator, tokenizer: zml.tokenizer.Tokenizer, config: LlamaLM.Config, prompt: []const u8) ![]u32 {
|
pub fn tokenizePrompt(allocator: std.mem.Allocator, tokenizer: zml.tokenizer.Tokenizer, config: LlamaLM.Config, prompt: []const u8, skip_llama3_encoding: bool) ![]u32 {
|
||||||
var tokens = std.ArrayList(u32).init(allocator);
|
var tokens = std.ArrayList(u32).init(allocator);
|
||||||
var encoder = try tokenizer.encoder();
|
var encoder = try tokenizer.encoder();
|
||||||
defer encoder.deinit();
|
defer encoder.deinit();
|
||||||
|
|
||||||
|
if (skip_llama3_encoding) {
|
||||||
|
// Copy to the arraylist so the ownership is the same in both branches.
|
||||||
|
try tokens.appendSlice(try encoder.encode(prompt));
|
||||||
|
return tokens.toOwnedSlice();
|
||||||
|
}
|
||||||
|
|
||||||
const start_header_id = tokenizer.tokenToId("<|start_header_id|>") orelse return error.NoSuchToken;
|
const start_header_id = tokenizer.tokenToId("<|start_header_id|>") orelse return error.NoSuchToken;
|
||||||
const end_header_id = tokenizer.tokenToId("<|end_header_id|>") orelse return error.NoSuchToken;
|
const end_header_id = tokenizer.tokenToId("<|end_header_id|>") orelse return error.NoSuchToken;
|
||||||
const eot_id = tokenizer.tokenToId("<|eot_id|>") orelse return error.NoSuchToken;
|
const eot_id = tokenizer.tokenToId("<|eot_id|>") orelse return error.NoSuchToken;
|
||||||
@ -59,84 +65,61 @@ pub fn generateText(
|
|||||||
prompt: []const u8,
|
prompt: []const u8,
|
||||||
skip_llama3_encoding: bool,
|
skip_llama3_encoding: bool,
|
||||||
) ![]const u8 {
|
) ![]const u8 {
|
||||||
var tokenizer_encoder = try tokenizer.encoder();
|
const prompt_tok: []const u32 = try tokenizePrompt(allocator, tokenizer, config, prompt, skip_llama3_encoding);
|
||||||
defer tokenizer_encoder.deinit();
|
defer allocator.free(prompt_tok);
|
||||||
|
|
||||||
var tokenizer_decoder = try tokenizer.decoder();
|
var tokenizer_decoder = try tokenizer.decoder();
|
||||||
defer tokenizer_decoder.deinit();
|
defer tokenizer_decoder.deinit();
|
||||||
|
|
||||||
const prompt_tok: []const u32 = if (skip_llama3_encoding) try tokenizer_encoder.encode(prompt) else try tokenizePromptLlama3(allocator, tokenizer, config, prompt);
|
|
||||||
defer allocator.free(prompt_tok);
|
|
||||||
|
|
||||||
const dims = llama_.model.shape();
|
|
||||||
const max_seq_len = dims.s;
|
|
||||||
|
|
||||||
// Prefill
|
|
||||||
// initialize a 0..max_seq_len buffer with the tokenized prompt
|
|
||||||
const prefill_buffer = try allocator.alloc(u32, @intCast(max_seq_len));
|
|
||||||
@memset(prefill_buffer, 0);
|
|
||||||
for (0..prompt_tok.len) |i| {
|
|
||||||
prefill_buffer[i] = @intCast(prompt_tok[i]);
|
|
||||||
}
|
|
||||||
defer allocator.free(prefill_buffer);
|
|
||||||
|
|
||||||
const platform = mod_generate.platform();
|
const platform = mod_generate.platform();
|
||||||
|
const max_seq_len = llama_.model.shape().s;
|
||||||
|
|
||||||
|
// init RNG and buffers
|
||||||
|
var rng = try zml.Tensor.Rng.init(platform, seed);
|
||||||
|
var generated_token_buffer = [_]u32{undefined};
|
||||||
|
|
||||||
|
var kv_cache = prefill: {
|
||||||
|
// prepare device buffers for the prefill tokens and their positions
|
||||||
|
const prefill_buffer = try allocator.alloc(u32, max_seq_len);
|
||||||
|
@memcpy(prefill_buffer[0..prompt_tok.len], prompt_tok);
|
||||||
|
|
||||||
// prepare device buffers for the prefill tokens and the index
|
|
||||||
var prefill_tokens = try zml.Buffer.fromSlice(platform, .{max_seq_len}, prefill_buffer);
|
var prefill_tokens = try zml.Buffer.fromSlice(platform, .{max_seq_len}, prefill_buffer);
|
||||||
defer prefill_tokens.deinit();
|
defer prefill_tokens.deinit();
|
||||||
var prefill_token_index = try zml.Buffer.constant(platform, zml.Shape.init(.{}, .u32), 0);
|
var prefill_token_pos = try zml.Buffer.constant(platform, zml.Shape.init(.{}, .u32), 0);
|
||||||
|
defer prefill_token_pos.deinit();
|
||||||
|
|
||||||
defer prefill_token_index.deinit();
|
const prefilled_tokens, const kv_cache, rng = mod_prefill.call(.{ prefill_tokens, prefill_token_pos, kv_cache_, rng });
|
||||||
|
_ = try prefilled_tokens.toHost(std.mem.sliceAsBytes(prefill_buffer));
|
||||||
|
generated_token_buffer[0] = prefill_buffer[prompt_tok.len - 1];
|
||||||
|
break :prefill kv_cache;
|
||||||
|
};
|
||||||
|
defer zml.aio.unloadBuffers(&kv_cache);
|
||||||
|
|
||||||
// init RNG and prefill
|
// Prepare for token-by-token generation,
|
||||||
var rng = try zml.Tensor.Rng.init(platform, seed);
|
// start with the token generated based on the full prompt.
|
||||||
prefill_tokens, var kv_cache, rng = mod_prefill.call(.{ prefill_tokens, prefill_token_index, kv_cache_, rng });
|
var current_token = try zml.Buffer.fromSlice(platform, .{1}, &generated_token_buffer);
|
||||||
defer kv_cache.k.deinit();
|
|
||||||
defer kv_cache.v.deinit();
|
|
||||||
defer kv_cache.layer_index.deinit();
|
|
||||||
|
|
||||||
// Prepare for token-by-token generation
|
|
||||||
var first_token_hostbuffer = [_]u32{prompt_tok[prompt_tok.len - 1]}; // start with the prompt's last token
|
|
||||||
var current_token = try zml.Buffer.fromSlice(platform, .{1}, &first_token_hostbuffer);
|
|
||||||
defer current_token.deinit();
|
defer current_token.deinit();
|
||||||
|
|
||||||
// Here we will copy the generated token from device
|
|
||||||
var generated_token_buffer = [_]u32{0};
|
|
||||||
|
|
||||||
// Here we collect the generated text
|
// Here we collect the generated text
|
||||||
var output = std.ArrayList(u8).init(allocator);
|
var output = std.ArrayList(u8).init(allocator);
|
||||||
defer output.deinit();
|
defer output.deinit();
|
||||||
|
|
||||||
const tracer_buffer = try allocator.alloc(u8, @intCast(max_seq_len));
|
|
||||||
defer allocator.free(tracer_buffer);
|
|
||||||
const tracer = zml.tools.Tracer.init("ai.zml.models.llama");
|
|
||||||
const output_tokens_len = max_seq_len - prompt_tok.len - 1;
|
const output_tokens_len = max_seq_len - prompt_tok.len - 1;
|
||||||
const start = std.time.microTimestamp();
|
const start = std.time.microTimestamp();
|
||||||
|
|
||||||
var num_tokens_generated: usize = 0;
|
// One token has alreadyh been generated by the prefill.
|
||||||
|
var num_tokens_generated: usize = 1;
|
||||||
|
|
||||||
generation: for (0..output_tokens_len) |i| {
|
generation: for (0..output_tokens_len + 1) |i| {
|
||||||
const frame_id = tracer.frameStart(try std.fmt.bufPrintZ(tracer_buffer, "Generate token {}/{}", .{ i + 1, output_tokens_len }));
|
// collect and print generated sequence
|
||||||
|
num_tokens_generated += 1;
|
||||||
// current token index needs to go into a zml.Buffer
|
|
||||||
const token_index_buffer = &[_]u32{@intCast(prompt_tok.len + i)};
|
|
||||||
const token_index = try zml.Buffer.fromSlice(platform, .{}, token_index_buffer);
|
|
||||||
|
|
||||||
defer token_index.deinit();
|
|
||||||
|
|
||||||
// call to generate the next token
|
|
||||||
current_token, kv_cache, rng = mod_generate.call(.{ current_token, token_index, kv_cache, rng });
|
|
||||||
|
|
||||||
tracer.frameEnd(frame_id, try std.fmt.bufPrintZ(tracer_buffer, "Generated token {}/{}", .{ i + 1, output_tokens_len }));
|
|
||||||
|
|
||||||
// extract the generated token from the buffer
|
|
||||||
_ = try current_token.toHost(std.mem.sliceAsBytes(&generated_token_buffer));
|
|
||||||
const generated_token = generated_token_buffer[0];
|
const generated_token = generated_token_buffer[0];
|
||||||
// de-tokenize generated token into a string
|
const chunk = try tokenizer_decoder.next(generated_token) orelse unreachable;
|
||||||
const chunk = try tokenizer_decoder.next(@intCast(generated_token)) orelse unreachable;
|
try output.appendSlice(chunk);
|
||||||
num_tokens_generated = i;
|
std.debug.print("{s}", .{chunk});
|
||||||
|
|
||||||
// check for eos
|
// check for eos
|
||||||
|
if (i == output_tokens_len) break :generation;
|
||||||
switch (config.eos_token_id.value) {
|
switch (config.eos_token_id.value) {
|
||||||
.int => |eos| if (generated_token == @as(u32, @intCast(eos))) break :generation,
|
.int => |eos| if (generated_token == @as(u32, @intCast(eos))) break :generation,
|
||||||
.ints => |eos_list| {
|
.ints => |eos_list| {
|
||||||
@ -146,9 +129,16 @@ pub fn generateText(
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// collect and print generated sequence
|
// current token pos needs to go into a zml.Buffer
|
||||||
try output.appendSlice(chunk);
|
const token_pos_buffer = &[_]u32{@intCast(prompt_tok.len + i)};
|
||||||
std.debug.print("{s}", .{chunk});
|
const token_pos = try zml.Buffer.fromSlice(platform, .{}, token_pos_buffer);
|
||||||
|
defer token_pos.deinit();
|
||||||
|
|
||||||
|
// call to generate the next token
|
||||||
|
current_token, kv_cache, rng = mod_generate.call(.{ current_token, token_pos, kv_cache, rng });
|
||||||
|
|
||||||
|
// extract the generated token from the buffer
|
||||||
|
_ = try current_token.toHost(std.mem.sliceAsBytes(&generated_token_buffer));
|
||||||
}
|
}
|
||||||
const end = std.time.microTimestamp();
|
const end = std.time.microTimestamp();
|
||||||
const duration = stdx.math.divFloat(f64, end - start, std.time.us_per_s);
|
const duration = stdx.math.divFloat(f64, end - start, std.time.us_per_s);
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user