diff --git a/examples/llama/llama.zig b/examples/llama/llama.zig index a162b43..2eaf8a8 100644 --- a/examples/llama/llama.zig +++ b/examples/llama/llama.zig @@ -87,16 +87,15 @@ pub const LlamaLM = struct { rng: Tensor.Rng, ) struct { Tensor, KvCache, Tensor.Rng } { stdx.debug.assert(tokens_.dtype() == .u32 and tokens_.rank() >= 1 and token_index.dtype() == .u32 and token_index.rank() <= 1, "Can't run Llama ! Expected >=1d tokens and 0d token_index, got: {} and {}", .{ tokens_, token_index }); - var tokens = tokens_.withPartialTags(.{.s}); + const tokens = tokens_.withPartialTags(.{.s}); const out, const updated_kv_cache = zml.call(self.model, .forward, .{ tokens, token_index, kv_cache }); - tokens, const new_rng = self.sampleTokens(self.lm_head, tokens, out, rng, self.gen_opts); - return .{ tokens, updated_kv_cache, new_rng }; + const new_tokens, const new_rng = self.sampleTokens(self.lm_head, out, rng, self.gen_opts); + return .{ new_tokens.convert(tokens.dtype()).reuseBuffer(tokens), updated_kv_cache, new_rng }; } pub fn sampleTokens( self: LlamaLM, lm_head_: ?zml.nn.Linear, - tokens_: Tensor, out_: Tensor, rng: Tensor.Rng, opts: zml.nn.SamplingStrategy, @@ -115,7 +114,7 @@ pub const LlamaLM = struct { logits = logits.rename(.{ .d = .voc }); const next_tokens, const new_rng = zml.nn.sampleTokens(logits, opts, rng); - return .{ next_tokens.convert(tokens_.dtype()).reuseBuffer(tokens_), new_rng }; + return .{ next_tokens, new_rng }; } pub fn increment(_: u8, token_index: Tensor) Tensor { diff --git a/examples/llama/main.zig b/examples/llama/main.zig index 8714837..9a17c40 100644 --- a/examples/llama/main.zig +++ b/examples/llama/main.zig @@ -22,11 +22,17 @@ pub const std_options = .{ .logFn = asynk.logFn(std.log.defaultLog), }; -pub fn tokenizePromptLlama3(allocator: std.mem.Allocator, tokenizer: zml.tokenizer.Tokenizer, config: LlamaLM.Config, prompt: []const u8) ![]u32 { +pub fn tokenizePrompt(allocator: std.mem.Allocator, tokenizer: zml.tokenizer.Tokenizer, config: LlamaLM.Config, prompt: []const u8, skip_llama3_encoding: bool) ![]u32 { var tokens = std.ArrayList(u32).init(allocator); var encoder = try tokenizer.encoder(); defer encoder.deinit(); + if (skip_llama3_encoding) { + // Copy to the arraylist so the ownership is the same in both branches. + try tokens.appendSlice(try encoder.encode(prompt)); + return tokens.toOwnedSlice(); + } + const start_header_id = tokenizer.tokenToId("<|start_header_id|>") orelse return error.NoSuchToken; const end_header_id = tokenizer.tokenToId("<|end_header_id|>") orelse return error.NoSuchToken; const eot_id = tokenizer.tokenToId("<|eot_id|>") orelse return error.NoSuchToken; @@ -59,84 +65,61 @@ pub fn generateText( prompt: []const u8, skip_llama3_encoding: bool, ) ![]const u8 { - var tokenizer_encoder = try tokenizer.encoder(); - defer tokenizer_encoder.deinit(); + const prompt_tok: []const u32 = try tokenizePrompt(allocator, tokenizer, config, prompt, skip_llama3_encoding); + defer allocator.free(prompt_tok); + var tokenizer_decoder = try tokenizer.decoder(); defer tokenizer_decoder.deinit(); - const prompt_tok: []const u32 = if (skip_llama3_encoding) try tokenizer_encoder.encode(prompt) else try tokenizePromptLlama3(allocator, tokenizer, config, prompt); - defer allocator.free(prompt_tok); - - const dims = llama_.model.shape(); - const max_seq_len = dims.s; - - // Prefill - // initialize a 0..max_seq_len buffer with the tokenized prompt - const prefill_buffer = try allocator.alloc(u32, @intCast(max_seq_len)); - @memset(prefill_buffer, 0); - for (0..prompt_tok.len) |i| { - prefill_buffer[i] = @intCast(prompt_tok[i]); - } - defer allocator.free(prefill_buffer); - const platform = mod_generate.platform(); + const max_seq_len = llama_.model.shape().s; - // prepare device buffers for the prefill tokens and the index - var prefill_tokens = try zml.Buffer.fromSlice(platform, .{max_seq_len}, prefill_buffer); - defer prefill_tokens.deinit(); - var prefill_token_index = try zml.Buffer.constant(platform, zml.Shape.init(.{}, .u32), 0); - - defer prefill_token_index.deinit(); - - // init RNG and prefill + // init RNG and buffers var rng = try zml.Tensor.Rng.init(platform, seed); - prefill_tokens, var kv_cache, rng = mod_prefill.call(.{ prefill_tokens, prefill_token_index, kv_cache_, rng }); - defer kv_cache.k.deinit(); - defer kv_cache.v.deinit(); - defer kv_cache.layer_index.deinit(); + var generated_token_buffer = [_]u32{undefined}; - // Prepare for token-by-token generation - var first_token_hostbuffer = [_]u32{prompt_tok[prompt_tok.len - 1]}; // start with the prompt's last token - var current_token = try zml.Buffer.fromSlice(platform, .{1}, &first_token_hostbuffer); + var kv_cache = prefill: { + // prepare device buffers for the prefill tokens and their positions + const prefill_buffer = try allocator.alloc(u32, max_seq_len); + @memcpy(prefill_buffer[0..prompt_tok.len], prompt_tok); + + var prefill_tokens = try zml.Buffer.fromSlice(platform, .{max_seq_len}, prefill_buffer); + defer prefill_tokens.deinit(); + var prefill_token_pos = try zml.Buffer.constant(platform, zml.Shape.init(.{}, .u32), 0); + defer prefill_token_pos.deinit(); + + const prefilled_tokens, const kv_cache, rng = mod_prefill.call(.{ prefill_tokens, prefill_token_pos, kv_cache_, rng }); + _ = try prefilled_tokens.toHost(std.mem.sliceAsBytes(prefill_buffer)); + generated_token_buffer[0] = prefill_buffer[prompt_tok.len - 1]; + break :prefill kv_cache; + }; + defer zml.aio.unloadBuffers(&kv_cache); + + // Prepare for token-by-token generation, + // start with the token generated based on the full prompt. + var current_token = try zml.Buffer.fromSlice(platform, .{1}, &generated_token_buffer); defer current_token.deinit(); - // Here we will copy the generated token from device - var generated_token_buffer = [_]u32{0}; - // Here we collect the generated text var output = std.ArrayList(u8).init(allocator); defer output.deinit(); - const tracer_buffer = try allocator.alloc(u8, @intCast(max_seq_len)); - defer allocator.free(tracer_buffer); - const tracer = zml.tools.Tracer.init("ai.zml.models.llama"); const output_tokens_len = max_seq_len - prompt_tok.len - 1; const start = std.time.microTimestamp(); - var num_tokens_generated: usize = 0; + // One token has alreadyh been generated by the prefill. + var num_tokens_generated: usize = 1; - generation: for (0..output_tokens_len) |i| { - const frame_id = tracer.frameStart(try std.fmt.bufPrintZ(tracer_buffer, "Generate token {}/{}", .{ i + 1, output_tokens_len })); - - // current token index needs to go into a zml.Buffer - const token_index_buffer = &[_]u32{@intCast(prompt_tok.len + i)}; - const token_index = try zml.Buffer.fromSlice(platform, .{}, token_index_buffer); - - defer token_index.deinit(); - - // call to generate the next token - current_token, kv_cache, rng = mod_generate.call(.{ current_token, token_index, kv_cache, rng }); - - tracer.frameEnd(frame_id, try std.fmt.bufPrintZ(tracer_buffer, "Generated token {}/{}", .{ i + 1, output_tokens_len })); - - // extract the generated token from the buffer - _ = try current_token.toHost(std.mem.sliceAsBytes(&generated_token_buffer)); + generation: for (0..output_tokens_len + 1) |i| { + // collect and print generated sequence + num_tokens_generated += 1; const generated_token = generated_token_buffer[0]; - // de-tokenize generated token into a string - const chunk = try tokenizer_decoder.next(@intCast(generated_token)) orelse unreachable; - num_tokens_generated = i; + const chunk = try tokenizer_decoder.next(generated_token) orelse unreachable; + try output.appendSlice(chunk); + std.debug.print("{s}", .{chunk}); // check for eos + if (i == output_tokens_len) break :generation; switch (config.eos_token_id.value) { .int => |eos| if (generated_token == @as(u32, @intCast(eos))) break :generation, .ints => |eos_list| { @@ -146,9 +129,16 @@ pub fn generateText( }, } - // collect and print generated sequence - try output.appendSlice(chunk); - std.debug.print("{s}", .{chunk}); + // current token pos needs to go into a zml.Buffer + const token_pos_buffer = &[_]u32{@intCast(prompt_tok.len + i)}; + const token_pos = try zml.Buffer.fromSlice(platform, .{}, token_pos_buffer); + defer token_pos.deinit(); + + // call to generate the next token + current_token, kv_cache, rng = mod_generate.call(.{ current_token, token_pos, kv_cache, rng }); + + // extract the generated token from the buffer + _ = try current_token.toHost(std.mem.sliceAsBytes(&generated_token_buffer)); } const end = std.time.microTimestamp(); const duration = stdx.math.divFloat(f64, end - start, std.time.us_per_s);