runtimes/cuda: sandbox CUDA dependencies by removing them from the leaf binary, sandboxing the dependency graph, marking dlopen direct dependencies as NEEDED, setting RPATH to the sandbox, loading the PJRT plugin from the sandbox, and enabling weak CUDA symbols without direct linking.
This commit is contained in:
parent
a5420068b1
commit
2d321d232d
@ -86,9 +86,6 @@ cpu = use_extension("//runtimes/cpu:cpu.bzl", "cpu_pjrt_plugin")
|
||||
use_repo(cpu, "libpjrt_cpu_darwin_amd64", "libpjrt_cpu_darwin_arm64", "libpjrt_cpu_linux_amd64")
|
||||
|
||||
cuda = use_extension("//runtimes/cuda:cuda.bzl", "cuda_packages")
|
||||
|
||||
inject_repo(cuda, "zlib1g")
|
||||
|
||||
use_repo(cuda, "libpjrt_cuda")
|
||||
|
||||
rocm = use_extension("//runtimes/rocm:rocm.bzl", "rocm_packages")
|
||||
@ -159,6 +156,12 @@ apt.install(
|
||||
manifest = "//runtimes/common:packages.yaml",
|
||||
)
|
||||
use_repo(apt, "apt_common")
|
||||
apt.install(
|
||||
name = "apt_cuda",
|
||||
lock = "//runtimes/cuda:packages.lock.json",
|
||||
manifest = "//runtimes/cuda:packages.yaml",
|
||||
)
|
||||
use_repo(apt, "apt_cuda")
|
||||
apt.install(
|
||||
name = "apt_rocm",
|
||||
lock = "//runtimes/rocm:packages.lock.json",
|
||||
|
||||
@ -77,7 +77,8 @@ pub const Api = struct {
|
||||
pub fn loadFrom(library: [:0]const u8) !*const Api {
|
||||
var lib: std.DynLib = switch (builtin.os.tag) {
|
||||
.linux => blk: {
|
||||
const handle = std.c.dlopen(library, .{ .LAZY = true, .GLOBAL = false, .NODELETE = true }) orelse {
|
||||
// We use RTLD_GLOBAL so that symbols from NEEDED libraries are available in the global namespace.
|
||||
const handle = std.c.dlopen(library, .{ .LAZY = true, .GLOBAL = true, .NODELETE = true }) orelse {
|
||||
log.err("Unable to dlopen plugin: {s}", .{library});
|
||||
return error.FileNotFound;
|
||||
};
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
load("@rules_zig//zig:defs.bzl", "zig_library")
|
||||
|
||||
exports_files(
|
||||
["packages.lock.json"],
|
||||
visibility = ["//runtimes:__subpackages__"],
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
load("@bazel_skylib//lib:paths.bzl", "paths")
|
||||
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
|
||||
load("//bazel:http_deb_archive.bzl", "http_deb_archive")
|
||||
load("//runtimes/common:packages.bzl", "packages")
|
||||
|
||||
_BUILD_FILE_DEFAULT_VISIBILITY = """\
|
||||
@ -16,47 +17,54 @@ CUDNN_REDIST_PREFIX = "https://developer.download.nvidia.com/compute/cudnn/redis
|
||||
CUDNN_VERSION = "9.8.0"
|
||||
CUDNN_REDIST_JSON_SHA256 = "a1599fa1f8dcb81235157be5de5ab7d3936e75dfc4e1e442d07970afad3c4843"
|
||||
|
||||
_UBUNTU_PACKAGES = {
|
||||
"zlib1g": packages.filegroup(name = "zlib1g", srcs = ["lib/x86_64-linux-gnu/libz.so.1"]),
|
||||
}
|
||||
|
||||
CUDA_PACKAGES = {
|
||||
"cuda_cudart": "\n".join([
|
||||
# Driver API only
|
||||
packages.cc_library(
|
||||
name = "cudart",
|
||||
name = "cuda",
|
||||
hdrs = ["include/cuda.h"],
|
||||
includes = ["include"],
|
||||
deps = [":cudart_so", ":cuda_so"],
|
||||
),
|
||||
packages.cc_import(
|
||||
name = "cudart_so",
|
||||
shared_library = "lib/libcudart.so.12",
|
||||
),
|
||||
packages.cc_import(
|
||||
name = "cuda_so",
|
||||
shared_library = "lib/stubs/libcuda.so",
|
||||
#TODO: Remove me as soon we use the Driver API in tracer.zig
|
||||
packages.filegroup(
|
||||
name = "so_files",
|
||||
srcs = ["lib/libcudart.so.12"],
|
||||
),
|
||||
]),
|
||||
"cuda_cupti": packages.cc_import(
|
||||
name = "cupti",
|
||||
shared_library = "lib/libcupti.so.12",
|
||||
"cuda_cupti": packages.filegroup(
|
||||
name = "so_files",
|
||||
srcs = ["lib/libcupti.so.12"],
|
||||
),
|
||||
"cuda_nvtx": packages.cc_import_glob_hdrs(
|
||||
name = "nvtx",
|
||||
hdrs_glob = ["include/nvtx3/**/*.h"],
|
||||
shared_library = "lib/libnvToolsExt.so.1",
|
||||
"cuda_nvtx": "\n".join([
|
||||
# packages.cc_library(
|
||||
# name = "nvtx",
|
||||
# hdrs = glob(["include/nvtx3/**/*.h"]),
|
||||
# visibility = ["//visibility:public"],
|
||||
# ),
|
||||
packages.filegroup(
|
||||
name = "so_files",
|
||||
srcs = ["lib/libnvToolsExt.so.1"],
|
||||
),
|
||||
]),
|
||||
"libcufft": packages.filegroup(
|
||||
name = "so_files",
|
||||
srcs = ["lib/libcufft.so.11"],
|
||||
),
|
||||
"libcufft": packages.cc_import(
|
||||
name = "cufft",
|
||||
shared_library = "lib/libcufft.so.11",
|
||||
"libcusolver": packages.filegroup(
|
||||
name = "so_files",
|
||||
srcs = ["lib/libcusolver.so.11"],
|
||||
),
|
||||
"libcusolver": packages.cc_import(
|
||||
name = "cusolver",
|
||||
shared_library = "lib/libcusolver.so.11",
|
||||
"libcusparse": packages.filegroup(
|
||||
name = "so_files",
|
||||
srcs = ["lib/libcusparse.so.12"],
|
||||
),
|
||||
"libcusparse": packages.cc_import(
|
||||
name = "cusparse",
|
||||
shared_library = "lib/libcusparse.so.12",
|
||||
),
|
||||
"libnvjitlink": packages.cc_import(
|
||||
name = "nvjitlink",
|
||||
shared_library = "lib/libnvJitLink.so.12",
|
||||
"libnvjitlink": packages.filegroup(
|
||||
name = "so_files",
|
||||
srcs = ["lib/libnvJitLink.so.12"],
|
||||
),
|
||||
"cuda_nvcc": "\n".join([
|
||||
packages.filegroup(
|
||||
@ -71,9 +79,9 @@ CUDA_PACKAGES = {
|
||||
name = "libdevice",
|
||||
srcs = ["nvvm/libdevice/libdevice.10.bc"],
|
||||
),
|
||||
packages.cc_import(
|
||||
name = "nvvm",
|
||||
shared_library = "nvvm/lib64/libnvvm.so.4",
|
||||
packages.filegroup(
|
||||
name = "so_files",
|
||||
srcs = ["nvvm/lib64/libnvvm.so.4"],
|
||||
),
|
||||
packages.cc_import(
|
||||
name = "nvptxcompiler",
|
||||
@ -81,73 +89,40 @@ CUDA_PACKAGES = {
|
||||
),
|
||||
]),
|
||||
"cuda_nvrtc": "\n".join([
|
||||
packages.cc_import(
|
||||
name = "nvrtc",
|
||||
shared_library = "lib/libnvrtc.so.12",
|
||||
deps = [":nvrtc_builtins"],
|
||||
),
|
||||
packages.cc_import(
|
||||
name = "nvrtc_builtins",
|
||||
shared_library = "lib/libnvrtc-builtins.so.12.8",
|
||||
packages.filegroup(
|
||||
name = "so_files",
|
||||
srcs = [
|
||||
"lib/libnvrtc.so.12",
|
||||
"lib/libnvrtc-builtins.so.12.8",
|
||||
],
|
||||
),
|
||||
]),
|
||||
"libcublas": "\n".join([
|
||||
packages.cc_import(
|
||||
name = "cublasLt",
|
||||
shared_library = "lib/libcublasLt.so.12",
|
||||
),
|
||||
packages.cc_import(
|
||||
name = "cublas",
|
||||
shared_library = "lib/libcublas.so.12",
|
||||
deps = [":cublasLt"],
|
||||
packages.filegroup(
|
||||
name = "so_files",
|
||||
srcs = [
|
||||
"lib/libcublasLt.so.12",
|
||||
"lib/libcublas.so.12",
|
||||
],
|
||||
),
|
||||
]),
|
||||
}
|
||||
|
||||
CUDNN_PACKAGES = {
|
||||
"cudnn": "\n".join([
|
||||
packages.cc_import(
|
||||
name = "cudnn",
|
||||
shared_library = "lib/libcudnn.so.9",
|
||||
deps = [
|
||||
":cudnn_adv",
|
||||
":cudnn_ops",
|
||||
":cudnn_cnn",
|
||||
":cudnn_graph",
|
||||
":cudnn_engines_precompiled",
|
||||
":cudnn_engines_runtime_compiled",
|
||||
":cudnn_heuristic",
|
||||
packages.filegroup(
|
||||
name = "so_files",
|
||||
srcs = [
|
||||
"lib/libcudnn.so.9",
|
||||
"lib/libcudnn_adv.so.9",
|
||||
"lib/libcudnn_ops.so.9",
|
||||
"lib/libcudnn_cnn.so.9",
|
||||
"lib/libcudnn_graph.so.9",
|
||||
"lib/libcudnn_engines_precompiled.so.9",
|
||||
"lib/libcudnn_engines_runtime_compiled.so.9",
|
||||
"lib/libcudnn_heuristic.so.9",
|
||||
],
|
||||
),
|
||||
packages.cc_import(
|
||||
name = "cudnn_adv",
|
||||
shared_library = "lib/libcudnn_adv.so.9",
|
||||
),
|
||||
packages.cc_import(
|
||||
name = "cudnn_ops",
|
||||
shared_library = "lib/libcudnn_ops.so.9",
|
||||
),
|
||||
packages.cc_import(
|
||||
name = "cudnn_cnn",
|
||||
shared_library = "lib/libcudnn_cnn.so.9",
|
||||
deps = [":cudnn_ops"],
|
||||
),
|
||||
packages.cc_import(
|
||||
name = "cudnn_graph",
|
||||
shared_library = "lib/libcudnn_graph.so.9",
|
||||
),
|
||||
packages.cc_import(
|
||||
name = "cudnn_engines_precompiled",
|
||||
shared_library = "lib/libcudnn_engines_precompiled.so.9",
|
||||
),
|
||||
packages.cc_import(
|
||||
name = "cudnn_engines_runtime_compiled",
|
||||
shared_library = "lib/libcudnn_engines_runtime_compiled.so.9",
|
||||
),
|
||||
packages.cc_import(
|
||||
name = "cudnn_heuristic",
|
||||
shared_library = "lib/libcudnn_heuristic.so.9",
|
||||
),
|
||||
]),
|
||||
}
|
||||
|
||||
@ -161,6 +136,9 @@ def _read_redist_json(mctx, url, sha256):
|
||||
return json.decode(mctx.read(fname))
|
||||
|
||||
def _cuda_impl(mctx):
|
||||
loaded_packages = packages.read(mctx, [
|
||||
"@zml//runtimes/cuda:packages.lock.json",
|
||||
])
|
||||
CUDA_REDIST = _read_redist_json(
|
||||
mctx,
|
||||
url = CUDA_REDIST_PREFIX + "redistrib_{}.json".format(CUDA_VERSION),
|
||||
@ -173,6 +151,15 @@ def _cuda_impl(mctx):
|
||||
sha256 = CUDNN_REDIST_JSON_SHA256,
|
||||
)
|
||||
|
||||
for pkg_name, build_file_content in _UBUNTU_PACKAGES.items():
|
||||
pkg = loaded_packages[pkg_name]
|
||||
http_deb_archive(
|
||||
name = pkg_name,
|
||||
urls = pkg["urls"],
|
||||
sha256 = pkg["sha256"],
|
||||
build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + build_file_content,
|
||||
)
|
||||
|
||||
for pkg, build_file_content in CUDA_PACKAGES.items():
|
||||
pkg_data = CUDA_REDIST[pkg]
|
||||
arch_data = pkg_data.get(ARCH)
|
||||
@ -205,9 +192,9 @@ def _cuda_impl(mctx):
|
||||
urls = ["https://files.pythonhosted.org/packages/11/0c/8c78b7603f4e685624a3ea944940f1e75f36d71bd6504330511f4a0e1557/nvidia_nccl_cu12-2.25.1-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl"],
|
||||
type = "zip",
|
||||
sha256 = "362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a",
|
||||
build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + packages.cc_import(
|
||||
name = "nccl",
|
||||
shared_library = "nvidia/nccl/lib/libnccl.so.2",
|
||||
build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + packages.filegroup(
|
||||
name = "so_files",
|
||||
srcs = ["nvidia/nccl/lib/libnccl.so.2"],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@ -31,20 +31,9 @@ fn hasCudaPathInLDPath() bool {
|
||||
return std.ascii.indexOfIgnoreCase(std.mem.span(ldLibraryPath), nvidiaLibsPath) != null;
|
||||
}
|
||||
|
||||
fn setupXlaGpuCudaDirFlag() !void {
|
||||
var arena = std.heap.ArenaAllocator.init(std.heap.c_allocator);
|
||||
defer arena.deinit();
|
||||
|
||||
var r_ = try runfiles.Runfiles.create(.{ .allocator = arena.allocator() }) orelse {
|
||||
log.warn("Unable to find CUDA directory. Using system defaults.", .{});
|
||||
return;
|
||||
};
|
||||
|
||||
const source_repo = bazel_builtin.current_repository;
|
||||
const r = r_.withSourceRepo(source_repo);
|
||||
const cuda_data_dir = (try r.rlocationAlloc(arena.allocator(), "libpjrt_cuda/sandbox")).?;
|
||||
const xla_flags = std.process.getEnvVarOwned(arena.allocator(), "XLA_FLAGS") catch "";
|
||||
const new_xla_flagsZ = try std.fmt.allocPrintZ(arena.allocator(), "{s} --xla_gpu_cuda_data_dir={s}", .{ xla_flags, cuda_data_dir });
|
||||
fn setupXlaGpuCudaDirFlag(allocator: std.mem.Allocator, sandbox: []const u8) !void {
|
||||
const xla_flags = std.process.getEnvVarOwned(allocator, "XLA_FLAGS") catch "";
|
||||
const new_xla_flagsZ = try std.fmt.allocPrintZ(allocator, "{s} --xla_gpu_cuda_data_dir={s}", .{ xla_flags, sandbox });
|
||||
|
||||
_ = c.setenv("XLA_FLAGS", new_xla_flagsZ, 1);
|
||||
}
|
||||
@ -63,9 +52,38 @@ pub fn load() !*const pjrt.Api {
|
||||
log.warn("Detected {s} in LD_LIBRARY_PATH. This can lead to undefined behaviors and crashes", .{nvidiaLibsPath});
|
||||
}
|
||||
|
||||
var arena = std.heap.ArenaAllocator.init(std.heap.c_allocator);
|
||||
defer arena.deinit();
|
||||
|
||||
var r_ = try runfiles.Runfiles.create(.{ .allocator = arena.allocator() }) orelse {
|
||||
stdx.debug.panic("Unable to find runfiles", .{});
|
||||
};
|
||||
|
||||
const source_repo = bazel_builtin.current_repository;
|
||||
const r = r_.withSourceRepo(source_repo);
|
||||
|
||||
var path_buf: [std.fs.max_path_bytes]u8 = undefined;
|
||||
const sandbox_path = try r.rlocation("libpjrt_cuda/sandbox", &path_buf) orelse {
|
||||
log.err("Failed to find sandbox path for CUDA runtime", .{});
|
||||
return error.FileNotFound;
|
||||
};
|
||||
|
||||
// CUDA path has to be set _before_ loading the PJRT plugin.
|
||||
// See https://github.com/openxla/xla/issues/21428
|
||||
try setupXlaGpuCudaDirFlag();
|
||||
try setupXlaGpuCudaDirFlag(arena.allocator(), sandbox_path);
|
||||
|
||||
return try asynk.callBlocking(pjrt.Api.loadFrom, .{"libpjrt_cuda.so"});
|
||||
{
|
||||
var lib_path_buf: [std.fs.max_path_bytes]u8 = undefined;
|
||||
const path = try stdx.fs.path.bufJoinZ(&lib_path_buf, &.{ sandbox_path, "lib", "libnvToolsExt.so.1" });
|
||||
_ = std.c.dlopen(path, .{ .NOW = true, .GLOBAL = true }) orelse {
|
||||
log.err("Unable to dlopen libnvToolsExt.so.1: {s}", .{std.c.dlerror().?});
|
||||
return error.DlError;
|
||||
};
|
||||
}
|
||||
|
||||
return blk: {
|
||||
var lib_path_buf: [std.fs.max_path_bytes]u8 = undefined;
|
||||
const path = try stdx.fs.path.bufJoinZ(&lib_path_buf, &.{ sandbox_path, "lib", "libpjrt_cuda.so" });
|
||||
break :blk asynk.callBlocking(pjrt.Api.loadFrom, .{path});
|
||||
};
|
||||
}
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
load("@aspect_bazel_lib//lib:copy_to_directory.bzl", "copy_to_directory")
|
||||
load("@zml//bazel:cc_import.bzl", "cc_import")
|
||||
load("@zml//bazel:patchelf.bzl", "patchelf")
|
||||
|
||||
cc_shared_library(
|
||||
name = "zmlxcuda_so",
|
||||
shared_lib_name = "libzmlxcuda.so.0",
|
||||
shared_lib_name = "lib/libzmlxcuda.so.0",
|
||||
deps = ["@zml//runtimes/cuda:zmlxcuda_lib"],
|
||||
)
|
||||
|
||||
@ -12,40 +13,65 @@ cc_import(
|
||||
shared_library = ":zmlxcuda_so",
|
||||
)
|
||||
|
||||
copy_to_directory(
|
||||
name = "sandbox",
|
||||
srcs = [
|
||||
"@cuda_nvcc//:libdevice",
|
||||
"@cuda_nvcc//:ptxas",
|
||||
"@cuda_nvcc//:nvlink",
|
||||
],
|
||||
include_external_repositories = ["**"],
|
||||
)
|
||||
|
||||
cc_import(
|
||||
name = "libpjrt_cuda",
|
||||
data = [":sandbox"],
|
||||
patchelf(
|
||||
name = "libpjrt_cuda.patchelf",
|
||||
shared_library = "libpjrt_cuda.so",
|
||||
add_needed = ["libzmlxcuda.so.0"],
|
||||
add_needed = [
|
||||
"libzmlxcuda.so.0",
|
||||
],
|
||||
rename_dynamic_symbols = {
|
||||
"dlopen": "zmlxcuda_dlopen",
|
||||
},
|
||||
visibility = ["@zml//runtimes/cuda:__subpackages__"],
|
||||
deps = [
|
||||
":zmlxcuda",
|
||||
"@cuda_cudart//:cudart",
|
||||
"@cuda_cupti//:cupti",
|
||||
"@cuda_nvtx//:nvtx",
|
||||
"@cuda_nvcc//:nvptxcompiler",
|
||||
"@cuda_nvcc//:nvvm",
|
||||
"@cuda_nvrtc//:nvrtc",
|
||||
"@cudnn//:cudnn",
|
||||
"@libcublas//:cublas",
|
||||
"@libcufft//:cufft",
|
||||
"@libcusolver//:cusolver",
|
||||
"@libcusparse//:cusparse",
|
||||
"@libnvjitlink//:nvjitlink",
|
||||
"@nccl",
|
||||
set_rpath = "$ORIGIN",
|
||||
)
|
||||
|
||||
copy_to_directory(
|
||||
name = "sandbox",
|
||||
srcs = [
|
||||
":zmlxcuda_so",
|
||||
":libpjrt_cuda.patchelf",
|
||||
"@cuda_nvcc//:libdevice",
|
||||
"@cuda_nvcc//:ptxas",
|
||||
"@cuda_nvcc//:nvlink",
|
||||
"@cuda_cupti//:so_files",
|
||||
"@cuda_nvtx//:so_files",
|
||||
"@cuda_nvcc//:so_files",
|
||||
"@cuda_nvrtc//:so_files",
|
||||
"@cuda_cudart//:so_files",
|
||||
"@cudnn//:so_files",
|
||||
"@libcublas//:so_files",
|
||||
"@libcufft//:so_files",
|
||||
"@libcusolver//:so_files",
|
||||
"@libcusparse//:so_files",
|
||||
"@libnvjitlink//:so_files",
|
||||
"@nccl//:so_files",
|
||||
"@zlib1g",
|
||||
],
|
||||
replace_prefixes = {
|
||||
"nvidia/nccl/lib": "lib",
|
||||
"nvvm/lib64": "lib",
|
||||
"libpjrt_cuda.patchelf": "lib",
|
||||
"lib/x86_64-linux-gnu": "lib",
|
||||
},
|
||||
add_directory_to_runfiles = False,
|
||||
include_external_repositories = ["**"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "libpjrt_cuda",
|
||||
data = [":sandbox"],
|
||||
deps = [
|
||||
"@cuda_cudart//:cuda",
|
||||
],
|
||||
linkopts = [
|
||||
# Defer function call resolution until the function is called
|
||||
# (lazy loading) rather than at load time.
|
||||
#
|
||||
# This is required because we want to let downstream use weak CUDA symbols.
|
||||
#
|
||||
# We force it here because -z,now (which resolve all symbols at load time),
|
||||
# is the default in most bazel CC toolchains as well as in certain linkers.
|
||||
"-Wl,-z,lazy",
|
||||
],
|
||||
visibility = ["@zml//runtimes/cuda:__subpackages__"],
|
||||
)
|
||||
|
||||
65
runtimes/cuda/packages.lock.json
Executable file
65
runtimes/cuda/packages.lock.json
Executable file
@ -0,0 +1,65 @@
|
||||
{
|
||||
"packages": [
|
||||
{
|
||||
"arch": "amd64",
|
||||
"dependencies": [
|
||||
{
|
||||
"key": "libc6_2.36-9-p-deb12u10_amd64",
|
||||
"name": "libc6",
|
||||
"version": "2.36-9+deb12u10"
|
||||
},
|
||||
{
|
||||
"key": "libgcc-s1_12.2.0-14-p-deb12u1_amd64",
|
||||
"name": "libgcc-s1",
|
||||
"version": "12.2.0-14+deb12u1"
|
||||
},
|
||||
{
|
||||
"key": "gcc-12-base_12.2.0-14-p-deb12u1_amd64",
|
||||
"name": "gcc-12-base",
|
||||
"version": "12.2.0-14+deb12u1"
|
||||
}
|
||||
],
|
||||
"key": "zlib1g_1-1.2.13.dfsg-1_amd64",
|
||||
"name": "zlib1g",
|
||||
"sha256": "d7dd1d1411fedf27f5e27650a6eff20ef294077b568f4c8c5e51466dc7c08ce4",
|
||||
"urls": [
|
||||
"https://snapshot-cloudflare.debian.org/archive/debian/20250711T030400Z/pool/main/z/zlib/zlib1g_1.2.13.dfsg-1_amd64.deb"
|
||||
],
|
||||
"version": "1:1.2.13.dfsg-1"
|
||||
},
|
||||
{
|
||||
"arch": "amd64",
|
||||
"dependencies": [],
|
||||
"key": "libc6_2.36-9-p-deb12u10_amd64",
|
||||
"name": "libc6",
|
||||
"sha256": "5dc83256f10ca4d0f2a53dd6583ffd0d0e319af30074ea6c82fb0ae77bd16365",
|
||||
"urls": [
|
||||
"https://snapshot-cloudflare.debian.org/archive/debian/20250711T030400Z/pool/main/g/glibc/libc6_2.36-9+deb12u10_amd64.deb"
|
||||
],
|
||||
"version": "2.36-9+deb12u10"
|
||||
},
|
||||
{
|
||||
"arch": "amd64",
|
||||
"dependencies": [],
|
||||
"key": "libgcc-s1_12.2.0-14-p-deb12u1_amd64",
|
||||
"name": "libgcc-s1",
|
||||
"sha256": "3016e62cb4b7cd8038822870601f5ed131befe942774d0f745622cc77d8a88f7",
|
||||
"urls": [
|
||||
"https://snapshot-cloudflare.debian.org/archive/debian/20250711T030400Z/pool/main/g/gcc-12/libgcc-s1_12.2.0-14+deb12u1_amd64.deb"
|
||||
],
|
||||
"version": "12.2.0-14+deb12u1"
|
||||
},
|
||||
{
|
||||
"arch": "amd64",
|
||||
"dependencies": [],
|
||||
"key": "gcc-12-base_12.2.0-14-p-deb12u1_amd64",
|
||||
"name": "gcc-12-base",
|
||||
"sha256": "1896a2aacf4ad681ff5eacc24a5b0ca4d5d9c9b9c9e4b6de5197bc1e116ea619",
|
||||
"urls": [
|
||||
"https://snapshot-cloudflare.debian.org/archive/debian/20250711T030400Z/pool/main/g/gcc-12/gcc-12-base_12.2.0-14+deb12u1_amd64.deb"
|
||||
],
|
||||
"version": "12.2.0-14+deb12u1"
|
||||
}
|
||||
],
|
||||
"version": 1
|
||||
}
|
||||
14
runtimes/cuda/packages.yaml
Normal file
14
runtimes/cuda/packages.yaml
Normal file
@ -0,0 +1,14 @@
|
||||
#
|
||||
# bazel run @apt_cuda//:lock
|
||||
#
|
||||
version: 1
|
||||
|
||||
sources:
|
||||
- channel: bookworm main
|
||||
url: https://snapshot-cloudflare.debian.org/archive/debian/20250711T030400Z
|
||||
|
||||
archs:
|
||||
- "amd64"
|
||||
|
||||
packages:
|
||||
- "zlib1g"
|
||||
@ -7,6 +7,7 @@ zig_library(
|
||||
"debug.zig",
|
||||
"flags.zig",
|
||||
"fmt.zig",
|
||||
"fs.zig",
|
||||
"io.zig",
|
||||
"json.zig",
|
||||
"math.zig",
|
||||
|
||||
13
stdx/fs.zig
Normal file
13
stdx/fs.zig
Normal file
@ -0,0 +1,13 @@
|
||||
const std = @import("std");
|
||||
|
||||
pub const path = struct {
|
||||
pub fn bufJoin(buf: []u8, paths: []const []const u8) ![]u8 {
|
||||
var fa: std.heap.FixedBufferAllocator = .init(buf);
|
||||
return try std.fs.path.join(fa.allocator(), paths);
|
||||
}
|
||||
|
||||
pub fn bufJoinZ(buf: []u8, paths: []const []const u8) ![:0]u8 {
|
||||
var fa: std.heap.FixedBufferAllocator = .init(buf);
|
||||
return try std.fs.path.joinZ(fa.allocator(), paths);
|
||||
}
|
||||
};
|
||||
@ -1,6 +1,7 @@
|
||||
pub const debug = @import("debug.zig");
|
||||
pub const flags = @import("flags.zig");
|
||||
pub const fmt = @import("fmt.zig");
|
||||
pub const fs = @import("fs.zig");
|
||||
pub const io = @import("io.zig");
|
||||
pub const json = @import("json.zig");
|
||||
pub const math = @import("math.zig");
|
||||
|
||||
@ -36,7 +36,6 @@ pub const Context = struct {
|
||||
inline for (comptime std.enums.values(runtimes.Platform)) |t| {
|
||||
if (runtimes.load(t)) |api| {
|
||||
Context.apis.set(t, api);
|
||||
if (t == .cuda) cuda.init();
|
||||
} else |_| {}
|
||||
}
|
||||
}
|
||||
@ -276,69 +275,3 @@ fn hostBufferFromPinnedBuffer(buffer_desc: *const pjrt.ffi.Buffer) HostBuffer {
|
||||
buffer_desc.data[0..buffer_shape.byteSize()],
|
||||
);
|
||||
}
|
||||
|
||||
pub const cuda = struct {
|
||||
pub var streamSynchronize: StreamSynchronize = @ptrFromInt(0xdeadc00da00);
|
||||
pub var cuLaunchHostFunc: CuLaunchHostFunc = @ptrFromInt(0xdeadc00da00);
|
||||
var _memcpyAsync: MemcpyAsync = @ptrFromInt(0xdeadc00da00);
|
||||
var _memcpyBlocking: MemcpyBlocking = @ptrFromInt(0xdeadc00da00);
|
||||
|
||||
pub const MemcpyKind = enum(c_int) {
|
||||
host_to_host = 0,
|
||||
host_to_device = 1,
|
||||
device_to_host = 2,
|
||||
device_to_device = 3,
|
||||
inferred = 4,
|
||||
};
|
||||
|
||||
const MemcpyAsync = *const fn (dst: *anyopaque, src: *const anyopaque, count: usize, kind: MemcpyKind, stream: ?*anyopaque) callconv(.C) c_int;
|
||||
const MemcpyBlocking = *const fn (dst: *anyopaque, src: *const anyopaque, count: usize, kind: MemcpyKind) callconv(.C) c_int;
|
||||
const StreamSynchronize = *const fn (stream: *anyopaque) callconv(.C) c_int;
|
||||
const CuLaunchHostFunc = *const fn (stream: *anyopaque, host_func: *const fn (user_data: *const anyopaque) callconv(.c) void, user_data: *const anyopaque) callconv(.c) c_int;
|
||||
|
||||
pub fn init() void {
|
||||
var cudart = std.DynLib.open("libcudart.so.12") catch {
|
||||
log.err("cudart not found, callback will segfault", .{});
|
||||
return;
|
||||
};
|
||||
defer cudart.close();
|
||||
|
||||
_memcpyAsync = cudart.lookup(MemcpyAsync, "cudaMemcpyAsync") orelse {
|
||||
@panic("cudaMemcpyAsync not found");
|
||||
};
|
||||
_memcpyBlocking = cudart.lookup(MemcpyBlocking, "cudaMemcpy") orelse {
|
||||
@panic("cudaMemcpy not found");
|
||||
};
|
||||
streamSynchronize = cudart.lookup(StreamSynchronize, "cudaStreamSynchronize") orelse {
|
||||
@panic("cudaStreamSynchronize not found");
|
||||
};
|
||||
cuLaunchHostFunc = cudart.lookup(CuLaunchHostFunc, "cudaLaunchHostFunc") orelse {
|
||||
@panic("cudaLaunchHostFunc not found");
|
||||
};
|
||||
}
|
||||
|
||||
pub fn memcpyToHostBlocking(dst: []u8, src: *const anyopaque) void {
|
||||
const err = _memcpyBlocking(dst.ptr, src, dst.len, .device_to_host);
|
||||
check(err);
|
||||
}
|
||||
|
||||
pub fn memcpyToDeviceBlocking(dst: *anyopaque, src: []const u8) void {
|
||||
const err = _memcpyBlocking(dst, src.ptr, src.len, .host_to_device);
|
||||
check(err);
|
||||
}
|
||||
|
||||
pub fn memcpyToDeviceAsync(dst: *anyopaque, src: []const u8, stream: ?*anyopaque) void {
|
||||
const err = _memcpyAsync(dst, src.ptr, src.len, .host_to_device, stream);
|
||||
check(err);
|
||||
}
|
||||
|
||||
pub fn memcpyToHostAsync(dst: []u8, src: *const anyopaque, stream: ?*anyopaque) void {
|
||||
const err = _memcpyAsync(dst.ptr, src, dst.len, .device_to_host, stream);
|
||||
check(err);
|
||||
}
|
||||
|
||||
pub fn check(err: c_int) void {
|
||||
if (err == 0) return;
|
||||
stdx.debug.panic("CUDA error: {d}", .{err});
|
||||
}
|
||||
};
|
||||
|
||||
@ -8,22 +8,28 @@ pub const Tracer = switch (builtin.os.tag) {
|
||||
};
|
||||
|
||||
const CudaTracer = struct {
|
||||
extern fn cudaProfilerStart() c_int;
|
||||
extern fn cudaProfilerStop() c_int;
|
||||
|
||||
extern fn nvtxMarkA(message: [*:0]const u8) void;
|
||||
extern fn nvtxRangeStartA(message: [*:0]const u8) c_int;
|
||||
extern fn nvtxRangeEnd(id: c_int) void;
|
||||
// Those symbols are defined in cudaProfiler.h but their implementation is in libcuda.so
|
||||
// They will be bound at call time after libcuda.so is loaded (as a needed dependency of libpjrt_cuda.so).
|
||||
const cuProfilerStart = @extern(*const fn () callconv(.C) c_int, .{ .name = "cuProfilerStart", .linkage = .weak }) orelse unreachable;
|
||||
const cuProfilerStop = @extern(*const fn () callconv(.C) c_int, .{ .name = "cuProfilerStop", .linkage = .weak }) orelse unreachable;
|
||||
|
||||
// Those symbols are defined in nvToolsExt.h which we don't want to provide.
|
||||
// However, we link with libnvToolsExt.so which provides them.
|
||||
// They will be bound at call time after libnvToolsExt.so is loaded (manually dlopen'ed by us).
|
||||
const nvtxMarkA = @extern(*const fn ([*:0]const u8) callconv(.C) void, .{ .name = "nvtxMarkA", .linkage = .weak }) orelse unreachable;
|
||||
const nvtxRangeStartA = @extern(*const fn ([*:0]const u8) callconv(.C) c_int, .{ .name = "nvtxRangeStartA", .linkage = .weak }) orelse unreachable;
|
||||
const nvtxRangeEnd = @extern(*const fn (c_int) callconv(.C) void, .{ .name = "nvtxRangeEnd", .linkage = .weak }) orelse unreachable;
|
||||
|
||||
pub fn init(name: [:0]const u8) CudaTracer {
|
||||
_ = name;
|
||||
_ = cudaProfilerStart();
|
||||
_ = cuProfilerStart();
|
||||
return .{};
|
||||
}
|
||||
|
||||
pub fn deinit(self: *const CudaTracer) void {
|
||||
_ = self;
|
||||
_ = cudaProfilerStop();
|
||||
_ = cuProfilerStop();
|
||||
}
|
||||
|
||||
pub fn event(self: *const CudaTracer, message: [:0]const u8) void {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user