52 lines
1.2 KiB
Plaintext
52 lines
1.2 KiB
Plaintext
load("@aspect_bazel_lib//lib:copy_to_directory.bzl", "copy_to_directory")
|
|
load("@zml//bazel:cc_import.bzl", "cc_import")
|
|
|
|
cc_shared_library(
|
|
name = "zmlxcuda_so",
|
|
shared_lib_name = "libzmlxcuda.so.0",
|
|
deps = ["@zml//runtimes/cuda:zmlxcuda_lib"],
|
|
)
|
|
|
|
cc_import(
|
|
name = "zmlxcuda",
|
|
shared_library = ":zmlxcuda_so",
|
|
)
|
|
|
|
copy_to_directory(
|
|
name = "sandbox",
|
|
srcs = [
|
|
"@cuda_nvcc//:libdevice",
|
|
"@cuda_nvcc//:ptxas",
|
|
"@cuda_nvcc//:nvlink",
|
|
],
|
|
include_external_repositories = ["**"],
|
|
)
|
|
|
|
cc_import(
|
|
name = "libpjrt_cuda",
|
|
data = [":sandbox"],
|
|
shared_library = "libpjrt_cuda.so",
|
|
add_needed = ["libzmlxcuda.so.0"],
|
|
rename_dynamic_symbols = {
|
|
"dlopen": "zmlxcuda_dlopen",
|
|
},
|
|
visibility = ["@zml//runtimes/cuda:__subpackages__"],
|
|
deps = [
|
|
":zmlxcuda",
|
|
"@cuda_cudart//:cudart",
|
|
"@cuda_cupti//:cupti",
|
|
"@cuda_nvtx//:nvtx",
|
|
"@cuda_nvcc//:nvptxcompiler",
|
|
"@cuda_nvcc//:nvvm",
|
|
"@cuda_nvrtc//:nvrtc",
|
|
"@cudnn//:cudnn",
|
|
"@libcublas//:cublas",
|
|
"@libcufft//:cufft",
|
|
"@libcusolver//:cusolver",
|
|
"@libcusparse//:cusparse",
|
|
"@libnvjitlink//:nvjitlink",
|
|
"@nccl",
|
|
"@zlib1g",
|
|
],
|
|
)
|