245 lines
7.9 KiB
Python
245 lines
7.9 KiB
Python
load("@bazel_skylib//lib:paths.bzl", "paths")
|
|
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
|
|
load("//bazel:http_deb_archive.bzl", "http_deb_archive")
|
|
load("//runtimes/common:packages.bzl", "packages")
|
|
|
|
_BUILD_FILE_DEFAULT_VISIBILITY = """\
|
|
package(default_visibility = ["//visibility:public"])
|
|
"""
|
|
|
|
ARCH = "linux-x86_64"
|
|
|
|
CUDA_REDIST_PREFIX = "https://developer.download.nvidia.com/compute/cuda/redist/"
|
|
CUDA_VERSION = "12.8.1"
|
|
CUDA_REDIST_JSON_SHA256 = "249e28a83008d711d5f72880541c8be6253f6d61608461de4fcb715554a6cf17"
|
|
|
|
CUDNN_REDIST_PREFIX = "https://developer.download.nvidia.com/compute/cudnn/redist/"
|
|
CUDNN_VERSION = "9.8.0"
|
|
CUDNN_REDIST_JSON_SHA256 = "a1599fa1f8dcb81235157be5de5ab7d3936e75dfc4e1e442d07970afad3c4843"
|
|
|
|
NVSHMEM_REDIST_PREFIX = "https://developer.download.nvidia.com/compute/nvshmem/redist/"
|
|
NVSHMEM_VERSION = "3.2.5"
|
|
NVSHMEM_REDIST_JSON_SHA256 = "6945425d3bfd24de23c045996f93ec720c010379bfd6f0860ac5f2716659442d"
|
|
|
|
_UBUNTU_PACKAGES = {
|
|
"zlib1g": packages.filegroup(name = "zlib1g", srcs = ["lib/x86_64-linux-gnu/libz.so.1"]),
|
|
}
|
|
|
|
CUDA_PACKAGES = {
|
|
"cuda_cudart": "\n".join([
|
|
# Driver API only
|
|
packages.cc_library(
|
|
name = "cuda",
|
|
hdrs = ["include/cuda.h"],
|
|
includes = ["include"],
|
|
),
|
|
#TODO: Remove me as soon we use the Driver API in tracer.zig
|
|
packages.filegroup(
|
|
name = "cuda_cudart",
|
|
srcs = ["lib/libcudart.so.12"],
|
|
),
|
|
]),
|
|
"cuda_cupti": packages.filegroup(
|
|
name = "cuda_cupti",
|
|
srcs = ["lib/libcupti.so.12"],
|
|
),
|
|
"cuda_nvtx": "\n".join([
|
|
# packages.cc_library(
|
|
# name = "nvtx",
|
|
# hdrs = glob(["include/nvtx3/**/*.h"]),
|
|
# visibility = ["//visibility:public"],
|
|
# ),
|
|
packages.filegroup(
|
|
name = "cuda_nvtx",
|
|
srcs = ["lib/libnvToolsExt.so.1"],
|
|
),
|
|
]),
|
|
"libcufft": packages.filegroup(
|
|
name = "libcufft",
|
|
srcs = ["lib/libcufft.so.11"],
|
|
),
|
|
"libcusolver": packages.filegroup(
|
|
name = "libcusolver",
|
|
srcs = ["lib/libcusolver.so.11"],
|
|
),
|
|
"libcusparse": packages.filegroup(
|
|
name = "libcusparse",
|
|
srcs = ["lib/libcusparse.so.12"],
|
|
),
|
|
"libnvjitlink": packages.filegroup(
|
|
name = "libnvjitlink",
|
|
srcs = ["lib/libnvJitLink.so.12"],
|
|
),
|
|
"cuda_nvcc": "\n".join([
|
|
packages.filegroup(
|
|
name = "cuda_nvcc",
|
|
srcs = [
|
|
"bin/ptxas",
|
|
"bin/nvlink",
|
|
"nvvm/libdevice/libdevice.10.bc",
|
|
"nvvm/lib64/libnvvm.so.4",
|
|
],
|
|
),
|
|
packages.cc_import(
|
|
name = "nvptxcompiler",
|
|
static_library = "lib/libnvptxcompiler_static.a",
|
|
),
|
|
]),
|
|
"cuda_nvrtc": "\n".join([
|
|
packages.filegroup(
|
|
name = "cuda_nvrtc",
|
|
srcs = [
|
|
"lib/libnvrtc.so.12",
|
|
"lib/libnvrtc-builtins.so.12.8",
|
|
],
|
|
),
|
|
]),
|
|
"libcublas": "\n".join([
|
|
packages.filegroup(
|
|
name = "libcublas",
|
|
srcs = [
|
|
"lib/libcublasLt.so.12",
|
|
"lib/libcublas.so.12",
|
|
],
|
|
),
|
|
]),
|
|
}
|
|
|
|
CUDNN_PACKAGES = {
|
|
"cudnn": "\n".join([
|
|
packages.filegroup(
|
|
name = "cudnn",
|
|
srcs = [
|
|
"lib/libcudnn.so.9",
|
|
"lib/libcudnn_adv.so.9",
|
|
"lib/libcudnn_ops.so.9",
|
|
"lib/libcudnn_cnn.so.9",
|
|
"lib/libcudnn_graph.so.9",
|
|
"lib/libcudnn_engines_precompiled.so.9",
|
|
"lib/libcudnn_engines_runtime_compiled.so.9",
|
|
"lib/libcudnn_heuristic.so.9",
|
|
],
|
|
),
|
|
]),
|
|
}
|
|
|
|
NVSHMEM_PACKAGES = {
|
|
"libnvshmem": packages.filegroup(
|
|
name = "libnvshmem",
|
|
srcs = [
|
|
"lib/libnvshmem_host.so.3",
|
|
"lib/nvshmem_bootstrap_uid.so.3",
|
|
"lib/nvshmem_transport_ibrc.so.3",
|
|
],
|
|
),
|
|
}
|
|
|
|
def _read_redist_json(mctx, url, sha256):
|
|
fname = ".{}.json".format(sha256)
|
|
mctx.download(
|
|
url = url,
|
|
output = fname,
|
|
sha256 = sha256,
|
|
)
|
|
return json.decode(mctx.read(fname))
|
|
|
|
def _cuda_impl(mctx):
|
|
loaded_packages = packages.read(mctx, [
|
|
"@zml//runtimes/cuda:packages.lock.json",
|
|
])
|
|
CUDA_REDIST = _read_redist_json(
|
|
mctx,
|
|
url = CUDA_REDIST_PREFIX + "redistrib_{}.json".format(CUDA_VERSION),
|
|
sha256 = CUDA_REDIST_JSON_SHA256,
|
|
)
|
|
|
|
NVSHMEM_REDIST = _read_redist_json(
|
|
mctx,
|
|
url = NVSHMEM_REDIST_PREFIX + "redistrib_{}.json".format(NVSHMEM_VERSION),
|
|
sha256 = NVSHMEM_REDIST_JSON_SHA256,
|
|
)
|
|
|
|
CUDNN_REDIST = _read_redist_json(
|
|
mctx,
|
|
url = CUDNN_REDIST_PREFIX + "redistrib_{}.json".format(CUDNN_VERSION),
|
|
sha256 = CUDNN_REDIST_JSON_SHA256,
|
|
)
|
|
|
|
for pkg_name, build_file_content in _UBUNTU_PACKAGES.items():
|
|
pkg = loaded_packages[pkg_name]
|
|
http_deb_archive(
|
|
name = pkg_name,
|
|
urls = pkg["urls"],
|
|
sha256 = pkg["sha256"],
|
|
build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + build_file_content,
|
|
)
|
|
|
|
for pkg, build_file_content in CUDA_PACKAGES.items():
|
|
pkg_data = CUDA_REDIST[pkg]
|
|
arch_data = pkg_data.get(ARCH)
|
|
if not arch_data:
|
|
continue
|
|
http_archive(
|
|
name = pkg,
|
|
build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + build_file_content,
|
|
url = CUDA_REDIST_PREFIX + arch_data["relative_path"],
|
|
sha256 = arch_data["sha256"],
|
|
strip_prefix = paths.basename(arch_data["relative_path"]).replace(".tar.xz", ""),
|
|
)
|
|
|
|
for pkg, build_file_content in CUDNN_PACKAGES.items():
|
|
pkg_data = CUDNN_REDIST[pkg]
|
|
arch_data = pkg_data.get(ARCH)
|
|
if not arch_data:
|
|
continue
|
|
arch_data = arch_data.get("cuda12", arch_data)
|
|
http_archive(
|
|
name = pkg,
|
|
build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + build_file_content,
|
|
url = CUDNN_REDIST_PREFIX + arch_data["relative_path"],
|
|
sha256 = arch_data["sha256"],
|
|
strip_prefix = paths.basename(arch_data["relative_path"]).replace(".tar.xz", ""),
|
|
)
|
|
|
|
for pkg, build_file_content in NVSHMEM_PACKAGES.items():
|
|
pkg_data = NVSHMEM_REDIST[pkg]
|
|
arch_data = pkg_data.get(ARCH)
|
|
if not arch_data:
|
|
continue
|
|
arch_data = arch_data.get("cuda12", arch_data)
|
|
http_archive(
|
|
name = pkg,
|
|
build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + build_file_content,
|
|
url = NVSHMEM_REDIST_PREFIX + arch_data["relative_path"],
|
|
sha256 = arch_data["sha256"],
|
|
strip_prefix = paths.basename(arch_data["relative_path"]).replace(".tar.xz", ""),
|
|
)
|
|
|
|
http_archive(
|
|
name = "nccl",
|
|
urls = ["https://files.pythonhosted.org/packages/11/0c/8c78b7603f4e685624a3ea944940f1e75f36d71bd6504330511f4a0e1557/nvidia_nccl_cu12-2.25.1-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl"],
|
|
type = "zip",
|
|
sha256 = "362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a",
|
|
build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + packages.filegroup(
|
|
name = "nccl",
|
|
srcs = ["nvidia/nccl/lib/libnccl.so.2"],
|
|
),
|
|
)
|
|
|
|
http_archive(
|
|
name = "libpjrt_cuda",
|
|
build_file = "libpjrt_cuda.BUILD.bazel",
|
|
url = "https://github.com/zml/pjrt-artifacts/releases/download/v13.0.0/pjrt-cuda_linux-amd64.tar.gz",
|
|
sha256 = "6cdac9bac6db904e4423c9745c61000cf3acaf3c7da8016ab0016f076869048a",
|
|
)
|
|
|
|
return mctx.extension_metadata(
|
|
reproducible = True,
|
|
root_module_direct_deps = ["libpjrt_cuda"],
|
|
root_module_direct_dev_deps = [],
|
|
)
|
|
|
|
cuda_packages = module_extension(
|
|
implementation = _cuda_impl,
|
|
)
|