runtimes/cuda: implement zmlxcuda in Zig

This commit is contained in:
Tarry Singh 2025-07-08 09:25:25 +00:00
parent c488b634fc
commit e1ee340306
5 changed files with 72 additions and 88 deletions

View File

@ -1,9 +1,14 @@
load("@rules_cc//cc:cc_library.bzl", "cc_library") load("@rules_cc//cc:cc_library.bzl", "cc_library")
load("@rules_zig//zig:defs.bzl", "zig_library") load("@rules_zig//zig:defs.bzl", "zig_library", "zig_shared_library")
cc_library( zig_shared_library(
name = "zmlxcuda_lib", name = "zmlxcuda",
srcs = ["zmlxcuda.c"], main = "zmlxcuda.zig",
# Use Clang's compiler-rt, but disable stack checking
# to avoid requiring on the _zig_probe_stack symbol.
copts = ["-fno-stack-check"],
shared_lib_name = "libzmlxcuda.so.0",
deps = ["//stdx"],
visibility = ["@libpjrt_cuda//:__subpackages__"], visibility = ["@libpjrt_cuda//:__subpackages__"],
) )

View File

@ -31,12 +31,12 @@ CUDA_PACKAGES = {
), ),
#TODO: Remove me as soon we use the Driver API in tracer.zig #TODO: Remove me as soon we use the Driver API in tracer.zig
packages.filegroup( packages.filegroup(
name = "so_files", name = "cuda_cudart",
srcs = ["lib/libcudart.so.12"], srcs = ["lib/libcudart.so.12"],
), ),
]), ]),
"cuda_cupti": packages.filegroup( "cuda_cupti": packages.filegroup(
name = "so_files", name = "cuda_cupti",
srcs = ["lib/libcupti.so.12"], srcs = ["lib/libcupti.so.12"],
), ),
"cuda_nvtx": "\n".join([ "cuda_nvtx": "\n".join([
@ -46,42 +46,35 @@ CUDA_PACKAGES = {
# visibility = ["//visibility:public"], # visibility = ["//visibility:public"],
# ), # ),
packages.filegroup( packages.filegroup(
name = "so_files", name = "cuda_nvtx",
srcs = ["lib/libnvToolsExt.so.1"], srcs = ["lib/libnvToolsExt.so.1"],
), ),
]), ]),
"libcufft": packages.filegroup( "libcufft": packages.filegroup(
name = "so_files", name = "libcufft",
srcs = ["lib/libcufft.so.11"], srcs = ["lib/libcufft.so.11"],
), ),
"libcusolver": packages.filegroup( "libcusolver": packages.filegroup(
name = "so_files", name = "libcusolver",
srcs = ["lib/libcusolver.so.11"], srcs = ["lib/libcusolver.so.11"],
), ),
"libcusparse": packages.filegroup( "libcusparse": packages.filegroup(
name = "so_files", name = "libcusparse",
srcs = ["lib/libcusparse.so.12"], srcs = ["lib/libcusparse.so.12"],
), ),
"libnvjitlink": packages.filegroup( "libnvjitlink": packages.filegroup(
name = "so_files", name = "libnvjitlink",
srcs = ["lib/libnvJitLink.so.12"], srcs = ["lib/libnvJitLink.so.12"],
), ),
"cuda_nvcc": "\n".join([ "cuda_nvcc": "\n".join([
packages.filegroup( packages.filegroup(
name = "ptxas", name = "cuda_nvcc",
srcs = ["bin/ptxas"], srcs = [
), "bin/ptxas",
packages.filegroup( "bin/nvlink",
name = "nvlink", "nvvm/libdevice/libdevice.10.bc",
srcs = ["bin/nvlink"], "nvvm/lib64/libnvvm.so.4",
), ],
packages.filegroup(
name = "libdevice",
srcs = ["nvvm/libdevice/libdevice.10.bc"],
),
packages.filegroup(
name = "so_files",
srcs = ["nvvm/lib64/libnvvm.so.4"],
), ),
packages.cc_import( packages.cc_import(
name = "nvptxcompiler", name = "nvptxcompiler",
@ -90,7 +83,7 @@ CUDA_PACKAGES = {
]), ]),
"cuda_nvrtc": "\n".join([ "cuda_nvrtc": "\n".join([
packages.filegroup( packages.filegroup(
name = "so_files", name = "cuda_nvrtc",
srcs = [ srcs = [
"lib/libnvrtc.so.12", "lib/libnvrtc.so.12",
"lib/libnvrtc-builtins.so.12.8", "lib/libnvrtc-builtins.so.12.8",
@ -99,7 +92,7 @@ CUDA_PACKAGES = {
]), ]),
"libcublas": "\n".join([ "libcublas": "\n".join([
packages.filegroup( packages.filegroup(
name = "so_files", name = "libcublas",
srcs = [ srcs = [
"lib/libcublasLt.so.12", "lib/libcublasLt.so.12",
"lib/libcublas.so.12", "lib/libcublas.so.12",
@ -111,7 +104,7 @@ CUDA_PACKAGES = {
CUDNN_PACKAGES = { CUDNN_PACKAGES = {
"cudnn": "\n".join([ "cudnn": "\n".join([
packages.filegroup( packages.filegroup(
name = "so_files", name = "cudnn",
srcs = [ srcs = [
"lib/libcudnn.so.9", "lib/libcudnn.so.9",
"lib/libcudnn_adv.so.9", "lib/libcudnn_adv.so.9",
@ -193,7 +186,7 @@ def _cuda_impl(mctx):
type = "zip", type = "zip",
sha256 = "362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a", sha256 = "362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a",
build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + packages.filegroup( build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + packages.filegroup(
name = "so_files", name = "nccl",
srcs = ["nvidia/nccl/lib/libnccl.so.2"], srcs = ["nvidia/nccl/lib/libnccl.so.2"],
), ),
) )

View File

@ -9,7 +9,7 @@ cc_shared_library(
) )
patchelf( patchelf(
name = "libpjrt_cuda.patchelf", name = "libpjrt_cuda_so",
src = "libpjrt_cuda.so", src = "libpjrt_cuda.so",
add_needed = [ add_needed = [
"libzmlxcuda.so.0", "libzmlxcuda.so.0",
@ -23,30 +23,28 @@ patchelf(
copy_to_directory( copy_to_directory(
name = "sandbox", name = "sandbox",
srcs = [ srcs = [
":zmlxcuda_so", ":libpjrt_cuda_so",
":libpjrt_cuda.patchelf", "@cuda_cudart",
"@cuda_nvcc//:libdevice", "@cuda_cupti",
"@cuda_nvcc//:ptxas", "@cuda_nvcc",
"@cuda_nvcc//:nvlink", "@cuda_nvrtc",
"@cuda_cupti//:so_files", "@cuda_nvtx",
"@cuda_nvtx//:so_files", "@cudnn",
"@cuda_nvcc//:so_files", "@libcublas",
"@cuda_nvrtc//:so_files", "@libcufft",
"@cuda_cudart//:so_files", "@libcusolver",
"@cudnn//:so_files", "@libcusparse",
"@libcublas//:so_files", "@libnvjitlink",
"@libcufft//:so_files", "@nccl",
"@libcusolver//:so_files",
"@libcusparse//:so_files",
"@libnvjitlink//:so_files",
"@nccl//:so_files",
"@zlib1g", "@zlib1g",
"@zml//runtimes/cuda:zmlxcuda",
], ],
replace_prefixes = { replace_prefixes = {
"nvidia/nccl/lib": "lib", "nvidia/nccl/lib": "lib",
"nvvm/lib64": "lib", "nvvm/lib64": "lib",
"libpjrt_cuda.patchelf": "lib", "libpjrt_cuda_so": "lib",
"lib/x86_64-linux-gnu": "lib", "lib/x86_64-linux-gnu": "lib",
"runtimes/cuda": "lib",
}, },
add_directory_to_runfiles = False, add_directory_to_runfiles = False,
include_external_repositories = ["**"], include_external_repositories = ["**"],

View File

@ -1,40 +0,0 @@
#include <dlfcn.h>
#include <string.h>
void *zmlxcuda_dlopen(const char *filename, int flags)
{
if (filename != NULL)
{
char *replacements[] = {
"libcublas.so",
"libcublas.so.12",
"libcublasLt.so",
"libcublasLt.so.12",
"libcudart.so",
"libcudart.so.12",
"libcudnn.so",
"libcudnn.so.9",
"libcufft.so",
"libcufft.so.11",
"libcupti.so",
"libcupti.so.12",
"libcusolver.so",
"libcusolver.so.11",
"libcusparse.so",
"libcusparse.so.12",
"libnccl.so",
"libnccl.so.2",
NULL,
NULL,
};
for (int i = 0; replacements[i] != NULL; i += 2)
{
if (strcmp(filename, replacements[i]) == 0)
{
filename = replacements[i + 1];
break;
}
}
}
return dlopen(filename, flags);
}

View File

@ -0,0 +1,28 @@
const std = @import("std");
const stdx = @import("stdx");
pub export fn zmlxcuda_dlopen(filename: [*c]const u8, flags: c_int) ?*anyopaque {
const replacements: std.StaticStringMap([:0]const u8) = .initComptime(.{
.{ "libcublas.so", "libcublas.so.12" },
.{ "libcublasLt.so", "libcublasLt.so.12" },
.{ "libcudart.so", "libcudart.so.12" },
.{ "libcudnn.so", "libcudnn.so.9" },
.{ "libcufft.so", "libcufft.so.11" },
.{ "libcupti.so", "libcupti.so.12" },
.{ "libcusolver.so", "libcusolver.so.11" },
.{ "libcusparse.so", "libcusparse.so.12" },
.{ "libnccl.so", "libnccl.so.2" },
});
var buf: [std.fs.max_path_bytes]u8 = undefined;
const new_filename: [*c]const u8 = if (filename) |f| blk: {
const replacement = replacements.get(std.fs.path.basename(std.mem.span(f))) orelse break :blk f;
break :blk stdx.fs.path.bufJoinZ(&buf, &.{
stdx.fs.selfSharedObjectDirPath(),
replacement,
}) catch unreachable;
} else null;
return std.c.dlopen(new_filename, @bitCast(flags));
}