From 27aabf9bebaf6df4e35bc543378eaf0c17827700 Mon Sep 17 00:00:00 2001 From: Foke Singh Date: Thu, 23 May 2024 15:52:34 +0000 Subject: [PATCH] Add Bazel build rules and a test for the benchmark, llama, mnist, and simple_layer examples. --- examples/benchmark/BUILD.bazel | 2 +- examples/benchmark/main.zig | 3 +- examples/llama/BUILD.bazel | 55 ++++++++++++++++++++++++++----- examples/llama/test.zig | 42 ++++++++++++----------- examples/mnist/BUILD.bazel | 15 +++++++++ examples/simple_layer/BUILD.bazel | 12 +++++++ 6 files changed, 99 insertions(+), 30 deletions(-) diff --git a/examples/benchmark/BUILD.bazel b/examples/benchmark/BUILD.bazel index 7ca74ad..7088dae 100644 --- a/examples/benchmark/BUILD.bazel +++ b/examples/benchmark/BUILD.bazel @@ -5,7 +5,7 @@ zig_cc_binary( main = "main.zig", deps = [ "@zml//async", + "@zml//stdx", "@zml//zml", - "//third_party/tigerbeetle:flags", ], ) diff --git a/examples/benchmark/main.zig b/examples/benchmark/main.zig index 28410c6..a2de99d 100644 --- a/examples/benchmark/main.zig +++ b/examples/benchmark/main.zig @@ -1,7 +1,8 @@ const std = @import("std"); const zml = @import("zml"); +const stdx = @import("stdx"); const asynk = @import("async"); -const flags = @import("tigerbeetle/flags"); +const flags = stdx.flags; // set log level to debug to print the generated IR pub const std_options = .{ diff --git a/examples/llama/BUILD.bazel b/examples/llama/BUILD.bazel index 5efbc68..da0f20f 100644 --- a/examples/llama/BUILD.bazel +++ b/examples/llama/BUILD.bazel @@ -6,7 +6,6 @@ load("@bazel_skylib//rules:write_file.bzl", "write_file") load("@rules_oci//oci:defs.bzl", "oci_image", "oci_load", "oci_push") load("@zml//bazel:zig.bzl", "zig_cc_binary") - zig_cc_binary( name = "llama", srcs = [ @@ -34,6 +33,9 @@ cc_binary( "@Meta-Llama-3.1-8B-Instruct//:model.safetensors.index.json", "@Meta-Llama-3.1-8B-Instruct//:tokenizer.json", ], + tags = [ + "no_ci", + ], deps = [":llama_lib"], ) @@ -50,6 +52,9 @@ cc_binary( "@Meta-Llama-3.1-70B-Instruct//:model.safetensors.index.json", "@Meta-Llama-3.1-70B-Instruct//:tokenizer.json", ], + tags = [ + "no_ci", + ], deps = [":llama_lib"], ) @@ -66,6 +71,9 @@ cc_binary( "@Meta-Llama-3.2-1B-Instruct//:model.safetensors", "@Meta-Llama-3.2-1B-Instruct//:tokenizer.json", ], + tags = [ + "no_ci", + ], deps = [":llama_lib"], ) @@ -82,6 +90,9 @@ cc_binary( "@Meta-Llama-3.2-3B-Instruct//:model.safetensors.index.json", "@Meta-Llama-3.2-3B-Instruct//:tokenizer.json", ], + tags = [ + "no_ci", + ], deps = [":llama_lib"], ) @@ -92,8 +103,8 @@ cc_binary( "--weights=$(location @Karpathy-TinyLlama-Stories15M//file)", "--tokenizer=$(location @Karpathy-TinyLlama-Tokenizer//file)", "--prompt='Once upon a time, there was a little girl named Lily.'", - "--no-llama3=1", # don't do template prompt encoding, I'm a simple model - "--sharding=false", # don't shard me, I'm so small + "--no-llama3=1", # don't do template prompt encoding, I'm a simple model + "--sharding=false", # don't shard me, I'm so small ], data = [ ":tinyllama_stories15M_json", @@ -109,7 +120,6 @@ write_file( content = ['{"bos_token_id":1,"eos_token_id":2,"hidden_act":"silu","hidden_size":288,"intermediate_size":768,"max_position_embeddings":256,"model_type":"llama","num_attention_heads":6,"num_hidden_layers":6,"num_key_value_heads":6,"rms_norm_eps":1e-05,"hf_rope_impl":false,"rope_scaling":null,"rope_theta":10000.0}'], ) - zig_cc_binary( name = "test-implementation", srcs = ["llama.zig"], @@ -118,10 +128,13 @@ zig_cc_binary( "--config=$(location @Meta-Llama-3.1-8B-Instruct//:config.json)", ], data = [ - "@Meta-Llama-3.1-8B-Instruct//:model", + "@Meta-Llama-3.1-8B-Instruct//:config.json", "@Meta-Llama-3.1-8B-Instruct//:model.safetensors.index.json", ], main = "test.zig", + tags = [ + "no_ci", + ], deps = [ "@zml//async", "@zml//stdx", @@ -135,8 +148,8 @@ native_test( # Note: all Llama-3.x tokenizers are the same, # but using the 3.2-1B version because downloading the tokenizer triggers downloading the model. args = [ - "--tokenizer=$(location @Meta-Llama-3.2-1B-Instruct//:tokenizer.json)", - """--prompt='Examples of titles: + "--tokenizer=$(location @Meta-Llama-3.2-1B-Instruct//:tokenizer.json)", + """--prompt='Examples of titles: 📉 Stock Market Trends 🍪 Perfect Chocolate Chip Recipe Evolution of Music Streaming @@ -144,15 +157,21 @@ Remote Work Productivity Tips Artificial Intelligence in Healthcare 🎮 Video Game Development Insights '""", - # this correspond to encoding with HF tokenizers, with bos=False - "--expected=41481,315,15671,512,9468,241,231,12937,8152,50730,198,9468,235,103,24118,39520,32013,26371,198,35212,3294,315,10948,45910,198,25732,5664,5761,1968,26788,198,9470,16895,22107,304,39435,198,9468,236,106,8519,4140,11050,73137,198", + # this correspond to encoding with HF tokenizers, with bos=False + "--expected=41481,315,15671,512,9468,241,231,12937,8152,50730,198,9468,235,103,24118,39520,32013,26371,198,35212,3294,315,10948,45910,198,25732,5664,5761,1968,26788,198,9470,16895,22107,304,39435,198,9468,236,106,8519,4140,11050,73137,198", ], data = ["@Meta-Llama-3.2-1B-Instruct//:tokenizer.json"], + tags = [ + "no_ci", + ], ) mtree_spec( name = "mtree", srcs = [":Llama-3.2-1B-Instruct"], + tags = [ + "no_ci", + ], ) tar( @@ -164,6 +183,9 @@ tar( ], compress = "zstd", mtree = ":mtree", + tags = [ + "no_ci", + ], ) expand_template( @@ -180,6 +202,9 @@ expand_template( ":weights": "$(rlocationpath @Meta-Llama-3.2-1B-Instruct//:model.safetensors)", ":tokenizer": "$(rlocationpath @Meta-Llama-3.2-1B-Instruct//:tokenizer.json)", }, + tags = [ + "no_ci", + ], template = [ "./{}/Llama-3.2-1B-Instruct".format(package_name()), "--config=./{}/Llama-3.2-1B-Instruct.runfiles/:config".format(package_name()), @@ -193,6 +218,9 @@ oci_image( base = "@distroless_cc_debian12_debug", # entrypoint = ["./{}/Llama-3.2-1B-Instruct".format(package_name())], entrypoint = ":entrypoint", + tags = [ + "no_ci", + ], tars = [ "@zml//runtimes:layers", ":archive", @@ -202,6 +230,9 @@ oci_image( platform_transition_filegroup( name = "image", srcs = [":image_"], + tags = [ + "no_ci", + ], target_platform = "@zml//platforms:linux_amd64", ) @@ -211,6 +242,9 @@ oci_load( repo_tags = [ "distroless/llama-3.2-1b-instruct:latest", ], + tags = [ + "no_ci", + ], ) oci_push( @@ -218,4 +252,7 @@ oci_push( image = ":image", remote_tags = ["latest"], repository = "index.docker.io/steeve/llama-3.2-1b-instruct", + tags = [ + "no_ci", + ], ) diff --git a/examples/llama/test.zig b/examples/llama/test.zig index 97aec3f..42c0b49 100644 --- a/examples/llama/test.zig +++ b/examples/llama/test.zig @@ -1,8 +1,8 @@ const asynk = @import("async"); -const flags = @import("tigerbeetle/flags"); const std = @import("std"); const stdx = @import("stdx"); const zml = @import("zml"); +const flags = stdx.flags; const llama_mod = @import("./llama.zig"); const LlamaLM = llama_mod.LlamaLM; @@ -10,7 +10,7 @@ const LlamaLM = llama_mod.LlamaLM; const Tensor = zml.Tensor; pub fn main() !void { - try asynk.AsyncThread.main(std.heap.c_allocator, asyncMain, .{}); + try asynk.AsyncThread.main(std.heap.c_allocator, asyncMain); } pub fn asyncMain() !void { @@ -54,13 +54,8 @@ pub fn asyncMain() !void { // Create the model and configure it. var llama = try zml.aio.populateModel(LlamaLM, model_arena, buffer_store); - const num_heads: i64 = cli_args.num_heads orelse buffer_store.metadata("num_heads", .int64) orelse @panic("--num_heads is required for this model"); - const num_kv_heads: i64 = cli_args.num_kv_heads orelse buffer_store.metadata("num_kv_heads", .int64) orelse num_heads; - - const rope_impl = if (buffer_store.metadata("rope_impl", .string)) |val| - std.meta.stringToEnum(zml.nn.RopeOpts.Implementation, val).? - else - .sequential; + const num_heads: i64 = cli_args.num_heads orelse buffer_store.metadata("num_heads", .int) orelse @panic("--num_heads is required for this model"); + const num_kv_heads: i64 = cli_args.num_kv_heads orelse buffer_store.metadata("num_kv_heads", .int) orelse num_heads; const config = blk: { var config_json_file = try asynk.File.open(cli_args.config, .{ .mode = .read_only }); @@ -72,22 +67,31 @@ pub fn asyncMain() !void { }; std.log.info("Parsed llama config: {}", .{config}); - const llama_options: llama_mod.LlamaOptions = .{ + const llama_config: LlamaLM.Config = .{ + .eos_token_id = config.eos_token_id, + .bos_token_id = config.bos_token_id, + .num_key_value_heads = @intCast(num_kv_heads), + .num_hidden_layers = @intCast(config.num_hidden_layers), + .num_attention_heads = @intCast(num_heads), + .max_position_embeddings = config.max_position_embeddings, + .rope_theta = config.rope_theta, + .rms_norm_eps = @floatCast(buffer_store.metadata("rms_norm_eps", .float) orelse 1e-5), + .hf_rope_impl = true, + }; + + const llama_options: LlamaLM.Options = .{ .max_seq_len = 256, - .num_kv_heads = num_kv_heads, - .num_heads = num_heads, - .gen_opts = .{}, - .rms_norm_eps = @floatCast(buffer_store.metadata("rms_norm_eps", .float64) orelse 1e-5), - .rope_opts = .{ - .impl = rope_impl, - .freq_base = @floatCast(buffer_store.metadata("rope_freq_base", .float64) orelse @as(f32, @floatFromInt(cli_args.rope_freq_base orelse 10_000))), + .sampling_strategy = .{ + .topk = 1, + .temperature = 1.0, }, }; + std.log.info("Parsed llama config: {}", .{llama_options}); - llama.init(llama_options); + llama.init(llama_config, llama_options); // Load the weights. - var llama_weights = try zml.aio.loadBuffers(LlamaLM, .{llama_options}, buffer_store, model_arena, platform); + var llama_weights = try zml.aio.loadBuffers(LlamaLM, .{ llama_config, llama_options }, buffer_store, model_arena, platform); defer zml.aio.unloadBuffers(&llama_weights); // Load the activations. diff --git a/examples/mnist/BUILD.bazel b/examples/mnist/BUILD.bazel index 2f77b40..be8b8e9 100644 --- a/examples/mnist/BUILD.bazel +++ b/examples/mnist/BUILD.bazel @@ -60,12 +60,18 @@ oci_image( name = "image_", base = "@distroless_cc_debian12", entrypoint = ":entrypoint", + target_compatible_with = [ + "@platforms//os:linux", + ], tars = [":archive"], ) platform_transition_filegroup( name = "image", srcs = [":image_"], + target_compatible_with = [ + "@platforms//os:linux", + ], target_platform = "@zml//platforms:linux_amd64", ) @@ -75,6 +81,9 @@ oci_load( repo_tags = [ "distroless/mnist:latest", ], + target_compatible_with = [ + "@platforms//os:linux", + ], ) oci_push( @@ -82,6 +91,9 @@ oci_push( image = ":image", remote_tags = ["latest"], repository = "index.docker.io/steeve/mnist", + target_compatible_with = [ + "@platforms//os:linux", + ], ) oci_load( @@ -90,4 +102,7 @@ oci_load( repo_tags = [ "distroless/mnist:latest", ], + target_compatible_with = [ + "@platforms//os:linux", + ], ) diff --git a/examples/simple_layer/BUILD.bazel b/examples/simple_layer/BUILD.bazel index 17ad673..8d74767 100644 --- a/examples/simple_layer/BUILD.bazel +++ b/examples/simple_layer/BUILD.bazel @@ -35,6 +35,9 @@ oci_image( name = "image_", base = "@distroless_cc_debian12", entrypoint = ["./{}/simple_layer".format(package_name())], + target_compatible_with = [ + "@platforms//os:linux", + ], tars = [":archive"], ) @@ -42,6 +45,9 @@ oci_image( platform_transition_filegroup( name = "image", srcs = [":image_"], + target_compatible_with = [ + "@platforms//os:linux", + ], target_platform = "@zml//platforms:linux_amd64", ) @@ -52,6 +58,9 @@ oci_load( repo_tags = [ "distroless/simple_layer:latest", ], + target_compatible_with = [ + "@platforms//os:linux", + ], ) # Bazel target for pushing the Linux image to the docker registry @@ -61,4 +70,7 @@ oci_push( remote_tags = ["latest"], # override with -- --repository foo.bar/org/image repository = "index.docker.io/renerocksai/simple_layer", + target_compatible_with = [ + "@platforms//os:linux", + ], )