From 1cafcc3c6082c2dc124328ea8b496551dd7cbb22 Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Wed, 5 Feb 2025 17:35:27 +0000 Subject: [PATCH] Workspace: bump XLA to newer version. --- MODULE.bazel | 2 +- runtimes/cpu/cpu.bzl | 12 +- runtimes/cuda/cuda.bzl | 4 +- runtimes/rocm/rocm.bzl | 4 +- .../xla/20250612.0-6e48cbb/MODULE.bazel | 52 + .../20250612.0-6e48cbb/overlay/MODULE.bazel | 52 + .../xla/20250612.0-6e48cbb/overlay/llvm.bzl | 30 + .../overlay/workspace_private.bzl | 71 + .../xla/20250612.0-6e48cbb/overlay/xla.bzl | 17 + .../0001-bazel-migration-to-bazel-8.1.1.patch | 41 + ...ler-registration-API-to-the-FFI-PjRt.patch | 135 ++ ...nal-allowOtherDialects-field-to-stab.patch | 1855 +++++++++++++++++ .../xla/20250612.0-6e48cbb/source.json | 17 + third_party/modules/xla/metadata.json | 3 +- 14 files changed, 2283 insertions(+), 12 deletions(-) create mode 100644 third_party/modules/xla/20250612.0-6e48cbb/MODULE.bazel create mode 100644 third_party/modules/xla/20250612.0-6e48cbb/overlay/MODULE.bazel create mode 100644 third_party/modules/xla/20250612.0-6e48cbb/overlay/llvm.bzl create mode 100644 third_party/modules/xla/20250612.0-6e48cbb/overlay/workspace_private.bzl create mode 100644 third_party/modules/xla/20250612.0-6e48cbb/overlay/xla.bzl create mode 100644 third_party/modules/xla/20250612.0-6e48cbb/patches/0001-bazel-migration-to-bazel-8.1.1.patch create mode 100644 third_party/modules/xla/20250612.0-6e48cbb/patches/0002-Added-FFI-handler-registration-API-to-the-FFI-PjRt.patch create mode 100644 third_party/modules/xla/20250612.0-6e48cbb/patches/0003-Revert-Add-optional-allowOtherDialects-field-to-stab.patch create mode 100644 third_party/modules/xla/20250612.0-6e48cbb/source.json diff --git a/MODULE.bazel b/MODULE.bazel index 6211391..269c12c 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -22,7 +22,7 @@ bazel_dep(name = "sentencepiece", version = "20240618.0-d7ace0a") bazel_dep(name = "toolchains_llvm_bootstrapped", version = "0.2.3") bazel_dep(name = "toolchains_protoc", version = "0.4.1") bazel_dep(name = "with_cfg.bzl", version = "0.9.1") -bazel_dep(name = "xla", version = "20250527.0-cb67f2f") +bazel_dep(name = "xla", version = "20250612.0-6e48cbb") bazel_dep(name = "zig-protobuf", version = "20250318.0-930153e") bazel_dep(name = "zig-yaml", version = "20240903.0-83d5fdf") diff --git a/runtimes/cpu/cpu.bzl b/runtimes/cpu/cpu.bzl index 3f6b332..094483e 100644 --- a/runtimes/cpu/cpu.bzl +++ b/runtimes/cpu/cpu.bzl @@ -25,22 +25,22 @@ def _cpu_pjrt_plugin_impl(mctx): http_archive( name = "libpjrt_cpu_linux_amd64", build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + _BUILD_LINUX, - sha256 = "ca92bccefa168881f98d01354971d6f598381cc4c5f07b161a0908d327610b66", - url = "https://github.com/zml/pjrt-artifacts/releases/download/v9.0.1/pjrt-cpu_linux-amd64.tar.gz", + sha256 = "4106ca11ab41bc9ec000d536ae084442139b5639ca329bfb62c7e0742acdc47a", + url = "https://github.com/zml/pjrt-artifacts/releases/download/v10.0.0/pjrt-cpu_linux-amd64.tar.gz", ) http_archive( name = "libpjrt_cpu_darwin_amd64", build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + _BUILD_DARWIN, - sha256 = "b6d05b5cd0382a7bd8943b8df98dc229853e402488127895e47786395afb73a7", - url = "https://github.com/zml/pjrt-artifacts/releases/download/v9.0.1/pjrt-cpu_darwin-amd64.tar.gz", + sha256 = "7be4d98f0737601fba7b29563917054aac3d09365139e6d3f5f96023a8c71c87", + url = "https://github.com/zml/pjrt-artifacts/releases/download/v10.0.0/pjrt-cpu_darwin-amd64.tar.gz", ) http_archive( name = "libpjrt_cpu_darwin_arm64", build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + _BUILD_DARWIN, - sha256 = "e1ac13cf80b0975eec1dc0643a6ec08001d6e07a6a0d500a38e1c4477f49a78c", - url = "https://github.com/zml/pjrt-artifacts/releases/download/v9.0.1/pjrt-cpu_darwin-arm64.tar.gz", + sha256 = "442cccd98d7adf4afe0f818ebba265baca6b68dea95b10ef2b4d4229b81d5412", + url = "https://github.com/zml/pjrt-artifacts/releases/download/v10.0.0/pjrt-cpu_darwin-arm64.tar.gz", ) return mctx.extension_metadata( diff --git a/runtimes/cuda/cuda.bzl b/runtimes/cuda/cuda.bzl index 09f8d75..a3123d9 100644 --- a/runtimes/cuda/cuda.bzl +++ b/runtimes/cuda/cuda.bzl @@ -214,8 +214,8 @@ def _cuda_impl(mctx): http_archive( name = "libpjrt_cuda", build_file = "libpjrt_cuda.BUILD.bazel", - url = "https://github.com/zml/pjrt-artifacts/releases/download/v9.0.1/pjrt-cuda_linux-amd64.tar.gz", - sha256 = "2ae18dacd9762e0ae89f223764b1793f8a4d7bd7238bfcd84d2342d7fb37a106", + url = "https://github.com/zml/pjrt-artifacts/releases/download/v10.0.0/pjrt-cuda_linux-amd64.tar.gz", + sha256 = "eddf4db325aaeb1692e9eff1b5021dbeda27c08e527cae87295a61d94e654395", ) return mctx.extension_metadata( diff --git a/runtimes/rocm/rocm.bzl b/runtimes/rocm/rocm.bzl index 8dfef40..5f118fd 100644 --- a/runtimes/rocm/rocm.bzl +++ b/runtimes/rocm/rocm.bzl @@ -127,8 +127,8 @@ def _rocm_impl(mctx): http_archive( name = "libpjrt_rocm", build_file = "libpjrt_rocm.BUILD.bazel", - url = "https://github.com/zml/pjrt-artifacts/releases/download/v9.0.1/pjrt-rocm_linux-amd64.tar.gz", - sha256 = "31223c61645e6a3966841be6ebbc8c56609835a792c75ad1e1442fd5afed759b", + url = "https://github.com/zml/pjrt-artifacts/releases/download/v10.0.0/pjrt-rocm_linux-amd64.tar.gz", + sha256 = "ce5badf1ba5d1073a7de1e4d1d2a97fd1b66876d1fa255f913ffd410f50e6bc5", ) return mctx.extension_metadata( diff --git a/third_party/modules/xla/20250612.0-6e48cbb/MODULE.bazel b/third_party/modules/xla/20250612.0-6e48cbb/MODULE.bazel new file mode 100644 index 0000000..46eb7d1 --- /dev/null +++ b/third_party/modules/xla/20250612.0-6e48cbb/MODULE.bazel @@ -0,0 +1,52 @@ +module( + name = "xla", + version = "20250612.0-6e48cbb", + compatibility_level = 1, +) + +bazel_dep(name = "platforms", version = "0.0.8") +bazel_dep(name = "bazel_skylib", version = "1.5.0") +bazel_dep(name = "rules_cc", version = "0.0.17") +bazel_dep(name = "rules_apple", version = "3.22.0", repo_name = "build_bazel_rules_apple") +bazel_dep(name = "abseil-cpp", version = "20240116.0", repo_name = "com_google_absl") +bazel_dep(name = "rules_python", version = "0.39.0") +bazel_dep(name = "rules_proto", version = "6.0.0-rc1") +bazel_dep(name = "rules_java", version = "7.3.2") +bazel_dep(name = "rules_pkg", version = "0.9.1") +bazel_dep(name = "zlib", version = "1.2.13") +bazel_dep(name = "re2", version = "2024-07-02.bcr.1", repo_name = "com_googlesource_code_re2") +bazel_dep(name = "rules_license", version = "0.0.8") +bazel_dep(name = "rules_shell", version = "0.4.1") +bazel_dep(name = "bazel_features", version = "1.25.0", repo_name = "proto_bazel_features") + +workspace_private = use_extension("//:workspace_private.bzl", "workspace_private") +use_repo( + workspace_private, + "com_github_grpc_grpc", + "com_google_protobuf", + "local_config_cuda", + "local_config_remote_execution", + "local_config_rocm", + "local_config_tensorrt", + "python_version_repo", + "tsl", +) + +workspace_public = use_extension("//:xla.bzl", "xla") +use_repo( + workspace_public, + "llvm-raw", + "stablehlo", + "triton", +) + +llvm = use_extension("//:llvm.bzl", "llvm") +llvm.configure( + targets = [ + "AArch64", + "AMDGPU", + "NVPTX", + "X86", + ], +) +use_repo(llvm, "llvm-project") diff --git a/third_party/modules/xla/20250612.0-6e48cbb/overlay/MODULE.bazel b/third_party/modules/xla/20250612.0-6e48cbb/overlay/MODULE.bazel new file mode 100644 index 0000000..74c9eaa --- /dev/null +++ b/third_party/modules/xla/20250612.0-6e48cbb/overlay/MODULE.bazel @@ -0,0 +1,52 @@ +module( + name = "xla", + version = "20250527.0-cb67f2f", + compatibility_level = 1, +) + +bazel_dep(name = "platforms", version = "0.0.8") +bazel_dep(name = "bazel_skylib", version = "1.5.0") +bazel_dep(name = "rules_cc", version = "0.0.17") +bazel_dep(name = "rules_apple", version = "3.22.0", repo_name = "build_bazel_rules_apple") +bazel_dep(name = "abseil-cpp", version = "20240116.0", repo_name = "com_google_absl") +bazel_dep(name = "rules_python", version = "0.39.0") +bazel_dep(name = "rules_proto", version = "6.0.0-rc1") +bazel_dep(name = "rules_java", version = "7.3.2") +bazel_dep(name = "rules_pkg", version = "0.9.1") +bazel_dep(name = "zlib", version = "1.2.13") +bazel_dep(name = "re2", version = "2024-07-02.bcr.1", repo_name = "com_googlesource_code_re2") +bazel_dep(name = "rules_license", version = "0.0.8") +bazel_dep(name = "rules_shell", version = "0.4.1") +bazel_dep(name = "bazel_features", version = "1.25.0", repo_name = "proto_bazel_features") + +workspace_private = use_extension("//:workspace_private.bzl", "workspace_private") +use_repo( + workspace_private, + "com_github_grpc_grpc", + "com_google_protobuf", + "local_config_cuda", + "local_config_remote_execution", + "local_config_rocm", + "local_config_tensorrt", + "python_version_repo", + "tsl", +) + +workspace_public = use_extension("//:xla.bzl", "xla") +use_repo( + workspace_public, + "llvm-raw", + "stablehlo", + "triton", +) + +llvm = use_extension("//:llvm.bzl", "llvm") +llvm.configure( + targets = [ + "AArch64", + "AMDGPU", + "NVPTX", + "X86", + ], +) +use_repo(llvm, "llvm-project") diff --git a/third_party/modules/xla/20250612.0-6e48cbb/overlay/llvm.bzl b/third_party/modules/xla/20250612.0-6e48cbb/overlay/llvm.bzl new file mode 100644 index 0000000..b4a2fe4 --- /dev/null +++ b/third_party/modules/xla/20250612.0-6e48cbb/overlay/llvm.bzl @@ -0,0 +1,30 @@ +load("@llvm-raw//utils/bazel:configure.bzl", _llvm_configure = "llvm_configure") + +def _llvm_impl(mctx): + _targets = {} + for mod in mctx.modules: + for conf in mod.tags.configure: + for target in conf.targets: + _targets[target] = True + _llvm_configure( + name = "llvm-project", + targets = _targets.keys(), + ) + return mctx.extension_metadata( + reproducible = True, + root_module_direct_deps = "all", + root_module_direct_dev_deps = [], + ) + +llvm = module_extension( + implementation = _llvm_impl, + tag_classes = { + "configure": tag_class( + attrs = { + "targets": attr.string_list( + default = [], + ), + }, + ), + }, +) diff --git a/third_party/modules/xla/20250612.0-6e48cbb/overlay/workspace_private.bzl b/third_party/modules/xla/20250612.0-6e48cbb/overlay/workspace_private.bzl new file mode 100644 index 0000000..b234d12 --- /dev/null +++ b/third_party/modules/xla/20250612.0-6e48cbb/overlay/workspace_private.bzl @@ -0,0 +1,71 @@ +load("//third_party/gpus:cuda_configure.bzl", "cuda_configure") +load("//third_party/gpus:rocm_configure.bzl", "rocm_configure") +load("//third_party/llvm:workspace.bzl", llvm = "repo") +load("//third_party/py:python_repo.bzl", "python_repository") +load("//third_party/pybind11_bazel:workspace.bzl", pybind11_bazel = "repo") +load("//third_party/stablehlo:workspace.bzl", stablehlo = "repo") +load("//third_party/tensorrt:tensorrt_configure.bzl", "tensorrt_configure") +load("//third_party/triton:workspace.bzl", triton = "repo") +load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") +load("//third_party:repo.bzl", "tf_vendored") +load("//tools/toolchains/remote:configure.bzl", "remote_execution_configure") + +def _workspace_private_impl(mctx): + cuda_configure(name = "local_config_cuda") + remote_execution_configure(name = "local_config_remote_execution") + rocm_configure(name = "local_config_rocm") + tensorrt_configure(name = "local_config_tensorrt") + tf_vendored(name = "tsl", relpath = "third_party/tsl") + pybind11_bazel() + tf_http_archive( + name = "com_github_grpc_grpc", + sha256 = "b956598d8cbe168b5ee717b5dafa56563eb5201a947856a6688bbeac9cac4e1f", + strip_prefix = "grpc-b54a5b338637f92bfcf4b0bc05e0f57a5fd8fadd", + system_build_file = "//third_party/systemlibs:grpc.BUILD", + patch_file = [ + "//third_party/grpc:generate_cc_env_fix.patch", + "//third_party/grpc:register_go_toolchain.patch", + ], + system_link_files = { + "//third_party/systemlibs:BUILD.bazel": "bazel/BUILD.bazel", + "//third_party/systemlibs:grpc.BUILD": "src/compiler/BUILD", + "//third_party/systemlibs:grpc.bazel.grpc_deps.bzl": "bazel/grpc_deps.bzl", + "//third_party/systemlibs:grpc.bazel.grpc_extra_deps.bzl": "bazel/grpc_extra_deps.bzl", + "//third_party/systemlibs:grpc.bazel.cc_grpc_library.bzl": "bazel/cc_grpc_library.bzl", + "//third_party/systemlibs:grpc.bazel.generate_cc.bzl": "bazel/generate_cc.bzl", + "//third_party/systemlibs:grpc.bazel.protobuf.bzl": "bazel/protobuf.bzl", + }, + urls = tf_mirror_urls("https://github.com/grpc/grpc/archive/b54a5b338637f92bfcf4b0bc05e0f57a5fd8fadd.tar.gz"), + ) + tf_http_archive( + name = "com_google_protobuf", + patch_file = ["//third_party/protobuf:protobuf.patch"], + sha256 = "f66073dee0bc159157b0bd7f502d7d1ee0bc76b3c1eac9836927511bdc4b3fc1", + strip_prefix = "protobuf-3.21.9", + system_build_file = "//third_party/systemlibs:protobuf.BUILD", + system_link_files = { + "//third_party/systemlibs:protobuf.bzl": "protobuf.bzl", + "//third_party/systemlibs:protobuf_deps.bzl": "protobuf_deps.bzl", + }, + urls = tf_mirror_urls("https://github.com/protocolbuffers/protobuf/archive/v3.21.9.zip"), + ) + python_repository( + name = "python_version_repo", + requirements_versions = ["3.11"], + requirements_locks = ["//:requirements_lock_3_11.txt"], + local_wheel_workspaces = [], + local_wheel_dist_folder = None, + default_python_version = None, + local_wheel_inclusion_list = ["*"], + local_wheel_exclusion_list = [], + ) + + return mctx.extension_metadata( + reproducible = True, + root_module_direct_deps = "all", + root_module_direct_dev_deps = [], + ) + +workspace_private = module_extension( + implementation = _workspace_private_impl, +) diff --git a/third_party/modules/xla/20250612.0-6e48cbb/overlay/xla.bzl b/third_party/modules/xla/20250612.0-6e48cbb/overlay/xla.bzl new file mode 100644 index 0000000..f14bf2a --- /dev/null +++ b/third_party/modules/xla/20250612.0-6e48cbb/overlay/xla.bzl @@ -0,0 +1,17 @@ +load("//third_party/llvm:workspace.bzl", llvm = "repo") +load("//third_party/stablehlo:workspace.bzl", stablehlo = "repo") +load("//third_party/triton:workspace.bzl", triton = "repo") + +def _xla_impl(mctx): + triton() + llvm("llvm-raw") + stablehlo() + return mctx.extension_metadata( + reproducible = True, + root_module_direct_deps = "all", + root_module_direct_dev_deps = [], + ) + +xla = module_extension( + implementation = _xla_impl, +) diff --git a/third_party/modules/xla/20250612.0-6e48cbb/patches/0001-bazel-migration-to-bazel-8.1.1.patch b/third_party/modules/xla/20250612.0-6e48cbb/patches/0001-bazel-migration-to-bazel-8.1.1.patch new file mode 100644 index 0000000..8924cf4 --- /dev/null +++ b/third_party/modules/xla/20250612.0-6e48cbb/patches/0001-bazel-migration-to-bazel-8.1.1.patch @@ -0,0 +1,41 @@ +From 6cf475b500521c1b8be06f590fdbc1818f0dc44b Mon Sep 17 00:00:00 2001 +From: Jean-Baptiste Dalido +Date: Mon, 6 Jan 2025 13:33:13 +0100 +Subject: [PATCH] bazel: migration to bazel 8.0.1 + +--- + .bazelversion | 2 +- + third_party/tsl/third_party/gpus/cuda_configure.bzl | 4 ++-- + 2 files changed, 3 insertions(+), 3 deletions(-) + +diff --git a/.bazelversion b/.bazelversion +index f22d756da3..fa5fce04b3 100644 +--- a/.bazelversion ++++ b/.bazelversion +@@ -1 +1 @@ +-7.4.1 ++8.1.1 +\ No newline at end of file +diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl +index d62531152d..71d80a5a99 100644 +--- a/third_party/gpus/cuda_configure.bzl ++++ b/third_party/gpus/cuda_configure.bzl +@@ -33,14 +33,14 @@ NB: DEPRECATED! Use `hermetic/cuda_configure` rule instead. + load( + "@bazel_tools//tools/cpp:lib_cc_configure.bzl", + "escape_string", +- "get_env_var", + ) + load( + "@bazel_tools//tools/cpp:windows_cc_configure.bzl", +- "find_msvc_tool", + "find_vc_path", + "setup_vc_env_vars", + ) ++load("@rules_cc//cc/private/toolchain:windows_cc_configure.bzl", "find_msvc_tool") ++load("@rules_cc//cc/private/toolchain:lib_cc_configure.bzl", "get_env_var") + load("//third_party/clang_toolchain:download_clang.bzl", "download_clang") + load( + "//third_party/remote_config:common.bzl", +-- +2.39.3 (Apple Git-146) diff --git a/third_party/modules/xla/20250612.0-6e48cbb/patches/0002-Added-FFI-handler-registration-API-to-the-FFI-PjRt.patch b/third_party/modules/xla/20250612.0-6e48cbb/patches/0002-Added-FFI-handler-registration-API-to-the-FFI-PjRt.patch new file mode 100644 index 0000000..938ef40 --- /dev/null +++ b/third_party/modules/xla/20250612.0-6e48cbb/patches/0002-Added-FFI-handler-registration-API-to-the-FFI-PjRt.patch @@ -0,0 +1,135 @@ +From 2ae9bb9d24b569c2c6bfab3c54b428103614944d Mon Sep 17 00:00:00 2001 +From: Hugo Mano +Date: Tue, 27 May 2025 11:48:17 +0200 +Subject: [PATCH 1/8] Added FFI handler registration API to the FFI PjRt + +PR: https://github.com/openxla/xla/pull/13420 +--- + xla/pjrt/c/BUILD | 5 +++++ + xla/pjrt/c/pjrt_c_api_ffi_extension.h | 21 ++++++++++++++++++ + xla/pjrt/c/pjrt_c_api_ffi_internal.cc | 32 ++++++++++++++++++++++++++- + 3 files changed, 57 insertions(+), 1 deletion(-) + +diff --git a/xla/pjrt/c/BUILD b/xla/pjrt/c/BUILD +index 79f18fa0bc..0f33dd8a6e 100644 +--- a/xla/pjrt/c/BUILD ++++ b/xla/pjrt/c/BUILD +@@ -69,8 +69,13 @@ cc_library( + ":pjrt_c_api_wrapper_impl", + "//xla/ffi:execution_context", + "//xla/ffi:type_id_registry", ++ "//xla/ffi:ffi_api", ++ "//xla/ffi/api:c_api", ++ "//xla/ffi/api:ffi", ++ "//xla/service:custom_call_target_registry", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", ++ "@com_google_absl//absl/strings:str_format", + ], + ) + +diff --git a/xla/pjrt/c/pjrt_c_api_ffi_extension.h b/xla/pjrt/c/pjrt_c_api_ffi_extension.h +index 995a2c7e50..b8f10bc2f7 100644 +--- a/xla/pjrt/c/pjrt_c_api_ffi_extension.h ++++ b/xla/pjrt/c/pjrt_c_api_ffi_extension.h +@@ -69,10 +69,31 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_UserData_Add_Args, user_data); + // Adds a user data to the execute context. + typedef PJRT_Error* PJRT_FFI_UserData_Add(PJRT_FFI_UserData_Add_Args* args); + ++typedef enum PJRT_FFI_Handler_TraitsBits { ++ PJRT_FFI_HANDLER_TRAITS_COMMAND_BUFFER_COMPATIBLE = 1u << 0, ++} PJRT_FFI_Handler_TraitsBits; ++ ++struct PJRT_FFI_Register_Handler_Args { ++ size_t struct_size; ++ const char* target_name; ++ size_t target_name_size; ++ int api_version; // 0 for an untyped call, 1 -- for typed ++ void* handler; ++ const char* platform_name; ++ size_t platform_name_size; ++ PJRT_FFI_Handler_TraitsBits traits; ++}; ++PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_Register_Handler_Args, traits); ++ ++// Registers an FFI call handler for a specific platform. ++typedef PJRT_Error* PJRT_FFI_Register_Handler( ++ PJRT_FFI_Register_Handler_Args* args); ++ + typedef struct PJRT_FFI_Extension { + PJRT_Extension_Base base; + PJRT_FFI_TypeID_Register* type_id_register; + PJRT_FFI_UserData_Add* user_data_add; ++ PJRT_FFI_Register_Handler* register_handler; + } PJRT_FFI; + PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_Extension, user_data_add); + +diff --git a/xla/pjrt/c/pjrt_c_api_ffi_internal.cc b/xla/pjrt/c/pjrt_c_api_ffi_internal.cc +index 5fa88eab33..763270331b 100644 +--- a/xla/pjrt/c/pjrt_c_api_ffi_internal.cc ++++ b/xla/pjrt/c/pjrt_c_api_ffi_internal.cc +@@ -13,16 +13,20 @@ See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ + +-#include "xla/pjrt/c/pjrt_c_api_ffi_internal.h" ++#include + + #include "absl/status/status.h" ++#include "absl/strings/str_format.h" + #include "absl/strings/string_view.h" ++#include "xla/ffi/api/c_api.h" + #include "xla/ffi/execution_context.h" + #include "xla/ffi/type_id_registry.h" ++#include "xla/ffi/ffi_api.h" + #include "xla/pjrt/c/pjrt_c_api.h" + #include "xla/pjrt/c/pjrt_c_api_ffi_extension.h" + #include "xla/pjrt/c/pjrt_c_api_helpers.h" + #include "xla/pjrt/c/pjrt_c_api_wrapper_impl.h" ++#include "xla/service/custom_call_target_registry.h" + + namespace pjrt { + +@@ -68,6 +72,31 @@ static PJRT_Error* PJRT_FFI_UserData_Add(PJRT_FFI_UserData_Add_Args* args) { + return nullptr; + } + ++static PJRT_Error* PJRT_FFI_Register_Handler( ++ PJRT_FFI_Register_Handler_Args* args) { ++ PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual( ++ "PJRT_FFI_Register_Handler_Args", ++ PJRT_FFI_Register_Handler_Args_STRUCT_SIZE, args->struct_size)); ++ std::string target_name(args->target_name, args->target_name_size); ++ std::string platform_name(args->platform_name, args->platform_name_size); ++ switch (args->api_version) { ++ case 0: ++ xla::CustomCallTargetRegistry::Global()->Register( ++ target_name, args->handler, platform_name); ++ return nullptr; ++ case 1: ++ xla::ffi::Ffi::RegisterStaticHandler( ++ xla::ffi::GetXlaFfiApi(), target_name, platform_name, ++ reinterpret_cast(args->handler)); ++ return nullptr; ++ default: ++ return new PJRT_Error{absl::UnimplementedError( ++ absl::StrFormat("API version %d not supported for PJRT GPU plugin. " ++ "Supported versions are 0 and 1.", ++ args->api_version))}; ++ } ++} ++ + PJRT_FFI_Extension CreateFfiExtension(PJRT_Extension_Base* next) { + return { + PJRT_Extension_Base{ +@@ -77,6 +106,7 @@ PJRT_FFI_Extension CreateFfiExtension(PJRT_Extension_Base* next) { + }, + /*type_id_register=*/PJRT_FFI_TypeID_Register, + /*user_data_add=*/PJRT_FFI_UserData_Add, ++ /*register_handler=*/PJRT_FFI_Register_Handler, + }; + } + +-- +2.39.5 (Apple Git-154) + diff --git a/third_party/modules/xla/20250612.0-6e48cbb/patches/0003-Revert-Add-optional-allowOtherDialects-field-to-stab.patch b/third_party/modules/xla/20250612.0-6e48cbb/patches/0003-Revert-Add-optional-allowOtherDialects-field-to-stab.patch new file mode 100644 index 0000000..27d52df --- /dev/null +++ b/third_party/modules/xla/20250612.0-6e48cbb/patches/0003-Revert-Add-optional-allowOtherDialects-field-to-stab.patch @@ -0,0 +1,1855 @@ +From 4c0819ac9fb9dfc6156ae4de83fb29e987ade780 Mon Sep 17 00:00:00 2001 +From: Corentin Godeau +Date: Mon, 16 Jun 2025 14:27:47 +0000 +Subject: [PATCH] Update Stablehlo patch + +--- + third_party/stablehlo/temporary.patch | 841 +++++++++++++------------- + 1 file changed, 432 insertions(+), 409 deletions(-) + mode change 100755 => 100644 third_party/stablehlo/temporary.patch + +diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch +old mode 100755 +new mode 100644 +index f54e8341ba..6b25e78467 +--- a/third_party/stablehlo/temporary.patch ++++ b/third_party/stablehlo/temporary.patch +@@ -1,7 +1,47 @@ +-diff --ruN a/stablehlo/stablehlo/dialect/VhloAttrs.td b/stablehlo/stablehlo/dialect/VhloAttrs.td +---- stablehlo/stablehlo/dialect/VhloAttrs.td +-+++ stablehlo/stablehlo/dialect/VhloAttrs.td +-@@ -45,14 +45,6 @@ ++From c48aa0650bf27b81c42799f9218dd90779e36e3b Mon Sep 17 00:00:00 2001 ++From: Corentin Godeau ++Date: Mon, 16 Jun 2025 14:08:48 +0000 ++Subject: [PATCH] Remove default argument that is not valid C ++ ++--- ++ stablehlo/dialect/Version.h | 2 +- ++ stablehlo/dialect/VhloAttrs.td | 14 +- ++ stablehlo/dialect/VhloDialect.td | 1 + ++ stablehlo/dialect/VhloTypes.h | 24 +- ++ .../integrations/c/StablehloDialectApi.h | 2 +- ++ .../stablehlo_aggressive_folder.mlir | 37 +- ++ .../stablehlo_aggressive_simplification.mlir | 2 +- ++ .../transforms/stablehlo_refine_shapes.mlir | 10 +- ++ .../stablehlo_legalize_to_vhlo_mixed.mlir | 38 +- ++ .../tests/vhlo/vhlo_attributes_invalid.mlir | 13 +- ++ stablehlo/tools/StablehloTranslateMain.cpp | 8 +- ++ stablehlo/transforms/Passes.h | 7 +- ++ .../transforms/StablehloLegalizeToVhlo.cpp | 48 +- ++ .../transforms/StablehloRefineShapes.cpp | 4 + ++ .../transforms/VhloLegalizeToStablehlo.cpp | 100 ++- ++ stablehlo/transforms/VhloToVersion.cpp | 38 +- ++ .../StablehloAggressiveFolder.cpp | 816 ++++++++++-------- ++ ...ablehloAggressiveSimplificationPatterns.td | 6 +- ++ 18 files changed, 696 insertions(+), 474 deletions(-) ++ ++diff --git a/stablehlo/dialect/Version.h b/stablehlo/dialect/Version.h ++index eb35ad52..6aaef985 100644 ++--- a/stablehlo/dialect/Version.h +++++ b/stablehlo/dialect/Version.h ++@@ -38,7 +38,7 @@ class Version { ++ static FailureOr fromString(llvm::StringRef versionRef); ++ ++ /// Return a Version representing the current VHLO dialect version. ++- static Version getCurrentVersion() { return Version(1, 10, 10); } +++ static Version getCurrentVersion() { return Version(1, 11, 0); } ++ ++ /// Return a Version representing the minimum supported VHLO dialect version. ++ static Version getMinimumVersion() { return Version(0, 9, 0); } ++diff --git a/stablehlo/dialect/VhloAttrs.td b/stablehlo/dialect/VhloAttrs.td ++index bf75ad27..ab48630a 100644 ++--- a/stablehlo/dialect/VhloAttrs.td +++++ b/stablehlo/dialect/VhloAttrs.td ++@@ -45,14 +45,6 @@ class VHLO_AttrDef + def VHLO_ArrayAttrV1 : VHLO_AttrDef<"ArrayV1", "0.9.0", "current"> { + let mnemonic = "array_v1"; + let parameters = (ins ArrayRefParameter<"mlir::Attribute">:$value); +@@ -16,7 +56,7 @@ diff --ruN a/stablehlo/stablehlo/dialect/VhloAttrs.td b/stablehlo/stablehlo/dial + let assemblyFormat = "`<` custom($value) `>`"; + } + +-@@ -75,9 +67,9 @@ ++@@ -75,9 +67,9 @@ def VHLO_DictionaryAttrV1 : VHLO_AttrDef<"DictionaryV1", "0.9.0", "current"> { + LogicalResult DictionaryV1Attr::verify( + llvm::function_ref errFn, + ArrayRef> value) { +@@ -29,36 +69,48 @@ diff --ruN a/stablehlo/stablehlo/dialect/VhloAttrs.td b/stablehlo/stablehlo/dial + return success(); + } + }]; +-diff --ruN a/stablehlo/stablehlo/dialect/VhloTypes.h b/stablehlo/stablehlo/dialect/VhloTypes.h +---- stablehlo/stablehlo/dialect/VhloTypes.h +-+++ stablehlo/stablehlo/dialect/VhloTypes.h +-@@ -27,20 +27,23 @@ ++diff --git a/stablehlo/dialect/VhloDialect.td b/stablehlo/dialect/VhloDialect.td ++index 4e50cb32..edd220e3 100644 ++--- a/stablehlo/dialect/VhloDialect.td +++++ b/stablehlo/dialect/VhloDialect.td ++@@ -49,6 +49,7 @@ def VHLO_Dialect : Dialect { ++ 1.8.0: Introduce `f4E2M1FN`, `f6E2M3FN`, `f6E3M2FN` and `f8E8M0FNU` types. ++ 1.9.0: Add `ResultAccuracy` attribute to `exp` op. ++ 1.10.0: Add `ResultAccuracy` attribute to `cbrt`, `cosine`, `exponential`, `exponential_minus_one`, `log`, `log_plus_one`, `logistic`, `rsqrt`, `sine`, `sqrt`, `tan` and `tanh` ops. +++ 1.11.0: Allow (de)serializing VHLO programs mixed with potentially unstable dialects. ++ }]; ++ ++ let useDefaultAttributePrinterParser = 0; ++diff --git a/stablehlo/dialect/VhloTypes.h b/stablehlo/dialect/VhloTypes.h ++index e5aee254..f3a8d68d 100644 ++--- a/stablehlo/dialect/VhloTypes.h +++++ b/stablehlo/dialect/VhloTypes.h ++@@ -27,20 +27,23 @@ limitations under the License. + namespace mlir { + namespace vhlo { + + -class VhloTypeConverterBase : public TypeConverter { +-- public: +-- VhloTypeConverterBase() : TypeConverter(){}; +-- +-- virtual ~VhloTypeConverterBase() = default; +-- +-- virtual Attribute convertEncoding(Attribute attr) const = 0; +--}; +- +- // This class is used to manage conversions between VHLO and Builtin +- // dialects. +--class VhloTypeConverter : public VhloTypeConverterBase { +++ +++// This class is used to manage conversions between VHLO and Builtin +++// dialects. + +class VhloTypeConverter : public TypeConverter { + public: +-- VhloTypeConverter() : VhloTypeConverterBase() {} ++- VhloTypeConverterBase() : TypeConverter(){}; + + VhloTypeConverter() : TypeConverter(), allowOtherDialects(false) {} + + VhloTypeConverter(bool allowOtherDialects) + + : TypeConverter(), allowOtherDialects(allowOtherDialects) {} +-+ ++ ++- virtual ~VhloTypeConverterBase() = default; + + virtual ~VhloTypeConverter() = default; +-+ +-+ virtual Attribute convertEncoding(Attribute attr) const = 0; +-+ ++ ++ virtual Attribute convertEncoding(Attribute attr) const = 0; ++-}; ++ ++-// This class is used to manage conversions between VHLO and Builtin ++-// dialects. ++-class VhloTypeConverter : public VhloTypeConverterBase { ++- public: ++- VhloTypeConverter() : VhloTypeConverterBase() {} + + Attribute convertUnknownAttribute(Attribute attr) const { + + if (allowOtherDialects) return attr; + + return {}; +@@ -66,7 +118,7 @@ diff --ruN a/stablehlo/stablehlo/dialect/VhloTypes.h b/stablehlo/stablehlo/diale + + // A subclass can call this method to add conversions from VHLO -> Builtin + // types. Note that conversions are applied in reverse order, with the most +-@@ -58,6 +61,9 @@ ++@@ -58,6 +61,9 @@ class VhloTypeConverter : public VhloTypeConverterBase { + + // Mark unrealized casts as legal. Useful for dialect mixing. + void addUnrealizedMaterializations(); +@@ -76,19 +128,34 @@ diff --ruN a/stablehlo/stablehlo/dialect/VhloTypes.h b/stablehlo/stablehlo/diale + }; + + // Autogenerated VHLO type printers and parsers. +-diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir b/stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir +---- stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir +-+++ stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir ++diff --git a/stablehlo/integrations/c/StablehloDialectApi.h b/stablehlo/integrations/c/StablehloDialectApi.h ++index 385156bf..24d11c1d 100644 ++--- a/stablehlo/integrations/c/StablehloDialectApi.h +++++ b/stablehlo/integrations/c/StablehloDialectApi.h ++@@ -93,7 +93,7 @@ stablehloSerializePortableArtifactFromModule(MlirModule moduleStr, ++ MlirStringRef targetVersion, ++ MlirStringCallback callback, ++ void* userData, ++- bool allowOtherDialects = false); +++ bool allowOtherDialects); ++ ++ // Read a StableHLO program from a portable artifact, returning the module as ++ // MLIR bytecode. Note, this bytecode returned is not a portable artifact, ++diff --git a/stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir b/stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir ++index 758965d0..c5239671 100644 ++--- a/stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir +++++ b/stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir + @@ -1,4 +1,4 @@ + -// RUN: stablehlo-opt --stablehlo-aggressive-folder --split-input-file --verify-diagnostics %s | FileCheck %s + +// RUN: stablehlo-opt --stablehlo-aggressive-folder=fold-op-element-limit=100 --split-input-file --verify-diagnostics %s | FileCheck %s + + //////// + // AddOp +-@@ -42,6 +42,21 @@ ++@@ -41,6 +41,21 @@ func.func @broadcast_in_dim_fold_splat(%arg0: tensor<3x3xi32>) ++ + // ----- + +- //////// +++//////// + +// ClampOp + + + +// CHECK-LABEL: func.func @clamp_fold +@@ -103,18 +170,13 @@ diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_folder.ml + + + +// ----- + + +-+//////// ++ //////// + // CompareOp + +- // CHECK-LABEL: func.func @compare_folds +-@@ -98,6 +113,26 @@ +- // CHECK-DAG: [[R3:%.+]] = stablehlo.constant dense<{{\[\[0, 1, 2, 11, 12\], \[3, 4, 5, 13, 14\]\]}}> : tensor<2x5xi32> +- // CHECK-NEXT: return [[R0]], [[R1]], [[R2]], [[R3]] +- return %0, %1, %2, %3 : tensor<6xi32>, tensor<3xi32>, tensor<3x3xi32>, tensor<2x5xi32> +-+} +-+ +-+// ----- +-+ ++@@ -102,6 +117,26 @@ func.func @concatenate_fold() -> (tensor<6xi32>, tensor<3xi32>, tensor<3x3xi32>, ++ ++ // ----- ++ + +//////// + +// DivOp + + +@@ -131,32 +193,37 @@ diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_folder.ml + + %1 = stablehlo.divide %cst_1, %cst_1 : tensor + + %2 = stablehlo.divide %cst_2, %cst_2 : tensor + + return %0, %1, %2 : tensor, tensor, tensor +- } +++} +++ +++// ----- +++ ++ //////// ++ // MulOp + +- // ----- +-diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir b/stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir +---- stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir +-+++ stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir ++diff --git a/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir b/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir ++index 4921a224..a1eed944 100644 ++--- a/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir +++++ b/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir + @@ -1,4 +1,4 @@ + -// RUN: stablehlo-opt --stablehlo-aggressive-simplification --allow-unregistered-dialect --split-input-file %s | FileCheck %s + +// RUN: stablehlo-opt --stablehlo-aggressive-simplification=fold-op-element-limit=100 --allow-unregistered-dialect --split-input-file %s | FileCheck %s + + ///////// + // AddOp +-diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_refine_shapes.mlir b/stablehlo/stablehlo/tests/transforms/stablehlo_refine_shapes.mlir +---- stablehlo/stablehlo/tests/transforms/stablehlo_refine_shapes.mlir +-+++ stablehlo/stablehlo/tests/transforms/stablehlo_refine_shapes.mlir +-@@ -521,16 +521,16 @@ ++diff --git a/stablehlo/tests/transforms/stablehlo_refine_shapes.mlir b/stablehlo/tests/transforms/stablehlo_refine_shapes.mlir ++index b262bf09..01b5dd8f 100644 ++--- a/stablehlo/tests/transforms/stablehlo_refine_shapes.mlir +++++ b/stablehlo/tests/transforms/stablehlo_refine_shapes.mlir ++@@ -521,16 +521,16 @@ func.func @eval_slice_zerodim() -> tensor<0x2x1xi64> { + // ----- + + // CHECK-LABEL: func @eval_slice_zerorank + -func.func @eval_slice_zerorank() -> tensor { + - // CHECK: [[RESULT:%.*]] = stablehlo.constant dense<3.300000e+01> : tensor +-- // CHECK: return [[RESULT]] +-- %0 = stablehlo.constant dense<33.0> : tensor + +func.func @eval_slice_zerorank() -> tensor { + + // CHECK: [[RESULT:%.*]] = stablehlo.constant dense<33> : tensor +-+ // CHECK: return [[RESULT]] ++ // CHECK: return [[RESULT]] ++- %0 = stablehlo.constant dense<33.0> : tensor + + %0 = stablehlo.constant dense<33> : tensor + %1 = "stablehlo.slice"(%0) { + start_indices = array, +@@ -169,9 +236,10 @@ diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_refine_shapes.mlir b + } + + // ----- +-diff --ruN a/stablehlo/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo_mixed.mlir b/stablehlo/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo_mixed.mlir +---- stablehlo/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo_mixed.mlir +-+++ stablehlo/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo_mixed.mlir ++diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo_mixed.mlir b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo_mixed.mlir ++index 6340898e..a5d344ac 100644 ++--- a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo_mixed.mlir +++++ b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo_mixed.mlir + @@ -5,15 +5,9 @@ + // about what constitutes a good test! The CHECK should be + // minimized and named to reflect the test intent. +@@ -190,14 +258,10 @@ diff --ruN a/stablehlo/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo_mixed.mli + // RUN: diff %t.0 %t.1 + + // CHECK-LABEL: vhlo.func_v1 @op_other( +-@@ -26,6 +20,34 @@ +- func.func @op_other(%arg0: tensor) -> tensor { +- %0 = arith.addf %arg0, %arg0 : tensor +- return %0 : tensor +-+} +-+ +-+// ----- +-+ ++@@ -30,6 +24,34 @@ func.func @op_other(%arg0: tensor) -> tensor { ++ ++ // ----- ++ + +// CHECK-LABEL: vhlo.func_v1 @func_attributes( + +// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !vhlo.tensor_v1) -> (!vhlo.tensor_v1) { + +// CHECK: "vhlo.return_v1"(%[[VAL_0]]) : (!vhlo.tensor_v1) -> () +@@ -222,12 +286,17 @@ diff --ruN a/stablehlo/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo_mixed.mli + + vhlo.mixed_dict = {affine_map = affine_map<(d0) -> (d0)>, str_attr = "STR_ATTR"} + +} { + + return %arg0 : tensor +- } +- +- // ----- +-diff --ruN a/stablehlo/stablehlo/tests/vhlo/vhlo_attributes_invalid.mlir b/stablehlo/stablehlo/tests/vhlo/vhlo_attributes_invalid.mlir +---- stablehlo/stablehlo/tests/vhlo/vhlo_attributes_invalid.mlir +-+++ stablehlo/stablehlo/tests/vhlo/vhlo_attributes_invalid.mlir +++} +++ +++// ----- +++ ++ // CHECK-LABEL: vhlo.func_v1 @op_shlo( ++ // CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !vhlo.tensor_v1) -> (!vhlo.tensor_v1) { ++ // CHECK: %[[VAL_1:.*]] = "vhlo.add_v1"(%[[VAL_0]], %[[VAL_0]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 ++diff --git a/stablehlo/tests/vhlo/vhlo_attributes_invalid.mlir b/stablehlo/tests/vhlo/vhlo_attributes_invalid.mlir ++index b73d1005..f4330e67 100644 ++--- a/stablehlo/tests/vhlo/vhlo_attributes_invalid.mlir +++++ b/stablehlo/tests/vhlo/vhlo_attributes_invalid.mlir + @@ -1,17 +1,8 @@ + // RUN: stablehlo-opt --vhlo-to-version=target=1.9.0 -verify-diagnostics --split-input-file %s + +@@ -248,10 +317,11 @@ diff --ruN a/stablehlo/stablehlo/tests/vhlo/vhlo_attributes_invalid.mlir b/stabl + } { + return + } +-diff --ruN a/stablehlo/stablehlo/tools/StablehloTranslateMain.cpp b/stablehlo/stablehlo/tools/StablehloTranslateMain.cpp +---- stablehlo/stablehlo/tools/StablehloTranslateMain.cpp +-+++ stablehlo/stablehlo/tools/StablehloTranslateMain.cpp +-@@ -76,6 +76,11 @@ ++diff --git a/stablehlo/tools/StablehloTranslateMain.cpp b/stablehlo/tools/StablehloTranslateMain.cpp ++index e3171b2d..fdf0d6a9 100644 ++--- a/stablehlo/tools/StablehloTranslateMain.cpp +++++ b/stablehlo/tools/StablehloTranslateMain.cpp ++@@ -76,6 +76,11 @@ llvm::cl::opt targetOption( + "target", llvm::cl::desc("Target version for serialization"), + llvm::cl::init("")); + +@@ -263,7 +333,7 @@ diff --ruN a/stablehlo/stablehlo/tools/StablehloTranslateMain.cpp b/stablehlo/st + llvm::cl::opt argsOption( + "args", llvm::cl::desc("Arguments to pass to the interpreter"), + llvm::cl::init("")); +-@@ -317,7 +322,8 @@ ++@@ -317,7 +322,8 @@ TranslateFromMLIRRegistration serializeRegistration( + return module.emitError("failed to strip debuginfo"); + } + +@@ -273,10 +343,11 @@ diff --ruN a/stablehlo/stablehlo/tools/StablehloTranslateMain.cpp b/stablehlo/st + }, + [](DialectRegistry ®istry) { + mlir::registerAllDialects(registry); +-diff --ruN a/stablehlo/stablehlo/transforms/Passes.h b/stablehlo/stablehlo/transforms/Passes.h +---- stablehlo/stablehlo/transforms/Passes.h +-+++ stablehlo/stablehlo/transforms/Passes.h +-@@ -35,6 +35,7 @@ ++diff --git a/stablehlo/transforms/Passes.h b/stablehlo/transforms/Passes.h ++index 0e4406d2..a5a4aa27 100644 ++--- a/stablehlo/transforms/Passes.h +++++ b/stablehlo/transforms/Passes.h ++@@ -35,6 +35,7 @@ limitations under the License. + #include "mlir/Transforms/DialectConversion.h" + #include "stablehlo/dialect/StablehloOps.h" + #include "stablehlo/dialect/Version.h" +@@ -284,7 +355,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/Passes.h b/stablehlo/stablehlo/trans + + namespace mlir { + namespace stablehlo { +-@@ -54,17 +55,17 @@ ++@@ -54,17 +55,17 @@ void populateStablehloRefineShapesPatterns(MLIRContext *context, + // Populates StableHLO ops to VHLO ops rewriting patterns. + void populateStablehloToVhloPatterns(MLIRContext *context, + RewritePatternSet *patterns, +@@ -305,10 +376,11 @@ diff --ruN a/stablehlo/stablehlo/transforms/Passes.h b/stablehlo/stablehlo/trans + + /// Collection of rewrite patterns for lowering of CHLO ops to StableHLO and + /// Shape ops. +-diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp +---- stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp +-+++ stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp +-@@ -57,7 +57,7 @@ ++diff --git a/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stablehlo/transforms/StablehloLegalizeToVhlo.cpp ++index e768ded6..25903399 100644 ++--- a/stablehlo/transforms/StablehloLegalizeToVhlo.cpp +++++ b/stablehlo/transforms/StablehloLegalizeToVhlo.cpp ++@@ -57,7 +57,7 @@ namespace { + class StablehloToVhloTypeConverter : public vhlo::VhloTypeConverter { + public: + StablehloToVhloTypeConverter(bool allowOtherDialects) +@@ -317,7 +389,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stable + LLVM_DEBUG( + llvm::dbgs() + << "[StablehloToVhloTypeConverter] Creating with allowOtherDialects: " +-@@ -82,7 +82,15 @@ ++@@ -82,7 +82,15 @@ class StablehloToVhloTypeConverter : public vhlo::VhloTypeConverter { + return vhlo::TokenV1Type::get(token.getContext()); + }); + addBuiltinToVhloConversions(); +@@ -334,7 +406,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stable + } + + Attribute convertEncoding(Attribute attr) const final { +-@@ -114,7 +122,7 @@ ++@@ -114,7 +122,7 @@ class StablehloToVhloTypeConverter : public vhlo::VhloTypeConverter { + return vhlo::Name##Version##Attr::get(attr.getContext(), vhloValue.value()) + + Attribute convertGeneric(Attribute stablehloAttr, +@@ -343,7 +415,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stable + LLVM_DEBUG(llvm::dbgs() << "Convert generic: " << stablehloAttr << '\n'); + + // Handle StableHLO attributes. +-@@ -241,6 +249,10 @@ ++@@ -241,6 +249,10 @@ Attribute convertGeneric(Attribute stablehloAttr, + return vhlo::TypeV1Attr::get(attr.getContext(), vhloType); + } + +@@ -354,7 +426,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stable + LLVM_DEBUG(llvm::dbgs() << "Failed to convert: " << stablehloAttr << '\n'); + return {}; // Failed to convert attribute. + } +-@@ -268,13 +280,15 @@ ++@@ -268,13 +280,15 @@ SpecialResult notSpecial() { return SpecialResult::NOT_SPECIAL; } + Attribute convertBool(const ConversionPattern& pattern, int64_t stablehloDim) { + auto stablehloType = IntegerType::get(pattern.getContext(), 1); + auto stablehloAttr = IntegerAttr::get(stablehloType, stablehloDim); +@@ -372,7 +444,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stable + } + + Attribute convertInts(const ConversionPattern& pattern, +-@@ -282,7 +296,8 @@ ++@@ -282,7 +296,8 @@ Attribute convertInts(const ConversionPattern& pattern, + auto stablehloType = RankedTensorType::get( + stablehloDims.size(), IntegerType::get(pattern.getContext(), 64)); + auto stablehloAttr = DenseIntElementsAttr::get(stablehloType, stablehloDims); +@@ -382,7 +454,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stable + } + + Attribute convertSymbol(const ConversionPattern& pattern, +-@@ -290,7 +305,7 @@ ++@@ -290,7 +305,7 @@ Attribute convertSymbol(const ConversionPattern& pattern, + auto stablehloSymbolAttr = dyn_cast(stablehloAttr); + if (!stablehloSymbolAttr) return {}; + return convertGeneric(stablehloSymbolAttr.getAttr(), +@@ -391,7 +463,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stable + } + + SpecialResult convertChannelHandle(const ConversionPattern& pattern, +-@@ -445,15 +460,15 @@ ++@@ -445,15 +460,15 @@ SpecialResult convertDotAlgorithm(const ConversionPattern& pattern, + vhloAttrs.emplace_back( + StringAttr::get(pattern.getContext(), "lhs_precision_type"), + convertGeneric(TypeAttr::get(attr.getLhsPrecisionType()), +@@ -410,7 +482,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stable + + // Components + auto vhloLhsComponentCount = convertInt(pattern, attr.getLhsComponentCount()); +-@@ -712,7 +727,9 @@ ++@@ -712,7 +727,9 @@ LogicalResult addDefaults(const OpConversionPattern& pattern, + auto addDefaultAttr = [&](StringRef vhloName, Attribute stablehloAttr) { + vhloAttrs.emplace_back( + StringAttr::get(pattern.getContext(), vhloName), +@@ -421,7 +493,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stable + }; + if constexpr (std::is_same::value) { + if (!stablehloOp.getSymVisibilityAttr()) +-@@ -987,8 +1004,9 @@ ++@@ -987,8 +1004,9 @@ class StablehloToVhloOpConverter : public OpConversionPattern { + case SpecialResult::SPECIAL_FAILURE: + return failure(); + case SpecialResult::NOT_SPECIAL: +@@ -433,7 +505,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stable + if (!vhloAttr) return failure(); + vhloAttrs.push_back({stablehloAttr.getName(), vhloAttr}); + break; +-@@ -1075,14 +1093,14 @@ ++@@ -1075,14 +1093,14 @@ struct StablehloLegalizeToVhloPass + } + + private: +@@ -450,10 +522,11 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stable + populateStablehloToVhloPatterns< + #define GET_OP_LIST + #include "stablehlo/dialect/StablehloOps.cpp.inc" +-diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp +---- stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp +-+++ stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp +-@@ -16,6 +16,7 @@ ++diff --git a/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehlo/transforms/StablehloRefineShapes.cpp ++index 64eaa3c0..4b585b46 100644 ++--- a/stablehlo/transforms/StablehloRefineShapes.cpp +++++ b/stablehlo/transforms/StablehloRefineShapes.cpp ++@@ -16,6 +16,7 @@ limitations under the License. + + #include + #include +@@ -461,7 +534,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehl + #include + #include + +-@@ -1038,8 +1039,11 @@ ++@@ -1038,8 +1039,11 @@ LogicalResult applyShapeRefinementPatterns(func::FuncOp func, + // Populate additional patterns for StableHLO extensions. + state.addAdditionalPatterns(patterns); + +@@ -473,10 +546,11 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehl + + // The folding patterns implement partial evaluation of shape computations + // which is a critical part of implementing type refinement for ops like +-diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp +---- stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp +-+++ stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp +-@@ -57,7 +57,10 @@ ++diff --git a/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stablehlo/transforms/VhloLegalizeToStablehlo.cpp ++index 5662ad9d..41973a0e 100644 ++--- a/stablehlo/transforms/VhloLegalizeToStablehlo.cpp +++++ b/stablehlo/transforms/VhloLegalizeToStablehlo.cpp ++@@ -57,7 +57,10 @@ namespace { + + class VhloToStablehloTypeConverter : public vhlo::VhloTypeConverter { + public: +@@ -488,7 +562,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable + addConversion([](Type type) -> Type { return type; }); + addConversion([](vhlo::TokenV1Type token) -> Type { + LLVM_DEBUG(llvm::dbgs() << "Converting TokenType\n"); +-@@ -90,7 +93,7 @@ ++@@ -90,7 +93,7 @@ class VhloToStablehloTypeConverter : public vhlo::VhloTypeConverter { + return stablehlo::Name##Attr::get(attr.getContext(), stablehloValue.value()) + + Attribute convertGeneric(Attribute vhloAttr, +@@ -497,18 +571,18 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable + LLVM_DEBUG(llvm::dbgs() << "Converting attr " << vhloAttr); + if (auto vhloAttrs = dyn_cast(vhloAttr)) { + SmallVector stablehloAttrs; +-@@ -189,6 +192,10 @@ +- // All VHLO attributes must have counterparts in StableHLO. ++@@ -190,6 +193,10 @@ Attribute convertGeneric(Attribute vhloAttr, + return {}; + } +-+ ++ + + // Fall back to type converter for unknown attributes. + + auto unknownAttr = typeConverter->convertUnknownAttribute(vhloAttr); + + if (unknownAttr) return unknownAttr; +- +++ + // This should be unreachable unless program is a mix of VHLO and other + // dialects, e.g. due to user edits to textual assembly format. +-@@ -229,7 +236,7 @@ ++ return {}; ++@@ -229,7 +236,7 @@ bool isNoneType(Attribute vhloAttr) { + } + + LogicalResult convertTypeAttr(Attribute vhloAttr, Type& result, +@@ -517,7 +591,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable + auto stablehloAttr = convertGeneric(vhloAttr, typeConverter); + if (!stablehloAttr || !isa(stablehloAttr)) return failure(); + result = cast(stablehloAttr).getValue(); +-@@ -244,7 +251,7 @@ ++@@ -244,7 +251,7 @@ LogicalResult convertInt(Attribute vhloAttr, int64_t& result) { + } + + LogicalResult convertInts(Attribute vhloAttr, +@@ -526,7 +600,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable + SmallVector& result) { + auto vhloTensorAttr = dyn_cast(vhloAttr); + if (!vhloTensorAttr) return failure(); +-@@ -256,7 +263,7 @@ ++@@ -256,7 +263,7 @@ LogicalResult convertInts(Attribute vhloAttr, + } + + Attribute convertSymbol(Attribute vhloAttr, +@@ -535,7 +609,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable + auto vhloStringAttr = dyn_cast(vhloAttr); + if (!vhloStringAttr) return {}; + auto stablehloStringAttr = dyn_cast_or_null( +-@@ -267,7 +274,7 @@ ++@@ -267,7 +274,7 @@ Attribute convertSymbol(Attribute vhloAttr, + + template + Attribute convertChannelHandle(OpType vhloOp, +@@ -544,7 +618,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable + int64_t channelId, channelType; + if (failed(convertInt(vhloOp.getChannelId(), channelId)) || + failed(convertInt(vhloOp.getChannelType(), channelType))) +-@@ -277,7 +284,7 @@ ++@@ -277,7 +284,7 @@ Attribute convertChannelHandle(OpType vhloOp, + } + + Attribute convertChannelId(Attribute vhloAttr, +@@ -553,7 +627,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable + int64_t channelId; + if (failed(convertInt(vhloAttr, channelId))) return {}; + return stablehlo::ChannelHandleAttr::get(vhloAttr.getContext(), channelId, +-@@ -285,8 +292,8 @@ ++@@ -285,8 +292,8 @@ Attribute convertChannelId(Attribute vhloAttr, + } + + template +@@ -564,7 +638,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable + int64_t stablehloInputBatchDimension, stablehloInputFeatureDimension; + SmallVector stablehloInputSpatialDimensions; + int64_t stablehloKernelInputFeatureDimension, +-@@ -323,7 +330,7 @@ ++@@ -323,7 +330,7 @@ Attribute convertConvDimensionNumbers(OpType vhloOp, + } + + Attribute convertCustomCallCalledComputations( +@@ -573,7 +647,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable + if (auto vhloArrayAttr = dyn_cast(vhloAttr)) { + SmallVector stablehloAttrs; + for (auto vhloAttr : vhloArrayAttr.getValue()) { +-@@ -336,8 +343,8 @@ ++@@ -336,8 +343,8 @@ Attribute convertCustomCallCalledComputations( + return {}; + } + +@@ -584,7 +658,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable + Type lhsPrecisionType, rhsPrecisionType, accumulationType; + if (isNoneType(vhloOp.getLhsComponentCount())) { + // All must be nonetype +-@@ -373,8 +380,8 @@ ++@@ -373,8 +380,8 @@ FailureOr convertDotAlgorithm(vhlo::DotGeneralOpV2 vhloOp, + numPrimitiveOperations, allowImpreciseAccumulation); + } + +@@ -595,7 +669,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable + SmallVector stablehloLhsBatchingDimensions, + stablehloRhsBatchingDimensions, stablehloLhsContractingDimensions, + stablehloRhsContractingDimensions; +-@@ -394,13 +401,13 @@ ++@@ -394,13 +401,13 @@ Attribute convertDotDimensionNumbers(vhlo::DotGeneralOpV2 vhloOp, + } + + Attribute convertFuncCallee(Attribute vhloAttr, +@@ -612,7 +686,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable + SmallVector stablehloOffsetDims, stablehloCollapsedSliceDims, + stablehloOperandBatchingDims, stablehloStartIndicesBatchingDims, + stablehloStartIndexMap; +-@@ -423,8 +430,8 @@ ++@@ -423,8 +430,8 @@ Attribute convertGatherDimensionNumbers(OpType vhloOp, + stablehloStartIndexMap, stablehloIndexVectorDim); + } + +@@ -623,7 +697,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable + SmallVector stablehloUpdateWindowDims, stablehloInsertedWindowDims, + stablehloInputBatchingDims, stablehloScatterIndicesBatchingDims, + stablehloScatterDimsToOperandDims; +-@@ -463,10 +470,11 @@ ++@@ -463,10 +470,11 @@ LogicalResult implodeSpecial(const OpConversionPattern& pattern, + VhloOpTy vhloOp, + SmallVector& vhloAttrs, + SmallVector& stablehloAttrs) { +@@ -637,7 +711,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable + if (!stablehloAttr) return failure(); + stablehloAttrs.emplace_back( + StringAttr::get(pattern.getContext(), "dimension_numbers"), +-@@ -480,7 +488,7 @@ ++@@ -480,7 +488,7 @@ LogicalResult implodeSpecial(const OpConversionPattern& pattern, + if constexpr (std::is_same::value) { + // Dot Dimension Numbers + auto stablehloDotDimAttr = +@@ -646,7 +720,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable + if (!stablehloDotDimAttr) return failure(); + stablehloAttrs.emplace_back( + StringAttr::get(pattern.getContext(), "dot_dimension_numbers"), +-@@ -489,8 +497,7 @@ ++@@ -489,8 +497,7 @@ LogicalResult implodeSpecial(const OpConversionPattern& pattern, + "lhs_contracting_dimensions", "rhs_contracting_dimensions"); + + // Dot Algorithm +@@ -656,7 +730,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable + if (failed(stablehloDotAlgorithmAttr)) return failure(); + if (stablehloDotAlgorithmAttr.value()) + stablehloAttrs.emplace_back( +-@@ -503,8 +510,7 @@ ++@@ -503,8 +510,7 @@ LogicalResult implodeSpecial(const OpConversionPattern& pattern, + } + if constexpr (std::is_same::value || + std::is_same::value) { +@@ -666,7 +740,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable + if (!stablehloAttr) return failure(); + stablehloAttrs.emplace_back( + StringAttr::get(pattern.getContext(), "dimension_numbers"), +-@@ -514,8 +520,7 @@ ++@@ -514,8 +520,7 @@ LogicalResult implodeSpecial(const OpConversionPattern& pattern, + "start_index_map", "index_vector_dim"); + } + if constexpr (std::is_same::value) { +@@ -676,7 +750,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable + if (!stablehloAttr) return failure(); + stablehloAttrs.emplace_back( + StringAttr::get(pattern.getContext(), "scatter_dimension_numbers"), +-@@ -526,8 +531,7 @@ ++@@ -526,8 +531,7 @@ LogicalResult implodeSpecial(const OpConversionPattern& pattern, + } + if constexpr (std::is_same::value || + std::is_same::value) { +@@ -686,7 +760,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable + if (!stablehloAttr) return failure(); + stablehloAttrs.emplace_back( + StringAttr::get(pattern.getContext(), "channel_handle"), stablehloAttr); +-@@ -537,7 +541,7 @@ ++@@ -537,7 +541,7 @@ LogicalResult implodeSpecial(const OpConversionPattern& pattern, + } + + template +@@ -695,7 +769,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable + StringAttr vhloName, Attribute vhloAttr, + SmallVector& stablehloAttrs) { + auto tensorAttr = dyn_cast(vhloAttr); +-@@ -556,15 +560,15 @@ ++@@ -556,15 +560,15 @@ SpecialResult convertDenseArray(const TypeConverter* typeConverter, + } + + SpecialResult convertDenseBoolArray( +@@ -715,7 +789,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable + return convertDenseArray( + typeConverter, vhloName, vhloAttr, stablehloAttrs); + } +-@@ -575,7 +579,8 @@ ++@@ -575,7 +579,8 @@ SpecialResult convertSpecial(const OpConversionPattern& pattern, + SmallVector& stablehloAttrs) { + StringAttr stablehloName = vhloName; + Attribute stablehloAttr; +@@ -725,7 +799,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable + + if constexpr (std::is_same::value || + std::is_same::value || +-@@ -585,7 +590,7 @@ ++@@ -585,7 +590,7 @@ SpecialResult convertSpecial(const OpConversionPattern& pattern, + std::is_same::value) { + if (vhloName == "channel_id") { + stablehloName = StringAttr::get(pattern.getContext(), "channel_handle"); +@@ -734,7 +808,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable + if (!stablehloAttr) return specialFailure(); + } + if (vhloName == "use_global_device_ids") { +-@@ -597,20 +602,20 @@ ++@@ -597,20 +602,20 @@ SpecialResult convertSpecial(const OpConversionPattern& pattern, + } + if constexpr (std::is_same::value) { + if (vhloName == "called_computations") { +@@ -759,7 +833,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable + if (!stablehloAttr) return specialFailure(); + } + } +-@@ -760,8 +765,8 @@ ++@@ -760,8 +765,8 @@ bool isDefaultResultAccuracyAttribute(Attribute vhloAttr) { + template + bool isSplatTensor(const ConversionPattern& pattern, Attribute vhloAttr, + T splatValue) { +@@ -770,7 +844,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable + return attr && attr.isSplat() && + attr.template getSplatValue() == splatValue; + } +-@@ -977,8 +982,9 @@ ++@@ -977,8 +982,9 @@ class VhloToStablehloOpConverter : public OpConversionPattern { + case SpecialResult::SPECIAL_FAILURE: + return failure(); + case SpecialResult::NOT_SPECIAL: +@@ -782,7 +856,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable + if (!stablehloAttr) return failure(); + stablehloAttrs.push_back({vhloAttr.getName(), stablehloAttr}); + break; +-@@ -1056,7 +1062,7 @@ ++@@ -1056,7 +1062,7 @@ struct ReconcileUnrealizedConversionCasts + template + void populateVhloToStablehloPatterns(MLIRContext* context, + RewritePatternSet* patterns, +@@ -791,7 +865,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable + patterns + ->add>...>( + *converter, context); +-@@ -1104,7 +1110,7 @@ ++@@ -1104,7 +1110,7 @@ struct VhloLegalizeToStablehloPass + + void populateVhloToStablehloPatterns(MLIRContext* context, + RewritePatternSet* patterns, +@@ -800,10 +874,11 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable + populateVhloToStablehloPatterns< + #define GET_OP_LIST + #include "stablehlo/dialect/StablehloOps.cpp.inc" +-diff --ruN a/stablehlo/stablehlo/transforms/VhloToVersion.cpp b/stablehlo/stablehlo/transforms/VhloToVersion.cpp +---- stablehlo/stablehlo/transforms/VhloToVersion.cpp +-+++ stablehlo/stablehlo/transforms/VhloToVersion.cpp +-@@ -56,15 +56,23 @@ ++diff --git a/stablehlo/transforms/VhloToVersion.cpp b/stablehlo/transforms/VhloToVersion.cpp ++index dfbf2877..7818e601 100644 ++--- a/stablehlo/transforms/VhloToVersion.cpp +++++ b/stablehlo/transforms/VhloToVersion.cpp ++@@ -56,15 +56,23 @@ namespace { + + // Currently there are no type-to-version conversions so this class + // simply validates that all types are from the VHLO dialect. +@@ -833,7 +908,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloToVersion.cpp b/stablehlo/stable + }; + + // Check user-specified target version. Emit error if invalid. +-@@ -111,6 +119,10 @@ ++@@ -111,6 +119,10 @@ bool isLegalVersion(VersionedInterface& interface, const Version& target) { + LogicalResult isLegalType(Type type, const Version& targetVersion); + + LogicalResult isLegalAttribute(const Attribute& attr, Version targetVersion) { +@@ -844,7 +919,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloToVersion.cpp b/stablehlo/stable + auto attrInterface = dyn_cast(attr); + if (!attrInterface || !isLegalVersion(attrInterface, targetVersion)) { + LLVM_DEBUG(llvm::dbgs() << "failed to legalize attribute " << attr +-@@ -119,10 +131,11 @@ ++@@ -119,10 +131,11 @@ LogicalResult isLegalAttribute(const Attribute& attr, Version targetVersion) { + } + + // Recursively check attrs if VHLO attr is a container +@@ -857,7 +932,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloToVersion.cpp b/stablehlo/stable + if (auto arrAttr = dyn_cast(attr)) { + return success(llvm::all_of( + arrAttr.getValue(), [&](std::pair entry) { +-@@ -146,6 +159,10 @@ ++@@ -146,6 +159,10 @@ LogicalResult isLegalAttribute(const Attribute& attr, Version targetVersion) { + } + + LogicalResult isLegalType(Type type, const Version& targetVersion) { +@@ -868,7 +943,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloToVersion.cpp b/stablehlo/stable + // All valid VHLO types must have versioned type interface. + auto typeInterface = dyn_cast(type); + if (!typeInterface || !isLegalVersion(typeInterface, targetVersion)) { +-@@ -170,10 +187,11 @@ ++@@ -170,10 +187,11 @@ LogicalResult isLegalType(Type type, const Version& targetVersion) { + return failure(); + return isLegalType(ranked.getElementType(), targetVersion); + } +@@ -881,7 +956,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloToVersion.cpp b/stablehlo/stable + if (auto quant = dyn_cast(type)) + return success( + succeeded(isLegalType(quant.getStorageType(), targetVersion)) && +-@@ -213,7 +231,7 @@ ++@@ -213,7 +231,7 @@ bool isLegalLocation(Location loc, const Version& targetVersion) { + bool isLegalOperation(Operation* op, const Version& targetVersion) { + // Validate op + auto opInterface = dyn_cast(op); +@@ -890,7 +965,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloToVersion.cpp b/stablehlo/stable + if (!isLegalVersion(opInterface, targetVersion)) return false; + LLVM_DEBUG(llvm::dbgs() << "Legal op version for target. " << op << '\n'); + +-@@ -454,7 +472,7 @@ ++@@ -454,7 +472,7 @@ struct AllReduceOpV2ToV1 : public OpRewritePattern { + namespace stablehlo { + void populateVhloToVersionPatterns(MLIRContext* context, + RewritePatternSet* patterns, +@@ -899,10 +974,11 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloToVersion.cpp b/stablehlo/stable + vhlo::populateWithGenerated(*patterns); + patterns->add(context); + patterns->add(context); +-diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp b/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp +---- stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp +-+++ stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp +-@@ -19,6 +19,7 @@ ++diff --git a/stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp b/stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp ++index 9cfac0ba..09bfb783 100644 ++--- a/stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp +++++ b/stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp ++@@ -19,6 +19,7 @@ limitations under the License. + #include + #include + #include +@@ -910,7 +986,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + #include + + #include "llvm/ADT/APInt.h" +-@@ -47,6 +48,7 @@ ++@@ -47,6 +48,7 @@ limitations under the License. + #include "mlir/IR/TypeUtilities.h" + #include "mlir/IR/Types.h" + #include "mlir/IR/Value.h" +@@ -918,7 +994,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + #include "mlir/Interfaces/SideEffectInterfaces.h" + #include "mlir/Pass/Pass.h" + #include "mlir/Rewrite/FrozenRewritePatternSet.h" +-@@ -98,29 +100,88 @@ ++@@ -98,29 +100,88 @@ LogicalResult validateStaticShapeResult(PatternRewriter& rewriter, + return success(); + } + +@@ -982,8 +1058,8 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + if (res) return cast(res); + + return nullptr; +-+} +-+ ++ } ++ + + + +/// Binary constant folder that used a generic folder function to handle both + +/// ints and floats. +@@ -1005,8 +1081,8 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + + if (!res) return rewriter.notifyMatchFailure(op, "folding failed"); + + + + return res; +- } +- +++} +++ + template + -LogicalResult evalConvertHelper(PatternRewriter& rewriter, OpType op, +@@ -1014,7 +1090,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + DenseIntOrFPElementsAttr elements, Type resType, + CalculationT&& calculate) { + auto result = constFoldCastOp( + elements, resType, calculate); + +@@ -1036,7 +1112,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + DenseIntOrFPElementsAttr elements, + RankedTensorType resultType) { + auto oldType = getElementTypeOrSelf(elements); +-@@ -153,7 +215,7 @@ ++@@ -153,7 +215,7 @@ LogicalResult evalConvert(PatternRewriter& rewriter, OpType op, + if (auto newFloatType = dyn_cast(newType)) { + // Float -> Float + const auto& targetSemantics = newFloatType.getFloatSemantics(); +@@ -1045,7 +1121,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + rewriter, op, elements, resultType, + [&targetSemantics](const APFloat& operand, bool& castStatus) { + bool losesInfo; +-@@ -167,7 +229,7 @@ ++@@ -167,7 +229,7 @@ LogicalResult evalConvert(PatternRewriter& rewriter, OpType op, + } + + // Float -> Int +@@ -1054,7 +1130,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + rewriter, op, elements, resultType, + [&newBitWidth, &isNewTypeUnsigned](const APFloat& operand, + bool& castStatus) { +-@@ -186,7 +248,7 @@ ++@@ -186,7 +248,7 @@ LogicalResult evalConvert(PatternRewriter& rewriter, OpType op, + + if (auto newFloatType = dyn_cast(newType)) { + // Int -> Float +@@ -1063,7 +1139,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + rewriter, op, elements, resultType, + [&newFloatType, &isOldTypeUnsigned](const APInt& operand, + bool& /*castStatus*/) { +-@@ -199,64 +261,12 @@ ++@@ -199,7 +261,7 @@ LogicalResult evalConvert(PatternRewriter& rewriter, OpType op, + } + + // Int -> Int +@@ -1072,10 +1148,10 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + rewriter, op, elements, resultType, + [&newBitWidth, &isOldTypeUnsigned](const APInt& operand, + bool& /*castStatus*/) { +- return APSInt(operand, isOldTypeUnsigned).extOrTrunc(newBitWidth); ++@@ -207,58 +269,6 @@ LogicalResult evalConvert(PatternRewriter& rewriter, OpType op, + }); +--} +-- ++ } ++ + -// The patterns below implement partial evaluation of shape computations which + -// is a critical part of implementing type refinement for ops like + -// dynamic_broadcast_in_dim, dynamic_iota and dynamic_reshape whose shape +@@ -1126,10 +1202,12 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + - rewriter.replaceOpWithNewOp(op, + - getTensorAttr(resultType, result)); + - return success(); +- } +- ++-} ++- + template +-@@ -275,29 +285,18 @@ ++ struct FoldOpRewritePattern : OpRewritePattern { ++ FoldOpRewritePattern(MLIRContext* context, ++@@ -275,29 +285,18 @@ struct FoldOpRewritePattern : OpRewritePattern { + ArrayRef generatedNames = {}) = delete; + + const StablehloAggressiveFolderPassOptions& options; +@@ -1154,9 +1232,8 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + - rewriter.replaceOpWithNewOp(op, res); + - return success(); + - } +-- ++ + - return failure(); +-+ + + LogicalResult validateElementCountForFold(PatternRewriter& rewriter, + + Operation* op, + + ShapedType resultType) const { +@@ -1171,31 +1248,21 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + } + }; + +-@@ -318,100 +317,102 @@ ++@@ -318,100 +317,102 @@ struct ShapeOpRewritePattern : public FoldOpRewritePattern { + } + }; + + -struct EvalAddOpShapePattern : public FoldOpRewritePattern { + - using FoldOpRewritePattern::FoldOpRewritePattern; +-- +-- LogicalResult matchAndRewrite(AddOp op, +-- PatternRewriter& rewriter) const override { +-- return evalElementwise(rewriter, op, +-- [&](APSInt lhs, APSInt rhs) { return lhs + rhs; }); +-- } +--}; +-- +--struct EvalAndOpPattern : public FoldOpRewritePattern { +-- using FoldOpRewritePattern::FoldOpRewritePattern; +-- +-- LogicalResult matchAndRewrite(AndOp op, +-- PatternRewriter& rewriter) const override { + +struct FoldAddOpPattern final + + : public ShapeOpRewritePattern { + + using ShapeOpRewritePattern::ShapeOpRewritePattern; +-+ ++ ++- LogicalResult matchAndRewrite(AddOp op, + + LogicalResult matchAndRewrite(mlir::stablehlo::AddOp op, +-+ PatternRewriter& rewriter) const override { ++ PatternRewriter& rewriter) const override { ++- return evalElementwise(rewriter, op, ++- [&](APSInt lhs, APSInt rhs) { return lhs + rhs; }); + + if (failed(validateShapeFoldDtype(rewriter, op, op.getType()))) + + return failure(); + + +@@ -1203,14 +1270,17 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + + if (failed(res)) return failure(); + + rewriter.replaceOpWithNewOp(op, res.value()); + + return success(); +-+ } +-+}; +-+ ++ } ++ }; ++ ++-struct EvalAndOpPattern : public FoldOpRewritePattern { ++- using FoldOpRewritePattern::FoldOpRewritePattern; + +struct FoldAndOpPattern : public ShapeOpRewritePattern { + + using ShapeOpRewritePattern::ShapeOpRewritePattern; +-+ ++ ++- LogicalResult matchAndRewrite(AndOp op, + + LogicalResult matchAndRewrite(mlir::stablehlo::AndOp op, +-+ PatternRewriter& rewriter) const override { ++ PatternRewriter& rewriter) const override { + + // TODO: Support more int types + auto resultType = op.getType(); + if (!resultType.getElementType().isInteger(1)) +@@ -1219,24 +1289,14 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + - return evalElementwise(rewriter, op, [&](APSInt lhsInt, APSInt rhsInt) { + - return getAPSInt(resultType.getElementType(), lhsInt != 0 && rhsInt != 0); + - }); +-- } + + auto res = foldBinaryOpIntOrFloat(rewriter, op, FoldAnd{}); + + if (failed(res)) return failure(); + + rewriter.replaceOpWithNewOp(op, res.value()); + + return success(); +-+ } +-+ +-+ struct FoldAnd { +-+ APInt operator()(APInt lhs, APInt rhs) const { +-+ return APInt(lhs.getBitWidth(), !lhs.isZero() && !rhs.isZero()); +-+ } +-+ std::optional operator()(APFloat lhs, APFloat rhs) const { +-+ return std::nullopt; +-+ } +-+ }; +- }; ++ } ++-}; + +- // Pattern: broadcast_in_dim(splat, _) -> constant(splat) ++-// Pattern: broadcast_in_dim(splat, _) -> constant(splat) + -struct FoldBroadcastInDimSplatPattern final + - : FoldOpRewritePattern { + - using FoldOpRewritePattern::FoldOpRewritePattern; +@@ -1251,14 +1311,22 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + - op, SplatElementsAttr::get(op.getType(), + - cstAttr.getSplatValue())); + - return success(); +-- } +++ struct FoldAnd { +++ APInt operator()(APInt lhs, APInt rhs) const { +++ return APInt(lhs.getBitWidth(), !lhs.isZero() && !rhs.isZero()); ++ } + - return failure(); + - } +--}; +-- +++ std::optional operator()(APFloat lhs, APFloat rhs) const { +++ return std::nullopt; +++ } +++ }; ++ }; ++ + -struct EvalBroadcastInDimOpPattern + - : public FoldOpRewritePattern { + - using FoldOpRewritePattern::FoldOpRewritePattern; +++// Pattern: broadcast_in_dim(splat, _) -> constant(splat) + +struct FoldBroadcastInDimOpSplatPattern + + : public ShapeOpRewritePattern { + + using ShapeOpRewritePattern::ShapeOpRewritePattern; +@@ -1267,8 +1335,10 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + PatternRewriter& rewriter) const override { + auto resultType = op.getType(); + - if (failed(validateStaticShapeResult(rewriter, op, resultType))) +-- return failure(); +-- +++ if (failed(validateStaticShapeResult(rewriter, op, resultType)) || +++ failed(validateShapeFoldDtype(rewriter, op, resultType))) ++ return failure(); ++ + - auto operandType = op.getOperand().getType(); + - if (operandType.getRank() != 0) + - return rewriter.notifyMatchFailure(op, "expected 0-dimensional type"); +@@ -1277,52 +1347,34 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + - if (failed(hlo::matchInts(op.getOperand(), operand))) + - return rewriter.notifyMatchFailure(op, "expected constant operands"); + - auto scalar = operand[0]; +-- +++ SplatElementsAttr cstAttr; +++ matchPattern(op.getOperand(), m_Constant(&cstAttr)); +++ if (!cstAttr) return rewriter.notifyMatchFailure(op, "operand not splat"); ++ + - rewriter.replaceOpWithNewOp( + - op, getTensorAttr(op.getType(), scalar)); +-- return success(); +-- } +--}; +-- +++ rewriter.replaceOpWithNewOp( +++ op, SplatElementsAttr::get(op.getType(), +++ cstAttr.getSplatValue())); ++ return success(); ++ } ++ }; ++ + -struct EvalClampOpPattern : public FoldOpRewritePattern { + - using FoldOpRewritePattern::FoldOpRewritePattern; +-- +++struct FoldCompareOpPattern : public ShapeOpRewritePattern { +++ using ShapeOpRewritePattern::ShapeOpRewritePattern; ++ + - LogicalResult matchAndRewrite(ClampOp op, +-- PatternRewriter& rewriter) const override { +++ LogicalResult matchAndRewrite(CompareOp op, ++ PatternRewriter& rewriter) const override { + - return evalElementwise(rewriter, op, + - [&](APSInt min, APSInt operand, APSInt max) { + - if (operand < min) return min; + - if (max < operand) return max; + - return operand; + - }); +-- } +--}; +-- +--struct EvalCompareOpPattern : public FoldOpRewritePattern { +-- using FoldOpRewritePattern::FoldOpRewritePattern; +-+ if (failed(validateStaticShapeResult(rewriter, op, resultType)) || +-+ failed(validateShapeFoldDtype(rewriter, op, resultType))) +-+ return failure(); +-+ +-+ SplatElementsAttr cstAttr; +-+ matchPattern(op.getOperand(), m_Constant(&cstAttr)); +-+ if (!cstAttr) return rewriter.notifyMatchFailure(op, "operand not splat"); +-+ +-+ rewriter.replaceOpWithNewOp( +-+ op, SplatElementsAttr::get(op.getType(), +-+ cstAttr.getSplatValue())); +-+ return success(); +-+ } +-+}; +-+ +-+struct FoldCompareOpPattern : public ShapeOpRewritePattern { +-+ using ShapeOpRewritePattern::ShapeOpRewritePattern; +- +- LogicalResult matchAndRewrite(CompareOp op, +- PatternRewriter& rewriter) const override { +- auto resultType = op.getType(); +-- auto kind = op.getCompareType(); +-- return evalElementwise(rewriter, op, [&](APInt lhs, APInt rhs) { +++ auto resultType = op.getType(); + + if (failed(validateShapeFoldDtype(rewriter, op, resultType))) + + return failure(); + + +@@ -1332,15 +1384,23 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + + if (failed(res)) return failure(); + + rewriter.replaceOpWithNewOp(op, res.value()); + + return success(); +-+ } +-+ ++ } ++-}; ++ ++-struct EvalCompareOpPattern : public FoldOpRewritePattern { ++- using FoldOpRewritePattern::FoldOpRewritePattern; + + struct FoldCompare { + + FoldCompare(ComparisonDirection direction, + + std::optional kind) + + : direction(direction), kind(kind) {} + + ComparisonDirection direction; + + std::optional kind; +-+ ++ ++- LogicalResult matchAndRewrite(CompareOp op, ++- PatternRewriter& rewriter) const override { ++- auto resultType = op.getType(); ++- auto kind = op.getCompareType(); ++- return evalElementwise(rewriter, op, [&](APInt lhs, APInt rhs) { + + // TODO: Enable float folding. + + std::optional operator()(APFloat lhs, APFloat rhs) { + + return std::nullopt; +@@ -1352,7 +1412,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + case ComparisonDirection::EQ: + result = lhs == rhs; + break; +-@@ -431,9 +432,9 @@ ++@@ -431,9 +432,9 @@ struct EvalCompareOpPattern : public FoldOpRewritePattern { + result = kind == ComparisonType::SIGNED ? lhs.slt(rhs) : lhs.ult(rhs); + break; + } +@@ -1365,7 +1425,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + }; + + ////////////////////////////////// +-@@ -441,16 +442,15 @@ ++@@ -441,16 +442,15 @@ struct EvalCompareOpPattern : public FoldOpRewritePattern { + ///////////////////////////////// + + struct FoldConcatenateOpPattern final +@@ -1387,7 +1447,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + return failure(); + + // Fold concatenate when all inputs are constants. +-@@ -466,6 +466,7 @@ ++@@ -466,6 +466,7 @@ struct FoldConcatenateOpPattern final + int64_t{1}, std::multiplies<>{}); + + SmallVector newElems; +@@ -1395,7 +1455,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + newElems.reserve(numElems); + + for (int64_t i = 0; i != topSize; ++i) { +-@@ -485,31 +486,7 @@ ++@@ -485,31 +486,7 @@ struct FoldConcatenateOpPattern final + int64_t foldOpElementLimit; + }; + +@@ -1428,20 +1488,17 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + using ShapeOpRewritePattern::ShapeOpRewritePattern; + + LogicalResult matchAndRewrite(ConvertOp op, +-@@ -532,28 +509,50 @@ ++@@ -532,28 +509,50 @@ struct EvalConvertOpPattern : public ShapeOpRewritePattern { + return rewriter.notifyMatchFailure( + op, "expected constant integer or float operand"); + + - return evalConvert(rewriter, op, elements, resultType); +-- } +--}; +-- +++ return foldConvert(rewriter, op, elements, resultType); ++ } ++ }; ++ + -struct EvalDivOpPattern : public FoldOpRewritePattern { + - using FoldOpRewritePattern::FoldOpRewritePattern; +-+ return foldConvert(rewriter, op, elements, resultType); +-+ } +-+}; +-+ + +struct FoldDivOpPattern : public ShapeOpRewritePattern { + + using ShapeOpRewritePattern::ShapeOpRewritePattern; + +@@ -1449,12 +1506,6 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + PatternRewriter& rewriter) const override { + - return evalElementwise(rewriter, op, + - [&](APSInt lhs, APSInt rhs) { return lhs / rhs; }); +-- } +--}; +-- +--struct EvalGetDimensionSizeOpPattern +-- : public FoldOpRewritePattern { +-- using FoldOpRewritePattern::FoldOpRewritePattern; + + auto resultType = op.getType(); + + if (failed(validateShapeFoldDtype(rewriter, op, resultType))) + + return failure(); +@@ -1464,7 +1515,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + + if (failed(res)) return failure(); + + rewriter.replaceOpWithNewOp(op, res.value()); + + return success(); +-+ } ++ } + + + + struct FoldDivide { + + FoldDivide(bool isUnsignedInt) +@@ -1479,8 +1530,11 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + + static APInt foldUint(APInt lhs, APInt rhs) { return lhs.udiv(rhs); } + + static APInt foldSint(APInt lhs, APInt rhs) { return lhs.sdiv(rhs); } + + }; +-+}; +-+ ++ }; ++ ++-struct EvalGetDimensionSizeOpPattern ++- : public FoldOpRewritePattern { ++- using FoldOpRewritePattern::FoldOpRewritePattern; + +struct FoldGetDimensionSizeOpPattern + + : public ShapeOpRewritePattern { + + using ShapeOpRewritePattern::ShapeOpRewritePattern; +@@ -1494,7 +1548,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + return failure(); + + auto operandType = op.getOperand().getType(); +-@@ -567,86 +566,187 @@ ++@@ -567,86 +566,187 @@ struct EvalGetDimensionSizeOpPattern + } + }; + +@@ -1549,11 +1603,6 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + - return evalElementwise(rewriter, op, [&](APSInt lhs, APSInt rhs) { + - return lhs >= rhs ? lhs : rhs; + - }); +-- } +--}; +-- +--struct EvalMinOpPattern : public FoldOpRewritePattern { +-- using FoldOpRewritePattern::FoldOpRewritePattern; + + auto resultType = op.getType(); + + if (failed(validateShapeFoldDtype(rewriter, op, resultType))) + + return failure(); +@@ -1563,9 +1612,11 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + + if (failed(res)) return failure(); + + rewriter.replaceOpWithNewOp(op, res.value()); + + return success(); +-+ } +-+}; +-+ ++ } ++ }; ++ ++-struct EvalMinOpPattern : public FoldOpRewritePattern { ++- using FoldOpRewritePattern::FoldOpRewritePattern; + +struct FoldMinOpPattern : public ShapeOpRewritePattern { + + using ShapeOpRewritePattern::ShapeOpRewritePattern; + +@@ -1574,11 +1625,6 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + - return evalElementwise(rewriter, op, [&](APSInt lhs, APSInt rhs) { + - return lhs <= rhs ? lhs : rhs; + - }); +-- } +--}; +-- +--struct FoldMulOpPattern final : FoldOpRewritePattern { +-- using FoldOpRewritePattern::FoldOpRewritePattern; + + auto resultType = op.getType(); + + if (failed(validateShapeFoldDtype(rewriter, op, resultType))) + + return failure(); +@@ -1588,19 +1634,35 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + + if (failed(res)) return failure(); + + rewriter.replaceOpWithNewOp(op, res.value()); + + return success(); +-+ } +-+}; +-+ ++ } ++ }; ++ ++-struct FoldMulOpPattern final : FoldOpRewritePattern { ++- using FoldOpRewritePattern::FoldOpRewritePattern; + +// Clamp is folded using Min and Max folders. + +struct FoldClampOpPattern : public ShapeOpRewritePattern { + + using ShapeOpRewritePattern::ShapeOpRewritePattern; +-+ ++ ++- LogicalResult matchAndRewrite(mlir::stablehlo::MulOp op, + + LogicalResult matchAndRewrite(ClampOp op, +-+ PatternRewriter& rewriter) const override { ++ PatternRewriter& rewriter) const override { ++- TypedAttr lhsAttr; ++- matchPattern(op.getLhs(), m_Constant(&lhsAttr)); ++- ++- TypedAttr rhsAttr; ++- matchPattern(op.getRhs(), m_Constant(&rhsAttr)); ++- ++- if (TypedAttr res; ++- lhsAttr && rhsAttr && ++- (res = foldBinaryOpIntOrFloat(lhsAttr, rhsAttr, std::multiplies<>{}))) { ++- rewriter.replaceOpWithNewOp(op, res); ++- return success(); ++- } + + auto resultType = op.getType(); + + if (failed(validateShapeFoldDtype(rewriter, op, resultType))) + + return failure(); +-+ ++ ++- return failure(); + + TypedAttr minAttr, operandAttr, maxAttr; + + matchPattern(op.getMin(), m_Constant(&minAttr)); + + matchPattern(op.getOperand(), m_Constant(&operandAttr)); +@@ -1620,43 +1682,19 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + + if (!res) return rewriter.notifyMatchFailure(op, "failed to fold clamp"); + + rewriter.replaceOpWithNewOp(op, res); + + return success(); +-+ } +-+}; +-+ +-+struct FoldMulOpPattern final : ShapeOpRewritePattern { +-+ using ShapeOpRewritePattern::ShapeOpRewritePattern; ++ } ++ }; + +- LogicalResult matchAndRewrite(mlir::stablehlo::MulOp op, +- PatternRewriter& rewriter) const override { +-- TypedAttr lhsAttr; +-- matchPattern(op.getLhs(), m_Constant(&lhsAttr)); +-- +-- TypedAttr rhsAttr; +-- matchPattern(op.getRhs(), m_Constant(&rhsAttr)); +-- +-- if (TypedAttr res; +-- lhsAttr && rhsAttr && +-- (res = foldBinaryOpIntOrFloat(lhsAttr, rhsAttr, std::multiplies<>{}))) { +-- rewriter.replaceOpWithNewOp(op, res); +-- return success(); +-- } +-- +-- return failure(); +-- } +--}; +-- + -struct EvalMulOpPattern : public FoldOpRewritePattern { + - using FoldOpRewritePattern::FoldOpRewritePattern; +-- +++struct FoldMulOpPattern final : ShapeOpRewritePattern { +++ using ShapeOpRewritePattern::ShapeOpRewritePattern; ++ + - LogicalResult matchAndRewrite(MulOp op, +-- PatternRewriter& rewriter) const override { +++ LogicalResult matchAndRewrite(mlir::stablehlo::MulOp op, ++ PatternRewriter& rewriter) const override { + - return evalElementwise(rewriter, op, + - [&](APSInt lhs, APSInt rhs) { return lhs * rhs; }); +-- } +--}; +-- +--struct EvalOrOpPattern : public FoldOpRewritePattern { +-- using FoldOpRewritePattern::FoldOpRewritePattern; + + if (failed(validateShapeFoldDtype(rewriter, op, op.getType()))) + + return failure(); + + +@@ -1664,9 +1702,11 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + + if (failed(res)) return failure(); + + rewriter.replaceOpWithNewOp(op, res.value()); + + return success(); +-+ } +-+}; +-+ ++ } ++ }; ++ ++-struct EvalOrOpPattern : public FoldOpRewritePattern { ++- using FoldOpRewritePattern::FoldOpRewritePattern; + +struct FoldOrOpPattern : public ShapeOpRewritePattern { + + using ShapeOpRewritePattern::ShapeOpRewritePattern; + +@@ -1680,16 +1720,11 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + - return evalElementwise(rewriter, op, [&](APSInt lhsInt, APSInt rhsInt) { + - return getAPSInt(resultType.getElementType(), lhsInt != 0 || rhsInt != 0); + - }); +-- } +--}; +-- +--struct EvalRemOpPattern : public FoldOpRewritePattern { +-- using FoldOpRewritePattern::FoldOpRewritePattern; + + auto res = foldBinaryOpIntOrFloat(rewriter, op, FoldOr{}); + + if (failed(res)) return failure(); + + rewriter.replaceOpWithNewOp(op, res.value()); + + return success(); +-+ } ++ } + + + + struct FoldOr { + + APInt operator()(APInt lhs, APInt rhs) const { +@@ -1699,8 +1734,10 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + + return std::nullopt; + + } + + }; +-+}; +-+ ++ }; ++ ++-struct EvalRemOpPattern : public FoldOpRewritePattern { ++- using FoldOpRewritePattern::FoldOpRewritePattern; + +struct FoldRemOpPattern : public ShapeOpRewritePattern { + + using ShapeOpRewritePattern::ShapeOpRewritePattern; + +@@ -1708,10 +1745,6 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + PatternRewriter& rewriter) const override { + - return evalElementwise(rewriter, op, + - [&](APSInt lhs, APSInt rhs) { return lhs % rhs; }); +-- } +--}; +-- +--struct EvalReshapeOpPattern : public ShapeOpRewritePattern { + + auto resultType = op.getType(); + + if (failed(validateShapeFoldDtype(rewriter, op, resultType))) + + return failure(); +@@ -1721,7 +1754,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + + if (failed(res)) return failure(); + + rewriter.replaceOpWithNewOp(op, res.value()); + + return success(); +-+ } ++ } + + + + struct FoldRem { + + FoldRem(bool isUnsignedInt) +@@ -1736,14 +1769,15 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + + static APInt foldUint(APInt lhs, APInt rhs) { return lhs.urem(rhs); } + + static APInt foldSint(APInt lhs, APInt rhs) { return lhs.srem(rhs); } + + }; +-+}; +-+ ++ }; ++ ++-struct EvalReshapeOpPattern : public ShapeOpRewritePattern { + +// Pattern: reshape(cst, shape) -> cst + +struct FoldReshapeOpPattern : public ShapeOpRewritePattern { + using ShapeOpRewritePattern::ShapeOpRewritePattern; + + LogicalResult matchAndRewrite(ReshapeOp op, +-@@ -656,7 +756,6 @@ ++@@ -656,7 +756,6 @@ struct EvalReshapeOpPattern : public ShapeOpRewritePattern { + failed(validateShapeFoldDtype(rewriter, op, resultType))) + return failure(); + +@@ -1751,7 +1785,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + DenseIntOrFPElementsAttr attr; + if (!matchPattern(op.getOperand(), m_Constant(&attr))) + return rewriter.notifyMatchFailure(op, "expected constant operand"); +-@@ -665,53 +764,98 @@ ++@@ -665,53 +764,98 @@ struct EvalReshapeOpPattern : public ShapeOpRewritePattern { + } + }; + +@@ -1764,16 +1798,14 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + PatternRewriter& rewriter) const override { + auto resultType = op.getType(); + - if (failed(validateStaticShapeResult(rewriter, op, resultType))) +-- return failure(); +-- +++ if (failed(validateStaticShapeResult(rewriter, op, resultType)) || +++ failed(validateShapeFoldDtype(rewriter, op, resultType))) ++ return failure(); ++ + - SmallVector pred, onTrue, onFalse; + - if (failed(hlo::matchInts(op.getPred(), pred)) || + - failed(hlo::matchInts(op.getOnTrue(), onTrue)) || + - failed(hlo::matchInts(op.getOnFalse(), onFalse))) +-+ if (failed(validateStaticShapeResult(rewriter, op, resultType)) || +-+ failed(validateShapeFoldDtype(rewriter, op, resultType))) +-+ return failure(); +-+ + + DenseIntElementsAttr predAttr; + + DenseElementsAttr onTrueAttr, onFalseAttr; + + matchPattern(op.getPred(), m_Constant(&predAttr)); +@@ -1783,14 +1815,17 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + return rewriter.notifyMatchFailure(op, "expected constant operands"); + + - SmallVector result; ++- for (auto [predEl, onTrueEl, onFalseEl] : ++- llvm::zip(pred, onTrue, onFalse)) { ++- result.push_back(predEl != 0 ? onTrueEl : onFalseEl); + + // Optimization, handle splat predicate + + if (isa(predAttr)) { + + auto pred = predAttr.getSplatValue(); + + rewriter.replaceOpWithNewOp( + + op, pred.isZero() ? onFalseAttr : onTrueAttr); + + return success(); +-+ } +-+ ++ } ++ + + // TODO: Enable float folding. + + if (op.getType().getElementType().isFloat()) + + return rewriter.notifyMatchFailure(op, "float select not supported yet"); +@@ -1800,27 +1835,17 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + + return failure(); + + + + SmallVector result; +- for (auto [predEl, onTrueEl, onFalseEl] : +-- llvm::zip(pred, onTrue, onFalse)) { +-- result.push_back(predEl != 0 ? onTrueEl : onFalseEl); +-- } +-- +++ for (auto [predEl, onTrueEl, onFalseEl] : + + llvm::zip(predAttr.getValues(), onTrueAttr.getValues(), + + onFalseAttr.getValues())) { + + result.push_back(!predEl.isZero() ? onTrueEl : onFalseEl); + + } + rewriter.replaceOpWithNewOp( + - op, getTensorAttr(op.getType(), result)); +-- return success(); +-- } +--}; +-- +--struct EvalSignOpPattern : public FoldOpRewritePattern { +-- using FoldOpRewritePattern::FoldOpRewritePattern; + + op, DenseIntElementsAttr::get(resultType, result)); + + +-+ return success(); +-+ } ++ return success(); ++ } + + + + struct FoldSelect { + + std::optional operator()(APFloat pred, APFloat onTrue, +@@ -1832,8 +1857,10 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + + return pred != 0 ? onTrue : onFalse; + + } + + }; +-+}; +-+ ++ }; ++ ++-struct EvalSignOpPattern : public FoldOpRewritePattern { ++- using FoldOpRewritePattern::FoldOpRewritePattern; + +struct FoldSignOpPattern : public ShapeOpRewritePattern { + + using ShapeOpRewritePattern::ShapeOpRewritePattern; + +@@ -1881,7 +1908,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + }; + + template +-@@ -749,13 +893,14 @@ ++@@ -749,13 +893,14 @@ DenseElementsAttr sliceType(SliceOp& op, const RangeType& data) { + ArrayRef(result)); + } + +@@ -1899,7 +1926,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + return failure(); + + auto operand = op.getOperand(); +-@@ -784,36 +929,18 @@ ++@@ -784,36 +929,18 @@ struct EvalSliceOpPattern : public FoldOpRewritePattern { + }; + + struct FoldSubtractOpPattern final +@@ -1930,14 +1957,13 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + - + -struct EvalSubtractOpPattern : public FoldOpRewritePattern { + - using FoldOpRewritePattern::FoldOpRewritePattern; +-- +++ if (failed(validateShapeFoldDtype(rewriter, op, op.getType()))) +++ return failure(); ++ + - LogicalResult matchAndRewrite(SubtractOp op, + - PatternRewriter& rewriter) const override { + - return evalElementwise(rewriter, op, + - [&](APSInt lhs, APSInt rhs) { return lhs - rhs; }); +-+ if (failed(validateShapeFoldDtype(rewriter, op, op.getType()))) +-+ return failure(); +-+ + + auto res = foldBinaryOpIntOrFloat(rewriter, op, std::minus<>{}); + + if (failed(res)) return failure(); + + rewriter.replaceOpWithNewOp(op, res.value()); +@@ -1945,23 +1971,35 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + } + }; + +-@@ -823,42 +950,38 @@ ++@@ -823,42 +950,38 @@ struct FoldSqrtOpPattern + + LogicalResult matchAndRewrite(mlir::stablehlo::SqrtOp op, + PatternRewriter& rewriter) const final { + - TypedAttr lhsAttr; + - matchPattern(op.getOperand(), m_Constant(&lhsAttr)); +-- +++ auto res = foldUnaryOpIntOrFloat(rewriter, op, FoldSqrt()); +++ if (failed(res)) return failure(); +++ rewriter.replaceOpWithNewOp(op, res.value()); +++ return success(); +++ } ++ + - if (!lhsAttr) + - return rewriter.notifyMatchFailure(op, "operand not constant"); +-- +++ struct FoldSqrt { +++ std::optional operator()(APFloat operand) { +++ if (operand.getSizeInBits(operand.getSemantics()) == 64) +++ return APFloat(std::sqrt(operand.convertToDouble())); ++ + - if (auto res = constFoldUnaryOp( + - lhsAttr, foldSqrt)) { + - rewriter.replaceOpWithNewOp( + - op, op.getType(), llvm::cast(res)); + - return success(); +-- } +-- +++ if (operand.getSizeInBits(operand.getSemantics()) == 32) +++ return APFloat(sqrtf(operand.convertToFloat())); +++ return std::nullopt; ++ } ++ + - return rewriter.notifyMatchFailure(op, "unable to fold sqrt"); + - } + - +@@ -1973,32 +2011,14 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + - return APFloat(sqrtf(a.convertToFloat())); + - return {}; + - } +--}; +-- +--struct EvalIotaOpPattern : public FoldOpRewritePattern { +-+ auto res = foldUnaryOpIntOrFloat(rewriter, op, FoldSqrt()); +-+ if (failed(res)) return failure(); +-+ rewriter.replaceOpWithNewOp(op, res.value()); +-+ return success(); +-+ } +-+ +-+ struct FoldSqrt { +-+ std::optional operator()(APFloat operand) { +-+ if (operand.getSizeInBits(operand.getSemantics()) == 64) +-+ return APFloat(std::sqrt(operand.convertToDouble())); +-+ +-+ if (operand.getSizeInBits(operand.getSemantics()) == 32) +-+ return APFloat(sqrtf(operand.convertToFloat())); +-+ return std::nullopt; +-+ } +-+ + + // TODO: Enable int folding. + + std::optional operator()(APInt operand) { + + return std::nullopt; + + } + + }; +-+}; +-+ ++ }; ++ ++-struct EvalIotaOpPattern : public FoldOpRewritePattern { + +struct FoldIotaOpPattern : public FoldOpRewritePattern { + using FoldOpRewritePattern::FoldOpRewritePattern; + +@@ -2015,7 +2035,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + + auto elementType = resultType.getElementType(); + +-@@ -929,7 +1052,7 @@ ++@@ -929,7 +1052,7 @@ DenseElementsAttr transposeType(TransposeOp& op, const RangeType& data) { + // transpose(constant) => constant with permuted dimensions + // This covers ranked tensor types with 0 dimensions(zero elements) and 0 + // rank(scalar), as well as splat values. +@@ -2024,7 +2044,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + using FoldOpRewritePattern::FoldOpRewritePattern; + + LogicalResult matchAndRewrite(TransposeOp op, +-@@ -943,6 +1066,7 @@ ++@@ -943,6 +1066,7 @@ struct EvalTransposeOpPattern : public FoldOpRewritePattern { + return rewriter.notifyMatchFailure( + op, "expected constant integer or float operand"); + +@@ -2032,7 +2052,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + DenseElementsAttr resAttr; + if (auto data = els.tryGetValues()) + resAttr = transposeType(op, *data); +-@@ -957,6 +1081,7 @@ ++@@ -957,6 +1081,7 @@ struct EvalTransposeOpPattern : public FoldOpRewritePattern { + } + }; + +@@ -2040,7 +2060,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + struct LowerBoolSplatConstantsIntoReduceOpRegion + : public FoldOpRewritePattern { + using FoldOpRewritePattern::FoldOpRewritePattern; +-@@ -1160,7 +1285,7 @@ ++@@ -1160,7 +1285,7 @@ bool hasNoDeclaredSideEffects(Operation* op) { + return true; + } + +@@ -2049,7 +2069,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + : public FoldOpRewritePattern { + using FoldOpRewritePattern::FoldOpRewritePattern; + +-@@ -1233,23 +1358,16 @@ ++@@ -1233,23 +1358,16 @@ void populateStablehloAggressiveFolderPatterns( + PatternBenefit benefit) { + populateStablehloShapeFolderPatterns(context, patterns, options, benefit); + +@@ -2083,7 +2103,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + } + + class StablehloTargetIndependentOptimizationPass { +-@@ -1266,25 +1384,25 @@ ++@@ -1266,25 +1384,25 @@ void populateStablehloShapeFolderPatterns( + MLIRContext* context, RewritePatternSet* patterns, + const StablehloAggressiveFolderPassOptions& options, + PatternBenefit benefit) { +@@ -2128,28 +2148,29 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + } + + void populateStablehloShapeFolderPatterns(MLIRContext* context, +-diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td b/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td +---- stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td +-+++ stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td +-@@ -48,6 +48,8 @@ +- def RankEqual : Constraint< ++diff --git a/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td b/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td ++index 8adb6dbc..3a8edc61 100644 ++--- a/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td +++++ b/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td ++@@ -49,6 +49,8 @@ def RankEqual : Constraint< + CPred<"llvm::cast($0.getType()).getRank() == llvm::cast($1.getType()).getRank()">, + "same rank">; +-+ +-+def TensorDimsAllOne : Constraint, "all tensor dims are 1">; + +++def TensorDimsAllOne : Constraint, "all tensor dims are 1">; +++ + def TypesEqual : Constraint, "operands are equal">; + +-@@ -100,6 +102,8 @@ +- def ZeroExtent : AttrConstraint< ++ /////////// ++@@ -101,6 +103,8 @@ def ZeroExtent : AttrConstraint< + CPred<"cast($_self).getNumElements() == 0">, + "is zero extent">; +-+ +-+def AnyStaticShapeIntTensor : StaticShapeTensorOf<[HLO_Int]>; + +++def AnyStaticShapeIntTensor : StaticShapeTensorOf<[HLO_Int]>; +++ + /////////// + //// Native Code Call Utilities +-@@ -503,7 +507,7 @@ ++ ++@@ -503,7 +507,7 @@ def SelectOp_InvertBroadcastPredicateAndSwap + // Must be static shape, otherwise would require broadcasting via + // CHLO_ConstantLike. + def SubtractOp_FoldToZero +@@ -2158,4 +2179,6 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimp + (StableHLO_ConstantLike<"0"> $operand)>; + + // Pattern: subtract(X, 0) -> X ++-- ++2.48.1 + +-- +2.48.1 + diff --git a/third_party/modules/xla/20250612.0-6e48cbb/source.json b/third_party/modules/xla/20250612.0-6e48cbb/source.json new file mode 100644 index 0000000..6a4e208 --- /dev/null +++ b/third_party/modules/xla/20250612.0-6e48cbb/source.json @@ -0,0 +1,17 @@ +{ + "strip_prefix": "xla-6e48cbb8d33d771c964697e39bfaf678bcc6de31", + "url": "https://github.com/openxla/xla/archive/6e48cbb8d33d771c964697e39bfaf678bcc6de31.tar.gz", + "integrity": "sha256-i9lYvZ2MkzfyVW2Iu3qIucXIgGEhkbwsYXCrUZ6Yze8=", + "overlay": { + "llvm.bzl": "", + "MODULE.bazel": "", + "workspace_private.bzl": "", + "xla.bzl": "" + }, + "patch_strip": 1, + "patches": { + "0001-bazel-migration-to-bazel-8.1.1.patch": "", + "0002-Added-FFI-handler-registration-API-to-the-FFI-PjRt.patch": "", + "0003-Revert-Add-optional-allowOtherDialects-field-to-stab.patch": "" + } +} diff --git a/third_party/modules/xla/metadata.json b/third_party/modules/xla/metadata.json index eca3b43..be35763 100644 --- a/third_party/modules/xla/metadata.json +++ b/third_party/modules/xla/metadata.json @@ -21,7 +21,8 @@ "20250317.0-71c67e2", "20250317.1-71c67e2", "20250317.2-71c67e2", - "20250527.0-cb67f2f" + "20250527.0-cb67f2f", + "20250612.0-6e48cbb" ], "yanked_versions": {} }