From 9488672d4b2800c196fa03e90128ddfed51199a8 Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Tue, 4 Mar 2025 17:12:34 +0000 Subject: [PATCH] workspace: bump xla to version 20250710.0-22ea002 Also: - Bump XLA deps : `com_github_grpc_grpc` and `com_google_protobuf` - Inject `rules_ml_toolchain` - Fix `zig_proto_library` rule --- MODULE.bazel | 2 +- bazel/zig_proto_library.bzl | 8 +- runtimes/cpu/cpu.bzl | 12 +- runtimes/cuda/cuda.bzl | 4 +- runtimes/rocm/rocm.bzl | 4 +- .../xla/20250710.0-22ea002/MODULE.bazel | 58 ++++++++ .../20250710.0-22ea002/overlay/MODULE.bazel | 58 ++++++++ .../xla/20250710.0-22ea002/overlay/llvm.bzl | 30 ++++ .../overlay/toolchains_private.bzl | 21 +++ .../overlay/workspace_private.bzl | 73 ++++++++++ .../xla/20250710.0-22ea002/overlay/xla.bzl | 17 +++ .../0001-bazel-migration-to-bazel-8.1.1.patch | 41 ++++++ ...ler-registration-API-to-the-FFI-PjRt.patch | 135 ++++++++++++++++++ ...ove-unconventional-C-code-in-headers.patch | 124 ++++++++++++++++ .../xla/20250710.0-22ea002/source.json | 18 +++ third_party/modules/xla/metadata.json | 3 +- 16 files changed, 592 insertions(+), 16 deletions(-) create mode 100644 third_party/modules/xla/20250710.0-22ea002/MODULE.bazel create mode 100644 third_party/modules/xla/20250710.0-22ea002/overlay/MODULE.bazel create mode 100644 third_party/modules/xla/20250710.0-22ea002/overlay/llvm.bzl create mode 100644 third_party/modules/xla/20250710.0-22ea002/overlay/toolchains_private.bzl create mode 100644 third_party/modules/xla/20250710.0-22ea002/overlay/workspace_private.bzl create mode 100644 third_party/modules/xla/20250710.0-22ea002/overlay/xla.bzl create mode 100644 third_party/modules/xla/20250710.0-22ea002/patches/0001-bazel-migration-to-bazel-8.1.1.patch create mode 100644 third_party/modules/xla/20250710.0-22ea002/patches/0002-Added-FFI-handler-registration-API-to-the-FFI-PjRt.patch create mode 100644 third_party/modules/xla/20250710.0-22ea002/patches/0003-Remove-unconventional-C-code-in-headers.patch create mode 100644 third_party/modules/xla/20250710.0-22ea002/source.json diff --git a/MODULE.bazel b/MODULE.bazel index 269c12c..6df826e 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -22,7 +22,7 @@ bazel_dep(name = "sentencepiece", version = "20240618.0-d7ace0a") bazel_dep(name = "toolchains_llvm_bootstrapped", version = "0.2.3") bazel_dep(name = "toolchains_protoc", version = "0.4.1") bazel_dep(name = "with_cfg.bzl", version = "0.9.1") -bazel_dep(name = "xla", version = "20250612.0-6e48cbb") +bazel_dep(name = "xla", version = "20250710.0-22ea002") bazel_dep(name = "zig-protobuf", version = "20250318.0-930153e") bazel_dep(name = "zig-yaml", version = "20240903.0-83d5fdf") diff --git a/bazel/zig_proto_library.bzl b/bazel/zig_proto_library.bzl index 54df675..02a903d 100644 --- a/bazel/zig_proto_library.bzl +++ b/bazel/zig_proto_library.bzl @@ -1,5 +1,6 @@ """Starlark implementation of zig_proto_library""" +load("@protobuf//bazel/common:proto_info.bzl", "ProtoInfo") load("@rules_proto//proto:defs.bzl", "proto_common") load( "@rules_zig//zig/private/providers:zig_module_info.bzl", @@ -70,10 +71,9 @@ def get_import_name(target, proto_src): name = str(target.label) # special handling of builtin types - if "com_google_protobuf//:" in name: - name = "google_protobuf_" + proto_src.basename - else: - name = name.rsplit("//")[-1] + if "com_google_protobuf" in name: + return "google_protobuf_" + proto_src.basename.replace(".", "_") + name = name.rsplit("//")[-1] return name.replace(".", "_").replace(":", "_").replace("/", "_") diff --git a/runtimes/cpu/cpu.bzl b/runtimes/cpu/cpu.bzl index 094483e..4dc62b7 100644 --- a/runtimes/cpu/cpu.bzl +++ b/runtimes/cpu/cpu.bzl @@ -25,22 +25,22 @@ def _cpu_pjrt_plugin_impl(mctx): http_archive( name = "libpjrt_cpu_linux_amd64", build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + _BUILD_LINUX, - sha256 = "4106ca11ab41bc9ec000d536ae084442139b5639ca329bfb62c7e0742acdc47a", - url = "https://github.com/zml/pjrt-artifacts/releases/download/v10.0.0/pjrt-cpu_linux-amd64.tar.gz", + sha256 = "3369fa7a1a1bb5998b818e1fb5f2c28966a59f6096eab500ef2d8419548a1c91", + url = "https://github.com/zml/pjrt-artifacts/releases/download/v11.0.0/pjrt-cpu_linux-amd64.tar.gz", ) http_archive( name = "libpjrt_cpu_darwin_amd64", build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + _BUILD_DARWIN, - sha256 = "7be4d98f0737601fba7b29563917054aac3d09365139e6d3f5f96023a8c71c87", - url = "https://github.com/zml/pjrt-artifacts/releases/download/v10.0.0/pjrt-cpu_darwin-amd64.tar.gz", + sha256 = "9947382613d30eb757dfb1bfcad0536ec9dad1e11b1189d1172abbce434b69bb", + url = "https://github.com/zml/pjrt-artifacts/releases/download/v11.0.0/pjrt-cpu_darwin-amd64.tar.gz", ) http_archive( name = "libpjrt_cpu_darwin_arm64", build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + _BUILD_DARWIN, - sha256 = "442cccd98d7adf4afe0f818ebba265baca6b68dea95b10ef2b4d4229b81d5412", - url = "https://github.com/zml/pjrt-artifacts/releases/download/v10.0.0/pjrt-cpu_darwin-arm64.tar.gz", + sha256 = "fe3818455b034c9ffbd65dec559c04c2211a200a9b4d7feec8a00d6a3ffd0acd", + url = "https://github.com/zml/pjrt-artifacts/releases/download/v11.0.0/pjrt-cpu_darwin-arm64.tar.gz", ) return mctx.extension_metadata( diff --git a/runtimes/cuda/cuda.bzl b/runtimes/cuda/cuda.bzl index d45963c..f8d96b5 100644 --- a/runtimes/cuda/cuda.bzl +++ b/runtimes/cuda/cuda.bzl @@ -214,8 +214,8 @@ def _cuda_impl(mctx): http_archive( name = "libpjrt_cuda", build_file = "libpjrt_cuda.BUILD.bazel", - url = "https://github.com/zml/pjrt-artifacts/releases/download/v10.0.0/pjrt-cuda_linux-amd64.tar.gz", - sha256 = "eddf4db325aaeb1692e9eff1b5021dbeda27c08e527cae87295a61d94e654395", + url = "https://github.com/zml/pjrt-artifacts/releases/download/v11.0.0/pjrt-cuda_linux-amd64.tar.gz", + sha256 = "08fa022a6067ddfb5c951bdf11ddc398e63de21fdcacc9ffd07f70b1463482c2", ) return mctx.extension_metadata( diff --git a/runtimes/rocm/rocm.bzl b/runtimes/rocm/rocm.bzl index 5f118fd..4fd8b55 100644 --- a/runtimes/rocm/rocm.bzl +++ b/runtimes/rocm/rocm.bzl @@ -127,8 +127,8 @@ def _rocm_impl(mctx): http_archive( name = "libpjrt_rocm", build_file = "libpjrt_rocm.BUILD.bazel", - url = "https://github.com/zml/pjrt-artifacts/releases/download/v10.0.0/pjrt-rocm_linux-amd64.tar.gz", - sha256 = "ce5badf1ba5d1073a7de1e4d1d2a97fd1b66876d1fa255f913ffd410f50e6bc5", + url = "https://github.com/zml/pjrt-artifacts/releases/download/v11.0.0/pjrt-rocm_linux-amd64.tar.gz", + sha256 = "a6d8ef38ae4deda244856549271a1b1a6f46499e9efb64fb71a12fd6ae792d3b", ) return mctx.extension_metadata( diff --git a/third_party/modules/xla/20250710.0-22ea002/MODULE.bazel b/third_party/modules/xla/20250710.0-22ea002/MODULE.bazel new file mode 100644 index 0000000..09bc230 --- /dev/null +++ b/third_party/modules/xla/20250710.0-22ea002/MODULE.bazel @@ -0,0 +1,58 @@ +module( + name = "xla", + version = "20250710.0-22ea002", + compatibility_level = 1, +) + +bazel_dep(name = "platforms", version = "0.0.8") +bazel_dep(name = "bazel_skylib", version = "1.5.0") +bazel_dep(name = "rules_cc", version = "0.0.17") +bazel_dep(name = "rules_apple", version = "3.22.0", repo_name = "build_bazel_rules_apple") +bazel_dep(name = "abseil-cpp", version = "20240116.0", repo_name = "com_google_absl") +bazel_dep(name = "rules_python", version = "0.39.0") +bazel_dep(name = "rules_proto", version = "6.0.0-rc1") +bazel_dep(name = "rules_java", version = "7.3.2") +bazel_dep(name = "rules_pkg", version = "0.9.1") +bazel_dep(name = "zlib", version = "1.2.13") +bazel_dep(name = "re2", version = "2024-07-02.bcr.1", repo_name = "com_googlesource_code_re2") +bazel_dep(name = "rules_license", version = "0.0.8") +bazel_dep(name = "rules_shell", version = "0.4.1") +bazel_dep(name = "bazel_features", version = "1.25.0", repo_name = "proto_bazel_features") + +toolchains_private = use_extension("//:toolchains_private.bzl", "toolchains_private") +use_repo( + toolchains_private, + "rules_ml_toolchain", +) + +workspace_private = use_extension("//:workspace_private.bzl", "workspace_private") +use_repo( + workspace_private, + "com_github_grpc_grpc", + "com_google_protobuf", + "local_config_cuda", + "local_config_remote_execution", + "local_config_rocm", + "local_config_tensorrt", + "python_version_repo", + "tsl", +) + +workspace_public = use_extension("//:xla.bzl", "xla") +use_repo( + workspace_public, + "llvm-raw", + "stablehlo", + "triton", +) + +llvm = use_extension("//:llvm.bzl", "llvm") +llvm.configure( + targets = [ + "AArch64", + "AMDGPU", + "NVPTX", + "X86", + ], +) +use_repo(llvm, "llvm-project") diff --git a/third_party/modules/xla/20250710.0-22ea002/overlay/MODULE.bazel b/third_party/modules/xla/20250710.0-22ea002/overlay/MODULE.bazel new file mode 100644 index 0000000..09bc230 --- /dev/null +++ b/third_party/modules/xla/20250710.0-22ea002/overlay/MODULE.bazel @@ -0,0 +1,58 @@ +module( + name = "xla", + version = "20250710.0-22ea002", + compatibility_level = 1, +) + +bazel_dep(name = "platforms", version = "0.0.8") +bazel_dep(name = "bazel_skylib", version = "1.5.0") +bazel_dep(name = "rules_cc", version = "0.0.17") +bazel_dep(name = "rules_apple", version = "3.22.0", repo_name = "build_bazel_rules_apple") +bazel_dep(name = "abseil-cpp", version = "20240116.0", repo_name = "com_google_absl") +bazel_dep(name = "rules_python", version = "0.39.0") +bazel_dep(name = "rules_proto", version = "6.0.0-rc1") +bazel_dep(name = "rules_java", version = "7.3.2") +bazel_dep(name = "rules_pkg", version = "0.9.1") +bazel_dep(name = "zlib", version = "1.2.13") +bazel_dep(name = "re2", version = "2024-07-02.bcr.1", repo_name = "com_googlesource_code_re2") +bazel_dep(name = "rules_license", version = "0.0.8") +bazel_dep(name = "rules_shell", version = "0.4.1") +bazel_dep(name = "bazel_features", version = "1.25.0", repo_name = "proto_bazel_features") + +toolchains_private = use_extension("//:toolchains_private.bzl", "toolchains_private") +use_repo( + toolchains_private, + "rules_ml_toolchain", +) + +workspace_private = use_extension("//:workspace_private.bzl", "workspace_private") +use_repo( + workspace_private, + "com_github_grpc_grpc", + "com_google_protobuf", + "local_config_cuda", + "local_config_remote_execution", + "local_config_rocm", + "local_config_tensorrt", + "python_version_repo", + "tsl", +) + +workspace_public = use_extension("//:xla.bzl", "xla") +use_repo( + workspace_public, + "llvm-raw", + "stablehlo", + "triton", +) + +llvm = use_extension("//:llvm.bzl", "llvm") +llvm.configure( + targets = [ + "AArch64", + "AMDGPU", + "NVPTX", + "X86", + ], +) +use_repo(llvm, "llvm-project") diff --git a/third_party/modules/xla/20250710.0-22ea002/overlay/llvm.bzl b/third_party/modules/xla/20250710.0-22ea002/overlay/llvm.bzl new file mode 100644 index 0000000..b4a2fe4 --- /dev/null +++ b/third_party/modules/xla/20250710.0-22ea002/overlay/llvm.bzl @@ -0,0 +1,30 @@ +load("@llvm-raw//utils/bazel:configure.bzl", _llvm_configure = "llvm_configure") + +def _llvm_impl(mctx): + _targets = {} + for mod in mctx.modules: + for conf in mod.tags.configure: + for target in conf.targets: + _targets[target] = True + _llvm_configure( + name = "llvm-project", + targets = _targets.keys(), + ) + return mctx.extension_metadata( + reproducible = True, + root_module_direct_deps = "all", + root_module_direct_dev_deps = [], + ) + +llvm = module_extension( + implementation = _llvm_impl, + tag_classes = { + "configure": tag_class( + attrs = { + "targets": attr.string_list( + default = [], + ), + }, + ), + }, +) diff --git a/third_party/modules/xla/20250710.0-22ea002/overlay/toolchains_private.bzl b/third_party/modules/xla/20250710.0-22ea002/overlay/toolchains_private.bzl new file mode 100644 index 0000000..b8c7907 --- /dev/null +++ b/third_party/modules/xla/20250710.0-22ea002/overlay/toolchains_private.bzl @@ -0,0 +1,21 @@ +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") + +def _toolchains_private_impl(mctx): + http_archive( + name = "rules_ml_toolchain", + sha256 = "fb78d09234528aef2be856820b69b76486829f65e4eb3c7ffaa5803b667fa441", + strip_prefix = "rules_ml_toolchain-f4ad89fa906be2c1374785a79335c8a7dcd49df7", + urls = [ + "https://github.com/zml/rules_ml_toolchain/archive/f4ad89fa906be2c1374785a79335c8a7dcd49df7.tar.gz", + ], + ) + + return mctx.extension_metadata( + reproducible = True, + root_module_direct_deps = "all", + root_module_direct_dev_deps = [], + ) + +toolchains_private = module_extension( + implementation = _toolchains_private_impl, +) diff --git a/third_party/modules/xla/20250710.0-22ea002/overlay/workspace_private.bzl b/third_party/modules/xla/20250710.0-22ea002/overlay/workspace_private.bzl new file mode 100644 index 0000000..9dfe0b1 --- /dev/null +++ b/third_party/modules/xla/20250710.0-22ea002/overlay/workspace_private.bzl @@ -0,0 +1,73 @@ +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") +load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls", "tf_vendored") +load("//third_party/gpus:cuda_configure.bzl", "cuda_configure") +load("//third_party/gpus:rocm_configure.bzl", "rocm_configure") +load("//third_party/py:python_repo.bzl", "python_repository") +load("//third_party/pybind11_bazel:workspace.bzl", pybind11_bazel = "repo") +load("//third_party/tensorrt:tensorrt_configure.bzl", "tensorrt_configure") +load("//tools/toolchains/remote:configure.bzl", "remote_execution_configure") + +def _workspace_private_impl(mctx): + http_archive( + name = "rules_ml_toolchain", + sha256 = "fb78d09234528aef2be856820b69b76486829f65e4eb3c7ffaa5803b667fa441", + strip_prefix = "rules_ml_toolchain-f4ad89fa906be2c1374785a79335c8a7dcd49df7", + urls = [ + "https://github.com/zml/rules_ml_toolchain/archive/f4ad89fa906be2c1374785a79335c8a7dcd49df7.tar.gz", + ], + ) + + # Use cuda_configure from XLA to make it work with bzlmod. + # A pure bzlmod solution for rules_ml_toolchain is impossible because of the legacy design. + # It relies on a "generate-then-load" pattern that creates a deadlock in Bazel's architecture: + # - Generate: First, it runs a rule to generate a .bzl file containing configuration data. + # - Load: Then, it requires a load() statement to load that same file to continue the setup. + # This fails in bzlmod because Bazel's Loading Phase (when load() statements are processed) happens before + # the Analysis Phase (when repository rules are run). + # This creates a fundamental chicken-and-egg problem: the build tries to load a file that has not been generated yet. + # Without using the official WORKSPACE.bzlmod escape hatch, + # this incompatibility cannot be resolved without modifying the upstream rules. + + cuda_configure(name = "local_config_cuda") + remote_execution_configure(name = "local_config_remote_execution") + rocm_configure(name = "local_config_rocm") + tensorrt_configure(name = "local_config_tensorrt") + tf_vendored(name = "tsl", relpath = "third_party/tsl") + pybind11_bazel() + tf_http_archive( + name = "com_github_grpc_grpc", + sha256 = "afbc5d78d6ba6d509cc6e264de0d49dcd7304db435cbf2d630385bacf49e066c", + strip_prefix = "grpc-1.68.2", + patch_file = [ + "//third_party/grpc:grpc.patch", + ], + urls = tf_mirror_urls("https://github.com/grpc/grpc/archive/refs/tags/v1.68.2.tar.gz"), + ) + tf_http_archive( + name = "com_google_protobuf", + patch_file = ["//third_party/protobuf:protobuf.patch"], + sha256 = "f645e6e42745ce922ca5388b1883ca583bafe4366cc74cf35c3c9299005136e2", + strip_prefix = "protobuf-5.28.3", + urls = tf_mirror_urls("https://github.com/protocolbuffers/protobuf/archive/refs/tags/v5.28.3.zip"), + ) + + python_repository( + name = "python_version_repo", + requirements_versions = ["3.11"], + requirements_locks = ["//:requirements_lock_3_11.txt"], + local_wheel_workspaces = [], + local_wheel_dist_folder = None, + default_python_version = None, + local_wheel_inclusion_list = ["*"], + local_wheel_exclusion_list = [], + ) + + return mctx.extension_metadata( + reproducible = True, + root_module_direct_deps = "all", + root_module_direct_dev_deps = [], + ) + +workspace_private = module_extension( + implementation = _workspace_private_impl, +) diff --git a/third_party/modules/xla/20250710.0-22ea002/overlay/xla.bzl b/third_party/modules/xla/20250710.0-22ea002/overlay/xla.bzl new file mode 100644 index 0000000..f14bf2a --- /dev/null +++ b/third_party/modules/xla/20250710.0-22ea002/overlay/xla.bzl @@ -0,0 +1,17 @@ +load("//third_party/llvm:workspace.bzl", llvm = "repo") +load("//third_party/stablehlo:workspace.bzl", stablehlo = "repo") +load("//third_party/triton:workspace.bzl", triton = "repo") + +def _xla_impl(mctx): + triton() + llvm("llvm-raw") + stablehlo() + return mctx.extension_metadata( + reproducible = True, + root_module_direct_deps = "all", + root_module_direct_dev_deps = [], + ) + +xla = module_extension( + implementation = _xla_impl, +) diff --git a/third_party/modules/xla/20250710.0-22ea002/patches/0001-bazel-migration-to-bazel-8.1.1.patch b/third_party/modules/xla/20250710.0-22ea002/patches/0001-bazel-migration-to-bazel-8.1.1.patch new file mode 100644 index 0000000..8924cf4 --- /dev/null +++ b/third_party/modules/xla/20250710.0-22ea002/patches/0001-bazel-migration-to-bazel-8.1.1.patch @@ -0,0 +1,41 @@ +From 6cf475b500521c1b8be06f590fdbc1818f0dc44b Mon Sep 17 00:00:00 2001 +From: Jean-Baptiste Dalido +Date: Mon, 6 Jan 2025 13:33:13 +0100 +Subject: [PATCH] bazel: migration to bazel 8.0.1 + +--- + .bazelversion | 2 +- + third_party/tsl/third_party/gpus/cuda_configure.bzl | 4 ++-- + 2 files changed, 3 insertions(+), 3 deletions(-) + +diff --git a/.bazelversion b/.bazelversion +index f22d756da3..fa5fce04b3 100644 +--- a/.bazelversion ++++ b/.bazelversion +@@ -1 +1 @@ +-7.4.1 ++8.1.1 +\ No newline at end of file +diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl +index d62531152d..71d80a5a99 100644 +--- a/third_party/gpus/cuda_configure.bzl ++++ b/third_party/gpus/cuda_configure.bzl +@@ -33,14 +33,14 @@ NB: DEPRECATED! Use `hermetic/cuda_configure` rule instead. + load( + "@bazel_tools//tools/cpp:lib_cc_configure.bzl", + "escape_string", +- "get_env_var", + ) + load( + "@bazel_tools//tools/cpp:windows_cc_configure.bzl", +- "find_msvc_tool", + "find_vc_path", + "setup_vc_env_vars", + ) ++load("@rules_cc//cc/private/toolchain:windows_cc_configure.bzl", "find_msvc_tool") ++load("@rules_cc//cc/private/toolchain:lib_cc_configure.bzl", "get_env_var") + load("//third_party/clang_toolchain:download_clang.bzl", "download_clang") + load( + "//third_party/remote_config:common.bzl", +-- +2.39.3 (Apple Git-146) diff --git a/third_party/modules/xla/20250710.0-22ea002/patches/0002-Added-FFI-handler-registration-API-to-the-FFI-PjRt.patch b/third_party/modules/xla/20250710.0-22ea002/patches/0002-Added-FFI-handler-registration-API-to-the-FFI-PjRt.patch new file mode 100644 index 0000000..938ef40 --- /dev/null +++ b/third_party/modules/xla/20250710.0-22ea002/patches/0002-Added-FFI-handler-registration-API-to-the-FFI-PjRt.patch @@ -0,0 +1,135 @@ +From 2ae9bb9d24b569c2c6bfab3c54b428103614944d Mon Sep 17 00:00:00 2001 +From: Hugo Mano +Date: Tue, 27 May 2025 11:48:17 +0200 +Subject: [PATCH 1/8] Added FFI handler registration API to the FFI PjRt + +PR: https://github.com/openxla/xla/pull/13420 +--- + xla/pjrt/c/BUILD | 5 +++++ + xla/pjrt/c/pjrt_c_api_ffi_extension.h | 21 ++++++++++++++++++ + xla/pjrt/c/pjrt_c_api_ffi_internal.cc | 32 ++++++++++++++++++++++++++- + 3 files changed, 57 insertions(+), 1 deletion(-) + +diff --git a/xla/pjrt/c/BUILD b/xla/pjrt/c/BUILD +index 79f18fa0bc..0f33dd8a6e 100644 +--- a/xla/pjrt/c/BUILD ++++ b/xla/pjrt/c/BUILD +@@ -69,8 +69,13 @@ cc_library( + ":pjrt_c_api_wrapper_impl", + "//xla/ffi:execution_context", + "//xla/ffi:type_id_registry", ++ "//xla/ffi:ffi_api", ++ "//xla/ffi/api:c_api", ++ "//xla/ffi/api:ffi", ++ "//xla/service:custom_call_target_registry", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", ++ "@com_google_absl//absl/strings:str_format", + ], + ) + +diff --git a/xla/pjrt/c/pjrt_c_api_ffi_extension.h b/xla/pjrt/c/pjrt_c_api_ffi_extension.h +index 995a2c7e50..b8f10bc2f7 100644 +--- a/xla/pjrt/c/pjrt_c_api_ffi_extension.h ++++ b/xla/pjrt/c/pjrt_c_api_ffi_extension.h +@@ -69,10 +69,31 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_UserData_Add_Args, user_data); + // Adds a user data to the execute context. + typedef PJRT_Error* PJRT_FFI_UserData_Add(PJRT_FFI_UserData_Add_Args* args); + ++typedef enum PJRT_FFI_Handler_TraitsBits { ++ PJRT_FFI_HANDLER_TRAITS_COMMAND_BUFFER_COMPATIBLE = 1u << 0, ++} PJRT_FFI_Handler_TraitsBits; ++ ++struct PJRT_FFI_Register_Handler_Args { ++ size_t struct_size; ++ const char* target_name; ++ size_t target_name_size; ++ int api_version; // 0 for an untyped call, 1 -- for typed ++ void* handler; ++ const char* platform_name; ++ size_t platform_name_size; ++ PJRT_FFI_Handler_TraitsBits traits; ++}; ++PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_Register_Handler_Args, traits); ++ ++// Registers an FFI call handler for a specific platform. ++typedef PJRT_Error* PJRT_FFI_Register_Handler( ++ PJRT_FFI_Register_Handler_Args* args); ++ + typedef struct PJRT_FFI_Extension { + PJRT_Extension_Base base; + PJRT_FFI_TypeID_Register* type_id_register; + PJRT_FFI_UserData_Add* user_data_add; ++ PJRT_FFI_Register_Handler* register_handler; + } PJRT_FFI; + PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_Extension, user_data_add); + +diff --git a/xla/pjrt/c/pjrt_c_api_ffi_internal.cc b/xla/pjrt/c/pjrt_c_api_ffi_internal.cc +index 5fa88eab33..763270331b 100644 +--- a/xla/pjrt/c/pjrt_c_api_ffi_internal.cc ++++ b/xla/pjrt/c/pjrt_c_api_ffi_internal.cc +@@ -13,16 +13,20 @@ See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ + +-#include "xla/pjrt/c/pjrt_c_api_ffi_internal.h" ++#include + + #include "absl/status/status.h" ++#include "absl/strings/str_format.h" + #include "absl/strings/string_view.h" ++#include "xla/ffi/api/c_api.h" + #include "xla/ffi/execution_context.h" + #include "xla/ffi/type_id_registry.h" ++#include "xla/ffi/ffi_api.h" + #include "xla/pjrt/c/pjrt_c_api.h" + #include "xla/pjrt/c/pjrt_c_api_ffi_extension.h" + #include "xla/pjrt/c/pjrt_c_api_helpers.h" + #include "xla/pjrt/c/pjrt_c_api_wrapper_impl.h" ++#include "xla/service/custom_call_target_registry.h" + + namespace pjrt { + +@@ -68,6 +72,31 @@ static PJRT_Error* PJRT_FFI_UserData_Add(PJRT_FFI_UserData_Add_Args* args) { + return nullptr; + } + ++static PJRT_Error* PJRT_FFI_Register_Handler( ++ PJRT_FFI_Register_Handler_Args* args) { ++ PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual( ++ "PJRT_FFI_Register_Handler_Args", ++ PJRT_FFI_Register_Handler_Args_STRUCT_SIZE, args->struct_size)); ++ std::string target_name(args->target_name, args->target_name_size); ++ std::string platform_name(args->platform_name, args->platform_name_size); ++ switch (args->api_version) { ++ case 0: ++ xla::CustomCallTargetRegistry::Global()->Register( ++ target_name, args->handler, platform_name); ++ return nullptr; ++ case 1: ++ xla::ffi::Ffi::RegisterStaticHandler( ++ xla::ffi::GetXlaFfiApi(), target_name, platform_name, ++ reinterpret_cast(args->handler)); ++ return nullptr; ++ default: ++ return new PJRT_Error{absl::UnimplementedError( ++ absl::StrFormat("API version %d not supported for PJRT GPU plugin. " ++ "Supported versions are 0 and 1.", ++ args->api_version))}; ++ } ++} ++ + PJRT_FFI_Extension CreateFfiExtension(PJRT_Extension_Base* next) { + return { + PJRT_Extension_Base{ +@@ -77,6 +106,7 @@ PJRT_FFI_Extension CreateFfiExtension(PJRT_Extension_Base* next) { + }, + /*type_id_register=*/PJRT_FFI_TypeID_Register, + /*user_data_add=*/PJRT_FFI_UserData_Add, ++ /*register_handler=*/PJRT_FFI_Register_Handler, + }; + } + +-- +2.39.5 (Apple Git-154) + diff --git a/third_party/modules/xla/20250710.0-22ea002/patches/0003-Remove-unconventional-C-code-in-headers.patch b/third_party/modules/xla/20250710.0-22ea002/patches/0003-Remove-unconventional-C-code-in-headers.patch new file mode 100644 index 0000000..6df5b1b --- /dev/null +++ b/third_party/modules/xla/20250710.0-22ea002/patches/0003-Remove-unconventional-C-code-in-headers.patch @@ -0,0 +1,124 @@ +From 6078da86a46b6f0d983dccb9ae4f36fc90640247 Mon Sep 17 00:00:00 2001 +From: Hugo Mano +Date: Fri, 11 Jul 2025 14:05:16 +0200 +Subject: [PATCH] zml patch + +--- + third_party/stablehlo/workspace.bzl | 1 + + third_party/stablehlo/zml.patch | 93 +++++++++++++++++++++++++++++ + 2 files changed, 94 insertions(+) + create mode 100644 third_party/stablehlo/zml.patch + +diff --git a/third_party/stablehlo/workspace.bzl b/third_party/stablehlo/workspace.bzl +index d9d5063744..44980948d0 100644 +--- a/third_party/stablehlo/workspace.bzl ++++ b/third_party/stablehlo/workspace.bzl +@@ -15,5 +15,6 @@ def repo(): + urls = tf_mirror_urls("https://github.com/openxla/stablehlo/archive/{commit}.zip".format(commit = STABLEHLO_COMMIT)), + patch_file = [ + "//third_party/stablehlo:temporary.patch", # Autogenerated, don't remove. ++ "//third_party/stablehlo:zml.patch", # Autogenerated, don't remove. + ], + ) +diff --git a/third_party/stablehlo/zml.patch b/third_party/stablehlo/zml.patch +new file mode 100644 +index 0000000000..2a09384582 +--- /dev/null ++++ b/third_party/stablehlo/zml.patch +@@ -0,0 +1,93 @@ ++From e38ab68376dd8a17ebf4469d2c8350f521310182 Mon Sep 17 00:00:00 2001 ++From: Hugo Mano ++Date: Fri, 11 Jul 2025 12:08:35 +0200 ++Subject: [PATCH] zml patch ++ ++--- ++ stablehlo/dialect/Serialization.cpp | 5 ++--- ++ stablehlo/dialect/Serialization.h | 3 +-- ++ stablehlo/integrations/c/StablehloDialectApi.cpp | 3 +-- ++ stablehlo/integrations/c/StablehloDialectApi.h | 2 +- ++ stablehlo/tools/StablehloTranslateMain.cpp | 2 +- ++ 5 files changed, 6 insertions(+), 9 deletions(-) ++ ++diff --git a/stablehlo/dialect/Serialization.cpp b/stablehlo/dialect/Serialization.cpp ++index cb89d673..4370d588 100644 ++--- a/stablehlo/dialect/Serialization.cpp +++++ b/stablehlo/dialect/Serialization.cpp ++@@ -39,8 +39,7 @@ namespace stablehlo { ++ ++ LogicalResult serializePortableArtifact(ModuleOp module, ++ StringRef targetVersion, ++- raw_ostream& os, ++- bool allowOtherDialects) { +++ raw_ostream& os) { ++ MLIRContext* context = module.getContext(); ++ ++ // Convert StableHLO --> VHLO. ++@@ -49,7 +48,7 @@ LogicalResult serializePortableArtifact(ModuleOp module, ++ { ++ PassManager pm(context); ++ StablehloLegalizeToVhloPassOptions options; ++- options.allowOtherDialects = allowOtherDialects; +++ options.allowOtherDialects = false; ++ pm.addPass(stablehlo::createStablehloLegalizeToVhloPass(options)); ++ if (!succeeded(pm.run(module))) { ++ return failure(); ++diff --git a/stablehlo/dialect/Serialization.h b/stablehlo/dialect/Serialization.h ++index 811ca97b..abe95e63 100644 ++--- a/stablehlo/dialect/Serialization.h +++++ b/stablehlo/dialect/Serialization.h ++@@ -34,8 +34,7 @@ namespace stablehlo { ++ // unsupported dialects. ++ LogicalResult serializePortableArtifact(ModuleOp module, ++ StringRef targetVersion, ++- raw_ostream& os, ++- bool allowOtherDialects = false); +++ raw_ostream& os); ++ ++ // Read StableHLO portable artifact ++ // ++diff --git a/stablehlo/integrations/c/StablehloDialectApi.cpp b/stablehlo/integrations/c/StablehloDialectApi.cpp ++index 343f8d0b..8f52e4d5 100644 ++--- a/stablehlo/integrations/c/StablehloDialectApi.cpp +++++ b/stablehlo/integrations/c/StablehloDialectApi.cpp ++@@ -81,8 +81,7 @@ MlirLogicalResult stablehloSerializePortableArtifactFromModule( ++ MlirStringCallback callback, void *userData, bool allowOtherDialects) { ++ mlir::detail::CallbackOstream stream(callback, userData); ++ if (failed(mlir::stablehlo::serializePortableArtifact( ++- unwrap(moduleStr), unwrap(targetVersion), stream, ++- allowOtherDialects))) +++ unwrap(moduleStr), unwrap(targetVersion), stream))) ++ return mlirLogicalResultFailure(); ++ return mlirLogicalResultSuccess(); ++ } ++diff --git a/stablehlo/integrations/c/StablehloDialectApi.h b/stablehlo/integrations/c/StablehloDialectApi.h ++index 385156bf..24d11c1d 100644 ++--- a/stablehlo/integrations/c/StablehloDialectApi.h +++++ b/stablehlo/integrations/c/StablehloDialectApi.h ++@@ -93,7 +93,7 @@ stablehloSerializePortableArtifactFromModule(MlirModule moduleStr, ++ MlirStringRef targetVersion, ++ MlirStringCallback callback, ++ void* userData, ++- bool allowOtherDialects = false); +++ bool allowOtherDialects); ++ ++ // Read a StableHLO program from a portable artifact, returning the module as ++ // MLIR bytecode. Note, this bytecode returned is not a portable artifact, ++diff --git a/stablehlo/tools/StablehloTranslateMain.cpp b/stablehlo/tools/StablehloTranslateMain.cpp ++index fdf0d6a9..8d5c8752 100644 ++--- a/stablehlo/tools/StablehloTranslateMain.cpp +++++ b/stablehlo/tools/StablehloTranslateMain.cpp ++@@ -323,7 +323,7 @@ TranslateFromMLIRRegistration serializeRegistration( ++ } ++ ++ return stablehlo::serializePortableArtifact( ++- module, targetVersion, os, allowOtherDialectsOption.getValue()); +++ module, targetVersion, os); ++ }, ++ [](DialectRegistry ®istry) { ++ mlir::registerAllDialects(registry); ++-- ++2.39.5 (Apple Git-154) ++ +-- +2.39.5 (Apple Git-154) + diff --git a/third_party/modules/xla/20250710.0-22ea002/source.json b/third_party/modules/xla/20250710.0-22ea002/source.json new file mode 100644 index 0000000..d01c18a --- /dev/null +++ b/third_party/modules/xla/20250710.0-22ea002/source.json @@ -0,0 +1,18 @@ +{ + "strip_prefix": "xla-ef07e787ea1303fa2f8d8a175d24d434bfb84107", + "url": "https://github.com/openxla/xla/archive/ef07e787ea1303fa2f8d8a175d24d434bfb84107.tar.gz", + "integrity": "sha256-OL0e3Y+dgu/DM5MV+LsWvtOI9rUGLvXbwIU/9bnHGXc=", + "overlay": { + "toolchains_private.bzl": "", + "llvm.bzl": "", + "MODULE.bazel": "", + "workspace_private.bzl": "", + "xla.bzl": "" + }, + "patch_strip": 1, + "patches": { + "0001-bazel-migration-to-bazel-8.1.1.patch": "", + "0002-Added-FFI-handler-registration-API-to-the-FFI-PjRt.patch": "", + "0003-Remove-unconventional-C-code-in-headers.patch": "" + } +} diff --git a/third_party/modules/xla/metadata.json b/third_party/modules/xla/metadata.json index be35763..41eb64b 100644 --- a/third_party/modules/xla/metadata.json +++ b/third_party/modules/xla/metadata.json @@ -22,7 +22,8 @@ "20250317.1-71c67e2", "20250317.2-71c67e2", "20250527.0-cb67f2f", - "20250612.0-6e48cbb" + "20250612.0-6e48cbb", + "20250710.0-22ea002" ], "yanked_versions": {} }