runtimes/cuda: sandbox CUDA dependencies by removing them from the leaf binary, sandboxing the dependency graph, marking dlopen direct dependencies as NEEDED, setting RPATH to the sandbox, loading the PJRT plugin from the sandbox, and enabling weak CUDA symbols without direct linking.

This commit is contained in:
Tarry Singh 2025-03-26 11:18:29 +00:00
parent a5420068b1
commit 2d321d232d
13 changed files with 286 additions and 216 deletions

View File

@ -86,9 +86,6 @@ cpu = use_extension("//runtimes/cpu:cpu.bzl", "cpu_pjrt_plugin")
use_repo(cpu, "libpjrt_cpu_darwin_amd64", "libpjrt_cpu_darwin_arm64", "libpjrt_cpu_linux_amd64") use_repo(cpu, "libpjrt_cpu_darwin_amd64", "libpjrt_cpu_darwin_arm64", "libpjrt_cpu_linux_amd64")
cuda = use_extension("//runtimes/cuda:cuda.bzl", "cuda_packages") cuda = use_extension("//runtimes/cuda:cuda.bzl", "cuda_packages")
inject_repo(cuda, "zlib1g")
use_repo(cuda, "libpjrt_cuda") use_repo(cuda, "libpjrt_cuda")
rocm = use_extension("//runtimes/rocm:rocm.bzl", "rocm_packages") rocm = use_extension("//runtimes/rocm:rocm.bzl", "rocm_packages")
@ -159,6 +156,12 @@ apt.install(
manifest = "//runtimes/common:packages.yaml", manifest = "//runtimes/common:packages.yaml",
) )
use_repo(apt, "apt_common") use_repo(apt, "apt_common")
apt.install(
name = "apt_cuda",
lock = "//runtimes/cuda:packages.lock.json",
manifest = "//runtimes/cuda:packages.yaml",
)
use_repo(apt, "apt_cuda")
apt.install( apt.install(
name = "apt_rocm", name = "apt_rocm",
lock = "//runtimes/rocm:packages.lock.json", lock = "//runtimes/rocm:packages.lock.json",

View File

@ -77,7 +77,8 @@ pub const Api = struct {
pub fn loadFrom(library: [:0]const u8) !*const Api { pub fn loadFrom(library: [:0]const u8) !*const Api {
var lib: std.DynLib = switch (builtin.os.tag) { var lib: std.DynLib = switch (builtin.os.tag) {
.linux => blk: { .linux => blk: {
const handle = std.c.dlopen(library, .{ .LAZY = true, .GLOBAL = false, .NODELETE = true }) orelse { // We use RTLD_GLOBAL so that symbols from NEEDED libraries are available in the global namespace.
const handle = std.c.dlopen(library, .{ .LAZY = true, .GLOBAL = true, .NODELETE = true }) orelse {
log.err("Unable to dlopen plugin: {s}", .{library}); log.err("Unable to dlopen plugin: {s}", .{library});
return error.FileNotFound; return error.FileNotFound;
}; };

View File

@ -1,3 +1,5 @@
load("@rules_zig//zig:defs.bzl", "zig_library")
exports_files( exports_files(
["packages.lock.json"], ["packages.lock.json"],
visibility = ["//runtimes:__subpackages__"], visibility = ["//runtimes:__subpackages__"],

View File

@ -1,5 +1,6 @@
load("@bazel_skylib//lib:paths.bzl", "paths") load("@bazel_skylib//lib:paths.bzl", "paths")
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
load("//bazel:http_deb_archive.bzl", "http_deb_archive")
load("//runtimes/common:packages.bzl", "packages") load("//runtimes/common:packages.bzl", "packages")
_BUILD_FILE_DEFAULT_VISIBILITY = """\ _BUILD_FILE_DEFAULT_VISIBILITY = """\
@ -16,47 +17,54 @@ CUDNN_REDIST_PREFIX = "https://developer.download.nvidia.com/compute/cudnn/redis
CUDNN_VERSION = "9.8.0" CUDNN_VERSION = "9.8.0"
CUDNN_REDIST_JSON_SHA256 = "a1599fa1f8dcb81235157be5de5ab7d3936e75dfc4e1e442d07970afad3c4843" CUDNN_REDIST_JSON_SHA256 = "a1599fa1f8dcb81235157be5de5ab7d3936e75dfc4e1e442d07970afad3c4843"
_UBUNTU_PACKAGES = {
"zlib1g": packages.filegroup(name = "zlib1g", srcs = ["lib/x86_64-linux-gnu/libz.so.1"]),
}
CUDA_PACKAGES = { CUDA_PACKAGES = {
"cuda_cudart": "\n".join([ "cuda_cudart": "\n".join([
# Driver API only
packages.cc_library( packages.cc_library(
name = "cudart", name = "cuda",
hdrs = ["include/cuda.h"], hdrs = ["include/cuda.h"],
includes = ["include"], includes = ["include"],
deps = [":cudart_so", ":cuda_so"],
), ),
packages.cc_import( #TODO: Remove me as soon we use the Driver API in tracer.zig
name = "cudart_so", packages.filegroup(
shared_library = "lib/libcudart.so.12", name = "so_files",
), srcs = ["lib/libcudart.so.12"],
packages.cc_import(
name = "cuda_so",
shared_library = "lib/stubs/libcuda.so",
), ),
]), ]),
"cuda_cupti": packages.cc_import( "cuda_cupti": packages.filegroup(
name = "cupti", name = "so_files",
shared_library = "lib/libcupti.so.12", srcs = ["lib/libcupti.so.12"],
), ),
"cuda_nvtx": packages.cc_import_glob_hdrs( "cuda_nvtx": "\n".join([
name = "nvtx", # packages.cc_library(
hdrs_glob = ["include/nvtx3/**/*.h"], # name = "nvtx",
shared_library = "lib/libnvToolsExt.so.1", # hdrs = glob(["include/nvtx3/**/*.h"]),
# visibility = ["//visibility:public"],
# ),
packages.filegroup(
name = "so_files",
srcs = ["lib/libnvToolsExt.so.1"],
), ),
"libcufft": packages.cc_import( ]),
name = "cufft", "libcufft": packages.filegroup(
shared_library = "lib/libcufft.so.11", name = "so_files",
srcs = ["lib/libcufft.so.11"],
), ),
"libcusolver": packages.cc_import( "libcusolver": packages.filegroup(
name = "cusolver", name = "so_files",
shared_library = "lib/libcusolver.so.11", srcs = ["lib/libcusolver.so.11"],
), ),
"libcusparse": packages.cc_import( "libcusparse": packages.filegroup(
name = "cusparse", name = "so_files",
shared_library = "lib/libcusparse.so.12", srcs = ["lib/libcusparse.so.12"],
), ),
"libnvjitlink": packages.cc_import( "libnvjitlink": packages.filegroup(
name = "nvjitlink", name = "so_files",
shared_library = "lib/libnvJitLink.so.12", srcs = ["lib/libnvJitLink.so.12"],
), ),
"cuda_nvcc": "\n".join([ "cuda_nvcc": "\n".join([
packages.filegroup( packages.filegroup(
@ -71,9 +79,9 @@ CUDA_PACKAGES = {
name = "libdevice", name = "libdevice",
srcs = ["nvvm/libdevice/libdevice.10.bc"], srcs = ["nvvm/libdevice/libdevice.10.bc"],
), ),
packages.cc_import( packages.filegroup(
name = "nvvm", name = "so_files",
shared_library = "nvvm/lib64/libnvvm.so.4", srcs = ["nvvm/lib64/libnvvm.so.4"],
), ),
packages.cc_import( packages.cc_import(
name = "nvptxcompiler", name = "nvptxcompiler",
@ -81,73 +89,40 @@ CUDA_PACKAGES = {
), ),
]), ]),
"cuda_nvrtc": "\n".join([ "cuda_nvrtc": "\n".join([
packages.cc_import( packages.filegroup(
name = "nvrtc", name = "so_files",
shared_library = "lib/libnvrtc.so.12", srcs = [
deps = [":nvrtc_builtins"], "lib/libnvrtc.so.12",
), "lib/libnvrtc-builtins.so.12.8",
packages.cc_import( ],
name = "nvrtc_builtins",
shared_library = "lib/libnvrtc-builtins.so.12.8",
), ),
]), ]),
"libcublas": "\n".join([ "libcublas": "\n".join([
packages.cc_import( packages.filegroup(
name = "cublasLt", name = "so_files",
shared_library = "lib/libcublasLt.so.12", srcs = [
), "lib/libcublasLt.so.12",
packages.cc_import( "lib/libcublas.so.12",
name = "cublas", ],
shared_library = "lib/libcublas.so.12",
deps = [":cublasLt"],
), ),
]), ]),
} }
CUDNN_PACKAGES = { CUDNN_PACKAGES = {
"cudnn": "\n".join([ "cudnn": "\n".join([
packages.cc_import( packages.filegroup(
name = "cudnn", name = "so_files",
shared_library = "lib/libcudnn.so.9", srcs = [
deps = [ "lib/libcudnn.so.9",
":cudnn_adv", "lib/libcudnn_adv.so.9",
":cudnn_ops", "lib/libcudnn_ops.so.9",
":cudnn_cnn", "lib/libcudnn_cnn.so.9",
":cudnn_graph", "lib/libcudnn_graph.so.9",
":cudnn_engines_precompiled", "lib/libcudnn_engines_precompiled.so.9",
":cudnn_engines_runtime_compiled", "lib/libcudnn_engines_runtime_compiled.so.9",
":cudnn_heuristic", "lib/libcudnn_heuristic.so.9",
], ],
), ),
packages.cc_import(
name = "cudnn_adv",
shared_library = "lib/libcudnn_adv.so.9",
),
packages.cc_import(
name = "cudnn_ops",
shared_library = "lib/libcudnn_ops.so.9",
),
packages.cc_import(
name = "cudnn_cnn",
shared_library = "lib/libcudnn_cnn.so.9",
deps = [":cudnn_ops"],
),
packages.cc_import(
name = "cudnn_graph",
shared_library = "lib/libcudnn_graph.so.9",
),
packages.cc_import(
name = "cudnn_engines_precompiled",
shared_library = "lib/libcudnn_engines_precompiled.so.9",
),
packages.cc_import(
name = "cudnn_engines_runtime_compiled",
shared_library = "lib/libcudnn_engines_runtime_compiled.so.9",
),
packages.cc_import(
name = "cudnn_heuristic",
shared_library = "lib/libcudnn_heuristic.so.9",
),
]), ]),
} }
@ -161,6 +136,9 @@ def _read_redist_json(mctx, url, sha256):
return json.decode(mctx.read(fname)) return json.decode(mctx.read(fname))
def _cuda_impl(mctx): def _cuda_impl(mctx):
loaded_packages = packages.read(mctx, [
"@zml//runtimes/cuda:packages.lock.json",
])
CUDA_REDIST = _read_redist_json( CUDA_REDIST = _read_redist_json(
mctx, mctx,
url = CUDA_REDIST_PREFIX + "redistrib_{}.json".format(CUDA_VERSION), url = CUDA_REDIST_PREFIX + "redistrib_{}.json".format(CUDA_VERSION),
@ -173,6 +151,15 @@ def _cuda_impl(mctx):
sha256 = CUDNN_REDIST_JSON_SHA256, sha256 = CUDNN_REDIST_JSON_SHA256,
) )
for pkg_name, build_file_content in _UBUNTU_PACKAGES.items():
pkg = loaded_packages[pkg_name]
http_deb_archive(
name = pkg_name,
urls = pkg["urls"],
sha256 = pkg["sha256"],
build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + build_file_content,
)
for pkg, build_file_content in CUDA_PACKAGES.items(): for pkg, build_file_content in CUDA_PACKAGES.items():
pkg_data = CUDA_REDIST[pkg] pkg_data = CUDA_REDIST[pkg]
arch_data = pkg_data.get(ARCH) arch_data = pkg_data.get(ARCH)
@ -205,9 +192,9 @@ def _cuda_impl(mctx):
urls = ["https://files.pythonhosted.org/packages/11/0c/8c78b7603f4e685624a3ea944940f1e75f36d71bd6504330511f4a0e1557/nvidia_nccl_cu12-2.25.1-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl"], urls = ["https://files.pythonhosted.org/packages/11/0c/8c78b7603f4e685624a3ea944940f1e75f36d71bd6504330511f4a0e1557/nvidia_nccl_cu12-2.25.1-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl"],
type = "zip", type = "zip",
sha256 = "362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a", sha256 = "362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a",
build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + packages.cc_import( build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + packages.filegroup(
name = "nccl", name = "so_files",
shared_library = "nvidia/nccl/lib/libnccl.so.2", srcs = ["nvidia/nccl/lib/libnccl.so.2"],
), ),
) )

View File

@ -31,20 +31,9 @@ fn hasCudaPathInLDPath() bool {
return std.ascii.indexOfIgnoreCase(std.mem.span(ldLibraryPath), nvidiaLibsPath) != null; return std.ascii.indexOfIgnoreCase(std.mem.span(ldLibraryPath), nvidiaLibsPath) != null;
} }
fn setupXlaGpuCudaDirFlag() !void { fn setupXlaGpuCudaDirFlag(allocator: std.mem.Allocator, sandbox: []const u8) !void {
var arena = std.heap.ArenaAllocator.init(std.heap.c_allocator); const xla_flags = std.process.getEnvVarOwned(allocator, "XLA_FLAGS") catch "";
defer arena.deinit(); const new_xla_flagsZ = try std.fmt.allocPrintZ(allocator, "{s} --xla_gpu_cuda_data_dir={s}", .{ xla_flags, sandbox });
var r_ = try runfiles.Runfiles.create(.{ .allocator = arena.allocator() }) orelse {
log.warn("Unable to find CUDA directory. Using system defaults.", .{});
return;
};
const source_repo = bazel_builtin.current_repository;
const r = r_.withSourceRepo(source_repo);
const cuda_data_dir = (try r.rlocationAlloc(arena.allocator(), "libpjrt_cuda/sandbox")).?;
const xla_flags = std.process.getEnvVarOwned(arena.allocator(), "XLA_FLAGS") catch "";
const new_xla_flagsZ = try std.fmt.allocPrintZ(arena.allocator(), "{s} --xla_gpu_cuda_data_dir={s}", .{ xla_flags, cuda_data_dir });
_ = c.setenv("XLA_FLAGS", new_xla_flagsZ, 1); _ = c.setenv("XLA_FLAGS", new_xla_flagsZ, 1);
} }
@ -63,9 +52,38 @@ pub fn load() !*const pjrt.Api {
log.warn("Detected {s} in LD_LIBRARY_PATH. This can lead to undefined behaviors and crashes", .{nvidiaLibsPath}); log.warn("Detected {s} in LD_LIBRARY_PATH. This can lead to undefined behaviors and crashes", .{nvidiaLibsPath});
} }
var arena = std.heap.ArenaAllocator.init(std.heap.c_allocator);
defer arena.deinit();
var r_ = try runfiles.Runfiles.create(.{ .allocator = arena.allocator() }) orelse {
stdx.debug.panic("Unable to find runfiles", .{});
};
const source_repo = bazel_builtin.current_repository;
const r = r_.withSourceRepo(source_repo);
var path_buf: [std.fs.max_path_bytes]u8 = undefined;
const sandbox_path = try r.rlocation("libpjrt_cuda/sandbox", &path_buf) orelse {
log.err("Failed to find sandbox path for CUDA runtime", .{});
return error.FileNotFound;
};
// CUDA path has to be set _before_ loading the PJRT plugin. // CUDA path has to be set _before_ loading the PJRT plugin.
// See https://github.com/openxla/xla/issues/21428 // See https://github.com/openxla/xla/issues/21428
try setupXlaGpuCudaDirFlag(); try setupXlaGpuCudaDirFlag(arena.allocator(), sandbox_path);
return try asynk.callBlocking(pjrt.Api.loadFrom, .{"libpjrt_cuda.so"}); {
var lib_path_buf: [std.fs.max_path_bytes]u8 = undefined;
const path = try stdx.fs.path.bufJoinZ(&lib_path_buf, &.{ sandbox_path, "lib", "libnvToolsExt.so.1" });
_ = std.c.dlopen(path, .{ .NOW = true, .GLOBAL = true }) orelse {
log.err("Unable to dlopen libnvToolsExt.so.1: {s}", .{std.c.dlerror().?});
return error.DlError;
};
}
return blk: {
var lib_path_buf: [std.fs.max_path_bytes]u8 = undefined;
const path = try stdx.fs.path.bufJoinZ(&lib_path_buf, &.{ sandbox_path, "lib", "libpjrt_cuda.so" });
break :blk asynk.callBlocking(pjrt.Api.loadFrom, .{path});
};
} }

View File

@ -1,9 +1,10 @@
load("@aspect_bazel_lib//lib:copy_to_directory.bzl", "copy_to_directory") load("@aspect_bazel_lib//lib:copy_to_directory.bzl", "copy_to_directory")
load("@zml//bazel:cc_import.bzl", "cc_import") load("@zml//bazel:cc_import.bzl", "cc_import")
load("@zml//bazel:patchelf.bzl", "patchelf")
cc_shared_library( cc_shared_library(
name = "zmlxcuda_so", name = "zmlxcuda_so",
shared_lib_name = "libzmlxcuda.so.0", shared_lib_name = "lib/libzmlxcuda.so.0",
deps = ["@zml//runtimes/cuda:zmlxcuda_lib"], deps = ["@zml//runtimes/cuda:zmlxcuda_lib"],
) )
@ -12,40 +13,65 @@ cc_import(
shared_library = ":zmlxcuda_so", shared_library = ":zmlxcuda_so",
) )
copy_to_directory( patchelf(
name = "sandbox", name = "libpjrt_cuda.patchelf",
srcs = [
"@cuda_nvcc//:libdevice",
"@cuda_nvcc//:ptxas",
"@cuda_nvcc//:nvlink",
],
include_external_repositories = ["**"],
)
cc_import(
name = "libpjrt_cuda",
data = [":sandbox"],
shared_library = "libpjrt_cuda.so", shared_library = "libpjrt_cuda.so",
add_needed = ["libzmlxcuda.so.0"], add_needed = [
"libzmlxcuda.so.0",
],
rename_dynamic_symbols = { rename_dynamic_symbols = {
"dlopen": "zmlxcuda_dlopen", "dlopen": "zmlxcuda_dlopen",
}, },
visibility = ["@zml//runtimes/cuda:__subpackages__"], set_rpath = "$ORIGIN",
deps = [ )
":zmlxcuda",
"@cuda_cudart//:cudart", copy_to_directory(
"@cuda_cupti//:cupti", name = "sandbox",
"@cuda_nvtx//:nvtx", srcs = [
"@cuda_nvcc//:nvptxcompiler", ":zmlxcuda_so",
"@cuda_nvcc//:nvvm", ":libpjrt_cuda.patchelf",
"@cuda_nvrtc//:nvrtc", "@cuda_nvcc//:libdevice",
"@cudnn//:cudnn", "@cuda_nvcc//:ptxas",
"@libcublas//:cublas", "@cuda_nvcc//:nvlink",
"@libcufft//:cufft", "@cuda_cupti//:so_files",
"@libcusolver//:cusolver", "@cuda_nvtx//:so_files",
"@libcusparse//:cusparse", "@cuda_nvcc//:so_files",
"@libnvjitlink//:nvjitlink", "@cuda_nvrtc//:so_files",
"@nccl", "@cuda_cudart//:so_files",
"@cudnn//:so_files",
"@libcublas//:so_files",
"@libcufft//:so_files",
"@libcusolver//:so_files",
"@libcusparse//:so_files",
"@libnvjitlink//:so_files",
"@nccl//:so_files",
"@zlib1g", "@zlib1g",
], ],
replace_prefixes = {
"nvidia/nccl/lib": "lib",
"nvvm/lib64": "lib",
"libpjrt_cuda.patchelf": "lib",
"lib/x86_64-linux-gnu": "lib",
},
add_directory_to_runfiles = False,
include_external_repositories = ["**"],
)
cc_library(
name = "libpjrt_cuda",
data = [":sandbox"],
deps = [
"@cuda_cudart//:cuda",
],
linkopts = [
# Defer function call resolution until the function is called
# (lazy loading) rather than at load time.
#
# This is required because we want to let downstream use weak CUDA symbols.
#
# We force it here because -z,now (which resolve all symbols at load time),
# is the default in most bazel CC toolchains as well as in certain linkers.
"-Wl,-z,lazy",
],
visibility = ["@zml//runtimes/cuda:__subpackages__"],
) )

View File

@ -0,0 +1,65 @@
{
"packages": [
{
"arch": "amd64",
"dependencies": [
{
"key": "libc6_2.36-9-p-deb12u10_amd64",
"name": "libc6",
"version": "2.36-9+deb12u10"
},
{
"key": "libgcc-s1_12.2.0-14-p-deb12u1_amd64",
"name": "libgcc-s1",
"version": "12.2.0-14+deb12u1"
},
{
"key": "gcc-12-base_12.2.0-14-p-deb12u1_amd64",
"name": "gcc-12-base",
"version": "12.2.0-14+deb12u1"
}
],
"key": "zlib1g_1-1.2.13.dfsg-1_amd64",
"name": "zlib1g",
"sha256": "d7dd1d1411fedf27f5e27650a6eff20ef294077b568f4c8c5e51466dc7c08ce4",
"urls": [
"https://snapshot-cloudflare.debian.org/archive/debian/20250711T030400Z/pool/main/z/zlib/zlib1g_1.2.13.dfsg-1_amd64.deb"
],
"version": "1:1.2.13.dfsg-1"
},
{
"arch": "amd64",
"dependencies": [],
"key": "libc6_2.36-9-p-deb12u10_amd64",
"name": "libc6",
"sha256": "5dc83256f10ca4d0f2a53dd6583ffd0d0e319af30074ea6c82fb0ae77bd16365",
"urls": [
"https://snapshot-cloudflare.debian.org/archive/debian/20250711T030400Z/pool/main/g/glibc/libc6_2.36-9+deb12u10_amd64.deb"
],
"version": "2.36-9+deb12u10"
},
{
"arch": "amd64",
"dependencies": [],
"key": "libgcc-s1_12.2.0-14-p-deb12u1_amd64",
"name": "libgcc-s1",
"sha256": "3016e62cb4b7cd8038822870601f5ed131befe942774d0f745622cc77d8a88f7",
"urls": [
"https://snapshot-cloudflare.debian.org/archive/debian/20250711T030400Z/pool/main/g/gcc-12/libgcc-s1_12.2.0-14+deb12u1_amd64.deb"
],
"version": "12.2.0-14+deb12u1"
},
{
"arch": "amd64",
"dependencies": [],
"key": "gcc-12-base_12.2.0-14-p-deb12u1_amd64",
"name": "gcc-12-base",
"sha256": "1896a2aacf4ad681ff5eacc24a5b0ca4d5d9c9b9c9e4b6de5197bc1e116ea619",
"urls": [
"https://snapshot-cloudflare.debian.org/archive/debian/20250711T030400Z/pool/main/g/gcc-12/gcc-12-base_12.2.0-14+deb12u1_amd64.deb"
],
"version": "12.2.0-14+deb12u1"
}
],
"version": 1
}

View File

@ -0,0 +1,14 @@
#
# bazel run @apt_cuda//:lock
#
version: 1
sources:
- channel: bookworm main
url: https://snapshot-cloudflare.debian.org/archive/debian/20250711T030400Z
archs:
- "amd64"
packages:
- "zlib1g"

View File

@ -7,6 +7,7 @@ zig_library(
"debug.zig", "debug.zig",
"flags.zig", "flags.zig",
"fmt.zig", "fmt.zig",
"fs.zig",
"io.zig", "io.zig",
"json.zig", "json.zig",
"math.zig", "math.zig",

13
stdx/fs.zig Normal file
View File

@ -0,0 +1,13 @@
const std = @import("std");
pub const path = struct {
pub fn bufJoin(buf: []u8, paths: []const []const u8) ![]u8 {
var fa: std.heap.FixedBufferAllocator = .init(buf);
return try std.fs.path.join(fa.allocator(), paths);
}
pub fn bufJoinZ(buf: []u8, paths: []const []const u8) ![:0]u8 {
var fa: std.heap.FixedBufferAllocator = .init(buf);
return try std.fs.path.joinZ(fa.allocator(), paths);
}
};

View File

@ -1,6 +1,7 @@
pub const debug = @import("debug.zig"); pub const debug = @import("debug.zig");
pub const flags = @import("flags.zig"); pub const flags = @import("flags.zig");
pub const fmt = @import("fmt.zig"); pub const fmt = @import("fmt.zig");
pub const fs = @import("fs.zig");
pub const io = @import("io.zig"); pub const io = @import("io.zig");
pub const json = @import("json.zig"); pub const json = @import("json.zig");
pub const math = @import("math.zig"); pub const math = @import("math.zig");

View File

@ -36,7 +36,6 @@ pub const Context = struct {
inline for (comptime std.enums.values(runtimes.Platform)) |t| { inline for (comptime std.enums.values(runtimes.Platform)) |t| {
if (runtimes.load(t)) |api| { if (runtimes.load(t)) |api| {
Context.apis.set(t, api); Context.apis.set(t, api);
if (t == .cuda) cuda.init();
} else |_| {} } else |_| {}
} }
} }
@ -276,69 +275,3 @@ fn hostBufferFromPinnedBuffer(buffer_desc: *const pjrt.ffi.Buffer) HostBuffer {
buffer_desc.data[0..buffer_shape.byteSize()], buffer_desc.data[0..buffer_shape.byteSize()],
); );
} }
pub const cuda = struct {
pub var streamSynchronize: StreamSynchronize = @ptrFromInt(0xdeadc00da00);
pub var cuLaunchHostFunc: CuLaunchHostFunc = @ptrFromInt(0xdeadc00da00);
var _memcpyAsync: MemcpyAsync = @ptrFromInt(0xdeadc00da00);
var _memcpyBlocking: MemcpyBlocking = @ptrFromInt(0xdeadc00da00);
pub const MemcpyKind = enum(c_int) {
host_to_host = 0,
host_to_device = 1,
device_to_host = 2,
device_to_device = 3,
inferred = 4,
};
const MemcpyAsync = *const fn (dst: *anyopaque, src: *const anyopaque, count: usize, kind: MemcpyKind, stream: ?*anyopaque) callconv(.C) c_int;
const MemcpyBlocking = *const fn (dst: *anyopaque, src: *const anyopaque, count: usize, kind: MemcpyKind) callconv(.C) c_int;
const StreamSynchronize = *const fn (stream: *anyopaque) callconv(.C) c_int;
const CuLaunchHostFunc = *const fn (stream: *anyopaque, host_func: *const fn (user_data: *const anyopaque) callconv(.c) void, user_data: *const anyopaque) callconv(.c) c_int;
pub fn init() void {
var cudart = std.DynLib.open("libcudart.so.12") catch {
log.err("cudart not found, callback will segfault", .{});
return;
};
defer cudart.close();
_memcpyAsync = cudart.lookup(MemcpyAsync, "cudaMemcpyAsync") orelse {
@panic("cudaMemcpyAsync not found");
};
_memcpyBlocking = cudart.lookup(MemcpyBlocking, "cudaMemcpy") orelse {
@panic("cudaMemcpy not found");
};
streamSynchronize = cudart.lookup(StreamSynchronize, "cudaStreamSynchronize") orelse {
@panic("cudaStreamSynchronize not found");
};
cuLaunchHostFunc = cudart.lookup(CuLaunchHostFunc, "cudaLaunchHostFunc") orelse {
@panic("cudaLaunchHostFunc not found");
};
}
pub fn memcpyToHostBlocking(dst: []u8, src: *const anyopaque) void {
const err = _memcpyBlocking(dst.ptr, src, dst.len, .device_to_host);
check(err);
}
pub fn memcpyToDeviceBlocking(dst: *anyopaque, src: []const u8) void {
const err = _memcpyBlocking(dst, src.ptr, src.len, .host_to_device);
check(err);
}
pub fn memcpyToDeviceAsync(dst: *anyopaque, src: []const u8, stream: ?*anyopaque) void {
const err = _memcpyAsync(dst, src.ptr, src.len, .host_to_device, stream);
check(err);
}
pub fn memcpyToHostAsync(dst: []u8, src: *const anyopaque, stream: ?*anyopaque) void {
const err = _memcpyAsync(dst.ptr, src, dst.len, .device_to_host, stream);
check(err);
}
pub fn check(err: c_int) void {
if (err == 0) return;
stdx.debug.panic("CUDA error: {d}", .{err});
}
};

View File

@ -8,22 +8,28 @@ pub const Tracer = switch (builtin.os.tag) {
}; };
const CudaTracer = struct { const CudaTracer = struct {
extern fn cudaProfilerStart() c_int;
extern fn cudaProfilerStop() c_int;
extern fn nvtxMarkA(message: [*:0]const u8) void; // Those symbols are defined in cudaProfiler.h but their implementation is in libcuda.so
extern fn nvtxRangeStartA(message: [*:0]const u8) c_int; // They will be bound at call time after libcuda.so is loaded (as a needed dependency of libpjrt_cuda.so).
extern fn nvtxRangeEnd(id: c_int) void; const cuProfilerStart = @extern(*const fn () callconv(.C) c_int, .{ .name = "cuProfilerStart", .linkage = .weak }) orelse unreachable;
const cuProfilerStop = @extern(*const fn () callconv(.C) c_int, .{ .name = "cuProfilerStop", .linkage = .weak }) orelse unreachable;
// Those symbols are defined in nvToolsExt.h which we don't want to provide.
// However, we link with libnvToolsExt.so which provides them.
// They will be bound at call time after libnvToolsExt.so is loaded (manually dlopen'ed by us).
const nvtxMarkA = @extern(*const fn ([*:0]const u8) callconv(.C) void, .{ .name = "nvtxMarkA", .linkage = .weak }) orelse unreachable;
const nvtxRangeStartA = @extern(*const fn ([*:0]const u8) callconv(.C) c_int, .{ .name = "nvtxRangeStartA", .linkage = .weak }) orelse unreachable;
const nvtxRangeEnd = @extern(*const fn (c_int) callconv(.C) void, .{ .name = "nvtxRangeEnd", .linkage = .weak }) orelse unreachable;
pub fn init(name: [:0]const u8) CudaTracer { pub fn init(name: [:0]const u8) CudaTracer {
_ = name; _ = name;
_ = cudaProfilerStart(); _ = cuProfilerStart();
return .{}; return .{};
} }
pub fn deinit(self: *const CudaTracer) void { pub fn deinit(self: *const CudaTracer) void {
_ = self; _ = self;
_ = cudaProfilerStop(); _ = cuProfilerStop();
} }
pub fn event(self: *const CudaTracer, message: [:0]const u8) void { pub fn event(self: *const CudaTracer, message: [:0]const u8) void {