Add Bazel build rules and a test for the benchmark, llama, mnist, and simple_layer examples.
This commit is contained in:
parent
3aac788544
commit
27aabf9beb
@ -5,7 +5,7 @@ zig_cc_binary(
|
||||
main = "main.zig",
|
||||
deps = [
|
||||
"@zml//async",
|
||||
"@zml//stdx",
|
||||
"@zml//zml",
|
||||
"//third_party/tigerbeetle:flags",
|
||||
],
|
||||
)
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
const std = @import("std");
|
||||
const zml = @import("zml");
|
||||
const stdx = @import("stdx");
|
||||
const asynk = @import("async");
|
||||
const flags = @import("tigerbeetle/flags");
|
||||
const flags = stdx.flags;
|
||||
|
||||
// set log level to debug to print the generated IR
|
||||
pub const std_options = .{
|
||||
|
||||
@ -6,7 +6,6 @@ load("@bazel_skylib//rules:write_file.bzl", "write_file")
|
||||
load("@rules_oci//oci:defs.bzl", "oci_image", "oci_load", "oci_push")
|
||||
load("@zml//bazel:zig.bzl", "zig_cc_binary")
|
||||
|
||||
|
||||
zig_cc_binary(
|
||||
name = "llama",
|
||||
srcs = [
|
||||
@ -34,6 +33,9 @@ cc_binary(
|
||||
"@Meta-Llama-3.1-8B-Instruct//:model.safetensors.index.json",
|
||||
"@Meta-Llama-3.1-8B-Instruct//:tokenizer.json",
|
||||
],
|
||||
tags = [
|
||||
"no_ci",
|
||||
],
|
||||
deps = [":llama_lib"],
|
||||
)
|
||||
|
||||
@ -50,6 +52,9 @@ cc_binary(
|
||||
"@Meta-Llama-3.1-70B-Instruct//:model.safetensors.index.json",
|
||||
"@Meta-Llama-3.1-70B-Instruct//:tokenizer.json",
|
||||
],
|
||||
tags = [
|
||||
"no_ci",
|
||||
],
|
||||
deps = [":llama_lib"],
|
||||
)
|
||||
|
||||
@ -66,6 +71,9 @@ cc_binary(
|
||||
"@Meta-Llama-3.2-1B-Instruct//:model.safetensors",
|
||||
"@Meta-Llama-3.2-1B-Instruct//:tokenizer.json",
|
||||
],
|
||||
tags = [
|
||||
"no_ci",
|
||||
],
|
||||
deps = [":llama_lib"],
|
||||
)
|
||||
|
||||
@ -82,6 +90,9 @@ cc_binary(
|
||||
"@Meta-Llama-3.2-3B-Instruct//:model.safetensors.index.json",
|
||||
"@Meta-Llama-3.2-3B-Instruct//:tokenizer.json",
|
||||
],
|
||||
tags = [
|
||||
"no_ci",
|
||||
],
|
||||
deps = [":llama_lib"],
|
||||
)
|
||||
|
||||
@ -92,8 +103,8 @@ cc_binary(
|
||||
"--weights=$(location @Karpathy-TinyLlama-Stories15M//file)",
|
||||
"--tokenizer=$(location @Karpathy-TinyLlama-Tokenizer//file)",
|
||||
"--prompt='Once upon a time, there was a little girl named Lily.'",
|
||||
"--no-llama3=1", # don't do template prompt encoding, I'm a simple model
|
||||
"--sharding=false", # don't shard me, I'm so small
|
||||
"--no-llama3=1", # don't do template prompt encoding, I'm a simple model
|
||||
"--sharding=false", # don't shard me, I'm so small
|
||||
],
|
||||
data = [
|
||||
":tinyllama_stories15M_json",
|
||||
@ -109,7 +120,6 @@ write_file(
|
||||
content = ['{"bos_token_id":1,"eos_token_id":2,"hidden_act":"silu","hidden_size":288,"intermediate_size":768,"max_position_embeddings":256,"model_type":"llama","num_attention_heads":6,"num_hidden_layers":6,"num_key_value_heads":6,"rms_norm_eps":1e-05,"hf_rope_impl":false,"rope_scaling":null,"rope_theta":10000.0}'],
|
||||
)
|
||||
|
||||
|
||||
zig_cc_binary(
|
||||
name = "test-implementation",
|
||||
srcs = ["llama.zig"],
|
||||
@ -118,10 +128,13 @@ zig_cc_binary(
|
||||
"--config=$(location @Meta-Llama-3.1-8B-Instruct//:config.json)",
|
||||
],
|
||||
data = [
|
||||
"@Meta-Llama-3.1-8B-Instruct//:model",
|
||||
"@Meta-Llama-3.1-8B-Instruct//:config.json",
|
||||
"@Meta-Llama-3.1-8B-Instruct//:model.safetensors.index.json",
|
||||
],
|
||||
main = "test.zig",
|
||||
tags = [
|
||||
"no_ci",
|
||||
],
|
||||
deps = [
|
||||
"@zml//async",
|
||||
"@zml//stdx",
|
||||
@ -135,8 +148,8 @@ native_test(
|
||||
# Note: all Llama-3.x tokenizers are the same,
|
||||
# but using the 3.2-1B version because downloading the tokenizer triggers downloading the model.
|
||||
args = [
|
||||
"--tokenizer=$(location @Meta-Llama-3.2-1B-Instruct//:tokenizer.json)",
|
||||
"""--prompt='Examples of titles:
|
||||
"--tokenizer=$(location @Meta-Llama-3.2-1B-Instruct//:tokenizer.json)",
|
||||
"""--prompt='Examples of titles:
|
||||
📉 Stock Market Trends
|
||||
🍪 Perfect Chocolate Chip Recipe
|
||||
Evolution of Music Streaming
|
||||
@ -144,15 +157,21 @@ Remote Work Productivity Tips
|
||||
Artificial Intelligence in Healthcare
|
||||
🎮 Video Game Development Insights
|
||||
'""",
|
||||
# this correspond to encoding with HF tokenizers, with bos=False
|
||||
"--expected=41481,315,15671,512,9468,241,231,12937,8152,50730,198,9468,235,103,24118,39520,32013,26371,198,35212,3294,315,10948,45910,198,25732,5664,5761,1968,26788,198,9470,16895,22107,304,39435,198,9468,236,106,8519,4140,11050,73137,198",
|
||||
# this correspond to encoding with HF tokenizers, with bos=False
|
||||
"--expected=41481,315,15671,512,9468,241,231,12937,8152,50730,198,9468,235,103,24118,39520,32013,26371,198,35212,3294,315,10948,45910,198,25732,5664,5761,1968,26788,198,9470,16895,22107,304,39435,198,9468,236,106,8519,4140,11050,73137,198",
|
||||
],
|
||||
data = ["@Meta-Llama-3.2-1B-Instruct//:tokenizer.json"],
|
||||
tags = [
|
||||
"no_ci",
|
||||
],
|
||||
)
|
||||
|
||||
mtree_spec(
|
||||
name = "mtree",
|
||||
srcs = [":Llama-3.2-1B-Instruct"],
|
||||
tags = [
|
||||
"no_ci",
|
||||
],
|
||||
)
|
||||
|
||||
tar(
|
||||
@ -164,6 +183,9 @@ tar(
|
||||
],
|
||||
compress = "zstd",
|
||||
mtree = ":mtree",
|
||||
tags = [
|
||||
"no_ci",
|
||||
],
|
||||
)
|
||||
|
||||
expand_template(
|
||||
@ -180,6 +202,9 @@ expand_template(
|
||||
":weights": "$(rlocationpath @Meta-Llama-3.2-1B-Instruct//:model.safetensors)",
|
||||
":tokenizer": "$(rlocationpath @Meta-Llama-3.2-1B-Instruct//:tokenizer.json)",
|
||||
},
|
||||
tags = [
|
||||
"no_ci",
|
||||
],
|
||||
template = [
|
||||
"./{}/Llama-3.2-1B-Instruct".format(package_name()),
|
||||
"--config=./{}/Llama-3.2-1B-Instruct.runfiles/:config".format(package_name()),
|
||||
@ -193,6 +218,9 @@ oci_image(
|
||||
base = "@distroless_cc_debian12_debug",
|
||||
# entrypoint = ["./{}/Llama-3.2-1B-Instruct".format(package_name())],
|
||||
entrypoint = ":entrypoint",
|
||||
tags = [
|
||||
"no_ci",
|
||||
],
|
||||
tars = [
|
||||
"@zml//runtimes:layers",
|
||||
":archive",
|
||||
@ -202,6 +230,9 @@ oci_image(
|
||||
platform_transition_filegroup(
|
||||
name = "image",
|
||||
srcs = [":image_"],
|
||||
tags = [
|
||||
"no_ci",
|
||||
],
|
||||
target_platform = "@zml//platforms:linux_amd64",
|
||||
)
|
||||
|
||||
@ -211,6 +242,9 @@ oci_load(
|
||||
repo_tags = [
|
||||
"distroless/llama-3.2-1b-instruct:latest",
|
||||
],
|
||||
tags = [
|
||||
"no_ci",
|
||||
],
|
||||
)
|
||||
|
||||
oci_push(
|
||||
@ -218,4 +252,7 @@ oci_push(
|
||||
image = ":image",
|
||||
remote_tags = ["latest"],
|
||||
repository = "index.docker.io/steeve/llama-3.2-1b-instruct",
|
||||
tags = [
|
||||
"no_ci",
|
||||
],
|
||||
)
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
const asynk = @import("async");
|
||||
const flags = @import("tigerbeetle/flags");
|
||||
const std = @import("std");
|
||||
const stdx = @import("stdx");
|
||||
const zml = @import("zml");
|
||||
const flags = stdx.flags;
|
||||
|
||||
const llama_mod = @import("./llama.zig");
|
||||
const LlamaLM = llama_mod.LlamaLM;
|
||||
@ -10,7 +10,7 @@ const LlamaLM = llama_mod.LlamaLM;
|
||||
const Tensor = zml.Tensor;
|
||||
|
||||
pub fn main() !void {
|
||||
try asynk.AsyncThread.main(std.heap.c_allocator, asyncMain, .{});
|
||||
try asynk.AsyncThread.main(std.heap.c_allocator, asyncMain);
|
||||
}
|
||||
|
||||
pub fn asyncMain() !void {
|
||||
@ -54,13 +54,8 @@ pub fn asyncMain() !void {
|
||||
|
||||
// Create the model and configure it.
|
||||
var llama = try zml.aio.populateModel(LlamaLM, model_arena, buffer_store);
|
||||
const num_heads: i64 = cli_args.num_heads orelse buffer_store.metadata("num_heads", .int64) orelse @panic("--num_heads is required for this model");
|
||||
const num_kv_heads: i64 = cli_args.num_kv_heads orelse buffer_store.metadata("num_kv_heads", .int64) orelse num_heads;
|
||||
|
||||
const rope_impl = if (buffer_store.metadata("rope_impl", .string)) |val|
|
||||
std.meta.stringToEnum(zml.nn.RopeOpts.Implementation, val).?
|
||||
else
|
||||
.sequential;
|
||||
const num_heads: i64 = cli_args.num_heads orelse buffer_store.metadata("num_heads", .int) orelse @panic("--num_heads is required for this model");
|
||||
const num_kv_heads: i64 = cli_args.num_kv_heads orelse buffer_store.metadata("num_kv_heads", .int) orelse num_heads;
|
||||
|
||||
const config = blk: {
|
||||
var config_json_file = try asynk.File.open(cli_args.config, .{ .mode = .read_only });
|
||||
@ -72,22 +67,31 @@ pub fn asyncMain() !void {
|
||||
};
|
||||
std.log.info("Parsed llama config: {}", .{config});
|
||||
|
||||
const llama_options: llama_mod.LlamaOptions = .{
|
||||
const llama_config: LlamaLM.Config = .{
|
||||
.eos_token_id = config.eos_token_id,
|
||||
.bos_token_id = config.bos_token_id,
|
||||
.num_key_value_heads = @intCast(num_kv_heads),
|
||||
.num_hidden_layers = @intCast(config.num_hidden_layers),
|
||||
.num_attention_heads = @intCast(num_heads),
|
||||
.max_position_embeddings = config.max_position_embeddings,
|
||||
.rope_theta = config.rope_theta,
|
||||
.rms_norm_eps = @floatCast(buffer_store.metadata("rms_norm_eps", .float) orelse 1e-5),
|
||||
.hf_rope_impl = true,
|
||||
};
|
||||
|
||||
const llama_options: LlamaLM.Options = .{
|
||||
.max_seq_len = 256,
|
||||
.num_kv_heads = num_kv_heads,
|
||||
.num_heads = num_heads,
|
||||
.gen_opts = .{},
|
||||
.rms_norm_eps = @floatCast(buffer_store.metadata("rms_norm_eps", .float64) orelse 1e-5),
|
||||
.rope_opts = .{
|
||||
.impl = rope_impl,
|
||||
.freq_base = @floatCast(buffer_store.metadata("rope_freq_base", .float64) orelse @as(f32, @floatFromInt(cli_args.rope_freq_base orelse 10_000))),
|
||||
.sampling_strategy = .{
|
||||
.topk = 1,
|
||||
.temperature = 1.0,
|
||||
},
|
||||
};
|
||||
|
||||
std.log.info("Parsed llama config: {}", .{llama_options});
|
||||
llama.init(llama_options);
|
||||
llama.init(llama_config, llama_options);
|
||||
|
||||
// Load the weights.
|
||||
var llama_weights = try zml.aio.loadBuffers(LlamaLM, .{llama_options}, buffer_store, model_arena, platform);
|
||||
var llama_weights = try zml.aio.loadBuffers(LlamaLM, .{ llama_config, llama_options }, buffer_store, model_arena, platform);
|
||||
defer zml.aio.unloadBuffers(&llama_weights);
|
||||
|
||||
// Load the activations.
|
||||
|
||||
@ -60,12 +60,18 @@ oci_image(
|
||||
name = "image_",
|
||||
base = "@distroless_cc_debian12",
|
||||
entrypoint = ":entrypoint",
|
||||
target_compatible_with = [
|
||||
"@platforms//os:linux",
|
||||
],
|
||||
tars = [":archive"],
|
||||
)
|
||||
|
||||
platform_transition_filegroup(
|
||||
name = "image",
|
||||
srcs = [":image_"],
|
||||
target_compatible_with = [
|
||||
"@platforms//os:linux",
|
||||
],
|
||||
target_platform = "@zml//platforms:linux_amd64",
|
||||
)
|
||||
|
||||
@ -75,6 +81,9 @@ oci_load(
|
||||
repo_tags = [
|
||||
"distroless/mnist:latest",
|
||||
],
|
||||
target_compatible_with = [
|
||||
"@platforms//os:linux",
|
||||
],
|
||||
)
|
||||
|
||||
oci_push(
|
||||
@ -82,6 +91,9 @@ oci_push(
|
||||
image = ":image",
|
||||
remote_tags = ["latest"],
|
||||
repository = "index.docker.io/steeve/mnist",
|
||||
target_compatible_with = [
|
||||
"@platforms//os:linux",
|
||||
],
|
||||
)
|
||||
|
||||
oci_load(
|
||||
@ -90,4 +102,7 @@ oci_load(
|
||||
repo_tags = [
|
||||
"distroless/mnist:latest",
|
||||
],
|
||||
target_compatible_with = [
|
||||
"@platforms//os:linux",
|
||||
],
|
||||
)
|
||||
|
||||
@ -35,6 +35,9 @@ oci_image(
|
||||
name = "image_",
|
||||
base = "@distroless_cc_debian12",
|
||||
entrypoint = ["./{}/simple_layer".format(package_name())],
|
||||
target_compatible_with = [
|
||||
"@platforms//os:linux",
|
||||
],
|
||||
tars = [":archive"],
|
||||
)
|
||||
|
||||
@ -42,6 +45,9 @@ oci_image(
|
||||
platform_transition_filegroup(
|
||||
name = "image",
|
||||
srcs = [":image_"],
|
||||
target_compatible_with = [
|
||||
"@platforms//os:linux",
|
||||
],
|
||||
target_platform = "@zml//platforms:linux_amd64",
|
||||
)
|
||||
|
||||
@ -52,6 +58,9 @@ oci_load(
|
||||
repo_tags = [
|
||||
"distroless/simple_layer:latest",
|
||||
],
|
||||
target_compatible_with = [
|
||||
"@platforms//os:linux",
|
||||
],
|
||||
)
|
||||
|
||||
# Bazel target for pushing the Linux image to the docker registry
|
||||
@ -61,4 +70,7 @@ oci_push(
|
||||
remote_tags = ["latest"],
|
||||
# override with -- --repository foo.bar/org/image
|
||||
repository = "index.docker.io/renerocksai/simple_layer",
|
||||
target_compatible_with = [
|
||||
"@platforms//os:linux",
|
||||
],
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user