From 7e6103d8763e454b7e53746bd560a6e62c1ec5b8 Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Tue, 6 Feb 2024 09:31:48 +0000 Subject: [PATCH] Upgrade XLA to version 20250122.0-cc075be, switch to nvptx compiler and nvlink with nvjitlink support, add warning for CUDA path in LD_LIBRARY_PATH, and revert the previous CUDA sandbox fix. --- MODULE.bazel | 6 +-- runtimes/cpu/cpu.bzl | 8 +-- runtimes/cuda/cuda.bzl | 23 ++++++-- runtimes/cuda/cuda.zig | 34 +++++------- runtimes/cuda/libpjrt_cuda.BUILD.bazel | 5 +- runtimes/rocm/rocm.bzl | 4 +- .../llvm-raw/20250117.0-bf17016/MODULE.bazel | 11 ++++ .../20250117.0-bf17016/overlay/BUILD.bazel | 0 .../20250117.0-bf17016/overlay/MODULE.bazel | 11 ++++ .../overlay/utils/bazel/extension.bzl | 28 ++++++++++ .../llvm-raw/20250117.0-bf17016/source.json | 10 ++++ third_party/modules/llvm-raw/metadata.json | 1 + .../stablehlo/20250117.0-c125b32/MODULE.bazel | 15 ++++++ .../20250117.0-c125b32/overlay/MODULE.bazel | 15 ++++++ .../stablehlo/20250117.0-c125b32/source.json | 8 +++ third_party/modules/stablehlo/metadata.json | 1 + .../xla/20250122.0-cc075be/MODULE.bazel | 34 ++++++++++++ .../20250122.0-cc075be/overlay/MODULE.bazel | 34 ++++++++++++ .../xla/20250122.0-cc075be/overlay/tsl.bzl | 19 +++++++ .../20250122.0-cc075be/overlay/workspace.bzl | 52 +++++++++++++++++++ ...e-C-compliance-for-Profiler-Extensio.patch | 27 ++++++++++ .../xla/20250122.0-cc075be/source.json | 14 +++++ third_party/modules/xla/metadata.json | 1 + zml/module.zig | 16 +++++- 24 files changed, 340 insertions(+), 37 deletions(-) create mode 100644 third_party/modules/llvm-raw/20250117.0-bf17016/MODULE.bazel create mode 100644 third_party/modules/llvm-raw/20250117.0-bf17016/overlay/BUILD.bazel create mode 100644 third_party/modules/llvm-raw/20250117.0-bf17016/overlay/MODULE.bazel create mode 100644 third_party/modules/llvm-raw/20250117.0-bf17016/overlay/utils/bazel/extension.bzl create mode 100644 third_party/modules/llvm-raw/20250117.0-bf17016/source.json create mode 100644 third_party/modules/stablehlo/20250117.0-c125b32/MODULE.bazel create mode 100644 third_party/modules/stablehlo/20250117.0-c125b32/overlay/MODULE.bazel create mode 100644 third_party/modules/stablehlo/20250117.0-c125b32/source.json create mode 100644 third_party/modules/xla/20250122.0-cc075be/MODULE.bazel create mode 100644 third_party/modules/xla/20250122.0-cc075be/overlay/MODULE.bazel create mode 100644 third_party/modules/xla/20250122.0-cc075be/overlay/tsl.bzl create mode 100644 third_party/modules/xla/20250122.0-cc075be/overlay/workspace.bzl create mode 100644 third_party/modules/xla/20250122.0-cc075be/patches/0003-PJRT-C-API-Ensure-C-compliance-for-Profiler-Extensio.patch create mode 100644 third_party/modules/xla/20250122.0-cc075be/source.json diff --git a/MODULE.bazel b/MODULE.bazel index d6e7c13..cfd4028 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -85,7 +85,7 @@ use_repo(zls, "zls_aarch64-macos", "zls_x86_64-linux") register_toolchains("//third_party/zls:all") bazel_dep(name = "libxev", version = "20241208.2-db6a52b") -bazel_dep(name = "llvm-raw", version = "20250102.0-f739aa4") +bazel_dep(name = "llvm-raw", version = "20250117.0-bf17016") llvm = use_extension("@llvm-raw//utils/bazel:extension.bzl", "llvm") llvm.configure( @@ -97,8 +97,8 @@ llvm.configure( ) use_repo(llvm, "llvm-project") -bazel_dep(name = "stablehlo", version = "20241220.0-38bb2f9") -bazel_dep(name = "xla", version = "20250103.0-5f1fe6a") +bazel_dep(name = "stablehlo", version = "20250117.0-c125b32") +bazel_dep(name = "xla", version = "20250122.0-cc075be") tsl = use_extension("@xla//:tsl.bzl", "tsl") use_repo(tsl, "tsl") diff --git a/runtimes/cpu/cpu.bzl b/runtimes/cpu/cpu.bzl index 832ffae..31dfafe 100644 --- a/runtimes/cpu/cpu.bzl +++ b/runtimes/cpu/cpu.bzl @@ -12,15 +12,15 @@ def _cpu_pjrt_plugin_impl(mctx): http_archive( name = "libpjrt_cpu_linux_amd64", build_file_content = _BUILD.format(ext = "so"), - sha256 = "35b6aefa0359317ae2892f846d6da892bee2116d8c6722e397ef0120cf572183", - url = "https://github.com/zml/pjrt-artifacts/releases/download/v4.0.0/pjrt-cpu_linux-amd64.tar.gz", + sha256 = "0f2cb204015e062df5d1cbd39d8c01c076ab2b004d0f4f37f6d5e120d3cd7087", + url = "https://github.com/zml/pjrt-artifacts/releases/download/v5.0.0/pjrt-cpu_linux-amd64.tar.gz", ) http_archive( name = "libpjrt_cpu_darwin_arm64", build_file_content = _BUILD.format(ext = "dylib"), - sha256 = "a532a2e1511f91ec6d6adc60290f6bc4d88d2521508661e90b9824061ebabb3a", - url = "https://github.com/zml/pjrt-artifacts/releases/download/v4.0.0/pjrt-cpu_darwin-arm64.tar.gz", + sha256 = "2ddb66a93c8a913e3bc8f291e01df59aa297592cc91e05aab2dd4813884098cb", + url = "https://github.com/zml/pjrt-artifacts/releases/download/v5.0.0/pjrt-cpu_darwin-arm64.tar.gz", ) return mctx.extension_metadata( diff --git a/runtimes/cuda/cuda.bzl b/runtimes/cuda/cuda.bzl index 4736bb4..af7d2a1 100644 --- a/runtimes/cuda/cuda.bzl +++ b/runtimes/cuda/cuda.bzl @@ -26,6 +26,16 @@ cc_import( ) """.format(name = repr(name), shared_library = repr(shared_library), deps = repr(deps)) +def _cc_import_static(name, static_library, deps = []): + return """\ +cc_import( + name = {name}, + static_library = {static_library}, + deps = {deps}, + visibility = ["@libpjrt_cuda//:__subpackages__"], +) +""".format(name = repr(name), static_library = repr(static_library), deps = repr(deps)) + CUDA_PACKAGES = { "cuda_cudart": _cc_import( name = "cudart", @@ -56,6 +66,10 @@ CUDA_PACKAGES = { name = "ptxas", srcs = ["bin/ptxas"], ), + _filegroup( + name = "nvlink", + srcs = ["bin/nvlink"], + ), _filegroup( name = "libdevice", srcs = ["nvvm/libdevice/libdevice.10.bc"], @@ -64,6 +78,10 @@ CUDA_PACKAGES = { name = "nvvm", shared_library = "nvvm/lib64/libnvvm.so.4", ), + _cc_import_static( + name = "nvptxcompiler", + static_library = "lib/libnvptxcompiler_static.a", + ), ]), "cuda_nvrtc": "\n".join([ _cc_import( @@ -190,9 +208,8 @@ def _cuda_impl(mctx): http_archive( name = "libpjrt_cuda", build_file = "libpjrt_cuda.BUILD.bazel", - url = "https://files.pythonhosted.org/packages/90/43/ac2c369e202e3e3e7e5aa7929b197801ba02eaf11868437adaa5341704e4/jax_cuda12_pjrt-0.4.38-py3-none-manylinux2014_x86_64.whl", - type = "zip", - sha256 = "83be4c59fbcf30077a60085d98e7d59dc738b1c91e0d628e4ac1779fde15ac2b", + url = "https://github.com/zml/pjrt-artifacts/releases/download/v5.0.0/pjrt-cuda_linux-amd64.tar.gz", + sha256 = "1c3ca76d887d112762d03ebb28f17a08beebf6338453c3044a36225e1678a113", ) return mctx.extension_metadata( diff --git a/runtimes/cuda/cuda.zig b/runtimes/cuda/cuda.zig index 4c2bc80..f832ce3 100644 --- a/runtimes/cuda/cuda.zig +++ b/runtimes/cuda/cuda.zig @@ -1,12 +1,10 @@ const builtin = @import("builtin"); const std = @import("std"); - const asynk = @import("async"); -const bazel_builtin = @import("bazel_builtin"); -const c = @import("c"); const pjrt = @import("pjrt"); -const runfiles = @import("runfiles"); -const stdx = @import("stdx"); +const c = @import("c"); + +const nvidiaLibsPath = "/usr/local/cuda/lib64"; pub fn isEnabled() bool { return @hasDecl(c, "ZML_RUNTIME_CUDA"); @@ -17,21 +15,14 @@ fn hasNvidiaDevice() bool { return true; } -fn setupXlaGpuCudaDirFlag() !void { - var arena = std.heap.ArenaAllocator.init(std.heap.c_allocator); - defer arena.deinit(); +fn hasCudaPathInLDPath() bool { + const ldLibraryPath = c.getenv("LD_LIBRARY_PATH"); - var r_ = try runfiles.Runfiles.create(.{ .allocator = arena.allocator() }) orelse { - stdx.debug.panic("Unable to find CUDA directory", .{}); - }; + if (ldLibraryPath == null) { + return false; + } - const source_repo = bazel_builtin.current_repository; - const r = r_.withSourceRepo(source_repo); - const cuda_data_dir = (try r.rlocationAlloc(arena.allocator(), "libpjrt_cuda/sandbox")).?; - const xla_flags = std.process.getEnvVarOwned(arena.allocator(), "XLA_FLAGS") catch ""; - const new_xla_flagsZ = try std.fmt.allocPrintZ(arena.allocator(), "--xla_gpu_cuda_data_dir={s} {s}", .{ cuda_data_dir, xla_flags }); - - _ = c.setenv("XLA_FLAGS", new_xla_flagsZ, 1); + return std.mem.indexOf(u8, std.mem.span(ldLibraryPath), nvidiaLibsPath) != null; } pub fn load() !*const pjrt.Api { @@ -44,10 +35,9 @@ pub fn load() !*const pjrt.Api { if (!hasNvidiaDevice()) { return error.Unavailable; } - - // CUDA path has to be set _before_ loading the PJRT plugin. - // See https://github.com/openxla/xla/issues/21428 - try setupXlaGpuCudaDirFlag(); + if (hasCudaPathInLDPath()) { + std.log.warn("Detected {s} in LD_LIBRARY_PATH. This can lead to undefined behaviors and crashes", .{nvidiaLibsPath}); + } return try asynk.callBlocking(pjrt.Api.loadFrom, .{"libpjrt_cuda.so"}); } diff --git a/runtimes/cuda/libpjrt_cuda.BUILD.bazel b/runtimes/cuda/libpjrt_cuda.BUILD.bazel index ab2ee83..95ea199 100644 --- a/runtimes/cuda/libpjrt_cuda.BUILD.bazel +++ b/runtimes/cuda/libpjrt_cuda.BUILD.bazel @@ -17,6 +17,7 @@ copy_to_directory( srcs = [ "@cuda_nvcc//:libdevice", "@cuda_nvcc//:ptxas", + "@cuda_nvcc//:nvlink", ], include_external_repositories = ["**"], ) @@ -24,8 +25,7 @@ copy_to_directory( cc_import( name = "libpjrt_cuda", data = [":sandbox"], - shared_library = "jax_plugins/xla_cuda12/xla_cuda_plugin.so", - soname = "libpjrt_cuda.so", + shared_library = "libpjrt_cuda.so", add_needed = ["libzmlxcuda.so.0"], rename_dynamic_symbols = { "dlopen": "zmlxcuda_dlopen", @@ -35,6 +35,7 @@ cc_import( ":zmlxcuda", "@cuda_cudart//:cudart", "@cuda_cupti//:cupti", + "@cuda_nvcc//:nvptxcompiler", "@cuda_nvcc//:nvvm", "@cuda_nvrtc//:nvrtc", "@cudnn//:cudnn", diff --git a/runtimes/rocm/rocm.bzl b/runtimes/rocm/rocm.bzl index 07242c7..0eacc9e 100644 --- a/runtimes/rocm/rocm.bzl +++ b/runtimes/rocm/rocm.bzl @@ -215,8 +215,8 @@ def _rocm_impl(mctx): http_archive( name = "libpjrt_rocm", build_file = "libpjrt_rocm.BUILD.bazel", - url = "https://github.com/zml/pjrt-artifacts/releases/download/v4.0.0/pjrt-rocm_linux-amd64.tar.gz", - sha256 = "75c2baf2efba0b2c6fe2513d06e542ed3f3a966e43498cc1d932465f646ca34d", + url = "https://github.com/zml/pjrt-artifacts/releases/download/v5.0.0/pjrt-rocm_linux-amd64.tar.gz", + sha256 = "2c7a687827f63987caa117cd5b56a6e20291681ae1c51edd54241a1181e91d2d", ) return mctx.extension_metadata( diff --git a/third_party/modules/llvm-raw/20250117.0-bf17016/MODULE.bazel b/third_party/modules/llvm-raw/20250117.0-bf17016/MODULE.bazel new file mode 100644 index 0000000..b1171a0 --- /dev/null +++ b/third_party/modules/llvm-raw/20250117.0-bf17016/MODULE.bazel @@ -0,0 +1,11 @@ +module( + name = "llvm-raw", + version = "20250117.0-bf17016", + compatibility_level = 1, +) + +bazel_dep(name = "bazel_skylib", version = "1.7.1") +bazel_dep(name = "platforms", version = "0.0.10") +bazel_dep(name = "zstd", version = "1.5.6", repo_name = "llvm_zstd") +bazel_dep(name = "zlib", version = "1.3.1.bcr.3", repo_name = "llvm_zlib") +bazel_dep(name = "rules_python", version = "0.29.0") diff --git a/third_party/modules/llvm-raw/20250117.0-bf17016/overlay/BUILD.bazel b/third_party/modules/llvm-raw/20250117.0-bf17016/overlay/BUILD.bazel new file mode 100644 index 0000000..e69de29 diff --git a/third_party/modules/llvm-raw/20250117.0-bf17016/overlay/MODULE.bazel b/third_party/modules/llvm-raw/20250117.0-bf17016/overlay/MODULE.bazel new file mode 100644 index 0000000..b1171a0 --- /dev/null +++ b/third_party/modules/llvm-raw/20250117.0-bf17016/overlay/MODULE.bazel @@ -0,0 +1,11 @@ +module( + name = "llvm-raw", + version = "20250117.0-bf17016", + compatibility_level = 1, +) + +bazel_dep(name = "bazel_skylib", version = "1.7.1") +bazel_dep(name = "platforms", version = "0.0.10") +bazel_dep(name = "zstd", version = "1.5.6", repo_name = "llvm_zstd") +bazel_dep(name = "zlib", version = "1.3.1.bcr.3", repo_name = "llvm_zlib") +bazel_dep(name = "rules_python", version = "0.29.0") diff --git a/third_party/modules/llvm-raw/20250117.0-bf17016/overlay/utils/bazel/extension.bzl b/third_party/modules/llvm-raw/20250117.0-bf17016/overlay/utils/bazel/extension.bzl new file mode 100644 index 0000000..f247ed4 --- /dev/null +++ b/third_party/modules/llvm-raw/20250117.0-bf17016/overlay/utils/bazel/extension.bzl @@ -0,0 +1,28 @@ +load("//utils/bazel:configure.bzl", _llvm_configure = "llvm_configure") + +def _llvm_impl(mctx): + _targets = {} + for mod in mctx.modules: + for conf in mod.tags.configure: + for target in conf.targets: + _targets[target] = True + _llvm_configure( + name = "llvm-project", + targets = _targets.keys(), + ) + return mctx.extension_metadata( + reproducible = True, + root_module_direct_deps = "all", + root_module_direct_dev_deps = [], + ) + +llvm = module_extension( + implementation = _llvm_impl, + tag_classes = { + "configure": tag_class( + attrs = { + "targets": attr.string_list(mandatory = True), + }, + ), + }, +) diff --git a/third_party/modules/llvm-raw/20250117.0-bf17016/source.json b/third_party/modules/llvm-raw/20250117.0-bf17016/source.json new file mode 100644 index 0000000..851e85a --- /dev/null +++ b/third_party/modules/llvm-raw/20250117.0-bf17016/source.json @@ -0,0 +1,10 @@ +{ + "strip_prefix": "llvm-project-bf17016a92bc8a23d2cdd2b51355dd4eb5019c68", + "url": "https://github.com/llvm/llvm-project/archive/bf17016a92bc8a23d2cdd2b51355dd4eb5019c68.tar.gz", + "integrity": "sha256-ugnxLlAZ9aylMbFzMnXwoQsYHW+JTesaRhDgF/drFyo=", + "overlay": { + "BUILD.bazel": "", + "MODULE.bazel": "", + "utils/bazel/extension.bzl": "" + } +} diff --git a/third_party/modules/llvm-raw/metadata.json b/third_party/modules/llvm-raw/metadata.json index 1028f27..16857b3 100644 --- a/third_party/modules/llvm-raw/metadata.json +++ b/third_party/modules/llvm-raw/metadata.json @@ -15,6 +15,7 @@ "20240919.0-94c024a", "20241022.0-6c4267f", "20250102.0-f739aa4", + "20250117.0-bf17016", ], "yanked_versions": {} } diff --git a/third_party/modules/stablehlo/20250117.0-c125b32/MODULE.bazel b/third_party/modules/stablehlo/20250117.0-c125b32/MODULE.bazel new file mode 100644 index 0000000..2aabd32 --- /dev/null +++ b/third_party/modules/stablehlo/20250117.0-c125b32/MODULE.bazel @@ -0,0 +1,15 @@ +module( + name = "stablehlo", + version = "20250117.0-c125b32", + compatibility_level = 1, +) + +bazel_dep(name = "bazel_skylib", version = "1.7.1") +bazel_dep(name = "rules_cc", version = "0.0.9") +bazel_dep(name = "llvm-raw", version = "20250117.0-bf17016") + +llvm = use_extension("@llvm-raw//utils/bazel:extension.bzl", "llvm") +llvm.configure( + targets = ["AArch64", "X86", "NVPTX"], +) +use_repo(llvm, "llvm-project") diff --git a/third_party/modules/stablehlo/20250117.0-c125b32/overlay/MODULE.bazel b/third_party/modules/stablehlo/20250117.0-c125b32/overlay/MODULE.bazel new file mode 100644 index 0000000..2aabd32 --- /dev/null +++ b/third_party/modules/stablehlo/20250117.0-c125b32/overlay/MODULE.bazel @@ -0,0 +1,15 @@ +module( + name = "stablehlo", + version = "20250117.0-c125b32", + compatibility_level = 1, +) + +bazel_dep(name = "bazel_skylib", version = "1.7.1") +bazel_dep(name = "rules_cc", version = "0.0.9") +bazel_dep(name = "llvm-raw", version = "20250117.0-bf17016") + +llvm = use_extension("@llvm-raw//utils/bazel:extension.bzl", "llvm") +llvm.configure( + targets = ["AArch64", "X86", "NVPTX"], +) +use_repo(llvm, "llvm-project") diff --git a/third_party/modules/stablehlo/20250117.0-c125b32/source.json b/third_party/modules/stablehlo/20250117.0-c125b32/source.json new file mode 100644 index 0000000..fd8426b --- /dev/null +++ b/third_party/modules/stablehlo/20250117.0-c125b32/source.json @@ -0,0 +1,8 @@ +{ + "strip_prefix": "stablehlo-c125b3284819fec57120231cf4430657dab7b881", + "url": "https://github.com/openxla/stablehlo/archive/c125b3284819fec57120231cf4430657dab7b881.tar.gz", + "integrity": "sha256-h4iBDLVbK2JZRDY7qdmNdns25AoVjpLekKdmsETEDf8=", + "overlay": { + "MODULE.bazel": "" + } +} diff --git a/third_party/modules/stablehlo/metadata.json b/third_party/modules/stablehlo/metadata.json index 2e6a693..72a2e77 100644 --- a/third_party/modules/stablehlo/metadata.json +++ b/third_party/modules/stablehlo/metadata.json @@ -15,6 +15,7 @@ "20240917.0-78c753a", "20241021.0-1c0b606", "20241220.0-38bb2f9", + "20250117.0-c125b32", ], "yanked_versions": {} } diff --git a/third_party/modules/xla/20250122.0-cc075be/MODULE.bazel b/third_party/modules/xla/20250122.0-cc075be/MODULE.bazel new file mode 100644 index 0000000..fd99777 --- /dev/null +++ b/third_party/modules/xla/20250122.0-cc075be/MODULE.bazel @@ -0,0 +1,34 @@ +module( + name = "xla", + version = "20250122.0-cc075be", + compatibility_level = 1, +) + +bazel_dep(name = "platforms", version = "0.0.8") +bazel_dep(name = "bazel_skylib", version = "1.5.0") +bazel_dep(name = "rules_cc", version = "0.0.9") +bazel_dep(name = "rules_apple", version = "3.2.1", repo_name = "build_bazel_rules_apple") +bazel_dep(name = "abseil-cpp", version = "20240116.0", repo_name = "com_google_absl") +bazel_dep(name = "rules_python", version = "0.29.0") +bazel_dep(name = "rules_proto", version = "6.0.0-rc1") +bazel_dep(name = "rules_java", version = "7.3.2") +bazel_dep(name = "rules_pkg", version = "0.9.1") +bazel_dep(name = "zlib", version = "1.2.13") +bazel_dep(name = "re2", version = "2024-02-01", repo_name = "com_googlesource_code_re2") +bazel_dep(name = "rules_license", version = "0.0.8") + +bazel_dep(name = "stablehlo", version = "20250117.0-c125b32") + +tsl = use_extension("//:tsl.bzl", "tsl") +use_repo(tsl, "tsl") + +xla_workspace = use_extension("//:workspace.bzl", "xla_workspace") +use_repo( + xla_workspace, + "com_github_grpc_grpc", + "com_google_protobuf", + "local_config_cuda", + "local_config_remote_execution", + "local_config_rocm", + "local_config_tensorrt", +) diff --git a/third_party/modules/xla/20250122.0-cc075be/overlay/MODULE.bazel b/third_party/modules/xla/20250122.0-cc075be/overlay/MODULE.bazel new file mode 100644 index 0000000..fd99777 --- /dev/null +++ b/third_party/modules/xla/20250122.0-cc075be/overlay/MODULE.bazel @@ -0,0 +1,34 @@ +module( + name = "xla", + version = "20250122.0-cc075be", + compatibility_level = 1, +) + +bazel_dep(name = "platforms", version = "0.0.8") +bazel_dep(name = "bazel_skylib", version = "1.5.0") +bazel_dep(name = "rules_cc", version = "0.0.9") +bazel_dep(name = "rules_apple", version = "3.2.1", repo_name = "build_bazel_rules_apple") +bazel_dep(name = "abseil-cpp", version = "20240116.0", repo_name = "com_google_absl") +bazel_dep(name = "rules_python", version = "0.29.0") +bazel_dep(name = "rules_proto", version = "6.0.0-rc1") +bazel_dep(name = "rules_java", version = "7.3.2") +bazel_dep(name = "rules_pkg", version = "0.9.1") +bazel_dep(name = "zlib", version = "1.2.13") +bazel_dep(name = "re2", version = "2024-02-01", repo_name = "com_googlesource_code_re2") +bazel_dep(name = "rules_license", version = "0.0.8") + +bazel_dep(name = "stablehlo", version = "20250117.0-c125b32") + +tsl = use_extension("//:tsl.bzl", "tsl") +use_repo(tsl, "tsl") + +xla_workspace = use_extension("//:workspace.bzl", "xla_workspace") +use_repo( + xla_workspace, + "com_github_grpc_grpc", + "com_google_protobuf", + "local_config_cuda", + "local_config_remote_execution", + "local_config_rocm", + "local_config_tensorrt", +) diff --git a/third_party/modules/xla/20250122.0-cc075be/overlay/tsl.bzl b/third_party/modules/xla/20250122.0-cc075be/overlay/tsl.bzl new file mode 100644 index 0000000..dc25c9d --- /dev/null +++ b/third_party/modules/xla/20250122.0-cc075be/overlay/tsl.bzl @@ -0,0 +1,19 @@ +load("//third_party:repo.bzl", "tf_vendored") +load("//third_party/py:python_init_repositories.bzl", "python_init_repositories") + +def _tsl_impl(mctx): + python_init_repositories( + requirements = { + "3.11": "//:requirements_lock_3_11.txt", + }, + ) + tf_vendored(name = "tsl", relpath = "third_party/tsl") + return mctx.extension_metadata( + reproducible = True, + root_module_direct_deps = ["tsl"], + root_module_direct_dev_deps = [], + ) + +tsl = module_extension( + implementation = _tsl_impl, +) diff --git a/third_party/modules/xla/20250122.0-cc075be/overlay/workspace.bzl b/third_party/modules/xla/20250122.0-cc075be/overlay/workspace.bzl new file mode 100644 index 0000000..87c035b --- /dev/null +++ b/third_party/modules/xla/20250122.0-cc075be/overlay/workspace.bzl @@ -0,0 +1,52 @@ +load("@tsl//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") +load("@tsl//third_party/gpus:cuda_configure.bzl", "cuda_configure") +load("@tsl//third_party/gpus:rocm_configure.bzl", "rocm_configure") +load("@tsl//third_party/tensorrt:tensorrt_configure.bzl", "tensorrt_configure") +load("@tsl//tools/toolchains/remote:configure.bzl", "remote_execution_configure") + +def _xla_workspace_impl(mctx): + cuda_configure(name = "local_config_cuda") + remote_execution_configure(name = "local_config_remote_execution") + rocm_configure(name = "local_config_rocm") + tensorrt_configure(name = "local_config_tensorrt") + tf_http_archive( + name = "com_github_grpc_grpc", + sha256 = "b956598d8cbe168b5ee717b5dafa56563eb5201a947856a6688bbeac9cac4e1f", + strip_prefix = "grpc-b54a5b338637f92bfcf4b0bc05e0f57a5fd8fadd", + system_build_file = "@tsl//third_party/systemlibs:grpc.BUILD", + patch_file = [ + "@tsl//third_party/grpc:generate_cc_env_fix.patch", + "@tsl//third_party/grpc:register_go_toolchain.patch", + ], + system_link_files = { + "@tsl//third_party/systemlibs:BUILD": "bazel/BUILD", + "@tsl//third_party/systemlibs:grpc.BUILD": "src/compiler/BUILD", + "@tsl//third_party/systemlibs:grpc.bazel.grpc_deps.bzl": "bazel/grpc_deps.bzl", + "@tsl//third_party/systemlibs:grpc.bazel.grpc_extra_deps.bzl": "bazel/grpc_extra_deps.bzl", + "@tsl//third_party/systemlibs:grpc.bazel.cc_grpc_library.bzl": "bazel/cc_grpc_library.bzl", + "@tsl//third_party/systemlibs:grpc.bazel.generate_cc.bzl": "bazel/generate_cc.bzl", + "@tsl//third_party/systemlibs:grpc.bazel.protobuf.bzl": "bazel/protobuf.bzl", + }, + urls = tf_mirror_urls("https://github.com/grpc/grpc/archive/b54a5b338637f92bfcf4b0bc05e0f57a5fd8fadd.tar.gz"), + ) + tf_http_archive( + name = "com_google_protobuf", + patch_file = ["@tsl//third_party/protobuf:protobuf.patch"], + sha256 = "f66073dee0bc159157b0bd7f502d7d1ee0bc76b3c1eac9836927511bdc4b3fc1", + strip_prefix = "protobuf-3.21.9", + system_build_file = "@tsl//third_party/systemlibs:protobuf.BUILD", + system_link_files = { + "@tsl//third_party/systemlibs:protobuf.bzl": "protobuf.bzl", + "@tsl//third_party/systemlibs:protobuf_deps.bzl": "protobuf_deps.bzl", + }, + urls = tf_mirror_urls("https://github.com/protocolbuffers/protobuf/archive/v3.21.9.zip"), + ) + return mctx.extension_metadata( + reproducible = True, + root_module_direct_deps = "all", + root_module_direct_dev_deps = [], + ) + +xla_workspace = module_extension( + implementation = _xla_workspace_impl, +) diff --git a/third_party/modules/xla/20250122.0-cc075be/patches/0003-PJRT-C-API-Ensure-C-compliance-for-Profiler-Extensio.patch b/third_party/modules/xla/20250122.0-cc075be/patches/0003-PJRT-C-API-Ensure-C-compliance-for-Profiler-Extensio.patch new file mode 100644 index 0000000..0f58f10 --- /dev/null +++ b/third_party/modules/xla/20250122.0-cc075be/patches/0003-PJRT-C-API-Ensure-C-compliance-for-Profiler-Extensio.patch @@ -0,0 +1,27 @@ +From 4db5de34f70d991fedbe28915c8239b97ba7a064 Mon Sep 17 00:00:00 2001 +From: Steeve Morin +Date: Mon, 18 Mar 2024 17:17:34 +0100 +Subject: [PATCH 3/3] [PJRT C API] Ensure C compliance for Profiler Extension + +--- + xla/pjrt/c/pjrt_c_api_profiler_extension.h | 2 ++ + 1 file changed, 2 insertions(+) + +diff --git a/xla/pjrt/c/pjrt_c_api_profiler_extension.h b/xla/pjrt/c/pjrt_c_api_profiler_extension.h +index c821916ad..89a596123 100644 +--- a/xla/pjrt/c/pjrt_c_api_profiler_extension.h ++++ b/xla/pjrt/c/pjrt_c_api_profiler_extension.h +@@ -16,8 +16,10 @@ limitations under the License. + #ifndef XLA_PJRT_C_PJRT_C_API_PROFILER_EXTENSION_H_ + #define XLA_PJRT_C_PJRT_C_API_PROFILER_EXTENSION_H_ + ++#ifdef __cplusplus + #include + #include ++#endif + + #include "xla/backends/profiler/plugin/profiler_c_api.h" + #include "xla/pjrt/c/pjrt_c_api.h" +-- +2.39.3 (Apple Git-146) + diff --git a/third_party/modules/xla/20250122.0-cc075be/source.json b/third_party/modules/xla/20250122.0-cc075be/source.json new file mode 100644 index 0000000..4946d80 --- /dev/null +++ b/third_party/modules/xla/20250122.0-cc075be/source.json @@ -0,0 +1,14 @@ +{ + "strip_prefix": "xla-cc075beb6148c2777da2b6749c63830856ee6c2a", + "url": "https://github.com/openxla/xla/archive/cc075beb6148c2777da2b6749c63830856ee6c2a.tar.gz", + "integrity": "sha256-oB8S38WZKEXBtZ6rARd0oL9SAtSmfsPM2xTuj3ylexc=", + "overlay": { + "tsl.bzl": "", + "workspace.bzl": "", + "MODULE.bazel": "" + }, + "patch_strip": 1, + "patches": { + "0003-PJRT-C-API-Ensure-C-compliance-for-Profiler-Extensio.patch": "" + } +} diff --git a/third_party/modules/xla/metadata.json b/third_party/modules/xla/metadata.json index fd7372a..3b1e47a 100644 --- a/third_party/modules/xla/metadata.json +++ b/third_party/modules/xla/metadata.json @@ -15,6 +15,7 @@ "20240919.0-1b18dd6", "20241025.0-4663f04", "20250103.0-5f1fe6a", + "20250122.0-cc075be", ], "yanked_versions": {} } diff --git a/zml/module.zig b/zml/module.zig index 73569cb..58d4c8b 100644 --- a/zml/module.zig +++ b/zml/module.zig @@ -2,6 +2,7 @@ const std = @import("std"); const asynk = @import("async"); const dialect = @import("mlir/dialects"); +const runfiles = @import("runfiles"); const stdx = @import("stdx"); const xla_pb = @import("//xla:xla_proto"); @@ -901,11 +902,13 @@ fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, m } } switch (platform.target) { - .cuda => { + .cuda => cuda_dir: { // NVIDIA recommends these settings // https://github.com/NVIDIA/JAX-Toolbox?tab=readme-ov-file#environment-variables setFlag(&options, "xla_gpu_enable_triton_gemm", false); setFlag(&options, "xla_gpu_enable_latency_hiding_scheduler", true); + setFlag(&options, "xla_gpu_enable_llvm_module_compilation_parallelism", true); + setFlag(&options, "xla_gpu_enable_libnvptxcompiler", true); // setFlag(&options, "xla_gpu_enable_cudnn_fmha", true); // setFlag(&options, "xla_gpu_fused_attention_use_cudnn_rng", true); // setFlag(&options, "xla_gpu_enable_cudnn_layer_norm", true); @@ -913,6 +916,17 @@ fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, m // setFlag(&options, "xla_gpu_enable_dynamic_slice_fusion", true); // setFlag(&options, "xla_gpu_enable_while_loop_double_buffering", true); // setFlag(&options, "xla_gpu_use_runtime_fusion", true); + + var r_ = try runfiles.Runfiles.create(.{ .allocator = arena }) orelse { + log.warn("Bazel runfile not found !", .{}); + break :cuda_dir; + }; + defer r_.deinit(arena); + const source_repo = @import("bazel_builtin").current_repository; + const r = r_.withSourceRepo(source_repo); + const cuda_data_dir = (try r.rlocationAlloc(arena, "libpjrt_cuda/sandbox")).?; + log.debug("xla_gpu_cuda_data_dir: {s}", .{cuda_data_dir}); + setFlag(&options, "xla_gpu_cuda_data_dir", cuda_data_dir); }, .rocm => { // Disable Triton GEMM on ROCM. For some reason it's much, much slower when