runtimes/cuda: implement zmlxcuda in Zig
This commit is contained in:
parent
c488b634fc
commit
e1ee340306
@ -1,9 +1,14 @@
|
|||||||
load("@rules_cc//cc:cc_library.bzl", "cc_library")
|
load("@rules_cc//cc:cc_library.bzl", "cc_library")
|
||||||
load("@rules_zig//zig:defs.bzl", "zig_library")
|
load("@rules_zig//zig:defs.bzl", "zig_library", "zig_shared_library")
|
||||||
|
|
||||||
cc_library(
|
zig_shared_library(
|
||||||
name = "zmlxcuda_lib",
|
name = "zmlxcuda",
|
||||||
srcs = ["zmlxcuda.c"],
|
main = "zmlxcuda.zig",
|
||||||
|
# Use Clang's compiler-rt, but disable stack checking
|
||||||
|
# to avoid requiring on the _zig_probe_stack symbol.
|
||||||
|
copts = ["-fno-stack-check"],
|
||||||
|
shared_lib_name = "libzmlxcuda.so.0",
|
||||||
|
deps = ["//stdx"],
|
||||||
visibility = ["@libpjrt_cuda//:__subpackages__"],
|
visibility = ["@libpjrt_cuda//:__subpackages__"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -31,12 +31,12 @@ CUDA_PACKAGES = {
|
|||||||
),
|
),
|
||||||
#TODO: Remove me as soon we use the Driver API in tracer.zig
|
#TODO: Remove me as soon we use the Driver API in tracer.zig
|
||||||
packages.filegroup(
|
packages.filegroup(
|
||||||
name = "so_files",
|
name = "cuda_cudart",
|
||||||
srcs = ["lib/libcudart.so.12"],
|
srcs = ["lib/libcudart.so.12"],
|
||||||
),
|
),
|
||||||
]),
|
]),
|
||||||
"cuda_cupti": packages.filegroup(
|
"cuda_cupti": packages.filegroup(
|
||||||
name = "so_files",
|
name = "cuda_cupti",
|
||||||
srcs = ["lib/libcupti.so.12"],
|
srcs = ["lib/libcupti.so.12"],
|
||||||
),
|
),
|
||||||
"cuda_nvtx": "\n".join([
|
"cuda_nvtx": "\n".join([
|
||||||
@ -46,42 +46,35 @@ CUDA_PACKAGES = {
|
|||||||
# visibility = ["//visibility:public"],
|
# visibility = ["//visibility:public"],
|
||||||
# ),
|
# ),
|
||||||
packages.filegroup(
|
packages.filegroup(
|
||||||
name = "so_files",
|
name = "cuda_nvtx",
|
||||||
srcs = ["lib/libnvToolsExt.so.1"],
|
srcs = ["lib/libnvToolsExt.so.1"],
|
||||||
),
|
),
|
||||||
]),
|
]),
|
||||||
"libcufft": packages.filegroup(
|
"libcufft": packages.filegroup(
|
||||||
name = "so_files",
|
name = "libcufft",
|
||||||
srcs = ["lib/libcufft.so.11"],
|
srcs = ["lib/libcufft.so.11"],
|
||||||
),
|
),
|
||||||
"libcusolver": packages.filegroup(
|
"libcusolver": packages.filegroup(
|
||||||
name = "so_files",
|
name = "libcusolver",
|
||||||
srcs = ["lib/libcusolver.so.11"],
|
srcs = ["lib/libcusolver.so.11"],
|
||||||
),
|
),
|
||||||
"libcusparse": packages.filegroup(
|
"libcusparse": packages.filegroup(
|
||||||
name = "so_files",
|
name = "libcusparse",
|
||||||
srcs = ["lib/libcusparse.so.12"],
|
srcs = ["lib/libcusparse.so.12"],
|
||||||
),
|
),
|
||||||
"libnvjitlink": packages.filegroup(
|
"libnvjitlink": packages.filegroup(
|
||||||
name = "so_files",
|
name = "libnvjitlink",
|
||||||
srcs = ["lib/libnvJitLink.so.12"],
|
srcs = ["lib/libnvJitLink.so.12"],
|
||||||
),
|
),
|
||||||
"cuda_nvcc": "\n".join([
|
"cuda_nvcc": "\n".join([
|
||||||
packages.filegroup(
|
packages.filegroup(
|
||||||
name = "ptxas",
|
name = "cuda_nvcc",
|
||||||
srcs = ["bin/ptxas"],
|
srcs = [
|
||||||
),
|
"bin/ptxas",
|
||||||
packages.filegroup(
|
"bin/nvlink",
|
||||||
name = "nvlink",
|
"nvvm/libdevice/libdevice.10.bc",
|
||||||
srcs = ["bin/nvlink"],
|
"nvvm/lib64/libnvvm.so.4",
|
||||||
),
|
],
|
||||||
packages.filegroup(
|
|
||||||
name = "libdevice",
|
|
||||||
srcs = ["nvvm/libdevice/libdevice.10.bc"],
|
|
||||||
),
|
|
||||||
packages.filegroup(
|
|
||||||
name = "so_files",
|
|
||||||
srcs = ["nvvm/lib64/libnvvm.so.4"],
|
|
||||||
),
|
),
|
||||||
packages.cc_import(
|
packages.cc_import(
|
||||||
name = "nvptxcompiler",
|
name = "nvptxcompiler",
|
||||||
@ -90,7 +83,7 @@ CUDA_PACKAGES = {
|
|||||||
]),
|
]),
|
||||||
"cuda_nvrtc": "\n".join([
|
"cuda_nvrtc": "\n".join([
|
||||||
packages.filegroup(
|
packages.filegroup(
|
||||||
name = "so_files",
|
name = "cuda_nvrtc",
|
||||||
srcs = [
|
srcs = [
|
||||||
"lib/libnvrtc.so.12",
|
"lib/libnvrtc.so.12",
|
||||||
"lib/libnvrtc-builtins.so.12.8",
|
"lib/libnvrtc-builtins.so.12.8",
|
||||||
@ -99,7 +92,7 @@ CUDA_PACKAGES = {
|
|||||||
]),
|
]),
|
||||||
"libcublas": "\n".join([
|
"libcublas": "\n".join([
|
||||||
packages.filegroup(
|
packages.filegroup(
|
||||||
name = "so_files",
|
name = "libcublas",
|
||||||
srcs = [
|
srcs = [
|
||||||
"lib/libcublasLt.so.12",
|
"lib/libcublasLt.so.12",
|
||||||
"lib/libcublas.so.12",
|
"lib/libcublas.so.12",
|
||||||
@ -111,7 +104,7 @@ CUDA_PACKAGES = {
|
|||||||
CUDNN_PACKAGES = {
|
CUDNN_PACKAGES = {
|
||||||
"cudnn": "\n".join([
|
"cudnn": "\n".join([
|
||||||
packages.filegroup(
|
packages.filegroup(
|
||||||
name = "so_files",
|
name = "cudnn",
|
||||||
srcs = [
|
srcs = [
|
||||||
"lib/libcudnn.so.9",
|
"lib/libcudnn.so.9",
|
||||||
"lib/libcudnn_adv.so.9",
|
"lib/libcudnn_adv.so.9",
|
||||||
@ -193,7 +186,7 @@ def _cuda_impl(mctx):
|
|||||||
type = "zip",
|
type = "zip",
|
||||||
sha256 = "362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a",
|
sha256 = "362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a",
|
||||||
build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + packages.filegroup(
|
build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + packages.filegroup(
|
||||||
name = "so_files",
|
name = "nccl",
|
||||||
srcs = ["nvidia/nccl/lib/libnccl.so.2"],
|
srcs = ["nvidia/nccl/lib/libnccl.so.2"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|||||||
@ -9,7 +9,7 @@ cc_shared_library(
|
|||||||
)
|
)
|
||||||
|
|
||||||
patchelf(
|
patchelf(
|
||||||
name = "libpjrt_cuda.patchelf",
|
name = "libpjrt_cuda_so",
|
||||||
src = "libpjrt_cuda.so",
|
src = "libpjrt_cuda.so",
|
||||||
add_needed = [
|
add_needed = [
|
||||||
"libzmlxcuda.so.0",
|
"libzmlxcuda.so.0",
|
||||||
@ -23,30 +23,28 @@ patchelf(
|
|||||||
copy_to_directory(
|
copy_to_directory(
|
||||||
name = "sandbox",
|
name = "sandbox",
|
||||||
srcs = [
|
srcs = [
|
||||||
":zmlxcuda_so",
|
":libpjrt_cuda_so",
|
||||||
":libpjrt_cuda.patchelf",
|
"@cuda_cudart",
|
||||||
"@cuda_nvcc//:libdevice",
|
"@cuda_cupti",
|
||||||
"@cuda_nvcc//:ptxas",
|
"@cuda_nvcc",
|
||||||
"@cuda_nvcc//:nvlink",
|
"@cuda_nvrtc",
|
||||||
"@cuda_cupti//:so_files",
|
"@cuda_nvtx",
|
||||||
"@cuda_nvtx//:so_files",
|
"@cudnn",
|
||||||
"@cuda_nvcc//:so_files",
|
"@libcublas",
|
||||||
"@cuda_nvrtc//:so_files",
|
"@libcufft",
|
||||||
"@cuda_cudart//:so_files",
|
"@libcusolver",
|
||||||
"@cudnn//:so_files",
|
"@libcusparse",
|
||||||
"@libcublas//:so_files",
|
"@libnvjitlink",
|
||||||
"@libcufft//:so_files",
|
"@nccl",
|
||||||
"@libcusolver//:so_files",
|
|
||||||
"@libcusparse//:so_files",
|
|
||||||
"@libnvjitlink//:so_files",
|
|
||||||
"@nccl//:so_files",
|
|
||||||
"@zlib1g",
|
"@zlib1g",
|
||||||
|
"@zml//runtimes/cuda:zmlxcuda",
|
||||||
],
|
],
|
||||||
replace_prefixes = {
|
replace_prefixes = {
|
||||||
"nvidia/nccl/lib": "lib",
|
"nvidia/nccl/lib": "lib",
|
||||||
"nvvm/lib64": "lib",
|
"nvvm/lib64": "lib",
|
||||||
"libpjrt_cuda.patchelf": "lib",
|
"libpjrt_cuda_so": "lib",
|
||||||
"lib/x86_64-linux-gnu": "lib",
|
"lib/x86_64-linux-gnu": "lib",
|
||||||
|
"runtimes/cuda": "lib",
|
||||||
},
|
},
|
||||||
add_directory_to_runfiles = False,
|
add_directory_to_runfiles = False,
|
||||||
include_external_repositories = ["**"],
|
include_external_repositories = ["**"],
|
||||||
|
|||||||
@ -1,40 +0,0 @@
|
|||||||
#include <dlfcn.h>
|
|
||||||
#include <string.h>
|
|
||||||
|
|
||||||
void *zmlxcuda_dlopen(const char *filename, int flags)
|
|
||||||
{
|
|
||||||
if (filename != NULL)
|
|
||||||
{
|
|
||||||
char *replacements[] = {
|
|
||||||
"libcublas.so",
|
|
||||||
"libcublas.so.12",
|
|
||||||
"libcublasLt.so",
|
|
||||||
"libcublasLt.so.12",
|
|
||||||
"libcudart.so",
|
|
||||||
"libcudart.so.12",
|
|
||||||
"libcudnn.so",
|
|
||||||
"libcudnn.so.9",
|
|
||||||
"libcufft.so",
|
|
||||||
"libcufft.so.11",
|
|
||||||
"libcupti.so",
|
|
||||||
"libcupti.so.12",
|
|
||||||
"libcusolver.so",
|
|
||||||
"libcusolver.so.11",
|
|
||||||
"libcusparse.so",
|
|
||||||
"libcusparse.so.12",
|
|
||||||
"libnccl.so",
|
|
||||||
"libnccl.so.2",
|
|
||||||
NULL,
|
|
||||||
NULL,
|
|
||||||
};
|
|
||||||
for (int i = 0; replacements[i] != NULL; i += 2)
|
|
||||||
{
|
|
||||||
if (strcmp(filename, replacements[i]) == 0)
|
|
||||||
{
|
|
||||||
filename = replacements[i + 1];
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return dlopen(filename, flags);
|
|
||||||
}
|
|
||||||
28
runtimes/cuda/zmlxcuda.zig
Normal file
28
runtimes/cuda/zmlxcuda.zig
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
const std = @import("std");
|
||||||
|
|
||||||
|
const stdx = @import("stdx");
|
||||||
|
|
||||||
|
pub export fn zmlxcuda_dlopen(filename: [*c]const u8, flags: c_int) ?*anyopaque {
|
||||||
|
const replacements: std.StaticStringMap([:0]const u8) = .initComptime(.{
|
||||||
|
.{ "libcublas.so", "libcublas.so.12" },
|
||||||
|
.{ "libcublasLt.so", "libcublasLt.so.12" },
|
||||||
|
.{ "libcudart.so", "libcudart.so.12" },
|
||||||
|
.{ "libcudnn.so", "libcudnn.so.9" },
|
||||||
|
.{ "libcufft.so", "libcufft.so.11" },
|
||||||
|
.{ "libcupti.so", "libcupti.so.12" },
|
||||||
|
.{ "libcusolver.so", "libcusolver.so.11" },
|
||||||
|
.{ "libcusparse.so", "libcusparse.so.12" },
|
||||||
|
.{ "libnccl.so", "libnccl.so.2" },
|
||||||
|
});
|
||||||
|
|
||||||
|
var buf: [std.fs.max_path_bytes]u8 = undefined;
|
||||||
|
const new_filename: [*c]const u8 = if (filename) |f| blk: {
|
||||||
|
const replacement = replacements.get(std.fs.path.basename(std.mem.span(f))) orelse break :blk f;
|
||||||
|
break :blk stdx.fs.path.bufJoinZ(&buf, &.{
|
||||||
|
stdx.fs.selfSharedObjectDirPath(),
|
||||||
|
replacement,
|
||||||
|
}) catch unreachable;
|
||||||
|
} else null;
|
||||||
|
|
||||||
|
return std.c.dlopen(new_filename, @bitCast(flags));
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue
Block a user