Revert CUDA PJRT plugin version to 0.4.38 to address performance regression on XLA master.
This commit is contained in:
parent
76e314db9b
commit
8a25b1eb74
@ -208,8 +208,9 @@ def _cuda_impl(mctx):
|
||||
http_archive(
|
||||
name = "libpjrt_cuda",
|
||||
build_file = "libpjrt_cuda.BUILD.bazel",
|
||||
url = "https://github.com/zml/pjrt-artifacts/releases/download/v5.0.0/pjrt-cuda_linux-amd64.tar.gz",
|
||||
sha256 = "1c3ca76d887d112762d03ebb28f17a08beebf6338453c3044a36225e1678a113",
|
||||
url = "https://files.pythonhosted.org/packages/90/43/ac2c369e202e3e3e7e5aa7929b197801ba02eaf11868437adaa5341704e4/jax_cuda12_pjrt-0.4.38-py3-none-manylinux2014_x86_64.whl",
|
||||
type = "zip",
|
||||
sha256 = "83be4c59fbcf30077a60085d98e7d59dc738b1c91e0d628e4ac1779fde15ac2b",
|
||||
)
|
||||
|
||||
return mctx.extension_metadata(
|
||||
|
||||
@ -1,10 +1,16 @@
|
||||
const builtin = @import("builtin");
|
||||
const std = @import("std");
|
||||
const asynk = @import("async");
|
||||
const pjrt = @import("pjrt");
|
||||
const c = @import("c");
|
||||
|
||||
const nvidiaLibsPath = "/usr/local/cuda/lib64";
|
||||
const asynk = @import("async");
|
||||
const bazel_builtin = @import("bazel_builtin");
|
||||
const c = @import("c");
|
||||
const pjrt = @import("pjrt");
|
||||
const runfiles = @import("runfiles");
|
||||
const stdx = @import("stdx");
|
||||
|
||||
const nvidiaLibsPath = "/cuda/";
|
||||
|
||||
const log = std.log.scoped(.@"zml/runtime/cuda");
|
||||
|
||||
pub fn isEnabled() bool {
|
||||
return @hasDecl(c, "ZML_RUNTIME_CUDA");
|
||||
@ -22,7 +28,24 @@ fn hasCudaPathInLDPath() bool {
|
||||
return false;
|
||||
}
|
||||
|
||||
return std.mem.indexOf(u8, std.mem.span(ldLibraryPath), nvidiaLibsPath) != null;
|
||||
return std.ascii.indexOfIgnoreCase(std.mem.span(ldLibraryPath), nvidiaLibsPath) != null;
|
||||
}
|
||||
|
||||
fn setupXlaGpuCudaDirFlag() !void {
|
||||
var arena = std.heap.ArenaAllocator.init(std.heap.c_allocator);
|
||||
defer arena.deinit();
|
||||
|
||||
var r_ = try runfiles.Runfiles.create(.{ .allocator = arena.allocator() }) orelse {
|
||||
stdx.debug.panic("Unable to find CUDA directory", .{});
|
||||
};
|
||||
|
||||
const source_repo = bazel_builtin.current_repository;
|
||||
const r = r_.withSourceRepo(source_repo);
|
||||
const cuda_data_dir = (try r.rlocationAlloc(arena.allocator(), "libpjrt_cuda/sandbox")).?;
|
||||
const xla_flags = std.process.getEnvVarOwned(arena.allocator(), "XLA_FLAGS") catch "";
|
||||
const new_xla_flagsZ = try std.fmt.allocPrintZ(arena.allocator(), "{s} --xla_gpu_cuda_data_dir={s}", .{ xla_flags, cuda_data_dir });
|
||||
|
||||
_ = c.setenv("XLA_FLAGS", new_xla_flagsZ, 1);
|
||||
}
|
||||
|
||||
pub fn load() !*const pjrt.Api {
|
||||
@ -36,8 +59,12 @@ pub fn load() !*const pjrt.Api {
|
||||
return error.Unavailable;
|
||||
}
|
||||
if (hasCudaPathInLDPath()) {
|
||||
std.log.warn("Detected {s} in LD_LIBRARY_PATH. This can lead to undefined behaviors and crashes", .{nvidiaLibsPath});
|
||||
log.warn("Detected {s} in LD_LIBRARY_PATH. This can lead to undefined behaviors and crashes", .{nvidiaLibsPath});
|
||||
}
|
||||
|
||||
// CUDA path has to be set _before_ loading the PJRT plugin.
|
||||
// See https://github.com/openxla/xla/issues/21428
|
||||
try setupXlaGpuCudaDirFlag();
|
||||
|
||||
return try asynk.callBlocking(pjrt.Api.loadFrom, .{"libpjrt_cuda.so"});
|
||||
}
|
||||
|
||||
@ -25,7 +25,8 @@ copy_to_directory(
|
||||
cc_import(
|
||||
name = "libpjrt_cuda",
|
||||
data = [":sandbox"],
|
||||
shared_library = "libpjrt_cuda.so",
|
||||
shared_library = "jax_plugins/xla_cuda12/xla_cuda_plugin.so",
|
||||
soname = "libpjrt_cuda.so",
|
||||
add_needed = ["libzmlxcuda.so.0"],
|
||||
rename_dynamic_symbols = {
|
||||
"dlopen": "zmlxcuda_dlopen",
|
||||
|
||||
@ -900,7 +900,7 @@ fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, m
|
||||
}
|
||||
}
|
||||
switch (platform.target) {
|
||||
.cuda => cuda_dir: {
|
||||
.cuda => {
|
||||
// NVIDIA recommends these settings
|
||||
// https://github.com/NVIDIA/JAX-Toolbox?tab=readme-ov-file#environment-variables
|
||||
setFlag(&options, "xla_gpu_enable_triton_gemm", false);
|
||||
@ -914,17 +914,6 @@ fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, m
|
||||
// setFlag(&options, "xla_gpu_enable_dynamic_slice_fusion", true);
|
||||
// setFlag(&options, "xla_gpu_enable_while_loop_double_buffering", true);
|
||||
// setFlag(&options, "xla_gpu_use_runtime_fusion", true);
|
||||
|
||||
var r_ = try runfiles.Runfiles.create(.{ .allocator = arena }) orelse {
|
||||
log.warn("Bazel runfile not found !", .{});
|
||||
break :cuda_dir;
|
||||
};
|
||||
defer r_.deinit(arena);
|
||||
const source_repo = @import("bazel_builtin").current_repository;
|
||||
const r = r_.withSourceRepo(source_repo);
|
||||
const cuda_data_dir = (try r.rlocationAlloc(arena, "libpjrt_cuda/sandbox")).?;
|
||||
log.debug("xla_gpu_cuda_data_dir: {s}", .{cuda_data_dir});
|
||||
setFlag(&options, "xla_gpu_cuda_data_dir", cuda_data_dir);
|
||||
},
|
||||
.rocm => {
|
||||
// Disable Triton GEMM on ROCM. For some reason it's much, much slower when
|
||||
|
||||
1051
zml/tokenizer.zig
1051
zml/tokenizer.zig
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user