diff --git a/examples/MODULE.bazel b/examples/MODULE.bazel index ea66f7a..bb53bcb 100644 --- a/examples/MODULE.bazel +++ b/examples/MODULE.bazel @@ -126,25 +126,19 @@ filegroup( ) use_repo(huggingface, "Meta-Llama-3.1-70B-Instruct") - -huggingface.model( - name = "TinyLlama-120M-scratch", - build_file_content = """\ -package(default_visibility = ["//visibility:public"]) -filegroup( - name = "TinyLlama-120M-scratch", - srcs = glob(["*.json", "*.safetensors"]), +http_file( + name = "Karpathy-TinyLlama-Stories15M", + downloaded_file_path = "stories15M.tinyllama", + sha256 = "cd590644d963867a2b6e5a1107f51fad663c41d79c149fbecbbb1f95fa81f49a", + url = "https://huggingface.co/karpathy/tinyllamas/resolve/0bd21da7698eaf29a0d7de3992de8a46ef624add/stories15M.bin?download=true", ) -""", - commit = "89c1bb4ea00861ddaa26c55f102ccb25e161feee", - includes = [ - "*.safetensors", - "*.json", - ], - model = "Hoyeon/TinyLlama-120M-scratch", -) -use_repo(huggingface, "TinyLlama-120M-scratch") +http_file( + name = "Karpathy-TinyLlama-Tokenizer", + downloaded_file_path = "stories260K.tinyllama", + sha256 = "50a52ef822ee9e83de5ce9d0be0a025a773d019437f58b5ff9dcafb063ece361", + url = "https://github.com/karpathy/llama2.c/raw/c02865df300f3bd9e567ce061000dc23bf785a17/tokenizer.bin", +) bazel_dep(name = "rules_rust", version = "0.57.0") rust = use_extension("@rules_rust//rust:extensions.bzl", "rust") diff --git a/examples/llama/BUILD.bazel b/examples/llama/BUILD.bazel index 8d03023..5efbc68 100644 --- a/examples/llama/BUILD.bazel +++ b/examples/llama/BUILD.bazel @@ -1,7 +1,8 @@ load("@aspect_bazel_lib//lib:expand_template.bzl", "expand_template") load("@aspect_bazel_lib//lib:tar.bzl", "mtree_spec", "tar") load("@aspect_bazel_lib//lib:transitions.bzl", "platform_transition_filegroup") -load("@bazel_skylib//rules:native_binary.bzl", "native_binary") +load("@bazel_skylib//rules:native_binary.bzl", "native_test") +load("@bazel_skylib//rules:write_file.bzl", "write_file") load("@rules_oci//oci:defs.bzl", "oci_image", "oci_load", "oci_push") load("@zml//bazel:zig.bzl", "zig_cc_binary") @@ -20,24 +21,6 @@ zig_cc_binary( ], ) -cc_binary( - name = "TinyLlama-120M-scratch", - args = [ - "--config=$(location @TinyLlama-120M-scratch//:config.json)", - "--weights=$(location @TinyLlama-120M-scratch//:model.safetensors)", - "--tokenizer=$(location @TinyLlama-120M-scratch//:tokenizer.json)", - "--no-llama3=true", # don't do llama3 template prompt encoding - "--sharding=false", # don't shard this - ], - data = [ - "@TinyLlama-120M-scratch", - "@TinyLlama-120M-scratch//:config.json", - "@TinyLlama-120M-scratch//:model.safetensors", - "@TinyLlama-120M-scratch//:tokenizer.json", - ], - deps = [":llama_lib"], -) - cc_binary( name = "Llama-3.1-8B-Instruct", args = [ @@ -70,7 +53,6 @@ cc_binary( deps = [":llama_lib"], ) - cc_binary( name = "Llama-3.2-1B-Instruct", args = [ @@ -102,7 +84,31 @@ cc_binary( ], deps = [":llama_lib"], ) -# + +cc_binary( + name = "TinyLlama-Stories-15M", + args = [ + "--config=$(location :tinyllama_stories15M_json)", + "--weights=$(location @Karpathy-TinyLlama-Stories15M//file)", + "--tokenizer=$(location @Karpathy-TinyLlama-Tokenizer//file)", + "--prompt='Once upon a time, there was a little girl named Lily.'", + "--no-llama3=1", # don't do template prompt encoding, I'm a simple model + "--sharding=false", # don't shard me, I'm so small + ], + data = [ + ":tinyllama_stories15M_json", + "@Karpathy-TinyLlama-Stories15M//file", + "@Karpathy-TinyLlama-Tokenizer//file", + ], + deps = [":llama_lib"], +) + +write_file( + name = "tinyllama_stories15M_json", + out = "config.json", + content = ['{"bos_token_id":1,"eos_token_id":2,"hidden_act":"silu","hidden_size":288,"intermediate_size":768,"max_position_embeddings":256,"model_type":"llama","num_attention_heads":6,"num_hidden_layers":6,"num_key_value_heads":6,"rms_norm_eps":1e-05,"hf_rope_impl":false,"rope_scaling":null,"rope_theta":10000.0}'], +) + zig_cc_binary( name = "test-implementation", @@ -117,27 +123,31 @@ zig_cc_binary( ], main = "test.zig", deps = [ - "//third_party/tigerbeetle:flags", "@zml//async", - "@zml//metax", + "@zml//stdx", "@zml//zml", ], ) -zig_cc_binary( +native_test( name = "test_tokenizer", - main = "test_tokenizer.zig", - deps = [ - "//third_party/tigerbeetle:flags", - "@zml//stdx", - "@zml//zml", - ], + src = "@zml//zml/tokenizer:main", # Note: all Llama-3.x tokenizers are the same, # but using the 3.2-1B version because downloading the tokenizer triggers downloading the model. args = [ - "--tokenizer=$(location @Meta-Llama-3.2-1B-Instruct//:tokenizer)", + "--tokenizer=$(location @Meta-Llama-3.2-1B-Instruct//:tokenizer.json)", + """--prompt='Examples of titles: +📉 Stock Market Trends +🍪 Perfect Chocolate Chip Recipe +Evolution of Music Streaming +Remote Work Productivity Tips +Artificial Intelligence in Healthcare +🎮 Video Game Development Insights +'""", + # this correspond to encoding with HF tokenizers, with bos=False + "--expected=41481,315,15671,512,9468,241,231,12937,8152,50730,198,9468,235,103,24118,39520,32013,26371,198,35212,3294,315,10948,45910,198,25732,5664,5761,1968,26788,198,9470,16895,22107,304,39435,198,9468,236,106,8519,4140,11050,73137,198", ], - data = ["@Meta-Llama-3.2-1B-Instruct//:tokenizer"], + data = ["@Meta-Llama-3.2-1B-Instruct//:tokenizer.json"], ) mtree_spec( diff --git a/examples/llama/llama.zig b/examples/llama/llama.zig index 6f6594f..a162b43 100644 --- a/examples/llama/llama.zig +++ b/examples/llama/llama.zig @@ -26,6 +26,7 @@ pub const LlamaLM = struct { rope_theta: f32, max_position_embeddings: usize, rms_norm_eps: f32, + hf_rope_impl: bool = true, }; pub const Options = struct { @@ -47,7 +48,7 @@ pub const LlamaLM = struct { self.model.num_heads = @intCast(config.num_attention_heads); self.model.num_kv_heads = @intCast(config.num_key_value_heads); self.model.rope_opts = .{ - .impl = .sequential, + .impl = if (config.hf_rope_impl) .sequential else .interleaved, .freq_base = config.rope_theta, }; for (self.model.layers) |*layer| { diff --git a/examples/llama/main.zig b/examples/llama/main.zig index 3cf2670..80df26f 100644 --- a/examples/llama/main.zig +++ b/examples/llama/main.zig @@ -27,9 +27,9 @@ pub fn tokenizePromptLlama3(allocator: std.mem.Allocator, tokenizer: zml.tokeniz var encoder = try tokenizer.encoder(); defer encoder.deinit(); - const start_header_id = tokenizer.token_to_id("<|start_header_id|>") orelse return error.NoSuchToken; - const end_header_id = tokenizer.token_to_id("<|end_header_id|>") orelse return error.NoSuchToken; - const eot_id = tokenizer.token_to_id("<|eot_id|>") orelse return error.NoSuchToken; + const start_header_id = tokenizer.tokenToId("<|start_header_id|>") orelse return error.NoSuchToken; + const end_header_id = tokenizer.tokenToId("<|end_header_id|>") orelse return error.NoSuchToken; + const eot_id = tokenizer.tokenToId("<|eot_id|>") orelse return error.NoSuchToken; const newline_id = (try encoder.encode("\n"))[0]; try tokens.append(config.bos_token_id); @@ -312,7 +312,7 @@ pub fn asyncMain() !void { var timer = try stdx.time.Timer.start(); defer log.info("Loaded tokenizer from {s} [{}]", .{ tok, timer.read() }); - break :blk try zml.tokenizer.Tokenizer.from_file(model_arena.allocator(), tok); + break :blk try zml.tokenizer.Tokenizer.fromFile(model_arena.allocator(), tok); } else { log.err("Missing --tokenizer", .{}); return;