Revert CUDA PJRT plugin version to 0.4.38 to address performance regression on XLA master.

This commit is contained in:
Tarry Singh 2024-03-05 17:04:42 +00:00
parent 76e314db9b
commit 8a25b1eb74
5 changed files with 39 additions and 1072 deletions

View File

@ -208,8 +208,9 @@ def _cuda_impl(mctx):
http_archive( http_archive(
name = "libpjrt_cuda", name = "libpjrt_cuda",
build_file = "libpjrt_cuda.BUILD.bazel", build_file = "libpjrt_cuda.BUILD.bazel",
url = "https://github.com/zml/pjrt-artifacts/releases/download/v5.0.0/pjrt-cuda_linux-amd64.tar.gz", url = "https://files.pythonhosted.org/packages/90/43/ac2c369e202e3e3e7e5aa7929b197801ba02eaf11868437adaa5341704e4/jax_cuda12_pjrt-0.4.38-py3-none-manylinux2014_x86_64.whl",
sha256 = "1c3ca76d887d112762d03ebb28f17a08beebf6338453c3044a36225e1678a113", type = "zip",
sha256 = "83be4c59fbcf30077a60085d98e7d59dc738b1c91e0d628e4ac1779fde15ac2b",
) )
return mctx.extension_metadata( return mctx.extension_metadata(

View File

@ -1,10 +1,16 @@
const builtin = @import("builtin"); const builtin = @import("builtin");
const std = @import("std"); const std = @import("std");
const asynk = @import("async");
const pjrt = @import("pjrt");
const c = @import("c");
const nvidiaLibsPath = "/usr/local/cuda/lib64"; const asynk = @import("async");
const bazel_builtin = @import("bazel_builtin");
const c = @import("c");
const pjrt = @import("pjrt");
const runfiles = @import("runfiles");
const stdx = @import("stdx");
const nvidiaLibsPath = "/cuda/";
const log = std.log.scoped(.@"zml/runtime/cuda");
pub fn isEnabled() bool { pub fn isEnabled() bool {
return @hasDecl(c, "ZML_RUNTIME_CUDA"); return @hasDecl(c, "ZML_RUNTIME_CUDA");
@ -22,7 +28,24 @@ fn hasCudaPathInLDPath() bool {
return false; return false;
} }
return std.mem.indexOf(u8, std.mem.span(ldLibraryPath), nvidiaLibsPath) != null; return std.ascii.indexOfIgnoreCase(std.mem.span(ldLibraryPath), nvidiaLibsPath) != null;
}
fn setupXlaGpuCudaDirFlag() !void {
var arena = std.heap.ArenaAllocator.init(std.heap.c_allocator);
defer arena.deinit();
var r_ = try runfiles.Runfiles.create(.{ .allocator = arena.allocator() }) orelse {
stdx.debug.panic("Unable to find CUDA directory", .{});
};
const source_repo = bazel_builtin.current_repository;
const r = r_.withSourceRepo(source_repo);
const cuda_data_dir = (try r.rlocationAlloc(arena.allocator(), "libpjrt_cuda/sandbox")).?;
const xla_flags = std.process.getEnvVarOwned(arena.allocator(), "XLA_FLAGS") catch "";
const new_xla_flagsZ = try std.fmt.allocPrintZ(arena.allocator(), "{s} --xla_gpu_cuda_data_dir={s}", .{ xla_flags, cuda_data_dir });
_ = c.setenv("XLA_FLAGS", new_xla_flagsZ, 1);
} }
pub fn load() !*const pjrt.Api { pub fn load() !*const pjrt.Api {
@ -36,8 +59,12 @@ pub fn load() !*const pjrt.Api {
return error.Unavailable; return error.Unavailable;
} }
if (hasCudaPathInLDPath()) { if (hasCudaPathInLDPath()) {
std.log.warn("Detected {s} in LD_LIBRARY_PATH. This can lead to undefined behaviors and crashes", .{nvidiaLibsPath}); log.warn("Detected {s} in LD_LIBRARY_PATH. This can lead to undefined behaviors and crashes", .{nvidiaLibsPath});
} }
// CUDA path has to be set _before_ loading the PJRT plugin.
// See https://github.com/openxla/xla/issues/21428
try setupXlaGpuCudaDirFlag();
return try asynk.callBlocking(pjrt.Api.loadFrom, .{"libpjrt_cuda.so"}); return try asynk.callBlocking(pjrt.Api.loadFrom, .{"libpjrt_cuda.so"});
} }

View File

@ -25,7 +25,8 @@ copy_to_directory(
cc_import( cc_import(
name = "libpjrt_cuda", name = "libpjrt_cuda",
data = [":sandbox"], data = [":sandbox"],
shared_library = "libpjrt_cuda.so", shared_library = "jax_plugins/xla_cuda12/xla_cuda_plugin.so",
soname = "libpjrt_cuda.so",
add_needed = ["libzmlxcuda.so.0"], add_needed = ["libzmlxcuda.so.0"],
rename_dynamic_symbols = { rename_dynamic_symbols = {
"dlopen": "zmlxcuda_dlopen", "dlopen": "zmlxcuda_dlopen",

View File

@ -900,7 +900,7 @@ fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, m
} }
} }
switch (platform.target) { switch (platform.target) {
.cuda => cuda_dir: { .cuda => {
// NVIDIA recommends these settings // NVIDIA recommends these settings
// https://github.com/NVIDIA/JAX-Toolbox?tab=readme-ov-file#environment-variables // https://github.com/NVIDIA/JAX-Toolbox?tab=readme-ov-file#environment-variables
setFlag(&options, "xla_gpu_enable_triton_gemm", false); setFlag(&options, "xla_gpu_enable_triton_gemm", false);
@ -914,17 +914,6 @@ fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, m
// setFlag(&options, "xla_gpu_enable_dynamic_slice_fusion", true); // setFlag(&options, "xla_gpu_enable_dynamic_slice_fusion", true);
// setFlag(&options, "xla_gpu_enable_while_loop_double_buffering", true); // setFlag(&options, "xla_gpu_enable_while_loop_double_buffering", true);
// setFlag(&options, "xla_gpu_use_runtime_fusion", true); // setFlag(&options, "xla_gpu_use_runtime_fusion", true);
var r_ = try runfiles.Runfiles.create(.{ .allocator = arena }) orelse {
log.warn("Bazel runfile not found !", .{});
break :cuda_dir;
};
defer r_.deinit(arena);
const source_repo = @import("bazel_builtin").current_repository;
const r = r_.withSourceRepo(source_repo);
const cuda_data_dir = (try r.rlocationAlloc(arena, "libpjrt_cuda/sandbox")).?;
log.debug("xla_gpu_cuda_data_dir: {s}", .{cuda_data_dir});
setFlag(&options, "xla_gpu_cuda_data_dir", cuda_data_dir);
}, },
.rocm => { .rocm => {
// Disable Triton GEMM on ROCM. For some reason it's much, much slower when // Disable Triton GEMM on ROCM. For some reason it's much, much slower when

File diff suppressed because it is too large Load Diff