Upgrade XLA to version 20250122.0-cc075be, switch to nvptx compiler and nvlink with nvjitlink support, add warning for CUDA path in LD_LIBRARY_PATH, and revert the previous CUDA sandbox fix.
This commit is contained in:
parent
b8a0aaee5a
commit
7e6103d876
@ -85,7 +85,7 @@ use_repo(zls, "zls_aarch64-macos", "zls_x86_64-linux")
|
|||||||
register_toolchains("//third_party/zls:all")
|
register_toolchains("//third_party/zls:all")
|
||||||
|
|
||||||
bazel_dep(name = "libxev", version = "20241208.2-db6a52b")
|
bazel_dep(name = "libxev", version = "20241208.2-db6a52b")
|
||||||
bazel_dep(name = "llvm-raw", version = "20250102.0-f739aa4")
|
bazel_dep(name = "llvm-raw", version = "20250117.0-bf17016")
|
||||||
|
|
||||||
llvm = use_extension("@llvm-raw//utils/bazel:extension.bzl", "llvm")
|
llvm = use_extension("@llvm-raw//utils/bazel:extension.bzl", "llvm")
|
||||||
llvm.configure(
|
llvm.configure(
|
||||||
@ -97,8 +97,8 @@ llvm.configure(
|
|||||||
)
|
)
|
||||||
use_repo(llvm, "llvm-project")
|
use_repo(llvm, "llvm-project")
|
||||||
|
|
||||||
bazel_dep(name = "stablehlo", version = "20241220.0-38bb2f9")
|
bazel_dep(name = "stablehlo", version = "20250117.0-c125b32")
|
||||||
bazel_dep(name = "xla", version = "20250103.0-5f1fe6a")
|
bazel_dep(name = "xla", version = "20250122.0-cc075be")
|
||||||
|
|
||||||
tsl = use_extension("@xla//:tsl.bzl", "tsl")
|
tsl = use_extension("@xla//:tsl.bzl", "tsl")
|
||||||
use_repo(tsl, "tsl")
|
use_repo(tsl, "tsl")
|
||||||
|
|||||||
@ -12,15 +12,15 @@ def _cpu_pjrt_plugin_impl(mctx):
|
|||||||
http_archive(
|
http_archive(
|
||||||
name = "libpjrt_cpu_linux_amd64",
|
name = "libpjrt_cpu_linux_amd64",
|
||||||
build_file_content = _BUILD.format(ext = "so"),
|
build_file_content = _BUILD.format(ext = "so"),
|
||||||
sha256 = "35b6aefa0359317ae2892f846d6da892bee2116d8c6722e397ef0120cf572183",
|
sha256 = "0f2cb204015e062df5d1cbd39d8c01c076ab2b004d0f4f37f6d5e120d3cd7087",
|
||||||
url = "https://github.com/zml/pjrt-artifacts/releases/download/v4.0.0/pjrt-cpu_linux-amd64.tar.gz",
|
url = "https://github.com/zml/pjrt-artifacts/releases/download/v5.0.0/pjrt-cpu_linux-amd64.tar.gz",
|
||||||
)
|
)
|
||||||
|
|
||||||
http_archive(
|
http_archive(
|
||||||
name = "libpjrt_cpu_darwin_arm64",
|
name = "libpjrt_cpu_darwin_arm64",
|
||||||
build_file_content = _BUILD.format(ext = "dylib"),
|
build_file_content = _BUILD.format(ext = "dylib"),
|
||||||
sha256 = "a532a2e1511f91ec6d6adc60290f6bc4d88d2521508661e90b9824061ebabb3a",
|
sha256 = "2ddb66a93c8a913e3bc8f291e01df59aa297592cc91e05aab2dd4813884098cb",
|
||||||
url = "https://github.com/zml/pjrt-artifacts/releases/download/v4.0.0/pjrt-cpu_darwin-arm64.tar.gz",
|
url = "https://github.com/zml/pjrt-artifacts/releases/download/v5.0.0/pjrt-cpu_darwin-arm64.tar.gz",
|
||||||
)
|
)
|
||||||
|
|
||||||
return mctx.extension_metadata(
|
return mctx.extension_metadata(
|
||||||
|
|||||||
@ -26,6 +26,16 @@ cc_import(
|
|||||||
)
|
)
|
||||||
""".format(name = repr(name), shared_library = repr(shared_library), deps = repr(deps))
|
""".format(name = repr(name), shared_library = repr(shared_library), deps = repr(deps))
|
||||||
|
|
||||||
|
def _cc_import_static(name, static_library, deps = []):
|
||||||
|
return """\
|
||||||
|
cc_import(
|
||||||
|
name = {name},
|
||||||
|
static_library = {static_library},
|
||||||
|
deps = {deps},
|
||||||
|
visibility = ["@libpjrt_cuda//:__subpackages__"],
|
||||||
|
)
|
||||||
|
""".format(name = repr(name), static_library = repr(static_library), deps = repr(deps))
|
||||||
|
|
||||||
CUDA_PACKAGES = {
|
CUDA_PACKAGES = {
|
||||||
"cuda_cudart": _cc_import(
|
"cuda_cudart": _cc_import(
|
||||||
name = "cudart",
|
name = "cudart",
|
||||||
@ -56,6 +66,10 @@ CUDA_PACKAGES = {
|
|||||||
name = "ptxas",
|
name = "ptxas",
|
||||||
srcs = ["bin/ptxas"],
|
srcs = ["bin/ptxas"],
|
||||||
),
|
),
|
||||||
|
_filegroup(
|
||||||
|
name = "nvlink",
|
||||||
|
srcs = ["bin/nvlink"],
|
||||||
|
),
|
||||||
_filegroup(
|
_filegroup(
|
||||||
name = "libdevice",
|
name = "libdevice",
|
||||||
srcs = ["nvvm/libdevice/libdevice.10.bc"],
|
srcs = ["nvvm/libdevice/libdevice.10.bc"],
|
||||||
@ -64,6 +78,10 @@ CUDA_PACKAGES = {
|
|||||||
name = "nvvm",
|
name = "nvvm",
|
||||||
shared_library = "nvvm/lib64/libnvvm.so.4",
|
shared_library = "nvvm/lib64/libnvvm.so.4",
|
||||||
),
|
),
|
||||||
|
_cc_import_static(
|
||||||
|
name = "nvptxcompiler",
|
||||||
|
static_library = "lib/libnvptxcompiler_static.a",
|
||||||
|
),
|
||||||
]),
|
]),
|
||||||
"cuda_nvrtc": "\n".join([
|
"cuda_nvrtc": "\n".join([
|
||||||
_cc_import(
|
_cc_import(
|
||||||
@ -190,9 +208,8 @@ def _cuda_impl(mctx):
|
|||||||
http_archive(
|
http_archive(
|
||||||
name = "libpjrt_cuda",
|
name = "libpjrt_cuda",
|
||||||
build_file = "libpjrt_cuda.BUILD.bazel",
|
build_file = "libpjrt_cuda.BUILD.bazel",
|
||||||
url = "https://files.pythonhosted.org/packages/90/43/ac2c369e202e3e3e7e5aa7929b197801ba02eaf11868437adaa5341704e4/jax_cuda12_pjrt-0.4.38-py3-none-manylinux2014_x86_64.whl",
|
url = "https://github.com/zml/pjrt-artifacts/releases/download/v5.0.0/pjrt-cuda_linux-amd64.tar.gz",
|
||||||
type = "zip",
|
sha256 = "1c3ca76d887d112762d03ebb28f17a08beebf6338453c3044a36225e1678a113",
|
||||||
sha256 = "83be4c59fbcf30077a60085d98e7d59dc738b1c91e0d628e4ac1779fde15ac2b",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return mctx.extension_metadata(
|
return mctx.extension_metadata(
|
||||||
|
|||||||
@ -1,12 +1,10 @@
|
|||||||
const builtin = @import("builtin");
|
const builtin = @import("builtin");
|
||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
|
|
||||||
const asynk = @import("async");
|
const asynk = @import("async");
|
||||||
const bazel_builtin = @import("bazel_builtin");
|
|
||||||
const c = @import("c");
|
|
||||||
const pjrt = @import("pjrt");
|
const pjrt = @import("pjrt");
|
||||||
const runfiles = @import("runfiles");
|
const c = @import("c");
|
||||||
const stdx = @import("stdx");
|
|
||||||
|
const nvidiaLibsPath = "/usr/local/cuda/lib64";
|
||||||
|
|
||||||
pub fn isEnabled() bool {
|
pub fn isEnabled() bool {
|
||||||
return @hasDecl(c, "ZML_RUNTIME_CUDA");
|
return @hasDecl(c, "ZML_RUNTIME_CUDA");
|
||||||
@ -17,21 +15,14 @@ fn hasNvidiaDevice() bool {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
fn setupXlaGpuCudaDirFlag() !void {
|
fn hasCudaPathInLDPath() bool {
|
||||||
var arena = std.heap.ArenaAllocator.init(std.heap.c_allocator);
|
const ldLibraryPath = c.getenv("LD_LIBRARY_PATH");
|
||||||
defer arena.deinit();
|
|
||||||
|
|
||||||
var r_ = try runfiles.Runfiles.create(.{ .allocator = arena.allocator() }) orelse {
|
if (ldLibraryPath == null) {
|
||||||
stdx.debug.panic("Unable to find CUDA directory", .{});
|
return false;
|
||||||
};
|
}
|
||||||
|
|
||||||
const source_repo = bazel_builtin.current_repository;
|
return std.mem.indexOf(u8, std.mem.span(ldLibraryPath), nvidiaLibsPath) != null;
|
||||||
const r = r_.withSourceRepo(source_repo);
|
|
||||||
const cuda_data_dir = (try r.rlocationAlloc(arena.allocator(), "libpjrt_cuda/sandbox")).?;
|
|
||||||
const xla_flags = std.process.getEnvVarOwned(arena.allocator(), "XLA_FLAGS") catch "";
|
|
||||||
const new_xla_flagsZ = try std.fmt.allocPrintZ(arena.allocator(), "--xla_gpu_cuda_data_dir={s} {s}", .{ cuda_data_dir, xla_flags });
|
|
||||||
|
|
||||||
_ = c.setenv("XLA_FLAGS", new_xla_flagsZ, 1);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn load() !*const pjrt.Api {
|
pub fn load() !*const pjrt.Api {
|
||||||
@ -44,10 +35,9 @@ pub fn load() !*const pjrt.Api {
|
|||||||
if (!hasNvidiaDevice()) {
|
if (!hasNvidiaDevice()) {
|
||||||
return error.Unavailable;
|
return error.Unavailable;
|
||||||
}
|
}
|
||||||
|
if (hasCudaPathInLDPath()) {
|
||||||
// CUDA path has to be set _before_ loading the PJRT plugin.
|
std.log.warn("Detected {s} in LD_LIBRARY_PATH. This can lead to undefined behaviors and crashes", .{nvidiaLibsPath});
|
||||||
// See https://github.com/openxla/xla/issues/21428
|
}
|
||||||
try setupXlaGpuCudaDirFlag();
|
|
||||||
|
|
||||||
return try asynk.callBlocking(pjrt.Api.loadFrom, .{"libpjrt_cuda.so"});
|
return try asynk.callBlocking(pjrt.Api.loadFrom, .{"libpjrt_cuda.so"});
|
||||||
}
|
}
|
||||||
|
|||||||
@ -17,6 +17,7 @@ copy_to_directory(
|
|||||||
srcs = [
|
srcs = [
|
||||||
"@cuda_nvcc//:libdevice",
|
"@cuda_nvcc//:libdevice",
|
||||||
"@cuda_nvcc//:ptxas",
|
"@cuda_nvcc//:ptxas",
|
||||||
|
"@cuda_nvcc//:nvlink",
|
||||||
],
|
],
|
||||||
include_external_repositories = ["**"],
|
include_external_repositories = ["**"],
|
||||||
)
|
)
|
||||||
@ -24,8 +25,7 @@ copy_to_directory(
|
|||||||
cc_import(
|
cc_import(
|
||||||
name = "libpjrt_cuda",
|
name = "libpjrt_cuda",
|
||||||
data = [":sandbox"],
|
data = [":sandbox"],
|
||||||
shared_library = "jax_plugins/xla_cuda12/xla_cuda_plugin.so",
|
shared_library = "libpjrt_cuda.so",
|
||||||
soname = "libpjrt_cuda.so",
|
|
||||||
add_needed = ["libzmlxcuda.so.0"],
|
add_needed = ["libzmlxcuda.so.0"],
|
||||||
rename_dynamic_symbols = {
|
rename_dynamic_symbols = {
|
||||||
"dlopen": "zmlxcuda_dlopen",
|
"dlopen": "zmlxcuda_dlopen",
|
||||||
@ -35,6 +35,7 @@ cc_import(
|
|||||||
":zmlxcuda",
|
":zmlxcuda",
|
||||||
"@cuda_cudart//:cudart",
|
"@cuda_cudart//:cudart",
|
||||||
"@cuda_cupti//:cupti",
|
"@cuda_cupti//:cupti",
|
||||||
|
"@cuda_nvcc//:nvptxcompiler",
|
||||||
"@cuda_nvcc//:nvvm",
|
"@cuda_nvcc//:nvvm",
|
||||||
"@cuda_nvrtc//:nvrtc",
|
"@cuda_nvrtc//:nvrtc",
|
||||||
"@cudnn//:cudnn",
|
"@cudnn//:cudnn",
|
||||||
|
|||||||
@ -215,8 +215,8 @@ def _rocm_impl(mctx):
|
|||||||
http_archive(
|
http_archive(
|
||||||
name = "libpjrt_rocm",
|
name = "libpjrt_rocm",
|
||||||
build_file = "libpjrt_rocm.BUILD.bazel",
|
build_file = "libpjrt_rocm.BUILD.bazel",
|
||||||
url = "https://github.com/zml/pjrt-artifacts/releases/download/v4.0.0/pjrt-rocm_linux-amd64.tar.gz",
|
url = "https://github.com/zml/pjrt-artifacts/releases/download/v5.0.0/pjrt-rocm_linux-amd64.tar.gz",
|
||||||
sha256 = "75c2baf2efba0b2c6fe2513d06e542ed3f3a966e43498cc1d932465f646ca34d",
|
sha256 = "2c7a687827f63987caa117cd5b56a6e20291681ae1c51edd54241a1181e91d2d",
|
||||||
)
|
)
|
||||||
|
|
||||||
return mctx.extension_metadata(
|
return mctx.extension_metadata(
|
||||||
|
|||||||
11
third_party/modules/llvm-raw/20250117.0-bf17016/MODULE.bazel
vendored
Normal file
11
third_party/modules/llvm-raw/20250117.0-bf17016/MODULE.bazel
vendored
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
module(
|
||||||
|
name = "llvm-raw",
|
||||||
|
version = "20250117.0-bf17016",
|
||||||
|
compatibility_level = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
bazel_dep(name = "bazel_skylib", version = "1.7.1")
|
||||||
|
bazel_dep(name = "platforms", version = "0.0.10")
|
||||||
|
bazel_dep(name = "zstd", version = "1.5.6", repo_name = "llvm_zstd")
|
||||||
|
bazel_dep(name = "zlib", version = "1.3.1.bcr.3", repo_name = "llvm_zlib")
|
||||||
|
bazel_dep(name = "rules_python", version = "0.29.0")
|
||||||
0
third_party/modules/llvm-raw/20250117.0-bf17016/overlay/BUILD.bazel
vendored
Normal file
0
third_party/modules/llvm-raw/20250117.0-bf17016/overlay/BUILD.bazel
vendored
Normal file
11
third_party/modules/llvm-raw/20250117.0-bf17016/overlay/MODULE.bazel
vendored
Normal file
11
third_party/modules/llvm-raw/20250117.0-bf17016/overlay/MODULE.bazel
vendored
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
module(
|
||||||
|
name = "llvm-raw",
|
||||||
|
version = "20250117.0-bf17016",
|
||||||
|
compatibility_level = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
bazel_dep(name = "bazel_skylib", version = "1.7.1")
|
||||||
|
bazel_dep(name = "platforms", version = "0.0.10")
|
||||||
|
bazel_dep(name = "zstd", version = "1.5.6", repo_name = "llvm_zstd")
|
||||||
|
bazel_dep(name = "zlib", version = "1.3.1.bcr.3", repo_name = "llvm_zlib")
|
||||||
|
bazel_dep(name = "rules_python", version = "0.29.0")
|
||||||
28
third_party/modules/llvm-raw/20250117.0-bf17016/overlay/utils/bazel/extension.bzl
vendored
Normal file
28
third_party/modules/llvm-raw/20250117.0-bf17016/overlay/utils/bazel/extension.bzl
vendored
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
load("//utils/bazel:configure.bzl", _llvm_configure = "llvm_configure")
|
||||||
|
|
||||||
|
def _llvm_impl(mctx):
|
||||||
|
_targets = {}
|
||||||
|
for mod in mctx.modules:
|
||||||
|
for conf in mod.tags.configure:
|
||||||
|
for target in conf.targets:
|
||||||
|
_targets[target] = True
|
||||||
|
_llvm_configure(
|
||||||
|
name = "llvm-project",
|
||||||
|
targets = _targets.keys(),
|
||||||
|
)
|
||||||
|
return mctx.extension_metadata(
|
||||||
|
reproducible = True,
|
||||||
|
root_module_direct_deps = "all",
|
||||||
|
root_module_direct_dev_deps = [],
|
||||||
|
)
|
||||||
|
|
||||||
|
llvm = module_extension(
|
||||||
|
implementation = _llvm_impl,
|
||||||
|
tag_classes = {
|
||||||
|
"configure": tag_class(
|
||||||
|
attrs = {
|
||||||
|
"targets": attr.string_list(mandatory = True),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
10
third_party/modules/llvm-raw/20250117.0-bf17016/source.json
vendored
Normal file
10
third_party/modules/llvm-raw/20250117.0-bf17016/source.json
vendored
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
{
|
||||||
|
"strip_prefix": "llvm-project-bf17016a92bc8a23d2cdd2b51355dd4eb5019c68",
|
||||||
|
"url": "https://github.com/llvm/llvm-project/archive/bf17016a92bc8a23d2cdd2b51355dd4eb5019c68.tar.gz",
|
||||||
|
"integrity": "sha256-ugnxLlAZ9aylMbFzMnXwoQsYHW+JTesaRhDgF/drFyo=",
|
||||||
|
"overlay": {
|
||||||
|
"BUILD.bazel": "",
|
||||||
|
"MODULE.bazel": "",
|
||||||
|
"utils/bazel/extension.bzl": ""
|
||||||
|
}
|
||||||
|
}
|
||||||
1
third_party/modules/llvm-raw/metadata.json
vendored
1
third_party/modules/llvm-raw/metadata.json
vendored
@ -15,6 +15,7 @@
|
|||||||
"20240919.0-94c024a",
|
"20240919.0-94c024a",
|
||||||
"20241022.0-6c4267f",
|
"20241022.0-6c4267f",
|
||||||
"20250102.0-f739aa4",
|
"20250102.0-f739aa4",
|
||||||
|
"20250117.0-bf17016",
|
||||||
],
|
],
|
||||||
"yanked_versions": {}
|
"yanked_versions": {}
|
||||||
}
|
}
|
||||||
|
|||||||
15
third_party/modules/stablehlo/20250117.0-c125b32/MODULE.bazel
vendored
Normal file
15
third_party/modules/stablehlo/20250117.0-c125b32/MODULE.bazel
vendored
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
module(
|
||||||
|
name = "stablehlo",
|
||||||
|
version = "20250117.0-c125b32",
|
||||||
|
compatibility_level = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
bazel_dep(name = "bazel_skylib", version = "1.7.1")
|
||||||
|
bazel_dep(name = "rules_cc", version = "0.0.9")
|
||||||
|
bazel_dep(name = "llvm-raw", version = "20250117.0-bf17016")
|
||||||
|
|
||||||
|
llvm = use_extension("@llvm-raw//utils/bazel:extension.bzl", "llvm")
|
||||||
|
llvm.configure(
|
||||||
|
targets = ["AArch64", "X86", "NVPTX"],
|
||||||
|
)
|
||||||
|
use_repo(llvm, "llvm-project")
|
||||||
15
third_party/modules/stablehlo/20250117.0-c125b32/overlay/MODULE.bazel
vendored
Normal file
15
third_party/modules/stablehlo/20250117.0-c125b32/overlay/MODULE.bazel
vendored
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
module(
|
||||||
|
name = "stablehlo",
|
||||||
|
version = "20250117.0-c125b32",
|
||||||
|
compatibility_level = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
bazel_dep(name = "bazel_skylib", version = "1.7.1")
|
||||||
|
bazel_dep(name = "rules_cc", version = "0.0.9")
|
||||||
|
bazel_dep(name = "llvm-raw", version = "20250117.0-bf17016")
|
||||||
|
|
||||||
|
llvm = use_extension("@llvm-raw//utils/bazel:extension.bzl", "llvm")
|
||||||
|
llvm.configure(
|
||||||
|
targets = ["AArch64", "X86", "NVPTX"],
|
||||||
|
)
|
||||||
|
use_repo(llvm, "llvm-project")
|
||||||
8
third_party/modules/stablehlo/20250117.0-c125b32/source.json
vendored
Normal file
8
third_party/modules/stablehlo/20250117.0-c125b32/source.json
vendored
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
{
|
||||||
|
"strip_prefix": "stablehlo-c125b3284819fec57120231cf4430657dab7b881",
|
||||||
|
"url": "https://github.com/openxla/stablehlo/archive/c125b3284819fec57120231cf4430657dab7b881.tar.gz",
|
||||||
|
"integrity": "sha256-h4iBDLVbK2JZRDY7qdmNdns25AoVjpLekKdmsETEDf8=",
|
||||||
|
"overlay": {
|
||||||
|
"MODULE.bazel": ""
|
||||||
|
}
|
||||||
|
}
|
||||||
1
third_party/modules/stablehlo/metadata.json
vendored
1
third_party/modules/stablehlo/metadata.json
vendored
@ -15,6 +15,7 @@
|
|||||||
"20240917.0-78c753a",
|
"20240917.0-78c753a",
|
||||||
"20241021.0-1c0b606",
|
"20241021.0-1c0b606",
|
||||||
"20241220.0-38bb2f9",
|
"20241220.0-38bb2f9",
|
||||||
|
"20250117.0-c125b32",
|
||||||
],
|
],
|
||||||
"yanked_versions": {}
|
"yanked_versions": {}
|
||||||
}
|
}
|
||||||
|
|||||||
34
third_party/modules/xla/20250122.0-cc075be/MODULE.bazel
vendored
Normal file
34
third_party/modules/xla/20250122.0-cc075be/MODULE.bazel
vendored
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
module(
|
||||||
|
name = "xla",
|
||||||
|
version = "20250122.0-cc075be",
|
||||||
|
compatibility_level = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
bazel_dep(name = "platforms", version = "0.0.8")
|
||||||
|
bazel_dep(name = "bazel_skylib", version = "1.5.0")
|
||||||
|
bazel_dep(name = "rules_cc", version = "0.0.9")
|
||||||
|
bazel_dep(name = "rules_apple", version = "3.2.1", repo_name = "build_bazel_rules_apple")
|
||||||
|
bazel_dep(name = "abseil-cpp", version = "20240116.0", repo_name = "com_google_absl")
|
||||||
|
bazel_dep(name = "rules_python", version = "0.29.0")
|
||||||
|
bazel_dep(name = "rules_proto", version = "6.0.0-rc1")
|
||||||
|
bazel_dep(name = "rules_java", version = "7.3.2")
|
||||||
|
bazel_dep(name = "rules_pkg", version = "0.9.1")
|
||||||
|
bazel_dep(name = "zlib", version = "1.2.13")
|
||||||
|
bazel_dep(name = "re2", version = "2024-02-01", repo_name = "com_googlesource_code_re2")
|
||||||
|
bazel_dep(name = "rules_license", version = "0.0.8")
|
||||||
|
|
||||||
|
bazel_dep(name = "stablehlo", version = "20250117.0-c125b32")
|
||||||
|
|
||||||
|
tsl = use_extension("//:tsl.bzl", "tsl")
|
||||||
|
use_repo(tsl, "tsl")
|
||||||
|
|
||||||
|
xla_workspace = use_extension("//:workspace.bzl", "xla_workspace")
|
||||||
|
use_repo(
|
||||||
|
xla_workspace,
|
||||||
|
"com_github_grpc_grpc",
|
||||||
|
"com_google_protobuf",
|
||||||
|
"local_config_cuda",
|
||||||
|
"local_config_remote_execution",
|
||||||
|
"local_config_rocm",
|
||||||
|
"local_config_tensorrt",
|
||||||
|
)
|
||||||
34
third_party/modules/xla/20250122.0-cc075be/overlay/MODULE.bazel
vendored
Normal file
34
third_party/modules/xla/20250122.0-cc075be/overlay/MODULE.bazel
vendored
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
module(
|
||||||
|
name = "xla",
|
||||||
|
version = "20250122.0-cc075be",
|
||||||
|
compatibility_level = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
bazel_dep(name = "platforms", version = "0.0.8")
|
||||||
|
bazel_dep(name = "bazel_skylib", version = "1.5.0")
|
||||||
|
bazel_dep(name = "rules_cc", version = "0.0.9")
|
||||||
|
bazel_dep(name = "rules_apple", version = "3.2.1", repo_name = "build_bazel_rules_apple")
|
||||||
|
bazel_dep(name = "abseil-cpp", version = "20240116.0", repo_name = "com_google_absl")
|
||||||
|
bazel_dep(name = "rules_python", version = "0.29.0")
|
||||||
|
bazel_dep(name = "rules_proto", version = "6.0.0-rc1")
|
||||||
|
bazel_dep(name = "rules_java", version = "7.3.2")
|
||||||
|
bazel_dep(name = "rules_pkg", version = "0.9.1")
|
||||||
|
bazel_dep(name = "zlib", version = "1.2.13")
|
||||||
|
bazel_dep(name = "re2", version = "2024-02-01", repo_name = "com_googlesource_code_re2")
|
||||||
|
bazel_dep(name = "rules_license", version = "0.0.8")
|
||||||
|
|
||||||
|
bazel_dep(name = "stablehlo", version = "20250117.0-c125b32")
|
||||||
|
|
||||||
|
tsl = use_extension("//:tsl.bzl", "tsl")
|
||||||
|
use_repo(tsl, "tsl")
|
||||||
|
|
||||||
|
xla_workspace = use_extension("//:workspace.bzl", "xla_workspace")
|
||||||
|
use_repo(
|
||||||
|
xla_workspace,
|
||||||
|
"com_github_grpc_grpc",
|
||||||
|
"com_google_protobuf",
|
||||||
|
"local_config_cuda",
|
||||||
|
"local_config_remote_execution",
|
||||||
|
"local_config_rocm",
|
||||||
|
"local_config_tensorrt",
|
||||||
|
)
|
||||||
19
third_party/modules/xla/20250122.0-cc075be/overlay/tsl.bzl
vendored
Normal file
19
third_party/modules/xla/20250122.0-cc075be/overlay/tsl.bzl
vendored
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
load("//third_party:repo.bzl", "tf_vendored")
|
||||||
|
load("//third_party/py:python_init_repositories.bzl", "python_init_repositories")
|
||||||
|
|
||||||
|
def _tsl_impl(mctx):
|
||||||
|
python_init_repositories(
|
||||||
|
requirements = {
|
||||||
|
"3.11": "//:requirements_lock_3_11.txt",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
tf_vendored(name = "tsl", relpath = "third_party/tsl")
|
||||||
|
return mctx.extension_metadata(
|
||||||
|
reproducible = True,
|
||||||
|
root_module_direct_deps = ["tsl"],
|
||||||
|
root_module_direct_dev_deps = [],
|
||||||
|
)
|
||||||
|
|
||||||
|
tsl = module_extension(
|
||||||
|
implementation = _tsl_impl,
|
||||||
|
)
|
||||||
52
third_party/modules/xla/20250122.0-cc075be/overlay/workspace.bzl
vendored
Normal file
52
third_party/modules/xla/20250122.0-cc075be/overlay/workspace.bzl
vendored
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
load("@tsl//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
|
||||||
|
load("@tsl//third_party/gpus:cuda_configure.bzl", "cuda_configure")
|
||||||
|
load("@tsl//third_party/gpus:rocm_configure.bzl", "rocm_configure")
|
||||||
|
load("@tsl//third_party/tensorrt:tensorrt_configure.bzl", "tensorrt_configure")
|
||||||
|
load("@tsl//tools/toolchains/remote:configure.bzl", "remote_execution_configure")
|
||||||
|
|
||||||
|
def _xla_workspace_impl(mctx):
|
||||||
|
cuda_configure(name = "local_config_cuda")
|
||||||
|
remote_execution_configure(name = "local_config_remote_execution")
|
||||||
|
rocm_configure(name = "local_config_rocm")
|
||||||
|
tensorrt_configure(name = "local_config_tensorrt")
|
||||||
|
tf_http_archive(
|
||||||
|
name = "com_github_grpc_grpc",
|
||||||
|
sha256 = "b956598d8cbe168b5ee717b5dafa56563eb5201a947856a6688bbeac9cac4e1f",
|
||||||
|
strip_prefix = "grpc-b54a5b338637f92bfcf4b0bc05e0f57a5fd8fadd",
|
||||||
|
system_build_file = "@tsl//third_party/systemlibs:grpc.BUILD",
|
||||||
|
patch_file = [
|
||||||
|
"@tsl//third_party/grpc:generate_cc_env_fix.patch",
|
||||||
|
"@tsl//third_party/grpc:register_go_toolchain.patch",
|
||||||
|
],
|
||||||
|
system_link_files = {
|
||||||
|
"@tsl//third_party/systemlibs:BUILD": "bazel/BUILD",
|
||||||
|
"@tsl//third_party/systemlibs:grpc.BUILD": "src/compiler/BUILD",
|
||||||
|
"@tsl//third_party/systemlibs:grpc.bazel.grpc_deps.bzl": "bazel/grpc_deps.bzl",
|
||||||
|
"@tsl//third_party/systemlibs:grpc.bazel.grpc_extra_deps.bzl": "bazel/grpc_extra_deps.bzl",
|
||||||
|
"@tsl//third_party/systemlibs:grpc.bazel.cc_grpc_library.bzl": "bazel/cc_grpc_library.bzl",
|
||||||
|
"@tsl//third_party/systemlibs:grpc.bazel.generate_cc.bzl": "bazel/generate_cc.bzl",
|
||||||
|
"@tsl//third_party/systemlibs:grpc.bazel.protobuf.bzl": "bazel/protobuf.bzl",
|
||||||
|
},
|
||||||
|
urls = tf_mirror_urls("https://github.com/grpc/grpc/archive/b54a5b338637f92bfcf4b0bc05e0f57a5fd8fadd.tar.gz"),
|
||||||
|
)
|
||||||
|
tf_http_archive(
|
||||||
|
name = "com_google_protobuf",
|
||||||
|
patch_file = ["@tsl//third_party/protobuf:protobuf.patch"],
|
||||||
|
sha256 = "f66073dee0bc159157b0bd7f502d7d1ee0bc76b3c1eac9836927511bdc4b3fc1",
|
||||||
|
strip_prefix = "protobuf-3.21.9",
|
||||||
|
system_build_file = "@tsl//third_party/systemlibs:protobuf.BUILD",
|
||||||
|
system_link_files = {
|
||||||
|
"@tsl//third_party/systemlibs:protobuf.bzl": "protobuf.bzl",
|
||||||
|
"@tsl//third_party/systemlibs:protobuf_deps.bzl": "protobuf_deps.bzl",
|
||||||
|
},
|
||||||
|
urls = tf_mirror_urls("https://github.com/protocolbuffers/protobuf/archive/v3.21.9.zip"),
|
||||||
|
)
|
||||||
|
return mctx.extension_metadata(
|
||||||
|
reproducible = True,
|
||||||
|
root_module_direct_deps = "all",
|
||||||
|
root_module_direct_dev_deps = [],
|
||||||
|
)
|
||||||
|
|
||||||
|
xla_workspace = module_extension(
|
||||||
|
implementation = _xla_workspace_impl,
|
||||||
|
)
|
||||||
@ -0,0 +1,27 @@
|
|||||||
|
From 4db5de34f70d991fedbe28915c8239b97ba7a064 Mon Sep 17 00:00:00 2001
|
||||||
|
From: Steeve Morin <steeve.morin@gmail.com>
|
||||||
|
Date: Mon, 18 Mar 2024 17:17:34 +0100
|
||||||
|
Subject: [PATCH 3/3] [PJRT C API] Ensure C compliance for Profiler Extension
|
||||||
|
|
||||||
|
---
|
||||||
|
xla/pjrt/c/pjrt_c_api_profiler_extension.h | 2 ++
|
||||||
|
1 file changed, 2 insertions(+)
|
||||||
|
|
||||||
|
diff --git a/xla/pjrt/c/pjrt_c_api_profiler_extension.h b/xla/pjrt/c/pjrt_c_api_profiler_extension.h
|
||||||
|
index c821916ad..89a596123 100644
|
||||||
|
--- a/xla/pjrt/c/pjrt_c_api_profiler_extension.h
|
||||||
|
+++ b/xla/pjrt/c/pjrt_c_api_profiler_extension.h
|
||||||
|
@@ -16,8 +16,10 @@ limitations under the License.
|
||||||
|
#ifndef XLA_PJRT_C_PJRT_C_API_PROFILER_EXTENSION_H_
|
||||||
|
#define XLA_PJRT_C_PJRT_C_API_PROFILER_EXTENSION_H_
|
||||||
|
|
||||||
|
+#ifdef __cplusplus
|
||||||
|
#include <cstddef>
|
||||||
|
#include <cstdint>
|
||||||
|
+#endif
|
||||||
|
|
||||||
|
#include "xla/backends/profiler/plugin/profiler_c_api.h"
|
||||||
|
#include "xla/pjrt/c/pjrt_c_api.h"
|
||||||
|
--
|
||||||
|
2.39.3 (Apple Git-146)
|
||||||
|
|
||||||
14
third_party/modules/xla/20250122.0-cc075be/source.json
vendored
Normal file
14
third_party/modules/xla/20250122.0-cc075be/source.json
vendored
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
{
|
||||||
|
"strip_prefix": "xla-cc075beb6148c2777da2b6749c63830856ee6c2a",
|
||||||
|
"url": "https://github.com/openxla/xla/archive/cc075beb6148c2777da2b6749c63830856ee6c2a.tar.gz",
|
||||||
|
"integrity": "sha256-oB8S38WZKEXBtZ6rARd0oL9SAtSmfsPM2xTuj3ylexc=",
|
||||||
|
"overlay": {
|
||||||
|
"tsl.bzl": "",
|
||||||
|
"workspace.bzl": "",
|
||||||
|
"MODULE.bazel": ""
|
||||||
|
},
|
||||||
|
"patch_strip": 1,
|
||||||
|
"patches": {
|
||||||
|
"0003-PJRT-C-API-Ensure-C-compliance-for-Profiler-Extensio.patch": ""
|
||||||
|
}
|
||||||
|
}
|
||||||
1
third_party/modules/xla/metadata.json
vendored
1
third_party/modules/xla/metadata.json
vendored
@ -15,6 +15,7 @@
|
|||||||
"20240919.0-1b18dd6",
|
"20240919.0-1b18dd6",
|
||||||
"20241025.0-4663f04",
|
"20241025.0-4663f04",
|
||||||
"20250103.0-5f1fe6a",
|
"20250103.0-5f1fe6a",
|
||||||
|
"20250122.0-cc075be",
|
||||||
],
|
],
|
||||||
"yanked_versions": {}
|
"yanked_versions": {}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -2,6 +2,7 @@ const std = @import("std");
|
|||||||
|
|
||||||
const asynk = @import("async");
|
const asynk = @import("async");
|
||||||
const dialect = @import("mlir/dialects");
|
const dialect = @import("mlir/dialects");
|
||||||
|
const runfiles = @import("runfiles");
|
||||||
const stdx = @import("stdx");
|
const stdx = @import("stdx");
|
||||||
const xla_pb = @import("//xla:xla_proto");
|
const xla_pb = @import("//xla:xla_proto");
|
||||||
|
|
||||||
@ -901,11 +902,13 @@ fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, m
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
switch (platform.target) {
|
switch (platform.target) {
|
||||||
.cuda => {
|
.cuda => cuda_dir: {
|
||||||
// NVIDIA recommends these settings
|
// NVIDIA recommends these settings
|
||||||
// https://github.com/NVIDIA/JAX-Toolbox?tab=readme-ov-file#environment-variables
|
// https://github.com/NVIDIA/JAX-Toolbox?tab=readme-ov-file#environment-variables
|
||||||
setFlag(&options, "xla_gpu_enable_triton_gemm", false);
|
setFlag(&options, "xla_gpu_enable_triton_gemm", false);
|
||||||
setFlag(&options, "xla_gpu_enable_latency_hiding_scheduler", true);
|
setFlag(&options, "xla_gpu_enable_latency_hiding_scheduler", true);
|
||||||
|
setFlag(&options, "xla_gpu_enable_llvm_module_compilation_parallelism", true);
|
||||||
|
setFlag(&options, "xla_gpu_enable_libnvptxcompiler", true);
|
||||||
// setFlag(&options, "xla_gpu_enable_cudnn_fmha", true);
|
// setFlag(&options, "xla_gpu_enable_cudnn_fmha", true);
|
||||||
// setFlag(&options, "xla_gpu_fused_attention_use_cudnn_rng", true);
|
// setFlag(&options, "xla_gpu_fused_attention_use_cudnn_rng", true);
|
||||||
// setFlag(&options, "xla_gpu_enable_cudnn_layer_norm", true);
|
// setFlag(&options, "xla_gpu_enable_cudnn_layer_norm", true);
|
||||||
@ -913,6 +916,17 @@ fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, m
|
|||||||
// setFlag(&options, "xla_gpu_enable_dynamic_slice_fusion", true);
|
// setFlag(&options, "xla_gpu_enable_dynamic_slice_fusion", true);
|
||||||
// setFlag(&options, "xla_gpu_enable_while_loop_double_buffering", true);
|
// setFlag(&options, "xla_gpu_enable_while_loop_double_buffering", true);
|
||||||
// setFlag(&options, "xla_gpu_use_runtime_fusion", true);
|
// setFlag(&options, "xla_gpu_use_runtime_fusion", true);
|
||||||
|
|
||||||
|
var r_ = try runfiles.Runfiles.create(.{ .allocator = arena }) orelse {
|
||||||
|
log.warn("Bazel runfile not found !", .{});
|
||||||
|
break :cuda_dir;
|
||||||
|
};
|
||||||
|
defer r_.deinit(arena);
|
||||||
|
const source_repo = @import("bazel_builtin").current_repository;
|
||||||
|
const r = r_.withSourceRepo(source_repo);
|
||||||
|
const cuda_data_dir = (try r.rlocationAlloc(arena, "libpjrt_cuda/sandbox")).?;
|
||||||
|
log.debug("xla_gpu_cuda_data_dir: {s}", .{cuda_data_dir});
|
||||||
|
setFlag(&options, "xla_gpu_cuda_data_dir", cuda_data_dir);
|
||||||
},
|
},
|
||||||
.rocm => {
|
.rocm => {
|
||||||
// Disable Triton GEMM on ROCM. For some reason it's much, much slower when
|
// Disable Triton GEMM on ROCM. For some reason it's much, much slower when
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user