diff --git a/runtimes/common/packages.bzl b/runtimes/common/packages.bzl index bf30258..fa4c19b 100644 --- a/runtimes/common/packages.bzl +++ b/runtimes/common/packages.bzl @@ -10,21 +10,19 @@ def _kwargs(**kwargs): def _cc_import(**kwargs): return """cc_import({})""".format(_kwargs(**kwargs)) -def _cc_import_glob_hdrs(name, hdrs_glob, shared_library, deps = []): +def _cc_library(**kwargs): + return """cc_library({})""".format(_kwargs(**kwargs)) + +def _cc_import_glob_hdrs(name, hdrs_glob, shared_library, deps = [], **kwargs): return """\ -filegroup( - name = "{name}_files", - srcs = glob(["{hdrs_glob}"]), - visibility = ["//visibility:public"], -) cc_import( name = "{name}", shared_library = {shared_library}, - hdrs = [":{name}_files"], + hdrs = glob(["{hdrs_glob}"]), deps = {deps}, - visibility = ["//visibility:public"], + {kwargs} ) -""".format(name = name, hdrs_glob = hdrs_glob, shared_library = repr(shared_library), deps = repr(deps)) +""".format(name = name, hdrs_glob = hdrs_glob, shared_library = repr(shared_library), deps = repr(deps), kwargs = _kwargs(**kwargs)) def _filegroup(**kwargs): return """filegroup({})""".format(_kwargs(**kwargs)) @@ -76,6 +74,7 @@ packages = struct( read = _read, cc_import = _cc_import, cc_import_glob_hdrs = _cc_import_glob_hdrs, + cc_library = _cc_library, filegroup = _filegroup, load_ = _load, ) diff --git a/runtimes/cuda/BUILD.bazel b/runtimes/cuda/BUILD.bazel index a7f8bf9..0913c05 100644 --- a/runtimes/cuda/BUILD.bazel +++ b/runtimes/cuda/BUILD.bazel @@ -1,3 +1,4 @@ +load("@rules_cc//cc:cc_library.bzl", "cc_library") load("@rules_zig//zig:defs.bzl", "zig_library") cc_library( @@ -12,6 +13,7 @@ cc_library( cc_library( name = "libpjrt_cuda", + hdrs = ["libpjrt_cuda.h"], defines = ["ZML_RUNTIME_CUDA"], deps = ["@libpjrt_cuda"], ) diff --git a/runtimes/cuda/cuda.bzl b/runtimes/cuda/cuda.bzl index 82d1e44..6cd9bee 100644 --- a/runtimes/cuda/cuda.bzl +++ b/runtimes/cuda/cuda.bzl @@ -17,10 +17,22 @@ CUDNN_VERSION = "9.8.0" CUDNN_REDIST_JSON_SHA256 = "a1599fa1f8dcb81235157be5de5ab7d3936e75dfc4e1e442d07970afad3c4843" CUDA_PACKAGES = { - "cuda_cudart": packages.cc_import( - name = "cudart", - shared_library = "lib/libcudart.so.12", - ), + "cuda_cudart": "\n".join([ + packages.cc_library( + name = "cudart", + hdrs = ["include/cuda.h"], + includes = ["include"], + deps = [":cudart_so", ":cuda_so"], + ), + packages.cc_import( + name = "cudart_so", + shared_library = "lib/libcudart.so.12", + ), + packages.cc_import( + name = "cuda_so", + shared_library = "lib/stubs/libcuda.so", + ), + ]), "cuda_cupti": packages.cc_import( name = "cupti", shared_library = "lib/libcupti.so.12", diff --git a/runtimes/cuda/libpjrt_cuda.h b/runtimes/cuda/libpjrt_cuda.h new file mode 100644 index 0000000..c2d9f0d --- /dev/null +++ b/runtimes/cuda/libpjrt_cuda.h @@ -0,0 +1 @@ +#include