Radix/examples/llama/BUILD.bazel

180 lines
4.5 KiB
Python
Raw Normal View History

load("@aspect_bazel_lib//lib:tar.bzl", "mtree_spec", "tar")
load("@aspect_bazel_lib//lib:transitions.bzl", "platform_transition_filegroup")
load("@rules_oci//oci:defs.bzl", "oci_image", "oci_load", "oci_push")
load("@zml//bazel:zig.bzl", "zig_cc_binary")
zig_cc_binary(
name = "llama",
srcs = [
"llama.zig",
],
main = "main.zig",
deps = [
"//third_party/tigerbeetle:flags",
"@zml//async",
"@zml//stdx",
"@zml//zml",
],
)
cc_binary(
name = "Llama-3.1-8B-Instruct",
args = [
"--model=$(location @Meta-Llama-3.1-8B-Instruct//:model.safetensors.index.json)",
"--tokenizer=$(location @Meta-Llama-3.1-8B-Instruct//:tokenizer)",
"--num-heads=32",
"--num-kv-heads=8",
"--rope-freq-base=500000",
],
data = [
"@Meta-Llama-3.1-8B-Instruct//:model",
"@Meta-Llama-3.1-8B-Instruct//:model.safetensors.index.json",
"@Meta-Llama-3.1-8B-Instruct//:tokenizer",
],
deps = [":llama_lib"],
)
cc_binary(
name = "Llama-3.1-70B-Instruct",
args = [
"--model=$(location @Meta-Llama-3.1-70B-Instruct//:model.safetensors.index.json)",
"--tokenizer=$(location @Meta-Llama-3.1-70B-Instruct//:tokenizer)",
"--num-heads=64",
"--num-kv-heads=8",
"--rope-freq-base=500000",
],
data = [
"@Meta-Llama-3.1-70B-Instruct//:model",
"@Meta-Llama-3.1-70B-Instruct//:model.safetensors.index.json",
"@Meta-Llama-3.1-70B-Instruct//:tokenizer",
],
deps = [":llama_lib"],
)
cc_binary(
name = "OpenLLaMA-3B",
args = [
"--model=$(location @OpenLM-Research-OpenLLaMA-3B//:model)",
"--tokenizer=$(location @OpenLM-Research-OpenLLaMA-3B//:tokenizer)",
"--num-heads=32",
"--num-kv-heads=32",
"--rope-freq-base=10000",
],
data = [
"@OpenLM-Research-OpenLLaMA-3B//:model",
"@OpenLM-Research-OpenLLaMA-3B//:tokenizer",
],
deps = [":llama_lib"],
)
cc_binary(
name = "TinyLlama-1.1B-Chat",
args = [
"--model=$(location @TinyLlama-1.1B-Chat-v1.0//:model.safetensors)",
"--tokenizer=$(location @TinyLlama-1.1B-Chat-v1.0//:tokenizer)",
"--num-heads=32",
"--num-kv-heads=4",
"--rope-freq-base=10000",
],
data = [
"@TinyLlama-1.1B-Chat-v1.0//:model.safetensors",
"@TinyLlama-1.1B-Chat-v1.0//:tokenizer",
],
deps = [":llama_lib"],
)
cc_binary(
name = "TinyLlama-Stories-110M",
args = [
"--model=$(location @Karpathy-TinyLlama-Stories//:stories110M)",
"--tokenizer=$(location @Karpathy-TinyLlama-Tokenizer//file)",
],
data = [
"@Karpathy-TinyLlama-Stories//:stories110M",
"@Karpathy-TinyLlama-Tokenizer//file",
],
deps = [":llama_lib"],
)
cc_binary(
name = "TinyLlama-Stories-15M",
args = [
"--model=$(location @Karpathy-TinyLlama-Stories//:stories15M)",
"--tokenizer=$(location @Karpathy-TinyLlama-Tokenizer//file)",
],
data = [
"@Karpathy-TinyLlama-Stories//:stories15M",
"@Karpathy-TinyLlama-Tokenizer//file",
],
deps = [":llama_lib"],
)
zig_cc_binary(
name = "test-implementation",
srcs = ["llama.zig"],
args = [
"--model=$(location @Meta-Llama-3.1-8B-Instruct//:model.safetensors.index.json)",
"--num-heads=32",
"--num-kv-heads=8",
"--rope-freq-base=500000",
],
data = [
"@Meta-Llama-3.1-8B-Instruct//:model",
"@Meta-Llama-3.1-8B-Instruct//:model.safetensors.index.json",
],
main = "test.zig",
deps = [
"//third_party/tigerbeetle:flags",
"@zml//async",
"@zml//metax",
"@zml//zml",
],
)
mtree_spec(
name = "mtree",
srcs = [":llama"],
)
tar(
name = "archive",
srcs = [":llama"],
args = [
"--options",
"zstd:compression-level=9",
],
compress = "zstd",
mtree = ":mtree",
)
oci_image(
name = "image_",
base = "@distroless_cc_debian12_debug",
entrypoint = ["./{}/llama".format(package_name())],
tars = [
"@zml//runtimes:layers",
":archive",
],
)
platform_transition_filegroup(
name = "image",
srcs = [":image_"],
target_platform = "@zml//platforms:linux_amd64",
)
oci_load(
name = "load",
image = ":image",
repo_tags = [
"distroless/llama:latest",
],
)
oci_push(
name = "push",
image = ":image",
remote_tags = ["latest"],
repository = "index.docker.io/steeve/llama",
)