diff --git a/MODULE.bazel b/MODULE.bazel index a9934bf..3a2b671 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -86,9 +86,6 @@ cpu = use_extension("//runtimes/cpu:cpu.bzl", "cpu_pjrt_plugin") use_repo(cpu, "libpjrt_cpu_darwin_amd64", "libpjrt_cpu_darwin_arm64", "libpjrt_cpu_linux_amd64") cuda = use_extension("//runtimes/cuda:cuda.bzl", "cuda_packages") - -inject_repo(cuda, "zlib1g") - use_repo(cuda, "libpjrt_cuda") rocm = use_extension("//runtimes/rocm:rocm.bzl", "rocm_packages") @@ -159,6 +156,12 @@ apt.install( manifest = "//runtimes/common:packages.yaml", ) use_repo(apt, "apt_common") +apt.install( + name = "apt_cuda", + lock = "//runtimes/cuda:packages.lock.json", + manifest = "//runtimes/cuda:packages.yaml", +) +use_repo(apt, "apt_cuda") apt.install( name = "apt_rocm", lock = "//runtimes/rocm:packages.lock.json", diff --git a/pjrt/pjrt.zig b/pjrt/pjrt.zig index a0be0b2..58d9ac2 100644 --- a/pjrt/pjrt.zig +++ b/pjrt/pjrt.zig @@ -77,7 +77,8 @@ pub const Api = struct { pub fn loadFrom(library: [:0]const u8) !*const Api { var lib: std.DynLib = switch (builtin.os.tag) { .linux => blk: { - const handle = std.c.dlopen(library, .{ .LAZY = true, .GLOBAL = false, .NODELETE = true }) orelse { + // We use RTLD_GLOBAL so that symbols from NEEDED libraries are available in the global namespace. + const handle = std.c.dlopen(library, .{ .LAZY = true, .GLOBAL = true, .NODELETE = true }) orelse { log.err("Unable to dlopen plugin: {s}", .{library}); return error.FileNotFound; }; diff --git a/runtimes/common/BUILD.bazel b/runtimes/common/BUILD.bazel index 6a5dd45..e22d428 100644 --- a/runtimes/common/BUILD.bazel +++ b/runtimes/common/BUILD.bazel @@ -1,3 +1,5 @@ +load("@rules_zig//zig:defs.bzl", "zig_library") + exports_files( ["packages.lock.json"], visibility = ["//runtimes:__subpackages__"], diff --git a/runtimes/cuda/cuda.bzl b/runtimes/cuda/cuda.bzl index f8d96b5..2d27705 100644 --- a/runtimes/cuda/cuda.bzl +++ b/runtimes/cuda/cuda.bzl @@ -1,5 +1,6 @@ load("@bazel_skylib//lib:paths.bzl", "paths") load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") +load("//bazel:http_deb_archive.bzl", "http_deb_archive") load("//runtimes/common:packages.bzl", "packages") _BUILD_FILE_DEFAULT_VISIBILITY = """\ @@ -16,47 +17,54 @@ CUDNN_REDIST_PREFIX = "https://developer.download.nvidia.com/compute/cudnn/redis CUDNN_VERSION = "9.8.0" CUDNN_REDIST_JSON_SHA256 = "a1599fa1f8dcb81235157be5de5ab7d3936e75dfc4e1e442d07970afad3c4843" +_UBUNTU_PACKAGES = { + "zlib1g": packages.filegroup(name = "zlib1g", srcs = ["lib/x86_64-linux-gnu/libz.so.1"]), +} + CUDA_PACKAGES = { "cuda_cudart": "\n".join([ + # Driver API only packages.cc_library( - name = "cudart", + name = "cuda", hdrs = ["include/cuda.h"], includes = ["include"], - deps = [":cudart_so", ":cuda_so"], ), - packages.cc_import( - name = "cudart_so", - shared_library = "lib/libcudart.so.12", - ), - packages.cc_import( - name = "cuda_so", - shared_library = "lib/stubs/libcuda.so", + #TODO: Remove me as soon we use the Driver API in tracer.zig + packages.filegroup( + name = "so_files", + srcs = ["lib/libcudart.so.12"], ), ]), - "cuda_cupti": packages.cc_import( - name = "cupti", - shared_library = "lib/libcupti.so.12", + "cuda_cupti": packages.filegroup( + name = "so_files", + srcs = ["lib/libcupti.so.12"], ), - "cuda_nvtx": packages.cc_import_glob_hdrs( - name = "nvtx", - hdrs_glob = ["include/nvtx3/**/*.h"], - shared_library = "lib/libnvToolsExt.so.1", + "cuda_nvtx": "\n".join([ + # packages.cc_library( + # name = "nvtx", + # hdrs = glob(["include/nvtx3/**/*.h"]), + # visibility = ["//visibility:public"], + # ), + packages.filegroup( + name = "so_files", + srcs = ["lib/libnvToolsExt.so.1"], + ), + ]), + "libcufft": packages.filegroup( + name = "so_files", + srcs = ["lib/libcufft.so.11"], ), - "libcufft": packages.cc_import( - name = "cufft", - shared_library = "lib/libcufft.so.11", + "libcusolver": packages.filegroup( + name = "so_files", + srcs = ["lib/libcusolver.so.11"], ), - "libcusolver": packages.cc_import( - name = "cusolver", - shared_library = "lib/libcusolver.so.11", + "libcusparse": packages.filegroup( + name = "so_files", + srcs = ["lib/libcusparse.so.12"], ), - "libcusparse": packages.cc_import( - name = "cusparse", - shared_library = "lib/libcusparse.so.12", - ), - "libnvjitlink": packages.cc_import( - name = "nvjitlink", - shared_library = "lib/libnvJitLink.so.12", + "libnvjitlink": packages.filegroup( + name = "so_files", + srcs = ["lib/libnvJitLink.so.12"], ), "cuda_nvcc": "\n".join([ packages.filegroup( @@ -71,9 +79,9 @@ CUDA_PACKAGES = { name = "libdevice", srcs = ["nvvm/libdevice/libdevice.10.bc"], ), - packages.cc_import( - name = "nvvm", - shared_library = "nvvm/lib64/libnvvm.so.4", + packages.filegroup( + name = "so_files", + srcs = ["nvvm/lib64/libnvvm.so.4"], ), packages.cc_import( name = "nvptxcompiler", @@ -81,73 +89,40 @@ CUDA_PACKAGES = { ), ]), "cuda_nvrtc": "\n".join([ - packages.cc_import( - name = "nvrtc", - shared_library = "lib/libnvrtc.so.12", - deps = [":nvrtc_builtins"], - ), - packages.cc_import( - name = "nvrtc_builtins", - shared_library = "lib/libnvrtc-builtins.so.12.8", + packages.filegroup( + name = "so_files", + srcs = [ + "lib/libnvrtc.so.12", + "lib/libnvrtc-builtins.so.12.8", + ], ), ]), "libcublas": "\n".join([ - packages.cc_import( - name = "cublasLt", - shared_library = "lib/libcublasLt.so.12", - ), - packages.cc_import( - name = "cublas", - shared_library = "lib/libcublas.so.12", - deps = [":cublasLt"], + packages.filegroup( + name = "so_files", + srcs = [ + "lib/libcublasLt.so.12", + "lib/libcublas.so.12", + ], ), ]), } CUDNN_PACKAGES = { "cudnn": "\n".join([ - packages.cc_import( - name = "cudnn", - shared_library = "lib/libcudnn.so.9", - deps = [ - ":cudnn_adv", - ":cudnn_ops", - ":cudnn_cnn", - ":cudnn_graph", - ":cudnn_engines_precompiled", - ":cudnn_engines_runtime_compiled", - ":cudnn_heuristic", + packages.filegroup( + name = "so_files", + srcs = [ + "lib/libcudnn.so.9", + "lib/libcudnn_adv.so.9", + "lib/libcudnn_ops.so.9", + "lib/libcudnn_cnn.so.9", + "lib/libcudnn_graph.so.9", + "lib/libcudnn_engines_precompiled.so.9", + "lib/libcudnn_engines_runtime_compiled.so.9", + "lib/libcudnn_heuristic.so.9", ], ), - packages.cc_import( - name = "cudnn_adv", - shared_library = "lib/libcudnn_adv.so.9", - ), - packages.cc_import( - name = "cudnn_ops", - shared_library = "lib/libcudnn_ops.so.9", - ), - packages.cc_import( - name = "cudnn_cnn", - shared_library = "lib/libcudnn_cnn.so.9", - deps = [":cudnn_ops"], - ), - packages.cc_import( - name = "cudnn_graph", - shared_library = "lib/libcudnn_graph.so.9", - ), - packages.cc_import( - name = "cudnn_engines_precompiled", - shared_library = "lib/libcudnn_engines_precompiled.so.9", - ), - packages.cc_import( - name = "cudnn_engines_runtime_compiled", - shared_library = "lib/libcudnn_engines_runtime_compiled.so.9", - ), - packages.cc_import( - name = "cudnn_heuristic", - shared_library = "lib/libcudnn_heuristic.so.9", - ), ]), } @@ -161,6 +136,9 @@ def _read_redist_json(mctx, url, sha256): return json.decode(mctx.read(fname)) def _cuda_impl(mctx): + loaded_packages = packages.read(mctx, [ + "@zml//runtimes/cuda:packages.lock.json", + ]) CUDA_REDIST = _read_redist_json( mctx, url = CUDA_REDIST_PREFIX + "redistrib_{}.json".format(CUDA_VERSION), @@ -173,6 +151,15 @@ def _cuda_impl(mctx): sha256 = CUDNN_REDIST_JSON_SHA256, ) + for pkg_name, build_file_content in _UBUNTU_PACKAGES.items(): + pkg = loaded_packages[pkg_name] + http_deb_archive( + name = pkg_name, + urls = pkg["urls"], + sha256 = pkg["sha256"], + build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + build_file_content, + ) + for pkg, build_file_content in CUDA_PACKAGES.items(): pkg_data = CUDA_REDIST[pkg] arch_data = pkg_data.get(ARCH) @@ -205,9 +192,9 @@ def _cuda_impl(mctx): urls = ["https://files.pythonhosted.org/packages/11/0c/8c78b7603f4e685624a3ea944940f1e75f36d71bd6504330511f4a0e1557/nvidia_nccl_cu12-2.25.1-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl"], type = "zip", sha256 = "362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a", - build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + packages.cc_import( - name = "nccl", - shared_library = "nvidia/nccl/lib/libnccl.so.2", + build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + packages.filegroup( + name = "so_files", + srcs = ["nvidia/nccl/lib/libnccl.so.2"], ), ) diff --git a/runtimes/cuda/cuda.zig b/runtimes/cuda/cuda.zig index 46eb69d..a7fc166 100644 --- a/runtimes/cuda/cuda.zig +++ b/runtimes/cuda/cuda.zig @@ -31,20 +31,9 @@ fn hasCudaPathInLDPath() bool { return std.ascii.indexOfIgnoreCase(std.mem.span(ldLibraryPath), nvidiaLibsPath) != null; } -fn setupXlaGpuCudaDirFlag() !void { - var arena = std.heap.ArenaAllocator.init(std.heap.c_allocator); - defer arena.deinit(); - - var r_ = try runfiles.Runfiles.create(.{ .allocator = arena.allocator() }) orelse { - log.warn("Unable to find CUDA directory. Using system defaults.", .{}); - return; - }; - - const source_repo = bazel_builtin.current_repository; - const r = r_.withSourceRepo(source_repo); - const cuda_data_dir = (try r.rlocationAlloc(arena.allocator(), "libpjrt_cuda/sandbox")).?; - const xla_flags = std.process.getEnvVarOwned(arena.allocator(), "XLA_FLAGS") catch ""; - const new_xla_flagsZ = try std.fmt.allocPrintZ(arena.allocator(), "{s} --xla_gpu_cuda_data_dir={s}", .{ xla_flags, cuda_data_dir }); +fn setupXlaGpuCudaDirFlag(allocator: std.mem.Allocator, sandbox: []const u8) !void { + const xla_flags = std.process.getEnvVarOwned(allocator, "XLA_FLAGS") catch ""; + const new_xla_flagsZ = try std.fmt.allocPrintZ(allocator, "{s} --xla_gpu_cuda_data_dir={s}", .{ xla_flags, sandbox }); _ = c.setenv("XLA_FLAGS", new_xla_flagsZ, 1); } @@ -63,9 +52,38 @@ pub fn load() !*const pjrt.Api { log.warn("Detected {s} in LD_LIBRARY_PATH. This can lead to undefined behaviors and crashes", .{nvidiaLibsPath}); } + var arena = std.heap.ArenaAllocator.init(std.heap.c_allocator); + defer arena.deinit(); + + var r_ = try runfiles.Runfiles.create(.{ .allocator = arena.allocator() }) orelse { + stdx.debug.panic("Unable to find runfiles", .{}); + }; + + const source_repo = bazel_builtin.current_repository; + const r = r_.withSourceRepo(source_repo); + + var path_buf: [std.fs.max_path_bytes]u8 = undefined; + const sandbox_path = try r.rlocation("libpjrt_cuda/sandbox", &path_buf) orelse { + log.err("Failed to find sandbox path for CUDA runtime", .{}); + return error.FileNotFound; + }; + // CUDA path has to be set _before_ loading the PJRT plugin. // See https://github.com/openxla/xla/issues/21428 - try setupXlaGpuCudaDirFlag(); + try setupXlaGpuCudaDirFlag(arena.allocator(), sandbox_path); - return try asynk.callBlocking(pjrt.Api.loadFrom, .{"libpjrt_cuda.so"}); + { + var lib_path_buf: [std.fs.max_path_bytes]u8 = undefined; + const path = try stdx.fs.path.bufJoinZ(&lib_path_buf, &.{ sandbox_path, "lib", "libnvToolsExt.so.1" }); + _ = std.c.dlopen(path, .{ .NOW = true, .GLOBAL = true }) orelse { + log.err("Unable to dlopen libnvToolsExt.so.1: {s}", .{std.c.dlerror().?}); + return error.DlError; + }; + } + + return blk: { + var lib_path_buf: [std.fs.max_path_bytes]u8 = undefined; + const path = try stdx.fs.path.bufJoinZ(&lib_path_buf, &.{ sandbox_path, "lib", "libpjrt_cuda.so" }); + break :blk asynk.callBlocking(pjrt.Api.loadFrom, .{path}); + }; } diff --git a/runtimes/cuda/libpjrt_cuda.BUILD.bazel b/runtimes/cuda/libpjrt_cuda.BUILD.bazel index 8881a01..2449d44 100644 --- a/runtimes/cuda/libpjrt_cuda.BUILD.bazel +++ b/runtimes/cuda/libpjrt_cuda.BUILD.bazel @@ -1,9 +1,10 @@ load("@aspect_bazel_lib//lib:copy_to_directory.bzl", "copy_to_directory") load("@zml//bazel:cc_import.bzl", "cc_import") +load("@zml//bazel:patchelf.bzl", "patchelf") cc_shared_library( name = "zmlxcuda_so", - shared_lib_name = "libzmlxcuda.so.0", + shared_lib_name = "lib/libzmlxcuda.so.0", deps = ["@zml//runtimes/cuda:zmlxcuda_lib"], ) @@ -12,40 +13,65 @@ cc_import( shared_library = ":zmlxcuda_so", ) -copy_to_directory( - name = "sandbox", - srcs = [ - "@cuda_nvcc//:libdevice", - "@cuda_nvcc//:ptxas", - "@cuda_nvcc//:nvlink", - ], - include_external_repositories = ["**"], -) - -cc_import( - name = "libpjrt_cuda", - data = [":sandbox"], +patchelf( + name = "libpjrt_cuda.patchelf", shared_library = "libpjrt_cuda.so", - add_needed = ["libzmlxcuda.so.0"], + add_needed = [ + "libzmlxcuda.so.0", + ], rename_dynamic_symbols = { "dlopen": "zmlxcuda_dlopen", }, - visibility = ["@zml//runtimes/cuda:__subpackages__"], - deps = [ - ":zmlxcuda", - "@cuda_cudart//:cudart", - "@cuda_cupti//:cupti", - "@cuda_nvtx//:nvtx", - "@cuda_nvcc//:nvptxcompiler", - "@cuda_nvcc//:nvvm", - "@cuda_nvrtc//:nvrtc", - "@cudnn//:cudnn", - "@libcublas//:cublas", - "@libcufft//:cufft", - "@libcusolver//:cusolver", - "@libcusparse//:cusparse", - "@libnvjitlink//:nvjitlink", - "@nccl", + set_rpath = "$ORIGIN", +) + +copy_to_directory( + name = "sandbox", + srcs = [ + ":zmlxcuda_so", + ":libpjrt_cuda.patchelf", + "@cuda_nvcc//:libdevice", + "@cuda_nvcc//:ptxas", + "@cuda_nvcc//:nvlink", + "@cuda_cupti//:so_files", + "@cuda_nvtx//:so_files", + "@cuda_nvcc//:so_files", + "@cuda_nvrtc//:so_files", + "@cuda_cudart//:so_files", + "@cudnn//:so_files", + "@libcublas//:so_files", + "@libcufft//:so_files", + "@libcusolver//:so_files", + "@libcusparse//:so_files", + "@libnvjitlink//:so_files", + "@nccl//:so_files", "@zlib1g", ], + replace_prefixes = { + "nvidia/nccl/lib": "lib", + "nvvm/lib64": "lib", + "libpjrt_cuda.patchelf": "lib", + "lib/x86_64-linux-gnu": "lib", + }, + add_directory_to_runfiles = False, + include_external_repositories = ["**"], +) + +cc_library( + name = "libpjrt_cuda", + data = [":sandbox"], + deps = [ + "@cuda_cudart//:cuda", + ], + linkopts = [ + # Defer function call resolution until the function is called + # (lazy loading) rather than at load time. + # + # This is required because we want to let downstream use weak CUDA symbols. + # + # We force it here because -z,now (which resolve all symbols at load time), + # is the default in most bazel CC toolchains as well as in certain linkers. + "-Wl,-z,lazy", + ], + visibility = ["@zml//runtimes/cuda:__subpackages__"], ) diff --git a/runtimes/cuda/packages.lock.json b/runtimes/cuda/packages.lock.json new file mode 100755 index 0000000..b80a2c3 --- /dev/null +++ b/runtimes/cuda/packages.lock.json @@ -0,0 +1,65 @@ +{ + "packages": [ + { + "arch": "amd64", + "dependencies": [ + { + "key": "libc6_2.36-9-p-deb12u10_amd64", + "name": "libc6", + "version": "2.36-9+deb12u10" + }, + { + "key": "libgcc-s1_12.2.0-14-p-deb12u1_amd64", + "name": "libgcc-s1", + "version": "12.2.0-14+deb12u1" + }, + { + "key": "gcc-12-base_12.2.0-14-p-deb12u1_amd64", + "name": "gcc-12-base", + "version": "12.2.0-14+deb12u1" + } + ], + "key": "zlib1g_1-1.2.13.dfsg-1_amd64", + "name": "zlib1g", + "sha256": "d7dd1d1411fedf27f5e27650a6eff20ef294077b568f4c8c5e51466dc7c08ce4", + "urls": [ + "https://snapshot-cloudflare.debian.org/archive/debian/20250711T030400Z/pool/main/z/zlib/zlib1g_1.2.13.dfsg-1_amd64.deb" + ], + "version": "1:1.2.13.dfsg-1" + }, + { + "arch": "amd64", + "dependencies": [], + "key": "libc6_2.36-9-p-deb12u10_amd64", + "name": "libc6", + "sha256": "5dc83256f10ca4d0f2a53dd6583ffd0d0e319af30074ea6c82fb0ae77bd16365", + "urls": [ + "https://snapshot-cloudflare.debian.org/archive/debian/20250711T030400Z/pool/main/g/glibc/libc6_2.36-9+deb12u10_amd64.deb" + ], + "version": "2.36-9+deb12u10" + }, + { + "arch": "amd64", + "dependencies": [], + "key": "libgcc-s1_12.2.0-14-p-deb12u1_amd64", + "name": "libgcc-s1", + "sha256": "3016e62cb4b7cd8038822870601f5ed131befe942774d0f745622cc77d8a88f7", + "urls": [ + "https://snapshot-cloudflare.debian.org/archive/debian/20250711T030400Z/pool/main/g/gcc-12/libgcc-s1_12.2.0-14+deb12u1_amd64.deb" + ], + "version": "12.2.0-14+deb12u1" + }, + { + "arch": "amd64", + "dependencies": [], + "key": "gcc-12-base_12.2.0-14-p-deb12u1_amd64", + "name": "gcc-12-base", + "sha256": "1896a2aacf4ad681ff5eacc24a5b0ca4d5d9c9b9c9e4b6de5197bc1e116ea619", + "urls": [ + "https://snapshot-cloudflare.debian.org/archive/debian/20250711T030400Z/pool/main/g/gcc-12/gcc-12-base_12.2.0-14+deb12u1_amd64.deb" + ], + "version": "12.2.0-14+deb12u1" + } + ], + "version": 1 +} \ No newline at end of file diff --git a/runtimes/cuda/packages.yaml b/runtimes/cuda/packages.yaml new file mode 100644 index 0000000..4b3b04c --- /dev/null +++ b/runtimes/cuda/packages.yaml @@ -0,0 +1,14 @@ +# +# bazel run @apt_cuda//:lock +# +version: 1 + +sources: + - channel: bookworm main + url: https://snapshot-cloudflare.debian.org/archive/debian/20250711T030400Z + +archs: + - "amd64" + +packages: + - "zlib1g" diff --git a/stdx/BUILD.bazel b/stdx/BUILD.bazel index 59abb65..171d136 100644 --- a/stdx/BUILD.bazel +++ b/stdx/BUILD.bazel @@ -7,6 +7,7 @@ zig_library( "debug.zig", "flags.zig", "fmt.zig", + "fs.zig", "io.zig", "json.zig", "math.zig", diff --git a/stdx/fs.zig b/stdx/fs.zig new file mode 100644 index 0000000..3f278e3 --- /dev/null +++ b/stdx/fs.zig @@ -0,0 +1,13 @@ +const std = @import("std"); + +pub const path = struct { + pub fn bufJoin(buf: []u8, paths: []const []const u8) ![]u8 { + var fa: std.heap.FixedBufferAllocator = .init(buf); + return try std.fs.path.join(fa.allocator(), paths); + } + + pub fn bufJoinZ(buf: []u8, paths: []const []const u8) ![:0]u8 { + var fa: std.heap.FixedBufferAllocator = .init(buf); + return try std.fs.path.joinZ(fa.allocator(), paths); + } +}; diff --git a/stdx/stdx.zig b/stdx/stdx.zig index cda9635..87cc231 100644 --- a/stdx/stdx.zig +++ b/stdx/stdx.zig @@ -1,6 +1,7 @@ pub const debug = @import("debug.zig"); pub const flags = @import("flags.zig"); pub const fmt = @import("fmt.zig"); +pub const fs = @import("fs.zig"); pub const io = @import("io.zig"); pub const json = @import("json.zig"); pub const math = @import("math.zig"); diff --git a/zml/context.zig b/zml/context.zig index 0e11c9b..bde867f 100644 --- a/zml/context.zig +++ b/zml/context.zig @@ -36,7 +36,6 @@ pub const Context = struct { inline for (comptime std.enums.values(runtimes.Platform)) |t| { if (runtimes.load(t)) |api| { Context.apis.set(t, api); - if (t == .cuda) cuda.init(); } else |_| {} } } @@ -276,69 +275,3 @@ fn hostBufferFromPinnedBuffer(buffer_desc: *const pjrt.ffi.Buffer) HostBuffer { buffer_desc.data[0..buffer_shape.byteSize()], ); } - -pub const cuda = struct { - pub var streamSynchronize: StreamSynchronize = @ptrFromInt(0xdeadc00da00); - pub var cuLaunchHostFunc: CuLaunchHostFunc = @ptrFromInt(0xdeadc00da00); - var _memcpyAsync: MemcpyAsync = @ptrFromInt(0xdeadc00da00); - var _memcpyBlocking: MemcpyBlocking = @ptrFromInt(0xdeadc00da00); - - pub const MemcpyKind = enum(c_int) { - host_to_host = 0, - host_to_device = 1, - device_to_host = 2, - device_to_device = 3, - inferred = 4, - }; - - const MemcpyAsync = *const fn (dst: *anyopaque, src: *const anyopaque, count: usize, kind: MemcpyKind, stream: ?*anyopaque) callconv(.C) c_int; - const MemcpyBlocking = *const fn (dst: *anyopaque, src: *const anyopaque, count: usize, kind: MemcpyKind) callconv(.C) c_int; - const StreamSynchronize = *const fn (stream: *anyopaque) callconv(.C) c_int; - const CuLaunchHostFunc = *const fn (stream: *anyopaque, host_func: *const fn (user_data: *const anyopaque) callconv(.c) void, user_data: *const anyopaque) callconv(.c) c_int; - - pub fn init() void { - var cudart = std.DynLib.open("libcudart.so.12") catch { - log.err("cudart not found, callback will segfault", .{}); - return; - }; - defer cudart.close(); - - _memcpyAsync = cudart.lookup(MemcpyAsync, "cudaMemcpyAsync") orelse { - @panic("cudaMemcpyAsync not found"); - }; - _memcpyBlocking = cudart.lookup(MemcpyBlocking, "cudaMemcpy") orelse { - @panic("cudaMemcpy not found"); - }; - streamSynchronize = cudart.lookup(StreamSynchronize, "cudaStreamSynchronize") orelse { - @panic("cudaStreamSynchronize not found"); - }; - cuLaunchHostFunc = cudart.lookup(CuLaunchHostFunc, "cudaLaunchHostFunc") orelse { - @panic("cudaLaunchHostFunc not found"); - }; - } - - pub fn memcpyToHostBlocking(dst: []u8, src: *const anyopaque) void { - const err = _memcpyBlocking(dst.ptr, src, dst.len, .device_to_host); - check(err); - } - - pub fn memcpyToDeviceBlocking(dst: *anyopaque, src: []const u8) void { - const err = _memcpyBlocking(dst, src.ptr, src.len, .host_to_device); - check(err); - } - - pub fn memcpyToDeviceAsync(dst: *anyopaque, src: []const u8, stream: ?*anyopaque) void { - const err = _memcpyAsync(dst, src.ptr, src.len, .host_to_device, stream); - check(err); - } - - pub fn memcpyToHostAsync(dst: []u8, src: *const anyopaque, stream: ?*anyopaque) void { - const err = _memcpyAsync(dst.ptr, src, dst.len, .device_to_host, stream); - check(err); - } - - pub fn check(err: c_int) void { - if (err == 0) return; - stdx.debug.panic("CUDA error: {d}", .{err}); - } -}; diff --git a/zml/tools/tracer.zig b/zml/tools/tracer.zig index 0576956..06bcd6f 100644 --- a/zml/tools/tracer.zig +++ b/zml/tools/tracer.zig @@ -8,22 +8,28 @@ pub const Tracer = switch (builtin.os.tag) { }; const CudaTracer = struct { - extern fn cudaProfilerStart() c_int; - extern fn cudaProfilerStop() c_int; - extern fn nvtxMarkA(message: [*:0]const u8) void; - extern fn nvtxRangeStartA(message: [*:0]const u8) c_int; - extern fn nvtxRangeEnd(id: c_int) void; + // Those symbols are defined in cudaProfiler.h but their implementation is in libcuda.so + // They will be bound at call time after libcuda.so is loaded (as a needed dependency of libpjrt_cuda.so). + const cuProfilerStart = @extern(*const fn () callconv(.C) c_int, .{ .name = "cuProfilerStart", .linkage = .weak }) orelse unreachable; + const cuProfilerStop = @extern(*const fn () callconv(.C) c_int, .{ .name = "cuProfilerStop", .linkage = .weak }) orelse unreachable; + + // Those symbols are defined in nvToolsExt.h which we don't want to provide. + // However, we link with libnvToolsExt.so which provides them. + // They will be bound at call time after libnvToolsExt.so is loaded (manually dlopen'ed by us). + const nvtxMarkA = @extern(*const fn ([*:0]const u8) callconv(.C) void, .{ .name = "nvtxMarkA", .linkage = .weak }) orelse unreachable; + const nvtxRangeStartA = @extern(*const fn ([*:0]const u8) callconv(.C) c_int, .{ .name = "nvtxRangeStartA", .linkage = .weak }) orelse unreachable; + const nvtxRangeEnd = @extern(*const fn (c_int) callconv(.C) void, .{ .name = "nvtxRangeEnd", .linkage = .weak }) orelse unreachable; pub fn init(name: [:0]const u8) CudaTracer { _ = name; - _ = cudaProfilerStart(); + _ = cuProfilerStart(); return .{}; } pub fn deinit(self: *const CudaTracer) void { _ = self; - _ = cudaProfilerStop(); + _ = cuProfilerStop(); } pub fn event(self: *const CudaTracer, message: [:0]const u8) void {