runtimes/cuda: expose cuda.h in the C namespace for CUDA runtimes, enabling custom calls to CUDA functions.

This commit is contained in:
Tarry Singh 2024-11-01 13:27:24 +00:00
parent 3849eb10b7
commit 47a4eda5f6
4 changed files with 27 additions and 13 deletions

View File

@ -10,21 +10,19 @@ def _kwargs(**kwargs):
def _cc_import(**kwargs):
return """cc_import({})""".format(_kwargs(**kwargs))
def _cc_import_glob_hdrs(name, hdrs_glob, shared_library, deps = []):
def _cc_library(**kwargs):
return """cc_library({})""".format(_kwargs(**kwargs))
def _cc_import_glob_hdrs(name, hdrs_glob, shared_library, deps = [], **kwargs):
return """\
filegroup(
name = "{name}_files",
srcs = glob(["{hdrs_glob}"]),
visibility = ["//visibility:public"],
)
cc_import(
name = "{name}",
shared_library = {shared_library},
hdrs = [":{name}_files"],
hdrs = glob(["{hdrs_glob}"]),
deps = {deps},
visibility = ["//visibility:public"],
{kwargs}
)
""".format(name = name, hdrs_glob = hdrs_glob, shared_library = repr(shared_library), deps = repr(deps))
""".format(name = name, hdrs_glob = hdrs_glob, shared_library = repr(shared_library), deps = repr(deps), kwargs = _kwargs(**kwargs))
def _filegroup(**kwargs):
return """filegroup({})""".format(_kwargs(**kwargs))
@ -76,6 +74,7 @@ packages = struct(
read = _read,
cc_import = _cc_import,
cc_import_glob_hdrs = _cc_import_glob_hdrs,
cc_library = _cc_library,
filegroup = _filegroup,
load_ = _load,
)

View File

@ -1,3 +1,4 @@
load("@rules_cc//cc:cc_library.bzl", "cc_library")
load("@rules_zig//zig:defs.bzl", "zig_library")
cc_library(
@ -12,6 +13,7 @@ cc_library(
cc_library(
name = "libpjrt_cuda",
hdrs = ["libpjrt_cuda.h"],
defines = ["ZML_RUNTIME_CUDA"],
deps = ["@libpjrt_cuda"],
)

View File

@ -17,10 +17,22 @@ CUDNN_VERSION = "9.8.0"
CUDNN_REDIST_JSON_SHA256 = "a1599fa1f8dcb81235157be5de5ab7d3936e75dfc4e1e442d07970afad3c4843"
CUDA_PACKAGES = {
"cuda_cudart": packages.cc_import(
name = "cudart",
shared_library = "lib/libcudart.so.12",
),
"cuda_cudart": "\n".join([
packages.cc_library(
name = "cudart",
hdrs = ["include/cuda.h"],
includes = ["include"],
deps = [":cudart_so", ":cuda_so"],
),
packages.cc_import(
name = "cudart_so",
shared_library = "lib/libcudart.so.12",
),
packages.cc_import(
name = "cuda_so",
shared_library = "lib/stubs/libcuda.so",
),
]),
"cuda_cupti": packages.cc_import(
name = "cupti",
shared_library = "lib/libcupti.so.12",

View File

@ -0,0 +1 @@
#include <cuda.h>