Add Bazel build rules and a test for the benchmark, llama, mnist, and simple_layer examples.

This commit is contained in:
Foke Singh 2024-05-23 15:52:34 +00:00
parent 3aac788544
commit 27aabf9beb
6 changed files with 99 additions and 30 deletions

View File

@ -5,7 +5,7 @@ zig_cc_binary(
main = "main.zig", main = "main.zig",
deps = [ deps = [
"@zml//async", "@zml//async",
"@zml//stdx",
"@zml//zml", "@zml//zml",
"//third_party/tigerbeetle:flags",
], ],
) )

View File

@ -1,7 +1,8 @@
const std = @import("std"); const std = @import("std");
const zml = @import("zml"); const zml = @import("zml");
const stdx = @import("stdx");
const asynk = @import("async"); const asynk = @import("async");
const flags = @import("tigerbeetle/flags"); const flags = stdx.flags;
// set log level to debug to print the generated IR // set log level to debug to print the generated IR
pub const std_options = .{ pub const std_options = .{

View File

@ -6,7 +6,6 @@ load("@bazel_skylib//rules:write_file.bzl", "write_file")
load("@rules_oci//oci:defs.bzl", "oci_image", "oci_load", "oci_push") load("@rules_oci//oci:defs.bzl", "oci_image", "oci_load", "oci_push")
load("@zml//bazel:zig.bzl", "zig_cc_binary") load("@zml//bazel:zig.bzl", "zig_cc_binary")
zig_cc_binary( zig_cc_binary(
name = "llama", name = "llama",
srcs = [ srcs = [
@ -34,6 +33,9 @@ cc_binary(
"@Meta-Llama-3.1-8B-Instruct//:model.safetensors.index.json", "@Meta-Llama-3.1-8B-Instruct//:model.safetensors.index.json",
"@Meta-Llama-3.1-8B-Instruct//:tokenizer.json", "@Meta-Llama-3.1-8B-Instruct//:tokenizer.json",
], ],
tags = [
"no_ci",
],
deps = [":llama_lib"], deps = [":llama_lib"],
) )
@ -50,6 +52,9 @@ cc_binary(
"@Meta-Llama-3.1-70B-Instruct//:model.safetensors.index.json", "@Meta-Llama-3.1-70B-Instruct//:model.safetensors.index.json",
"@Meta-Llama-3.1-70B-Instruct//:tokenizer.json", "@Meta-Llama-3.1-70B-Instruct//:tokenizer.json",
], ],
tags = [
"no_ci",
],
deps = [":llama_lib"], deps = [":llama_lib"],
) )
@ -66,6 +71,9 @@ cc_binary(
"@Meta-Llama-3.2-1B-Instruct//:model.safetensors", "@Meta-Llama-3.2-1B-Instruct//:model.safetensors",
"@Meta-Llama-3.2-1B-Instruct//:tokenizer.json", "@Meta-Llama-3.2-1B-Instruct//:tokenizer.json",
], ],
tags = [
"no_ci",
],
deps = [":llama_lib"], deps = [":llama_lib"],
) )
@ -82,6 +90,9 @@ cc_binary(
"@Meta-Llama-3.2-3B-Instruct//:model.safetensors.index.json", "@Meta-Llama-3.2-3B-Instruct//:model.safetensors.index.json",
"@Meta-Llama-3.2-3B-Instruct//:tokenizer.json", "@Meta-Llama-3.2-3B-Instruct//:tokenizer.json",
], ],
tags = [
"no_ci",
],
deps = [":llama_lib"], deps = [":llama_lib"],
) )
@ -109,7 +120,6 @@ write_file(
content = ['{"bos_token_id":1,"eos_token_id":2,"hidden_act":"silu","hidden_size":288,"intermediate_size":768,"max_position_embeddings":256,"model_type":"llama","num_attention_heads":6,"num_hidden_layers":6,"num_key_value_heads":6,"rms_norm_eps":1e-05,"hf_rope_impl":false,"rope_scaling":null,"rope_theta":10000.0}'], content = ['{"bos_token_id":1,"eos_token_id":2,"hidden_act":"silu","hidden_size":288,"intermediate_size":768,"max_position_embeddings":256,"model_type":"llama","num_attention_heads":6,"num_hidden_layers":6,"num_key_value_heads":6,"rms_norm_eps":1e-05,"hf_rope_impl":false,"rope_scaling":null,"rope_theta":10000.0}'],
) )
zig_cc_binary( zig_cc_binary(
name = "test-implementation", name = "test-implementation",
srcs = ["llama.zig"], srcs = ["llama.zig"],
@ -118,10 +128,13 @@ zig_cc_binary(
"--config=$(location @Meta-Llama-3.1-8B-Instruct//:config.json)", "--config=$(location @Meta-Llama-3.1-8B-Instruct//:config.json)",
], ],
data = [ data = [
"@Meta-Llama-3.1-8B-Instruct//:model", "@Meta-Llama-3.1-8B-Instruct//:config.json",
"@Meta-Llama-3.1-8B-Instruct//:model.safetensors.index.json", "@Meta-Llama-3.1-8B-Instruct//:model.safetensors.index.json",
], ],
main = "test.zig", main = "test.zig",
tags = [
"no_ci",
],
deps = [ deps = [
"@zml//async", "@zml//async",
"@zml//stdx", "@zml//stdx",
@ -148,11 +161,17 @@ Artificial Intelligence in Healthcare
"--expected=41481,315,15671,512,9468,241,231,12937,8152,50730,198,9468,235,103,24118,39520,32013,26371,198,35212,3294,315,10948,45910,198,25732,5664,5761,1968,26788,198,9470,16895,22107,304,39435,198,9468,236,106,8519,4140,11050,73137,198", "--expected=41481,315,15671,512,9468,241,231,12937,8152,50730,198,9468,235,103,24118,39520,32013,26371,198,35212,3294,315,10948,45910,198,25732,5664,5761,1968,26788,198,9470,16895,22107,304,39435,198,9468,236,106,8519,4140,11050,73137,198",
], ],
data = ["@Meta-Llama-3.2-1B-Instruct//:tokenizer.json"], data = ["@Meta-Llama-3.2-1B-Instruct//:tokenizer.json"],
tags = [
"no_ci",
],
) )
mtree_spec( mtree_spec(
name = "mtree", name = "mtree",
srcs = [":Llama-3.2-1B-Instruct"], srcs = [":Llama-3.2-1B-Instruct"],
tags = [
"no_ci",
],
) )
tar( tar(
@ -164,6 +183,9 @@ tar(
], ],
compress = "zstd", compress = "zstd",
mtree = ":mtree", mtree = ":mtree",
tags = [
"no_ci",
],
) )
expand_template( expand_template(
@ -180,6 +202,9 @@ expand_template(
":weights": "$(rlocationpath @Meta-Llama-3.2-1B-Instruct//:model.safetensors)", ":weights": "$(rlocationpath @Meta-Llama-3.2-1B-Instruct//:model.safetensors)",
":tokenizer": "$(rlocationpath @Meta-Llama-3.2-1B-Instruct//:tokenizer.json)", ":tokenizer": "$(rlocationpath @Meta-Llama-3.2-1B-Instruct//:tokenizer.json)",
}, },
tags = [
"no_ci",
],
template = [ template = [
"./{}/Llama-3.2-1B-Instruct".format(package_name()), "./{}/Llama-3.2-1B-Instruct".format(package_name()),
"--config=./{}/Llama-3.2-1B-Instruct.runfiles/:config".format(package_name()), "--config=./{}/Llama-3.2-1B-Instruct.runfiles/:config".format(package_name()),
@ -193,6 +218,9 @@ oci_image(
base = "@distroless_cc_debian12_debug", base = "@distroless_cc_debian12_debug",
# entrypoint = ["./{}/Llama-3.2-1B-Instruct".format(package_name())], # entrypoint = ["./{}/Llama-3.2-1B-Instruct".format(package_name())],
entrypoint = ":entrypoint", entrypoint = ":entrypoint",
tags = [
"no_ci",
],
tars = [ tars = [
"@zml//runtimes:layers", "@zml//runtimes:layers",
":archive", ":archive",
@ -202,6 +230,9 @@ oci_image(
platform_transition_filegroup( platform_transition_filegroup(
name = "image", name = "image",
srcs = [":image_"], srcs = [":image_"],
tags = [
"no_ci",
],
target_platform = "@zml//platforms:linux_amd64", target_platform = "@zml//platforms:linux_amd64",
) )
@ -211,6 +242,9 @@ oci_load(
repo_tags = [ repo_tags = [
"distroless/llama-3.2-1b-instruct:latest", "distroless/llama-3.2-1b-instruct:latest",
], ],
tags = [
"no_ci",
],
) )
oci_push( oci_push(
@ -218,4 +252,7 @@ oci_push(
image = ":image", image = ":image",
remote_tags = ["latest"], remote_tags = ["latest"],
repository = "index.docker.io/steeve/llama-3.2-1b-instruct", repository = "index.docker.io/steeve/llama-3.2-1b-instruct",
tags = [
"no_ci",
],
) )

View File

@ -1,8 +1,8 @@
const asynk = @import("async"); const asynk = @import("async");
const flags = @import("tigerbeetle/flags");
const std = @import("std"); const std = @import("std");
const stdx = @import("stdx"); const stdx = @import("stdx");
const zml = @import("zml"); const zml = @import("zml");
const flags = stdx.flags;
const llama_mod = @import("./llama.zig"); const llama_mod = @import("./llama.zig");
const LlamaLM = llama_mod.LlamaLM; const LlamaLM = llama_mod.LlamaLM;
@ -10,7 +10,7 @@ const LlamaLM = llama_mod.LlamaLM;
const Tensor = zml.Tensor; const Tensor = zml.Tensor;
pub fn main() !void { pub fn main() !void {
try asynk.AsyncThread.main(std.heap.c_allocator, asyncMain, .{}); try asynk.AsyncThread.main(std.heap.c_allocator, asyncMain);
} }
pub fn asyncMain() !void { pub fn asyncMain() !void {
@ -54,13 +54,8 @@ pub fn asyncMain() !void {
// Create the model and configure it. // Create the model and configure it.
var llama = try zml.aio.populateModel(LlamaLM, model_arena, buffer_store); var llama = try zml.aio.populateModel(LlamaLM, model_arena, buffer_store);
const num_heads: i64 = cli_args.num_heads orelse buffer_store.metadata("num_heads", .int64) orelse @panic("--num_heads is required for this model"); const num_heads: i64 = cli_args.num_heads orelse buffer_store.metadata("num_heads", .int) orelse @panic("--num_heads is required for this model");
const num_kv_heads: i64 = cli_args.num_kv_heads orelse buffer_store.metadata("num_kv_heads", .int64) orelse num_heads; const num_kv_heads: i64 = cli_args.num_kv_heads orelse buffer_store.metadata("num_kv_heads", .int) orelse num_heads;
const rope_impl = if (buffer_store.metadata("rope_impl", .string)) |val|
std.meta.stringToEnum(zml.nn.RopeOpts.Implementation, val).?
else
.sequential;
const config = blk: { const config = blk: {
var config_json_file = try asynk.File.open(cli_args.config, .{ .mode = .read_only }); var config_json_file = try asynk.File.open(cli_args.config, .{ .mode = .read_only });
@ -72,22 +67,31 @@ pub fn asyncMain() !void {
}; };
std.log.info("Parsed llama config: {}", .{config}); std.log.info("Parsed llama config: {}", .{config});
const llama_options: llama_mod.LlamaOptions = .{ const llama_config: LlamaLM.Config = .{
.eos_token_id = config.eos_token_id,
.bos_token_id = config.bos_token_id,
.num_key_value_heads = @intCast(num_kv_heads),
.num_hidden_layers = @intCast(config.num_hidden_layers),
.num_attention_heads = @intCast(num_heads),
.max_position_embeddings = config.max_position_embeddings,
.rope_theta = config.rope_theta,
.rms_norm_eps = @floatCast(buffer_store.metadata("rms_norm_eps", .float) orelse 1e-5),
.hf_rope_impl = true,
};
const llama_options: LlamaLM.Options = .{
.max_seq_len = 256, .max_seq_len = 256,
.num_kv_heads = num_kv_heads, .sampling_strategy = .{
.num_heads = num_heads, .topk = 1,
.gen_opts = .{}, .temperature = 1.0,
.rms_norm_eps = @floatCast(buffer_store.metadata("rms_norm_eps", .float64) orelse 1e-5),
.rope_opts = .{
.impl = rope_impl,
.freq_base = @floatCast(buffer_store.metadata("rope_freq_base", .float64) orelse @as(f32, @floatFromInt(cli_args.rope_freq_base orelse 10_000))),
}, },
}; };
std.log.info("Parsed llama config: {}", .{llama_options}); std.log.info("Parsed llama config: {}", .{llama_options});
llama.init(llama_options); llama.init(llama_config, llama_options);
// Load the weights. // Load the weights.
var llama_weights = try zml.aio.loadBuffers(LlamaLM, .{llama_options}, buffer_store, model_arena, platform); var llama_weights = try zml.aio.loadBuffers(LlamaLM, .{ llama_config, llama_options }, buffer_store, model_arena, platform);
defer zml.aio.unloadBuffers(&llama_weights); defer zml.aio.unloadBuffers(&llama_weights);
// Load the activations. // Load the activations.

View File

@ -60,12 +60,18 @@ oci_image(
name = "image_", name = "image_",
base = "@distroless_cc_debian12", base = "@distroless_cc_debian12",
entrypoint = ":entrypoint", entrypoint = ":entrypoint",
target_compatible_with = [
"@platforms//os:linux",
],
tars = [":archive"], tars = [":archive"],
) )
platform_transition_filegroup( platform_transition_filegroup(
name = "image", name = "image",
srcs = [":image_"], srcs = [":image_"],
target_compatible_with = [
"@platforms//os:linux",
],
target_platform = "@zml//platforms:linux_amd64", target_platform = "@zml//platforms:linux_amd64",
) )
@ -75,6 +81,9 @@ oci_load(
repo_tags = [ repo_tags = [
"distroless/mnist:latest", "distroless/mnist:latest",
], ],
target_compatible_with = [
"@platforms//os:linux",
],
) )
oci_push( oci_push(
@ -82,6 +91,9 @@ oci_push(
image = ":image", image = ":image",
remote_tags = ["latest"], remote_tags = ["latest"],
repository = "index.docker.io/steeve/mnist", repository = "index.docker.io/steeve/mnist",
target_compatible_with = [
"@platforms//os:linux",
],
) )
oci_load( oci_load(
@ -90,4 +102,7 @@ oci_load(
repo_tags = [ repo_tags = [
"distroless/mnist:latest", "distroless/mnist:latest",
], ],
target_compatible_with = [
"@platforms//os:linux",
],
) )

View File

@ -35,6 +35,9 @@ oci_image(
name = "image_", name = "image_",
base = "@distroless_cc_debian12", base = "@distroless_cc_debian12",
entrypoint = ["./{}/simple_layer".format(package_name())], entrypoint = ["./{}/simple_layer".format(package_name())],
target_compatible_with = [
"@platforms//os:linux",
],
tars = [":archive"], tars = [":archive"],
) )
@ -42,6 +45,9 @@ oci_image(
platform_transition_filegroup( platform_transition_filegroup(
name = "image", name = "image",
srcs = [":image_"], srcs = [":image_"],
target_compatible_with = [
"@platforms//os:linux",
],
target_platform = "@zml//platforms:linux_amd64", target_platform = "@zml//platforms:linux_amd64",
) )
@ -52,6 +58,9 @@ oci_load(
repo_tags = [ repo_tags = [
"distroless/simple_layer:latest", "distroless/simple_layer:latest",
], ],
target_compatible_with = [
"@platforms//os:linux",
],
) )
# Bazel target for pushing the Linux image to the docker registry # Bazel target for pushing the Linux image to the docker registry
@ -61,4 +70,7 @@ oci_push(
remote_tags = ["latest"], remote_tags = ["latest"],
# override with -- --repository foo.bar/org/image # override with -- --repository foo.bar/org/image
repository = "index.docker.io/renerocksai/simple_layer", repository = "index.docker.io/renerocksai/simple_layer",
target_compatible_with = [
"@platforms//os:linux",
],
) )