From e1ee340306a495a507029e34cfb4baf43ac560e1 Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Tue, 8 Jul 2025 09:25:25 +0000 Subject: [PATCH] runtimes/cuda: implement zmlxcuda in Zig --- runtimes/cuda/BUILD.bazel | 13 +++++--- runtimes/cuda/cuda.bzl | 43 +++++++++++--------------- runtimes/cuda/libpjrt_cuda.BUILD.bazel | 36 ++++++++++----------- runtimes/cuda/zmlxcuda.c | 40 ------------------------ runtimes/cuda/zmlxcuda.zig | 28 +++++++++++++++++ 5 files changed, 72 insertions(+), 88 deletions(-) delete mode 100644 runtimes/cuda/zmlxcuda.c create mode 100644 runtimes/cuda/zmlxcuda.zig diff --git a/runtimes/cuda/BUILD.bazel b/runtimes/cuda/BUILD.bazel index 0913c05..38904d4 100644 --- a/runtimes/cuda/BUILD.bazel +++ b/runtimes/cuda/BUILD.bazel @@ -1,9 +1,14 @@ load("@rules_cc//cc:cc_library.bzl", "cc_library") -load("@rules_zig//zig:defs.bzl", "zig_library") +load("@rules_zig//zig:defs.bzl", "zig_library", "zig_shared_library") -cc_library( - name = "zmlxcuda_lib", - srcs = ["zmlxcuda.c"], +zig_shared_library( + name = "zmlxcuda", + main = "zmlxcuda.zig", + # Use Clang's compiler-rt, but disable stack checking + # to avoid requiring on the _zig_probe_stack symbol. + copts = ["-fno-stack-check"], + shared_lib_name = "libzmlxcuda.so.0", + deps = ["//stdx"], visibility = ["@libpjrt_cuda//:__subpackages__"], ) diff --git a/runtimes/cuda/cuda.bzl b/runtimes/cuda/cuda.bzl index 2d27705..3380694 100644 --- a/runtimes/cuda/cuda.bzl +++ b/runtimes/cuda/cuda.bzl @@ -31,12 +31,12 @@ CUDA_PACKAGES = { ), #TODO: Remove me as soon we use the Driver API in tracer.zig packages.filegroup( - name = "so_files", + name = "cuda_cudart", srcs = ["lib/libcudart.so.12"], ), ]), "cuda_cupti": packages.filegroup( - name = "so_files", + name = "cuda_cupti", srcs = ["lib/libcupti.so.12"], ), "cuda_nvtx": "\n".join([ @@ -46,42 +46,35 @@ CUDA_PACKAGES = { # visibility = ["//visibility:public"], # ), packages.filegroup( - name = "so_files", + name = "cuda_nvtx", srcs = ["lib/libnvToolsExt.so.1"], ), ]), "libcufft": packages.filegroup( - name = "so_files", + name = "libcufft", srcs = ["lib/libcufft.so.11"], ), "libcusolver": packages.filegroup( - name = "so_files", + name = "libcusolver", srcs = ["lib/libcusolver.so.11"], ), "libcusparse": packages.filegroup( - name = "so_files", + name = "libcusparse", srcs = ["lib/libcusparse.so.12"], ), "libnvjitlink": packages.filegroup( - name = "so_files", + name = "libnvjitlink", srcs = ["lib/libnvJitLink.so.12"], ), "cuda_nvcc": "\n".join([ packages.filegroup( - name = "ptxas", - srcs = ["bin/ptxas"], - ), - packages.filegroup( - name = "nvlink", - srcs = ["bin/nvlink"], - ), - packages.filegroup( - name = "libdevice", - srcs = ["nvvm/libdevice/libdevice.10.bc"], - ), - packages.filegroup( - name = "so_files", - srcs = ["nvvm/lib64/libnvvm.so.4"], + name = "cuda_nvcc", + srcs = [ + "bin/ptxas", + "bin/nvlink", + "nvvm/libdevice/libdevice.10.bc", + "nvvm/lib64/libnvvm.so.4", + ], ), packages.cc_import( name = "nvptxcompiler", @@ -90,7 +83,7 @@ CUDA_PACKAGES = { ]), "cuda_nvrtc": "\n".join([ packages.filegroup( - name = "so_files", + name = "cuda_nvrtc", srcs = [ "lib/libnvrtc.so.12", "lib/libnvrtc-builtins.so.12.8", @@ -99,7 +92,7 @@ CUDA_PACKAGES = { ]), "libcublas": "\n".join([ packages.filegroup( - name = "so_files", + name = "libcublas", srcs = [ "lib/libcublasLt.so.12", "lib/libcublas.so.12", @@ -111,7 +104,7 @@ CUDA_PACKAGES = { CUDNN_PACKAGES = { "cudnn": "\n".join([ packages.filegroup( - name = "so_files", + name = "cudnn", srcs = [ "lib/libcudnn.so.9", "lib/libcudnn_adv.so.9", @@ -193,7 +186,7 @@ def _cuda_impl(mctx): type = "zip", sha256 = "362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a", build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + packages.filegroup( - name = "so_files", + name = "nccl", srcs = ["nvidia/nccl/lib/libnccl.so.2"], ), ) diff --git a/runtimes/cuda/libpjrt_cuda.BUILD.bazel b/runtimes/cuda/libpjrt_cuda.BUILD.bazel index 6cb5d74..031ec50 100644 --- a/runtimes/cuda/libpjrt_cuda.BUILD.bazel +++ b/runtimes/cuda/libpjrt_cuda.BUILD.bazel @@ -9,7 +9,7 @@ cc_shared_library( ) patchelf( - name = "libpjrt_cuda.patchelf", + name = "libpjrt_cuda_so", src = "libpjrt_cuda.so", add_needed = [ "libzmlxcuda.so.0", @@ -23,30 +23,28 @@ patchelf( copy_to_directory( name = "sandbox", srcs = [ - ":zmlxcuda_so", - ":libpjrt_cuda.patchelf", - "@cuda_nvcc//:libdevice", - "@cuda_nvcc//:ptxas", - "@cuda_nvcc//:nvlink", - "@cuda_cupti//:so_files", - "@cuda_nvtx//:so_files", - "@cuda_nvcc//:so_files", - "@cuda_nvrtc//:so_files", - "@cuda_cudart//:so_files", - "@cudnn//:so_files", - "@libcublas//:so_files", - "@libcufft//:so_files", - "@libcusolver//:so_files", - "@libcusparse//:so_files", - "@libnvjitlink//:so_files", - "@nccl//:so_files", + ":libpjrt_cuda_so", + "@cuda_cudart", + "@cuda_cupti", + "@cuda_nvcc", + "@cuda_nvrtc", + "@cuda_nvtx", + "@cudnn", + "@libcublas", + "@libcufft", + "@libcusolver", + "@libcusparse", + "@libnvjitlink", + "@nccl", "@zlib1g", + "@zml//runtimes/cuda:zmlxcuda", ], replace_prefixes = { "nvidia/nccl/lib": "lib", "nvvm/lib64": "lib", - "libpjrt_cuda.patchelf": "lib", + "libpjrt_cuda_so": "lib", "lib/x86_64-linux-gnu": "lib", + "runtimes/cuda": "lib", }, add_directory_to_runfiles = False, include_external_repositories = ["**"], diff --git a/runtimes/cuda/zmlxcuda.c b/runtimes/cuda/zmlxcuda.c deleted file mode 100644 index 2bd1d32..0000000 --- a/runtimes/cuda/zmlxcuda.c +++ /dev/null @@ -1,40 +0,0 @@ -#include -#include - -void *zmlxcuda_dlopen(const char *filename, int flags) -{ - if (filename != NULL) - { - char *replacements[] = { - "libcublas.so", - "libcublas.so.12", - "libcublasLt.so", - "libcublasLt.so.12", - "libcudart.so", - "libcudart.so.12", - "libcudnn.so", - "libcudnn.so.9", - "libcufft.so", - "libcufft.so.11", - "libcupti.so", - "libcupti.so.12", - "libcusolver.so", - "libcusolver.so.11", - "libcusparse.so", - "libcusparse.so.12", - "libnccl.so", - "libnccl.so.2", - NULL, - NULL, - }; - for (int i = 0; replacements[i] != NULL; i += 2) - { - if (strcmp(filename, replacements[i]) == 0) - { - filename = replacements[i + 1]; - break; - } - } - } - return dlopen(filename, flags); -} diff --git a/runtimes/cuda/zmlxcuda.zig b/runtimes/cuda/zmlxcuda.zig new file mode 100644 index 0000000..1ccd4ba --- /dev/null +++ b/runtimes/cuda/zmlxcuda.zig @@ -0,0 +1,28 @@ +const std = @import("std"); + +const stdx = @import("stdx"); + +pub export fn zmlxcuda_dlopen(filename: [*c]const u8, flags: c_int) ?*anyopaque { + const replacements: std.StaticStringMap([:0]const u8) = .initComptime(.{ + .{ "libcublas.so", "libcublas.so.12" }, + .{ "libcublasLt.so", "libcublasLt.so.12" }, + .{ "libcudart.so", "libcudart.so.12" }, + .{ "libcudnn.so", "libcudnn.so.9" }, + .{ "libcufft.so", "libcufft.so.11" }, + .{ "libcupti.so", "libcupti.so.12" }, + .{ "libcusolver.so", "libcusolver.so.11" }, + .{ "libcusparse.so", "libcusparse.so.12" }, + .{ "libnccl.so", "libnccl.so.2" }, + }); + + var buf: [std.fs.max_path_bytes]u8 = undefined; + const new_filename: [*c]const u8 = if (filename) |f| blk: { + const replacement = replacements.get(std.fs.path.basename(std.mem.span(f))) orelse break :blk f; + break :blk stdx.fs.path.bufJoinZ(&buf, &.{ + stdx.fs.selfSharedObjectDirPath(), + replacement, + }) catch unreachable; + } else null; + + return std.c.dlopen(new_filename, @bitCast(flags)); +}