runtimes/cuda: expose cuda.h in the C namespace for CUDA runtimes, enabling custom calls to CUDA functions.

This commit is contained in:
Tarry Singh 2024-11-01 13:27:24 +00:00
parent 3849eb10b7
commit 47a4eda5f6
4 changed files with 27 additions and 13 deletions

View File

@ -10,21 +10,19 @@ def _kwargs(**kwargs):
def _cc_import(**kwargs): def _cc_import(**kwargs):
return """cc_import({})""".format(_kwargs(**kwargs)) return """cc_import({})""".format(_kwargs(**kwargs))
def _cc_import_glob_hdrs(name, hdrs_glob, shared_library, deps = []): def _cc_library(**kwargs):
return """cc_library({})""".format(_kwargs(**kwargs))
def _cc_import_glob_hdrs(name, hdrs_glob, shared_library, deps = [], **kwargs):
return """\ return """\
filegroup(
name = "{name}_files",
srcs = glob(["{hdrs_glob}"]),
visibility = ["//visibility:public"],
)
cc_import( cc_import(
name = "{name}", name = "{name}",
shared_library = {shared_library}, shared_library = {shared_library},
hdrs = [":{name}_files"], hdrs = glob(["{hdrs_glob}"]),
deps = {deps}, deps = {deps},
visibility = ["//visibility:public"], {kwargs}
) )
""".format(name = name, hdrs_glob = hdrs_glob, shared_library = repr(shared_library), deps = repr(deps)) """.format(name = name, hdrs_glob = hdrs_glob, shared_library = repr(shared_library), deps = repr(deps), kwargs = _kwargs(**kwargs))
def _filegroup(**kwargs): def _filegroup(**kwargs):
return """filegroup({})""".format(_kwargs(**kwargs)) return """filegroup({})""".format(_kwargs(**kwargs))
@ -76,6 +74,7 @@ packages = struct(
read = _read, read = _read,
cc_import = _cc_import, cc_import = _cc_import,
cc_import_glob_hdrs = _cc_import_glob_hdrs, cc_import_glob_hdrs = _cc_import_glob_hdrs,
cc_library = _cc_library,
filegroup = _filegroup, filegroup = _filegroup,
load_ = _load, load_ = _load,
) )

View File

@ -1,3 +1,4 @@
load("@rules_cc//cc:cc_library.bzl", "cc_library")
load("@rules_zig//zig:defs.bzl", "zig_library") load("@rules_zig//zig:defs.bzl", "zig_library")
cc_library( cc_library(
@ -12,6 +13,7 @@ cc_library(
cc_library( cc_library(
name = "libpjrt_cuda", name = "libpjrt_cuda",
hdrs = ["libpjrt_cuda.h"],
defines = ["ZML_RUNTIME_CUDA"], defines = ["ZML_RUNTIME_CUDA"],
deps = ["@libpjrt_cuda"], deps = ["@libpjrt_cuda"],
) )

View File

@ -17,10 +17,22 @@ CUDNN_VERSION = "9.8.0"
CUDNN_REDIST_JSON_SHA256 = "a1599fa1f8dcb81235157be5de5ab7d3936e75dfc4e1e442d07970afad3c4843" CUDNN_REDIST_JSON_SHA256 = "a1599fa1f8dcb81235157be5de5ab7d3936e75dfc4e1e442d07970afad3c4843"
CUDA_PACKAGES = { CUDA_PACKAGES = {
"cuda_cudart": packages.cc_import( "cuda_cudart": "\n".join([
packages.cc_library(
name = "cudart", name = "cudart",
hdrs = ["include/cuda.h"],
includes = ["include"],
deps = [":cudart_so", ":cuda_so"],
),
packages.cc_import(
name = "cudart_so",
shared_library = "lib/libcudart.so.12", shared_library = "lib/libcudart.so.12",
), ),
packages.cc_import(
name = "cuda_so",
shared_library = "lib/stubs/libcuda.so",
),
]),
"cuda_cupti": packages.cc_import( "cuda_cupti": packages.cc_import(
name = "cupti", name = "cupti",
shared_library = "lib/libcupti.so.12", shared_library = "lib/libcupti.so.12",

View File

@ -0,0 +1 @@
#include <cuda.h>