Update Llama example docs and Bazel build files, and add tests for the new HuggingFace tokenizer integration.
This commit is contained in:
parent
959bc48c42
commit
76e314db9b
@ -59,47 +59,41 @@ Llama is a family of "Large Language Models", trained to generate text, based
|
|||||||
on the beginning of a sentence/book/article. This "beginning" is generally
|
on the beginning of a sentence/book/article. This "beginning" is generally
|
||||||
referred to as the "prompt".
|
referred to as the "prompt".
|
||||||
|
|
||||||
#### TinyLlama, Stories 15M
|
#### Meta Llama 3.1 8B
|
||||||
|
|
||||||
To start, you can use a small model trained specifically on children's history
|
|
||||||
books. This model has been trained by [Andrej Karpathy](https://x.com/karpathy);
|
|
||||||
you can read more about it on his
|
|
||||||
[Github](https://github.com/karpathy/llama2.c).
|
|
||||||
|
|
||||||
```
|
|
||||||
cd examples
|
|
||||||
bazel run -c opt //llama:TinyLlama-Stories-15M
|
|
||||||
bazel run -c opt //llama:TinyLlama-Stories-15M -- --prompt="Once upon a time, there was a cute little dragon"
|
|
||||||
```
|
|
||||||
|
|
||||||
#### OpenLLama 3B
|
|
||||||
|
|
||||||
```
|
|
||||||
cd examples
|
|
||||||
bazel run -c opt //llama:OpenLLaMA-3B
|
|
||||||
bazel run -c opt //llama:OpenLLaMA-3B -- --prompt="Once upon a time,"
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Meta Llama 3 8B
|
|
||||||
|
|
||||||
This model has restrictions, see
|
This model has restrictions, see
|
||||||
[here](https://huggingface.co/meta-llama/Meta-Llama-3-8B): it **requires
|
[here](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct). It **requires
|
||||||
approval from Meta on Huggingface**, which can take a few hours to get granted.
|
approval from Meta on Huggingface**, which can take a few hours to get granted.
|
||||||
|
|
||||||
While waiting for approval, you can already
|
While waiting for approval, you can already
|
||||||
[generate your Huggingface access token](../howtos/huggingface_access_token.md).
|
[generate your Huggingface access token](../howtos/huggingface_access_token.md).
|
||||||
|
|
||||||
Once you've been granted access, you're ready to download a gated model like
|
Once you've been granted access, you're ready to download a gated model like
|
||||||
`Meta-Llama-3-8b`!
|
`Meta-Llama-3.1-8B-Instruct`!
|
||||||
|
|
||||||
```
|
```
|
||||||
# requires token in $HOME/.cache/huggingface/token, as created by the
|
# requires token in $HOME/.cache/huggingface/token, as created by the
|
||||||
# `huggingface-cli login` command, or the `HUGGINGFACE_TOKEN` environment variable.
|
# `huggingface-cli login` command, or the `HUGGINGFACE_TOKEN` environment variable.
|
||||||
cd examples
|
cd examples
|
||||||
bazel run -c opt //llama:Meta-Llama-3-8b
|
bazel run -c opt //llama:Llama-3.1-8B-Instruct
|
||||||
bazel run -c opt //llama:Meta-Llama-3-8b -- --promt="Once upon a time,"
|
bazel run -c opt //llama:Llama-3.1-8B-Instruct -- --prompt="What is the capital of France?"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
You can also try `Llama-3.1-70B-Instruct` if you have enough memory.
|
||||||
|
|
||||||
|
### Meta Llama 3.2 1B
|
||||||
|
|
||||||
|
Like the 8B model above, this model also requires approval. See
|
||||||
|
[here](https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct) for access requirements.
|
||||||
|
|
||||||
|
```
|
||||||
|
cd examples
|
||||||
|
bazel run -c opt //llama:Llama-3.2-1B-Instruct
|
||||||
|
bazel run -c opt //llama:Llama-3.2-1B-Instruct -- --prompt="What is the capital of France?"
|
||||||
|
```
|
||||||
|
|
||||||
|
For a larger 3.2 model, you can also try `Llama-3.2-3B-Instruct`.
|
||||||
|
|
||||||
|
|
||||||
## Run Tests
|
## Run Tests
|
||||||
|
|
||||||
@ -126,9 +120,9 @@ run the following:
|
|||||||
|
|
||||||
```
|
```
|
||||||
cd examples
|
cd examples
|
||||||
bazel run -c opt //llama:OpenLLaMA-3B \
|
bazel run -c opt //llama:Llama-3.2-1B-Instruct \
|
||||||
--@zml//runtimes:cuda=true \
|
--@zml//runtimes:cuda=true \
|
||||||
-- --prompt="Once upon a time,"
|
-- --prompt="What is the capital of France?"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
0
examples/BUILD.bazel
Normal file
0
examples/BUILD.bazel
Normal file
@ -7,6 +7,9 @@ bazel_dep(name = "zml", version = "0.1.0")
|
|||||||
bazel_dep(name = "aspect_bazel_lib", version = "2.11.0")
|
bazel_dep(name = "aspect_bazel_lib", version = "2.11.0")
|
||||||
bazel_dep(name = "rules_oci", version = "2.0.0")
|
bazel_dep(name = "rules_oci", version = "2.0.0")
|
||||||
|
|
||||||
|
non_module_deps = use_extension("//:third_party/non_module_deps.bzl", "non_module_deps")
|
||||||
|
use_repo(non_module_deps, "com_github_hejsil_clap")
|
||||||
|
|
||||||
oci = use_extension("@rules_oci//oci:extensions.bzl", "oci")
|
oci = use_extension("@rules_oci//oci:extensions.bzl", "oci")
|
||||||
oci.pull(
|
oci.pull(
|
||||||
name = "distroless_cc_debian12",
|
name = "distroless_cc_debian12",
|
||||||
@ -44,67 +47,24 @@ http_file(
|
|||||||
url = "https://github.com/ggerganov/ggml/raw/18703ad600cc68dbdb04d57434c876989a841d12/examples/mnist/models/mnist/t10k-images.idx3-ubyte",
|
url = "https://github.com/ggerganov/ggml/raw/18703ad600cc68dbdb04d57434c876989a841d12/examples/mnist/models/mnist/t10k-images.idx3-ubyte",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Llama weights
|
|
||||||
huggingface = use_extension("@zml//bazel:huggingface.bzl", "huggingface")
|
|
||||||
huggingface.model(
|
|
||||||
name = "Karpathy-TinyLlama-Stories",
|
|
||||||
build_file_content = """\
|
|
||||||
load("@bazel_skylib//rules:copy_file.bzl", "copy_file")
|
|
||||||
|
|
||||||
# leverage copy_file to rename tokenizer extension
|
|
||||||
# which allow zml.aio.detectFormatAndLoadTokenizer
|
|
||||||
# to leverage the right tokenizer
|
|
||||||
copy_file(
|
|
||||||
name = "stories15M",
|
|
||||||
src = "stories15M.bin",
|
|
||||||
out = "stories15M.tinyllama",
|
|
||||||
allow_symlink = True,
|
|
||||||
visibility = ["//visibility:public"],
|
|
||||||
)
|
|
||||||
|
|
||||||
copy_file(
|
|
||||||
name = "stories110M",
|
|
||||||
src = "stories110M.bin",
|
|
||||||
out = "stories110M.tinyllama",
|
|
||||||
allow_symlink = True,
|
|
||||||
visibility = ["//visibility:public"],
|
|
||||||
)
|
|
||||||
""",
|
|
||||||
commit = "0bd21da7698eaf29a0d7de3992de8a46ef624add",
|
|
||||||
includes = [
|
|
||||||
"stories15M.bin",
|
|
||||||
"stories110M.bin",
|
|
||||||
],
|
|
||||||
model = "karpathy/tinyllamas",
|
|
||||||
)
|
|
||||||
use_repo(huggingface, "Karpathy-TinyLlama-Stories")
|
|
||||||
|
|
||||||
http_file(
|
|
||||||
name = "Karpathy-TinyLlama-Tokenizer",
|
|
||||||
downloaded_file_path = "stories260K.tinyllama",
|
|
||||||
sha256 = "50a52ef822ee9e83de5ce9d0be0a025a773d019437f58b5ff9dcafb063ece361",
|
|
||||||
url = "https://github.com/karpathy/llama2.c/raw/c02865df300f3bd9e567ce061000dc23bf785a17/tokenizer.bin",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Llama 3.2
|
# Llama 3.2
|
||||||
|
|
||||||
|
huggingface = use_extension("@zml//bazel:huggingface.bzl", "huggingface")
|
||||||
|
|
||||||
huggingface.model(
|
huggingface.model(
|
||||||
name = "Meta-Llama-3.2-1B-Instruct",
|
name = "Meta-Llama-3.2-1B-Instruct",
|
||||||
build_file_content = """\
|
build_file_content = """\
|
||||||
package(default_visibility = ["//visibility:public"])
|
package(default_visibility = ["//visibility:public"])
|
||||||
filegroup(
|
filegroup(
|
||||||
name = "model",
|
name = "Meta-Llama-3.2-1B-Instruct",
|
||||||
srcs = ["model.safetensors"],
|
srcs = glob(["*.json", "*.safetensors"]),
|
||||||
)
|
|
||||||
|
|
||||||
filegroup(
|
|
||||||
name = "tokenizer",
|
|
||||||
srcs = ["tokenizer.json"],
|
|
||||||
)
|
)
|
||||||
""",
|
""",
|
||||||
commit = "9213176726f574b556790deb65791e0c5aa438b6",
|
commit = "9213176726f574b556790deb65791e0c5aa438b6",
|
||||||
includes = [
|
includes = [
|
||||||
"model.safetensors",
|
"*.safetensors",
|
||||||
"tokenizer.json",
|
"*.json",
|
||||||
],
|
],
|
||||||
model = "meta-llama/Llama-3.2-1B-Instruct",
|
model = "meta-llama/Llama-3.2-1B-Instruct",
|
||||||
)
|
)
|
||||||
@ -115,129 +75,87 @@ huggingface.model(
|
|||||||
build_file_content = """\
|
build_file_content = """\
|
||||||
package(default_visibility = ["//visibility:public"])
|
package(default_visibility = ["//visibility:public"])
|
||||||
filegroup(
|
filegroup(
|
||||||
name = "model",
|
name = "Meta-Llama-3.2-3B-Instruct",
|
||||||
srcs = glob(["*.safetensors"]) + ["model.safetensors.index.json"],
|
srcs = glob(["*.json", "*.safetensors"]),
|
||||||
)
|
|
||||||
|
|
||||||
filegroup(
|
|
||||||
name = "tokenizer",
|
|
||||||
srcs = ["tokenizer.json"],
|
|
||||||
)
|
)
|
||||||
""",
|
""",
|
||||||
commit = "0cb88a4f764b7a12671c53f0838cd831a0843b95",
|
commit = "0cb88a4f764b7a12671c53f0838cd831a0843b95",
|
||||||
includes = [
|
includes = [
|
||||||
"*.safetensors",
|
"*.safetensors",
|
||||||
"model.safetensors.index.json",
|
"*.json",
|
||||||
"tokenizer.json",
|
|
||||||
],
|
],
|
||||||
model = "meta-llama/Llama-3.2-3B-Instruct",
|
model = "meta-llama/Llama-3.2-3B-Instruct",
|
||||||
)
|
)
|
||||||
use_repo(huggingface, "Meta-Llama-3.2-3B-Instruct")
|
use_repo(huggingface, "Meta-Llama-3.2-3B-Instruct")
|
||||||
|
|
||||||
# Llama 3.1
|
# Llama 3.1
|
||||||
|
|
||||||
huggingface.model(
|
huggingface.model(
|
||||||
name = "Meta-Llama-3.1-8B-Instruct",
|
name = "Meta-Llama-3.1-8B-Instruct",
|
||||||
build_file_content = """\
|
build_file_content = """\
|
||||||
package(default_visibility = ["//visibility:public"])
|
package(default_visibility = ["//visibility:public"])
|
||||||
filegroup(
|
filegroup(
|
||||||
name = "model",
|
name = "Meta-Llama-3.1-8B-Instruct",
|
||||||
srcs = glob(["*.safetensors"]) + ["model.safetensors.index.json"],
|
srcs = glob(["*.json", "*.safetensors"]),
|
||||||
)
|
|
||||||
|
|
||||||
filegroup(
|
|
||||||
name = "tokenizer",
|
|
||||||
srcs = ["tokenizer.json"],
|
|
||||||
)
|
)
|
||||||
""",
|
""",
|
||||||
commit = "5206a32e0bd3067aef1ce90f5528ade7d866253f",
|
commit = "5206a32e0bd3067aef1ce90f5528ade7d866253f",
|
||||||
includes = [
|
includes = [
|
||||||
"*.safetensors",
|
"*.safetensors",
|
||||||
"model.safetensors.index.json",
|
"*.json",
|
||||||
"tokenizer.json",
|
|
||||||
],
|
],
|
||||||
model = "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
model = "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||||
)
|
)
|
||||||
use_repo(huggingface, "Meta-Llama-3.1-8B-Instruct")
|
use_repo(huggingface, "Meta-Llama-3.1-8B-Instruct")
|
||||||
|
|
||||||
huggingface.model(
|
huggingface.model(
|
||||||
name = "Meta-Llama-3.1-70B-Instruct",
|
name = "Meta-Llama-3.1-70B-Instruct",
|
||||||
build_file_content = """\
|
build_file_content = """\
|
||||||
package(default_visibility = ["//visibility:public"])
|
package(default_visibility = ["//visibility:public"])
|
||||||
filegroup(
|
filegroup(
|
||||||
name = "model",
|
name = "Meta-Llama-3.1-70B-Instruct",
|
||||||
srcs = glob(["*.safetensors"]) + ["model.safetensors.index.json"],
|
srcs = glob(["*.json", "*.safetensors"]),
|
||||||
)
|
|
||||||
|
|
||||||
filegroup(
|
|
||||||
name = "tokenizer",
|
|
||||||
srcs = ["tokenizer.json"],
|
|
||||||
)
|
)
|
||||||
""",
|
""",
|
||||||
commit = "945c8663693130f8be2ee66210e062158b2a9693",
|
commit = "945c8663693130f8be2ee66210e062158b2a9693",
|
||||||
includes = [
|
includes = [
|
||||||
"*.safetensors",
|
"*.safetensors",
|
||||||
"model.safetensors.index.json",
|
"*.json",
|
||||||
"tokenizer.json",
|
|
||||||
],
|
],
|
||||||
model = "meta-llama/Meta-Llama-3.1-70B-Instruct",
|
model = "meta-llama/Meta-Llama-3.1-70B-Instruct",
|
||||||
)
|
)
|
||||||
use_repo(huggingface, "Meta-Llama-3.1-70B-Instruct")
|
use_repo(huggingface, "Meta-Llama-3.1-70B-Instruct")
|
||||||
|
|
||||||
|
|
||||||
huggingface.model(
|
huggingface.model(
|
||||||
name = "TinyLlama-1.1B-Chat-v1.0",
|
name = "TinyLlama-120M-scratch",
|
||||||
build_file_content = """\
|
build_file_content = """\
|
||||||
package(default_visibility = ["//visibility:public"])
|
package(default_visibility = ["//visibility:public"])
|
||||||
|
|
||||||
filegroup(
|
filegroup(
|
||||||
name = "model",
|
name = "TinyLlama-120M-scratch",
|
||||||
srcs = ["model.safetensors"],
|
srcs = glob(["*.json", "*.safetensors"]),
|
||||||
)
|
|
||||||
|
|
||||||
filegroup(
|
|
||||||
name = "tokenizer",
|
|
||||||
srcs = ["tokenizer.model"],
|
|
||||||
)
|
)
|
||||||
""",
|
""",
|
||||||
commit = "fe8a4ea1ffedaf415f4da2f062534de366a451e6",
|
commit = "89c1bb4ea00861ddaa26c55f102ccb25e161feee",
|
||||||
includes = [
|
includes = [
|
||||||
"model.safetensors",
|
"*.safetensors",
|
||||||
"tokenizer.model",
|
"*.json",
|
||||||
],
|
],
|
||||||
model = "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
model = "Hoyeon/TinyLlama-120M-scratch",
|
||||||
)
|
)
|
||||||
use_repo(huggingface, "TinyLlama-1.1B-Chat-v1.0")
|
use_repo(huggingface, "TinyLlama-120M-scratch")
|
||||||
|
|
||||||
#OpenLLaMa
|
|
||||||
huggingface.model(
|
|
||||||
name = "OpenLM-Research-OpenLLaMA-3B",
|
|
||||||
build_file_content = """\
|
|
||||||
load("@bazel_skylib//rules:copy_file.bzl", "copy_file")
|
|
||||||
|
|
||||||
package(default_visibility = ["//visibility:public"])
|
bazel_dep(name = "rules_rust", version = "0.57.0")
|
||||||
|
rust = use_extension("@rules_rust//rust:extensions.bzl", "rust")
|
||||||
filegroup(
|
rust.toolchain(
|
||||||
name = "model",
|
edition = "2021",
|
||||||
srcs = ["model.safetensors"],
|
versions = ["1.84.0"],
|
||||||
)
|
extra_target_triples = [
|
||||||
|
"aarch64-apple-darwin",
|
||||||
filegroup(
|
"aarch64-unknown-linux-gnu",
|
||||||
name = "tokenizer",
|
"x86_64-unknown-linux-gnu",
|
||||||
srcs = [":tokenizer_pb"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# leverage copy_file to rename tokenizer extension
|
|
||||||
# which allow zml.aio.detectFormatAndLoadTokenizer
|
|
||||||
# to leverage the right tokenizer
|
|
||||||
copy_file(
|
|
||||||
name = "tokenizer_pb",
|
|
||||||
src = "tokenizer.model",
|
|
||||||
out = "tokenizer.pb",
|
|
||||||
allow_symlink = True,
|
|
||||||
)
|
|
||||||
""",
|
|
||||||
commit = "fcc2e809eb8f14dabba84d76a0ddc17b8ea05356",
|
|
||||||
includes = [
|
|
||||||
"model.safetensors",
|
|
||||||
"tokenizer.model",
|
|
||||||
],
|
],
|
||||||
model = "openlm-research/open_llama_3b",
|
|
||||||
)
|
)
|
||||||
use_repo(huggingface, "OpenLM-Research-OpenLLaMA-3B")
|
use_repo(rust, "rust_toolchains")
|
||||||
|
register_toolchains("@rust_toolchains//:all")
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@ -1,3 +1,4 @@
|
|||||||
|
load("@aspect_bazel_lib//lib:expand_template.bzl", "expand_template")
|
||||||
load("@aspect_bazel_lib//lib:tar.bzl", "mtree_spec", "tar")
|
load("@aspect_bazel_lib//lib:tar.bzl", "mtree_spec", "tar")
|
||||||
load("@aspect_bazel_lib//lib:transitions.bzl", "platform_transition_filegroup")
|
load("@aspect_bazel_lib//lib:transitions.bzl", "platform_transition_filegroup")
|
||||||
load("@bazel_skylib//rules:native_binary.bzl", "native_binary")
|
load("@bazel_skylib//rules:native_binary.bzl", "native_binary")
|
||||||
@ -12,7 +13,7 @@ zig_cc_binary(
|
|||||||
],
|
],
|
||||||
main = "main.zig",
|
main = "main.zig",
|
||||||
deps = [
|
deps = [
|
||||||
"//third_party/tigerbeetle:flags",
|
"@com_github_hejsil_clap//:clap",
|
||||||
"@zml//async",
|
"@zml//async",
|
||||||
"@zml//stdx",
|
"@zml//stdx",
|
||||||
"@zml//zml",
|
"@zml//zml",
|
||||||
@ -20,18 +21,35 @@ zig_cc_binary(
|
|||||||
)
|
)
|
||||||
|
|
||||||
cc_binary(
|
cc_binary(
|
||||||
name = "Llama-3.1-8B-Instruct",
|
name = "TinyLlama-120M-scratch",
|
||||||
args = [
|
args = [
|
||||||
"--model=$(location @Meta-Llama-3.1-8B-Instruct//:model.safetensors.index.json)",
|
"--config=$(location @TinyLlama-120M-scratch//:config.json)",
|
||||||
"--tokenizer=$(location @Meta-Llama-3.1-8B-Instruct//:tokenizer)",
|
"--weights=$(location @TinyLlama-120M-scratch//:model.safetensors)",
|
||||||
"--num-heads=32",
|
"--tokenizer=$(location @TinyLlama-120M-scratch//:tokenizer.json)",
|
||||||
"--num-kv-heads=8",
|
"--no-llama3=true", # don't do llama3 template prompt encoding
|
||||||
"--rope-freq-base=500000",
|
"--sharding=false", # don't shard this
|
||||||
],
|
],
|
||||||
data = [
|
data = [
|
||||||
"@Meta-Llama-3.1-8B-Instruct//:model",
|
"@TinyLlama-120M-scratch",
|
||||||
|
"@TinyLlama-120M-scratch//:config.json",
|
||||||
|
"@TinyLlama-120M-scratch//:model.safetensors",
|
||||||
|
"@TinyLlama-120M-scratch//:tokenizer.json",
|
||||||
|
],
|
||||||
|
deps = [":llama_lib"],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_binary(
|
||||||
|
name = "Llama-3.1-8B-Instruct",
|
||||||
|
args = [
|
||||||
|
"--config=$(location @Meta-Llama-3.1-8B-Instruct//:config.json)",
|
||||||
|
"--weights=$(location @Meta-Llama-3.1-8B-Instruct//:model.safetensors.index.json)",
|
||||||
|
"--tokenizer=$(location @Meta-Llama-3.1-8B-Instruct//:tokenizer.json)",
|
||||||
|
],
|
||||||
|
data = [
|
||||||
|
"@Meta-Llama-3.1-8B-Instruct",
|
||||||
|
"@Meta-Llama-3.1-8B-Instruct//:config.json",
|
||||||
"@Meta-Llama-3.1-8B-Instruct//:model.safetensors.index.json",
|
"@Meta-Llama-3.1-8B-Instruct//:model.safetensors.index.json",
|
||||||
"@Meta-Llama-3.1-8B-Instruct//:tokenizer",
|
"@Meta-Llama-3.1-8B-Instruct//:tokenizer.json",
|
||||||
],
|
],
|
||||||
deps = [":llama_lib"],
|
deps = [":llama_lib"],
|
||||||
)
|
)
|
||||||
@ -39,32 +57,32 @@ cc_binary(
|
|||||||
cc_binary(
|
cc_binary(
|
||||||
name = "Llama-3.1-70B-Instruct",
|
name = "Llama-3.1-70B-Instruct",
|
||||||
args = [
|
args = [
|
||||||
"--model=$(location @Meta-Llama-3.1-70B-Instruct//:model.safetensors.index.json)",
|
"--config=$(location @Meta-Llama-3.1-70B-Instruct//:config.json)",
|
||||||
"--tokenizer=$(location @Meta-Llama-3.1-70B-Instruct//:tokenizer)",
|
"--weights=$(location @Meta-Llama-3.1-70B-Instruct//:model.safetensors.index.json)",
|
||||||
"--num-heads=64",
|
"--tokenizer=$(location @Meta-Llama-3.1-70B-Instruct//:tokenizer.json)",
|
||||||
"--num-kv-heads=8",
|
|
||||||
"--rope-freq-base=500000",
|
|
||||||
],
|
],
|
||||||
data = [
|
data = [
|
||||||
"@Meta-Llama-3.1-70B-Instruct//:model",
|
"@Meta-Llama-3.1-70B-Instruct",
|
||||||
|
"@Meta-Llama-3.1-70B-Instruct//:config.json",
|
||||||
"@Meta-Llama-3.1-70B-Instruct//:model.safetensors.index.json",
|
"@Meta-Llama-3.1-70B-Instruct//:model.safetensors.index.json",
|
||||||
"@Meta-Llama-3.1-70B-Instruct//:tokenizer",
|
"@Meta-Llama-3.1-70B-Instruct//:tokenizer.json",
|
||||||
],
|
],
|
||||||
deps = [":llama_lib"],
|
deps = [":llama_lib"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
cc_binary(
|
cc_binary(
|
||||||
name = "Llama-3.2-1B-Instruct",
|
name = "Llama-3.2-1B-Instruct",
|
||||||
args = [
|
args = [
|
||||||
"--model=$(location @Meta-Llama-3.2-1B-Instruct//:model.safetensors)",
|
"--config=$(location @Meta-Llama-3.2-1B-Instruct//:config.json)",
|
||||||
"--tokenizer=$(location @Meta-Llama-3.2-1B-Instruct//:tokenizer)",
|
"--weights=$(location @Meta-Llama-3.2-1B-Instruct//:model.safetensors)",
|
||||||
"--num-heads=32",
|
"--tokenizer=$(location @Meta-Llama-3.2-1B-Instruct//:tokenizer.json)",
|
||||||
"--num-kv-heads=8",
|
|
||||||
"--rope-freq-base=500000",
|
|
||||||
],
|
],
|
||||||
data = [
|
data = [
|
||||||
|
"@Meta-Llama-3.2-1B-Instruct",
|
||||||
|
"@Meta-Llama-3.2-1B-Instruct//:config.json",
|
||||||
"@Meta-Llama-3.2-1B-Instruct//:model.safetensors",
|
"@Meta-Llama-3.2-1B-Instruct//:model.safetensors",
|
||||||
"@Meta-Llama-3.2-1B-Instruct//:tokenizer",
|
"@Meta-Llama-3.2-1B-Instruct//:tokenizer.json",
|
||||||
],
|
],
|
||||||
deps = [":llama_lib"],
|
deps = [":llama_lib"],
|
||||||
)
|
)
|
||||||
@ -72,86 +90,26 @@ cc_binary(
|
|||||||
cc_binary(
|
cc_binary(
|
||||||
name = "Llama-3.2-3B-Instruct",
|
name = "Llama-3.2-3B-Instruct",
|
||||||
args = [
|
args = [
|
||||||
"--model=$(location @Meta-Llama-3.2-3B-Instruct//:model.safetensors.index.json)",
|
"--config=$(location @Meta-Llama-3.2-3B-Instruct//:config.json)",
|
||||||
"--tokenizer=$(location @Meta-Llama-3.2-3B-Instruct//:tokenizer)",
|
"--weights=$(location @Meta-Llama-3.2-3B-Instruct//:model.safetensors.index.json)",
|
||||||
"--num-heads=24",
|
"--tokenizer=$(location @Meta-Llama-3.2-3B-Instruct//:tokenizer.json)",
|
||||||
"--num-kv-heads=8",
|
|
||||||
"--rope-freq-base=500000",
|
|
||||||
],
|
],
|
||||||
data = [
|
data = [
|
||||||
"@Meta-Llama-3.2-3B-Instruct//:model",
|
"@Meta-Llama-3.2-3B-Instruct",
|
||||||
|
"@Meta-Llama-3.2-3B-Instruct//:config.json",
|
||||||
"@Meta-Llama-3.2-3B-Instruct//:model.safetensors.index.json",
|
"@Meta-Llama-3.2-3B-Instruct//:model.safetensors.index.json",
|
||||||
"@Meta-Llama-3.2-3B-Instruct//:tokenizer",
|
"@Meta-Llama-3.2-3B-Instruct//:tokenizer.json",
|
||||||
],
|
|
||||||
deps = [":llama_lib"],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_binary(
|
|
||||||
name = "OpenLLaMA-3B",
|
|
||||||
args = [
|
|
||||||
"--model=$(location @OpenLM-Research-OpenLLaMA-3B//:model)",
|
|
||||||
"--tokenizer=$(location @OpenLM-Research-OpenLLaMA-3B//:tokenizer)",
|
|
||||||
"--num-heads=32",
|
|
||||||
"--num-kv-heads=32",
|
|
||||||
"--rope-freq-base=10000",
|
|
||||||
],
|
|
||||||
data = [
|
|
||||||
"@OpenLM-Research-OpenLLaMA-3B//:model",
|
|
||||||
"@OpenLM-Research-OpenLLaMA-3B//:tokenizer",
|
|
||||||
],
|
|
||||||
deps = [":llama_lib"],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_binary(
|
|
||||||
name = "TinyLlama-1.1B-Chat",
|
|
||||||
args = [
|
|
||||||
"--model=$(location @TinyLlama-1.1B-Chat-v1.0//:model.safetensors)",
|
|
||||||
"--tokenizer=$(location @TinyLlama-1.1B-Chat-v1.0//:tokenizer)",
|
|
||||||
"--num-heads=32",
|
|
||||||
"--num-kv-heads=4",
|
|
||||||
"--rope-freq-base=10000",
|
|
||||||
],
|
|
||||||
data = [
|
|
||||||
"@TinyLlama-1.1B-Chat-v1.0//:model.safetensors",
|
|
||||||
"@TinyLlama-1.1B-Chat-v1.0//:tokenizer",
|
|
||||||
],
|
|
||||||
deps = [":llama_lib"],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_binary(
|
|
||||||
name = "TinyLlama-Stories-110M",
|
|
||||||
args = [
|
|
||||||
"--model=$(location @Karpathy-TinyLlama-Stories//:stories110M)",
|
|
||||||
"--tokenizer=$(location @Karpathy-TinyLlama-Tokenizer//file)",
|
|
||||||
],
|
|
||||||
data = [
|
|
||||||
"@Karpathy-TinyLlama-Stories//:stories110M",
|
|
||||||
"@Karpathy-TinyLlama-Tokenizer//file",
|
|
||||||
],
|
|
||||||
deps = [":llama_lib"],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_binary(
|
|
||||||
name = "TinyLlama-Stories-15M",
|
|
||||||
args = [
|
|
||||||
"--model=$(location @Karpathy-TinyLlama-Stories//:stories15M)",
|
|
||||||
"--tokenizer=$(location @Karpathy-TinyLlama-Tokenizer//file)",
|
|
||||||
],
|
|
||||||
data = [
|
|
||||||
"@Karpathy-TinyLlama-Stories//:stories15M",
|
|
||||||
"@Karpathy-TinyLlama-Tokenizer//file",
|
|
||||||
],
|
],
|
||||||
deps = [":llama_lib"],
|
deps = [":llama_lib"],
|
||||||
)
|
)
|
||||||
|
#
|
||||||
|
|
||||||
zig_cc_binary(
|
zig_cc_binary(
|
||||||
name = "test-implementation",
|
name = "test-implementation",
|
||||||
srcs = ["llama.zig"],
|
srcs = ["llama.zig"],
|
||||||
args = [
|
args = [
|
||||||
"--model=$(location @Meta-Llama-3.1-8B-Instruct//:model.safetensors.index.json)",
|
"--weights=$(location @Meta-Llama-3.1-8B-Instruct//:model.safetensors.index.json)",
|
||||||
"--num-heads=32",
|
"--config=$(location @Meta-Llama-3.1-8B-Instruct//:config.json)",
|
||||||
"--num-kv-heads=8",
|
|
||||||
"--rope-freq-base=500000",
|
|
||||||
],
|
],
|
||||||
data = [
|
data = [
|
||||||
"@Meta-Llama-3.1-8B-Instruct//:model",
|
"@Meta-Llama-3.1-8B-Instruct//:model",
|
||||||
@ -184,12 +142,12 @@ zig_cc_binary(
|
|||||||
|
|
||||||
mtree_spec(
|
mtree_spec(
|
||||||
name = "mtree",
|
name = "mtree",
|
||||||
srcs = [":llama"],
|
srcs = [":Llama-3.2-1B-Instruct"],
|
||||||
)
|
)
|
||||||
|
|
||||||
tar(
|
tar(
|
||||||
name = "archive",
|
name = "archive",
|
||||||
srcs = [":llama"],
|
srcs = [":Llama-3.2-1B-Instruct"],
|
||||||
args = [
|
args = [
|
||||||
"--options",
|
"--options",
|
||||||
"zstd:compression-level=9",
|
"zstd:compression-level=9",
|
||||||
@ -198,10 +156,33 @@ tar(
|
|||||||
mtree = ":mtree",
|
mtree = ":mtree",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
expand_template(
|
||||||
|
name = "entrypoint",
|
||||||
|
data = [
|
||||||
|
":Llama-3.2-1B-Instruct",
|
||||||
|
"@Meta-Llama-3.2-1B-Instruct",
|
||||||
|
"@Meta-Llama-3.2-1B-Instruct//:config.json",
|
||||||
|
"@Meta-Llama-3.2-1B-Instruct//:model.safetensors",
|
||||||
|
"@Meta-Llama-3.2-1B-Instruct//:tokenizer.json",
|
||||||
|
],
|
||||||
|
substitutions = {
|
||||||
|
":config": "$(rlocationpath @Meta-Llama-3.2-1B-Instruct//:config.json)",
|
||||||
|
":weights": "$(rlocationpath @Meta-Llama-3.2-1B-Instruct//:model.safetensors)",
|
||||||
|
":tokenizer": "$(rlocationpath @Meta-Llama-3.2-1B-Instruct//:tokenizer.json)",
|
||||||
|
},
|
||||||
|
template = [
|
||||||
|
"./{}/Llama-3.2-1B-Instruct".format(package_name()),
|
||||||
|
"--config=./{}/Llama-3.2-1B-Instruct.runfiles/:config".format(package_name()),
|
||||||
|
"--weights=./{}/Llama-3.2-1B-Instruct.runfiles/:weights".format(package_name()),
|
||||||
|
"--tokenizer=./{}/Llama-3.2-1B-Instruct.runfiles/:tokenizer".format(package_name()),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
oci_image(
|
oci_image(
|
||||||
name = "image_",
|
name = "image_",
|
||||||
base = "@distroless_cc_debian12_debug",
|
base = "@distroless_cc_debian12_debug",
|
||||||
entrypoint = ["./{}/llama".format(package_name())],
|
# entrypoint = ["./{}/Llama-3.2-1B-Instruct".format(package_name())],
|
||||||
|
entrypoint = ":entrypoint",
|
||||||
tars = [
|
tars = [
|
||||||
"@zml//runtimes:layers",
|
"@zml//runtimes:layers",
|
||||||
":archive",
|
":archive",
|
||||||
@ -218,7 +199,7 @@ oci_load(
|
|||||||
name = "load",
|
name = "load",
|
||||||
image = ":image",
|
image = ":image",
|
||||||
repo_tags = [
|
repo_tags = [
|
||||||
"distroless/llama:latest",
|
"distroless/llama-3.2-1b-instruct:latest",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -226,5 +207,5 @@ oci_push(
|
|||||||
name = "push",
|
name = "push",
|
||||||
image = ":image",
|
image = ":image",
|
||||||
remote_tags = ["latest"],
|
remote_tags = ["latest"],
|
||||||
repository = "index.docker.io/steeve/llama",
|
repository = "index.docker.io/steeve/llama-3.2-1b-instruct",
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,4 +1,3 @@
|
|||||||
const flags = @import("tigerbeetle/flags");
|
|
||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
const stdx = @import("stdx");
|
const stdx = @import("stdx");
|
||||||
const zml = @import("zml");
|
const zml = @import("zml");
|
||||||
@ -12,36 +11,51 @@ const gguf = zml.io.gguf;
|
|||||||
const expectClose = zml.testing.expectClose;
|
const expectClose = zml.testing.expectClose;
|
||||||
const log = std.log.scoped(.llama);
|
const log = std.log.scoped(.llama);
|
||||||
|
|
||||||
pub const LlamaOptions = struct {
|
|
||||||
gen_opts: zml.nn.SamplingStrategy,
|
|
||||||
max_seq_len: u32,
|
|
||||||
num_heads: i64,
|
|
||||||
num_kv_heads: i64,
|
|
||||||
rms_norm_eps: f32,
|
|
||||||
rope_opts: zml.nn.RopeOpts,
|
|
||||||
};
|
|
||||||
|
|
||||||
/// Llama architecture, using huggingface transformers naming.
|
/// Llama architecture, using huggingface transformers naming.
|
||||||
/// Dimensions of activations: {.b, .s, .d}
|
/// Dimensions of activations: {.b, .s, .d}
|
||||||
pub const LlamaLM = struct {
|
pub const LlamaLM = struct {
|
||||||
lm_head: ?zml.nn.Linear = null,
|
pub const Config = struct {
|
||||||
|
bos_token_id: u32,
|
||||||
|
eos_token_id: stdx.json.Union(union(enum) {
|
||||||
|
int: u32,
|
||||||
|
ints: []u32,
|
||||||
|
}),
|
||||||
|
num_hidden_layers: usize,
|
||||||
|
num_attention_heads: usize,
|
||||||
|
num_key_value_heads: usize,
|
||||||
|
rope_theta: f32,
|
||||||
|
max_position_embeddings: usize,
|
||||||
|
rms_norm_eps: f32,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub const Options = struct {
|
||||||
|
sampling_strategy: ?zml.nn.SamplingStrategy,
|
||||||
|
max_seq_len: usize,
|
||||||
|
};
|
||||||
|
|
||||||
|
lm_head: ?zml.nn.Linear,
|
||||||
model: Llama,
|
model: Llama,
|
||||||
|
|
||||||
// Options controlling generation
|
// Options controlling generation
|
||||||
gen_opts: zml.nn.SamplingStrategy = .{},
|
gen_opts: zml.nn.SamplingStrategy = .{},
|
||||||
|
config: Config,
|
||||||
|
|
||||||
pub fn init(self: *LlamaLM, options: LlamaOptions) void {
|
pub fn init(self: *LlamaLM, config: Config, options: Options) void {
|
||||||
self.gen_opts = options.gen_opts;
|
self.config = config;
|
||||||
self.model.max_seq_len = options.max_seq_len;
|
self.gen_opts = options.sampling_strategy orelse .{};
|
||||||
self.model.num_heads = options.num_heads;
|
self.model.max_seq_len = @intCast(options.max_seq_len);
|
||||||
self.model.num_kv_heads = options.num_kv_heads;
|
self.model.num_heads = @intCast(config.num_attention_heads);
|
||||||
self.model.rope_opts = options.rope_opts;
|
self.model.num_kv_heads = @intCast(config.num_key_value_heads);
|
||||||
|
self.model.rope_opts = .{
|
||||||
|
.impl = .sequential,
|
||||||
|
.freq_base = config.rope_theta,
|
||||||
|
};
|
||||||
for (self.model.layers) |*layer| {
|
for (self.model.layers) |*layer| {
|
||||||
layer.self_attn.num_heads = options.num_heads;
|
layer.self_attn.num_heads = self.model.num_heads;
|
||||||
layer.self_attn.num_kv_heads = options.num_kv_heads;
|
layer.self_attn.num_kv_heads = self.model.num_kv_heads;
|
||||||
layer.self_attn.rope_opts = options.rope_opts;
|
layer.self_attn.rope_opts = self.model.rope_opts;
|
||||||
layer.input_layernorm.eps = options.rms_norm_eps;
|
layer.input_layernorm.eps = config.rms_norm_eps;
|
||||||
layer.post_attention_layernorm.eps = options.rms_norm_eps;
|
layer.post_attention_layernorm.eps = config.rms_norm_eps;
|
||||||
layer.mlp.up_proj.weight = layer.mlp.up_proj.weight.withSharding(.{0});
|
layer.mlp.up_proj.weight = layer.mlp.up_proj.weight.withSharding(.{0});
|
||||||
layer.mlp.gate_proj.weight = layer.mlp.gate_proj.weight.withSharding(.{0});
|
layer.mlp.gate_proj.weight = layer.mlp.gate_proj.weight.withSharding(.{0});
|
||||||
layer.mlp.down_proj.weight = layer.mlp.down_proj.weight.withSharding(.{1});
|
layer.mlp.down_proj.weight = layer.mlp.down_proj.weight.withSharding(.{1});
|
||||||
@ -54,88 +68,58 @@ pub const LlamaLM = struct {
|
|||||||
|
|
||||||
// TODO(Corentin): Fix lm_head sharding when top-k sampling is enabled.
|
// TODO(Corentin): Fix lm_head sharding when top-k sampling is enabled.
|
||||||
// It currently crashes/compilation fails
|
// It currently crashes/compilation fails
|
||||||
if (options.gen_opts.topk == 1) {
|
if (self.gen_opts.topk == 1 and self.lm_head != null) {
|
||||||
if (self.lm_head) |lm_head| {
|
self.lm_head.?.weight = self.lm_head.?.weight.withSharding(.{0});
|
||||||
self.lm_head.?.weight = lm_head.weight.withSharding(.{0});
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Predicts the token at `token_index` position.
|
/// Predicts the token at `token_index` position.
|
||||||
/// Returns:
|
/// Returns:
|
||||||
/// - updated `tokens`,
|
/// - updated `tokens`,
|
||||||
/// - `token_idx` + 1,
|
|
||||||
/// - updated KV cache
|
/// - updated KV cache
|
||||||
/// - a Rng state to allow for probabilistic generation
|
/// - a Rng state to allow for probabilistic generation
|
||||||
pub fn forward(
|
pub fn forward(
|
||||||
self: LlamaLM,
|
self: LlamaLM,
|
||||||
tokens_: Tensor,
|
tokens_: Tensor,
|
||||||
token_index: Tensor,
|
token_index: Tensor,
|
||||||
kv_cache: ?KvCache,
|
kv_cache: KvCache,
|
||||||
rng: Tensor.Rng,
|
rng: Tensor.Rng,
|
||||||
) struct { Tensor, Tensor, KvCache, Tensor.Rng } {
|
) struct { Tensor, KvCache, Tensor.Rng } {
|
||||||
stdx.debug.assert(tokens_.dtype() == .i32 and tokens_.rank() >= 1 and token_index.dtype() == .i32 and token_index.rank() == 0, "Can't run Llama ! Expected >=1d tokens and 0d token_index, got: {} and {}", .{ tokens_, token_index });
|
stdx.debug.assert(tokens_.dtype() == .u32 and tokens_.rank() >= 1 and token_index.dtype() == .u32 and token_index.rank() <= 1, "Can't run Llama ! Expected >=1d tokens and 0d token_index, got: {} and {}", .{ tokens_, token_index });
|
||||||
|
|
||||||
var tokens = tokens_.withPartialTags(.{.s});
|
var tokens = tokens_.withPartialTags(.{.s});
|
||||||
const out, const updated_kv_cache = zml.call(self.model, .forward, .{ tokens, if (kv_cache == null) null else token_index, kv_cache });
|
const out, const updated_kv_cache = zml.call(self.model, .forward, .{ tokens, token_index, kv_cache });
|
||||||
tokens, const new_rng = self.updateTokens(tokens, token_index, out, rng, self.gen_opts);
|
tokens, const new_rng = self.sampleTokens(self.lm_head, tokens, out, rng, self.gen_opts);
|
||||||
return .{ tokens, increment(0, token_index), updated_kv_cache, new_rng };
|
return .{ tokens, updated_kv_cache, new_rng };
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn updateTokens(
|
pub fn sampleTokens(
|
||||||
self: LlamaLM,
|
self: LlamaLM,
|
||||||
|
lm_head_: ?zml.nn.Linear,
|
||||||
tokens_: Tensor,
|
tokens_: Tensor,
|
||||||
token_index: Tensor,
|
|
||||||
out_: Tensor,
|
out_: Tensor,
|
||||||
rng: Tensor.Rng,
|
rng: Tensor.Rng,
|
||||||
opts: zml.nn.SamplingStrategy,
|
opts: zml.nn.SamplingStrategy,
|
||||||
) struct { Tensor, Tensor.Rng } {
|
) struct { Tensor, Tensor.Rng } {
|
||||||
const tokens = tokens_.withPartialTags(.{.s});
|
|
||||||
const out = out_.withPartialTags(.{ .s, .d });
|
const out = out_.withPartialTags(.{ .s, .d });
|
||||||
|
|
||||||
const next_token_pred = out.gatherValues(.s, token_index, .{});
|
var logits = blk: {
|
||||||
var logits = if (self.lm_head) |lm_head|
|
if (lm_head_) |lm_head| {
|
||||||
zml.call(lm_head, .forward, .{next_token_pred})
|
break :blk zml.call(lm_head, .forward, .{out});
|
||||||
else
|
} else {
|
||||||
self.model.embed_tokens.weight.withTags(.{ .voc, .d }).dot(next_token_pred, .{.d});
|
break :blk self.model.embed_tokens.weight.withTags(.{ .voc, .d }).dot(out, .{.d});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
if (logits.shape().hasTag(.voc) == null)
|
if (logits.shape().hasTag(.voc) == null)
|
||||||
logits = logits.rename(.{ .d = .voc });
|
logits = logits.rename(.{ .d = .voc });
|
||||||
|
|
||||||
const next_token, const new_rng = zml.nn.sampleTokens(logits, opts, rng);
|
const next_tokens, const new_rng = zml.nn.sampleTokens(logits, opts, rng);
|
||||||
const next_token_index = token_index.addConstant(1);
|
return .{ next_tokens.reuseBuffer(tokens_), new_rng };
|
||||||
const new_tokens = tokens.dynamicUpdateSlice(.{ .s = next_token_index }, next_token);
|
|
||||||
|
|
||||||
return .{ new_tokens.reuseBuffer(tokens_), new_rng };
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn increment(_: u8, token_index: Tensor) Tensor {
|
pub fn increment(_: u8, token_index: Tensor) Tensor {
|
||||||
return token_index.addConstant(1);
|
return token_index.addConstant(1).reuseBuffer(token_index);
|
||||||
}
|
|
||||||
|
|
||||||
/// Run the generation entirely within pjrt.
|
|
||||||
pub fn generate(self: LlamaLM, tokens: Tensor, token_index: Tensor, rng: Tensor.Rng) Tensor {
|
|
||||||
// Generate the first token using the prompt and generate the KV-cache initial values.
|
|
||||||
const prefill = zml.call(self, .forward, .{ tokens, token_index, null, rng });
|
|
||||||
|
|
||||||
const Gen = struct {
|
|
||||||
/// Same as LlamaLM.forward but without optional in the signature
|
|
||||||
pub fn forward(lm: LlamaLM, t_ids: Tensor, t_idx: Tensor, kv_cache_: KvCache, inner_rng: Tensor.Rng) struct { Tensor, Tensor, KvCache, Tensor.Rng } {
|
|
||||||
var kv_cache = kv_cache_;
|
|
||||||
kv_cache.k = kv_cache.k.withPartialTags(.{ .layer, .h, .k, .hd });
|
|
||||||
kv_cache.v = kv_cache.v.withPartialTags(.{ .layer, .h, .k, .hd });
|
|
||||||
return zml.call(lm, .forward, .{ t_ids._ctx, t_ids, t_idx, kv_cache, inner_rng });
|
|
||||||
}
|
|
||||||
// / Stops when we generated `max_seq_len` tokens.
|
|
||||||
pub fn shouldContinue(lm: LlamaLM, t_ids: Tensor, t_idx: Tensor, kv_cache: KvCache, inner_rng: Tensor.Rng) Tensor {
|
|
||||||
_ = kv_cache;
|
|
||||||
_ = inner_rng;
|
|
||||||
std.debug.assert(t_ids.dim(1) == lm.model.max_seq_len);
|
|
||||||
return t_idx.cmp(.LT, Tensor.scalar(t_ids._ctx, lm.model.max_seq_len, t_idx.dtype()));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
// Generate remaining tokens using the KV-cache, return tokens.
|
|
||||||
return zml.ops.while_(Gen.shouldContinue, Gen.forward, self, prefill)[0];
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -177,33 +161,28 @@ pub const Llama = struct {
|
|||||||
|
|
||||||
/// Forward one token, using KV cache for previous tokens.
|
/// Forward one token, using KV cache for previous tokens.
|
||||||
/// Returns result and updated KV cache.
|
/// Returns result and updated KV cache.
|
||||||
pub fn forward(self: Llama, tokens: Tensor, token_index: ?Tensor, kv_cache: ?KvCache) struct { Tensor, KvCache } {
|
pub fn forward(self: Llama, tokens: Tensor, token_index: Tensor, kv_cache: KvCache) struct { Tensor, KvCache } {
|
||||||
const embeds = embed(self.embed_tokens, tokens, token_index);
|
const embeds = embed(self.embed_tokens, tokens);
|
||||||
|
|
||||||
var hidden = embeds;
|
var hidden = embeds;
|
||||||
const kv_cache0 = kv_cache orelse self.initKvCache(embeds.shape());
|
|
||||||
|
|
||||||
var updated_kv_cache = kv_cache0;
|
var updated_kv_cache = kv_cache;
|
||||||
for (self.layers, 0..) |layer, i| {
|
for (self.layers, 0..) |layer, i| {
|
||||||
hidden, updated_kv_cache = zml.call(layer, .forward, .{ hidden, token_index, updated_kv_cache.atLayer(i) });
|
hidden, updated_kv_cache = zml.call(layer, .forward, .{ hidden, token_index, updated_kv_cache.atLayer(i) });
|
||||||
}
|
}
|
||||||
const output = zml.call(self.norm, .forward, .{hidden});
|
const output = zml.call(self.norm, .forward, .{hidden});
|
||||||
|
|
||||||
return .{ output, updated_kv_cache.reuseBuffer(kv_cache0) };
|
return .{ output, updated_kv_cache.reuseBuffer(kv_cache) };
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn embed(embed_tokens_: zml.nn.TokenEmbedding, tokens_: Tensor, token_index: ?Tensor) Tensor {
|
pub fn embed(embed_tokens_: zml.nn.TokenEmbedding, tokens_: Tensor) Tensor {
|
||||||
const tokens = if (token_index) |idx|
|
return zml.call(embed_tokens_, .forward, .{tokens_}).withPartialTags(.{.d});
|
||||||
tokens_.dynamicSlice1d(-1, .{ .start = idx, .len = 1 })
|
|
||||||
else
|
|
||||||
tokens_;
|
|
||||||
return zml.call(embed_tokens_, .forward, .{tokens}).withPartialTags(.{ .s, .d });
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn initKvCache(self: Llama, embed_shape: zml.Shape) KvCache {
|
fn initKvCache(self: Llama, embed_shape: zml.Shape) KvCache {
|
||||||
const dims = self.shape();
|
const dims = self.shape();
|
||||||
var kv_shape = embed_shape.insert(0, .{ .layer = dims.layer }).rename(.{ .s = .k }).splitAxes(.{ .d = .{ .h = dims.nkvh, .hd = dims.hd } });
|
var kv_shape = embed_shape.insert(0, .{ .layer = dims.layer }).rename(.{ .s = .k }).splitAxes(.{ .d = .{ .h = dims.nkvh, .hd = dims.hd } });
|
||||||
const perm = kv_shape.contiguousPerm(.{ .h, .k, .hd });
|
const perm = kv_shape.contiguousPerm(.{ .k, .h, .hd });
|
||||||
kv_shape = kv_shape.transpose(perm.constSlice());
|
kv_shape = kv_shape.transpose(perm.constSlice());
|
||||||
return KvCache.init(kv_shape);
|
return KvCache.init(kv_shape);
|
||||||
}
|
}
|
||||||
@ -218,8 +197,8 @@ pub const TransformerLayer = struct {
|
|||||||
pub fn forward(
|
pub fn forward(
|
||||||
self: TransformerLayer,
|
self: TransformerLayer,
|
||||||
x0: Tensor,
|
x0: Tensor,
|
||||||
token_index: ?Tensor,
|
token_index: Tensor,
|
||||||
kv_cache: ?KvCache,
|
kv_cache: KvCache,
|
||||||
) struct { Tensor, KvCache } {
|
) struct { Tensor, KvCache } {
|
||||||
// Self Attention
|
// Self Attention
|
||||||
//log.debug("TransformerLayer({}) -> {}", .{ x0, self.input_layernorm.forward(x0) });
|
//log.debug("TransformerLayer({}) -> {}", .{ x0, self.input_layernorm.forward(x0) });
|
||||||
@ -287,39 +266,41 @@ pub const SelfAttn = struct {
|
|||||||
pub fn forward(
|
pub fn forward(
|
||||||
self: SelfAttn,
|
self: SelfAttn,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
token_index: ?Tensor,
|
token_index: Tensor,
|
||||||
kv_cache_: ?KvCache,
|
kv_cache: KvCache,
|
||||||
) struct { Tensor, KvCache } {
|
) struct { Tensor, KvCache } {
|
||||||
// log.debug("x.shape: {}", .{x.shape()});
|
|
||||||
const num_kv_heads = if (self.num_kv_heads > 0) self.num_kv_heads else self.num_heads;
|
const num_kv_heads = if (self.num_kv_heads > 0) self.num_kv_heads else self.num_heads;
|
||||||
var q = zml.call(self.q_proj, .forward, .{x}).splitAxis(-1, .{ .h = self.num_heads, .hd = .auto }).withSharding(.{.h});
|
var q = zml.call(self.q_proj, .forward, .{x}).splitAxis(-1, .{ .h = self.num_heads, .hd = .auto }).withSharding(.{.h});
|
||||||
var k = zml.call(self.k_proj, .forward, .{x}).splitAxis(-1, .{ .h = num_kv_heads, .hd = .auto }).withSharding(.{.h});
|
var k = zml.call(self.k_proj, .forward, .{x}).splitAxis(-1, .{ .h = num_kv_heads, .hd = .auto }).withSharding(.{.h});
|
||||||
var v = zml.call(self.v_proj, .forward, .{x}).splitAxis(-1, .{ .h = num_kv_heads, .hd = .auto }).withSharding(.{.h});
|
var v = zml.call(self.v_proj, .forward, .{x}).splitAxis(-1, .{ .h = num_kv_heads, .hd = .auto }).withSharding(.{.h});
|
||||||
|
|
||||||
// Generate the attention mask.
|
// Generate the attention mask.
|
||||||
const kv_cache = kv_cache_ orelse initKvCache(k.shape());
|
|
||||||
const seq_len = kv_cache.k.dim(.k);
|
const seq_len = kv_cache.k.dim(.k);
|
||||||
var attn_mask = zml.nn.causalAttnMask(.{ .q = seq_len, .k = seq_len }, x.dtype(), null);
|
var attn_mask = zml.nn.causalAttnMask(.{ .q = seq_len, .k = seq_len }, x.dtype(), null);
|
||||||
if (token_index) |idx| {
|
|
||||||
// Note: in Pytorch it would be very inefficient to generate the full attn_mask,
|
// Note: in Pytorch it would be very inefficient to generate the full attn_mask,
|
||||||
// then slice into it, but XLA is able to optimize this correctly.
|
// then slice into it, but XLA is able to optimize this correctly.
|
||||||
attn_mask = attn_mask.dynamicSlice(.{ .q = .{ .start = idx, .len = 1 } });
|
attn_mask = attn_mask.gatherSlices(zml.Shape.init(.{ .q = x.dim(.s) }, attn_mask.dtype()), token_index.reshape(.{ .b = token_index.shape().dim(0), .coord = 1 }), .{});
|
||||||
}
|
|
||||||
|
|
||||||
// In self-attention, .s axis is used both for keys and queries.
|
// In self-attention, .s axis is used both for keys and queries.
|
||||||
q = zml.nn.rope(q, token_index, self.rope_opts);
|
const pos_index = b: {
|
||||||
k = zml.nn.rope(k, token_index, self.rope_opts);
|
const temp = Tensor.arange(.{ .end = x.dim(.s) }, token_index.dtype()).withTags(.{.s}).broad(zml.Shape.init(.{ .b = token_index.shape().dim(0), .s = x.dim(.s) }, token_index.dtype()));
|
||||||
|
break :b temp.add(token_index.withTags(.{.b}).broad(temp.shape()));
|
||||||
|
};
|
||||||
|
|
||||||
|
q = zml.nn.rope(q, pos_index, self.rope_opts);
|
||||||
|
k = zml.nn.rope(k, pos_index, self.rope_opts);
|
||||||
q = q.rename(.{ .s = .q });
|
q = q.rename(.{ .s = .q });
|
||||||
k = k.rename(.{ .s = .k });
|
k = k.rename(.{ .s = .k });
|
||||||
v = v.rename(.{ .s = .k });
|
v = v.rename(.{ .s = .k });
|
||||||
|
|
||||||
const new_kv_cache = kv_cache.update(k, v, token_index orelse Tensor.scalar(0, .i32));
|
const dtype = q.dtype();
|
||||||
if (token_index) |_| {
|
const new_kv_cache = kv_cache.update(k, v, token_index);
|
||||||
stdx.debug.assert(q.dim(.q) == 1, "Expected dimension .q to be 1, got {}", .{q.dim(.q)});
|
k = new_kv_cache.keys().convert(dtype);
|
||||||
k = new_kv_cache.keys();
|
v = new_kv_cache.values().convert(dtype);
|
||||||
v = new_kv_cache.values();
|
|
||||||
}
|
|
||||||
|
|
||||||
const attn_output = zml.nn.sdpa(q, k, v, .{ .attn_mask = attn_mask, .allow_cudnn = false });
|
const attn_output = zml.nn.sdpa(q, k, v, .{ .attn_mask = attn_mask, .allow_cudnn = true });
|
||||||
|
// const attn_output = zml.nn.sdpaMemEfficient(q, k, v, .{ .attn_mask = attn_mask }, .{ .q_chunk_size = 4096, .k_chunk_size = 1024 });
|
||||||
const attn = attn_output.merge(.{ .d = .{ .h, .hd } }).rename(.{ .q = .s });
|
const attn = attn_output.merge(.{ .d = .{ .h, .hd } }).rename(.{ .q = .s });
|
||||||
return .{ zml.call(self.o_proj, .forward, .{attn}), new_kv_cache };
|
return .{ zml.call(self.o_proj, .forward, .{attn}), new_kv_cache };
|
||||||
}
|
}
|
||||||
@ -330,7 +311,7 @@ pub const SelfAttn = struct {
|
|||||||
const perm = kv_shape.contiguousPerm(.{ .h, .k, .hd });
|
const perm = kv_shape.contiguousPerm(.{ .h, .k, .hd });
|
||||||
kv_shape = kv_shape.transpose(perm.constSlice());
|
kv_shape = kv_shape.transpose(perm.constSlice());
|
||||||
var res = KvCache.init(kv_shape);
|
var res = KvCache.init(kv_shape);
|
||||||
res.layer_index = Tensor.scalar(0, .i32);
|
res.layer_index = Tensor.scalar(0, .u32);
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -345,7 +326,7 @@ pub const KvCache = struct {
|
|||||||
return .{
|
return .{
|
||||||
.k = Tensor.constant(kv_shape, kv_shape.dtype().one()).withSharding(.{.h}),
|
.k = Tensor.constant(kv_shape, kv_shape.dtype().one()).withSharding(.{.h}),
|
||||||
.v = Tensor.constant(kv_shape, kv_shape.dtype().one()).withSharding(.{.h}),
|
.v = Tensor.constant(kv_shape, kv_shape.dtype().one()).withSharding(.{.h}),
|
||||||
.layer_index = Tensor.scalar(-1, .i32),
|
.layer_index = Tensor.scalar(-1, .u32),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -353,7 +334,15 @@ pub const KvCache = struct {
|
|||||||
return .{
|
return .{
|
||||||
.k = kv_shape,
|
.k = kv_shape,
|
||||||
.v = kv_shape,
|
.v = kv_shape,
|
||||||
.layer_index = zml.Shape.init(.{}, .i32),
|
.layer_index = zml.Shape.init(.{}, .u32),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn initBuffer(kv_shape: zml.Shape, platform: zml.Platform) !zml.Bufferized(KvCache) {
|
||||||
|
return .{
|
||||||
|
.k = try zml.Buffer.constant(platform, kv_shape, 1),
|
||||||
|
.v = try zml.Buffer.constant(platform, kv_shape, 1),
|
||||||
|
.layer_index = try zml.Buffer.constant(platform, zml.Shape.init(.{}, .u32), 0),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -365,17 +354,33 @@ pub const KvCache = struct {
|
|||||||
return self.v.dynamicSlice(.{ .layer = .{ .start = self.layer_index, .len = 1 } }).squeeze(.layer);
|
return self.v.dynamicSlice(.{ .layer = .{ .start = self.layer_index, .len = 1 } }).squeeze(.layer);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn update(self: KvCache, new_k: Tensor, new_v: Tensor, token_index: Tensor) KvCache {
|
pub fn update(self: KvCache, new_k: Tensor, new_v: Tensor, token_index: ?Tensor) KvCache {
|
||||||
return .{
|
const k_shape = self.k.shape().drop(.layer);
|
||||||
.k = self.k.dynamicUpdateSlice(
|
var layer = self.layer_index;
|
||||||
.{ .layer = self.layer_index, .k = token_index },
|
layer = if (token_index) |idx| layer.broad(idx.shape()) else layer;
|
||||||
// transpose to match kv-cache layout
|
|
||||||
new_k.contiguous(.{ .h, .k, .hd }),
|
return if (token_index) |idx| .{
|
||||||
|
.k = self.k.scatterSlices(
|
||||||
|
.{ .layer = layer, .k = idx },
|
||||||
|
new_k.convert(self.k.dtype()).transpose(k_shape),
|
||||||
|
.{ .indices_are_sorted = true, .update_fn = zml.Tensor.ScatterOpts.override },
|
||||||
).reuseBuffer(self.k),
|
).reuseBuffer(self.k),
|
||||||
.v = self.v.dynamicUpdateSlice(
|
.v = self.v.scatterSlices(
|
||||||
.{ .layer = self.layer_index, .k = token_index },
|
.{ .layer = layer, .k = idx },
|
||||||
// transpose to match kv-cache layout
|
new_v.convert(self.v.dtype()).transpose(k_shape),
|
||||||
new_v.contiguous(.{ .h, .k, .hd }),
|
.{ .indices_are_sorted = true, .update_fn = zml.Tensor.ScatterOpts.override },
|
||||||
|
).reuseBuffer(self.v),
|
||||||
|
.layer_index = self.layer_index,
|
||||||
|
} else .{
|
||||||
|
.k = self.k.scatterSlices(
|
||||||
|
.{ .layer = layer },
|
||||||
|
new_k.convert(self.k.dtype()).transpose(k_shape),
|
||||||
|
.{ .indices_are_sorted = true, .update_fn = zml.Tensor.ScatterOpts.override },
|
||||||
|
).reuseBuffer(self.k),
|
||||||
|
.v = self.v.scatterSlices(
|
||||||
|
.{ .layer = layer },
|
||||||
|
new_v.convert(self.v.dtype()).transpose(k_shape),
|
||||||
|
.{ .indices_are_sorted = true, .update_fn = zml.Tensor.ScatterOpts.override },
|
||||||
).reuseBuffer(self.v),
|
).reuseBuffer(self.v),
|
||||||
.layer_index = self.layer_index,
|
.layer_index = self.layer_index,
|
||||||
};
|
};
|
||||||
@ -385,7 +390,7 @@ pub const KvCache = struct {
|
|||||||
return .{
|
return .{
|
||||||
.k = self.k,
|
.k = self.k,
|
||||||
.v = self.v,
|
.v = self.v,
|
||||||
.layer_index = Tensor.scalar(layer_index, .i32),
|
.layer_index = Tensor.scalar(layer_index, .u32),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1,248 +1,330 @@
|
|||||||
const asynk = @import("async");
|
const asynk = @import("async");
|
||||||
const flags = @import("tigerbeetle/flags");
|
const clap = @import("clap");
|
||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
const stdx = @import("stdx");
|
const stdx = @import("stdx");
|
||||||
const zml = @import("zml");
|
const zml = @import("zml");
|
||||||
|
|
||||||
const llama_mod = @import("llama.zig");
|
const llama = @import("llama.zig");
|
||||||
|
|
||||||
const LlamaLM = llama_mod.LlamaLM;
|
const LlamaLM = llama.LlamaLM;
|
||||||
const Llama = llama_mod.Llama;
|
const Llama = llama.Llama;
|
||||||
const KvCache = llama_mod.KvCache;
|
const KvCache = llama.KvCache;
|
||||||
const TransformerLayer = llama_mod.TransformerLayer;
|
const TransformerLayer = llama.TransformerLayer;
|
||||||
const SelfAttn = llama_mod.SelfAttn;
|
const SelfAttn = llama.SelfAttn;
|
||||||
const Buffer = zml.Buffer;
|
const Buffer = zml.Buffer;
|
||||||
const Tensor = zml.Tensor;
|
const Tensor = zml.Tensor;
|
||||||
const ShapeOf = zml.ShapeOf;
|
const ShapeOf = zml.ShapeOf;
|
||||||
|
|
||||||
const log = std.log.scoped(.llama);
|
const log = std.log.scoped(.llama);
|
||||||
|
|
||||||
const eos_tokens: [3]i32 = .{ 128001, 128008, 128009 };
|
|
||||||
|
|
||||||
// set this to false to disable the verbose logging
|
|
||||||
const show_mlir = true;
|
|
||||||
|
|
||||||
pub const std_options = .{
|
pub const std_options = .{
|
||||||
.log_level = .warn,
|
.log_level = .info,
|
||||||
.log_scope_levels = &[_]std.log.ScopeLevel{
|
|
||||||
.{ .scope = .zml_module, .level = if (show_mlir) .debug else .warn },
|
|
||||||
.{ .scope = .llama, .level = .info },
|
|
||||||
},
|
|
||||||
.logFn = asynk.logFn,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
pub fn tokenizePromptLlama3(allocator: std.mem.Allocator, tokenizer: zml.tokenizer.Tokenizer, config: LlamaLM.Config, prompt: []const u8) ![]u32 {
|
||||||
|
var tokens = std.ArrayList(u32).init(allocator);
|
||||||
|
var encoder = try tokenizer.encoder();
|
||||||
|
defer encoder.deinit();
|
||||||
|
|
||||||
|
const start_header_id = tokenizer.token_to_id("<|start_header_id|>") orelse return error.NoSuchToken;
|
||||||
|
const end_header_id = tokenizer.token_to_id("<|end_header_id|>") orelse return error.NoSuchToken;
|
||||||
|
const eot_id = tokenizer.token_to_id("<|eot_id|>") orelse return error.NoSuchToken;
|
||||||
|
const newline_id = (try encoder.encode("\n"))[0];
|
||||||
|
|
||||||
|
try tokens.append(config.bos_token_id);
|
||||||
|
|
||||||
|
try tokens.append(start_header_id);
|
||||||
|
try tokens.appendSlice(try encoder.encode("user"));
|
||||||
|
try tokens.appendSlice(&.{ end_header_id, newline_id, newline_id });
|
||||||
|
|
||||||
|
try tokens.appendSlice(try encoder.encode(prompt));
|
||||||
|
try tokens.appendSlice(&.{ eot_id, newline_id });
|
||||||
|
try tokens.appendSlice(try encoder.encode("\n"));
|
||||||
|
try tokens.append(start_header_id);
|
||||||
|
try tokens.appendSlice(try encoder.encode("assistant"));
|
||||||
|
try tokens.append(end_header_id);
|
||||||
|
|
||||||
|
return tokens.toOwnedSlice();
|
||||||
|
}
|
||||||
|
|
||||||
pub fn generateText(
|
pub fn generateText(
|
||||||
llama: LlamaLM,
|
config: LlamaLM.Config,
|
||||||
|
llama_: LlamaLM,
|
||||||
mod_prefill: zml.ModuleExe(LlamaLM.forward),
|
mod_prefill: zml.ModuleExe(LlamaLM.forward),
|
||||||
mod: zml.ModuleExe(LlamaLM.forward),
|
mod_generate: zml.ModuleExe(LlamaLM.forward),
|
||||||
|
kv_cache_: zml.Bufferized(llama.KvCache),
|
||||||
tokenizer: zml.tokenizer.Tokenizer,
|
tokenizer: zml.tokenizer.Tokenizer,
|
||||||
allocator: std.mem.Allocator,
|
allocator: std.mem.Allocator,
|
||||||
seed: u128,
|
seed: u128,
|
||||||
prompt: []const u8,
|
prompt: []const u8,
|
||||||
|
skip_llama3_encoding: bool,
|
||||||
) ![]const u8 {
|
) ![]const u8 {
|
||||||
const prompt_tok = tokenizer.encode(allocator, prompt, .{}) catch unreachable;
|
var tokenizer_encoder = try tokenizer.encoder();
|
||||||
log.debug("Tokenized Prompt {d}", .{prompt_tok});
|
defer tokenizer_encoder.deinit();
|
||||||
const dims = llama.model.shape();
|
var tokenizer_decoder = try tokenizer.decoder();
|
||||||
const max_seq_len = dims.s;
|
defer tokenizer_decoder.deinit();
|
||||||
const token_buffer = try allocator.alloc(i32, @intCast(max_seq_len));
|
|
||||||
@memset(token_buffer, 0);
|
|
||||||
for (0..prompt_tok.len) |i| {
|
|
||||||
token_buffer[i] = @intCast(prompt_tok[i]);
|
|
||||||
}
|
|
||||||
|
|
||||||
const tracer_buffer = try allocator.alloc(u8, @intCast(max_seq_len));
|
const prompt_tok: []const u32 = if (skip_llama3_encoding) try tokenizer_encoder.encode(prompt) else try tokenizePromptLlama3(allocator, tokenizer, config, prompt);
|
||||||
defer allocator.free(token_buffer);
|
|
||||||
defer allocator.free(tracer_buffer);
|
|
||||||
defer allocator.free(prompt_tok);
|
defer allocator.free(prompt_tok);
|
||||||
var output = std.ArrayList(u8).init(allocator);
|
|
||||||
defer output.deinit();
|
|
||||||
|
|
||||||
var tokens = try zml.Buffer.fromSlice(mod.platform(), .{max_seq_len}, token_buffer);
|
const dims = llama_.model.shape();
|
||||||
var prefill_token_index = try zml.Buffer.fromSlice(mod.platform(), .{}, &[_]i32{@intCast(prompt_tok.len - 1)});
|
const max_seq_len = dims.s;
|
||||||
|
|
||||||
|
// Prefill
|
||||||
|
// initialize a 0..max_seq_len buffer with the tokenized prompt
|
||||||
|
const prefill_buffer = try allocator.alloc(u32, @intCast(max_seq_len));
|
||||||
|
@memset(prefill_buffer, 0);
|
||||||
|
for (0..prompt_tok.len) |i| {
|
||||||
|
prefill_buffer[i] = @intCast(prompt_tok[i]);
|
||||||
|
}
|
||||||
|
defer allocator.free(prefill_buffer);
|
||||||
|
|
||||||
|
const platform = mod_generate.platform();
|
||||||
|
|
||||||
|
// prepare device buffers for the prefill tokens and the index
|
||||||
|
var prefill_tokens = try zml.Buffer.fromSlice(platform, .{max_seq_len}, prefill_buffer);
|
||||||
|
defer prefill_tokens.deinit();
|
||||||
|
var prefill_token_index = try zml.Buffer.fromSlice(platform, .{}, &[_]u32{0});
|
||||||
defer prefill_token_index.deinit();
|
defer prefill_token_index.deinit();
|
||||||
|
|
||||||
var rng = try zml.Tensor.Rng.init(mod.platform(), seed);
|
// init RNG and prefill
|
||||||
tokens, var token_index, var kv_cache, rng = mod_prefill.call(.{ tokens, prefill_token_index, null, rng });
|
var rng = try zml.Tensor.Rng.init(platform, seed);
|
||||||
defer token_index.deinit();
|
prefill_tokens, var kv_cache, rng = mod_prefill.call(.{ prefill_tokens, prefill_token_index, kv_cache_, rng });
|
||||||
defer kv_cache.k.deinit();
|
defer kv_cache.k.deinit();
|
||||||
defer kv_cache.v.deinit();
|
defer kv_cache.v.deinit();
|
||||||
defer kv_cache.layer_index.deinit();
|
defer kv_cache.layer_index.deinit();
|
||||||
|
|
||||||
|
// Prepare for token-by-token generation
|
||||||
|
var first_token_hostbuffer = [_]u32{prompt_tok[prompt_tok.len - 1]}; // start with the prompt's last token
|
||||||
|
var current_token = try zml.Buffer.fromSlice(platform, .{}, &first_token_hostbuffer);
|
||||||
|
defer current_token.deinit();
|
||||||
|
|
||||||
|
// Here we will copy the generated token from device
|
||||||
|
var generated_token_buffer = [_]u32{0};
|
||||||
|
|
||||||
|
// Here we collect the generated text
|
||||||
|
var output = std.ArrayList(u8).init(allocator);
|
||||||
|
defer output.deinit();
|
||||||
|
|
||||||
|
const tracer_buffer = try allocator.alloc(u8, @intCast(max_seq_len));
|
||||||
|
defer allocator.free(tracer_buffer);
|
||||||
const tracer = zml.tools.Tracer.init("ai.zml.models.llama");
|
const tracer = zml.tools.Tracer.init("ai.zml.models.llama");
|
||||||
var decode_progress = prompt_tok.len;
|
|
||||||
const output_tokens_len = max_seq_len - prompt_tok.len - 1;
|
const output_tokens_len = max_seq_len - prompt_tok.len - 1;
|
||||||
|
|
||||||
const start = std.time.microTimestamp();
|
const start = std.time.microTimestamp();
|
||||||
const output_freq: u8 = 1;
|
|
||||||
var eos_index: ?usize = null;
|
var num_tokens_generated: usize = 0;
|
||||||
for (0..output_tokens_len) |i| {
|
|
||||||
//_ = i;
|
generation: for (0..output_tokens_len) |i| {
|
||||||
const frame_id = tracer.frameStart(try std.fmt.bufPrintZ(tracer_buffer, "Generate token {}/{}", .{ i + 1, output_tokens_len }));
|
const frame_id = tracer.frameStart(try std.fmt.bufPrintZ(tracer_buffer, "Generate token {}/{}", .{ i + 1, output_tokens_len }));
|
||||||
tokens, const new_token_index, kv_cache, rng = mod.call(.{ tokens, token_index, kv_cache, rng });
|
|
||||||
token_index.deinit();
|
// current token index needs to go into a zml.Buffer
|
||||||
token_index = new_token_index;
|
const token_index_buffer = &[_]u32{@intCast(prompt_tok.len + i)};
|
||||||
if ((i + 1) % output_freq == 0) {
|
const token_index = try zml.Buffer.fromSlice(platform, .{}, token_index_buffer);
|
||||||
const n = output.items.len;
|
defer token_index.deinit();
|
||||||
_ = try tokens.toHost(std.mem.sliceAsBytes(token_buffer));
|
|
||||||
try tokenizer.decodeWithOpts(&output, @ptrCast(token_buffer[decode_progress..][0..output_freq]), .{});
|
// call to generate the next token
|
||||||
decode_progress += output_freq;
|
current_token, kv_cache, rng = mod_generate.call(.{ current_token, token_index, kv_cache, rng });
|
||||||
std.debug.print("{s}", .{output.items[n..]});
|
|
||||||
tracer.frameEnd(frame_id, try std.fmt.bufPrintZ(tracer_buffer, "Decoded token {}/{} : {s}", .{ i + 1, output_tokens_len, output.items[n..] }));
|
|
||||||
if (std.mem.indexOfAny(i32, token_buffer[decode_progress - output_freq ..], &eos_tokens)) |index| {
|
|
||||||
// Handle strange scenarios when eos id isn't the very next token after decode_progress
|
|
||||||
eos_index = decode_progress - output_freq + index;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
tracer.frameEnd(frame_id, try std.fmt.bufPrintZ(tracer_buffer, "Generated token {}/{}", .{ i + 1, output_tokens_len }));
|
tracer.frameEnd(frame_id, try std.fmt.bufPrintZ(tracer_buffer, "Generated token {}/{}", .{ i + 1, output_tokens_len }));
|
||||||
|
|
||||||
|
// extract the generated token from the buffer
|
||||||
|
_ = try current_token.toHost(std.mem.sliceAsBytes(&generated_token_buffer));
|
||||||
|
const generated_token = generated_token_buffer[0];
|
||||||
|
// de-tokenize generated token into a string
|
||||||
|
const chunk = try tokenizer_decoder.next(@intCast(generated_token)) orelse unreachable;
|
||||||
|
num_tokens_generated = i;
|
||||||
|
|
||||||
|
// check for eos
|
||||||
|
switch (config.eos_token_id.value) {
|
||||||
|
.int => |eos| if (generated_token == @as(u32, @intCast(eos))) break :generation,
|
||||||
|
.ints => |eos_list| {
|
||||||
|
for (eos_list) |eos| {
|
||||||
|
if (generated_token == @as(u32, @intCast(eos))) break :generation;
|
||||||
}
|
}
|
||||||
|
},
|
||||||
}
|
}
|
||||||
var total_token_count: usize = max_seq_len;
|
|
||||||
const n = output.items.len;
|
// collect and print generated sequence
|
||||||
if (eos_index) |end_idx| {
|
try output.appendSlice(chunk);
|
||||||
// count = eos index + 1
|
std.debug.print("{s}", .{chunk});
|
||||||
total_token_count = end_idx + 1;
|
|
||||||
}
|
}
|
||||||
const generated_token_count = total_token_count - prompt_tok.len;
|
|
||||||
try tokenizer.decodeWithOpts(&output, @ptrCast(token_buffer[decode_progress..total_token_count]), .{});
|
|
||||||
std.debug.print("{s}\n", .{output.items[n..]});
|
|
||||||
const end = std.time.microTimestamp();
|
const end = std.time.microTimestamp();
|
||||||
|
|
||||||
const duration = stdx.math.divFloat(f64, end - start, std.time.us_per_s);
|
const duration = stdx.math.divFloat(f64, end - start, std.time.us_per_s);
|
||||||
const speed = @as(f64, @floatFromInt(generated_token_count)) / duration;
|
const speed = @as(f64, @floatFromInt(num_tokens_generated)) / duration;
|
||||||
log.info("✅ Generated {d} tokens in {:.3}s: {d:.3}tok/s", .{ generated_token_count, duration, speed });
|
std.debug.print("\n", .{});
|
||||||
|
log.info("✅ Generated {d} tokens in {:.3}s: {d:.3}tok/s", .{ num_tokens_generated, duration, speed });
|
||||||
_ = try tokens.toHost(std.mem.sliceAsBytes(token_buffer));
|
|
||||||
output.clearRetainingCapacity();
|
|
||||||
|
|
||||||
try tokenizer.decodeWithOpts(&output, @ptrCast(token_buffer[0..total_token_count]), .{});
|
|
||||||
return output.toOwnedSlice();
|
return output.toOwnedSlice();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const params = clap.parseParamsComptime(
|
||||||
|
\\--help print this help
|
||||||
|
\\--prompt <STRING> the prompt
|
||||||
|
\\--config <PATH> config.json path
|
||||||
|
\\--weights <PATH> model weights path
|
||||||
|
\\--tokenizer <PATH> tokenizer path
|
||||||
|
\\--seed <UINT> random seed (optional)
|
||||||
|
\\--seq-len <UINT> sequence length
|
||||||
|
\\--create-options <STRING> platform creation options JSON, defaults to {}
|
||||||
|
\\--no-llama3 <BOOL> skip prompt template
|
||||||
|
\\--sharding <BOOL> default: true: sharding on or off
|
||||||
|
);
|
||||||
|
|
||||||
|
pub fn bool_parser(in: []const u8) error{}!bool {
|
||||||
|
return std.mem.indexOfScalar(u8, "tTyY1", in[0]) != null;
|
||||||
|
}
|
||||||
|
|
||||||
pub fn main() !void {
|
pub fn main() !void {
|
||||||
try asynk.AsyncThread.main(std.heap.c_allocator, asyncMain);
|
try asynk.AsyncThread.main(std.heap.c_allocator, asyncMain);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn asyncMain() !void {
|
pub fn asyncMain() !void {
|
||||||
const CliArgs = struct {
|
|
||||||
pub const help =
|
|
||||||
\\ llama --model=llama3.7B.safetensors --tokenizer=vocab.json --num_layers=2
|
|
||||||
;
|
|
||||||
model: []const u8,
|
|
||||||
tokenizer: ?[]const u8 = null,
|
|
||||||
layer_start: u8 = 0,
|
|
||||||
num_layers: ?u8 = null,
|
|
||||||
seq_len: u32 = 256,
|
|
||||||
topk: u32 = 2,
|
|
||||||
temperature: u32 = 1,
|
|
||||||
num_heads: ?i64 = null,
|
|
||||||
num_kv_heads: ?i64 = null,
|
|
||||||
rope_freq_base: ?i64 = null,
|
|
||||||
prompt: ?[]const u8 = null,
|
|
||||||
test_activations: ?[]const u8 = null,
|
|
||||||
seed: ?u128 = null,
|
|
||||||
// eg: --create-options='{"cuda":{"allocator":{"bfc":{"memory_fraction": 0.99}}}}'
|
|
||||||
create_options: []const u8 = "{}",
|
|
||||||
};
|
|
||||||
|
|
||||||
log.info(" LLama was compiled with {}", .{@import("builtin").mode});
|
log.info(" LLama was compiled with {}", .{@import("builtin").mode});
|
||||||
|
|
||||||
const allocator = std.heap.c_allocator;
|
const allocator = std.heap.c_allocator;
|
||||||
|
|
||||||
const tmp = try std.fs.openDirAbsolute("/tmp", .{});
|
const parsers = comptime .{
|
||||||
try tmp.makePath("zml/llama/cache");
|
.BOOL = bool_parser,
|
||||||
|
.UINT = clap.parsers.int(usize, 0),
|
||||||
|
.STRING = clap.parsers.string,
|
||||||
|
.PATH = clap.parsers.string,
|
||||||
|
};
|
||||||
|
var diag: clap.Diagnostic = .{};
|
||||||
|
const stderr = std.io.getStdErr().writer();
|
||||||
|
var res = clap.parse(clap.Help, ¶ms, parsers, .{
|
||||||
|
.diagnostic = &diag,
|
||||||
|
.allocator = allocator,
|
||||||
|
}) catch |err| {
|
||||||
|
diag.report(stderr, err) catch {};
|
||||||
|
stderr.print("usage: ", .{}) catch {};
|
||||||
|
clap.usage(stderr, clap.Help, ¶ms) catch {};
|
||||||
|
stderr.print("\n", .{}) catch {};
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
defer res.deinit();
|
||||||
|
|
||||||
|
if (res.args.help != 0) {
|
||||||
|
clap.help(std.io.getStdErr().writer(), clap.Help, ¶ms, .{}) catch {};
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const config = blk: {
|
||||||
|
if (res.args.config) |config_json_path| {
|
||||||
|
var config_json_file = try asynk.File.open(config_json_path, .{ .mode = .read_only });
|
||||||
|
defer config_json_file.close() catch unreachable;
|
||||||
|
var reader = std.json.reader(allocator, config_json_file.reader());
|
||||||
|
defer reader.deinit();
|
||||||
|
const config_obj = try std.json.parseFromTokenSourceLeaky(llama.LlamaLM.Config, allocator, &reader, .{ .ignore_unknown_fields = true });
|
||||||
|
break :blk config_obj;
|
||||||
|
} else {
|
||||||
|
log.err("Missing --config", .{});
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
var context = try zml.Context.init();
|
var context = try zml.Context.init();
|
||||||
defer context.deinit();
|
defer context.deinit();
|
||||||
|
|
||||||
const compilation_options = zml.CompilationOptions{
|
const compilation_options = zml.CompilationOptions{
|
||||||
.xla_dump_to = "/tmp/zml/llama",
|
.xla_dump_to = "/tmp/zml/llama",
|
||||||
.sharding_enabled = true,
|
.sharding_enabled = res.args.sharding orelse true,
|
||||||
};
|
};
|
||||||
|
|
||||||
var args = std.process.args();
|
// initialize ZML platform with optional create options
|
||||||
const cli_args = flags.parse(&args, CliArgs);
|
// eg: --create-options='{"cuda":{"allocator":{"bfc":{"memory_fraction": 0.99}}}}'
|
||||||
const model_file = cli_args.model;
|
const create_opts_json = res.args.@"create-options" orelse "{}";
|
||||||
|
const create_opts = try std.json.parseFromSlice(zml.Platform.CreateOptions, allocator, create_opts_json, .{});
|
||||||
var arena_state = std.heap.ArenaAllocator.init(allocator);
|
const platform = context.autoPlatform(create_opts.value).withCompilationOptions(compilation_options);
|
||||||
defer arena_state.deinit();
|
create_opts.deinit();
|
||||||
const model_arena = arena_state.allocator();
|
|
||||||
|
|
||||||
const create_opts = try std.json.parseFromSliceLeaky(zml.Platform.CreateOptions, model_arena, cli_args.create_options, .{});
|
|
||||||
const platform = context.autoPlatform(create_opts).withCompilationOptions(compilation_options);
|
|
||||||
context.printAvailablePlatforms(platform);
|
context.printAvailablePlatforms(platform);
|
||||||
|
|
||||||
log.info("Model file: {s}", .{model_file});
|
var ts = try zml.aio.detectFormatAndOpen(allocator, res.args.weights.?);
|
||||||
|
|
||||||
var ts = try zml.aio.detectFormatAndOpen(allocator, model_file);
|
|
||||||
defer ts.deinit();
|
defer ts.deinit();
|
||||||
|
|
||||||
var llama = try zml.aio.populateModel(LlamaLM, model_arena, ts);
|
var model_arena = std.heap.ArenaAllocator.init(allocator);
|
||||||
const num_heads = cli_args.num_heads orelse ts.metadata("num_heads", .int) orelse @panic("--num_heads is required for this model");
|
var model_instance = try zml.aio.populateModel(llama.LlamaLM, model_arena.allocator(), ts);
|
||||||
const num_kv_heads = cli_args.num_kv_heads orelse ts.metadata("num_kv_heads", .int) orelse num_heads;
|
|
||||||
|
|
||||||
const rope_impl = if (ts.metadata("rope_impl", .string)) |val|
|
const llama_options: llama.LlamaLM.Options = .{
|
||||||
std.meta.stringToEnum(zml.nn.RopeOpts.Implementation, val).?
|
.max_seq_len = @intCast(res.args.@"seq-len" orelse 256),
|
||||||
else
|
.sampling_strategy = .{
|
||||||
.sequential;
|
.topk = 1,
|
||||||
|
.temperature = 1.0,
|
||||||
const llama_options: llama_mod.LlamaOptions = .{
|
|
||||||
.max_seq_len = cli_args.seq_len,
|
|
||||||
.num_kv_heads = num_kv_heads,
|
|
||||||
.num_heads = num_heads,
|
|
||||||
.gen_opts = .{
|
|
||||||
.topk = cli_args.topk,
|
|
||||||
.temperature = @floatFromInt(cli_args.temperature),
|
|
||||||
},
|
|
||||||
.rms_norm_eps = @floatCast(ts.metadata("rms_norm_eps", .float) orelse 1e-5),
|
|
||||||
.rope_opts = .{
|
|
||||||
.impl = rope_impl,
|
|
||||||
.freq_base = @floatCast(ts.metadata("rope_freq_base", .float) orelse @as(f32, @floatFromInt(cli_args.rope_freq_base orelse 10_000))),
|
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
log.info("✅\tParsed llama config: {}", .{llama_options});
|
model_instance.init(config, llama_options);
|
||||||
llama.init(llama_options);
|
|
||||||
|
|
||||||
if (cli_args.tokenizer == null and !std.mem.endsWith(u8, cli_args.model, ".gguf")) {
|
const dims = model_instance.model.shape();
|
||||||
log.err("Model doesn't have an embbedded tokenizer, please provide a path to a tokenizer.", .{});
|
const dtype = model_instance.model.embed_tokens.weight.dtype();
|
||||||
@panic("No tokenizer provided");
|
|
||||||
}
|
|
||||||
const tokenizer_path = cli_args.tokenizer orelse cli_args.model;
|
|
||||||
log.info("\tLoading tokenizer from {s}", .{tokenizer_path});
|
|
||||||
var tokenizer = try zml.aio.detectFormatAndLoadTokenizer(allocator, tokenizer_path);
|
|
||||||
log.info("✅\tLoaded tokenizer from {s}", .{tokenizer_path});
|
|
||||||
defer tokenizer.deinit();
|
|
||||||
|
|
||||||
const dims = llama.model.shape();
|
const batch_size = 1;
|
||||||
const dtype = llama.model.embed_tokens.weight.dtype();
|
|
||||||
|
|
||||||
// Note: we compile the model without a batching dimension.
|
const tokens_shape_prefill = zml.Shape.init(.{ .b = batch_size, .s = llama_options.max_seq_len }, .u32);
|
||||||
// To do so, we would just need to add `.b = batch_size` to `token_shape` and `kv_shape`.
|
const tokens_shape = zml.Shape.init(.{ .b = batch_size, .s = 1 }, .u32);
|
||||||
const tokens_shape = zml.Shape.init(.{ .s = dims.s }, .i32);
|
const token_idx_shape = zml.Shape.init(.{ .b = batch_size }, .u32);
|
||||||
const token_idx_shape = zml.Shape.init(.{}, .i32);
|
|
||||||
const kv_shape = zml.Shape.init(.{ .layer = llama.model.layers.len, .h = dims.nkvh, .k = dims.s, .hd = dims.hd }, dtype).withSharding(.{.h});
|
const kv_shape = zml.Shape.init(.{ .layer = model_instance.model.layers.len, .b = batch_size, .k = dims.s, .h = dims.nkvh, .hd = dims.hd }, dtype).withSharding(.{.h});
|
||||||
// needs to be optional
|
|
||||||
const kv_cache_shape: ?ShapeOf(KvCache) = KvCache.initShape(kv_shape);
|
const kv_cache_shape: zml.ShapeOf(llama.KvCache) = llama.KvCache.initShape(kv_shape);
|
||||||
const rng_shape = Tensor.Rng.shape();
|
const rng_shape = zml.Tensor.Rng.shape();
|
||||||
|
|
||||||
var start = try std.time.Timer.start();
|
var start = try std.time.Timer.start();
|
||||||
var fut_mod_prefill = try asynk.asyncc(zml.compile, .{ allocator, LlamaLM.forward, .{llama_options}, .{ tokens_shape, token_idx_shape, null, rng_shape }, ts, platform });
|
var fut_mod_prefill = try asynk.asyncc(zml.compile, .{
|
||||||
var fut_mod = try asynk.asyncc(zml.compile, .{ allocator, LlamaLM.forward, .{llama_options}, .{ tokens_shape, token_idx_shape, kv_cache_shape, rng_shape }, ts, platform });
|
allocator, llama.LlamaLM.forward, .{ config, llama_options },
|
||||||
|
.{
|
||||||
|
tokens_shape_prefill,
|
||||||
|
token_idx_shape,
|
||||||
|
kv_cache_shape,
|
||||||
|
rng_shape,
|
||||||
|
},
|
||||||
|
ts,
|
||||||
|
platform,
|
||||||
|
});
|
||||||
|
|
||||||
log.info("\tLoading Llama weights from {s}...", .{cli_args.model});
|
var fut_mod = try asynk.asyncc(zml.compile, .{
|
||||||
var llama_weights = try zml.aio.loadBuffers(LlamaLM, .{llama_options}, ts, model_arena, platform);
|
allocator, llama.LlamaLM.forward, .{ config, llama_options },
|
||||||
|
.{
|
||||||
|
tokens_shape,
|
||||||
|
token_idx_shape,
|
||||||
|
kv_cache_shape,
|
||||||
|
rng_shape,
|
||||||
|
},
|
||||||
|
ts,
|
||||||
|
platform,
|
||||||
|
});
|
||||||
|
|
||||||
|
log.info("\tLoading Llama weights from {?s}...", .{res.args.weights});
|
||||||
|
var llama_weights = try zml.aio.loadBuffers(llama.LlamaLM, .{ config, llama_options }, ts, model_arena.allocator(), platform);
|
||||||
defer zml.aio.unloadBuffers(&llama_weights);
|
defer zml.aio.unloadBuffers(&llama_weights);
|
||||||
log.info("✅\tLoaded weights in {d}ms", .{start.read() / std.time.ns_per_ms});
|
log.info("✅\tLoaded weights in {}", .{std.fmt.fmtDuration(start.read())});
|
||||||
|
|
||||||
var llama_module_prefill = (try fut_mod_prefill.awaitt()).prepare(llama_weights);
|
var llama_module_prefill = (try fut_mod_prefill.awaitt()).prepare(llama_weights);
|
||||||
defer llama_module_prefill.deinit();
|
defer llama_module_prefill.deinit();
|
||||||
var llama_module = (try fut_mod.awaitt()).prepare(llama_weights);
|
var llama_module = (try fut_mod.awaitt()).prepare(llama_weights);
|
||||||
defer llama_module.deinit();
|
defer llama_module.deinit();
|
||||||
log.info("✅\tCompiled model in {d}ms", .{start.read() / std.time.ns_per_ms});
|
log.info("✅\tCompiled model in {}", .{std.fmt.fmtDuration(start.read())});
|
||||||
|
|
||||||
const prompt = cli_args.prompt orelse "Once upon a time, there was a little girl named Lily.";
|
log.info("Creating KvCache", .{});
|
||||||
|
const kv_cache = try llama.KvCache.initBuffer(kv_shape, platform);
|
||||||
|
|
||||||
|
var tokenizer = blk: {
|
||||||
|
if (res.args.tokenizer) |tok| {
|
||||||
|
log.info("Loading tokenizer from {s}", .{tok});
|
||||||
|
var timer = try stdx.time.Timer.start();
|
||||||
|
defer log.info("Loaded tokenizer from {s} [{}]", .{ tok, timer.read() });
|
||||||
|
|
||||||
|
break :blk try zml.tokenizer.Tokenizer.from_file(model_arena.allocator(), tok);
|
||||||
|
} else {
|
||||||
|
log.err("Missing --tokenizer", .{});
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
errdefer tokenizer.deinit();
|
||||||
|
|
||||||
|
const prompt = res.args.prompt orelse "What is the capital of France?";
|
||||||
log.info("✅\tPrompt: {s}", .{prompt});
|
log.info("✅\tPrompt: {s}", .{prompt});
|
||||||
|
|
||||||
const seed = cli_args.seed orelse @as(u128, @bitCast(std.time.nanoTimestamp()));
|
const seed = res.args.seed orelse @as(u128, @bitCast(std.time.nanoTimestamp()));
|
||||||
const story = try generateText(llama, llama_module_prefill, llama_module, tokenizer, allocator, seed, prompt);
|
const skip_llama3_encoding = res.args.@"no-llama3" orelse false;
|
||||||
defer allocator.free(story);
|
const generated_text = try generateText(config, model_instance, llama_module_prefill, llama_module, kv_cache, tokenizer, allocator, seed, prompt[0..], skip_llama3_encoding);
|
||||||
|
// generated text will be printed token by token.
|
||||||
|
defer allocator.free(generated_text);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -16,9 +16,10 @@ pub fn main() !void {
|
|||||||
pub fn asyncMain() !void {
|
pub fn asyncMain() !void {
|
||||||
const CliArgs = struct {
|
const CliArgs = struct {
|
||||||
pub const help =
|
pub const help =
|
||||||
\\ test-implementation --model=llama3.8B.safetensors --reference=activation.safetensors
|
\\ test-implementation --weights=llama3.8B.safetensors --config=config.json --reference=activation.safetensors
|
||||||
;
|
;
|
||||||
model: []const u8,
|
weights: []const u8,
|
||||||
|
config: []const u8,
|
||||||
reference: []const u8,
|
reference: []const u8,
|
||||||
num_heads: ?i64 = null,
|
num_heads: ?i64 = null,
|
||||||
num_kv_heads: ?i64 = null,
|
num_kv_heads: ?i64 = null,
|
||||||
@ -38,7 +39,7 @@ pub fn asyncMain() !void {
|
|||||||
// Parse program args
|
// Parse program args
|
||||||
var args = std.process.args();
|
var args = std.process.args();
|
||||||
const cli_args = flags.parse(&args, CliArgs);
|
const cli_args = flags.parse(&args, CliArgs);
|
||||||
const model_file = cli_args.model;
|
const model_file = cli_args.weights;
|
||||||
|
|
||||||
// Memory arena dedicated to model shapes and weights
|
// Memory arena dedicated to model shapes and weights
|
||||||
var arena_state = std.heap.ArenaAllocator.init(allocator);
|
var arena_state = std.heap.ArenaAllocator.init(allocator);
|
||||||
@ -61,6 +62,16 @@ pub fn asyncMain() !void {
|
|||||||
else
|
else
|
||||||
.sequential;
|
.sequential;
|
||||||
|
|
||||||
|
const config = blk: {
|
||||||
|
var config_json_file = try asynk.File.open(cli_args.config, .{ .mode = .read_only });
|
||||||
|
defer config_json_file.close() catch unreachable;
|
||||||
|
var reader = std.json.reader(allocator, config_json_file.reader());
|
||||||
|
defer reader.deinit();
|
||||||
|
const config_obj = try std.json.parseFromTokenSourceLeaky(LlamaLM.Config, allocator, &reader, .{ .ignore_unknown_fields = true });
|
||||||
|
break :blk config_obj;
|
||||||
|
};
|
||||||
|
std.log.info("Parsed llama config: {}", .{config});
|
||||||
|
|
||||||
const llama_options: llama_mod.LlamaOptions = .{
|
const llama_options: llama_mod.LlamaOptions = .{
|
||||||
.max_seq_len = 256,
|
.max_seq_len = 256,
|
||||||
.num_kv_heads = num_kv_heads,
|
.num_kv_heads = num_kv_heads,
|
||||||
@ -101,26 +112,4 @@ fn testImplementation(
|
|||||||
try zml.testing.testLayer(platform, buffer_store, "layers.0.mlp", llama.model.layers[0].mlp, llama_weights.model.layers[0].mlp, 1e-2);
|
try zml.testing.testLayer(platform, buffer_store, "layers.0.mlp", llama.model.layers[0].mlp, llama_weights.model.layers[0].mlp, 1e-2);
|
||||||
try zml.testing.testLayer(platform, buffer_store, "layers.0.input_layernorm", llama.model.layers[0].input_layernorm, llama_weights.model.layers[0].input_layernorm, 1e-2);
|
try zml.testing.testLayer(platform, buffer_store, "layers.0.input_layernorm", llama.model.layers[0].input_layernorm, llama_weights.model.layers[0].input_layernorm, 1e-2);
|
||||||
try zml.testing.testLayer(platform, buffer_store, "layers.0.post_attention_layernorm", llama.model.layers[0].post_attention_layernorm, llama_weights.model.layers[0].post_attention_layernorm, 1e-2);
|
try zml.testing.testLayer(platform, buffer_store, "layers.0.post_attention_layernorm", llama.model.layers[0].post_attention_layernorm, llama_weights.model.layers[0].post_attention_layernorm, 1e-2);
|
||||||
|
|
||||||
{
|
|
||||||
const test_case = "layers.0.self_attn";
|
|
||||||
std.log.info("Testing {s}", .{test_case});
|
|
||||||
// Small wrapper to explicitly tag the input, and ignore the extra arguments used in HF implementation.
|
|
||||||
const SelfAttnPrefill = struct {
|
|
||||||
inner: llama_mod.SelfAttn,
|
|
||||||
|
|
||||||
pub fn forward(self: @This(), x_: Tensor) struct { Tensor, llama_mod.KvCache } {
|
|
||||||
return self.inner.forward(x_.withTags(.{ .b, .s, .d }), null, null);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
try zml.testing.testLayer(
|
|
||||||
platform,
|
|
||||||
buffer_store,
|
|
||||||
"layers.0.self_attn",
|
|
||||||
SelfAttnPrefill{ .inner = llama.model.layers[0].self_attn },
|
|
||||||
.{ .inner = llama_weights.model.layers[0].self_attn },
|
|
||||||
1e-3,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
9
examples/third_party/com_github_hejsil_clap/clap.bazel
vendored
Normal file
9
examples/third_party/com_github_hejsil_clap/clap.bazel
vendored
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
load("@rules_zig//zig:defs.bzl", "zig_library")
|
||||||
|
|
||||||
|
zig_library(
|
||||||
|
name = "clap",
|
||||||
|
import_name = "clap",
|
||||||
|
srcs = glob(["clap/*.zig"]),
|
||||||
|
main = "clap.zig",
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
)
|
||||||
20
examples/third_party/non_module_deps.bzl
vendored
Normal file
20
examples/third_party/non_module_deps.bzl
vendored
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
load("@bazel_tools//tools/build_defs/repo:git.bzl", "new_git_repository")
|
||||||
|
|
||||||
|
def _non_module_deps_impl(mctx):
|
||||||
|
|
||||||
|
new_git_repository(
|
||||||
|
name = "com_github_hejsil_clap",
|
||||||
|
remote = "https://github.com/Hejsil/zig-clap.git",
|
||||||
|
commit = "d71cc39a94f3e6ccbad00c25d350c9147de4df9f",
|
||||||
|
build_file = "//:third_party/com_github_hejsil_clap/clap.bazel",
|
||||||
|
)
|
||||||
|
|
||||||
|
return mctx.extension_metadata(
|
||||||
|
reproducible = True,
|
||||||
|
root_module_direct_deps = "all",
|
||||||
|
root_module_direct_dev_deps = [],
|
||||||
|
)
|
||||||
|
|
||||||
|
non_module_deps = module_extension(
|
||||||
|
implementation = _non_module_deps_impl,
|
||||||
|
)
|
||||||
Loading…
Reference in New Issue
Block a user