From 2d321d232d6c172608e0551d5032c85777cf49b2 Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Wed, 26 Mar 2025 11:18:29 +0000 Subject: [PATCH] runtimes/cuda: sandbox CUDA dependencies by removing them from the leaf binary, sandboxing the dependency graph, marking dlopen direct dependencies as NEEDED, setting RPATH to the sandbox, loading the PJRT plugin from the sandbox, and enabling weak CUDA symbols without direct linking. --- MODULE.bazel | 9 +- pjrt/pjrt.zig | 3 +- runtimes/common/BUILD.bazel | 2 + runtimes/cuda/cuda.bzl | 169 ++++++++++++------------- runtimes/cuda/cuda.zig | 50 +++++--- runtimes/cuda/libpjrt_cuda.BUILD.bazel | 88 ++++++++----- runtimes/cuda/packages.lock.json | 65 ++++++++++ runtimes/cuda/packages.yaml | 14 ++ stdx/BUILD.bazel | 1 + stdx/fs.zig | 13 ++ stdx/stdx.zig | 1 + zml/context.zig | 67 ---------- zml/tools/tracer.zig | 20 ++- 13 files changed, 286 insertions(+), 216 deletions(-) create mode 100755 runtimes/cuda/packages.lock.json create mode 100644 runtimes/cuda/packages.yaml create mode 100644 stdx/fs.zig diff --git a/MODULE.bazel b/MODULE.bazel index a9934bf..3a2b671 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -86,9 +86,6 @@ cpu = use_extension("//runtimes/cpu:cpu.bzl", "cpu_pjrt_plugin") use_repo(cpu, "libpjrt_cpu_darwin_amd64", "libpjrt_cpu_darwin_arm64", "libpjrt_cpu_linux_amd64") cuda = use_extension("//runtimes/cuda:cuda.bzl", "cuda_packages") - -inject_repo(cuda, "zlib1g") - use_repo(cuda, "libpjrt_cuda") rocm = use_extension("//runtimes/rocm:rocm.bzl", "rocm_packages") @@ -159,6 +156,12 @@ apt.install( manifest = "//runtimes/common:packages.yaml", ) use_repo(apt, "apt_common") +apt.install( + name = "apt_cuda", + lock = "//runtimes/cuda:packages.lock.json", + manifest = "//runtimes/cuda:packages.yaml", +) +use_repo(apt, "apt_cuda") apt.install( name = "apt_rocm", lock = "//runtimes/rocm:packages.lock.json", diff --git a/pjrt/pjrt.zig b/pjrt/pjrt.zig index a0be0b2..58d9ac2 100644 --- a/pjrt/pjrt.zig +++ b/pjrt/pjrt.zig @@ -77,7 +77,8 @@ pub const Api = struct { pub fn loadFrom(library: [:0]const u8) !*const Api { var lib: std.DynLib = switch (builtin.os.tag) { .linux => blk: { - const handle = std.c.dlopen(library, .{ .LAZY = true, .GLOBAL = false, .NODELETE = true }) orelse { + // We use RTLD_GLOBAL so that symbols from NEEDED libraries are available in the global namespace. + const handle = std.c.dlopen(library, .{ .LAZY = true, .GLOBAL = true, .NODELETE = true }) orelse { log.err("Unable to dlopen plugin: {s}", .{library}); return error.FileNotFound; }; diff --git a/runtimes/common/BUILD.bazel b/runtimes/common/BUILD.bazel index 6a5dd45..e22d428 100644 --- a/runtimes/common/BUILD.bazel +++ b/runtimes/common/BUILD.bazel @@ -1,3 +1,5 @@ +load("@rules_zig//zig:defs.bzl", "zig_library") + exports_files( ["packages.lock.json"], visibility = ["//runtimes:__subpackages__"], diff --git a/runtimes/cuda/cuda.bzl b/runtimes/cuda/cuda.bzl index f8d96b5..2d27705 100644 --- a/runtimes/cuda/cuda.bzl +++ b/runtimes/cuda/cuda.bzl @@ -1,5 +1,6 @@ load("@bazel_skylib//lib:paths.bzl", "paths") load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") +load("//bazel:http_deb_archive.bzl", "http_deb_archive") load("//runtimes/common:packages.bzl", "packages") _BUILD_FILE_DEFAULT_VISIBILITY = """\ @@ -16,47 +17,54 @@ CUDNN_REDIST_PREFIX = "https://developer.download.nvidia.com/compute/cudnn/redis CUDNN_VERSION = "9.8.0" CUDNN_REDIST_JSON_SHA256 = "a1599fa1f8dcb81235157be5de5ab7d3936e75dfc4e1e442d07970afad3c4843" +_UBUNTU_PACKAGES = { + "zlib1g": packages.filegroup(name = "zlib1g", srcs = ["lib/x86_64-linux-gnu/libz.so.1"]), +} + CUDA_PACKAGES = { "cuda_cudart": "\n".join([ + # Driver API only packages.cc_library( - name = "cudart", + name = "cuda", hdrs = ["include/cuda.h"], includes = ["include"], - deps = [":cudart_so", ":cuda_so"], ), - packages.cc_import( - name = "cudart_so", - shared_library = "lib/libcudart.so.12", - ), - packages.cc_import( - name = "cuda_so", - shared_library = "lib/stubs/libcuda.so", + #TODO: Remove me as soon we use the Driver API in tracer.zig + packages.filegroup( + name = "so_files", + srcs = ["lib/libcudart.so.12"], ), ]), - "cuda_cupti": packages.cc_import( - name = "cupti", - shared_library = "lib/libcupti.so.12", + "cuda_cupti": packages.filegroup( + name = "so_files", + srcs = ["lib/libcupti.so.12"], ), - "cuda_nvtx": packages.cc_import_glob_hdrs( - name = "nvtx", - hdrs_glob = ["include/nvtx3/**/*.h"], - shared_library = "lib/libnvToolsExt.so.1", + "cuda_nvtx": "\n".join([ + # packages.cc_library( + # name = "nvtx", + # hdrs = glob(["include/nvtx3/**/*.h"]), + # visibility = ["//visibility:public"], + # ), + packages.filegroup( + name = "so_files", + srcs = ["lib/libnvToolsExt.so.1"], + ), + ]), + "libcufft": packages.filegroup( + name = "so_files", + srcs = ["lib/libcufft.so.11"], ), - "libcufft": packages.cc_import( - name = "cufft", - shared_library = "lib/libcufft.so.11", + "libcusolver": packages.filegroup( + name = "so_files", + srcs = ["lib/libcusolver.so.11"], ), - "libcusolver": packages.cc_import( - name = "cusolver", - shared_library = "lib/libcusolver.so.11", + "libcusparse": packages.filegroup( + name = "so_files", + srcs = ["lib/libcusparse.so.12"], ), - "libcusparse": packages.cc_import( - name = "cusparse", - shared_library = "lib/libcusparse.so.12", - ), - "libnvjitlink": packages.cc_import( - name = "nvjitlink", - shared_library = "lib/libnvJitLink.so.12", + "libnvjitlink": packages.filegroup( + name = "so_files", + srcs = ["lib/libnvJitLink.so.12"], ), "cuda_nvcc": "\n".join([ packages.filegroup( @@ -71,9 +79,9 @@ CUDA_PACKAGES = { name = "libdevice", srcs = ["nvvm/libdevice/libdevice.10.bc"], ), - packages.cc_import( - name = "nvvm", - shared_library = "nvvm/lib64/libnvvm.so.4", + packages.filegroup( + name = "so_files", + srcs = ["nvvm/lib64/libnvvm.so.4"], ), packages.cc_import( name = "nvptxcompiler", @@ -81,73 +89,40 @@ CUDA_PACKAGES = { ), ]), "cuda_nvrtc": "\n".join([ - packages.cc_import( - name = "nvrtc", - shared_library = "lib/libnvrtc.so.12", - deps = [":nvrtc_builtins"], - ), - packages.cc_import( - name = "nvrtc_builtins", - shared_library = "lib/libnvrtc-builtins.so.12.8", + packages.filegroup( + name = "so_files", + srcs = [ + "lib/libnvrtc.so.12", + "lib/libnvrtc-builtins.so.12.8", + ], ), ]), "libcublas": "\n".join([ - packages.cc_import( - name = "cublasLt", - shared_library = "lib/libcublasLt.so.12", - ), - packages.cc_import( - name = "cublas", - shared_library = "lib/libcublas.so.12", - deps = [":cublasLt"], + packages.filegroup( + name = "so_files", + srcs = [ + "lib/libcublasLt.so.12", + "lib/libcublas.so.12", + ], ), ]), } CUDNN_PACKAGES = { "cudnn": "\n".join([ - packages.cc_import( - name = "cudnn", - shared_library = "lib/libcudnn.so.9", - deps = [ - ":cudnn_adv", - ":cudnn_ops", - ":cudnn_cnn", - ":cudnn_graph", - ":cudnn_engines_precompiled", - ":cudnn_engines_runtime_compiled", - ":cudnn_heuristic", + packages.filegroup( + name = "so_files", + srcs = [ + "lib/libcudnn.so.9", + "lib/libcudnn_adv.so.9", + "lib/libcudnn_ops.so.9", + "lib/libcudnn_cnn.so.9", + "lib/libcudnn_graph.so.9", + "lib/libcudnn_engines_precompiled.so.9", + "lib/libcudnn_engines_runtime_compiled.so.9", + "lib/libcudnn_heuristic.so.9", ], ), - packages.cc_import( - name = "cudnn_adv", - shared_library = "lib/libcudnn_adv.so.9", - ), - packages.cc_import( - name = "cudnn_ops", - shared_library = "lib/libcudnn_ops.so.9", - ), - packages.cc_import( - name = "cudnn_cnn", - shared_library = "lib/libcudnn_cnn.so.9", - deps = [":cudnn_ops"], - ), - packages.cc_import( - name = "cudnn_graph", - shared_library = "lib/libcudnn_graph.so.9", - ), - packages.cc_import( - name = "cudnn_engines_precompiled", - shared_library = "lib/libcudnn_engines_precompiled.so.9", - ), - packages.cc_import( - name = "cudnn_engines_runtime_compiled", - shared_library = "lib/libcudnn_engines_runtime_compiled.so.9", - ), - packages.cc_import( - name = "cudnn_heuristic", - shared_library = "lib/libcudnn_heuristic.so.9", - ), ]), } @@ -161,6 +136,9 @@ def _read_redist_json(mctx, url, sha256): return json.decode(mctx.read(fname)) def _cuda_impl(mctx): + loaded_packages = packages.read(mctx, [ + "@zml//runtimes/cuda:packages.lock.json", + ]) CUDA_REDIST = _read_redist_json( mctx, url = CUDA_REDIST_PREFIX + "redistrib_{}.json".format(CUDA_VERSION), @@ -173,6 +151,15 @@ def _cuda_impl(mctx): sha256 = CUDNN_REDIST_JSON_SHA256, ) + for pkg_name, build_file_content in _UBUNTU_PACKAGES.items(): + pkg = loaded_packages[pkg_name] + http_deb_archive( + name = pkg_name, + urls = pkg["urls"], + sha256 = pkg["sha256"], + build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + build_file_content, + ) + for pkg, build_file_content in CUDA_PACKAGES.items(): pkg_data = CUDA_REDIST[pkg] arch_data = pkg_data.get(ARCH) @@ -205,9 +192,9 @@ def _cuda_impl(mctx): urls = ["https://files.pythonhosted.org/packages/11/0c/8c78b7603f4e685624a3ea944940f1e75f36d71bd6504330511f4a0e1557/nvidia_nccl_cu12-2.25.1-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl"], type = "zip", sha256 = "362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a", - build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + packages.cc_import( - name = "nccl", - shared_library = "nvidia/nccl/lib/libnccl.so.2", + build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + packages.filegroup( + name = "so_files", + srcs = ["nvidia/nccl/lib/libnccl.so.2"], ), ) diff --git a/runtimes/cuda/cuda.zig b/runtimes/cuda/cuda.zig index 46eb69d..a7fc166 100644 --- a/runtimes/cuda/cuda.zig +++ b/runtimes/cuda/cuda.zig @@ -31,20 +31,9 @@ fn hasCudaPathInLDPath() bool { return std.ascii.indexOfIgnoreCase(std.mem.span(ldLibraryPath), nvidiaLibsPath) != null; } -fn setupXlaGpuCudaDirFlag() !void { - var arena = std.heap.ArenaAllocator.init(std.heap.c_allocator); - defer arena.deinit(); - - var r_ = try runfiles.Runfiles.create(.{ .allocator = arena.allocator() }) orelse { - log.warn("Unable to find CUDA directory. Using system defaults.", .{}); - return; - }; - - const source_repo = bazel_builtin.current_repository; - const r = r_.withSourceRepo(source_repo); - const cuda_data_dir = (try r.rlocationAlloc(arena.allocator(), "libpjrt_cuda/sandbox")).?; - const xla_flags = std.process.getEnvVarOwned(arena.allocator(), "XLA_FLAGS") catch ""; - const new_xla_flagsZ = try std.fmt.allocPrintZ(arena.allocator(), "{s} --xla_gpu_cuda_data_dir={s}", .{ xla_flags, cuda_data_dir }); +fn setupXlaGpuCudaDirFlag(allocator: std.mem.Allocator, sandbox: []const u8) !void { + const xla_flags = std.process.getEnvVarOwned(allocator, "XLA_FLAGS") catch ""; + const new_xla_flagsZ = try std.fmt.allocPrintZ(allocator, "{s} --xla_gpu_cuda_data_dir={s}", .{ xla_flags, sandbox }); _ = c.setenv("XLA_FLAGS", new_xla_flagsZ, 1); } @@ -63,9 +52,38 @@ pub fn load() !*const pjrt.Api { log.warn("Detected {s} in LD_LIBRARY_PATH. This can lead to undefined behaviors and crashes", .{nvidiaLibsPath}); } + var arena = std.heap.ArenaAllocator.init(std.heap.c_allocator); + defer arena.deinit(); + + var r_ = try runfiles.Runfiles.create(.{ .allocator = arena.allocator() }) orelse { + stdx.debug.panic("Unable to find runfiles", .{}); + }; + + const source_repo = bazel_builtin.current_repository; + const r = r_.withSourceRepo(source_repo); + + var path_buf: [std.fs.max_path_bytes]u8 = undefined; + const sandbox_path = try r.rlocation("libpjrt_cuda/sandbox", &path_buf) orelse { + log.err("Failed to find sandbox path for CUDA runtime", .{}); + return error.FileNotFound; + }; + // CUDA path has to be set _before_ loading the PJRT plugin. // See https://github.com/openxla/xla/issues/21428 - try setupXlaGpuCudaDirFlag(); + try setupXlaGpuCudaDirFlag(arena.allocator(), sandbox_path); - return try asynk.callBlocking(pjrt.Api.loadFrom, .{"libpjrt_cuda.so"}); + { + var lib_path_buf: [std.fs.max_path_bytes]u8 = undefined; + const path = try stdx.fs.path.bufJoinZ(&lib_path_buf, &.{ sandbox_path, "lib", "libnvToolsExt.so.1" }); + _ = std.c.dlopen(path, .{ .NOW = true, .GLOBAL = true }) orelse { + log.err("Unable to dlopen libnvToolsExt.so.1: {s}", .{std.c.dlerror().?}); + return error.DlError; + }; + } + + return blk: { + var lib_path_buf: [std.fs.max_path_bytes]u8 = undefined; + const path = try stdx.fs.path.bufJoinZ(&lib_path_buf, &.{ sandbox_path, "lib", "libpjrt_cuda.so" }); + break :blk asynk.callBlocking(pjrt.Api.loadFrom, .{path}); + }; } diff --git a/runtimes/cuda/libpjrt_cuda.BUILD.bazel b/runtimes/cuda/libpjrt_cuda.BUILD.bazel index 8881a01..2449d44 100644 --- a/runtimes/cuda/libpjrt_cuda.BUILD.bazel +++ b/runtimes/cuda/libpjrt_cuda.BUILD.bazel @@ -1,9 +1,10 @@ load("@aspect_bazel_lib//lib:copy_to_directory.bzl", "copy_to_directory") load("@zml//bazel:cc_import.bzl", "cc_import") +load("@zml//bazel:patchelf.bzl", "patchelf") cc_shared_library( name = "zmlxcuda_so", - shared_lib_name = "libzmlxcuda.so.0", + shared_lib_name = "lib/libzmlxcuda.so.0", deps = ["@zml//runtimes/cuda:zmlxcuda_lib"], ) @@ -12,40 +13,65 @@ cc_import( shared_library = ":zmlxcuda_so", ) -copy_to_directory( - name = "sandbox", - srcs = [ - "@cuda_nvcc//:libdevice", - "@cuda_nvcc//:ptxas", - "@cuda_nvcc//:nvlink", - ], - include_external_repositories = ["**"], -) - -cc_import( - name = "libpjrt_cuda", - data = [":sandbox"], +patchelf( + name = "libpjrt_cuda.patchelf", shared_library = "libpjrt_cuda.so", - add_needed = ["libzmlxcuda.so.0"], + add_needed = [ + "libzmlxcuda.so.0", + ], rename_dynamic_symbols = { "dlopen": "zmlxcuda_dlopen", }, - visibility = ["@zml//runtimes/cuda:__subpackages__"], - deps = [ - ":zmlxcuda", - "@cuda_cudart//:cudart", - "@cuda_cupti//:cupti", - "@cuda_nvtx//:nvtx", - "@cuda_nvcc//:nvptxcompiler", - "@cuda_nvcc//:nvvm", - "@cuda_nvrtc//:nvrtc", - "@cudnn//:cudnn", - "@libcublas//:cublas", - "@libcufft//:cufft", - "@libcusolver//:cusolver", - "@libcusparse//:cusparse", - "@libnvjitlink//:nvjitlink", - "@nccl", + set_rpath = "$ORIGIN", +) + +copy_to_directory( + name = "sandbox", + srcs = [ + ":zmlxcuda_so", + ":libpjrt_cuda.patchelf", + "@cuda_nvcc//:libdevice", + "@cuda_nvcc//:ptxas", + "@cuda_nvcc//:nvlink", + "@cuda_cupti//:so_files", + "@cuda_nvtx//:so_files", + "@cuda_nvcc//:so_files", + "@cuda_nvrtc//:so_files", + "@cuda_cudart//:so_files", + "@cudnn//:so_files", + "@libcublas//:so_files", + "@libcufft//:so_files", + "@libcusolver//:so_files", + "@libcusparse//:so_files", + "@libnvjitlink//:so_files", + "@nccl//:so_files", "@zlib1g", ], + replace_prefixes = { + "nvidia/nccl/lib": "lib", + "nvvm/lib64": "lib", + "libpjrt_cuda.patchelf": "lib", + "lib/x86_64-linux-gnu": "lib", + }, + add_directory_to_runfiles = False, + include_external_repositories = ["**"], +) + +cc_library( + name = "libpjrt_cuda", + data = [":sandbox"], + deps = [ + "@cuda_cudart//:cuda", + ], + linkopts = [ + # Defer function call resolution until the function is called + # (lazy loading) rather than at load time. + # + # This is required because we want to let downstream use weak CUDA symbols. + # + # We force it here because -z,now (which resolve all symbols at load time), + # is the default in most bazel CC toolchains as well as in certain linkers. + "-Wl,-z,lazy", + ], + visibility = ["@zml//runtimes/cuda:__subpackages__"], ) diff --git a/runtimes/cuda/packages.lock.json b/runtimes/cuda/packages.lock.json new file mode 100755 index 0000000..b80a2c3 --- /dev/null +++ b/runtimes/cuda/packages.lock.json @@ -0,0 +1,65 @@ +{ + "packages": [ + { + "arch": "amd64", + "dependencies": [ + { + "key": "libc6_2.36-9-p-deb12u10_amd64", + "name": "libc6", + "version": "2.36-9+deb12u10" + }, + { + "key": "libgcc-s1_12.2.0-14-p-deb12u1_amd64", + "name": "libgcc-s1", + "version": "12.2.0-14+deb12u1" + }, + { + "key": "gcc-12-base_12.2.0-14-p-deb12u1_amd64", + "name": "gcc-12-base", + "version": "12.2.0-14+deb12u1" + } + ], + "key": "zlib1g_1-1.2.13.dfsg-1_amd64", + "name": "zlib1g", + "sha256": "d7dd1d1411fedf27f5e27650a6eff20ef294077b568f4c8c5e51466dc7c08ce4", + "urls": [ + "https://snapshot-cloudflare.debian.org/archive/debian/20250711T030400Z/pool/main/z/zlib/zlib1g_1.2.13.dfsg-1_amd64.deb" + ], + "version": "1:1.2.13.dfsg-1" + }, + { + "arch": "amd64", + "dependencies": [], + "key": "libc6_2.36-9-p-deb12u10_amd64", + "name": "libc6", + "sha256": "5dc83256f10ca4d0f2a53dd6583ffd0d0e319af30074ea6c82fb0ae77bd16365", + "urls": [ + "https://snapshot-cloudflare.debian.org/archive/debian/20250711T030400Z/pool/main/g/glibc/libc6_2.36-9+deb12u10_amd64.deb" + ], + "version": "2.36-9+deb12u10" + }, + { + "arch": "amd64", + "dependencies": [], + "key": "libgcc-s1_12.2.0-14-p-deb12u1_amd64", + "name": "libgcc-s1", + "sha256": "3016e62cb4b7cd8038822870601f5ed131befe942774d0f745622cc77d8a88f7", + "urls": [ + "https://snapshot-cloudflare.debian.org/archive/debian/20250711T030400Z/pool/main/g/gcc-12/libgcc-s1_12.2.0-14+deb12u1_amd64.deb" + ], + "version": "12.2.0-14+deb12u1" + }, + { + "arch": "amd64", + "dependencies": [], + "key": "gcc-12-base_12.2.0-14-p-deb12u1_amd64", + "name": "gcc-12-base", + "sha256": "1896a2aacf4ad681ff5eacc24a5b0ca4d5d9c9b9c9e4b6de5197bc1e116ea619", + "urls": [ + "https://snapshot-cloudflare.debian.org/archive/debian/20250711T030400Z/pool/main/g/gcc-12/gcc-12-base_12.2.0-14+deb12u1_amd64.deb" + ], + "version": "12.2.0-14+deb12u1" + } + ], + "version": 1 +} \ No newline at end of file diff --git a/runtimes/cuda/packages.yaml b/runtimes/cuda/packages.yaml new file mode 100644 index 0000000..4b3b04c --- /dev/null +++ b/runtimes/cuda/packages.yaml @@ -0,0 +1,14 @@ +# +# bazel run @apt_cuda//:lock +# +version: 1 + +sources: + - channel: bookworm main + url: https://snapshot-cloudflare.debian.org/archive/debian/20250711T030400Z + +archs: + - "amd64" + +packages: + - "zlib1g" diff --git a/stdx/BUILD.bazel b/stdx/BUILD.bazel index 59abb65..171d136 100644 --- a/stdx/BUILD.bazel +++ b/stdx/BUILD.bazel @@ -7,6 +7,7 @@ zig_library( "debug.zig", "flags.zig", "fmt.zig", + "fs.zig", "io.zig", "json.zig", "math.zig", diff --git a/stdx/fs.zig b/stdx/fs.zig new file mode 100644 index 0000000..3f278e3 --- /dev/null +++ b/stdx/fs.zig @@ -0,0 +1,13 @@ +const std = @import("std"); + +pub const path = struct { + pub fn bufJoin(buf: []u8, paths: []const []const u8) ![]u8 { + var fa: std.heap.FixedBufferAllocator = .init(buf); + return try std.fs.path.join(fa.allocator(), paths); + } + + pub fn bufJoinZ(buf: []u8, paths: []const []const u8) ![:0]u8 { + var fa: std.heap.FixedBufferAllocator = .init(buf); + return try std.fs.path.joinZ(fa.allocator(), paths); + } +}; diff --git a/stdx/stdx.zig b/stdx/stdx.zig index cda9635..87cc231 100644 --- a/stdx/stdx.zig +++ b/stdx/stdx.zig @@ -1,6 +1,7 @@ pub const debug = @import("debug.zig"); pub const flags = @import("flags.zig"); pub const fmt = @import("fmt.zig"); +pub const fs = @import("fs.zig"); pub const io = @import("io.zig"); pub const json = @import("json.zig"); pub const math = @import("math.zig"); diff --git a/zml/context.zig b/zml/context.zig index 0e11c9b..bde867f 100644 --- a/zml/context.zig +++ b/zml/context.zig @@ -36,7 +36,6 @@ pub const Context = struct { inline for (comptime std.enums.values(runtimes.Platform)) |t| { if (runtimes.load(t)) |api| { Context.apis.set(t, api); - if (t == .cuda) cuda.init(); } else |_| {} } } @@ -276,69 +275,3 @@ fn hostBufferFromPinnedBuffer(buffer_desc: *const pjrt.ffi.Buffer) HostBuffer { buffer_desc.data[0..buffer_shape.byteSize()], ); } - -pub const cuda = struct { - pub var streamSynchronize: StreamSynchronize = @ptrFromInt(0xdeadc00da00); - pub var cuLaunchHostFunc: CuLaunchHostFunc = @ptrFromInt(0xdeadc00da00); - var _memcpyAsync: MemcpyAsync = @ptrFromInt(0xdeadc00da00); - var _memcpyBlocking: MemcpyBlocking = @ptrFromInt(0xdeadc00da00); - - pub const MemcpyKind = enum(c_int) { - host_to_host = 0, - host_to_device = 1, - device_to_host = 2, - device_to_device = 3, - inferred = 4, - }; - - const MemcpyAsync = *const fn (dst: *anyopaque, src: *const anyopaque, count: usize, kind: MemcpyKind, stream: ?*anyopaque) callconv(.C) c_int; - const MemcpyBlocking = *const fn (dst: *anyopaque, src: *const anyopaque, count: usize, kind: MemcpyKind) callconv(.C) c_int; - const StreamSynchronize = *const fn (stream: *anyopaque) callconv(.C) c_int; - const CuLaunchHostFunc = *const fn (stream: *anyopaque, host_func: *const fn (user_data: *const anyopaque) callconv(.c) void, user_data: *const anyopaque) callconv(.c) c_int; - - pub fn init() void { - var cudart = std.DynLib.open("libcudart.so.12") catch { - log.err("cudart not found, callback will segfault", .{}); - return; - }; - defer cudart.close(); - - _memcpyAsync = cudart.lookup(MemcpyAsync, "cudaMemcpyAsync") orelse { - @panic("cudaMemcpyAsync not found"); - }; - _memcpyBlocking = cudart.lookup(MemcpyBlocking, "cudaMemcpy") orelse { - @panic("cudaMemcpy not found"); - }; - streamSynchronize = cudart.lookup(StreamSynchronize, "cudaStreamSynchronize") orelse { - @panic("cudaStreamSynchronize not found"); - }; - cuLaunchHostFunc = cudart.lookup(CuLaunchHostFunc, "cudaLaunchHostFunc") orelse { - @panic("cudaLaunchHostFunc not found"); - }; - } - - pub fn memcpyToHostBlocking(dst: []u8, src: *const anyopaque) void { - const err = _memcpyBlocking(dst.ptr, src, dst.len, .device_to_host); - check(err); - } - - pub fn memcpyToDeviceBlocking(dst: *anyopaque, src: []const u8) void { - const err = _memcpyBlocking(dst, src.ptr, src.len, .host_to_device); - check(err); - } - - pub fn memcpyToDeviceAsync(dst: *anyopaque, src: []const u8, stream: ?*anyopaque) void { - const err = _memcpyAsync(dst, src.ptr, src.len, .host_to_device, stream); - check(err); - } - - pub fn memcpyToHostAsync(dst: []u8, src: *const anyopaque, stream: ?*anyopaque) void { - const err = _memcpyAsync(dst.ptr, src, dst.len, .device_to_host, stream); - check(err); - } - - pub fn check(err: c_int) void { - if (err == 0) return; - stdx.debug.panic("CUDA error: {d}", .{err}); - } -}; diff --git a/zml/tools/tracer.zig b/zml/tools/tracer.zig index 0576956..06bcd6f 100644 --- a/zml/tools/tracer.zig +++ b/zml/tools/tracer.zig @@ -8,22 +8,28 @@ pub const Tracer = switch (builtin.os.tag) { }; const CudaTracer = struct { - extern fn cudaProfilerStart() c_int; - extern fn cudaProfilerStop() c_int; - extern fn nvtxMarkA(message: [*:0]const u8) void; - extern fn nvtxRangeStartA(message: [*:0]const u8) c_int; - extern fn nvtxRangeEnd(id: c_int) void; + // Those symbols are defined in cudaProfiler.h but their implementation is in libcuda.so + // They will be bound at call time after libcuda.so is loaded (as a needed dependency of libpjrt_cuda.so). + const cuProfilerStart = @extern(*const fn () callconv(.C) c_int, .{ .name = "cuProfilerStart", .linkage = .weak }) orelse unreachable; + const cuProfilerStop = @extern(*const fn () callconv(.C) c_int, .{ .name = "cuProfilerStop", .linkage = .weak }) orelse unreachable; + + // Those symbols are defined in nvToolsExt.h which we don't want to provide. + // However, we link with libnvToolsExt.so which provides them. + // They will be bound at call time after libnvToolsExt.so is loaded (manually dlopen'ed by us). + const nvtxMarkA = @extern(*const fn ([*:0]const u8) callconv(.C) void, .{ .name = "nvtxMarkA", .linkage = .weak }) orelse unreachable; + const nvtxRangeStartA = @extern(*const fn ([*:0]const u8) callconv(.C) c_int, .{ .name = "nvtxRangeStartA", .linkage = .weak }) orelse unreachable; + const nvtxRangeEnd = @extern(*const fn (c_int) callconv(.C) void, .{ .name = "nvtxRangeEnd", .linkage = .weak }) orelse unreachable; pub fn init(name: [:0]const u8) CudaTracer { _ = name; - _ = cudaProfilerStart(); + _ = cuProfilerStart(); return .{}; } pub fn deinit(self: *const CudaTracer) void { _ = self; - _ = cudaProfilerStop(); + _ = cuProfilerStop(); } pub fn event(self: *const CudaTracer, message: [:0]const u8) void {