From 4abdd32f0d0d7f6c6423c35f2206043dadadb4bb Mon Sep 17 00:00:00 2001 From: Foke Singh Date: Tue, 12 Sep 2023 15:40:21 +0000 Subject: [PATCH] Update llama example BUILD to use jax-cuda-pjrt plugin and bump CUDA (12.6.2) / CuDNN (9.5.1) versions. --- examples/llama/BUILD.bazel | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/examples/llama/BUILD.bazel b/examples/llama/BUILD.bazel index e982eb1..f9c4258 100644 --- a/examples/llama/BUILD.bazel +++ b/examples/llama/BUILD.bazel @@ -1,6 +1,5 @@ load("@aspect_bazel_lib//lib:tar.bzl", "mtree_spec", "tar") load("@aspect_bazel_lib//lib:transitions.bzl", "platform_transition_filegroup") -load("@bazel_skylib//rules:native_binary.bzl", "native_binary") load("@rules_oci//oci:defs.bzl", "oci_image", "oci_load", "oci_push") load("@zml//bazel:zig.bzl", "zig_cc_binary") @@ -18,9 +17,8 @@ zig_cc_binary( ], ) -native_binary( +cc_binary( name = "Llama-3.1-8B-Instruct", - src = ":llama", args = [ "--model=$(location @Meta-Llama-3.1-8B-Instruct//:model.safetensors.index.json)", "--tokenizer=$(location @Meta-Llama-3.1-8B-Instruct//:tokenizer)", @@ -33,11 +31,11 @@ native_binary( "@Meta-Llama-3.1-8B-Instruct//:model.safetensors.index.json", "@Meta-Llama-3.1-8B-Instruct//:tokenizer", ], + deps = [":llama_lib"], ) -native_binary( +cc_binary( name = "Llama-3.1-70B-Instruct", - src = ":llama", args = [ "--model=$(location @Meta-Llama-3.1-70B-Instruct//:model.safetensors.index.json)", "--tokenizer=$(location @Meta-Llama-3.1-70B-Instruct//:tokenizer)", @@ -50,11 +48,11 @@ native_binary( "@Meta-Llama-3.1-70B-Instruct//:model.safetensors.index.json", "@Meta-Llama-3.1-70B-Instruct//:tokenizer", ], + deps = [":llama_lib"], ) -native_binary( +cc_binary( name = "OpenLLaMA-3B", - src = ":llama", args = [ "--model=$(location @OpenLM-Research-OpenLLaMA-3B//:model)", "--tokenizer=$(location @OpenLM-Research-OpenLLaMA-3B//:tokenizer)", @@ -66,11 +64,11 @@ native_binary( "@OpenLM-Research-OpenLLaMA-3B//:model", "@OpenLM-Research-OpenLLaMA-3B//:tokenizer", ], + deps = [":llama_lib"], ) -native_binary( +cc_binary( name = "TinyLlama-1.1B-Chat", - src = ":llama", args = [ "--model=$(location @TinyLlama-1.1B-Chat-v1.0//:model.safetensors)", "--tokenizer=$(location @TinyLlama-1.1B-Chat-v1.0//:tokenizer)", @@ -82,11 +80,11 @@ native_binary( "@TinyLlama-1.1B-Chat-v1.0//:model.safetensors", "@TinyLlama-1.1B-Chat-v1.0//:tokenizer", ], + deps = [":llama_lib"], ) -native_binary( +cc_binary( name = "TinyLlama-Stories-110M", - src = ":llama", args = [ "--model=$(location @Karpathy-TinyLlama-Stories//:stories110M)", "--tokenizer=$(location @Karpathy-TinyLlama-Tokenizer//file)", @@ -95,11 +93,11 @@ native_binary( "@Karpathy-TinyLlama-Stories//:stories110M", "@Karpathy-TinyLlama-Tokenizer//file", ], + deps = [":llama_lib"], ) -native_binary( +cc_binary( name = "TinyLlama-Stories-15M", - src = ":llama", args = [ "--model=$(location @Karpathy-TinyLlama-Stories//:stories15M)", "--tokenizer=$(location @Karpathy-TinyLlama-Tokenizer//file)", @@ -108,6 +106,7 @@ native_binary( "@Karpathy-TinyLlama-Stories//:stories15M", "@Karpathy-TinyLlama-Tokenizer//file", ], + deps = [":llama_lib"], ) zig_cc_binary(