Update Llama example docs and Bazel build files, and add tests for the new HuggingFace tokenizer integration.
This commit is contained in:
parent
959bc48c42
commit
76e314db9b
@ -59,47 +59,41 @@ Llama is a family of "Large Language Models", trained to generate text, based
|
||||
on the beginning of a sentence/book/article. This "beginning" is generally
|
||||
referred to as the "prompt".
|
||||
|
||||
#### TinyLlama, Stories 15M
|
||||
|
||||
To start, you can use a small model trained specifically on children's history
|
||||
books. This model has been trained by [Andrej Karpathy](https://x.com/karpathy);
|
||||
you can read more about it on his
|
||||
[Github](https://github.com/karpathy/llama2.c).
|
||||
|
||||
```
|
||||
cd examples
|
||||
bazel run -c opt //llama:TinyLlama-Stories-15M
|
||||
bazel run -c opt //llama:TinyLlama-Stories-15M -- --prompt="Once upon a time, there was a cute little dragon"
|
||||
```
|
||||
|
||||
#### OpenLLama 3B
|
||||
|
||||
```
|
||||
cd examples
|
||||
bazel run -c opt //llama:OpenLLaMA-3B
|
||||
bazel run -c opt //llama:OpenLLaMA-3B -- --prompt="Once upon a time,"
|
||||
```
|
||||
|
||||
#### Meta Llama 3 8B
|
||||
#### Meta Llama 3.1 8B
|
||||
|
||||
This model has restrictions, see
|
||||
[here](https://huggingface.co/meta-llama/Meta-Llama-3-8B): it **requires
|
||||
[here](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct). It **requires
|
||||
approval from Meta on Huggingface**, which can take a few hours to get granted.
|
||||
|
||||
While waiting for approval, you can already
|
||||
[generate your Huggingface access token](../howtos/huggingface_access_token.md).
|
||||
|
||||
Once you've been granted access, you're ready to download a gated model like
|
||||
`Meta-Llama-3-8b`!
|
||||
`Meta-Llama-3.1-8B-Instruct`!
|
||||
|
||||
```
|
||||
# requires token in $HOME/.cache/huggingface/token, as created by the
|
||||
# `huggingface-cli login` command, or the `HUGGINGFACE_TOKEN` environment variable.
|
||||
cd examples
|
||||
bazel run -c opt //llama:Meta-Llama-3-8b
|
||||
bazel run -c opt //llama:Meta-Llama-3-8b -- --promt="Once upon a time,"
|
||||
bazel run -c opt //llama:Llama-3.1-8B-Instruct
|
||||
bazel run -c opt //llama:Llama-3.1-8B-Instruct -- --prompt="What is the capital of France?"
|
||||
```
|
||||
|
||||
You can also try `Llama-3.1-70B-Instruct` if you have enough memory.
|
||||
|
||||
### Meta Llama 3.2 1B
|
||||
|
||||
Like the 8B model above, this model also requires approval. See
|
||||
[here](https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct) for access requirements.
|
||||
|
||||
```
|
||||
cd examples
|
||||
bazel run -c opt //llama:Llama-3.2-1B-Instruct
|
||||
bazel run -c opt //llama:Llama-3.2-1B-Instruct -- --prompt="What is the capital of France?"
|
||||
```
|
||||
|
||||
For a larger 3.2 model, you can also try `Llama-3.2-3B-Instruct`.
|
||||
|
||||
|
||||
## Run Tests
|
||||
|
||||
@ -126,9 +120,9 @@ run the following:
|
||||
|
||||
```
|
||||
cd examples
|
||||
bazel run -c opt //llama:OpenLLaMA-3B \
|
||||
--@zml//runtimes:cuda=true \
|
||||
-- --prompt="Once upon a time,"
|
||||
bazel run -c opt //llama:Llama-3.2-1B-Instruct \
|
||||
--@zml//runtimes:cuda=true \
|
||||
-- --prompt="What is the capital of France?"
|
||||
```
|
||||
|
||||
|
||||
|
||||
0
examples/BUILD.bazel
Normal file
0
examples/BUILD.bazel
Normal file
@ -7,6 +7,9 @@ bazel_dep(name = "zml", version = "0.1.0")
|
||||
bazel_dep(name = "aspect_bazel_lib", version = "2.11.0")
|
||||
bazel_dep(name = "rules_oci", version = "2.0.0")
|
||||
|
||||
non_module_deps = use_extension("//:third_party/non_module_deps.bzl", "non_module_deps")
|
||||
use_repo(non_module_deps, "com_github_hejsil_clap")
|
||||
|
||||
oci = use_extension("@rules_oci//oci:extensions.bzl", "oci")
|
||||
oci.pull(
|
||||
name = "distroless_cc_debian12",
|
||||
@ -44,67 +47,24 @@ http_file(
|
||||
url = "https://github.com/ggerganov/ggml/raw/18703ad600cc68dbdb04d57434c876989a841d12/examples/mnist/models/mnist/t10k-images.idx3-ubyte",
|
||||
)
|
||||
|
||||
# Llama weights
|
||||
huggingface = use_extension("@zml//bazel:huggingface.bzl", "huggingface")
|
||||
huggingface.model(
|
||||
name = "Karpathy-TinyLlama-Stories",
|
||||
build_file_content = """\
|
||||
load("@bazel_skylib//rules:copy_file.bzl", "copy_file")
|
||||
|
||||
# leverage copy_file to rename tokenizer extension
|
||||
# which allow zml.aio.detectFormatAndLoadTokenizer
|
||||
# to leverage the right tokenizer
|
||||
copy_file(
|
||||
name = "stories15M",
|
||||
src = "stories15M.bin",
|
||||
out = "stories15M.tinyllama",
|
||||
allow_symlink = True,
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
copy_file(
|
||||
name = "stories110M",
|
||||
src = "stories110M.bin",
|
||||
out = "stories110M.tinyllama",
|
||||
allow_symlink = True,
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
""",
|
||||
commit = "0bd21da7698eaf29a0d7de3992de8a46ef624add",
|
||||
includes = [
|
||||
"stories15M.bin",
|
||||
"stories110M.bin",
|
||||
],
|
||||
model = "karpathy/tinyllamas",
|
||||
)
|
||||
use_repo(huggingface, "Karpathy-TinyLlama-Stories")
|
||||
|
||||
http_file(
|
||||
name = "Karpathy-TinyLlama-Tokenizer",
|
||||
downloaded_file_path = "stories260K.tinyllama",
|
||||
sha256 = "50a52ef822ee9e83de5ce9d0be0a025a773d019437f58b5ff9dcafb063ece361",
|
||||
url = "https://github.com/karpathy/llama2.c/raw/c02865df300f3bd9e567ce061000dc23bf785a17/tokenizer.bin",
|
||||
)
|
||||
|
||||
# Llama 3.2
|
||||
|
||||
huggingface = use_extension("@zml//bazel:huggingface.bzl", "huggingface")
|
||||
|
||||
huggingface.model(
|
||||
name = "Meta-Llama-3.2-1B-Instruct",
|
||||
build_file_content = """\
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
filegroup(
|
||||
name = "model",
|
||||
srcs = ["model.safetensors"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "tokenizer",
|
||||
srcs = ["tokenizer.json"],
|
||||
name = "Meta-Llama-3.2-1B-Instruct",
|
||||
srcs = glob(["*.json", "*.safetensors"]),
|
||||
)
|
||||
""",
|
||||
commit = "9213176726f574b556790deb65791e0c5aa438b6",
|
||||
includes = [
|
||||
"model.safetensors",
|
||||
"tokenizer.json",
|
||||
"*.safetensors",
|
||||
"*.json",
|
||||
],
|
||||
model = "meta-llama/Llama-3.2-1B-Instruct",
|
||||
)
|
||||
@ -115,129 +75,87 @@ huggingface.model(
|
||||
build_file_content = """\
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
filegroup(
|
||||
name = "model",
|
||||
srcs = glob(["*.safetensors"]) + ["model.safetensors.index.json"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "tokenizer",
|
||||
srcs = ["tokenizer.json"],
|
||||
name = "Meta-Llama-3.2-3B-Instruct",
|
||||
srcs = glob(["*.json", "*.safetensors"]),
|
||||
)
|
||||
""",
|
||||
commit = "0cb88a4f764b7a12671c53f0838cd831a0843b95",
|
||||
includes = [
|
||||
"*.safetensors",
|
||||
"model.safetensors.index.json",
|
||||
"tokenizer.json",
|
||||
"*.json",
|
||||
],
|
||||
model = "meta-llama/Llama-3.2-3B-Instruct",
|
||||
)
|
||||
use_repo(huggingface, "Meta-Llama-3.2-3B-Instruct")
|
||||
|
||||
# Llama 3.1
|
||||
|
||||
huggingface.model(
|
||||
name = "Meta-Llama-3.1-8B-Instruct",
|
||||
build_file_content = """\
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
filegroup(
|
||||
name = "model",
|
||||
srcs = glob(["*.safetensors"]) + ["model.safetensors.index.json"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "tokenizer",
|
||||
srcs = ["tokenizer.json"],
|
||||
name = "Meta-Llama-3.1-8B-Instruct",
|
||||
srcs = glob(["*.json", "*.safetensors"]),
|
||||
)
|
||||
""",
|
||||
commit = "5206a32e0bd3067aef1ce90f5528ade7d866253f",
|
||||
includes = [
|
||||
"*.safetensors",
|
||||
"model.safetensors.index.json",
|
||||
"tokenizer.json",
|
||||
"*.json",
|
||||
],
|
||||
model = "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
)
|
||||
use_repo(huggingface, "Meta-Llama-3.1-8B-Instruct")
|
||||
|
||||
huggingface.model(
|
||||
name = "Meta-Llama-3.1-70B-Instruct",
|
||||
build_file_content = """\
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
filegroup(
|
||||
name = "model",
|
||||
srcs = glob(["*.safetensors"]) + ["model.safetensors.index.json"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "tokenizer",
|
||||
srcs = ["tokenizer.json"],
|
||||
name = "Meta-Llama-3.1-70B-Instruct",
|
||||
srcs = glob(["*.json", "*.safetensors"]),
|
||||
)
|
||||
""",
|
||||
commit = "945c8663693130f8be2ee66210e062158b2a9693",
|
||||
includes = [
|
||||
"*.safetensors",
|
||||
"model.safetensors.index.json",
|
||||
"tokenizer.json",
|
||||
"*.json",
|
||||
],
|
||||
model = "meta-llama/Meta-Llama-3.1-70B-Instruct",
|
||||
)
|
||||
use_repo(huggingface, "Meta-Llama-3.1-70B-Instruct")
|
||||
|
||||
|
||||
huggingface.model(
|
||||
name = "TinyLlama-1.1B-Chat-v1.0",
|
||||
name = "TinyLlama-120M-scratch",
|
||||
build_file_content = """\
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
filegroup(
|
||||
name = "model",
|
||||
srcs = ["model.safetensors"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "tokenizer",
|
||||
srcs = ["tokenizer.model"],
|
||||
name = "TinyLlama-120M-scratch",
|
||||
srcs = glob(["*.json", "*.safetensors"]),
|
||||
)
|
||||
""",
|
||||
commit = "fe8a4ea1ffedaf415f4da2f062534de366a451e6",
|
||||
commit = "89c1bb4ea00861ddaa26c55f102ccb25e161feee",
|
||||
includes = [
|
||||
"model.safetensors",
|
||||
"tokenizer.model",
|
||||
"*.safetensors",
|
||||
"*.json",
|
||||
],
|
||||
model = "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
model = "Hoyeon/TinyLlama-120M-scratch",
|
||||
)
|
||||
use_repo(huggingface, "TinyLlama-1.1B-Chat-v1.0")
|
||||
use_repo(huggingface, "TinyLlama-120M-scratch")
|
||||
|
||||
#OpenLLaMa
|
||||
huggingface.model(
|
||||
name = "OpenLM-Research-OpenLLaMA-3B",
|
||||
build_file_content = """\
|
||||
load("@bazel_skylib//rules:copy_file.bzl", "copy_file")
|
||||
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
filegroup(
|
||||
name = "model",
|
||||
srcs = ["model.safetensors"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "tokenizer",
|
||||
srcs = [":tokenizer_pb"],
|
||||
)
|
||||
|
||||
# leverage copy_file to rename tokenizer extension
|
||||
# which allow zml.aio.detectFormatAndLoadTokenizer
|
||||
# to leverage the right tokenizer
|
||||
copy_file(
|
||||
name = "tokenizer_pb",
|
||||
src = "tokenizer.model",
|
||||
out = "tokenizer.pb",
|
||||
allow_symlink = True,
|
||||
)
|
||||
""",
|
||||
commit = "fcc2e809eb8f14dabba84d76a0ddc17b8ea05356",
|
||||
includes = [
|
||||
"model.safetensors",
|
||||
"tokenizer.model",
|
||||
bazel_dep(name = "rules_rust", version = "0.57.0")
|
||||
rust = use_extension("@rules_rust//rust:extensions.bzl", "rust")
|
||||
rust.toolchain(
|
||||
edition = "2021",
|
||||
versions = ["1.84.0"],
|
||||
extra_target_triples = [
|
||||
"aarch64-apple-darwin",
|
||||
"aarch64-unknown-linux-gnu",
|
||||
"x86_64-unknown-linux-gnu",
|
||||
],
|
||||
model = "openlm-research/open_llama_3b",
|
||||
)
|
||||
use_repo(huggingface, "OpenLM-Research-OpenLLaMA-3B")
|
||||
use_repo(rust, "rust_toolchains")
|
||||
register_toolchains("@rust_toolchains//:all")
|
||||
|
||||
File diff suppressed because one or more lines are too long
@ -1,3 +1,4 @@
|
||||
load("@aspect_bazel_lib//lib:expand_template.bzl", "expand_template")
|
||||
load("@aspect_bazel_lib//lib:tar.bzl", "mtree_spec", "tar")
|
||||
load("@aspect_bazel_lib//lib:transitions.bzl", "platform_transition_filegroup")
|
||||
load("@bazel_skylib//rules:native_binary.bzl", "native_binary")
|
||||
@ -12,7 +13,7 @@ zig_cc_binary(
|
||||
],
|
||||
main = "main.zig",
|
||||
deps = [
|
||||
"//third_party/tigerbeetle:flags",
|
||||
"@com_github_hejsil_clap//:clap",
|
||||
"@zml//async",
|
||||
"@zml//stdx",
|
||||
"@zml//zml",
|
||||
@ -20,18 +21,35 @@ zig_cc_binary(
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "Llama-3.1-8B-Instruct",
|
||||
name = "TinyLlama-120M-scratch",
|
||||
args = [
|
||||
"--model=$(location @Meta-Llama-3.1-8B-Instruct//:model.safetensors.index.json)",
|
||||
"--tokenizer=$(location @Meta-Llama-3.1-8B-Instruct//:tokenizer)",
|
||||
"--num-heads=32",
|
||||
"--num-kv-heads=8",
|
||||
"--rope-freq-base=500000",
|
||||
"--config=$(location @TinyLlama-120M-scratch//:config.json)",
|
||||
"--weights=$(location @TinyLlama-120M-scratch//:model.safetensors)",
|
||||
"--tokenizer=$(location @TinyLlama-120M-scratch//:tokenizer.json)",
|
||||
"--no-llama3=true", # don't do llama3 template prompt encoding
|
||||
"--sharding=false", # don't shard this
|
||||
],
|
||||
data = [
|
||||
"@Meta-Llama-3.1-8B-Instruct//:model",
|
||||
"@TinyLlama-120M-scratch",
|
||||
"@TinyLlama-120M-scratch//:config.json",
|
||||
"@TinyLlama-120M-scratch//:model.safetensors",
|
||||
"@TinyLlama-120M-scratch//:tokenizer.json",
|
||||
],
|
||||
deps = [":llama_lib"],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "Llama-3.1-8B-Instruct",
|
||||
args = [
|
||||
"--config=$(location @Meta-Llama-3.1-8B-Instruct//:config.json)",
|
||||
"--weights=$(location @Meta-Llama-3.1-8B-Instruct//:model.safetensors.index.json)",
|
||||
"--tokenizer=$(location @Meta-Llama-3.1-8B-Instruct//:tokenizer.json)",
|
||||
],
|
||||
data = [
|
||||
"@Meta-Llama-3.1-8B-Instruct",
|
||||
"@Meta-Llama-3.1-8B-Instruct//:config.json",
|
||||
"@Meta-Llama-3.1-8B-Instruct//:model.safetensors.index.json",
|
||||
"@Meta-Llama-3.1-8B-Instruct//:tokenizer",
|
||||
"@Meta-Llama-3.1-8B-Instruct//:tokenizer.json",
|
||||
],
|
||||
deps = [":llama_lib"],
|
||||
)
|
||||
@ -39,32 +57,32 @@ cc_binary(
|
||||
cc_binary(
|
||||
name = "Llama-3.1-70B-Instruct",
|
||||
args = [
|
||||
"--model=$(location @Meta-Llama-3.1-70B-Instruct//:model.safetensors.index.json)",
|
||||
"--tokenizer=$(location @Meta-Llama-3.1-70B-Instruct//:tokenizer)",
|
||||
"--num-heads=64",
|
||||
"--num-kv-heads=8",
|
||||
"--rope-freq-base=500000",
|
||||
"--config=$(location @Meta-Llama-3.1-70B-Instruct//:config.json)",
|
||||
"--weights=$(location @Meta-Llama-3.1-70B-Instruct//:model.safetensors.index.json)",
|
||||
"--tokenizer=$(location @Meta-Llama-3.1-70B-Instruct//:tokenizer.json)",
|
||||
],
|
||||
data = [
|
||||
"@Meta-Llama-3.1-70B-Instruct//:model",
|
||||
"@Meta-Llama-3.1-70B-Instruct",
|
||||
"@Meta-Llama-3.1-70B-Instruct//:config.json",
|
||||
"@Meta-Llama-3.1-70B-Instruct//:model.safetensors.index.json",
|
||||
"@Meta-Llama-3.1-70B-Instruct//:tokenizer",
|
||||
"@Meta-Llama-3.1-70B-Instruct//:tokenizer.json",
|
||||
],
|
||||
deps = [":llama_lib"],
|
||||
)
|
||||
|
||||
|
||||
cc_binary(
|
||||
name = "Llama-3.2-1B-Instruct",
|
||||
args = [
|
||||
"--model=$(location @Meta-Llama-3.2-1B-Instruct//:model.safetensors)",
|
||||
"--tokenizer=$(location @Meta-Llama-3.2-1B-Instruct//:tokenizer)",
|
||||
"--num-heads=32",
|
||||
"--num-kv-heads=8",
|
||||
"--rope-freq-base=500000",
|
||||
"--config=$(location @Meta-Llama-3.2-1B-Instruct//:config.json)",
|
||||
"--weights=$(location @Meta-Llama-3.2-1B-Instruct//:model.safetensors)",
|
||||
"--tokenizer=$(location @Meta-Llama-3.2-1B-Instruct//:tokenizer.json)",
|
||||
],
|
||||
data = [
|
||||
"@Meta-Llama-3.2-1B-Instruct",
|
||||
"@Meta-Llama-3.2-1B-Instruct//:config.json",
|
||||
"@Meta-Llama-3.2-1B-Instruct//:model.safetensors",
|
||||
"@Meta-Llama-3.2-1B-Instruct//:tokenizer",
|
||||
"@Meta-Llama-3.2-1B-Instruct//:tokenizer.json",
|
||||
],
|
||||
deps = [":llama_lib"],
|
||||
)
|
||||
@ -72,86 +90,26 @@ cc_binary(
|
||||
cc_binary(
|
||||
name = "Llama-3.2-3B-Instruct",
|
||||
args = [
|
||||
"--model=$(location @Meta-Llama-3.2-3B-Instruct//:model.safetensors.index.json)",
|
||||
"--tokenizer=$(location @Meta-Llama-3.2-3B-Instruct//:tokenizer)",
|
||||
"--num-heads=24",
|
||||
"--num-kv-heads=8",
|
||||
"--rope-freq-base=500000",
|
||||
"--config=$(location @Meta-Llama-3.2-3B-Instruct//:config.json)",
|
||||
"--weights=$(location @Meta-Llama-3.2-3B-Instruct//:model.safetensors.index.json)",
|
||||
"--tokenizer=$(location @Meta-Llama-3.2-3B-Instruct//:tokenizer.json)",
|
||||
],
|
||||
data = [
|
||||
"@Meta-Llama-3.2-3B-Instruct//:model",
|
||||
"@Meta-Llama-3.2-3B-Instruct",
|
||||
"@Meta-Llama-3.2-3B-Instruct//:config.json",
|
||||
"@Meta-Llama-3.2-3B-Instruct//:model.safetensors.index.json",
|
||||
"@Meta-Llama-3.2-3B-Instruct//:tokenizer",
|
||||
],
|
||||
deps = [":llama_lib"],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "OpenLLaMA-3B",
|
||||
args = [
|
||||
"--model=$(location @OpenLM-Research-OpenLLaMA-3B//:model)",
|
||||
"--tokenizer=$(location @OpenLM-Research-OpenLLaMA-3B//:tokenizer)",
|
||||
"--num-heads=32",
|
||||
"--num-kv-heads=32",
|
||||
"--rope-freq-base=10000",
|
||||
],
|
||||
data = [
|
||||
"@OpenLM-Research-OpenLLaMA-3B//:model",
|
||||
"@OpenLM-Research-OpenLLaMA-3B//:tokenizer",
|
||||
],
|
||||
deps = [":llama_lib"],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "TinyLlama-1.1B-Chat",
|
||||
args = [
|
||||
"--model=$(location @TinyLlama-1.1B-Chat-v1.0//:model.safetensors)",
|
||||
"--tokenizer=$(location @TinyLlama-1.1B-Chat-v1.0//:tokenizer)",
|
||||
"--num-heads=32",
|
||||
"--num-kv-heads=4",
|
||||
"--rope-freq-base=10000",
|
||||
],
|
||||
data = [
|
||||
"@TinyLlama-1.1B-Chat-v1.0//:model.safetensors",
|
||||
"@TinyLlama-1.1B-Chat-v1.0//:tokenizer",
|
||||
],
|
||||
deps = [":llama_lib"],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "TinyLlama-Stories-110M",
|
||||
args = [
|
||||
"--model=$(location @Karpathy-TinyLlama-Stories//:stories110M)",
|
||||
"--tokenizer=$(location @Karpathy-TinyLlama-Tokenizer//file)",
|
||||
],
|
||||
data = [
|
||||
"@Karpathy-TinyLlama-Stories//:stories110M",
|
||||
"@Karpathy-TinyLlama-Tokenizer//file",
|
||||
],
|
||||
deps = [":llama_lib"],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "TinyLlama-Stories-15M",
|
||||
args = [
|
||||
"--model=$(location @Karpathy-TinyLlama-Stories//:stories15M)",
|
||||
"--tokenizer=$(location @Karpathy-TinyLlama-Tokenizer//file)",
|
||||
],
|
||||
data = [
|
||||
"@Karpathy-TinyLlama-Stories//:stories15M",
|
||||
"@Karpathy-TinyLlama-Tokenizer//file",
|
||||
"@Meta-Llama-3.2-3B-Instruct//:tokenizer.json",
|
||||
],
|
||||
deps = [":llama_lib"],
|
||||
)
|
||||
#
|
||||
|
||||
zig_cc_binary(
|
||||
name = "test-implementation",
|
||||
srcs = ["llama.zig"],
|
||||
args = [
|
||||
"--model=$(location @Meta-Llama-3.1-8B-Instruct//:model.safetensors.index.json)",
|
||||
"--num-heads=32",
|
||||
"--num-kv-heads=8",
|
||||
"--rope-freq-base=500000",
|
||||
"--weights=$(location @Meta-Llama-3.1-8B-Instruct//:model.safetensors.index.json)",
|
||||
"--config=$(location @Meta-Llama-3.1-8B-Instruct//:config.json)",
|
||||
],
|
||||
data = [
|
||||
"@Meta-Llama-3.1-8B-Instruct//:model",
|
||||
@ -184,12 +142,12 @@ zig_cc_binary(
|
||||
|
||||
mtree_spec(
|
||||
name = "mtree",
|
||||
srcs = [":llama"],
|
||||
srcs = [":Llama-3.2-1B-Instruct"],
|
||||
)
|
||||
|
||||
tar(
|
||||
name = "archive",
|
||||
srcs = [":llama"],
|
||||
srcs = [":Llama-3.2-1B-Instruct"],
|
||||
args = [
|
||||
"--options",
|
||||
"zstd:compression-level=9",
|
||||
@ -198,10 +156,33 @@ tar(
|
||||
mtree = ":mtree",
|
||||
)
|
||||
|
||||
expand_template(
|
||||
name = "entrypoint",
|
||||
data = [
|
||||
":Llama-3.2-1B-Instruct",
|
||||
"@Meta-Llama-3.2-1B-Instruct",
|
||||
"@Meta-Llama-3.2-1B-Instruct//:config.json",
|
||||
"@Meta-Llama-3.2-1B-Instruct//:model.safetensors",
|
||||
"@Meta-Llama-3.2-1B-Instruct//:tokenizer.json",
|
||||
],
|
||||
substitutions = {
|
||||
":config": "$(rlocationpath @Meta-Llama-3.2-1B-Instruct//:config.json)",
|
||||
":weights": "$(rlocationpath @Meta-Llama-3.2-1B-Instruct//:model.safetensors)",
|
||||
":tokenizer": "$(rlocationpath @Meta-Llama-3.2-1B-Instruct//:tokenizer.json)",
|
||||
},
|
||||
template = [
|
||||
"./{}/Llama-3.2-1B-Instruct".format(package_name()),
|
||||
"--config=./{}/Llama-3.2-1B-Instruct.runfiles/:config".format(package_name()),
|
||||
"--weights=./{}/Llama-3.2-1B-Instruct.runfiles/:weights".format(package_name()),
|
||||
"--tokenizer=./{}/Llama-3.2-1B-Instruct.runfiles/:tokenizer".format(package_name()),
|
||||
],
|
||||
)
|
||||
|
||||
oci_image(
|
||||
name = "image_",
|
||||
base = "@distroless_cc_debian12_debug",
|
||||
entrypoint = ["./{}/llama".format(package_name())],
|
||||
# entrypoint = ["./{}/Llama-3.2-1B-Instruct".format(package_name())],
|
||||
entrypoint = ":entrypoint",
|
||||
tars = [
|
||||
"@zml//runtimes:layers",
|
||||
":archive",
|
||||
@ -218,7 +199,7 @@ oci_load(
|
||||
name = "load",
|
||||
image = ":image",
|
||||
repo_tags = [
|
||||
"distroless/llama:latest",
|
||||
"distroless/llama-3.2-1b-instruct:latest",
|
||||
],
|
||||
)
|
||||
|
||||
@ -226,5 +207,5 @@ oci_push(
|
||||
name = "push",
|
||||
image = ":image",
|
||||
remote_tags = ["latest"],
|
||||
repository = "index.docker.io/steeve/llama",
|
||||
repository = "index.docker.io/steeve/llama-3.2-1b-instruct",
|
||||
)
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
const flags = @import("tigerbeetle/flags");
|
||||
const std = @import("std");
|
||||
const stdx = @import("stdx");
|
||||
const zml = @import("zml");
|
||||
@ -12,36 +11,51 @@ const gguf = zml.io.gguf;
|
||||
const expectClose = zml.testing.expectClose;
|
||||
const log = std.log.scoped(.llama);
|
||||
|
||||
pub const LlamaOptions = struct {
|
||||
gen_opts: zml.nn.SamplingStrategy,
|
||||
max_seq_len: u32,
|
||||
num_heads: i64,
|
||||
num_kv_heads: i64,
|
||||
rms_norm_eps: f32,
|
||||
rope_opts: zml.nn.RopeOpts,
|
||||
};
|
||||
|
||||
/// Llama architecture, using huggingface transformers naming.
|
||||
/// Dimensions of activations: {.b, .s, .d}
|
||||
pub const LlamaLM = struct {
|
||||
lm_head: ?zml.nn.Linear = null,
|
||||
pub const Config = struct {
|
||||
bos_token_id: u32,
|
||||
eos_token_id: stdx.json.Union(union(enum) {
|
||||
int: u32,
|
||||
ints: []u32,
|
||||
}),
|
||||
num_hidden_layers: usize,
|
||||
num_attention_heads: usize,
|
||||
num_key_value_heads: usize,
|
||||
rope_theta: f32,
|
||||
max_position_embeddings: usize,
|
||||
rms_norm_eps: f32,
|
||||
};
|
||||
|
||||
pub const Options = struct {
|
||||
sampling_strategy: ?zml.nn.SamplingStrategy,
|
||||
max_seq_len: usize,
|
||||
};
|
||||
|
||||
lm_head: ?zml.nn.Linear,
|
||||
model: Llama,
|
||||
|
||||
// Options controlling generation
|
||||
gen_opts: zml.nn.SamplingStrategy = .{},
|
||||
config: Config,
|
||||
|
||||
pub fn init(self: *LlamaLM, options: LlamaOptions) void {
|
||||
self.gen_opts = options.gen_opts;
|
||||
self.model.max_seq_len = options.max_seq_len;
|
||||
self.model.num_heads = options.num_heads;
|
||||
self.model.num_kv_heads = options.num_kv_heads;
|
||||
self.model.rope_opts = options.rope_opts;
|
||||
pub fn init(self: *LlamaLM, config: Config, options: Options) void {
|
||||
self.config = config;
|
||||
self.gen_opts = options.sampling_strategy orelse .{};
|
||||
self.model.max_seq_len = @intCast(options.max_seq_len);
|
||||
self.model.num_heads = @intCast(config.num_attention_heads);
|
||||
self.model.num_kv_heads = @intCast(config.num_key_value_heads);
|
||||
self.model.rope_opts = .{
|
||||
.impl = .sequential,
|
||||
.freq_base = config.rope_theta,
|
||||
};
|
||||
for (self.model.layers) |*layer| {
|
||||
layer.self_attn.num_heads = options.num_heads;
|
||||
layer.self_attn.num_kv_heads = options.num_kv_heads;
|
||||
layer.self_attn.rope_opts = options.rope_opts;
|
||||
layer.input_layernorm.eps = options.rms_norm_eps;
|
||||
layer.post_attention_layernorm.eps = options.rms_norm_eps;
|
||||
layer.self_attn.num_heads = self.model.num_heads;
|
||||
layer.self_attn.num_kv_heads = self.model.num_kv_heads;
|
||||
layer.self_attn.rope_opts = self.model.rope_opts;
|
||||
layer.input_layernorm.eps = config.rms_norm_eps;
|
||||
layer.post_attention_layernorm.eps = config.rms_norm_eps;
|
||||
layer.mlp.up_proj.weight = layer.mlp.up_proj.weight.withSharding(.{0});
|
||||
layer.mlp.gate_proj.weight = layer.mlp.gate_proj.weight.withSharding(.{0});
|
||||
layer.mlp.down_proj.weight = layer.mlp.down_proj.weight.withSharding(.{1});
|
||||
@ -54,88 +68,58 @@ pub const LlamaLM = struct {
|
||||
|
||||
// TODO(Corentin): Fix lm_head sharding when top-k sampling is enabled.
|
||||
// It currently crashes/compilation fails
|
||||
if (options.gen_opts.topk == 1) {
|
||||
if (self.lm_head) |lm_head| {
|
||||
self.lm_head.?.weight = lm_head.weight.withSharding(.{0});
|
||||
}
|
||||
if (self.gen_opts.topk == 1 and self.lm_head != null) {
|
||||
self.lm_head.?.weight = self.lm_head.?.weight.withSharding(.{0});
|
||||
}
|
||||
}
|
||||
|
||||
/// Predicts the token at `token_index` position.
|
||||
/// Returns:
|
||||
/// - updated `tokens`,
|
||||
/// - `token_idx` + 1,
|
||||
/// - updated KV cache
|
||||
/// - a Rng state to allow for probabilistic generation
|
||||
pub fn forward(
|
||||
self: LlamaLM,
|
||||
tokens_: Tensor,
|
||||
token_index: Tensor,
|
||||
kv_cache: ?KvCache,
|
||||
kv_cache: KvCache,
|
||||
rng: Tensor.Rng,
|
||||
) struct { Tensor, Tensor, KvCache, Tensor.Rng } {
|
||||
stdx.debug.assert(tokens_.dtype() == .i32 and tokens_.rank() >= 1 and token_index.dtype() == .i32 and token_index.rank() == 0, "Can't run Llama ! Expected >=1d tokens and 0d token_index, got: {} and {}", .{ tokens_, token_index });
|
||||
) struct { Tensor, KvCache, Tensor.Rng } {
|
||||
stdx.debug.assert(tokens_.dtype() == .u32 and tokens_.rank() >= 1 and token_index.dtype() == .u32 and token_index.rank() <= 1, "Can't run Llama ! Expected >=1d tokens and 0d token_index, got: {} and {}", .{ tokens_, token_index });
|
||||
|
||||
var tokens = tokens_.withPartialTags(.{.s});
|
||||
const out, const updated_kv_cache = zml.call(self.model, .forward, .{ tokens, if (kv_cache == null) null else token_index, kv_cache });
|
||||
tokens, const new_rng = self.updateTokens(tokens, token_index, out, rng, self.gen_opts);
|
||||
return .{ tokens, increment(0, token_index), updated_kv_cache, new_rng };
|
||||
const out, const updated_kv_cache = zml.call(self.model, .forward, .{ tokens, token_index, kv_cache });
|
||||
tokens, const new_rng = self.sampleTokens(self.lm_head, tokens, out, rng, self.gen_opts);
|
||||
return .{ tokens, updated_kv_cache, new_rng };
|
||||
}
|
||||
|
||||
pub fn updateTokens(
|
||||
pub fn sampleTokens(
|
||||
self: LlamaLM,
|
||||
lm_head_: ?zml.nn.Linear,
|
||||
tokens_: Tensor,
|
||||
token_index: Tensor,
|
||||
out_: Tensor,
|
||||
rng: Tensor.Rng,
|
||||
opts: zml.nn.SamplingStrategy,
|
||||
) struct { Tensor, Tensor.Rng } {
|
||||
const tokens = tokens_.withPartialTags(.{.s});
|
||||
const out = out_.withPartialTags(.{ .s, .d });
|
||||
|
||||
const next_token_pred = out.gatherValues(.s, token_index, .{});
|
||||
var logits = if (self.lm_head) |lm_head|
|
||||
zml.call(lm_head, .forward, .{next_token_pred})
|
||||
else
|
||||
self.model.embed_tokens.weight.withTags(.{ .voc, .d }).dot(next_token_pred, .{.d});
|
||||
var logits = blk: {
|
||||
if (lm_head_) |lm_head| {
|
||||
break :blk zml.call(lm_head, .forward, .{out});
|
||||
} else {
|
||||
break :blk self.model.embed_tokens.weight.withTags(.{ .voc, .d }).dot(out, .{.d});
|
||||
}
|
||||
};
|
||||
|
||||
if (logits.shape().hasTag(.voc) == null)
|
||||
logits = logits.rename(.{ .d = .voc });
|
||||
|
||||
const next_token, const new_rng = zml.nn.sampleTokens(logits, opts, rng);
|
||||
const next_token_index = token_index.addConstant(1);
|
||||
const new_tokens = tokens.dynamicUpdateSlice(.{ .s = next_token_index }, next_token);
|
||||
|
||||
return .{ new_tokens.reuseBuffer(tokens_), new_rng };
|
||||
const next_tokens, const new_rng = zml.nn.sampleTokens(logits, opts, rng);
|
||||
return .{ next_tokens.reuseBuffer(tokens_), new_rng };
|
||||
}
|
||||
|
||||
pub fn increment(_: u8, token_index: Tensor) Tensor {
|
||||
return token_index.addConstant(1);
|
||||
}
|
||||
|
||||
/// Run the generation entirely within pjrt.
|
||||
pub fn generate(self: LlamaLM, tokens: Tensor, token_index: Tensor, rng: Tensor.Rng) Tensor {
|
||||
// Generate the first token using the prompt and generate the KV-cache initial values.
|
||||
const prefill = zml.call(self, .forward, .{ tokens, token_index, null, rng });
|
||||
|
||||
const Gen = struct {
|
||||
/// Same as LlamaLM.forward but without optional in the signature
|
||||
pub fn forward(lm: LlamaLM, t_ids: Tensor, t_idx: Tensor, kv_cache_: KvCache, inner_rng: Tensor.Rng) struct { Tensor, Tensor, KvCache, Tensor.Rng } {
|
||||
var kv_cache = kv_cache_;
|
||||
kv_cache.k = kv_cache.k.withPartialTags(.{ .layer, .h, .k, .hd });
|
||||
kv_cache.v = kv_cache.v.withPartialTags(.{ .layer, .h, .k, .hd });
|
||||
return zml.call(lm, .forward, .{ t_ids._ctx, t_ids, t_idx, kv_cache, inner_rng });
|
||||
}
|
||||
// / Stops when we generated `max_seq_len` tokens.
|
||||
pub fn shouldContinue(lm: LlamaLM, t_ids: Tensor, t_idx: Tensor, kv_cache: KvCache, inner_rng: Tensor.Rng) Tensor {
|
||||
_ = kv_cache;
|
||||
_ = inner_rng;
|
||||
std.debug.assert(t_ids.dim(1) == lm.model.max_seq_len);
|
||||
return t_idx.cmp(.LT, Tensor.scalar(t_ids._ctx, lm.model.max_seq_len, t_idx.dtype()));
|
||||
}
|
||||
};
|
||||
// Generate remaining tokens using the KV-cache, return tokens.
|
||||
return zml.ops.while_(Gen.shouldContinue, Gen.forward, self, prefill)[0];
|
||||
return token_index.addConstant(1).reuseBuffer(token_index);
|
||||
}
|
||||
};
|
||||
|
||||
@ -177,33 +161,28 @@ pub const Llama = struct {
|
||||
|
||||
/// Forward one token, using KV cache for previous tokens.
|
||||
/// Returns result and updated KV cache.
|
||||
pub fn forward(self: Llama, tokens: Tensor, token_index: ?Tensor, kv_cache: ?KvCache) struct { Tensor, KvCache } {
|
||||
const embeds = embed(self.embed_tokens, tokens, token_index);
|
||||
pub fn forward(self: Llama, tokens: Tensor, token_index: Tensor, kv_cache: KvCache) struct { Tensor, KvCache } {
|
||||
const embeds = embed(self.embed_tokens, tokens);
|
||||
|
||||
var hidden = embeds;
|
||||
const kv_cache0 = kv_cache orelse self.initKvCache(embeds.shape());
|
||||
|
||||
var updated_kv_cache = kv_cache0;
|
||||
var updated_kv_cache = kv_cache;
|
||||
for (self.layers, 0..) |layer, i| {
|
||||
hidden, updated_kv_cache = zml.call(layer, .forward, .{ hidden, token_index, updated_kv_cache.atLayer(i) });
|
||||
}
|
||||
const output = zml.call(self.norm, .forward, .{hidden});
|
||||
|
||||
return .{ output, updated_kv_cache.reuseBuffer(kv_cache0) };
|
||||
return .{ output, updated_kv_cache.reuseBuffer(kv_cache) };
|
||||
}
|
||||
|
||||
pub fn embed(embed_tokens_: zml.nn.TokenEmbedding, tokens_: Tensor, token_index: ?Tensor) Tensor {
|
||||
const tokens = if (token_index) |idx|
|
||||
tokens_.dynamicSlice1d(-1, .{ .start = idx, .len = 1 })
|
||||
else
|
||||
tokens_;
|
||||
return zml.call(embed_tokens_, .forward, .{tokens}).withPartialTags(.{ .s, .d });
|
||||
pub fn embed(embed_tokens_: zml.nn.TokenEmbedding, tokens_: Tensor) Tensor {
|
||||
return zml.call(embed_tokens_, .forward, .{tokens_}).withPartialTags(.{.d});
|
||||
}
|
||||
|
||||
fn initKvCache(self: Llama, embed_shape: zml.Shape) KvCache {
|
||||
const dims = self.shape();
|
||||
var kv_shape = embed_shape.insert(0, .{ .layer = dims.layer }).rename(.{ .s = .k }).splitAxes(.{ .d = .{ .h = dims.nkvh, .hd = dims.hd } });
|
||||
const perm = kv_shape.contiguousPerm(.{ .h, .k, .hd });
|
||||
const perm = kv_shape.contiguousPerm(.{ .k, .h, .hd });
|
||||
kv_shape = kv_shape.transpose(perm.constSlice());
|
||||
return KvCache.init(kv_shape);
|
||||
}
|
||||
@ -218,8 +197,8 @@ pub const TransformerLayer = struct {
|
||||
pub fn forward(
|
||||
self: TransformerLayer,
|
||||
x0: Tensor,
|
||||
token_index: ?Tensor,
|
||||
kv_cache: ?KvCache,
|
||||
token_index: Tensor,
|
||||
kv_cache: KvCache,
|
||||
) struct { Tensor, KvCache } {
|
||||
// Self Attention
|
||||
//log.debug("TransformerLayer({}) -> {}", .{ x0, self.input_layernorm.forward(x0) });
|
||||
@ -287,39 +266,41 @@ pub const SelfAttn = struct {
|
||||
pub fn forward(
|
||||
self: SelfAttn,
|
||||
x: Tensor,
|
||||
token_index: ?Tensor,
|
||||
kv_cache_: ?KvCache,
|
||||
token_index: Tensor,
|
||||
kv_cache: KvCache,
|
||||
) struct { Tensor, KvCache } {
|
||||
// log.debug("x.shape: {}", .{x.shape()});
|
||||
const num_kv_heads = if (self.num_kv_heads > 0) self.num_kv_heads else self.num_heads;
|
||||
var q = zml.call(self.q_proj, .forward, .{x}).splitAxis(-1, .{ .h = self.num_heads, .hd = .auto }).withSharding(.{.h});
|
||||
var k = zml.call(self.k_proj, .forward, .{x}).splitAxis(-1, .{ .h = num_kv_heads, .hd = .auto }).withSharding(.{.h});
|
||||
var v = zml.call(self.v_proj, .forward, .{x}).splitAxis(-1, .{ .h = num_kv_heads, .hd = .auto }).withSharding(.{.h});
|
||||
|
||||
// Generate the attention mask.
|
||||
const kv_cache = kv_cache_ orelse initKvCache(k.shape());
|
||||
const seq_len = kv_cache.k.dim(.k);
|
||||
var attn_mask = zml.nn.causalAttnMask(.{ .q = seq_len, .k = seq_len }, x.dtype(), null);
|
||||
if (token_index) |idx| {
|
||||
// Note: in Pytorch it would be very inefficient to generate the full attn_mask,
|
||||
// then slice into it, but XLA is able to optimize this correctly.
|
||||
attn_mask = attn_mask.dynamicSlice(.{ .q = .{ .start = idx, .len = 1 } });
|
||||
}
|
||||
|
||||
// Note: in Pytorch it would be very inefficient to generate the full attn_mask,
|
||||
// then slice into it, but XLA is able to optimize this correctly.
|
||||
attn_mask = attn_mask.gatherSlices(zml.Shape.init(.{ .q = x.dim(.s) }, attn_mask.dtype()), token_index.reshape(.{ .b = token_index.shape().dim(0), .coord = 1 }), .{});
|
||||
|
||||
// In self-attention, .s axis is used both for keys and queries.
|
||||
q = zml.nn.rope(q, token_index, self.rope_opts);
|
||||
k = zml.nn.rope(k, token_index, self.rope_opts);
|
||||
const pos_index = b: {
|
||||
const temp = Tensor.arange(.{ .end = x.dim(.s) }, token_index.dtype()).withTags(.{.s}).broad(zml.Shape.init(.{ .b = token_index.shape().dim(0), .s = x.dim(.s) }, token_index.dtype()));
|
||||
break :b temp.add(token_index.withTags(.{.b}).broad(temp.shape()));
|
||||
};
|
||||
|
||||
q = zml.nn.rope(q, pos_index, self.rope_opts);
|
||||
k = zml.nn.rope(k, pos_index, self.rope_opts);
|
||||
q = q.rename(.{ .s = .q });
|
||||
k = k.rename(.{ .s = .k });
|
||||
v = v.rename(.{ .s = .k });
|
||||
|
||||
const new_kv_cache = kv_cache.update(k, v, token_index orelse Tensor.scalar(0, .i32));
|
||||
if (token_index) |_| {
|
||||
stdx.debug.assert(q.dim(.q) == 1, "Expected dimension .q to be 1, got {}", .{q.dim(.q)});
|
||||
k = new_kv_cache.keys();
|
||||
v = new_kv_cache.values();
|
||||
}
|
||||
const dtype = q.dtype();
|
||||
const new_kv_cache = kv_cache.update(k, v, token_index);
|
||||
k = new_kv_cache.keys().convert(dtype);
|
||||
v = new_kv_cache.values().convert(dtype);
|
||||
|
||||
const attn_output = zml.nn.sdpa(q, k, v, .{ .attn_mask = attn_mask, .allow_cudnn = false });
|
||||
const attn_output = zml.nn.sdpa(q, k, v, .{ .attn_mask = attn_mask, .allow_cudnn = true });
|
||||
// const attn_output = zml.nn.sdpaMemEfficient(q, k, v, .{ .attn_mask = attn_mask }, .{ .q_chunk_size = 4096, .k_chunk_size = 1024 });
|
||||
const attn = attn_output.merge(.{ .d = .{ .h, .hd } }).rename(.{ .q = .s });
|
||||
return .{ zml.call(self.o_proj, .forward, .{attn}), new_kv_cache };
|
||||
}
|
||||
@ -330,7 +311,7 @@ pub const SelfAttn = struct {
|
||||
const perm = kv_shape.contiguousPerm(.{ .h, .k, .hd });
|
||||
kv_shape = kv_shape.transpose(perm.constSlice());
|
||||
var res = KvCache.init(kv_shape);
|
||||
res.layer_index = Tensor.scalar(0, .i32);
|
||||
res.layer_index = Tensor.scalar(0, .u32);
|
||||
return res;
|
||||
}
|
||||
};
|
||||
@ -345,7 +326,7 @@ pub const KvCache = struct {
|
||||
return .{
|
||||
.k = Tensor.constant(kv_shape, kv_shape.dtype().one()).withSharding(.{.h}),
|
||||
.v = Tensor.constant(kv_shape, kv_shape.dtype().one()).withSharding(.{.h}),
|
||||
.layer_index = Tensor.scalar(-1, .i32),
|
||||
.layer_index = Tensor.scalar(-1, .u32),
|
||||
};
|
||||
}
|
||||
|
||||
@ -353,7 +334,15 @@ pub const KvCache = struct {
|
||||
return .{
|
||||
.k = kv_shape,
|
||||
.v = kv_shape,
|
||||
.layer_index = zml.Shape.init(.{}, .i32),
|
||||
.layer_index = zml.Shape.init(.{}, .u32),
|
||||
};
|
||||
}
|
||||
|
||||
pub fn initBuffer(kv_shape: zml.Shape, platform: zml.Platform) !zml.Bufferized(KvCache) {
|
||||
return .{
|
||||
.k = try zml.Buffer.constant(platform, kv_shape, 1),
|
||||
.v = try zml.Buffer.constant(platform, kv_shape, 1),
|
||||
.layer_index = try zml.Buffer.constant(platform, zml.Shape.init(.{}, .u32), 0),
|
||||
};
|
||||
}
|
||||
|
||||
@ -365,17 +354,33 @@ pub const KvCache = struct {
|
||||
return self.v.dynamicSlice(.{ .layer = .{ .start = self.layer_index, .len = 1 } }).squeeze(.layer);
|
||||
}
|
||||
|
||||
pub fn update(self: KvCache, new_k: Tensor, new_v: Tensor, token_index: Tensor) KvCache {
|
||||
return .{
|
||||
.k = self.k.dynamicUpdateSlice(
|
||||
.{ .layer = self.layer_index, .k = token_index },
|
||||
// transpose to match kv-cache layout
|
||||
new_k.contiguous(.{ .h, .k, .hd }),
|
||||
pub fn update(self: KvCache, new_k: Tensor, new_v: Tensor, token_index: ?Tensor) KvCache {
|
||||
const k_shape = self.k.shape().drop(.layer);
|
||||
var layer = self.layer_index;
|
||||
layer = if (token_index) |idx| layer.broad(idx.shape()) else layer;
|
||||
|
||||
return if (token_index) |idx| .{
|
||||
.k = self.k.scatterSlices(
|
||||
.{ .layer = layer, .k = idx },
|
||||
new_k.convert(self.k.dtype()).transpose(k_shape),
|
||||
.{ .indices_are_sorted = true, .update_fn = zml.Tensor.ScatterOpts.override },
|
||||
).reuseBuffer(self.k),
|
||||
.v = self.v.dynamicUpdateSlice(
|
||||
.{ .layer = self.layer_index, .k = token_index },
|
||||
// transpose to match kv-cache layout
|
||||
new_v.contiguous(.{ .h, .k, .hd }),
|
||||
.v = self.v.scatterSlices(
|
||||
.{ .layer = layer, .k = idx },
|
||||
new_v.convert(self.v.dtype()).transpose(k_shape),
|
||||
.{ .indices_are_sorted = true, .update_fn = zml.Tensor.ScatterOpts.override },
|
||||
).reuseBuffer(self.v),
|
||||
.layer_index = self.layer_index,
|
||||
} else .{
|
||||
.k = self.k.scatterSlices(
|
||||
.{ .layer = layer },
|
||||
new_k.convert(self.k.dtype()).transpose(k_shape),
|
||||
.{ .indices_are_sorted = true, .update_fn = zml.Tensor.ScatterOpts.override },
|
||||
).reuseBuffer(self.k),
|
||||
.v = self.v.scatterSlices(
|
||||
.{ .layer = layer },
|
||||
new_v.convert(self.v.dtype()).transpose(k_shape),
|
||||
.{ .indices_are_sorted = true, .update_fn = zml.Tensor.ScatterOpts.override },
|
||||
).reuseBuffer(self.v),
|
||||
.layer_index = self.layer_index,
|
||||
};
|
||||
@ -385,7 +390,7 @@ pub const KvCache = struct {
|
||||
return .{
|
||||
.k = self.k,
|
||||
.v = self.v,
|
||||
.layer_index = Tensor.scalar(layer_index, .i32),
|
||||
.layer_index = Tensor.scalar(layer_index, .u32),
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@ -1,248 +1,330 @@
|
||||
const asynk = @import("async");
|
||||
const flags = @import("tigerbeetle/flags");
|
||||
const clap = @import("clap");
|
||||
const std = @import("std");
|
||||
const stdx = @import("stdx");
|
||||
const zml = @import("zml");
|
||||
|
||||
const llama_mod = @import("llama.zig");
|
||||
const llama = @import("llama.zig");
|
||||
|
||||
const LlamaLM = llama_mod.LlamaLM;
|
||||
const Llama = llama_mod.Llama;
|
||||
const KvCache = llama_mod.KvCache;
|
||||
const TransformerLayer = llama_mod.TransformerLayer;
|
||||
const SelfAttn = llama_mod.SelfAttn;
|
||||
const LlamaLM = llama.LlamaLM;
|
||||
const Llama = llama.Llama;
|
||||
const KvCache = llama.KvCache;
|
||||
const TransformerLayer = llama.TransformerLayer;
|
||||
const SelfAttn = llama.SelfAttn;
|
||||
const Buffer = zml.Buffer;
|
||||
const Tensor = zml.Tensor;
|
||||
const ShapeOf = zml.ShapeOf;
|
||||
|
||||
const log = std.log.scoped(.llama);
|
||||
|
||||
const eos_tokens: [3]i32 = .{ 128001, 128008, 128009 };
|
||||
|
||||
// set this to false to disable the verbose logging
|
||||
const show_mlir = true;
|
||||
|
||||
pub const std_options = .{
|
||||
.log_level = .warn,
|
||||
.log_scope_levels = &[_]std.log.ScopeLevel{
|
||||
.{ .scope = .zml_module, .level = if (show_mlir) .debug else .warn },
|
||||
.{ .scope = .llama, .level = .info },
|
||||
},
|
||||
.logFn = asynk.logFn,
|
||||
.log_level = .info,
|
||||
};
|
||||
|
||||
pub fn tokenizePromptLlama3(allocator: std.mem.Allocator, tokenizer: zml.tokenizer.Tokenizer, config: LlamaLM.Config, prompt: []const u8) ![]u32 {
|
||||
var tokens = std.ArrayList(u32).init(allocator);
|
||||
var encoder = try tokenizer.encoder();
|
||||
defer encoder.deinit();
|
||||
|
||||
const start_header_id = tokenizer.token_to_id("<|start_header_id|>") orelse return error.NoSuchToken;
|
||||
const end_header_id = tokenizer.token_to_id("<|end_header_id|>") orelse return error.NoSuchToken;
|
||||
const eot_id = tokenizer.token_to_id("<|eot_id|>") orelse return error.NoSuchToken;
|
||||
const newline_id = (try encoder.encode("\n"))[0];
|
||||
|
||||
try tokens.append(config.bos_token_id);
|
||||
|
||||
try tokens.append(start_header_id);
|
||||
try tokens.appendSlice(try encoder.encode("user"));
|
||||
try tokens.appendSlice(&.{ end_header_id, newline_id, newline_id });
|
||||
|
||||
try tokens.appendSlice(try encoder.encode(prompt));
|
||||
try tokens.appendSlice(&.{ eot_id, newline_id });
|
||||
try tokens.appendSlice(try encoder.encode("\n"));
|
||||
try tokens.append(start_header_id);
|
||||
try tokens.appendSlice(try encoder.encode("assistant"));
|
||||
try tokens.append(end_header_id);
|
||||
|
||||
return tokens.toOwnedSlice();
|
||||
}
|
||||
|
||||
pub fn generateText(
|
||||
llama: LlamaLM,
|
||||
config: LlamaLM.Config,
|
||||
llama_: LlamaLM,
|
||||
mod_prefill: zml.ModuleExe(LlamaLM.forward),
|
||||
mod: zml.ModuleExe(LlamaLM.forward),
|
||||
mod_generate: zml.ModuleExe(LlamaLM.forward),
|
||||
kv_cache_: zml.Bufferized(llama.KvCache),
|
||||
tokenizer: zml.tokenizer.Tokenizer,
|
||||
allocator: std.mem.Allocator,
|
||||
seed: u128,
|
||||
prompt: []const u8,
|
||||
skip_llama3_encoding: bool,
|
||||
) ![]const u8 {
|
||||
const prompt_tok = tokenizer.encode(allocator, prompt, .{}) catch unreachable;
|
||||
log.debug("Tokenized Prompt {d}", .{prompt_tok});
|
||||
const dims = llama.model.shape();
|
||||
const max_seq_len = dims.s;
|
||||
const token_buffer = try allocator.alloc(i32, @intCast(max_seq_len));
|
||||
@memset(token_buffer, 0);
|
||||
for (0..prompt_tok.len) |i| {
|
||||
token_buffer[i] = @intCast(prompt_tok[i]);
|
||||
}
|
||||
var tokenizer_encoder = try tokenizer.encoder();
|
||||
defer tokenizer_encoder.deinit();
|
||||
var tokenizer_decoder = try tokenizer.decoder();
|
||||
defer tokenizer_decoder.deinit();
|
||||
|
||||
const tracer_buffer = try allocator.alloc(u8, @intCast(max_seq_len));
|
||||
defer allocator.free(token_buffer);
|
||||
defer allocator.free(tracer_buffer);
|
||||
const prompt_tok: []const u32 = if (skip_llama3_encoding) try tokenizer_encoder.encode(prompt) else try tokenizePromptLlama3(allocator, tokenizer, config, prompt);
|
||||
defer allocator.free(prompt_tok);
|
||||
var output = std.ArrayList(u8).init(allocator);
|
||||
defer output.deinit();
|
||||
|
||||
var tokens = try zml.Buffer.fromSlice(mod.platform(), .{max_seq_len}, token_buffer);
|
||||
var prefill_token_index = try zml.Buffer.fromSlice(mod.platform(), .{}, &[_]i32{@intCast(prompt_tok.len - 1)});
|
||||
const dims = llama_.model.shape();
|
||||
const max_seq_len = dims.s;
|
||||
|
||||
// Prefill
|
||||
// initialize a 0..max_seq_len buffer with the tokenized prompt
|
||||
const prefill_buffer = try allocator.alloc(u32, @intCast(max_seq_len));
|
||||
@memset(prefill_buffer, 0);
|
||||
for (0..prompt_tok.len) |i| {
|
||||
prefill_buffer[i] = @intCast(prompt_tok[i]);
|
||||
}
|
||||
defer allocator.free(prefill_buffer);
|
||||
|
||||
const platform = mod_generate.platform();
|
||||
|
||||
// prepare device buffers for the prefill tokens and the index
|
||||
var prefill_tokens = try zml.Buffer.fromSlice(platform, .{max_seq_len}, prefill_buffer);
|
||||
defer prefill_tokens.deinit();
|
||||
var prefill_token_index = try zml.Buffer.fromSlice(platform, .{}, &[_]u32{0});
|
||||
defer prefill_token_index.deinit();
|
||||
|
||||
var rng = try zml.Tensor.Rng.init(mod.platform(), seed);
|
||||
tokens, var token_index, var kv_cache, rng = mod_prefill.call(.{ tokens, prefill_token_index, null, rng });
|
||||
defer token_index.deinit();
|
||||
// init RNG and prefill
|
||||
var rng = try zml.Tensor.Rng.init(platform, seed);
|
||||
prefill_tokens, var kv_cache, rng = mod_prefill.call(.{ prefill_tokens, prefill_token_index, kv_cache_, rng });
|
||||
defer kv_cache.k.deinit();
|
||||
defer kv_cache.v.deinit();
|
||||
defer kv_cache.layer_index.deinit();
|
||||
|
||||
// Prepare for token-by-token generation
|
||||
var first_token_hostbuffer = [_]u32{prompt_tok[prompt_tok.len - 1]}; // start with the prompt's last token
|
||||
var current_token = try zml.Buffer.fromSlice(platform, .{}, &first_token_hostbuffer);
|
||||
defer current_token.deinit();
|
||||
|
||||
// Here we will copy the generated token from device
|
||||
var generated_token_buffer = [_]u32{0};
|
||||
|
||||
// Here we collect the generated text
|
||||
var output = std.ArrayList(u8).init(allocator);
|
||||
defer output.deinit();
|
||||
|
||||
const tracer_buffer = try allocator.alloc(u8, @intCast(max_seq_len));
|
||||
defer allocator.free(tracer_buffer);
|
||||
const tracer = zml.tools.Tracer.init("ai.zml.models.llama");
|
||||
var decode_progress = prompt_tok.len;
|
||||
const output_tokens_len = max_seq_len - prompt_tok.len - 1;
|
||||
|
||||
const start = std.time.microTimestamp();
|
||||
const output_freq: u8 = 1;
|
||||
var eos_index: ?usize = null;
|
||||
for (0..output_tokens_len) |i| {
|
||||
//_ = i;
|
||||
|
||||
var num_tokens_generated: usize = 0;
|
||||
|
||||
generation: for (0..output_tokens_len) |i| {
|
||||
const frame_id = tracer.frameStart(try std.fmt.bufPrintZ(tracer_buffer, "Generate token {}/{}", .{ i + 1, output_tokens_len }));
|
||||
tokens, const new_token_index, kv_cache, rng = mod.call(.{ tokens, token_index, kv_cache, rng });
|
||||
token_index.deinit();
|
||||
token_index = new_token_index;
|
||||
if ((i + 1) % output_freq == 0) {
|
||||
const n = output.items.len;
|
||||
_ = try tokens.toHost(std.mem.sliceAsBytes(token_buffer));
|
||||
try tokenizer.decodeWithOpts(&output, @ptrCast(token_buffer[decode_progress..][0..output_freq]), .{});
|
||||
decode_progress += output_freq;
|
||||
std.debug.print("{s}", .{output.items[n..]});
|
||||
tracer.frameEnd(frame_id, try std.fmt.bufPrintZ(tracer_buffer, "Decoded token {}/{} : {s}", .{ i + 1, output_tokens_len, output.items[n..] }));
|
||||
if (std.mem.indexOfAny(i32, token_buffer[decode_progress - output_freq ..], &eos_tokens)) |index| {
|
||||
// Handle strange scenarios when eos id isn't the very next token after decode_progress
|
||||
eos_index = decode_progress - output_freq + index;
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
tracer.frameEnd(frame_id, try std.fmt.bufPrintZ(tracer_buffer, "Generated token {}/{}", .{ i + 1, output_tokens_len }));
|
||||
|
||||
// current token index needs to go into a zml.Buffer
|
||||
const token_index_buffer = &[_]u32{@intCast(prompt_tok.len + i)};
|
||||
const token_index = try zml.Buffer.fromSlice(platform, .{}, token_index_buffer);
|
||||
defer token_index.deinit();
|
||||
|
||||
// call to generate the next token
|
||||
current_token, kv_cache, rng = mod_generate.call(.{ current_token, token_index, kv_cache, rng });
|
||||
|
||||
tracer.frameEnd(frame_id, try std.fmt.bufPrintZ(tracer_buffer, "Generated token {}/{}", .{ i + 1, output_tokens_len }));
|
||||
|
||||
// extract the generated token from the buffer
|
||||
_ = try current_token.toHost(std.mem.sliceAsBytes(&generated_token_buffer));
|
||||
const generated_token = generated_token_buffer[0];
|
||||
// de-tokenize generated token into a string
|
||||
const chunk = try tokenizer_decoder.next(@intCast(generated_token)) orelse unreachable;
|
||||
num_tokens_generated = i;
|
||||
|
||||
// check for eos
|
||||
switch (config.eos_token_id.value) {
|
||||
.int => |eos| if (generated_token == @as(u32, @intCast(eos))) break :generation,
|
||||
.ints => |eos_list| {
|
||||
for (eos_list) |eos| {
|
||||
if (generated_token == @as(u32, @intCast(eos))) break :generation;
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
// collect and print generated sequence
|
||||
try output.appendSlice(chunk);
|
||||
std.debug.print("{s}", .{chunk});
|
||||
}
|
||||
var total_token_count: usize = max_seq_len;
|
||||
const n = output.items.len;
|
||||
if (eos_index) |end_idx| {
|
||||
// count = eos index + 1
|
||||
total_token_count = end_idx + 1;
|
||||
}
|
||||
const generated_token_count = total_token_count - prompt_tok.len;
|
||||
try tokenizer.decodeWithOpts(&output, @ptrCast(token_buffer[decode_progress..total_token_count]), .{});
|
||||
std.debug.print("{s}\n", .{output.items[n..]});
|
||||
const end = std.time.microTimestamp();
|
||||
|
||||
const duration = stdx.math.divFloat(f64, end - start, std.time.us_per_s);
|
||||
const speed = @as(f64, @floatFromInt(generated_token_count)) / duration;
|
||||
log.info("✅ Generated {d} tokens in {:.3}s: {d:.3}tok/s", .{ generated_token_count, duration, speed });
|
||||
|
||||
_ = try tokens.toHost(std.mem.sliceAsBytes(token_buffer));
|
||||
output.clearRetainingCapacity();
|
||||
|
||||
try tokenizer.decodeWithOpts(&output, @ptrCast(token_buffer[0..total_token_count]), .{});
|
||||
const speed = @as(f64, @floatFromInt(num_tokens_generated)) / duration;
|
||||
std.debug.print("\n", .{});
|
||||
log.info("✅ Generated {d} tokens in {:.3}s: {d:.3}tok/s", .{ num_tokens_generated, duration, speed });
|
||||
return output.toOwnedSlice();
|
||||
}
|
||||
|
||||
const params = clap.parseParamsComptime(
|
||||
\\--help print this help
|
||||
\\--prompt <STRING> the prompt
|
||||
\\--config <PATH> config.json path
|
||||
\\--weights <PATH> model weights path
|
||||
\\--tokenizer <PATH> tokenizer path
|
||||
\\--seed <UINT> random seed (optional)
|
||||
\\--seq-len <UINT> sequence length
|
||||
\\--create-options <STRING> platform creation options JSON, defaults to {}
|
||||
\\--no-llama3 <BOOL> skip prompt template
|
||||
\\--sharding <BOOL> default: true: sharding on or off
|
||||
);
|
||||
|
||||
pub fn bool_parser(in: []const u8) error{}!bool {
|
||||
return std.mem.indexOfScalar(u8, "tTyY1", in[0]) != null;
|
||||
}
|
||||
|
||||
pub fn main() !void {
|
||||
try asynk.AsyncThread.main(std.heap.c_allocator, asyncMain);
|
||||
}
|
||||
|
||||
pub fn asyncMain() !void {
|
||||
const CliArgs = struct {
|
||||
pub const help =
|
||||
\\ llama --model=llama3.7B.safetensors --tokenizer=vocab.json --num_layers=2
|
||||
;
|
||||
model: []const u8,
|
||||
tokenizer: ?[]const u8 = null,
|
||||
layer_start: u8 = 0,
|
||||
num_layers: ?u8 = null,
|
||||
seq_len: u32 = 256,
|
||||
topk: u32 = 2,
|
||||
temperature: u32 = 1,
|
||||
num_heads: ?i64 = null,
|
||||
num_kv_heads: ?i64 = null,
|
||||
rope_freq_base: ?i64 = null,
|
||||
prompt: ?[]const u8 = null,
|
||||
test_activations: ?[]const u8 = null,
|
||||
seed: ?u128 = null,
|
||||
// eg: --create-options='{"cuda":{"allocator":{"bfc":{"memory_fraction": 0.99}}}}'
|
||||
create_options: []const u8 = "{}",
|
||||
};
|
||||
|
||||
log.info(" LLama was compiled with {}", .{@import("builtin").mode});
|
||||
|
||||
const allocator = std.heap.c_allocator;
|
||||
|
||||
const tmp = try std.fs.openDirAbsolute("/tmp", .{});
|
||||
try tmp.makePath("zml/llama/cache");
|
||||
const parsers = comptime .{
|
||||
.BOOL = bool_parser,
|
||||
.UINT = clap.parsers.int(usize, 0),
|
||||
.STRING = clap.parsers.string,
|
||||
.PATH = clap.parsers.string,
|
||||
};
|
||||
var diag: clap.Diagnostic = .{};
|
||||
const stderr = std.io.getStdErr().writer();
|
||||
var res = clap.parse(clap.Help, ¶ms, parsers, .{
|
||||
.diagnostic = &diag,
|
||||
.allocator = allocator,
|
||||
}) catch |err| {
|
||||
diag.report(stderr, err) catch {};
|
||||
stderr.print("usage: ", .{}) catch {};
|
||||
clap.usage(stderr, clap.Help, ¶ms) catch {};
|
||||
stderr.print("\n", .{}) catch {};
|
||||
return;
|
||||
};
|
||||
defer res.deinit();
|
||||
|
||||
if (res.args.help != 0) {
|
||||
clap.help(std.io.getStdErr().writer(), clap.Help, ¶ms, .{}) catch {};
|
||||
return;
|
||||
}
|
||||
|
||||
const config = blk: {
|
||||
if (res.args.config) |config_json_path| {
|
||||
var config_json_file = try asynk.File.open(config_json_path, .{ .mode = .read_only });
|
||||
defer config_json_file.close() catch unreachable;
|
||||
var reader = std.json.reader(allocator, config_json_file.reader());
|
||||
defer reader.deinit();
|
||||
const config_obj = try std.json.parseFromTokenSourceLeaky(llama.LlamaLM.Config, allocator, &reader, .{ .ignore_unknown_fields = true });
|
||||
break :blk config_obj;
|
||||
} else {
|
||||
log.err("Missing --config", .{});
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
var context = try zml.Context.init();
|
||||
defer context.deinit();
|
||||
|
||||
const compilation_options = zml.CompilationOptions{
|
||||
.xla_dump_to = "/tmp/zml/llama",
|
||||
.sharding_enabled = true,
|
||||
.sharding_enabled = res.args.sharding orelse true,
|
||||
};
|
||||
|
||||
var args = std.process.args();
|
||||
const cli_args = flags.parse(&args, CliArgs);
|
||||
const model_file = cli_args.model;
|
||||
|
||||
var arena_state = std.heap.ArenaAllocator.init(allocator);
|
||||
defer arena_state.deinit();
|
||||
const model_arena = arena_state.allocator();
|
||||
|
||||
const create_opts = try std.json.parseFromSliceLeaky(zml.Platform.CreateOptions, model_arena, cli_args.create_options, .{});
|
||||
const platform = context.autoPlatform(create_opts).withCompilationOptions(compilation_options);
|
||||
// initialize ZML platform with optional create options
|
||||
// eg: --create-options='{"cuda":{"allocator":{"bfc":{"memory_fraction": 0.99}}}}'
|
||||
const create_opts_json = res.args.@"create-options" orelse "{}";
|
||||
const create_opts = try std.json.parseFromSlice(zml.Platform.CreateOptions, allocator, create_opts_json, .{});
|
||||
const platform = context.autoPlatform(create_opts.value).withCompilationOptions(compilation_options);
|
||||
create_opts.deinit();
|
||||
context.printAvailablePlatforms(platform);
|
||||
|
||||
log.info("Model file: {s}", .{model_file});
|
||||
|
||||
var ts = try zml.aio.detectFormatAndOpen(allocator, model_file);
|
||||
var ts = try zml.aio.detectFormatAndOpen(allocator, res.args.weights.?);
|
||||
defer ts.deinit();
|
||||
|
||||
var llama = try zml.aio.populateModel(LlamaLM, model_arena, ts);
|
||||
const num_heads = cli_args.num_heads orelse ts.metadata("num_heads", .int) orelse @panic("--num_heads is required for this model");
|
||||
const num_kv_heads = cli_args.num_kv_heads orelse ts.metadata("num_kv_heads", .int) orelse num_heads;
|
||||
var model_arena = std.heap.ArenaAllocator.init(allocator);
|
||||
var model_instance = try zml.aio.populateModel(llama.LlamaLM, model_arena.allocator(), ts);
|
||||
|
||||
const rope_impl = if (ts.metadata("rope_impl", .string)) |val|
|
||||
std.meta.stringToEnum(zml.nn.RopeOpts.Implementation, val).?
|
||||
else
|
||||
.sequential;
|
||||
|
||||
const llama_options: llama_mod.LlamaOptions = .{
|
||||
.max_seq_len = cli_args.seq_len,
|
||||
.num_kv_heads = num_kv_heads,
|
||||
.num_heads = num_heads,
|
||||
.gen_opts = .{
|
||||
.topk = cli_args.topk,
|
||||
.temperature = @floatFromInt(cli_args.temperature),
|
||||
},
|
||||
.rms_norm_eps = @floatCast(ts.metadata("rms_norm_eps", .float) orelse 1e-5),
|
||||
.rope_opts = .{
|
||||
.impl = rope_impl,
|
||||
.freq_base = @floatCast(ts.metadata("rope_freq_base", .float) orelse @as(f32, @floatFromInt(cli_args.rope_freq_base orelse 10_000))),
|
||||
const llama_options: llama.LlamaLM.Options = .{
|
||||
.max_seq_len = @intCast(res.args.@"seq-len" orelse 256),
|
||||
.sampling_strategy = .{
|
||||
.topk = 1,
|
||||
.temperature = 1.0,
|
||||
},
|
||||
};
|
||||
log.info("✅\tParsed llama config: {}", .{llama_options});
|
||||
llama.init(llama_options);
|
||||
model_instance.init(config, llama_options);
|
||||
|
||||
if (cli_args.tokenizer == null and !std.mem.endsWith(u8, cli_args.model, ".gguf")) {
|
||||
log.err("Model doesn't have an embbedded tokenizer, please provide a path to a tokenizer.", .{});
|
||||
@panic("No tokenizer provided");
|
||||
}
|
||||
const tokenizer_path = cli_args.tokenizer orelse cli_args.model;
|
||||
log.info("\tLoading tokenizer from {s}", .{tokenizer_path});
|
||||
var tokenizer = try zml.aio.detectFormatAndLoadTokenizer(allocator, tokenizer_path);
|
||||
log.info("✅\tLoaded tokenizer from {s}", .{tokenizer_path});
|
||||
defer tokenizer.deinit();
|
||||
const dims = model_instance.model.shape();
|
||||
const dtype = model_instance.model.embed_tokens.weight.dtype();
|
||||
|
||||
const dims = llama.model.shape();
|
||||
const dtype = llama.model.embed_tokens.weight.dtype();
|
||||
const batch_size = 1;
|
||||
|
||||
// Note: we compile the model without a batching dimension.
|
||||
// To do so, we would just need to add `.b = batch_size` to `token_shape` and `kv_shape`.
|
||||
const tokens_shape = zml.Shape.init(.{ .s = dims.s }, .i32);
|
||||
const token_idx_shape = zml.Shape.init(.{}, .i32);
|
||||
const kv_shape = zml.Shape.init(.{ .layer = llama.model.layers.len, .h = dims.nkvh, .k = dims.s, .hd = dims.hd }, dtype).withSharding(.{.h});
|
||||
// needs to be optional
|
||||
const kv_cache_shape: ?ShapeOf(KvCache) = KvCache.initShape(kv_shape);
|
||||
const rng_shape = Tensor.Rng.shape();
|
||||
const tokens_shape_prefill = zml.Shape.init(.{ .b = batch_size, .s = llama_options.max_seq_len }, .u32);
|
||||
const tokens_shape = zml.Shape.init(.{ .b = batch_size, .s = 1 }, .u32);
|
||||
const token_idx_shape = zml.Shape.init(.{ .b = batch_size }, .u32);
|
||||
|
||||
const kv_shape = zml.Shape.init(.{ .layer = model_instance.model.layers.len, .b = batch_size, .k = dims.s, .h = dims.nkvh, .hd = dims.hd }, dtype).withSharding(.{.h});
|
||||
|
||||
const kv_cache_shape: zml.ShapeOf(llama.KvCache) = llama.KvCache.initShape(kv_shape);
|
||||
const rng_shape = zml.Tensor.Rng.shape();
|
||||
|
||||
var start = try std.time.Timer.start();
|
||||
var fut_mod_prefill = try asynk.asyncc(zml.compile, .{ allocator, LlamaLM.forward, .{llama_options}, .{ tokens_shape, token_idx_shape, null, rng_shape }, ts, platform });
|
||||
var fut_mod = try asynk.asyncc(zml.compile, .{ allocator, LlamaLM.forward, .{llama_options}, .{ tokens_shape, token_idx_shape, kv_cache_shape, rng_shape }, ts, platform });
|
||||
var fut_mod_prefill = try asynk.asyncc(zml.compile, .{
|
||||
allocator, llama.LlamaLM.forward, .{ config, llama_options },
|
||||
.{
|
||||
tokens_shape_prefill,
|
||||
token_idx_shape,
|
||||
kv_cache_shape,
|
||||
rng_shape,
|
||||
},
|
||||
ts,
|
||||
platform,
|
||||
});
|
||||
|
||||
log.info("\tLoading Llama weights from {s}...", .{cli_args.model});
|
||||
var llama_weights = try zml.aio.loadBuffers(LlamaLM, .{llama_options}, ts, model_arena, platform);
|
||||
var fut_mod = try asynk.asyncc(zml.compile, .{
|
||||
allocator, llama.LlamaLM.forward, .{ config, llama_options },
|
||||
.{
|
||||
tokens_shape,
|
||||
token_idx_shape,
|
||||
kv_cache_shape,
|
||||
rng_shape,
|
||||
},
|
||||
ts,
|
||||
platform,
|
||||
});
|
||||
|
||||
log.info("\tLoading Llama weights from {?s}...", .{res.args.weights});
|
||||
var llama_weights = try zml.aio.loadBuffers(llama.LlamaLM, .{ config, llama_options }, ts, model_arena.allocator(), platform);
|
||||
defer zml.aio.unloadBuffers(&llama_weights);
|
||||
log.info("✅\tLoaded weights in {d}ms", .{start.read() / std.time.ns_per_ms});
|
||||
log.info("✅\tLoaded weights in {}", .{std.fmt.fmtDuration(start.read())});
|
||||
|
||||
var llama_module_prefill = (try fut_mod_prefill.awaitt()).prepare(llama_weights);
|
||||
defer llama_module_prefill.deinit();
|
||||
var llama_module = (try fut_mod.awaitt()).prepare(llama_weights);
|
||||
defer llama_module.deinit();
|
||||
log.info("✅\tCompiled model in {d}ms", .{start.read() / std.time.ns_per_ms});
|
||||
log.info("✅\tCompiled model in {}", .{std.fmt.fmtDuration(start.read())});
|
||||
|
||||
const prompt = cli_args.prompt orelse "Once upon a time, there was a little girl named Lily.";
|
||||
log.info("Creating KvCache", .{});
|
||||
const kv_cache = try llama.KvCache.initBuffer(kv_shape, platform);
|
||||
|
||||
var tokenizer = blk: {
|
||||
if (res.args.tokenizer) |tok| {
|
||||
log.info("Loading tokenizer from {s}", .{tok});
|
||||
var timer = try stdx.time.Timer.start();
|
||||
defer log.info("Loaded tokenizer from {s} [{}]", .{ tok, timer.read() });
|
||||
|
||||
break :blk try zml.tokenizer.Tokenizer.from_file(model_arena.allocator(), tok);
|
||||
} else {
|
||||
log.err("Missing --tokenizer", .{});
|
||||
return;
|
||||
}
|
||||
};
|
||||
errdefer tokenizer.deinit();
|
||||
|
||||
const prompt = res.args.prompt orelse "What is the capital of France?";
|
||||
log.info("✅\tPrompt: {s}", .{prompt});
|
||||
|
||||
const seed = cli_args.seed orelse @as(u128, @bitCast(std.time.nanoTimestamp()));
|
||||
const story = try generateText(llama, llama_module_prefill, llama_module, tokenizer, allocator, seed, prompt);
|
||||
defer allocator.free(story);
|
||||
const seed = res.args.seed orelse @as(u128, @bitCast(std.time.nanoTimestamp()));
|
||||
const skip_llama3_encoding = res.args.@"no-llama3" orelse false;
|
||||
const generated_text = try generateText(config, model_instance, llama_module_prefill, llama_module, kv_cache, tokenizer, allocator, seed, prompt[0..], skip_llama3_encoding);
|
||||
// generated text will be printed token by token.
|
||||
defer allocator.free(generated_text);
|
||||
}
|
||||
|
||||
@ -16,9 +16,10 @@ pub fn main() !void {
|
||||
pub fn asyncMain() !void {
|
||||
const CliArgs = struct {
|
||||
pub const help =
|
||||
\\ test-implementation --model=llama3.8B.safetensors --reference=activation.safetensors
|
||||
\\ test-implementation --weights=llama3.8B.safetensors --config=config.json --reference=activation.safetensors
|
||||
;
|
||||
model: []const u8,
|
||||
weights: []const u8,
|
||||
config: []const u8,
|
||||
reference: []const u8,
|
||||
num_heads: ?i64 = null,
|
||||
num_kv_heads: ?i64 = null,
|
||||
@ -38,7 +39,7 @@ pub fn asyncMain() !void {
|
||||
// Parse program args
|
||||
var args = std.process.args();
|
||||
const cli_args = flags.parse(&args, CliArgs);
|
||||
const model_file = cli_args.model;
|
||||
const model_file = cli_args.weights;
|
||||
|
||||
// Memory arena dedicated to model shapes and weights
|
||||
var arena_state = std.heap.ArenaAllocator.init(allocator);
|
||||
@ -61,6 +62,16 @@ pub fn asyncMain() !void {
|
||||
else
|
||||
.sequential;
|
||||
|
||||
const config = blk: {
|
||||
var config_json_file = try asynk.File.open(cli_args.config, .{ .mode = .read_only });
|
||||
defer config_json_file.close() catch unreachable;
|
||||
var reader = std.json.reader(allocator, config_json_file.reader());
|
||||
defer reader.deinit();
|
||||
const config_obj = try std.json.parseFromTokenSourceLeaky(LlamaLM.Config, allocator, &reader, .{ .ignore_unknown_fields = true });
|
||||
break :blk config_obj;
|
||||
};
|
||||
std.log.info("Parsed llama config: {}", .{config});
|
||||
|
||||
const llama_options: llama_mod.LlamaOptions = .{
|
||||
.max_seq_len = 256,
|
||||
.num_kv_heads = num_kv_heads,
|
||||
@ -101,26 +112,4 @@ fn testImplementation(
|
||||
try zml.testing.testLayer(platform, buffer_store, "layers.0.mlp", llama.model.layers[0].mlp, llama_weights.model.layers[0].mlp, 1e-2);
|
||||
try zml.testing.testLayer(platform, buffer_store, "layers.0.input_layernorm", llama.model.layers[0].input_layernorm, llama_weights.model.layers[0].input_layernorm, 1e-2);
|
||||
try zml.testing.testLayer(platform, buffer_store, "layers.0.post_attention_layernorm", llama.model.layers[0].post_attention_layernorm, llama_weights.model.layers[0].post_attention_layernorm, 1e-2);
|
||||
|
||||
{
|
||||
const test_case = "layers.0.self_attn";
|
||||
std.log.info("Testing {s}", .{test_case});
|
||||
// Small wrapper to explicitly tag the input, and ignore the extra arguments used in HF implementation.
|
||||
const SelfAttnPrefill = struct {
|
||||
inner: llama_mod.SelfAttn,
|
||||
|
||||
pub fn forward(self: @This(), x_: Tensor) struct { Tensor, llama_mod.KvCache } {
|
||||
return self.inner.forward(x_.withTags(.{ .b, .s, .d }), null, null);
|
||||
}
|
||||
};
|
||||
|
||||
try zml.testing.testLayer(
|
||||
platform,
|
||||
buffer_store,
|
||||
"layers.0.self_attn",
|
||||
SelfAttnPrefill{ .inner = llama.model.layers[0].self_attn },
|
||||
.{ .inner = llama_weights.model.layers[0].self_attn },
|
||||
1e-3,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
9
examples/third_party/com_github_hejsil_clap/clap.bazel
vendored
Normal file
9
examples/third_party/com_github_hejsil_clap/clap.bazel
vendored
Normal file
@ -0,0 +1,9 @@
|
||||
load("@rules_zig//zig:defs.bzl", "zig_library")
|
||||
|
||||
zig_library(
|
||||
name = "clap",
|
||||
import_name = "clap",
|
||||
srcs = glob(["clap/*.zig"]),
|
||||
main = "clap.zig",
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
20
examples/third_party/non_module_deps.bzl
vendored
Normal file
20
examples/third_party/non_module_deps.bzl
vendored
Normal file
@ -0,0 +1,20 @@
|
||||
load("@bazel_tools//tools/build_defs/repo:git.bzl", "new_git_repository")
|
||||
|
||||
def _non_module_deps_impl(mctx):
|
||||
|
||||
new_git_repository(
|
||||
name = "com_github_hejsil_clap",
|
||||
remote = "https://github.com/Hejsil/zig-clap.git",
|
||||
commit = "d71cc39a94f3e6ccbad00c25d350c9147de4df9f",
|
||||
build_file = "//:third_party/com_github_hejsil_clap/clap.bazel",
|
||||
)
|
||||
|
||||
return mctx.extension_metadata(
|
||||
reproducible = True,
|
||||
root_module_direct_deps = "all",
|
||||
root_module_direct_dev_deps = [],
|
||||
)
|
||||
|
||||
non_module_deps = module_extension(
|
||||
implementation = _non_module_deps_impl,
|
||||
)
|
||||
Loading…
Reference in New Issue
Block a user