From 01da2184feeef876292db5f14b5673c99a13b7cf Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Tue, 12 Aug 2025 13:32:18 +0000 Subject: [PATCH] xla: bump to commit b3fbfee, temporarily disable libnvptxcompiler due to missing support in PjRT CUDA plugin v13.0, add nvshmem to sandbox for PjRT CUDA plugin --- pjrt/pjrt.zig | 1 - runtimes/cpu/cpu.bzl | 12 +- runtimes/cuda/cuda.bzl | 39 ++++- runtimes/cuda/libpjrt_cuda.BUILD.bazel | 2 +- runtimes/rocm/rocm.bzl | 4 +- runtimes/tpu/tpu.bzl | 4 +- .../0001-bazel-migration-to-bazel-8.1.1.patch | 41 ------ ...ler-registration-API-to-the-FFI-PjRt.patch | 135 ------------------ ...ove-unconventional-C-code-in-headers.patch | 124 ---------------- third_party/xla/repo.bzl | 6 +- third_party/xla/xla.bzl | 10 +- zml/module.zig | 1 - 12 files changed, 54 insertions(+), 325 deletions(-) delete mode 100644 third_party/xla/patches/0001-bazel-migration-to-bazel-8.1.1.patch delete mode 100644 third_party/xla/patches/0002-Added-FFI-handler-registration-API-to-the-FFI-PjRt.patch delete mode 100644 third_party/xla/patches/0003-Remove-unconventional-C-code-in-headers.patch diff --git a/pjrt/pjrt.zig b/pjrt/pjrt.zig index 2458149..6e35856 100644 --- a/pjrt/pjrt.zig +++ b/pjrt/pjrt.zig @@ -1308,7 +1308,6 @@ pub const FFI = extern struct { options: RegisterFfiOptions, ) ApiError!void { var ret = pjrtStruct(c.PJRT_FFI_Register_Handler_Args{ - .api_version = 1, .target_name = target_name.ptr, .target_name_size = target_name.len, .handler = @ptrCast(@constCast(func)), diff --git a/runtimes/cpu/cpu.bzl b/runtimes/cpu/cpu.bzl index e4d5536..dbc69d7 100644 --- a/runtimes/cpu/cpu.bzl +++ b/runtimes/cpu/cpu.bzl @@ -23,22 +23,22 @@ def _cpu_pjrt_plugin_impl(mctx): http_archive( name = "libpjrt_cpu_linux_amd64", build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + _BUILD_LINUX, - sha256 = "cf5ea44b14a6ddc320c5b1d7bb88328dda099571e106b17a214b6ec586e321b8", - url = "https://github.com/zml/pjrt-artifacts/releases/download/v12.0.0/pjrt-cpu_linux-amd64.tar.gz", + sha256 = "124dc500291a5930f910ca23533520e22c90797110b29fd2c0d8274475f4a220", + url = "https://github.com/zml/pjrt-artifacts/releases/download/v13.0.0/pjrt-cpu_linux-amd64.tar.gz", ) http_archive( name = "libpjrt_cpu_darwin_amd64", build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + _BUILD_DARWIN, - sha256 = "1b48c26b0d2709730df6921a86d9ae97ab71d3f2cfa17fd1a370e85235370914", - url = "https://github.com/zml/pjrt-artifacts/releases/download/v12.0.0/pjrt-cpu_darwin-amd64.tar.gz", + sha256 = "6e5b59874880f4db37c53fb1d52520d410b0078f9d2606a90762c6c622693c26", + url = "https://github.com/zml/pjrt-artifacts/releases/download/v13.0.0/pjrt-cpu_darwin-amd64.tar.gz", ) http_archive( name = "libpjrt_cpu_darwin_arm64", build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + _BUILD_DARWIN, - sha256 = "66dc6c65933a6d7985b1c69837d8abe13750cd1f06704427f6989c3f952d3511", - url = "https://github.com/zml/pjrt-artifacts/releases/download/v12.0.0/pjrt-cpu_darwin-arm64.tar.gz", + sha256 = "e54bc3c3b71313c49e38fc342851bffe2cd3a55e3990839947d49496a1a71270", + url = "https://github.com/zml/pjrt-artifacts/releases/download/v13.0.0/pjrt-cpu_darwin-arm64.tar.gz", ) return mctx.extension_metadata( diff --git a/runtimes/cuda/cuda.bzl b/runtimes/cuda/cuda.bzl index 9920103..a99d460 100644 --- a/runtimes/cuda/cuda.bzl +++ b/runtimes/cuda/cuda.bzl @@ -17,6 +17,10 @@ CUDNN_REDIST_PREFIX = "https://developer.download.nvidia.com/compute/cudnn/redis CUDNN_VERSION = "9.8.0" CUDNN_REDIST_JSON_SHA256 = "a1599fa1f8dcb81235157be5de5ab7d3936e75dfc4e1e442d07970afad3c4843" +NVSHMEM_REDIST_PREFIX = "https://developer.download.nvidia.com/compute/nvshmem/redist/" +NVSHMEM_VERSION = "3.2.5" +NVSHMEM_REDIST_JSON_SHA256 = "6945425d3bfd24de23c045996f93ec720c010379bfd6f0860ac5f2716659442d" + _UBUNTU_PACKAGES = { "zlib1g": packages.filegroup(name = "zlib1g", srcs = ["lib/x86_64-linux-gnu/libz.so.1"]), } @@ -119,6 +123,17 @@ CUDNN_PACKAGES = { ]), } +NVSHMEM_PACKAGES = { + "libnvshmem": packages.filegroup( + name = "libnvshmem", + srcs = [ + "lib/libnvshmem_host.so.3", + "lib/nvshmem_bootstrap_uid.so.3", + "lib/nvshmem_transport_ibrc.so.3", + ], + ), +} + def _read_redist_json(mctx, url, sha256): fname = ".{}.json".format(sha256) mctx.download( @@ -138,6 +153,12 @@ def _cuda_impl(mctx): sha256 = CUDA_REDIST_JSON_SHA256, ) + NVSHMEM_REDIST = _read_redist_json( + mctx, + url = NVSHMEM_REDIST_PREFIX + "redistrib_{}.json".format(NVSHMEM_VERSION), + sha256 = NVSHMEM_REDIST_JSON_SHA256, + ) + CUDNN_REDIST = _read_redist_json( mctx, url = CUDNN_REDIST_PREFIX + "redistrib_{}.json".format(CUDNN_VERSION), @@ -180,6 +201,20 @@ def _cuda_impl(mctx): strip_prefix = paths.basename(arch_data["relative_path"]).replace(".tar.xz", ""), ) + for pkg, build_file_content in NVSHMEM_PACKAGES.items(): + pkg_data = NVSHMEM_REDIST[pkg] + arch_data = pkg_data.get(ARCH) + if not arch_data: + continue + arch_data = arch_data.get("cuda12", arch_data) + http_archive( + name = pkg, + build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + build_file_content, + url = NVSHMEM_REDIST_PREFIX + arch_data["relative_path"], + sha256 = arch_data["sha256"], + strip_prefix = paths.basename(arch_data["relative_path"]).replace(".tar.xz", ""), + ) + http_archive( name = "nccl", urls = ["https://files.pythonhosted.org/packages/11/0c/8c78b7603f4e685624a3ea944940f1e75f36d71bd6504330511f4a0e1557/nvidia_nccl_cu12-2.25.1-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl"], @@ -194,8 +229,8 @@ def _cuda_impl(mctx): http_archive( name = "libpjrt_cuda", build_file = "libpjrt_cuda.BUILD.bazel", - url = "https://github.com/zml/pjrt-artifacts/releases/download/v11.0.0/pjrt-cuda_linux-amd64.tar.gz", - sha256 = "08fa022a6067ddfb5c951bdf11ddc398e63de21fdcacc9ffd07f70b1463482c2", + url = "https://github.com/zml/pjrt-artifacts/releases/download/v13.0.0/pjrt-cuda_linux-amd64.tar.gz", + sha256 = "6cdac9bac6db904e4423c9745c61000cf3acaf3c7da8016ab0016f076869048a", ) return mctx.extension_metadata( diff --git a/runtimes/cuda/libpjrt_cuda.BUILD.bazel b/runtimes/cuda/libpjrt_cuda.BUILD.bazel index 031ec50..f63f0ff 100644 --- a/runtimes/cuda/libpjrt_cuda.BUILD.bazel +++ b/runtimes/cuda/libpjrt_cuda.BUILD.bazel @@ -1,5 +1,4 @@ load("@aspect_bazel_lib//lib:copy_to_directory.bzl", "copy_to_directory") -load("@zml//bazel:cc_import.bzl", "cc_import") load("@zml//bazel:patchelf.bzl", "patchelf") cc_shared_library( @@ -30,6 +29,7 @@ copy_to_directory( "@cuda_nvrtc", "@cuda_nvtx", "@cudnn", + "@libnvshmem", "@libcublas", "@libcufft", "@libcusolver", diff --git a/runtimes/rocm/rocm.bzl b/runtimes/rocm/rocm.bzl index 1d11da3..e71f1a9 100644 --- a/runtimes/rocm/rocm.bzl +++ b/runtimes/rocm/rocm.bzl @@ -153,8 +153,8 @@ def _rocm_impl(mctx): http_archive( name = "libpjrt_rocm", build_file = "libpjrt_rocm.BUILD.bazel", - url = "https://github.com/zml/pjrt-artifacts/releases/download/v12.0.0/pjrt-rocm_linux-amd64.tar.gz", - sha256 = "709982b959595750545a01d125adf4893c42f05c60ec290425276bba8aa49f64", + url = "https://github.com/zml/pjrt-artifacts/releases/download/v13.0.0/pjrt-rocm_linux-amd64.tar.gz", + sha256 = "945c43c68325c0e91cd41eaa594a9f9f6e78da7cc06892d83bf345b69f7bd714", ) return mctx.extension_metadata( diff --git a/runtimes/tpu/tpu.bzl b/runtimes/tpu/tpu.bzl index d838415..e637d8d 100644 --- a/runtimes/tpu/tpu.bzl +++ b/runtimes/tpu/tpu.bzl @@ -4,9 +4,9 @@ def _tpu_impl(mctx): # https://storage.googleapis.com/jax-releases/libtpu_releases.html http_archive( name = "libpjrt_tpu", - url = "https://storage.googleapis.com/libtpu-nightly-releases/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20250102+nightly-py3-none-linux_x86_64.whl", + url = "https://storage.googleapis.com/libtpu-nightly-releases/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20250807+nightly-py3-none-manylinux_2_31_x86_64.whl", type = "zip", - sha256 = "df8339b4f852bd19ad4ed380facc08f28c04e214e9dabb88863e70907b08817e", + sha256 = "41c19fa5ae4a32fbd05f0260527ba2d93afb6cf128e6c4de7773e9011c7b3df5", build_file = "libpjrt_tpu.BUILD.bazel", ) return mctx.extension_metadata( diff --git a/third_party/xla/patches/0001-bazel-migration-to-bazel-8.1.1.patch b/third_party/xla/patches/0001-bazel-migration-to-bazel-8.1.1.patch deleted file mode 100644 index 8924cf4..0000000 --- a/third_party/xla/patches/0001-bazel-migration-to-bazel-8.1.1.patch +++ /dev/null @@ -1,41 +0,0 @@ -From 6cf475b500521c1b8be06f590fdbc1818f0dc44b Mon Sep 17 00:00:00 2001 -From: Jean-Baptiste Dalido -Date: Mon, 6 Jan 2025 13:33:13 +0100 -Subject: [PATCH] bazel: migration to bazel 8.0.1 - ---- - .bazelversion | 2 +- - third_party/tsl/third_party/gpus/cuda_configure.bzl | 4 ++-- - 2 files changed, 3 insertions(+), 3 deletions(-) - -diff --git a/.bazelversion b/.bazelversion -index f22d756da3..fa5fce04b3 100644 ---- a/.bazelversion -+++ b/.bazelversion -@@ -1 +1 @@ --7.4.1 -+8.1.1 -\ No newline at end of file -diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl -index d62531152d..71d80a5a99 100644 ---- a/third_party/gpus/cuda_configure.bzl -+++ b/third_party/gpus/cuda_configure.bzl -@@ -33,14 +33,14 @@ NB: DEPRECATED! Use `hermetic/cuda_configure` rule instead. - load( - "@bazel_tools//tools/cpp:lib_cc_configure.bzl", - "escape_string", -- "get_env_var", - ) - load( - "@bazel_tools//tools/cpp:windows_cc_configure.bzl", -- "find_msvc_tool", - "find_vc_path", - "setup_vc_env_vars", - ) -+load("@rules_cc//cc/private/toolchain:windows_cc_configure.bzl", "find_msvc_tool") -+load("@rules_cc//cc/private/toolchain:lib_cc_configure.bzl", "get_env_var") - load("//third_party/clang_toolchain:download_clang.bzl", "download_clang") - load( - "//third_party/remote_config:common.bzl", --- -2.39.3 (Apple Git-146) diff --git a/third_party/xla/patches/0002-Added-FFI-handler-registration-API-to-the-FFI-PjRt.patch b/third_party/xla/patches/0002-Added-FFI-handler-registration-API-to-the-FFI-PjRt.patch deleted file mode 100644 index 938ef40..0000000 --- a/third_party/xla/patches/0002-Added-FFI-handler-registration-API-to-the-FFI-PjRt.patch +++ /dev/null @@ -1,135 +0,0 @@ -From 2ae9bb9d24b569c2c6bfab3c54b428103614944d Mon Sep 17 00:00:00 2001 -From: Hugo Mano -Date: Tue, 27 May 2025 11:48:17 +0200 -Subject: [PATCH 1/8] Added FFI handler registration API to the FFI PjRt - -PR: https://github.com/openxla/xla/pull/13420 ---- - xla/pjrt/c/BUILD | 5 +++++ - xla/pjrt/c/pjrt_c_api_ffi_extension.h | 21 ++++++++++++++++++ - xla/pjrt/c/pjrt_c_api_ffi_internal.cc | 32 ++++++++++++++++++++++++++- - 3 files changed, 57 insertions(+), 1 deletion(-) - -diff --git a/xla/pjrt/c/BUILD b/xla/pjrt/c/BUILD -index 79f18fa0bc..0f33dd8a6e 100644 ---- a/xla/pjrt/c/BUILD -+++ b/xla/pjrt/c/BUILD -@@ -69,8 +69,13 @@ cc_library( - ":pjrt_c_api_wrapper_impl", - "//xla/ffi:execution_context", - "//xla/ffi:type_id_registry", -+ "//xla/ffi:ffi_api", -+ "//xla/ffi/api:c_api", -+ "//xla/ffi/api:ffi", -+ "//xla/service:custom_call_target_registry", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings:string_view", -+ "@com_google_absl//absl/strings:str_format", - ], - ) - -diff --git a/xla/pjrt/c/pjrt_c_api_ffi_extension.h b/xla/pjrt/c/pjrt_c_api_ffi_extension.h -index 995a2c7e50..b8f10bc2f7 100644 ---- a/xla/pjrt/c/pjrt_c_api_ffi_extension.h -+++ b/xla/pjrt/c/pjrt_c_api_ffi_extension.h -@@ -69,10 +69,31 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_UserData_Add_Args, user_data); - // Adds a user data to the execute context. - typedef PJRT_Error* PJRT_FFI_UserData_Add(PJRT_FFI_UserData_Add_Args* args); - -+typedef enum PJRT_FFI_Handler_TraitsBits { -+ PJRT_FFI_HANDLER_TRAITS_COMMAND_BUFFER_COMPATIBLE = 1u << 0, -+} PJRT_FFI_Handler_TraitsBits; -+ -+struct PJRT_FFI_Register_Handler_Args { -+ size_t struct_size; -+ const char* target_name; -+ size_t target_name_size; -+ int api_version; // 0 for an untyped call, 1 -- for typed -+ void* handler; -+ const char* platform_name; -+ size_t platform_name_size; -+ PJRT_FFI_Handler_TraitsBits traits; -+}; -+PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_Register_Handler_Args, traits); -+ -+// Registers an FFI call handler for a specific platform. -+typedef PJRT_Error* PJRT_FFI_Register_Handler( -+ PJRT_FFI_Register_Handler_Args* args); -+ - typedef struct PJRT_FFI_Extension { - PJRT_Extension_Base base; - PJRT_FFI_TypeID_Register* type_id_register; - PJRT_FFI_UserData_Add* user_data_add; -+ PJRT_FFI_Register_Handler* register_handler; - } PJRT_FFI; - PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_Extension, user_data_add); - -diff --git a/xla/pjrt/c/pjrt_c_api_ffi_internal.cc b/xla/pjrt/c/pjrt_c_api_ffi_internal.cc -index 5fa88eab33..763270331b 100644 ---- a/xla/pjrt/c/pjrt_c_api_ffi_internal.cc -+++ b/xla/pjrt/c/pjrt_c_api_ffi_internal.cc -@@ -13,16 +13,20 @@ See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ - --#include "xla/pjrt/c/pjrt_c_api_ffi_internal.h" -+#include - - #include "absl/status/status.h" -+#include "absl/strings/str_format.h" - #include "absl/strings/string_view.h" -+#include "xla/ffi/api/c_api.h" - #include "xla/ffi/execution_context.h" - #include "xla/ffi/type_id_registry.h" -+#include "xla/ffi/ffi_api.h" - #include "xla/pjrt/c/pjrt_c_api.h" - #include "xla/pjrt/c/pjrt_c_api_ffi_extension.h" - #include "xla/pjrt/c/pjrt_c_api_helpers.h" - #include "xla/pjrt/c/pjrt_c_api_wrapper_impl.h" -+#include "xla/service/custom_call_target_registry.h" - - namespace pjrt { - -@@ -68,6 +72,31 @@ static PJRT_Error* PJRT_FFI_UserData_Add(PJRT_FFI_UserData_Add_Args* args) { - return nullptr; - } - -+static PJRT_Error* PJRT_FFI_Register_Handler( -+ PJRT_FFI_Register_Handler_Args* args) { -+ PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual( -+ "PJRT_FFI_Register_Handler_Args", -+ PJRT_FFI_Register_Handler_Args_STRUCT_SIZE, args->struct_size)); -+ std::string target_name(args->target_name, args->target_name_size); -+ std::string platform_name(args->platform_name, args->platform_name_size); -+ switch (args->api_version) { -+ case 0: -+ xla::CustomCallTargetRegistry::Global()->Register( -+ target_name, args->handler, platform_name); -+ return nullptr; -+ case 1: -+ xla::ffi::Ffi::RegisterStaticHandler( -+ xla::ffi::GetXlaFfiApi(), target_name, platform_name, -+ reinterpret_cast(args->handler)); -+ return nullptr; -+ default: -+ return new PJRT_Error{absl::UnimplementedError( -+ absl::StrFormat("API version %d not supported for PJRT GPU plugin. " -+ "Supported versions are 0 and 1.", -+ args->api_version))}; -+ } -+} -+ - PJRT_FFI_Extension CreateFfiExtension(PJRT_Extension_Base* next) { - return { - PJRT_Extension_Base{ -@@ -77,6 +106,7 @@ PJRT_FFI_Extension CreateFfiExtension(PJRT_Extension_Base* next) { - }, - /*type_id_register=*/PJRT_FFI_TypeID_Register, - /*user_data_add=*/PJRT_FFI_UserData_Add, -+ /*register_handler=*/PJRT_FFI_Register_Handler, - }; - } - --- -2.39.5 (Apple Git-154) - diff --git a/third_party/xla/patches/0003-Remove-unconventional-C-code-in-headers.patch b/third_party/xla/patches/0003-Remove-unconventional-C-code-in-headers.patch deleted file mode 100644 index 6df5b1b..0000000 --- a/third_party/xla/patches/0003-Remove-unconventional-C-code-in-headers.patch +++ /dev/null @@ -1,124 +0,0 @@ -From 6078da86a46b6f0d983dccb9ae4f36fc90640247 Mon Sep 17 00:00:00 2001 -From: Hugo Mano -Date: Fri, 11 Jul 2025 14:05:16 +0200 -Subject: [PATCH] zml patch - ---- - third_party/stablehlo/workspace.bzl | 1 + - third_party/stablehlo/zml.patch | 93 +++++++++++++++++++++++++++++ - 2 files changed, 94 insertions(+) - create mode 100644 third_party/stablehlo/zml.patch - -diff --git a/third_party/stablehlo/workspace.bzl b/third_party/stablehlo/workspace.bzl -index d9d5063744..44980948d0 100644 ---- a/third_party/stablehlo/workspace.bzl -+++ b/third_party/stablehlo/workspace.bzl -@@ -15,5 +15,6 @@ def repo(): - urls = tf_mirror_urls("https://github.com/openxla/stablehlo/archive/{commit}.zip".format(commit = STABLEHLO_COMMIT)), - patch_file = [ - "//third_party/stablehlo:temporary.patch", # Autogenerated, don't remove. -+ "//third_party/stablehlo:zml.patch", # Autogenerated, don't remove. - ], - ) -diff --git a/third_party/stablehlo/zml.patch b/third_party/stablehlo/zml.patch -new file mode 100644 -index 0000000000..2a09384582 ---- /dev/null -+++ b/third_party/stablehlo/zml.patch -@@ -0,0 +1,93 @@ -+From e38ab68376dd8a17ebf4469d2c8350f521310182 Mon Sep 17 00:00:00 2001 -+From: Hugo Mano -+Date: Fri, 11 Jul 2025 12:08:35 +0200 -+Subject: [PATCH] zml patch -+ -+--- -+ stablehlo/dialect/Serialization.cpp | 5 ++--- -+ stablehlo/dialect/Serialization.h | 3 +-- -+ stablehlo/integrations/c/StablehloDialectApi.cpp | 3 +-- -+ stablehlo/integrations/c/StablehloDialectApi.h | 2 +- -+ stablehlo/tools/StablehloTranslateMain.cpp | 2 +- -+ 5 files changed, 6 insertions(+), 9 deletions(-) -+ -+diff --git a/stablehlo/dialect/Serialization.cpp b/stablehlo/dialect/Serialization.cpp -+index cb89d673..4370d588 100644 -+--- a/stablehlo/dialect/Serialization.cpp -++++ b/stablehlo/dialect/Serialization.cpp -+@@ -39,8 +39,7 @@ namespace stablehlo { -+ -+ LogicalResult serializePortableArtifact(ModuleOp module, -+ StringRef targetVersion, -+- raw_ostream& os, -+- bool allowOtherDialects) { -++ raw_ostream& os) { -+ MLIRContext* context = module.getContext(); -+ -+ // Convert StableHLO --> VHLO. -+@@ -49,7 +48,7 @@ LogicalResult serializePortableArtifact(ModuleOp module, -+ { -+ PassManager pm(context); -+ StablehloLegalizeToVhloPassOptions options; -+- options.allowOtherDialects = allowOtherDialects; -++ options.allowOtherDialects = false; -+ pm.addPass(stablehlo::createStablehloLegalizeToVhloPass(options)); -+ if (!succeeded(pm.run(module))) { -+ return failure(); -+diff --git a/stablehlo/dialect/Serialization.h b/stablehlo/dialect/Serialization.h -+index 811ca97b..abe95e63 100644 -+--- a/stablehlo/dialect/Serialization.h -++++ b/stablehlo/dialect/Serialization.h -+@@ -34,8 +34,7 @@ namespace stablehlo { -+ // unsupported dialects. -+ LogicalResult serializePortableArtifact(ModuleOp module, -+ StringRef targetVersion, -+- raw_ostream& os, -+- bool allowOtherDialects = false); -++ raw_ostream& os); -+ -+ // Read StableHLO portable artifact -+ // -+diff --git a/stablehlo/integrations/c/StablehloDialectApi.cpp b/stablehlo/integrations/c/StablehloDialectApi.cpp -+index 343f8d0b..8f52e4d5 100644 -+--- a/stablehlo/integrations/c/StablehloDialectApi.cpp -++++ b/stablehlo/integrations/c/StablehloDialectApi.cpp -+@@ -81,8 +81,7 @@ MlirLogicalResult stablehloSerializePortableArtifactFromModule( -+ MlirStringCallback callback, void *userData, bool allowOtherDialects) { -+ mlir::detail::CallbackOstream stream(callback, userData); -+ if (failed(mlir::stablehlo::serializePortableArtifact( -+- unwrap(moduleStr), unwrap(targetVersion), stream, -+- allowOtherDialects))) -++ unwrap(moduleStr), unwrap(targetVersion), stream))) -+ return mlirLogicalResultFailure(); -+ return mlirLogicalResultSuccess(); -+ } -+diff --git a/stablehlo/integrations/c/StablehloDialectApi.h b/stablehlo/integrations/c/StablehloDialectApi.h -+index 385156bf..24d11c1d 100644 -+--- a/stablehlo/integrations/c/StablehloDialectApi.h -++++ b/stablehlo/integrations/c/StablehloDialectApi.h -+@@ -93,7 +93,7 @@ stablehloSerializePortableArtifactFromModule(MlirModule moduleStr, -+ MlirStringRef targetVersion, -+ MlirStringCallback callback, -+ void* userData, -+- bool allowOtherDialects = false); -++ bool allowOtherDialects); -+ -+ // Read a StableHLO program from a portable artifact, returning the module as -+ // MLIR bytecode. Note, this bytecode returned is not a portable artifact, -+diff --git a/stablehlo/tools/StablehloTranslateMain.cpp b/stablehlo/tools/StablehloTranslateMain.cpp -+index fdf0d6a9..8d5c8752 100644 -+--- a/stablehlo/tools/StablehloTranslateMain.cpp -++++ b/stablehlo/tools/StablehloTranslateMain.cpp -+@@ -323,7 +323,7 @@ TranslateFromMLIRRegistration serializeRegistration( -+ } -+ -+ return stablehlo::serializePortableArtifact( -+- module, targetVersion, os, allowOtherDialectsOption.getValue()); -++ module, targetVersion, os); -+ }, -+ [](DialectRegistry ®istry) { -+ mlir::registerAllDialects(registry); -+-- -+2.39.5 (Apple Git-154) -+ --- -2.39.5 (Apple Git-154) - diff --git a/third_party/xla/repo.bzl b/third_party/xla/repo.bzl index 8c005dc..e9a8b40 100644 --- a/third_party/xla/repo.bzl +++ b/third_party/xla/repo.bzl @@ -4,11 +4,9 @@ def repo(): git_repository( name = "xla", remote = "https://github.com/openxla/xla.git", - commit = "ef07e787ea1303fa2f8d8a175d24d434bfb84107", + commit = "b3fbfeeb076f2b536897180f4a274680ed9d52eb", patch_args = ["-p1"], patches = [ - "//third_party/xla:patches/0001-bazel-migration-to-bazel-8.1.1.patch", - "//third_party/xla:patches/0002-Added-FFI-handler-registration-API-to-the-FFI-PjRt.patch", - "//third_party/xla:patches/0003-Remove-unconventional-C-code-in-headers.patch", + # patches live in the patches directory ], ) diff --git a/third_party/xla/xla.bzl b/third_party/xla/xla.bzl index b0aff38..3780bc5 100644 --- a/third_party/xla/xla.bzl +++ b/third_party/xla/xla.bzl @@ -65,12 +65,10 @@ def _xla_impl(mctx): tf_http_archive( name = "com_github_grpc_grpc", - sha256 = "afbc5d78d6ba6d509cc6e264de0d49dcd7304db435cbf2d630385bacf49e066c", - strip_prefix = "grpc-1.68.2", - patch_file = [ - "//third_party/grpc:grpc.patch", - ], - urls = tf_mirror_urls("https://github.com/grpc/grpc/archive/refs/tags/v1.68.2.tar.gz"), + sha256 = "dd6a2fa311ba8441bbefd2764c55b99136ff10f7ea42954be96006a2723d33fc", + strip_prefix = "grpc-1.74.0", + patch_file = ["//third_party/grpc:grpc.patch"], + urls = tf_mirror_urls("https://github.com/grpc/grpc/archive/refs/tags/v1.74.0.tar.gz"), ) tf_vendored(name = "tsl", relpath = "third_party/tsl") diff --git a/zml/module.zig b/zml/module.zig index e57f942..33d22f7 100644 --- a/zml/module.zig +++ b/zml/module.zig @@ -905,7 +905,6 @@ fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, m try setXlaOverrideFlag(overrides_map, "xla_gpu_enable_triton_gemm", false, upb_arena); try setXlaOverrideFlag(overrides_map, "xla_gpu_enable_latency_hiding_scheduler", true, upb_arena); try setXlaOverrideFlag(overrides_map, "xla_gpu_enable_llvm_module_compilation_parallelism", true, upb_arena); - try setXlaOverrideFlag(overrides_map, "xla_gpu_enable_libnvptxcompiler", true, upb_arena); }, .rocm => { // Disable Triton GEMM on ROCM. For some reason it's much, much slower when