Revert CUDA PJRT plugin version to 0.4.38 to address performance regression on XLA master.
This commit is contained in:
parent
76e314db9b
commit
8a25b1eb74
@ -208,8 +208,9 @@ def _cuda_impl(mctx):
|
|||||||
http_archive(
|
http_archive(
|
||||||
name = "libpjrt_cuda",
|
name = "libpjrt_cuda",
|
||||||
build_file = "libpjrt_cuda.BUILD.bazel",
|
build_file = "libpjrt_cuda.BUILD.bazel",
|
||||||
url = "https://github.com/zml/pjrt-artifacts/releases/download/v5.0.0/pjrt-cuda_linux-amd64.tar.gz",
|
url = "https://files.pythonhosted.org/packages/90/43/ac2c369e202e3e3e7e5aa7929b197801ba02eaf11868437adaa5341704e4/jax_cuda12_pjrt-0.4.38-py3-none-manylinux2014_x86_64.whl",
|
||||||
sha256 = "1c3ca76d887d112762d03ebb28f17a08beebf6338453c3044a36225e1678a113",
|
type = "zip",
|
||||||
|
sha256 = "83be4c59fbcf30077a60085d98e7d59dc738b1c91e0d628e4ac1779fde15ac2b",
|
||||||
)
|
)
|
||||||
|
|
||||||
return mctx.extension_metadata(
|
return mctx.extension_metadata(
|
||||||
|
|||||||
@ -1,10 +1,16 @@
|
|||||||
const builtin = @import("builtin");
|
const builtin = @import("builtin");
|
||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
const asynk = @import("async");
|
|
||||||
const pjrt = @import("pjrt");
|
|
||||||
const c = @import("c");
|
|
||||||
|
|
||||||
const nvidiaLibsPath = "/usr/local/cuda/lib64";
|
const asynk = @import("async");
|
||||||
|
const bazel_builtin = @import("bazel_builtin");
|
||||||
|
const c = @import("c");
|
||||||
|
const pjrt = @import("pjrt");
|
||||||
|
const runfiles = @import("runfiles");
|
||||||
|
const stdx = @import("stdx");
|
||||||
|
|
||||||
|
const nvidiaLibsPath = "/cuda/";
|
||||||
|
|
||||||
|
const log = std.log.scoped(.@"zml/runtime/cuda");
|
||||||
|
|
||||||
pub fn isEnabled() bool {
|
pub fn isEnabled() bool {
|
||||||
return @hasDecl(c, "ZML_RUNTIME_CUDA");
|
return @hasDecl(c, "ZML_RUNTIME_CUDA");
|
||||||
@ -22,7 +28,24 @@ fn hasCudaPathInLDPath() bool {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
return std.mem.indexOf(u8, std.mem.span(ldLibraryPath), nvidiaLibsPath) != null;
|
return std.ascii.indexOfIgnoreCase(std.mem.span(ldLibraryPath), nvidiaLibsPath) != null;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn setupXlaGpuCudaDirFlag() !void {
|
||||||
|
var arena = std.heap.ArenaAllocator.init(std.heap.c_allocator);
|
||||||
|
defer arena.deinit();
|
||||||
|
|
||||||
|
var r_ = try runfiles.Runfiles.create(.{ .allocator = arena.allocator() }) orelse {
|
||||||
|
stdx.debug.panic("Unable to find CUDA directory", .{});
|
||||||
|
};
|
||||||
|
|
||||||
|
const source_repo = bazel_builtin.current_repository;
|
||||||
|
const r = r_.withSourceRepo(source_repo);
|
||||||
|
const cuda_data_dir = (try r.rlocationAlloc(arena.allocator(), "libpjrt_cuda/sandbox")).?;
|
||||||
|
const xla_flags = std.process.getEnvVarOwned(arena.allocator(), "XLA_FLAGS") catch "";
|
||||||
|
const new_xla_flagsZ = try std.fmt.allocPrintZ(arena.allocator(), "{s} --xla_gpu_cuda_data_dir={s}", .{ xla_flags, cuda_data_dir });
|
||||||
|
|
||||||
|
_ = c.setenv("XLA_FLAGS", new_xla_flagsZ, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn load() !*const pjrt.Api {
|
pub fn load() !*const pjrt.Api {
|
||||||
@ -36,8 +59,12 @@ pub fn load() !*const pjrt.Api {
|
|||||||
return error.Unavailable;
|
return error.Unavailable;
|
||||||
}
|
}
|
||||||
if (hasCudaPathInLDPath()) {
|
if (hasCudaPathInLDPath()) {
|
||||||
std.log.warn("Detected {s} in LD_LIBRARY_PATH. This can lead to undefined behaviors and crashes", .{nvidiaLibsPath});
|
log.warn("Detected {s} in LD_LIBRARY_PATH. This can lead to undefined behaviors and crashes", .{nvidiaLibsPath});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CUDA path has to be set _before_ loading the PJRT plugin.
|
||||||
|
// See https://github.com/openxla/xla/issues/21428
|
||||||
|
try setupXlaGpuCudaDirFlag();
|
||||||
|
|
||||||
return try asynk.callBlocking(pjrt.Api.loadFrom, .{"libpjrt_cuda.so"});
|
return try asynk.callBlocking(pjrt.Api.loadFrom, .{"libpjrt_cuda.so"});
|
||||||
}
|
}
|
||||||
|
|||||||
@ -25,7 +25,8 @@ copy_to_directory(
|
|||||||
cc_import(
|
cc_import(
|
||||||
name = "libpjrt_cuda",
|
name = "libpjrt_cuda",
|
||||||
data = [":sandbox"],
|
data = [":sandbox"],
|
||||||
shared_library = "libpjrt_cuda.so",
|
shared_library = "jax_plugins/xla_cuda12/xla_cuda_plugin.so",
|
||||||
|
soname = "libpjrt_cuda.so",
|
||||||
add_needed = ["libzmlxcuda.so.0"],
|
add_needed = ["libzmlxcuda.so.0"],
|
||||||
rename_dynamic_symbols = {
|
rename_dynamic_symbols = {
|
||||||
"dlopen": "zmlxcuda_dlopen",
|
"dlopen": "zmlxcuda_dlopen",
|
||||||
|
|||||||
@ -900,7 +900,7 @@ fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, m
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
switch (platform.target) {
|
switch (platform.target) {
|
||||||
.cuda => cuda_dir: {
|
.cuda => {
|
||||||
// NVIDIA recommends these settings
|
// NVIDIA recommends these settings
|
||||||
// https://github.com/NVIDIA/JAX-Toolbox?tab=readme-ov-file#environment-variables
|
// https://github.com/NVIDIA/JAX-Toolbox?tab=readme-ov-file#environment-variables
|
||||||
setFlag(&options, "xla_gpu_enable_triton_gemm", false);
|
setFlag(&options, "xla_gpu_enable_triton_gemm", false);
|
||||||
@ -914,17 +914,6 @@ fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, m
|
|||||||
// setFlag(&options, "xla_gpu_enable_dynamic_slice_fusion", true);
|
// setFlag(&options, "xla_gpu_enable_dynamic_slice_fusion", true);
|
||||||
// setFlag(&options, "xla_gpu_enable_while_loop_double_buffering", true);
|
// setFlag(&options, "xla_gpu_enable_while_loop_double_buffering", true);
|
||||||
// setFlag(&options, "xla_gpu_use_runtime_fusion", true);
|
// setFlag(&options, "xla_gpu_use_runtime_fusion", true);
|
||||||
|
|
||||||
var r_ = try runfiles.Runfiles.create(.{ .allocator = arena }) orelse {
|
|
||||||
log.warn("Bazel runfile not found !", .{});
|
|
||||||
break :cuda_dir;
|
|
||||||
};
|
|
||||||
defer r_.deinit(arena);
|
|
||||||
const source_repo = @import("bazel_builtin").current_repository;
|
|
||||||
const r = r_.withSourceRepo(source_repo);
|
|
||||||
const cuda_data_dir = (try r.rlocationAlloc(arena, "libpjrt_cuda/sandbox")).?;
|
|
||||||
log.debug("xla_gpu_cuda_data_dir: {s}", .{cuda_data_dir});
|
|
||||||
setFlag(&options, "xla_gpu_cuda_data_dir", cuda_data_dir);
|
|
||||||
},
|
},
|
||||||
.rocm => {
|
.rocm => {
|
||||||
// Disable Triton GEMM on ROCM. For some reason it's much, much slower when
|
// Disable Triton GEMM on ROCM. For some reason it's much, much slower when
|
||||||
|
|||||||
1051
zml/tokenizer.zig
1051
zml/tokenizer.zig
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user