181 lines
4.5 KiB
Python
181 lines
4.5 KiB
Python
load("@aspect_bazel_lib//lib:tar.bzl", "mtree_spec", "tar")
|
|
load("@aspect_bazel_lib//lib:transitions.bzl", "platform_transition_filegroup")
|
|
load("@bazel_skylib//rules:native_binary.bzl", "native_binary")
|
|
load("@rules_oci//oci:defs.bzl", "oci_image", "oci_load", "oci_push")
|
|
load("@zml//bazel:zig.bzl", "zig_cc_binary")
|
|
|
|
zig_cc_binary(
|
|
name = "llama",
|
|
srcs = [
|
|
"llama.zig",
|
|
],
|
|
main = "main.zig",
|
|
deps = [
|
|
"//third_party/tigerbeetle:flags",
|
|
"@zml//async",
|
|
"@zml//stdx",
|
|
"@zml//zml",
|
|
],
|
|
)
|
|
|
|
native_binary(
|
|
name = "Llama-3.1-8B-Instruct",
|
|
src = ":llama",
|
|
args = [
|
|
"--model=$(location @Meta-Llama-3.1-8B-Instruct//:model.safetensors.index.json)",
|
|
"--tokenizer=$(location @Meta-Llama-3.1-8B-Instruct//:tokenizer)",
|
|
"--num-heads=32",
|
|
"--num-kv-heads=8",
|
|
"--rope-freq-base=500000",
|
|
],
|
|
data = [
|
|
"@Meta-Llama-3.1-8B-Instruct//:model",
|
|
"@Meta-Llama-3.1-8B-Instruct//:model.safetensors.index.json",
|
|
"@Meta-Llama-3.1-8B-Instruct//:tokenizer",
|
|
],
|
|
)
|
|
|
|
native_binary(
|
|
name = "Llama-3.1-70B-Instruct",
|
|
src = ":llama",
|
|
args = [
|
|
"--model=$(location @Meta-Llama-3.1-70B-Instruct//:model.safetensors.index.json)",
|
|
"--tokenizer=$(location @Meta-Llama-3.1-70B-Instruct//:tokenizer)",
|
|
"--num-heads=64",
|
|
"--num-kv-heads=8",
|
|
"--rope-freq-base=500000",
|
|
],
|
|
data = [
|
|
"@Meta-Llama-3.1-70B-Instruct//:model",
|
|
"@Meta-Llama-3.1-70B-Instruct//:model.safetensors.index.json",
|
|
"@Meta-Llama-3.1-70B-Instruct//:tokenizer",
|
|
],
|
|
)
|
|
|
|
native_binary(
|
|
name = "OpenLLaMA-3B",
|
|
src = ":llama",
|
|
args = [
|
|
"--model=$(location @OpenLM-Research-OpenLLaMA-3B//:model)",
|
|
"--tokenizer=$(location @OpenLM-Research-OpenLLaMA-3B//:tokenizer)",
|
|
"--num-heads=32",
|
|
"--num-kv-heads=32",
|
|
"--rope-freq-base=10000",
|
|
],
|
|
data = [
|
|
"@OpenLM-Research-OpenLLaMA-3B//:model",
|
|
"@OpenLM-Research-OpenLLaMA-3B//:tokenizer",
|
|
],
|
|
)
|
|
|
|
native_binary(
|
|
name = "TinyLlama-1.1B-Chat",
|
|
src = ":llama",
|
|
args = [
|
|
"--model=$(location @TinyLlama-1.1B-Chat-v1.0//:model.safetensors)",
|
|
"--tokenizer=$(location @TinyLlama-1.1B-Chat-v1.0//:tokenizer)",
|
|
"--num-heads=32",
|
|
"--num-kv-heads=4",
|
|
"--rope-freq-base=10000",
|
|
],
|
|
data = [
|
|
"@TinyLlama-1.1B-Chat-v1.0//:model.safetensors",
|
|
"@TinyLlama-1.1B-Chat-v1.0//:tokenizer",
|
|
],
|
|
)
|
|
|
|
native_binary(
|
|
name = "TinyLlama-Stories-110M",
|
|
src = ":llama",
|
|
args = [
|
|
"--model=$(location @Karpathy-TinyLlama-Stories//:stories110M)",
|
|
"--tokenizer=$(location @Karpathy-TinyLlama-Tokenizer//file)",
|
|
],
|
|
data = [
|
|
"@Karpathy-TinyLlama-Stories//:stories110M",
|
|
"@Karpathy-TinyLlama-Tokenizer//file",
|
|
],
|
|
)
|
|
|
|
native_binary(
|
|
name = "TinyLlama-Stories-15M",
|
|
src = ":llama",
|
|
args = [
|
|
"--model=$(location @Karpathy-TinyLlama-Stories//:stories15M)",
|
|
"--tokenizer=$(location @Karpathy-TinyLlama-Tokenizer//file)",
|
|
],
|
|
data = [
|
|
"@Karpathy-TinyLlama-Stories//:stories15M",
|
|
"@Karpathy-TinyLlama-Tokenizer//file",
|
|
],
|
|
)
|
|
|
|
zig_cc_binary(
|
|
name = "test-implementation",
|
|
srcs = ["llama.zig"],
|
|
args = [
|
|
"--model=$(location @Meta-Llama-3.1-8B-Instruct//:model.safetensors.index.json)",
|
|
"--num-heads=32",
|
|
"--num-kv-heads=8",
|
|
"--rope-freq-base=500000",
|
|
],
|
|
data = [
|
|
"@Meta-Llama-3.1-8B-Instruct//:model",
|
|
"@Meta-Llama-3.1-8B-Instruct//:model.safetensors.index.json",
|
|
],
|
|
main = "test.zig",
|
|
deps = [
|
|
"//third_party/tigerbeetle:flags",
|
|
"@zml//async",
|
|
"@zml//metax",
|
|
"@zml//zml",
|
|
],
|
|
)
|
|
|
|
mtree_spec(
|
|
name = "mtree",
|
|
srcs = [":llama"],
|
|
)
|
|
|
|
tar(
|
|
name = "archive",
|
|
srcs = [":llama"],
|
|
args = [
|
|
"--options",
|
|
"zstd:compression-level=9",
|
|
],
|
|
compress = "zstd",
|
|
mtree = ":mtree",
|
|
)
|
|
|
|
oci_image(
|
|
name = "image_",
|
|
base = "@distroless_cc_debian12_debug",
|
|
entrypoint = ["./{}/llama".format(package_name())],
|
|
tars = [
|
|
"@zml//runtimes:layers",
|
|
":archive",
|
|
],
|
|
)
|
|
|
|
platform_transition_filegroup(
|
|
name = "image",
|
|
srcs = [":image_"],
|
|
target_platform = "@zml//platforms:linux_amd64",
|
|
)
|
|
|
|
oci_load(
|
|
name = "load",
|
|
image = ":image",
|
|
repo_tags = [
|
|
"distroless/llama:latest",
|
|
],
|
|
)
|
|
|
|
oci_push(
|
|
name = "push",
|
|
image = ":image",
|
|
remote_tags = ["latest"],
|
|
repository = "index.docker.io/steeve/llama",
|
|
)
|