runtimes/cuda: expose cuda.h in the C namespace for CUDA runtimes, enabling custom calls to CUDA functions.
This commit is contained in:
parent
3849eb10b7
commit
47a4eda5f6
@ -10,21 +10,19 @@ def _kwargs(**kwargs):
|
||||
def _cc_import(**kwargs):
|
||||
return """cc_import({})""".format(_kwargs(**kwargs))
|
||||
|
||||
def _cc_import_glob_hdrs(name, hdrs_glob, shared_library, deps = []):
|
||||
def _cc_library(**kwargs):
|
||||
return """cc_library({})""".format(_kwargs(**kwargs))
|
||||
|
||||
def _cc_import_glob_hdrs(name, hdrs_glob, shared_library, deps = [], **kwargs):
|
||||
return """\
|
||||
filegroup(
|
||||
name = "{name}_files",
|
||||
srcs = glob(["{hdrs_glob}"]),
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
cc_import(
|
||||
name = "{name}",
|
||||
shared_library = {shared_library},
|
||||
hdrs = [":{name}_files"],
|
||||
hdrs = glob(["{hdrs_glob}"]),
|
||||
deps = {deps},
|
||||
visibility = ["//visibility:public"],
|
||||
{kwargs}
|
||||
)
|
||||
""".format(name = name, hdrs_glob = hdrs_glob, shared_library = repr(shared_library), deps = repr(deps))
|
||||
""".format(name = name, hdrs_glob = hdrs_glob, shared_library = repr(shared_library), deps = repr(deps), kwargs = _kwargs(**kwargs))
|
||||
|
||||
def _filegroup(**kwargs):
|
||||
return """filegroup({})""".format(_kwargs(**kwargs))
|
||||
@ -76,6 +74,7 @@ packages = struct(
|
||||
read = _read,
|
||||
cc_import = _cc_import,
|
||||
cc_import_glob_hdrs = _cc_import_glob_hdrs,
|
||||
cc_library = _cc_library,
|
||||
filegroup = _filegroup,
|
||||
load_ = _load,
|
||||
)
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
load("@rules_cc//cc:cc_library.bzl", "cc_library")
|
||||
load("@rules_zig//zig:defs.bzl", "zig_library")
|
||||
|
||||
cc_library(
|
||||
@ -12,6 +13,7 @@ cc_library(
|
||||
|
||||
cc_library(
|
||||
name = "libpjrt_cuda",
|
||||
hdrs = ["libpjrt_cuda.h"],
|
||||
defines = ["ZML_RUNTIME_CUDA"],
|
||||
deps = ["@libpjrt_cuda"],
|
||||
)
|
||||
|
||||
@ -17,10 +17,22 @@ CUDNN_VERSION = "9.8.0"
|
||||
CUDNN_REDIST_JSON_SHA256 = "a1599fa1f8dcb81235157be5de5ab7d3936e75dfc4e1e442d07970afad3c4843"
|
||||
|
||||
CUDA_PACKAGES = {
|
||||
"cuda_cudart": packages.cc_import(
|
||||
name = "cudart",
|
||||
shared_library = "lib/libcudart.so.12",
|
||||
),
|
||||
"cuda_cudart": "\n".join([
|
||||
packages.cc_library(
|
||||
name = "cudart",
|
||||
hdrs = ["include/cuda.h"],
|
||||
includes = ["include"],
|
||||
deps = [":cudart_so", ":cuda_so"],
|
||||
),
|
||||
packages.cc_import(
|
||||
name = "cudart_so",
|
||||
shared_library = "lib/libcudart.so.12",
|
||||
),
|
||||
packages.cc_import(
|
||||
name = "cuda_so",
|
||||
shared_library = "lib/stubs/libcuda.so",
|
||||
),
|
||||
]),
|
||||
"cuda_cupti": packages.cc_import(
|
||||
name = "cupti",
|
||||
shared_library = "lib/libcupti.so.12",
|
||||
|
||||
1
runtimes/cuda/libpjrt_cuda.h
Normal file
1
runtimes/cuda/libpjrt_cuda.h
Normal file
@ -0,0 +1 @@
|
||||
#include <cuda.h>
|
||||
Loading…
Reference in New Issue
Block a user