From 237a877a29b7a4330238bd0292708f906ce94dd5 Mon Sep 17 00:00:00 2001 From: Foke Singh Date: Wed, 1 Nov 2023 10:16:48 +0000 Subject: [PATCH] zml: Add support for Llama 3.2 text-only models. Implement transpose over embed_tokens as a replacement for missing lm_head and make lm_head optional for compatibility. Add repositories and executions to Bazel and update README. --- examples/MODULE.bazel | 52 ++++++++++++++++++++++++++++++++++++++ examples/llama/BUILD.bazel | 33 ++++++++++++++++++++++++ examples/llama/llama.zig | 16 ++++++++---- examples/llama/main.zig | 27 ++++++++++++++------ 4 files changed, 115 insertions(+), 13 deletions(-) diff --git a/examples/MODULE.bazel b/examples/MODULE.bazel index a2ebcc5..0658ba3 100644 --- a/examples/MODULE.bazel +++ b/examples/MODULE.bazel @@ -86,6 +86,56 @@ http_file( url = "https://github.com/karpathy/llama2.c/raw/c02865df300f3bd9e567ce061000dc23bf785a17/tokenizer.bin", ) +# Llama 3.2 +huggingface.model( + name = "Meta-Llama-3.2-1B-Instruct", + build_file_content = """\ +package(default_visibility = ["//visibility:public"]) +filegroup( + name = "model", + srcs = ["model.safetensors"], +) + +filegroup( + name = "tokenizer", + srcs = ["tokenizer.json"], +) +""", + commit = "9213176726f574b556790deb65791e0c5aa438b6", + includes = [ + "model.safetensors", + "tokenizer.json", + ], + model = "meta-llama/Llama-3.2-1B-Instruct", +) +use_repo(huggingface, "Meta-Llama-3.2-1B-Instruct") + +huggingface.model( + name = "Meta-Llama-3.2-3B-Instruct", + build_file_content = """\ +package(default_visibility = ["//visibility:public"]) +filegroup( + name = "model", + srcs = glob(["*.safetensors"]) + ["model.safetensors.index.json"], +) + +filegroup( + name = "tokenizer", + srcs = ["tokenizer.json"], +) +""", + commit = "0cb88a4f764b7a12671c53f0838cd831a0843b95", + includes = [ + "*.safetensors", + "model.safetensors.index.json", + "tokenizer.json", + ], + model = "meta-llama/Llama-3.2-3B-Instruct", +) +use_repo(huggingface, "Meta-Llama-3.2-3B-Instruct") + + +# Llama 3.1 huggingface.model( name = "Meta-Llama-3.1-8B-Instruct", build_file_content = """\ @@ -155,6 +205,8 @@ filegroup( model = "TinyLlama/TinyLlama-1.1B-Chat-v1.0", ) use_repo(huggingface, "TinyLlama-1.1B-Chat-v1.0") + +#OpenLLaMa huggingface.model( name = "OpenLM-Research-OpenLLaMA-3B", build_file_content = """\ diff --git a/examples/llama/BUILD.bazel b/examples/llama/BUILD.bazel index f9c4258..8444e78 100644 --- a/examples/llama/BUILD.bazel +++ b/examples/llama/BUILD.bazel @@ -51,6 +51,39 @@ cc_binary( deps = [":llama_lib"], ) +cc_binary( + name = "Llama-3.2-1B-Instruct", + args = [ + "--model=$(location @Meta-Llama-3.2-1B-Instruct//:model.safetensors)", + "--tokenizer=$(location @Meta-Llama-3.2-1B-Instruct//:tokenizer)", + "--num-heads=32", + "--num-kv-heads=8", + "--rope-freq-base=500000", + ], + data = [ + "@Meta-Llama-3.2-1B-Instruct//:model.safetensors", + "@Meta-Llama-3.2-1B-Instruct//:tokenizer", + ], + deps = [":llama_lib"], +) + +cc_binary( + name = "Llama-3.2-3B-Instruct", + args = [ + "--model=$(location @Meta-Llama-3.2-3B-Instruct//:model.safetensors.index.json)", + "--tokenizer=$(location @Meta-Llama-3.2-3B-Instruct//:tokenizer)", + "--num-heads=24", + "--num-kv-heads=8", + "--rope-freq-base=500000", + ], + data = [ + "@Meta-Llama-3.2-3B-Instruct//:model", + "@Meta-Llama-3.2-3B-Instruct//:model.safetensors.index.json", + "@Meta-Llama-3.2-3B-Instruct//:tokenizer", + ], + deps = [":llama_lib"], +) + cc_binary( name = "OpenLLaMA-3B", args = [ diff --git a/examples/llama/llama.zig b/examples/llama/llama.zig index 6fe74cc..d24324e 100644 --- a/examples/llama/llama.zig +++ b/examples/llama/llama.zig @@ -24,7 +24,7 @@ pub const LlamaOptions = struct { /// Llama architecture, using huggingface transformers naming. /// Dimensions of activations: {.b, .s, .d} pub const LlamaLM = struct { - lm_head: zml.nn.Linear, + lm_head: ?zml.nn.Linear = null, model: Llama, // Options controlling generation @@ -55,7 +55,9 @@ pub const LlamaLM = struct { // TODO(Corentin): Fix lm_head sharding when top-k sampling is enabled. // It currently crashes/compilation fails if (options.gen_opts.topk == 1) { - self.lm_head.weight = self.lm_head.weight.withSharding(.{0}); + if (self.lm_head) |lm_head| { + self.lm_head.?.weight = lm_head.weight.withSharding(.{0}); + } } } @@ -76,12 +78,12 @@ pub const LlamaLM = struct { var tokens = tokens_.withPartialTags(.{.s}); const out, const updated_kv_cache = zml.call(self.model, .forward, .{ tokens, if (kv_cache == null) null else token_index, kv_cache }); - tokens, const new_rng = updateTokens(self.lm_head, tokens, token_index, out, rng, self.gen_opts); + tokens, const new_rng = self.updateTokens(tokens, token_index, out, rng, self.gen_opts); return .{ tokens, increment(0, token_index), updated_kv_cache, new_rng }; } pub fn updateTokens( - lm_head: zml.nn.Linear, + self: LlamaLM, tokens_: Tensor, token_index: Tensor, out_: Tensor, @@ -92,7 +94,11 @@ pub const LlamaLM = struct { const out = out_.withPartialTags(.{ .s, .d }); const next_token_pred = out.gatherValues(.s, token_index, .{}); - var logits = zml.call(lm_head, .forward, .{next_token_pred}); + var logits = if (self.lm_head) |lm_head| + zml.call(lm_head, .forward, .{next_token_pred}) + else + self.model.embed_tokens.weight.withTags(.{ .voc, .d }).dot(next_token_pred, .{.d}); + if (logits.shape().hasTag(.voc) == null) logits = logits.rename(.{ .d = .voc }); diff --git a/examples/llama/main.zig b/examples/llama/main.zig index 3909b46..ccf5f28 100644 --- a/examples/llama/main.zig +++ b/examples/llama/main.zig @@ -17,6 +17,8 @@ const ShapeOf = zml.ShapeOf; const log = std.log.scoped(.llama); +const eos_tokens: [3]i32 = .{ 128001, 128008, 128009 }; + // set this to false to disable the verbose logging const show_mlir = true; @@ -71,6 +73,7 @@ pub fn generateText( const start = std.time.microTimestamp(); const output_freq: u8 = 1; + var eos_index: ?usize = null; for (0..output_tokens_len) |i| { //_ = i; const frame_id = tracer.frameStart(try std.fmt.bufPrintZ(tracer_buffer, "Generate token {}/{}", .{ i + 1, output_tokens_len })); @@ -84,26 +87,34 @@ pub fn generateText( decode_progress += output_freq; std.debug.print("{s}", .{output.items[n..]}); tracer.frameEnd(frame_id, try std.fmt.bufPrintZ(tracer_buffer, "Decoded token {}/{} : {s}", .{ i + 1, output_tokens_len, output.items[n..] })); + if (std.mem.indexOfAny(i32, token_buffer[decode_progress - output_freq ..], &eos_tokens)) |index| { + // Handle strange scenarios when eos id isn't the very next token after decode_progress + eos_index = decode_progress - output_freq + index; + break; + } } else { tracer.frameEnd(frame_id, try std.fmt.bufPrintZ(tracer_buffer, "Generated token {}/{}", .{ i + 1, output_tokens_len })); } } - std.debug.print("\n", .{}); - + var total_token_count: usize = max_seq_len; const n = output.items.len; - try tokenizer.decodeWithOpts(&output, @ptrCast(token_buffer[decode_progress..]), .{}); + if (eos_index) |end_idx| { + // count = eos index + 1 + total_token_count = end_idx + 1; + } + const generated_token_count = total_token_count - prompt_tok.len; + try tokenizer.decodeWithOpts(&output, @ptrCast(token_buffer[decode_progress..total_token_count]), .{}); std.debug.print("{s}\n", .{output.items[n..]}); const end = std.time.microTimestamp(); const duration = stdx.math.divFloat(f64, end - start, std.time.us_per_s); - const speed = @as(f64, @floatFromInt(max_seq_len)) / duration; - log.info("✅ Generated {d} tokens in {:.3}s: {d:.3}tok/s", .{ max_seq_len, duration, speed }); + const speed = @as(f64, @floatFromInt(generated_token_count)) / duration; + log.info("✅ Generated {d} tokens in {:.3}s: {d:.3}tok/s", .{ generated_token_count, duration, speed }); _ = try tokens.toHost(std.mem.sliceAsBytes(token_buffer)); - const end_index = std.mem.indexOfScalar(i32, token_buffer, 128001) orelse max_seq_len; output.clearRetainingCapacity(); - try tokenizer.decodeWithOpts(&output, @ptrCast(token_buffer[0..end_index]), .{}); + try tokenizer.decodeWithOpts(&output, @ptrCast(token_buffer[0..total_token_count]), .{}); return output.toOwnedSlice(); } @@ -199,7 +210,7 @@ pub fn asyncMain() !void { defer tokenizer.deinit(); const dims = llama.model.shape(); - const dtype = llama.lm_head.weight.dtype(); + const dtype = llama.model.embed_tokens.weight.dtype(); // Note: we compile the model without a batching dimension. // To do so, we would just need to add `.b = batch_size` to `token_shape` and `kv_shape`.