runtimes/cuda: expose cuda.h in the C namespace for CUDA runtimes, enabling custom calls to CUDA functions.
This commit is contained in:
parent
3849eb10b7
commit
47a4eda5f6
@ -10,21 +10,19 @@ def _kwargs(**kwargs):
|
|||||||
def _cc_import(**kwargs):
|
def _cc_import(**kwargs):
|
||||||
return """cc_import({})""".format(_kwargs(**kwargs))
|
return """cc_import({})""".format(_kwargs(**kwargs))
|
||||||
|
|
||||||
def _cc_import_glob_hdrs(name, hdrs_glob, shared_library, deps = []):
|
def _cc_library(**kwargs):
|
||||||
|
return """cc_library({})""".format(_kwargs(**kwargs))
|
||||||
|
|
||||||
|
def _cc_import_glob_hdrs(name, hdrs_glob, shared_library, deps = [], **kwargs):
|
||||||
return """\
|
return """\
|
||||||
filegroup(
|
|
||||||
name = "{name}_files",
|
|
||||||
srcs = glob(["{hdrs_glob}"]),
|
|
||||||
visibility = ["//visibility:public"],
|
|
||||||
)
|
|
||||||
cc_import(
|
cc_import(
|
||||||
name = "{name}",
|
name = "{name}",
|
||||||
shared_library = {shared_library},
|
shared_library = {shared_library},
|
||||||
hdrs = [":{name}_files"],
|
hdrs = glob(["{hdrs_glob}"]),
|
||||||
deps = {deps},
|
deps = {deps},
|
||||||
visibility = ["//visibility:public"],
|
{kwargs}
|
||||||
)
|
)
|
||||||
""".format(name = name, hdrs_glob = hdrs_glob, shared_library = repr(shared_library), deps = repr(deps))
|
""".format(name = name, hdrs_glob = hdrs_glob, shared_library = repr(shared_library), deps = repr(deps), kwargs = _kwargs(**kwargs))
|
||||||
|
|
||||||
def _filegroup(**kwargs):
|
def _filegroup(**kwargs):
|
||||||
return """filegroup({})""".format(_kwargs(**kwargs))
|
return """filegroup({})""".format(_kwargs(**kwargs))
|
||||||
@ -76,6 +74,7 @@ packages = struct(
|
|||||||
read = _read,
|
read = _read,
|
||||||
cc_import = _cc_import,
|
cc_import = _cc_import,
|
||||||
cc_import_glob_hdrs = _cc_import_glob_hdrs,
|
cc_import_glob_hdrs = _cc_import_glob_hdrs,
|
||||||
|
cc_library = _cc_library,
|
||||||
filegroup = _filegroup,
|
filegroup = _filegroup,
|
||||||
load_ = _load,
|
load_ = _load,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
load("@rules_cc//cc:cc_library.bzl", "cc_library")
|
||||||
load("@rules_zig//zig:defs.bzl", "zig_library")
|
load("@rules_zig//zig:defs.bzl", "zig_library")
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
@ -12,6 +13,7 @@ cc_library(
|
|||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "libpjrt_cuda",
|
name = "libpjrt_cuda",
|
||||||
|
hdrs = ["libpjrt_cuda.h"],
|
||||||
defines = ["ZML_RUNTIME_CUDA"],
|
defines = ["ZML_RUNTIME_CUDA"],
|
||||||
deps = ["@libpjrt_cuda"],
|
deps = ["@libpjrt_cuda"],
|
||||||
)
|
)
|
||||||
|
|||||||
@ -17,10 +17,22 @@ CUDNN_VERSION = "9.8.0"
|
|||||||
CUDNN_REDIST_JSON_SHA256 = "a1599fa1f8dcb81235157be5de5ab7d3936e75dfc4e1e442d07970afad3c4843"
|
CUDNN_REDIST_JSON_SHA256 = "a1599fa1f8dcb81235157be5de5ab7d3936e75dfc4e1e442d07970afad3c4843"
|
||||||
|
|
||||||
CUDA_PACKAGES = {
|
CUDA_PACKAGES = {
|
||||||
"cuda_cudart": packages.cc_import(
|
"cuda_cudart": "\n".join([
|
||||||
name = "cudart",
|
packages.cc_library(
|
||||||
shared_library = "lib/libcudart.so.12",
|
name = "cudart",
|
||||||
),
|
hdrs = ["include/cuda.h"],
|
||||||
|
includes = ["include"],
|
||||||
|
deps = [":cudart_so", ":cuda_so"],
|
||||||
|
),
|
||||||
|
packages.cc_import(
|
||||||
|
name = "cudart_so",
|
||||||
|
shared_library = "lib/libcudart.so.12",
|
||||||
|
),
|
||||||
|
packages.cc_import(
|
||||||
|
name = "cuda_so",
|
||||||
|
shared_library = "lib/stubs/libcuda.so",
|
||||||
|
),
|
||||||
|
]),
|
||||||
"cuda_cupti": packages.cc_import(
|
"cuda_cupti": packages.cc_import(
|
||||||
name = "cupti",
|
name = "cupti",
|
||||||
shared_library = "lib/libcupti.so.12",
|
shared_library = "lib/libcupti.so.12",
|
||||||
|
|||||||
1
runtimes/cuda/libpjrt_cuda.h
Normal file
1
runtimes/cuda/libpjrt_cuda.h
Normal file
@ -0,0 +1 @@
|
|||||||
|
#include <cuda.h>
|
||||||
Loading…
Reference in New Issue
Block a user