Update llama example BUILD to use jax-cuda-pjrt plugin and bump CUDA (12.6.2) / CuDNN (9.5.1) versions.

This commit is contained in:
Foke Singh 2023-09-12 15:40:21 +00:00
parent c8c99d7d5a
commit 4abdd32f0d

View File

@ -1,6 +1,5 @@
load("@aspect_bazel_lib//lib:tar.bzl", "mtree_spec", "tar") load("@aspect_bazel_lib//lib:tar.bzl", "mtree_spec", "tar")
load("@aspect_bazel_lib//lib:transitions.bzl", "platform_transition_filegroup") load("@aspect_bazel_lib//lib:transitions.bzl", "platform_transition_filegroup")
load("@bazel_skylib//rules:native_binary.bzl", "native_binary")
load("@rules_oci//oci:defs.bzl", "oci_image", "oci_load", "oci_push") load("@rules_oci//oci:defs.bzl", "oci_image", "oci_load", "oci_push")
load("@zml//bazel:zig.bzl", "zig_cc_binary") load("@zml//bazel:zig.bzl", "zig_cc_binary")
@ -18,9 +17,8 @@ zig_cc_binary(
], ],
) )
native_binary( cc_binary(
name = "Llama-3.1-8B-Instruct", name = "Llama-3.1-8B-Instruct",
src = ":llama",
args = [ args = [
"--model=$(location @Meta-Llama-3.1-8B-Instruct//:model.safetensors.index.json)", "--model=$(location @Meta-Llama-3.1-8B-Instruct//:model.safetensors.index.json)",
"--tokenizer=$(location @Meta-Llama-3.1-8B-Instruct//:tokenizer)", "--tokenizer=$(location @Meta-Llama-3.1-8B-Instruct//:tokenizer)",
@ -33,11 +31,11 @@ native_binary(
"@Meta-Llama-3.1-8B-Instruct//:model.safetensors.index.json", "@Meta-Llama-3.1-8B-Instruct//:model.safetensors.index.json",
"@Meta-Llama-3.1-8B-Instruct//:tokenizer", "@Meta-Llama-3.1-8B-Instruct//:tokenizer",
], ],
deps = [":llama_lib"],
) )
native_binary( cc_binary(
name = "Llama-3.1-70B-Instruct", name = "Llama-3.1-70B-Instruct",
src = ":llama",
args = [ args = [
"--model=$(location @Meta-Llama-3.1-70B-Instruct//:model.safetensors.index.json)", "--model=$(location @Meta-Llama-3.1-70B-Instruct//:model.safetensors.index.json)",
"--tokenizer=$(location @Meta-Llama-3.1-70B-Instruct//:tokenizer)", "--tokenizer=$(location @Meta-Llama-3.1-70B-Instruct//:tokenizer)",
@ -50,11 +48,11 @@ native_binary(
"@Meta-Llama-3.1-70B-Instruct//:model.safetensors.index.json", "@Meta-Llama-3.1-70B-Instruct//:model.safetensors.index.json",
"@Meta-Llama-3.1-70B-Instruct//:tokenizer", "@Meta-Llama-3.1-70B-Instruct//:tokenizer",
], ],
deps = [":llama_lib"],
) )
native_binary( cc_binary(
name = "OpenLLaMA-3B", name = "OpenLLaMA-3B",
src = ":llama",
args = [ args = [
"--model=$(location @OpenLM-Research-OpenLLaMA-3B//:model)", "--model=$(location @OpenLM-Research-OpenLLaMA-3B//:model)",
"--tokenizer=$(location @OpenLM-Research-OpenLLaMA-3B//:tokenizer)", "--tokenizer=$(location @OpenLM-Research-OpenLLaMA-3B//:tokenizer)",
@ -66,11 +64,11 @@ native_binary(
"@OpenLM-Research-OpenLLaMA-3B//:model", "@OpenLM-Research-OpenLLaMA-3B//:model",
"@OpenLM-Research-OpenLLaMA-3B//:tokenizer", "@OpenLM-Research-OpenLLaMA-3B//:tokenizer",
], ],
deps = [":llama_lib"],
) )
native_binary( cc_binary(
name = "TinyLlama-1.1B-Chat", name = "TinyLlama-1.1B-Chat",
src = ":llama",
args = [ args = [
"--model=$(location @TinyLlama-1.1B-Chat-v1.0//:model.safetensors)", "--model=$(location @TinyLlama-1.1B-Chat-v1.0//:model.safetensors)",
"--tokenizer=$(location @TinyLlama-1.1B-Chat-v1.0//:tokenizer)", "--tokenizer=$(location @TinyLlama-1.1B-Chat-v1.0//:tokenizer)",
@ -82,11 +80,11 @@ native_binary(
"@TinyLlama-1.1B-Chat-v1.0//:model.safetensors", "@TinyLlama-1.1B-Chat-v1.0//:model.safetensors",
"@TinyLlama-1.1B-Chat-v1.0//:tokenizer", "@TinyLlama-1.1B-Chat-v1.0//:tokenizer",
], ],
deps = [":llama_lib"],
) )
native_binary( cc_binary(
name = "TinyLlama-Stories-110M", name = "TinyLlama-Stories-110M",
src = ":llama",
args = [ args = [
"--model=$(location @Karpathy-TinyLlama-Stories//:stories110M)", "--model=$(location @Karpathy-TinyLlama-Stories//:stories110M)",
"--tokenizer=$(location @Karpathy-TinyLlama-Tokenizer//file)", "--tokenizer=$(location @Karpathy-TinyLlama-Tokenizer//file)",
@ -95,11 +93,11 @@ native_binary(
"@Karpathy-TinyLlama-Stories//:stories110M", "@Karpathy-TinyLlama-Stories//:stories110M",
"@Karpathy-TinyLlama-Tokenizer//file", "@Karpathy-TinyLlama-Tokenizer//file",
], ],
deps = [":llama_lib"],
) )
native_binary( cc_binary(
name = "TinyLlama-Stories-15M", name = "TinyLlama-Stories-15M",
src = ":llama",
args = [ args = [
"--model=$(location @Karpathy-TinyLlama-Stories//:stories15M)", "--model=$(location @Karpathy-TinyLlama-Stories//:stories15M)",
"--tokenizer=$(location @Karpathy-TinyLlama-Tokenizer//file)", "--tokenizer=$(location @Karpathy-TinyLlama-Tokenizer//file)",
@ -108,6 +106,7 @@ native_binary(
"@Karpathy-TinyLlama-Stories//:stories15M", "@Karpathy-TinyLlama-Stories//:stories15M",
"@Karpathy-TinyLlama-Tokenizer//file", "@Karpathy-TinyLlama-Tokenizer//file",
], ],
deps = [":llama_lib"],
) )
zig_cc_binary( zig_cc_binary(