Update Llama example docs and Bazel build files, and add tests for the new HuggingFace tokenizer integration.

This commit is contained in:
Foke Singh 2024-03-04 12:11:13 +00:00
parent 959bc48c42
commit 76e314db9b
10 changed files with 1847 additions and 585 deletions

View File

@ -59,47 +59,41 @@ Llama is a family of "Large Language Models", trained to generate text, based
on the beginning of a sentence/book/article. This "beginning" is generally on the beginning of a sentence/book/article. This "beginning" is generally
referred to as the "prompt". referred to as the "prompt".
#### TinyLlama, Stories 15M #### Meta Llama 3.1 8B
To start, you can use a small model trained specifically on children's history
books. This model has been trained by [Andrej Karpathy](https://x.com/karpathy);
you can read more about it on his
[Github](https://github.com/karpathy/llama2.c).
```
cd examples
bazel run -c opt //llama:TinyLlama-Stories-15M
bazel run -c opt //llama:TinyLlama-Stories-15M -- --prompt="Once upon a time, there was a cute little dragon"
```
#### OpenLLama 3B
```
cd examples
bazel run -c opt //llama:OpenLLaMA-3B
bazel run -c opt //llama:OpenLLaMA-3B -- --prompt="Once upon a time,"
```
#### Meta Llama 3 8B
This model has restrictions, see This model has restrictions, see
[here](https://huggingface.co/meta-llama/Meta-Llama-3-8B): it **requires [here](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct). It **requires
approval from Meta on Huggingface**, which can take a few hours to get granted. approval from Meta on Huggingface**, which can take a few hours to get granted.
While waiting for approval, you can already While waiting for approval, you can already
[generate your Huggingface access token](../howtos/huggingface_access_token.md). [generate your Huggingface access token](../howtos/huggingface_access_token.md).
Once you've been granted access, you're ready to download a gated model like Once you've been granted access, you're ready to download a gated model like
`Meta-Llama-3-8b`! `Meta-Llama-3.1-8B-Instruct`!
``` ```
# requires token in $HOME/.cache/huggingface/token, as created by the # requires token in $HOME/.cache/huggingface/token, as created by the
# `huggingface-cli login` command, or the `HUGGINGFACE_TOKEN` environment variable. # `huggingface-cli login` command, or the `HUGGINGFACE_TOKEN` environment variable.
cd examples cd examples
bazel run -c opt //llama:Meta-Llama-3-8b bazel run -c opt //llama:Llama-3.1-8B-Instruct
bazel run -c opt //llama:Meta-Llama-3-8b -- --promt="Once upon a time," bazel run -c opt //llama:Llama-3.1-8B-Instruct -- --prompt="What is the capital of France?"
``` ```
You can also try `Llama-3.1-70B-Instruct` if you have enough memory.
### Meta Llama 3.2 1B
Like the 8B model above, this model also requires approval. See
[here](https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct) for access requirements.
```
cd examples
bazel run -c opt //llama:Llama-3.2-1B-Instruct
bazel run -c opt //llama:Llama-3.2-1B-Instruct -- --prompt="What is the capital of France?"
```
For a larger 3.2 model, you can also try `Llama-3.2-3B-Instruct`.
## Run Tests ## Run Tests
@ -126,9 +120,9 @@ run the following:
``` ```
cd examples cd examples
bazel run -c opt //llama:OpenLLaMA-3B \ bazel run -c opt //llama:Llama-3.2-1B-Instruct \
--@zml//runtimes:cuda=true \ --@zml//runtimes:cuda=true \
-- --prompt="Once upon a time," -- --prompt="What is the capital of France?"
``` ```

0
examples/BUILD.bazel Normal file
View File

View File

@ -7,6 +7,9 @@ bazel_dep(name = "zml", version = "0.1.0")
bazel_dep(name = "aspect_bazel_lib", version = "2.11.0") bazel_dep(name = "aspect_bazel_lib", version = "2.11.0")
bazel_dep(name = "rules_oci", version = "2.0.0") bazel_dep(name = "rules_oci", version = "2.0.0")
non_module_deps = use_extension("//:third_party/non_module_deps.bzl", "non_module_deps")
use_repo(non_module_deps, "com_github_hejsil_clap")
oci = use_extension("@rules_oci//oci:extensions.bzl", "oci") oci = use_extension("@rules_oci//oci:extensions.bzl", "oci")
oci.pull( oci.pull(
name = "distroless_cc_debian12", name = "distroless_cc_debian12",
@ -44,67 +47,24 @@ http_file(
url = "https://github.com/ggerganov/ggml/raw/18703ad600cc68dbdb04d57434c876989a841d12/examples/mnist/models/mnist/t10k-images.idx3-ubyte", url = "https://github.com/ggerganov/ggml/raw/18703ad600cc68dbdb04d57434c876989a841d12/examples/mnist/models/mnist/t10k-images.idx3-ubyte",
) )
# Llama weights
huggingface = use_extension("@zml//bazel:huggingface.bzl", "huggingface")
huggingface.model(
name = "Karpathy-TinyLlama-Stories",
build_file_content = """\
load("@bazel_skylib//rules:copy_file.bzl", "copy_file")
# leverage copy_file to rename tokenizer extension
# which allow zml.aio.detectFormatAndLoadTokenizer
# to leverage the right tokenizer
copy_file(
name = "stories15M",
src = "stories15M.bin",
out = "stories15M.tinyllama",
allow_symlink = True,
visibility = ["//visibility:public"],
)
copy_file(
name = "stories110M",
src = "stories110M.bin",
out = "stories110M.tinyllama",
allow_symlink = True,
visibility = ["//visibility:public"],
)
""",
commit = "0bd21da7698eaf29a0d7de3992de8a46ef624add",
includes = [
"stories15M.bin",
"stories110M.bin",
],
model = "karpathy/tinyllamas",
)
use_repo(huggingface, "Karpathy-TinyLlama-Stories")
http_file(
name = "Karpathy-TinyLlama-Tokenizer",
downloaded_file_path = "stories260K.tinyllama",
sha256 = "50a52ef822ee9e83de5ce9d0be0a025a773d019437f58b5ff9dcafb063ece361",
url = "https://github.com/karpathy/llama2.c/raw/c02865df300f3bd9e567ce061000dc23bf785a17/tokenizer.bin",
)
# Llama 3.2 # Llama 3.2
huggingface = use_extension("@zml//bazel:huggingface.bzl", "huggingface")
huggingface.model( huggingface.model(
name = "Meta-Llama-3.2-1B-Instruct", name = "Meta-Llama-3.2-1B-Instruct",
build_file_content = """\ build_file_content = """\
package(default_visibility = ["//visibility:public"]) package(default_visibility = ["//visibility:public"])
filegroup( filegroup(
name = "model", name = "Meta-Llama-3.2-1B-Instruct",
srcs = ["model.safetensors"], srcs = glob(["*.json", "*.safetensors"]),
)
filegroup(
name = "tokenizer",
srcs = ["tokenizer.json"],
) )
""", """,
commit = "9213176726f574b556790deb65791e0c5aa438b6", commit = "9213176726f574b556790deb65791e0c5aa438b6",
includes = [ includes = [
"model.safetensors", "*.safetensors",
"tokenizer.json", "*.json",
], ],
model = "meta-llama/Llama-3.2-1B-Instruct", model = "meta-llama/Llama-3.2-1B-Instruct",
) )
@ -115,129 +75,87 @@ huggingface.model(
build_file_content = """\ build_file_content = """\
package(default_visibility = ["//visibility:public"]) package(default_visibility = ["//visibility:public"])
filegroup( filegroup(
name = "model", name = "Meta-Llama-3.2-3B-Instruct",
srcs = glob(["*.safetensors"]) + ["model.safetensors.index.json"], srcs = glob(["*.json", "*.safetensors"]),
)
filegroup(
name = "tokenizer",
srcs = ["tokenizer.json"],
) )
""", """,
commit = "0cb88a4f764b7a12671c53f0838cd831a0843b95", commit = "0cb88a4f764b7a12671c53f0838cd831a0843b95",
includes = [ includes = [
"*.safetensors", "*.safetensors",
"model.safetensors.index.json", "*.json",
"tokenizer.json",
], ],
model = "meta-llama/Llama-3.2-3B-Instruct", model = "meta-llama/Llama-3.2-3B-Instruct",
) )
use_repo(huggingface, "Meta-Llama-3.2-3B-Instruct") use_repo(huggingface, "Meta-Llama-3.2-3B-Instruct")
# Llama 3.1 # Llama 3.1
huggingface.model( huggingface.model(
name = "Meta-Llama-3.1-8B-Instruct", name = "Meta-Llama-3.1-8B-Instruct",
build_file_content = """\ build_file_content = """\
package(default_visibility = ["//visibility:public"]) package(default_visibility = ["//visibility:public"])
filegroup( filegroup(
name = "model", name = "Meta-Llama-3.1-8B-Instruct",
srcs = glob(["*.safetensors"]) + ["model.safetensors.index.json"], srcs = glob(["*.json", "*.safetensors"]),
)
filegroup(
name = "tokenizer",
srcs = ["tokenizer.json"],
) )
""", """,
commit = "5206a32e0bd3067aef1ce90f5528ade7d866253f", commit = "5206a32e0bd3067aef1ce90f5528ade7d866253f",
includes = [ includes = [
"*.safetensors", "*.safetensors",
"model.safetensors.index.json", "*.json",
"tokenizer.json",
], ],
model = "meta-llama/Meta-Llama-3.1-8B-Instruct", model = "meta-llama/Meta-Llama-3.1-8B-Instruct",
) )
use_repo(huggingface, "Meta-Llama-3.1-8B-Instruct") use_repo(huggingface, "Meta-Llama-3.1-8B-Instruct")
huggingface.model( huggingface.model(
name = "Meta-Llama-3.1-70B-Instruct", name = "Meta-Llama-3.1-70B-Instruct",
build_file_content = """\ build_file_content = """\
package(default_visibility = ["//visibility:public"]) package(default_visibility = ["//visibility:public"])
filegroup( filegroup(
name = "model", name = "Meta-Llama-3.1-70B-Instruct",
srcs = glob(["*.safetensors"]) + ["model.safetensors.index.json"], srcs = glob(["*.json", "*.safetensors"]),
)
filegroup(
name = "tokenizer",
srcs = ["tokenizer.json"],
) )
""", """,
commit = "945c8663693130f8be2ee66210e062158b2a9693", commit = "945c8663693130f8be2ee66210e062158b2a9693",
includes = [ includes = [
"*.safetensors", "*.safetensors",
"model.safetensors.index.json", "*.json",
"tokenizer.json",
], ],
model = "meta-llama/Meta-Llama-3.1-70B-Instruct", model = "meta-llama/Meta-Llama-3.1-70B-Instruct",
) )
use_repo(huggingface, "Meta-Llama-3.1-70B-Instruct") use_repo(huggingface, "Meta-Llama-3.1-70B-Instruct")
huggingface.model( huggingface.model(
name = "TinyLlama-1.1B-Chat-v1.0", name = "TinyLlama-120M-scratch",
build_file_content = """\ build_file_content = """\
package(default_visibility = ["//visibility:public"]) package(default_visibility = ["//visibility:public"])
filegroup( filegroup(
name = "model", name = "TinyLlama-120M-scratch",
srcs = ["model.safetensors"], srcs = glob(["*.json", "*.safetensors"]),
)
filegroup(
name = "tokenizer",
srcs = ["tokenizer.model"],
) )
""", """,
commit = "fe8a4ea1ffedaf415f4da2f062534de366a451e6", commit = "89c1bb4ea00861ddaa26c55f102ccb25e161feee",
includes = [ includes = [
"model.safetensors", "*.safetensors",
"tokenizer.model", "*.json",
], ],
model = "TinyLlama/TinyLlama-1.1B-Chat-v1.0", model = "Hoyeon/TinyLlama-120M-scratch",
) )
use_repo(huggingface, "TinyLlama-1.1B-Chat-v1.0") use_repo(huggingface, "TinyLlama-120M-scratch")
#OpenLLaMa
huggingface.model(
name = "OpenLM-Research-OpenLLaMA-3B",
build_file_content = """\
load("@bazel_skylib//rules:copy_file.bzl", "copy_file")
package(default_visibility = ["//visibility:public"]) bazel_dep(name = "rules_rust", version = "0.57.0")
rust = use_extension("@rules_rust//rust:extensions.bzl", "rust")
filegroup( rust.toolchain(
name = "model", edition = "2021",
srcs = ["model.safetensors"], versions = ["1.84.0"],
) extra_target_triples = [
"aarch64-apple-darwin",
filegroup( "aarch64-unknown-linux-gnu",
name = "tokenizer", "x86_64-unknown-linux-gnu",
srcs = [":tokenizer_pb"],
)
# leverage copy_file to rename tokenizer extension
# which allow zml.aio.detectFormatAndLoadTokenizer
# to leverage the right tokenizer
copy_file(
name = "tokenizer_pb",
src = "tokenizer.model",
out = "tokenizer.pb",
allow_symlink = True,
)
""",
commit = "fcc2e809eb8f14dabba84d76a0ddc17b8ea05356",
includes = [
"model.safetensors",
"tokenizer.model",
], ],
model = "openlm-research/open_llama_3b",
) )
use_repo(huggingface, "OpenLM-Research-OpenLLaMA-3B") use_repo(rust, "rust_toolchains")
register_toolchains("@rust_toolchains//:all")

File diff suppressed because one or more lines are too long

View File

@ -1,3 +1,4 @@
load("@aspect_bazel_lib//lib:expand_template.bzl", "expand_template")
load("@aspect_bazel_lib//lib:tar.bzl", "mtree_spec", "tar") load("@aspect_bazel_lib//lib:tar.bzl", "mtree_spec", "tar")
load("@aspect_bazel_lib//lib:transitions.bzl", "platform_transition_filegroup") load("@aspect_bazel_lib//lib:transitions.bzl", "platform_transition_filegroup")
load("@bazel_skylib//rules:native_binary.bzl", "native_binary") load("@bazel_skylib//rules:native_binary.bzl", "native_binary")
@ -12,7 +13,7 @@ zig_cc_binary(
], ],
main = "main.zig", main = "main.zig",
deps = [ deps = [
"//third_party/tigerbeetle:flags", "@com_github_hejsil_clap//:clap",
"@zml//async", "@zml//async",
"@zml//stdx", "@zml//stdx",
"@zml//zml", "@zml//zml",
@ -20,18 +21,35 @@ zig_cc_binary(
) )
cc_binary( cc_binary(
name = "Llama-3.1-8B-Instruct", name = "TinyLlama-120M-scratch",
args = [ args = [
"--model=$(location @Meta-Llama-3.1-8B-Instruct//:model.safetensors.index.json)", "--config=$(location @TinyLlama-120M-scratch//:config.json)",
"--tokenizer=$(location @Meta-Llama-3.1-8B-Instruct//:tokenizer)", "--weights=$(location @TinyLlama-120M-scratch//:model.safetensors)",
"--num-heads=32", "--tokenizer=$(location @TinyLlama-120M-scratch//:tokenizer.json)",
"--num-kv-heads=8", "--no-llama3=true", # don't do llama3 template prompt encoding
"--rope-freq-base=500000", "--sharding=false", # don't shard this
], ],
data = [ data = [
"@Meta-Llama-3.1-8B-Instruct//:model", "@TinyLlama-120M-scratch",
"@TinyLlama-120M-scratch//:config.json",
"@TinyLlama-120M-scratch//:model.safetensors",
"@TinyLlama-120M-scratch//:tokenizer.json",
],
deps = [":llama_lib"],
)
cc_binary(
name = "Llama-3.1-8B-Instruct",
args = [
"--config=$(location @Meta-Llama-3.1-8B-Instruct//:config.json)",
"--weights=$(location @Meta-Llama-3.1-8B-Instruct//:model.safetensors.index.json)",
"--tokenizer=$(location @Meta-Llama-3.1-8B-Instruct//:tokenizer.json)",
],
data = [
"@Meta-Llama-3.1-8B-Instruct",
"@Meta-Llama-3.1-8B-Instruct//:config.json",
"@Meta-Llama-3.1-8B-Instruct//:model.safetensors.index.json", "@Meta-Llama-3.1-8B-Instruct//:model.safetensors.index.json",
"@Meta-Llama-3.1-8B-Instruct//:tokenizer", "@Meta-Llama-3.1-8B-Instruct//:tokenizer.json",
], ],
deps = [":llama_lib"], deps = [":llama_lib"],
) )
@ -39,32 +57,32 @@ cc_binary(
cc_binary( cc_binary(
name = "Llama-3.1-70B-Instruct", name = "Llama-3.1-70B-Instruct",
args = [ args = [
"--model=$(location @Meta-Llama-3.1-70B-Instruct//:model.safetensors.index.json)", "--config=$(location @Meta-Llama-3.1-70B-Instruct//:config.json)",
"--tokenizer=$(location @Meta-Llama-3.1-70B-Instruct//:tokenizer)", "--weights=$(location @Meta-Llama-3.1-70B-Instruct//:model.safetensors.index.json)",
"--num-heads=64", "--tokenizer=$(location @Meta-Llama-3.1-70B-Instruct//:tokenizer.json)",
"--num-kv-heads=8",
"--rope-freq-base=500000",
], ],
data = [ data = [
"@Meta-Llama-3.1-70B-Instruct//:model", "@Meta-Llama-3.1-70B-Instruct",
"@Meta-Llama-3.1-70B-Instruct//:config.json",
"@Meta-Llama-3.1-70B-Instruct//:model.safetensors.index.json", "@Meta-Llama-3.1-70B-Instruct//:model.safetensors.index.json",
"@Meta-Llama-3.1-70B-Instruct//:tokenizer", "@Meta-Llama-3.1-70B-Instruct//:tokenizer.json",
], ],
deps = [":llama_lib"], deps = [":llama_lib"],
) )
cc_binary( cc_binary(
name = "Llama-3.2-1B-Instruct", name = "Llama-3.2-1B-Instruct",
args = [ args = [
"--model=$(location @Meta-Llama-3.2-1B-Instruct//:model.safetensors)", "--config=$(location @Meta-Llama-3.2-1B-Instruct//:config.json)",
"--tokenizer=$(location @Meta-Llama-3.2-1B-Instruct//:tokenizer)", "--weights=$(location @Meta-Llama-3.2-1B-Instruct//:model.safetensors)",
"--num-heads=32", "--tokenizer=$(location @Meta-Llama-3.2-1B-Instruct//:tokenizer.json)",
"--num-kv-heads=8",
"--rope-freq-base=500000",
], ],
data = [ data = [
"@Meta-Llama-3.2-1B-Instruct",
"@Meta-Llama-3.2-1B-Instruct//:config.json",
"@Meta-Llama-3.2-1B-Instruct//:model.safetensors", "@Meta-Llama-3.2-1B-Instruct//:model.safetensors",
"@Meta-Llama-3.2-1B-Instruct//:tokenizer", "@Meta-Llama-3.2-1B-Instruct//:tokenizer.json",
], ],
deps = [":llama_lib"], deps = [":llama_lib"],
) )
@ -72,86 +90,26 @@ cc_binary(
cc_binary( cc_binary(
name = "Llama-3.2-3B-Instruct", name = "Llama-3.2-3B-Instruct",
args = [ args = [
"--model=$(location @Meta-Llama-3.2-3B-Instruct//:model.safetensors.index.json)", "--config=$(location @Meta-Llama-3.2-3B-Instruct//:config.json)",
"--tokenizer=$(location @Meta-Llama-3.2-3B-Instruct//:tokenizer)", "--weights=$(location @Meta-Llama-3.2-3B-Instruct//:model.safetensors.index.json)",
"--num-heads=24", "--tokenizer=$(location @Meta-Llama-3.2-3B-Instruct//:tokenizer.json)",
"--num-kv-heads=8",
"--rope-freq-base=500000",
], ],
data = [ data = [
"@Meta-Llama-3.2-3B-Instruct//:model", "@Meta-Llama-3.2-3B-Instruct",
"@Meta-Llama-3.2-3B-Instruct//:config.json",
"@Meta-Llama-3.2-3B-Instruct//:model.safetensors.index.json", "@Meta-Llama-3.2-3B-Instruct//:model.safetensors.index.json",
"@Meta-Llama-3.2-3B-Instruct//:tokenizer", "@Meta-Llama-3.2-3B-Instruct//:tokenizer.json",
],
deps = [":llama_lib"],
)
cc_binary(
name = "OpenLLaMA-3B",
args = [
"--model=$(location @OpenLM-Research-OpenLLaMA-3B//:model)",
"--tokenizer=$(location @OpenLM-Research-OpenLLaMA-3B//:tokenizer)",
"--num-heads=32",
"--num-kv-heads=32",
"--rope-freq-base=10000",
],
data = [
"@OpenLM-Research-OpenLLaMA-3B//:model",
"@OpenLM-Research-OpenLLaMA-3B//:tokenizer",
],
deps = [":llama_lib"],
)
cc_binary(
name = "TinyLlama-1.1B-Chat",
args = [
"--model=$(location @TinyLlama-1.1B-Chat-v1.0//:model.safetensors)",
"--tokenizer=$(location @TinyLlama-1.1B-Chat-v1.0//:tokenizer)",
"--num-heads=32",
"--num-kv-heads=4",
"--rope-freq-base=10000",
],
data = [
"@TinyLlama-1.1B-Chat-v1.0//:model.safetensors",
"@TinyLlama-1.1B-Chat-v1.0//:tokenizer",
],
deps = [":llama_lib"],
)
cc_binary(
name = "TinyLlama-Stories-110M",
args = [
"--model=$(location @Karpathy-TinyLlama-Stories//:stories110M)",
"--tokenizer=$(location @Karpathy-TinyLlama-Tokenizer//file)",
],
data = [
"@Karpathy-TinyLlama-Stories//:stories110M",
"@Karpathy-TinyLlama-Tokenizer//file",
],
deps = [":llama_lib"],
)
cc_binary(
name = "TinyLlama-Stories-15M",
args = [
"--model=$(location @Karpathy-TinyLlama-Stories//:stories15M)",
"--tokenizer=$(location @Karpathy-TinyLlama-Tokenizer//file)",
],
data = [
"@Karpathy-TinyLlama-Stories//:stories15M",
"@Karpathy-TinyLlama-Tokenizer//file",
], ],
deps = [":llama_lib"], deps = [":llama_lib"],
) )
#
zig_cc_binary( zig_cc_binary(
name = "test-implementation", name = "test-implementation",
srcs = ["llama.zig"], srcs = ["llama.zig"],
args = [ args = [
"--model=$(location @Meta-Llama-3.1-8B-Instruct//:model.safetensors.index.json)", "--weights=$(location @Meta-Llama-3.1-8B-Instruct//:model.safetensors.index.json)",
"--num-heads=32", "--config=$(location @Meta-Llama-3.1-8B-Instruct//:config.json)",
"--num-kv-heads=8",
"--rope-freq-base=500000",
], ],
data = [ data = [
"@Meta-Llama-3.1-8B-Instruct//:model", "@Meta-Llama-3.1-8B-Instruct//:model",
@ -184,12 +142,12 @@ zig_cc_binary(
mtree_spec( mtree_spec(
name = "mtree", name = "mtree",
srcs = [":llama"], srcs = [":Llama-3.2-1B-Instruct"],
) )
tar( tar(
name = "archive", name = "archive",
srcs = [":llama"], srcs = [":Llama-3.2-1B-Instruct"],
args = [ args = [
"--options", "--options",
"zstd:compression-level=9", "zstd:compression-level=9",
@ -198,10 +156,33 @@ tar(
mtree = ":mtree", mtree = ":mtree",
) )
expand_template(
name = "entrypoint",
data = [
":Llama-3.2-1B-Instruct",
"@Meta-Llama-3.2-1B-Instruct",
"@Meta-Llama-3.2-1B-Instruct//:config.json",
"@Meta-Llama-3.2-1B-Instruct//:model.safetensors",
"@Meta-Llama-3.2-1B-Instruct//:tokenizer.json",
],
substitutions = {
":config": "$(rlocationpath @Meta-Llama-3.2-1B-Instruct//:config.json)",
":weights": "$(rlocationpath @Meta-Llama-3.2-1B-Instruct//:model.safetensors)",
":tokenizer": "$(rlocationpath @Meta-Llama-3.2-1B-Instruct//:tokenizer.json)",
},
template = [
"./{}/Llama-3.2-1B-Instruct".format(package_name()),
"--config=./{}/Llama-3.2-1B-Instruct.runfiles/:config".format(package_name()),
"--weights=./{}/Llama-3.2-1B-Instruct.runfiles/:weights".format(package_name()),
"--tokenizer=./{}/Llama-3.2-1B-Instruct.runfiles/:tokenizer".format(package_name()),
],
)
oci_image( oci_image(
name = "image_", name = "image_",
base = "@distroless_cc_debian12_debug", base = "@distroless_cc_debian12_debug",
entrypoint = ["./{}/llama".format(package_name())], # entrypoint = ["./{}/Llama-3.2-1B-Instruct".format(package_name())],
entrypoint = ":entrypoint",
tars = [ tars = [
"@zml//runtimes:layers", "@zml//runtimes:layers",
":archive", ":archive",
@ -218,7 +199,7 @@ oci_load(
name = "load", name = "load",
image = ":image", image = ":image",
repo_tags = [ repo_tags = [
"distroless/llama:latest", "distroless/llama-3.2-1b-instruct:latest",
], ],
) )
@ -226,5 +207,5 @@ oci_push(
name = "push", name = "push",
image = ":image", image = ":image",
remote_tags = ["latest"], remote_tags = ["latest"],
repository = "index.docker.io/steeve/llama", repository = "index.docker.io/steeve/llama-3.2-1b-instruct",
) )

View File

@ -1,4 +1,3 @@
const flags = @import("tigerbeetle/flags");
const std = @import("std"); const std = @import("std");
const stdx = @import("stdx"); const stdx = @import("stdx");
const zml = @import("zml"); const zml = @import("zml");
@ -12,36 +11,51 @@ const gguf = zml.io.gguf;
const expectClose = zml.testing.expectClose; const expectClose = zml.testing.expectClose;
const log = std.log.scoped(.llama); const log = std.log.scoped(.llama);
pub const LlamaOptions = struct {
gen_opts: zml.nn.SamplingStrategy,
max_seq_len: u32,
num_heads: i64,
num_kv_heads: i64,
rms_norm_eps: f32,
rope_opts: zml.nn.RopeOpts,
};
/// Llama architecture, using huggingface transformers naming. /// Llama architecture, using huggingface transformers naming.
/// Dimensions of activations: {.b, .s, .d} /// Dimensions of activations: {.b, .s, .d}
pub const LlamaLM = struct { pub const LlamaLM = struct {
lm_head: ?zml.nn.Linear = null, pub const Config = struct {
bos_token_id: u32,
eos_token_id: stdx.json.Union(union(enum) {
int: u32,
ints: []u32,
}),
num_hidden_layers: usize,
num_attention_heads: usize,
num_key_value_heads: usize,
rope_theta: f32,
max_position_embeddings: usize,
rms_norm_eps: f32,
};
pub const Options = struct {
sampling_strategy: ?zml.nn.SamplingStrategy,
max_seq_len: usize,
};
lm_head: ?zml.nn.Linear,
model: Llama, model: Llama,
// Options controlling generation // Options controlling generation
gen_opts: zml.nn.SamplingStrategy = .{}, gen_opts: zml.nn.SamplingStrategy = .{},
config: Config,
pub fn init(self: *LlamaLM, options: LlamaOptions) void { pub fn init(self: *LlamaLM, config: Config, options: Options) void {
self.gen_opts = options.gen_opts; self.config = config;
self.model.max_seq_len = options.max_seq_len; self.gen_opts = options.sampling_strategy orelse .{};
self.model.num_heads = options.num_heads; self.model.max_seq_len = @intCast(options.max_seq_len);
self.model.num_kv_heads = options.num_kv_heads; self.model.num_heads = @intCast(config.num_attention_heads);
self.model.rope_opts = options.rope_opts; self.model.num_kv_heads = @intCast(config.num_key_value_heads);
self.model.rope_opts = .{
.impl = .sequential,
.freq_base = config.rope_theta,
};
for (self.model.layers) |*layer| { for (self.model.layers) |*layer| {
layer.self_attn.num_heads = options.num_heads; layer.self_attn.num_heads = self.model.num_heads;
layer.self_attn.num_kv_heads = options.num_kv_heads; layer.self_attn.num_kv_heads = self.model.num_kv_heads;
layer.self_attn.rope_opts = options.rope_opts; layer.self_attn.rope_opts = self.model.rope_opts;
layer.input_layernorm.eps = options.rms_norm_eps; layer.input_layernorm.eps = config.rms_norm_eps;
layer.post_attention_layernorm.eps = options.rms_norm_eps; layer.post_attention_layernorm.eps = config.rms_norm_eps;
layer.mlp.up_proj.weight = layer.mlp.up_proj.weight.withSharding(.{0}); layer.mlp.up_proj.weight = layer.mlp.up_proj.weight.withSharding(.{0});
layer.mlp.gate_proj.weight = layer.mlp.gate_proj.weight.withSharding(.{0}); layer.mlp.gate_proj.weight = layer.mlp.gate_proj.weight.withSharding(.{0});
layer.mlp.down_proj.weight = layer.mlp.down_proj.weight.withSharding(.{1}); layer.mlp.down_proj.weight = layer.mlp.down_proj.weight.withSharding(.{1});
@ -54,88 +68,58 @@ pub const LlamaLM = struct {
// TODO(Corentin): Fix lm_head sharding when top-k sampling is enabled. // TODO(Corentin): Fix lm_head sharding when top-k sampling is enabled.
// It currently crashes/compilation fails // It currently crashes/compilation fails
if (options.gen_opts.topk == 1) { if (self.gen_opts.topk == 1 and self.lm_head != null) {
if (self.lm_head) |lm_head| { self.lm_head.?.weight = self.lm_head.?.weight.withSharding(.{0});
self.lm_head.?.weight = lm_head.weight.withSharding(.{0});
}
} }
} }
/// Predicts the token at `token_index` position. /// Predicts the token at `token_index` position.
/// Returns: /// Returns:
/// - updated `tokens`, /// - updated `tokens`,
/// - `token_idx` + 1,
/// - updated KV cache /// - updated KV cache
/// - a Rng state to allow for probabilistic generation /// - a Rng state to allow for probabilistic generation
pub fn forward( pub fn forward(
self: LlamaLM, self: LlamaLM,
tokens_: Tensor, tokens_: Tensor,
token_index: Tensor, token_index: Tensor,
kv_cache: ?KvCache, kv_cache: KvCache,
rng: Tensor.Rng, rng: Tensor.Rng,
) struct { Tensor, Tensor, KvCache, Tensor.Rng } { ) struct { Tensor, KvCache, Tensor.Rng } {
stdx.debug.assert(tokens_.dtype() == .i32 and tokens_.rank() >= 1 and token_index.dtype() == .i32 and token_index.rank() == 0, "Can't run Llama ! Expected >=1d tokens and 0d token_index, got: {} and {}", .{ tokens_, token_index }); stdx.debug.assert(tokens_.dtype() == .u32 and tokens_.rank() >= 1 and token_index.dtype() == .u32 and token_index.rank() <= 1, "Can't run Llama ! Expected >=1d tokens and 0d token_index, got: {} and {}", .{ tokens_, token_index });
var tokens = tokens_.withPartialTags(.{.s}); var tokens = tokens_.withPartialTags(.{.s});
const out, const updated_kv_cache = zml.call(self.model, .forward, .{ tokens, if (kv_cache == null) null else token_index, kv_cache }); const out, const updated_kv_cache = zml.call(self.model, .forward, .{ tokens, token_index, kv_cache });
tokens, const new_rng = self.updateTokens(tokens, token_index, out, rng, self.gen_opts); tokens, const new_rng = self.sampleTokens(self.lm_head, tokens, out, rng, self.gen_opts);
return .{ tokens, increment(0, token_index), updated_kv_cache, new_rng }; return .{ tokens, updated_kv_cache, new_rng };
} }
pub fn updateTokens( pub fn sampleTokens(
self: LlamaLM, self: LlamaLM,
lm_head_: ?zml.nn.Linear,
tokens_: Tensor, tokens_: Tensor,
token_index: Tensor,
out_: Tensor, out_: Tensor,
rng: Tensor.Rng, rng: Tensor.Rng,
opts: zml.nn.SamplingStrategy, opts: zml.nn.SamplingStrategy,
) struct { Tensor, Tensor.Rng } { ) struct { Tensor, Tensor.Rng } {
const tokens = tokens_.withPartialTags(.{.s});
const out = out_.withPartialTags(.{ .s, .d }); const out = out_.withPartialTags(.{ .s, .d });
const next_token_pred = out.gatherValues(.s, token_index, .{}); var logits = blk: {
var logits = if (self.lm_head) |lm_head| if (lm_head_) |lm_head| {
zml.call(lm_head, .forward, .{next_token_pred}) break :blk zml.call(lm_head, .forward, .{out});
else } else {
self.model.embed_tokens.weight.withTags(.{ .voc, .d }).dot(next_token_pred, .{.d}); break :blk self.model.embed_tokens.weight.withTags(.{ .voc, .d }).dot(out, .{.d});
}
};
if (logits.shape().hasTag(.voc) == null) if (logits.shape().hasTag(.voc) == null)
logits = logits.rename(.{ .d = .voc }); logits = logits.rename(.{ .d = .voc });
const next_token, const new_rng = zml.nn.sampleTokens(logits, opts, rng); const next_tokens, const new_rng = zml.nn.sampleTokens(logits, opts, rng);
const next_token_index = token_index.addConstant(1); return .{ next_tokens.reuseBuffer(tokens_), new_rng };
const new_tokens = tokens.dynamicUpdateSlice(.{ .s = next_token_index }, next_token);
return .{ new_tokens.reuseBuffer(tokens_), new_rng };
} }
pub fn increment(_: u8, token_index: Tensor) Tensor { pub fn increment(_: u8, token_index: Tensor) Tensor {
return token_index.addConstant(1); return token_index.addConstant(1).reuseBuffer(token_index);
}
/// Run the generation entirely within pjrt.
pub fn generate(self: LlamaLM, tokens: Tensor, token_index: Tensor, rng: Tensor.Rng) Tensor {
// Generate the first token using the prompt and generate the KV-cache initial values.
const prefill = zml.call(self, .forward, .{ tokens, token_index, null, rng });
const Gen = struct {
/// Same as LlamaLM.forward but without optional in the signature
pub fn forward(lm: LlamaLM, t_ids: Tensor, t_idx: Tensor, kv_cache_: KvCache, inner_rng: Tensor.Rng) struct { Tensor, Tensor, KvCache, Tensor.Rng } {
var kv_cache = kv_cache_;
kv_cache.k = kv_cache.k.withPartialTags(.{ .layer, .h, .k, .hd });
kv_cache.v = kv_cache.v.withPartialTags(.{ .layer, .h, .k, .hd });
return zml.call(lm, .forward, .{ t_ids._ctx, t_ids, t_idx, kv_cache, inner_rng });
}
// / Stops when we generated `max_seq_len` tokens.
pub fn shouldContinue(lm: LlamaLM, t_ids: Tensor, t_idx: Tensor, kv_cache: KvCache, inner_rng: Tensor.Rng) Tensor {
_ = kv_cache;
_ = inner_rng;
std.debug.assert(t_ids.dim(1) == lm.model.max_seq_len);
return t_idx.cmp(.LT, Tensor.scalar(t_ids._ctx, lm.model.max_seq_len, t_idx.dtype()));
}
};
// Generate remaining tokens using the KV-cache, return tokens.
return zml.ops.while_(Gen.shouldContinue, Gen.forward, self, prefill)[0];
} }
}; };
@ -177,33 +161,28 @@ pub const Llama = struct {
/// Forward one token, using KV cache for previous tokens. /// Forward one token, using KV cache for previous tokens.
/// Returns result and updated KV cache. /// Returns result and updated KV cache.
pub fn forward(self: Llama, tokens: Tensor, token_index: ?Tensor, kv_cache: ?KvCache) struct { Tensor, KvCache } { pub fn forward(self: Llama, tokens: Tensor, token_index: Tensor, kv_cache: KvCache) struct { Tensor, KvCache } {
const embeds = embed(self.embed_tokens, tokens, token_index); const embeds = embed(self.embed_tokens, tokens);
var hidden = embeds; var hidden = embeds;
const kv_cache0 = kv_cache orelse self.initKvCache(embeds.shape());
var updated_kv_cache = kv_cache0; var updated_kv_cache = kv_cache;
for (self.layers, 0..) |layer, i| { for (self.layers, 0..) |layer, i| {
hidden, updated_kv_cache = zml.call(layer, .forward, .{ hidden, token_index, updated_kv_cache.atLayer(i) }); hidden, updated_kv_cache = zml.call(layer, .forward, .{ hidden, token_index, updated_kv_cache.atLayer(i) });
} }
const output = zml.call(self.norm, .forward, .{hidden}); const output = zml.call(self.norm, .forward, .{hidden});
return .{ output, updated_kv_cache.reuseBuffer(kv_cache0) }; return .{ output, updated_kv_cache.reuseBuffer(kv_cache) };
} }
pub fn embed(embed_tokens_: zml.nn.TokenEmbedding, tokens_: Tensor, token_index: ?Tensor) Tensor { pub fn embed(embed_tokens_: zml.nn.TokenEmbedding, tokens_: Tensor) Tensor {
const tokens = if (token_index) |idx| return zml.call(embed_tokens_, .forward, .{tokens_}).withPartialTags(.{.d});
tokens_.dynamicSlice1d(-1, .{ .start = idx, .len = 1 })
else
tokens_;
return zml.call(embed_tokens_, .forward, .{tokens}).withPartialTags(.{ .s, .d });
} }
fn initKvCache(self: Llama, embed_shape: zml.Shape) KvCache { fn initKvCache(self: Llama, embed_shape: zml.Shape) KvCache {
const dims = self.shape(); const dims = self.shape();
var kv_shape = embed_shape.insert(0, .{ .layer = dims.layer }).rename(.{ .s = .k }).splitAxes(.{ .d = .{ .h = dims.nkvh, .hd = dims.hd } }); var kv_shape = embed_shape.insert(0, .{ .layer = dims.layer }).rename(.{ .s = .k }).splitAxes(.{ .d = .{ .h = dims.nkvh, .hd = dims.hd } });
const perm = kv_shape.contiguousPerm(.{ .h, .k, .hd }); const perm = kv_shape.contiguousPerm(.{ .k, .h, .hd });
kv_shape = kv_shape.transpose(perm.constSlice()); kv_shape = kv_shape.transpose(perm.constSlice());
return KvCache.init(kv_shape); return KvCache.init(kv_shape);
} }
@ -218,8 +197,8 @@ pub const TransformerLayer = struct {
pub fn forward( pub fn forward(
self: TransformerLayer, self: TransformerLayer,
x0: Tensor, x0: Tensor,
token_index: ?Tensor, token_index: Tensor,
kv_cache: ?KvCache, kv_cache: KvCache,
) struct { Tensor, KvCache } { ) struct { Tensor, KvCache } {
// Self Attention // Self Attention
//log.debug("TransformerLayer({}) -> {}", .{ x0, self.input_layernorm.forward(x0) }); //log.debug("TransformerLayer({}) -> {}", .{ x0, self.input_layernorm.forward(x0) });
@ -287,39 +266,41 @@ pub const SelfAttn = struct {
pub fn forward( pub fn forward(
self: SelfAttn, self: SelfAttn,
x: Tensor, x: Tensor,
token_index: ?Tensor, token_index: Tensor,
kv_cache_: ?KvCache, kv_cache: KvCache,
) struct { Tensor, KvCache } { ) struct { Tensor, KvCache } {
// log.debug("x.shape: {}", .{x.shape()});
const num_kv_heads = if (self.num_kv_heads > 0) self.num_kv_heads else self.num_heads; const num_kv_heads = if (self.num_kv_heads > 0) self.num_kv_heads else self.num_heads;
var q = zml.call(self.q_proj, .forward, .{x}).splitAxis(-1, .{ .h = self.num_heads, .hd = .auto }).withSharding(.{.h}); var q = zml.call(self.q_proj, .forward, .{x}).splitAxis(-1, .{ .h = self.num_heads, .hd = .auto }).withSharding(.{.h});
var k = zml.call(self.k_proj, .forward, .{x}).splitAxis(-1, .{ .h = num_kv_heads, .hd = .auto }).withSharding(.{.h}); var k = zml.call(self.k_proj, .forward, .{x}).splitAxis(-1, .{ .h = num_kv_heads, .hd = .auto }).withSharding(.{.h});
var v = zml.call(self.v_proj, .forward, .{x}).splitAxis(-1, .{ .h = num_kv_heads, .hd = .auto }).withSharding(.{.h}); var v = zml.call(self.v_proj, .forward, .{x}).splitAxis(-1, .{ .h = num_kv_heads, .hd = .auto }).withSharding(.{.h});
// Generate the attention mask. // Generate the attention mask.
const kv_cache = kv_cache_ orelse initKvCache(k.shape());
const seq_len = kv_cache.k.dim(.k); const seq_len = kv_cache.k.dim(.k);
var attn_mask = zml.nn.causalAttnMask(.{ .q = seq_len, .k = seq_len }, x.dtype(), null); var attn_mask = zml.nn.causalAttnMask(.{ .q = seq_len, .k = seq_len }, x.dtype(), null);
if (token_index) |idx| {
// Note: in Pytorch it would be very inefficient to generate the full attn_mask, // Note: in Pytorch it would be very inefficient to generate the full attn_mask,
// then slice into it, but XLA is able to optimize this correctly. // then slice into it, but XLA is able to optimize this correctly.
attn_mask = attn_mask.dynamicSlice(.{ .q = .{ .start = idx, .len = 1 } }); attn_mask = attn_mask.gatherSlices(zml.Shape.init(.{ .q = x.dim(.s) }, attn_mask.dtype()), token_index.reshape(.{ .b = token_index.shape().dim(0), .coord = 1 }), .{});
}
// In self-attention, .s axis is used both for keys and queries. // In self-attention, .s axis is used both for keys and queries.
q = zml.nn.rope(q, token_index, self.rope_opts); const pos_index = b: {
k = zml.nn.rope(k, token_index, self.rope_opts); const temp = Tensor.arange(.{ .end = x.dim(.s) }, token_index.dtype()).withTags(.{.s}).broad(zml.Shape.init(.{ .b = token_index.shape().dim(0), .s = x.dim(.s) }, token_index.dtype()));
break :b temp.add(token_index.withTags(.{.b}).broad(temp.shape()));
};
q = zml.nn.rope(q, pos_index, self.rope_opts);
k = zml.nn.rope(k, pos_index, self.rope_opts);
q = q.rename(.{ .s = .q }); q = q.rename(.{ .s = .q });
k = k.rename(.{ .s = .k }); k = k.rename(.{ .s = .k });
v = v.rename(.{ .s = .k }); v = v.rename(.{ .s = .k });
const new_kv_cache = kv_cache.update(k, v, token_index orelse Tensor.scalar(0, .i32)); const dtype = q.dtype();
if (token_index) |_| { const new_kv_cache = kv_cache.update(k, v, token_index);
stdx.debug.assert(q.dim(.q) == 1, "Expected dimension .q to be 1, got {}", .{q.dim(.q)}); k = new_kv_cache.keys().convert(dtype);
k = new_kv_cache.keys(); v = new_kv_cache.values().convert(dtype);
v = new_kv_cache.values();
}
const attn_output = zml.nn.sdpa(q, k, v, .{ .attn_mask = attn_mask, .allow_cudnn = false }); const attn_output = zml.nn.sdpa(q, k, v, .{ .attn_mask = attn_mask, .allow_cudnn = true });
// const attn_output = zml.nn.sdpaMemEfficient(q, k, v, .{ .attn_mask = attn_mask }, .{ .q_chunk_size = 4096, .k_chunk_size = 1024 });
const attn = attn_output.merge(.{ .d = .{ .h, .hd } }).rename(.{ .q = .s }); const attn = attn_output.merge(.{ .d = .{ .h, .hd } }).rename(.{ .q = .s });
return .{ zml.call(self.o_proj, .forward, .{attn}), new_kv_cache }; return .{ zml.call(self.o_proj, .forward, .{attn}), new_kv_cache };
} }
@ -330,7 +311,7 @@ pub const SelfAttn = struct {
const perm = kv_shape.contiguousPerm(.{ .h, .k, .hd }); const perm = kv_shape.contiguousPerm(.{ .h, .k, .hd });
kv_shape = kv_shape.transpose(perm.constSlice()); kv_shape = kv_shape.transpose(perm.constSlice());
var res = KvCache.init(kv_shape); var res = KvCache.init(kv_shape);
res.layer_index = Tensor.scalar(0, .i32); res.layer_index = Tensor.scalar(0, .u32);
return res; return res;
} }
}; };
@ -345,7 +326,7 @@ pub const KvCache = struct {
return .{ return .{
.k = Tensor.constant(kv_shape, kv_shape.dtype().one()).withSharding(.{.h}), .k = Tensor.constant(kv_shape, kv_shape.dtype().one()).withSharding(.{.h}),
.v = Tensor.constant(kv_shape, kv_shape.dtype().one()).withSharding(.{.h}), .v = Tensor.constant(kv_shape, kv_shape.dtype().one()).withSharding(.{.h}),
.layer_index = Tensor.scalar(-1, .i32), .layer_index = Tensor.scalar(-1, .u32),
}; };
} }
@ -353,7 +334,15 @@ pub const KvCache = struct {
return .{ return .{
.k = kv_shape, .k = kv_shape,
.v = kv_shape, .v = kv_shape,
.layer_index = zml.Shape.init(.{}, .i32), .layer_index = zml.Shape.init(.{}, .u32),
};
}
pub fn initBuffer(kv_shape: zml.Shape, platform: zml.Platform) !zml.Bufferized(KvCache) {
return .{
.k = try zml.Buffer.constant(platform, kv_shape, 1),
.v = try zml.Buffer.constant(platform, kv_shape, 1),
.layer_index = try zml.Buffer.constant(platform, zml.Shape.init(.{}, .u32), 0),
}; };
} }
@ -365,17 +354,33 @@ pub const KvCache = struct {
return self.v.dynamicSlice(.{ .layer = .{ .start = self.layer_index, .len = 1 } }).squeeze(.layer); return self.v.dynamicSlice(.{ .layer = .{ .start = self.layer_index, .len = 1 } }).squeeze(.layer);
} }
pub fn update(self: KvCache, new_k: Tensor, new_v: Tensor, token_index: Tensor) KvCache { pub fn update(self: KvCache, new_k: Tensor, new_v: Tensor, token_index: ?Tensor) KvCache {
return .{ const k_shape = self.k.shape().drop(.layer);
.k = self.k.dynamicUpdateSlice( var layer = self.layer_index;
.{ .layer = self.layer_index, .k = token_index }, layer = if (token_index) |idx| layer.broad(idx.shape()) else layer;
// transpose to match kv-cache layout
new_k.contiguous(.{ .h, .k, .hd }), return if (token_index) |idx| .{
.k = self.k.scatterSlices(
.{ .layer = layer, .k = idx },
new_k.convert(self.k.dtype()).transpose(k_shape),
.{ .indices_are_sorted = true, .update_fn = zml.Tensor.ScatterOpts.override },
).reuseBuffer(self.k), ).reuseBuffer(self.k),
.v = self.v.dynamicUpdateSlice( .v = self.v.scatterSlices(
.{ .layer = self.layer_index, .k = token_index }, .{ .layer = layer, .k = idx },
// transpose to match kv-cache layout new_v.convert(self.v.dtype()).transpose(k_shape),
new_v.contiguous(.{ .h, .k, .hd }), .{ .indices_are_sorted = true, .update_fn = zml.Tensor.ScatterOpts.override },
).reuseBuffer(self.v),
.layer_index = self.layer_index,
} else .{
.k = self.k.scatterSlices(
.{ .layer = layer },
new_k.convert(self.k.dtype()).transpose(k_shape),
.{ .indices_are_sorted = true, .update_fn = zml.Tensor.ScatterOpts.override },
).reuseBuffer(self.k),
.v = self.v.scatterSlices(
.{ .layer = layer },
new_v.convert(self.v.dtype()).transpose(k_shape),
.{ .indices_are_sorted = true, .update_fn = zml.Tensor.ScatterOpts.override },
).reuseBuffer(self.v), ).reuseBuffer(self.v),
.layer_index = self.layer_index, .layer_index = self.layer_index,
}; };
@ -385,7 +390,7 @@ pub const KvCache = struct {
return .{ return .{
.k = self.k, .k = self.k,
.v = self.v, .v = self.v,
.layer_index = Tensor.scalar(layer_index, .i32), .layer_index = Tensor.scalar(layer_index, .u32),
}; };
} }

View File

@ -1,248 +1,330 @@
const asynk = @import("async"); const asynk = @import("async");
const flags = @import("tigerbeetle/flags"); const clap = @import("clap");
const std = @import("std"); const std = @import("std");
const stdx = @import("stdx"); const stdx = @import("stdx");
const zml = @import("zml"); const zml = @import("zml");
const llama_mod = @import("llama.zig"); const llama = @import("llama.zig");
const LlamaLM = llama_mod.LlamaLM; const LlamaLM = llama.LlamaLM;
const Llama = llama_mod.Llama; const Llama = llama.Llama;
const KvCache = llama_mod.KvCache; const KvCache = llama.KvCache;
const TransformerLayer = llama_mod.TransformerLayer; const TransformerLayer = llama.TransformerLayer;
const SelfAttn = llama_mod.SelfAttn; const SelfAttn = llama.SelfAttn;
const Buffer = zml.Buffer; const Buffer = zml.Buffer;
const Tensor = zml.Tensor; const Tensor = zml.Tensor;
const ShapeOf = zml.ShapeOf; const ShapeOf = zml.ShapeOf;
const log = std.log.scoped(.llama); const log = std.log.scoped(.llama);
const eos_tokens: [3]i32 = .{ 128001, 128008, 128009 };
// set this to false to disable the verbose logging
const show_mlir = true;
pub const std_options = .{ pub const std_options = .{
.log_level = .warn, .log_level = .info,
.log_scope_levels = &[_]std.log.ScopeLevel{
.{ .scope = .zml_module, .level = if (show_mlir) .debug else .warn },
.{ .scope = .llama, .level = .info },
},
.logFn = asynk.logFn,
}; };
pub fn tokenizePromptLlama3(allocator: std.mem.Allocator, tokenizer: zml.tokenizer.Tokenizer, config: LlamaLM.Config, prompt: []const u8) ![]u32 {
var tokens = std.ArrayList(u32).init(allocator);
var encoder = try tokenizer.encoder();
defer encoder.deinit();
const start_header_id = tokenizer.token_to_id("<|start_header_id|>") orelse return error.NoSuchToken;
const end_header_id = tokenizer.token_to_id("<|end_header_id|>") orelse return error.NoSuchToken;
const eot_id = tokenizer.token_to_id("<|eot_id|>") orelse return error.NoSuchToken;
const newline_id = (try encoder.encode("\n"))[0];
try tokens.append(config.bos_token_id);
try tokens.append(start_header_id);
try tokens.appendSlice(try encoder.encode("user"));
try tokens.appendSlice(&.{ end_header_id, newline_id, newline_id });
try tokens.appendSlice(try encoder.encode(prompt));
try tokens.appendSlice(&.{ eot_id, newline_id });
try tokens.appendSlice(try encoder.encode("\n"));
try tokens.append(start_header_id);
try tokens.appendSlice(try encoder.encode("assistant"));
try tokens.append(end_header_id);
return tokens.toOwnedSlice();
}
pub fn generateText( pub fn generateText(
llama: LlamaLM, config: LlamaLM.Config,
llama_: LlamaLM,
mod_prefill: zml.ModuleExe(LlamaLM.forward), mod_prefill: zml.ModuleExe(LlamaLM.forward),
mod: zml.ModuleExe(LlamaLM.forward), mod_generate: zml.ModuleExe(LlamaLM.forward),
kv_cache_: zml.Bufferized(llama.KvCache),
tokenizer: zml.tokenizer.Tokenizer, tokenizer: zml.tokenizer.Tokenizer,
allocator: std.mem.Allocator, allocator: std.mem.Allocator,
seed: u128, seed: u128,
prompt: []const u8, prompt: []const u8,
skip_llama3_encoding: bool,
) ![]const u8 { ) ![]const u8 {
const prompt_tok = tokenizer.encode(allocator, prompt, .{}) catch unreachable; var tokenizer_encoder = try tokenizer.encoder();
log.debug("Tokenized Prompt {d}", .{prompt_tok}); defer tokenizer_encoder.deinit();
const dims = llama.model.shape(); var tokenizer_decoder = try tokenizer.decoder();
const max_seq_len = dims.s; defer tokenizer_decoder.deinit();
const token_buffer = try allocator.alloc(i32, @intCast(max_seq_len));
@memset(token_buffer, 0);
for (0..prompt_tok.len) |i| {
token_buffer[i] = @intCast(prompt_tok[i]);
}
const tracer_buffer = try allocator.alloc(u8, @intCast(max_seq_len)); const prompt_tok: []const u32 = if (skip_llama3_encoding) try tokenizer_encoder.encode(prompt) else try tokenizePromptLlama3(allocator, tokenizer, config, prompt);
defer allocator.free(token_buffer);
defer allocator.free(tracer_buffer);
defer allocator.free(prompt_tok); defer allocator.free(prompt_tok);
var output = std.ArrayList(u8).init(allocator);
defer output.deinit();
var tokens = try zml.Buffer.fromSlice(mod.platform(), .{max_seq_len}, token_buffer); const dims = llama_.model.shape();
var prefill_token_index = try zml.Buffer.fromSlice(mod.platform(), .{}, &[_]i32{@intCast(prompt_tok.len - 1)}); const max_seq_len = dims.s;
// Prefill
// initialize a 0..max_seq_len buffer with the tokenized prompt
const prefill_buffer = try allocator.alloc(u32, @intCast(max_seq_len));
@memset(prefill_buffer, 0);
for (0..prompt_tok.len) |i| {
prefill_buffer[i] = @intCast(prompt_tok[i]);
}
defer allocator.free(prefill_buffer);
const platform = mod_generate.platform();
// prepare device buffers for the prefill tokens and the index
var prefill_tokens = try zml.Buffer.fromSlice(platform, .{max_seq_len}, prefill_buffer);
defer prefill_tokens.deinit();
var prefill_token_index = try zml.Buffer.fromSlice(platform, .{}, &[_]u32{0});
defer prefill_token_index.deinit(); defer prefill_token_index.deinit();
var rng = try zml.Tensor.Rng.init(mod.platform(), seed); // init RNG and prefill
tokens, var token_index, var kv_cache, rng = mod_prefill.call(.{ tokens, prefill_token_index, null, rng }); var rng = try zml.Tensor.Rng.init(platform, seed);
defer token_index.deinit(); prefill_tokens, var kv_cache, rng = mod_prefill.call(.{ prefill_tokens, prefill_token_index, kv_cache_, rng });
defer kv_cache.k.deinit(); defer kv_cache.k.deinit();
defer kv_cache.v.deinit(); defer kv_cache.v.deinit();
defer kv_cache.layer_index.deinit(); defer kv_cache.layer_index.deinit();
// Prepare for token-by-token generation
var first_token_hostbuffer = [_]u32{prompt_tok[prompt_tok.len - 1]}; // start with the prompt's last token
var current_token = try zml.Buffer.fromSlice(platform, .{}, &first_token_hostbuffer);
defer current_token.deinit();
// Here we will copy the generated token from device
var generated_token_buffer = [_]u32{0};
// Here we collect the generated text
var output = std.ArrayList(u8).init(allocator);
defer output.deinit();
const tracer_buffer = try allocator.alloc(u8, @intCast(max_seq_len));
defer allocator.free(tracer_buffer);
const tracer = zml.tools.Tracer.init("ai.zml.models.llama"); const tracer = zml.tools.Tracer.init("ai.zml.models.llama");
var decode_progress = prompt_tok.len;
const output_tokens_len = max_seq_len - prompt_tok.len - 1; const output_tokens_len = max_seq_len - prompt_tok.len - 1;
const start = std.time.microTimestamp(); const start = std.time.microTimestamp();
const output_freq: u8 = 1;
var eos_index: ?usize = null; var num_tokens_generated: usize = 0;
for (0..output_tokens_len) |i| {
//_ = i; generation: for (0..output_tokens_len) |i| {
const frame_id = tracer.frameStart(try std.fmt.bufPrintZ(tracer_buffer, "Generate token {}/{}", .{ i + 1, output_tokens_len })); const frame_id = tracer.frameStart(try std.fmt.bufPrintZ(tracer_buffer, "Generate token {}/{}", .{ i + 1, output_tokens_len }));
tokens, const new_token_index, kv_cache, rng = mod.call(.{ tokens, token_index, kv_cache, rng });
token_index.deinit(); // current token index needs to go into a zml.Buffer
token_index = new_token_index; const token_index_buffer = &[_]u32{@intCast(prompt_tok.len + i)};
if ((i + 1) % output_freq == 0) { const token_index = try zml.Buffer.fromSlice(platform, .{}, token_index_buffer);
const n = output.items.len; defer token_index.deinit();
_ = try tokens.toHost(std.mem.sliceAsBytes(token_buffer));
try tokenizer.decodeWithOpts(&output, @ptrCast(token_buffer[decode_progress..][0..output_freq]), .{}); // call to generate the next token
decode_progress += output_freq; current_token, kv_cache, rng = mod_generate.call(.{ current_token, token_index, kv_cache, rng });
std.debug.print("{s}", .{output.items[n..]});
tracer.frameEnd(frame_id, try std.fmt.bufPrintZ(tracer_buffer, "Decoded token {}/{} : {s}", .{ i + 1, output_tokens_len, output.items[n..] }));
if (std.mem.indexOfAny(i32, token_buffer[decode_progress - output_freq ..], &eos_tokens)) |index| {
// Handle strange scenarios when eos id isn't the very next token after decode_progress
eos_index = decode_progress - output_freq + index;
break;
}
} else {
tracer.frameEnd(frame_id, try std.fmt.bufPrintZ(tracer_buffer, "Generated token {}/{}", .{ i + 1, output_tokens_len })); tracer.frameEnd(frame_id, try std.fmt.bufPrintZ(tracer_buffer, "Generated token {}/{}", .{ i + 1, output_tokens_len }));
// extract the generated token from the buffer
_ = try current_token.toHost(std.mem.sliceAsBytes(&generated_token_buffer));
const generated_token = generated_token_buffer[0];
// de-tokenize generated token into a string
const chunk = try tokenizer_decoder.next(@intCast(generated_token)) orelse unreachable;
num_tokens_generated = i;
// check for eos
switch (config.eos_token_id.value) {
.int => |eos| if (generated_token == @as(u32, @intCast(eos))) break :generation,
.ints => |eos_list| {
for (eos_list) |eos| {
if (generated_token == @as(u32, @intCast(eos))) break :generation;
} }
},
} }
var total_token_count: usize = max_seq_len;
const n = output.items.len; // collect and print generated sequence
if (eos_index) |end_idx| { try output.appendSlice(chunk);
// count = eos index + 1 std.debug.print("{s}", .{chunk});
total_token_count = end_idx + 1;
} }
const generated_token_count = total_token_count - prompt_tok.len;
try tokenizer.decodeWithOpts(&output, @ptrCast(token_buffer[decode_progress..total_token_count]), .{});
std.debug.print("{s}\n", .{output.items[n..]});
const end = std.time.microTimestamp(); const end = std.time.microTimestamp();
const duration = stdx.math.divFloat(f64, end - start, std.time.us_per_s); const duration = stdx.math.divFloat(f64, end - start, std.time.us_per_s);
const speed = @as(f64, @floatFromInt(generated_token_count)) / duration; const speed = @as(f64, @floatFromInt(num_tokens_generated)) / duration;
log.info("✅ Generated {d} tokens in {:.3}s: {d:.3}tok/s", .{ generated_token_count, duration, speed }); std.debug.print("\n", .{});
log.info("✅ Generated {d} tokens in {:.3}s: {d:.3}tok/s", .{ num_tokens_generated, duration, speed });
_ = try tokens.toHost(std.mem.sliceAsBytes(token_buffer));
output.clearRetainingCapacity();
try tokenizer.decodeWithOpts(&output, @ptrCast(token_buffer[0..total_token_count]), .{});
return output.toOwnedSlice(); return output.toOwnedSlice();
} }
const params = clap.parseParamsComptime(
\\--help print this help
\\--prompt <STRING> the prompt
\\--config <PATH> config.json path
\\--weights <PATH> model weights path
\\--tokenizer <PATH> tokenizer path
\\--seed <UINT> random seed (optional)
\\--seq-len <UINT> sequence length
\\--create-options <STRING> platform creation options JSON, defaults to {}
\\--no-llama3 <BOOL> skip prompt template
\\--sharding <BOOL> default: true: sharding on or off
);
pub fn bool_parser(in: []const u8) error{}!bool {
return std.mem.indexOfScalar(u8, "tTyY1", in[0]) != null;
}
pub fn main() !void { pub fn main() !void {
try asynk.AsyncThread.main(std.heap.c_allocator, asyncMain); try asynk.AsyncThread.main(std.heap.c_allocator, asyncMain);
} }
pub fn asyncMain() !void { pub fn asyncMain() !void {
const CliArgs = struct {
pub const help =
\\ llama --model=llama3.7B.safetensors --tokenizer=vocab.json --num_layers=2
;
model: []const u8,
tokenizer: ?[]const u8 = null,
layer_start: u8 = 0,
num_layers: ?u8 = null,
seq_len: u32 = 256,
topk: u32 = 2,
temperature: u32 = 1,
num_heads: ?i64 = null,
num_kv_heads: ?i64 = null,
rope_freq_base: ?i64 = null,
prompt: ?[]const u8 = null,
test_activations: ?[]const u8 = null,
seed: ?u128 = null,
// eg: --create-options='{"cuda":{"allocator":{"bfc":{"memory_fraction": 0.99}}}}'
create_options: []const u8 = "{}",
};
log.info(" LLama was compiled with {}", .{@import("builtin").mode}); log.info(" LLama was compiled with {}", .{@import("builtin").mode});
const allocator = std.heap.c_allocator; const allocator = std.heap.c_allocator;
const tmp = try std.fs.openDirAbsolute("/tmp", .{}); const parsers = comptime .{
try tmp.makePath("zml/llama/cache"); .BOOL = bool_parser,
.UINT = clap.parsers.int(usize, 0),
.STRING = clap.parsers.string,
.PATH = clap.parsers.string,
};
var diag: clap.Diagnostic = .{};
const stderr = std.io.getStdErr().writer();
var res = clap.parse(clap.Help, &params, parsers, .{
.diagnostic = &diag,
.allocator = allocator,
}) catch |err| {
diag.report(stderr, err) catch {};
stderr.print("usage: ", .{}) catch {};
clap.usage(stderr, clap.Help, &params) catch {};
stderr.print("\n", .{}) catch {};
return;
};
defer res.deinit();
if (res.args.help != 0) {
clap.help(std.io.getStdErr().writer(), clap.Help, &params, .{}) catch {};
return;
}
const config = blk: {
if (res.args.config) |config_json_path| {
var config_json_file = try asynk.File.open(config_json_path, .{ .mode = .read_only });
defer config_json_file.close() catch unreachable;
var reader = std.json.reader(allocator, config_json_file.reader());
defer reader.deinit();
const config_obj = try std.json.parseFromTokenSourceLeaky(llama.LlamaLM.Config, allocator, &reader, .{ .ignore_unknown_fields = true });
break :blk config_obj;
} else {
log.err("Missing --config", .{});
return;
}
};
var context = try zml.Context.init(); var context = try zml.Context.init();
defer context.deinit(); defer context.deinit();
const compilation_options = zml.CompilationOptions{ const compilation_options = zml.CompilationOptions{
.xla_dump_to = "/tmp/zml/llama", .xla_dump_to = "/tmp/zml/llama",
.sharding_enabled = true, .sharding_enabled = res.args.sharding orelse true,
}; };
var args = std.process.args(); // initialize ZML platform with optional create options
const cli_args = flags.parse(&args, CliArgs); // eg: --create-options='{"cuda":{"allocator":{"bfc":{"memory_fraction": 0.99}}}}'
const model_file = cli_args.model; const create_opts_json = res.args.@"create-options" orelse "{}";
const create_opts = try std.json.parseFromSlice(zml.Platform.CreateOptions, allocator, create_opts_json, .{});
var arena_state = std.heap.ArenaAllocator.init(allocator); const platform = context.autoPlatform(create_opts.value).withCompilationOptions(compilation_options);
defer arena_state.deinit(); create_opts.deinit();
const model_arena = arena_state.allocator();
const create_opts = try std.json.parseFromSliceLeaky(zml.Platform.CreateOptions, model_arena, cli_args.create_options, .{});
const platform = context.autoPlatform(create_opts).withCompilationOptions(compilation_options);
context.printAvailablePlatforms(platform); context.printAvailablePlatforms(platform);
log.info("Model file: {s}", .{model_file}); var ts = try zml.aio.detectFormatAndOpen(allocator, res.args.weights.?);
var ts = try zml.aio.detectFormatAndOpen(allocator, model_file);
defer ts.deinit(); defer ts.deinit();
var llama = try zml.aio.populateModel(LlamaLM, model_arena, ts); var model_arena = std.heap.ArenaAllocator.init(allocator);
const num_heads = cli_args.num_heads orelse ts.metadata("num_heads", .int) orelse @panic("--num_heads is required for this model"); var model_instance = try zml.aio.populateModel(llama.LlamaLM, model_arena.allocator(), ts);
const num_kv_heads = cli_args.num_kv_heads orelse ts.metadata("num_kv_heads", .int) orelse num_heads;
const rope_impl = if (ts.metadata("rope_impl", .string)) |val| const llama_options: llama.LlamaLM.Options = .{
std.meta.stringToEnum(zml.nn.RopeOpts.Implementation, val).? .max_seq_len = @intCast(res.args.@"seq-len" orelse 256),
else .sampling_strategy = .{
.sequential; .topk = 1,
.temperature = 1.0,
const llama_options: llama_mod.LlamaOptions = .{
.max_seq_len = cli_args.seq_len,
.num_kv_heads = num_kv_heads,
.num_heads = num_heads,
.gen_opts = .{
.topk = cli_args.topk,
.temperature = @floatFromInt(cli_args.temperature),
},
.rms_norm_eps = @floatCast(ts.metadata("rms_norm_eps", .float) orelse 1e-5),
.rope_opts = .{
.impl = rope_impl,
.freq_base = @floatCast(ts.metadata("rope_freq_base", .float) orelse @as(f32, @floatFromInt(cli_args.rope_freq_base orelse 10_000))),
}, },
}; };
log.info("\tParsed llama config: {}", .{llama_options}); model_instance.init(config, llama_options);
llama.init(llama_options);
if (cli_args.tokenizer == null and !std.mem.endsWith(u8, cli_args.model, ".gguf")) { const dims = model_instance.model.shape();
log.err("Model doesn't have an embbedded tokenizer, please provide a path to a tokenizer.", .{}); const dtype = model_instance.model.embed_tokens.weight.dtype();
@panic("No tokenizer provided");
}
const tokenizer_path = cli_args.tokenizer orelse cli_args.model;
log.info("\tLoading tokenizer from {s}", .{tokenizer_path});
var tokenizer = try zml.aio.detectFormatAndLoadTokenizer(allocator, tokenizer_path);
log.info("\tLoaded tokenizer from {s}", .{tokenizer_path});
defer tokenizer.deinit();
const dims = llama.model.shape(); const batch_size = 1;
const dtype = llama.model.embed_tokens.weight.dtype();
// Note: we compile the model without a batching dimension. const tokens_shape_prefill = zml.Shape.init(.{ .b = batch_size, .s = llama_options.max_seq_len }, .u32);
// To do so, we would just need to add `.b = batch_size` to `token_shape` and `kv_shape`. const tokens_shape = zml.Shape.init(.{ .b = batch_size, .s = 1 }, .u32);
const tokens_shape = zml.Shape.init(.{ .s = dims.s }, .i32); const token_idx_shape = zml.Shape.init(.{ .b = batch_size }, .u32);
const token_idx_shape = zml.Shape.init(.{}, .i32);
const kv_shape = zml.Shape.init(.{ .layer = llama.model.layers.len, .h = dims.nkvh, .k = dims.s, .hd = dims.hd }, dtype).withSharding(.{.h}); const kv_shape = zml.Shape.init(.{ .layer = model_instance.model.layers.len, .b = batch_size, .k = dims.s, .h = dims.nkvh, .hd = dims.hd }, dtype).withSharding(.{.h});
// needs to be optional
const kv_cache_shape: ?ShapeOf(KvCache) = KvCache.initShape(kv_shape); const kv_cache_shape: zml.ShapeOf(llama.KvCache) = llama.KvCache.initShape(kv_shape);
const rng_shape = Tensor.Rng.shape(); const rng_shape = zml.Tensor.Rng.shape();
var start = try std.time.Timer.start(); var start = try std.time.Timer.start();
var fut_mod_prefill = try asynk.asyncc(zml.compile, .{ allocator, LlamaLM.forward, .{llama_options}, .{ tokens_shape, token_idx_shape, null, rng_shape }, ts, platform }); var fut_mod_prefill = try asynk.asyncc(zml.compile, .{
var fut_mod = try asynk.asyncc(zml.compile, .{ allocator, LlamaLM.forward, .{llama_options}, .{ tokens_shape, token_idx_shape, kv_cache_shape, rng_shape }, ts, platform }); allocator, llama.LlamaLM.forward, .{ config, llama_options },
.{
tokens_shape_prefill,
token_idx_shape,
kv_cache_shape,
rng_shape,
},
ts,
platform,
});
log.info("\tLoading Llama weights from {s}...", .{cli_args.model}); var fut_mod = try asynk.asyncc(zml.compile, .{
var llama_weights = try zml.aio.loadBuffers(LlamaLM, .{llama_options}, ts, model_arena, platform); allocator, llama.LlamaLM.forward, .{ config, llama_options },
.{
tokens_shape,
token_idx_shape,
kv_cache_shape,
rng_shape,
},
ts,
platform,
});
log.info("\tLoading Llama weights from {?s}...", .{res.args.weights});
var llama_weights = try zml.aio.loadBuffers(llama.LlamaLM, .{ config, llama_options }, ts, model_arena.allocator(), platform);
defer zml.aio.unloadBuffers(&llama_weights); defer zml.aio.unloadBuffers(&llama_weights);
log.info("\tLoaded weights in {d}ms", .{start.read() / std.time.ns_per_ms}); log.info("\tLoaded weights in {}", .{std.fmt.fmtDuration(start.read())});
var llama_module_prefill = (try fut_mod_prefill.awaitt()).prepare(llama_weights); var llama_module_prefill = (try fut_mod_prefill.awaitt()).prepare(llama_weights);
defer llama_module_prefill.deinit(); defer llama_module_prefill.deinit();
var llama_module = (try fut_mod.awaitt()).prepare(llama_weights); var llama_module = (try fut_mod.awaitt()).prepare(llama_weights);
defer llama_module.deinit(); defer llama_module.deinit();
log.info("\tCompiled model in {d}ms", .{start.read() / std.time.ns_per_ms}); log.info("\tCompiled model in {}", .{std.fmt.fmtDuration(start.read())});
const prompt = cli_args.prompt orelse "Once upon a time, there was a little girl named Lily."; log.info("Creating KvCache", .{});
const kv_cache = try llama.KvCache.initBuffer(kv_shape, platform);
var tokenizer = blk: {
if (res.args.tokenizer) |tok| {
log.info("Loading tokenizer from {s}", .{tok});
var timer = try stdx.time.Timer.start();
defer log.info("Loaded tokenizer from {s} [{}]", .{ tok, timer.read() });
break :blk try zml.tokenizer.Tokenizer.from_file(model_arena.allocator(), tok);
} else {
log.err("Missing --tokenizer", .{});
return;
}
};
errdefer tokenizer.deinit();
const prompt = res.args.prompt orelse "What is the capital of France?";
log.info("\tPrompt: {s}", .{prompt}); log.info("\tPrompt: {s}", .{prompt});
const seed = cli_args.seed orelse @as(u128, @bitCast(std.time.nanoTimestamp())); const seed = res.args.seed orelse @as(u128, @bitCast(std.time.nanoTimestamp()));
const story = try generateText(llama, llama_module_prefill, llama_module, tokenizer, allocator, seed, prompt); const skip_llama3_encoding = res.args.@"no-llama3" orelse false;
defer allocator.free(story); const generated_text = try generateText(config, model_instance, llama_module_prefill, llama_module, kv_cache, tokenizer, allocator, seed, prompt[0..], skip_llama3_encoding);
// generated text will be printed token by token.
defer allocator.free(generated_text);
} }

View File

@ -16,9 +16,10 @@ pub fn main() !void {
pub fn asyncMain() !void { pub fn asyncMain() !void {
const CliArgs = struct { const CliArgs = struct {
pub const help = pub const help =
\\ test-implementation --model=llama3.8B.safetensors --reference=activation.safetensors \\ test-implementation --weights=llama3.8B.safetensors --config=config.json --reference=activation.safetensors
; ;
model: []const u8, weights: []const u8,
config: []const u8,
reference: []const u8, reference: []const u8,
num_heads: ?i64 = null, num_heads: ?i64 = null,
num_kv_heads: ?i64 = null, num_kv_heads: ?i64 = null,
@ -38,7 +39,7 @@ pub fn asyncMain() !void {
// Parse program args // Parse program args
var args = std.process.args(); var args = std.process.args();
const cli_args = flags.parse(&args, CliArgs); const cli_args = flags.parse(&args, CliArgs);
const model_file = cli_args.model; const model_file = cli_args.weights;
// Memory arena dedicated to model shapes and weights // Memory arena dedicated to model shapes and weights
var arena_state = std.heap.ArenaAllocator.init(allocator); var arena_state = std.heap.ArenaAllocator.init(allocator);
@ -61,6 +62,16 @@ pub fn asyncMain() !void {
else else
.sequential; .sequential;
const config = blk: {
var config_json_file = try asynk.File.open(cli_args.config, .{ .mode = .read_only });
defer config_json_file.close() catch unreachable;
var reader = std.json.reader(allocator, config_json_file.reader());
defer reader.deinit();
const config_obj = try std.json.parseFromTokenSourceLeaky(LlamaLM.Config, allocator, &reader, .{ .ignore_unknown_fields = true });
break :blk config_obj;
};
std.log.info("Parsed llama config: {}", .{config});
const llama_options: llama_mod.LlamaOptions = .{ const llama_options: llama_mod.LlamaOptions = .{
.max_seq_len = 256, .max_seq_len = 256,
.num_kv_heads = num_kv_heads, .num_kv_heads = num_kv_heads,
@ -101,26 +112,4 @@ fn testImplementation(
try zml.testing.testLayer(platform, buffer_store, "layers.0.mlp", llama.model.layers[0].mlp, llama_weights.model.layers[0].mlp, 1e-2); try zml.testing.testLayer(platform, buffer_store, "layers.0.mlp", llama.model.layers[0].mlp, llama_weights.model.layers[0].mlp, 1e-2);
try zml.testing.testLayer(platform, buffer_store, "layers.0.input_layernorm", llama.model.layers[0].input_layernorm, llama_weights.model.layers[0].input_layernorm, 1e-2); try zml.testing.testLayer(platform, buffer_store, "layers.0.input_layernorm", llama.model.layers[0].input_layernorm, llama_weights.model.layers[0].input_layernorm, 1e-2);
try zml.testing.testLayer(platform, buffer_store, "layers.0.post_attention_layernorm", llama.model.layers[0].post_attention_layernorm, llama_weights.model.layers[0].post_attention_layernorm, 1e-2); try zml.testing.testLayer(platform, buffer_store, "layers.0.post_attention_layernorm", llama.model.layers[0].post_attention_layernorm, llama_weights.model.layers[0].post_attention_layernorm, 1e-2);
{
const test_case = "layers.0.self_attn";
std.log.info("Testing {s}", .{test_case});
// Small wrapper to explicitly tag the input, and ignore the extra arguments used in HF implementation.
const SelfAttnPrefill = struct {
inner: llama_mod.SelfAttn,
pub fn forward(self: @This(), x_: Tensor) struct { Tensor, llama_mod.KvCache } {
return self.inner.forward(x_.withTags(.{ .b, .s, .d }), null, null);
}
};
try zml.testing.testLayer(
platform,
buffer_store,
"layers.0.self_attn",
SelfAttnPrefill{ .inner = llama.model.layers[0].self_attn },
.{ .inner = llama_weights.model.layers[0].self_attn },
1e-3,
);
}
} }

View File

@ -0,0 +1,9 @@
load("@rules_zig//zig:defs.bzl", "zig_library")
zig_library(
name = "clap",
import_name = "clap",
srcs = glob(["clap/*.zig"]),
main = "clap.zig",
visibility = ["//visibility:public"],
)

View File

@ -0,0 +1,20 @@
load("@bazel_tools//tools/build_defs/repo:git.bzl", "new_git_repository")
def _non_module_deps_impl(mctx):
new_git_repository(
name = "com_github_hejsil_clap",
remote = "https://github.com/Hejsil/zig-clap.git",
commit = "d71cc39a94f3e6ccbad00c25d350c9147de4df9f",
build_file = "//:third_party/com_github_hejsil_clap/clap.bazel",
)
return mctx.extension_metadata(
reproducible = True,
root_module_direct_deps = "all",
root_module_direct_dev_deps = [],
)
non_module_deps = module_extension(
implementation = _non_module_deps_impl,
)