Radix/third_party/modules/xla/20250612.0-6e48cbb/patches/0003-Revert-Add-optional-allowOtherDialects-field-to-stab.patch

1856 lines
81 KiB
Diff
Raw Normal View History

2025-02-05 17:35:27 +00:00
From 4c0819ac9fb9dfc6156ae4de83fb29e987ade780 Mon Sep 17 00:00:00 2001
From: Corentin Godeau <corentin@zml.ai>
Date: Mon, 16 Jun 2025 14:27:47 +0000
Subject: [PATCH] Update Stablehlo patch
---
third_party/stablehlo/temporary.patch | 841 +++++++++++++-------------
1 file changed, 432 insertions(+), 409 deletions(-)
mode change 100755 => 100644 third_party/stablehlo/temporary.patch
diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch
old mode 100755
new mode 100644
index f54e8341ba..6b25e78467
--- a/third_party/stablehlo/temporary.patch
+++ b/third_party/stablehlo/temporary.patch
@@ -1,7 +1,47 @@
-diff --ruN a/stablehlo/stablehlo/dialect/VhloAttrs.td b/stablehlo/stablehlo/dialect/VhloAttrs.td
---- stablehlo/stablehlo/dialect/VhloAttrs.td
-+++ stablehlo/stablehlo/dialect/VhloAttrs.td
-@@ -45,14 +45,6 @@
+From c48aa0650bf27b81c42799f9218dd90779e36e3b Mon Sep 17 00:00:00 2001
+From: Corentin Godeau <corentin@zml.ai>
+Date: Mon, 16 Jun 2025 14:08:48 +0000
+Subject: [PATCH] Remove default argument that is not valid C
+
+---
+ stablehlo/dialect/Version.h | 2 +-
+ stablehlo/dialect/VhloAttrs.td | 14 +-
+ stablehlo/dialect/VhloDialect.td | 1 +
+ stablehlo/dialect/VhloTypes.h | 24 +-
+ .../integrations/c/StablehloDialectApi.h | 2 +-
+ .../stablehlo_aggressive_folder.mlir | 37 +-
+ .../stablehlo_aggressive_simplification.mlir | 2 +-
+ .../transforms/stablehlo_refine_shapes.mlir | 10 +-
+ .../stablehlo_legalize_to_vhlo_mixed.mlir | 38 +-
+ .../tests/vhlo/vhlo_attributes_invalid.mlir | 13 +-
+ stablehlo/tools/StablehloTranslateMain.cpp | 8 +-
+ stablehlo/transforms/Passes.h | 7 +-
+ .../transforms/StablehloLegalizeToVhlo.cpp | 48 +-
+ .../transforms/StablehloRefineShapes.cpp | 4 +
+ .../transforms/VhloLegalizeToStablehlo.cpp | 100 ++-
+ stablehlo/transforms/VhloToVersion.cpp | 38 +-
+ .../StablehloAggressiveFolder.cpp | 816 ++++++++++--------
+ ...ablehloAggressiveSimplificationPatterns.td | 6 +-
+ 18 files changed, 696 insertions(+), 474 deletions(-)
+
+diff --git a/stablehlo/dialect/Version.h b/stablehlo/dialect/Version.h
+index eb35ad52..6aaef985 100644
+--- a/stablehlo/dialect/Version.h
++++ b/stablehlo/dialect/Version.h
+@@ -38,7 +38,7 @@ class Version {
+ static FailureOr<Version> fromString(llvm::StringRef versionRef);
+
+ /// Return a Version representing the current VHLO dialect version.
+- static Version getCurrentVersion() { return Version(1, 10, 10); }
++ static Version getCurrentVersion() { return Version(1, 11, 0); }
+
+ /// Return a Version representing the minimum supported VHLO dialect version.
+ static Version getMinimumVersion() { return Version(0, 9, 0); }
+diff --git a/stablehlo/dialect/VhloAttrs.td b/stablehlo/dialect/VhloAttrs.td
+index bf75ad27..ab48630a 100644
+--- a/stablehlo/dialect/VhloAttrs.td
++++ b/stablehlo/dialect/VhloAttrs.td
+@@ -45,14 +45,6 @@ class VHLO_AttrDef<string name, string minVersion, string maxVersion>
def VHLO_ArrayAttrV1 : VHLO_AttrDef<"ArrayV1", "0.9.0", "current"> {
let mnemonic = "array_v1";
let parameters = (ins ArrayRefParameter<"mlir::Attribute">:$value);
@@ -16,7 +56,7 @@ diff --ruN a/stablehlo/stablehlo/dialect/VhloAttrs.td b/stablehlo/stablehlo/dial
let assemblyFormat = "`<` custom<AttributeArray>($value) `>`";
}
-@@ -75,9 +67,9 @@
+@@ -75,9 +67,9 @@ def VHLO_DictionaryAttrV1 : VHLO_AttrDef<"DictionaryV1", "0.9.0", "current"> {
LogicalResult DictionaryV1Attr::verify(
llvm::function_ref<mlir::InFlightDiagnostic ()> errFn,
ArrayRef<std::pair<mlir::Attribute, mlir::Attribute>> value) {
@@ -29,36 +69,48 @@ diff --ruN a/stablehlo/stablehlo/dialect/VhloAttrs.td b/stablehlo/stablehlo/dial
return success();
}
}];
-diff --ruN a/stablehlo/stablehlo/dialect/VhloTypes.h b/stablehlo/stablehlo/dialect/VhloTypes.h
---- stablehlo/stablehlo/dialect/VhloTypes.h
-+++ stablehlo/stablehlo/dialect/VhloTypes.h
-@@ -27,20 +27,23 @@
+diff --git a/stablehlo/dialect/VhloDialect.td b/stablehlo/dialect/VhloDialect.td
+index 4e50cb32..edd220e3 100644
+--- a/stablehlo/dialect/VhloDialect.td
++++ b/stablehlo/dialect/VhloDialect.td
+@@ -49,6 +49,7 @@ def VHLO_Dialect : Dialect {
+ 1.8.0: Introduce `f4E2M1FN`, `f6E2M3FN`, `f6E3M2FN` and `f8E8M0FNU` types.
+ 1.9.0: Add `ResultAccuracy` attribute to `exp` op.
+ 1.10.0: Add `ResultAccuracy` attribute to `cbrt`, `cosine`, `exponential`, `exponential_minus_one`, `log`, `log_plus_one`, `logistic`, `rsqrt`, `sine`, `sqrt`, `tan` and `tanh` ops.
++ 1.11.0: Allow (de)serializing VHLO programs mixed with potentially unstable dialects.
+ }];
+
+ let useDefaultAttributePrinterParser = 0;
+diff --git a/stablehlo/dialect/VhloTypes.h b/stablehlo/dialect/VhloTypes.h
+index e5aee254..f3a8d68d 100644
+--- a/stablehlo/dialect/VhloTypes.h
++++ b/stablehlo/dialect/VhloTypes.h
+@@ -27,20 +27,23 @@ limitations under the License.
namespace mlir {
namespace vhlo {
-class VhloTypeConverterBase : public TypeConverter {
-- public:
-- VhloTypeConverterBase() : TypeConverter(){};
--
-- virtual ~VhloTypeConverterBase() = default;
--
-- virtual Attribute convertEncoding(Attribute attr) const = 0;
--};
-
- // This class is used to manage conversions between VHLO and Builtin
- // dialects.
--class VhloTypeConverter : public VhloTypeConverterBase {
++
++// This class is used to manage conversions between VHLO and Builtin
++// dialects.
+class VhloTypeConverter : public TypeConverter {
public:
-- VhloTypeConverter() : VhloTypeConverterBase() {}
+- VhloTypeConverterBase() : TypeConverter(){};
+ VhloTypeConverter() : TypeConverter(), allowOtherDialects(false) {}
+ VhloTypeConverter(bool allowOtherDialects)
+ : TypeConverter(), allowOtherDialects(allowOtherDialects) {}
-+
+
+- virtual ~VhloTypeConverterBase() = default;
+ virtual ~VhloTypeConverter() = default;
-+
-+ virtual Attribute convertEncoding(Attribute attr) const = 0;
-+
+
+ virtual Attribute convertEncoding(Attribute attr) const = 0;
+-};
+
+-// This class is used to manage conversions between VHLO and Builtin
+-// dialects.
+-class VhloTypeConverter : public VhloTypeConverterBase {
+- public:
+- VhloTypeConverter() : VhloTypeConverterBase() {}
+ Attribute convertUnknownAttribute(Attribute attr) const {
+ if (allowOtherDialects) return attr;
+ return {};
@@ -66,7 +118,7 @@ diff --ruN a/stablehlo/stablehlo/dialect/VhloTypes.h b/stablehlo/stablehlo/diale
// A subclass can call this method to add conversions from VHLO -> Builtin
// types. Note that conversions are applied in reverse order, with the most
-@@ -58,6 +61,9 @@
+@@ -58,6 +61,9 @@ class VhloTypeConverter : public VhloTypeConverterBase {
// Mark unrealized casts as legal. Useful for dialect mixing.
void addUnrealizedMaterializations();
@@ -76,19 +128,34 @@ diff --ruN a/stablehlo/stablehlo/dialect/VhloTypes.h b/stablehlo/stablehlo/diale
};
// Autogenerated VHLO type printers and parsers.
-diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir b/stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir
---- stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir
-+++ stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir
+diff --git a/stablehlo/integrations/c/StablehloDialectApi.h b/stablehlo/integrations/c/StablehloDialectApi.h
+index 385156bf..24d11c1d 100644
+--- a/stablehlo/integrations/c/StablehloDialectApi.h
++++ b/stablehlo/integrations/c/StablehloDialectApi.h
+@@ -93,7 +93,7 @@ stablehloSerializePortableArtifactFromModule(MlirModule moduleStr,
+ MlirStringRef targetVersion,
+ MlirStringCallback callback,
+ void* userData,
+- bool allowOtherDialects = false);
++ bool allowOtherDialects);
+
+ // Read a StableHLO program from a portable artifact, returning the module as
+ // MLIR bytecode. Note, this bytecode returned is not a portable artifact,
+diff --git a/stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir b/stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir
+index 758965d0..c5239671 100644
+--- a/stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir
++++ b/stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir
@@ -1,4 +1,4 @@
-// RUN: stablehlo-opt --stablehlo-aggressive-folder --split-input-file --verify-diagnostics %s | FileCheck %s
+// RUN: stablehlo-opt --stablehlo-aggressive-folder=fold-op-element-limit=100 --split-input-file --verify-diagnostics %s | FileCheck %s
////////
// AddOp
-@@ -42,6 +42,21 @@
+@@ -41,6 +41,21 @@ func.func @broadcast_in_dim_fold_splat(%arg0: tensor<3x3xi32>)
+
// -----
- ////////
++////////
+// ClampOp
+
+// CHECK-LABEL: func.func @clamp_fold
@@ -103,18 +170,13 @@ diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_folder.ml
+
+// -----
+
-+////////
+ ////////
// CompareOp
- // CHECK-LABEL: func.func @compare_folds
-@@ -98,6 +113,26 @@
- // CHECK-DAG: [[R3:%.+]] = stablehlo.constant dense<{{\[\[0, 1, 2, 11, 12\], \[3, 4, 5, 13, 14\]\]}}> : tensor<2x5xi32>
- // CHECK-NEXT: return [[R0]], [[R1]], [[R2]], [[R3]]
- return %0, %1, %2, %3 : tensor<6xi32>, tensor<3xi32>, tensor<3x3xi32>, tensor<2x5xi32>
-+}
-+
-+// -----
-+
+@@ -102,6 +117,26 @@ func.func @concatenate_fold() -> (tensor<6xi32>, tensor<3xi32>, tensor<3x3xi32>,
+
+ // -----
+
+////////
+// DivOp
+
@@ -131,32 +193,37 @@ diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_folder.ml
+ %1 = stablehlo.divide %cst_1, %cst_1 : tensor<ui32>
+ %2 = stablehlo.divide %cst_2, %cst_2 : tensor<f32>
+ return %0, %1, %2 : tensor<i32>, tensor<ui32>, tensor<f32>
- }
++}
++
++// -----
++
+ ////////
+ // MulOp
- // -----
-diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir b/stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir
---- stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir
-+++ stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir
+diff --git a/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir b/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir
+index 4921a224..a1eed944 100644
+--- a/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir
++++ b/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir
@@ -1,4 +1,4 @@
-// RUN: stablehlo-opt --stablehlo-aggressive-simplification --allow-unregistered-dialect --split-input-file %s | FileCheck %s
+// RUN: stablehlo-opt --stablehlo-aggressive-simplification=fold-op-element-limit=100 --allow-unregistered-dialect --split-input-file %s | FileCheck %s
/////////
// AddOp
-diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_refine_shapes.mlir b/stablehlo/stablehlo/tests/transforms/stablehlo_refine_shapes.mlir
---- stablehlo/stablehlo/tests/transforms/stablehlo_refine_shapes.mlir
-+++ stablehlo/stablehlo/tests/transforms/stablehlo_refine_shapes.mlir
-@@ -521,16 +521,16 @@
+diff --git a/stablehlo/tests/transforms/stablehlo_refine_shapes.mlir b/stablehlo/tests/transforms/stablehlo_refine_shapes.mlir
+index b262bf09..01b5dd8f 100644
+--- a/stablehlo/tests/transforms/stablehlo_refine_shapes.mlir
++++ b/stablehlo/tests/transforms/stablehlo_refine_shapes.mlir
+@@ -521,16 +521,16 @@ func.func @eval_slice_zerodim() -> tensor<0x2x1xi64> {
// -----
// CHECK-LABEL: func @eval_slice_zerorank
-func.func @eval_slice_zerorank() -> tensor<f32> {
- // CHECK: [[RESULT:%.*]] = stablehlo.constant dense<3.300000e+01> : tensor<f32>
-- // CHECK: return [[RESULT]]
-- %0 = stablehlo.constant dense<33.0> : tensor<f32>
+func.func @eval_slice_zerorank() -> tensor<i32> {
+ // CHECK: [[RESULT:%.*]] = stablehlo.constant dense<33> : tensor<i32>
-+ // CHECK: return [[RESULT]]
+ // CHECK: return [[RESULT]]
+- %0 = stablehlo.constant dense<33.0> : tensor<f32>
+ %0 = stablehlo.constant dense<33> : tensor<i32>
%1 = "stablehlo.slice"(%0) {
start_indices = array<i64>,
@@ -169,9 +236,10 @@ diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_refine_shapes.mlir b
}
// -----
-diff --ruN a/stablehlo/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo_mixed.mlir b/stablehlo/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo_mixed.mlir
---- stablehlo/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo_mixed.mlir
-+++ stablehlo/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo_mixed.mlir
+diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo_mixed.mlir b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo_mixed.mlir
+index 6340898e..a5d344ac 100644
+--- a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo_mixed.mlir
++++ b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo_mixed.mlir
@@ -5,15 +5,9 @@
// about what constitutes a good test! The CHECK should be
// minimized and named to reflect the test intent.
@@ -190,14 +258,10 @@ diff --ruN a/stablehlo/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo_mixed.mli
// RUN: diff %t.0 %t.1
// CHECK-LABEL: vhlo.func_v1 @op_other(
-@@ -26,6 +20,34 @@
- func.func @op_other(%arg0: tensor<f32>) -> tensor<f32> {
- %0 = arith.addf %arg0, %arg0 : tensor<f32>
- return %0 : tensor<f32>
-+}
-+
-+// -----
-+
+@@ -30,6 +24,34 @@ func.func @op_other(%arg0: tensor<f32>) -> tensor<f32> {
+
+ // -----
+
+// CHECK-LABEL: vhlo.func_v1 @func_attributes(
+// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !vhlo.tensor_v1<!vhlo.f32_v1>) -> (!vhlo.tensor_v1<!vhlo.f32_v1>) {
+// CHECK: "vhlo.return_v1"(%[[VAL_0]]) : (!vhlo.tensor_v1<!vhlo.f32_v1>) -> ()
@@ -222,12 +286,17 @@ diff --ruN a/stablehlo/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo_mixed.mli
+ vhlo.mixed_dict = {affine_map = affine_map<(d0) -> (d0)>, str_attr = "STR_ATTR"}
+} {
+ return %arg0 : tensor<f32>
- }
-
- // -----
-diff --ruN a/stablehlo/stablehlo/tests/vhlo/vhlo_attributes_invalid.mlir b/stablehlo/stablehlo/tests/vhlo/vhlo_attributes_invalid.mlir
---- stablehlo/stablehlo/tests/vhlo/vhlo_attributes_invalid.mlir
-+++ stablehlo/stablehlo/tests/vhlo/vhlo_attributes_invalid.mlir
++}
++
++// -----
++
+ // CHECK-LABEL: vhlo.func_v1 @op_shlo(
+ // CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !vhlo.tensor_v1<!vhlo.f32_v1>) -> (!vhlo.tensor_v1<!vhlo.f32_v1>) {
+ // CHECK: %[[VAL_1:.*]] = "vhlo.add_v1"(%[[VAL_0]], %[[VAL_0]]) : (!vhlo.tensor_v1<!vhlo.f32_v1>, !vhlo.tensor_v1<!vhlo.f32_v1>) -> !vhlo.tensor_v1<!vhlo.f32_v1>
+diff --git a/stablehlo/tests/vhlo/vhlo_attributes_invalid.mlir b/stablehlo/tests/vhlo/vhlo_attributes_invalid.mlir
+index b73d1005..f4330e67 100644
+--- a/stablehlo/tests/vhlo/vhlo_attributes_invalid.mlir
++++ b/stablehlo/tests/vhlo/vhlo_attributes_invalid.mlir
@@ -1,17 +1,8 @@
// RUN: stablehlo-opt --vhlo-to-version=target=1.9.0 -verify-diagnostics --split-input-file %s
@@ -248,10 +317,11 @@ diff --ruN a/stablehlo/stablehlo/tests/vhlo/vhlo_attributes_invalid.mlir b/stabl
} {
return
}
-diff --ruN a/stablehlo/stablehlo/tools/StablehloTranslateMain.cpp b/stablehlo/stablehlo/tools/StablehloTranslateMain.cpp
---- stablehlo/stablehlo/tools/StablehloTranslateMain.cpp
-+++ stablehlo/stablehlo/tools/StablehloTranslateMain.cpp
-@@ -76,6 +76,11 @@
+diff --git a/stablehlo/tools/StablehloTranslateMain.cpp b/stablehlo/tools/StablehloTranslateMain.cpp
+index e3171b2d..fdf0d6a9 100644
+--- a/stablehlo/tools/StablehloTranslateMain.cpp
++++ b/stablehlo/tools/StablehloTranslateMain.cpp
+@@ -76,6 +76,11 @@ llvm::cl::opt<std::string> targetOption(
"target", llvm::cl::desc("Target version for serialization"),
llvm::cl::init(""));
@@ -263,7 +333,7 @@ diff --ruN a/stablehlo/stablehlo/tools/StablehloTranslateMain.cpp b/stablehlo/st
llvm::cl::opt<std::string> argsOption(
"args", llvm::cl::desc("Arguments to pass to the interpreter"),
llvm::cl::init(""));
-@@ -317,7 +322,8 @@
+@@ -317,7 +322,8 @@ TranslateFromMLIRRegistration serializeRegistration(
return module.emitError("failed to strip debuginfo");
}
@@ -273,10 +343,11 @@ diff --ruN a/stablehlo/stablehlo/tools/StablehloTranslateMain.cpp b/stablehlo/st
},
[](DialectRegistry &registry) {
mlir::registerAllDialects(registry);
-diff --ruN a/stablehlo/stablehlo/transforms/Passes.h b/stablehlo/stablehlo/transforms/Passes.h
---- stablehlo/stablehlo/transforms/Passes.h
-+++ stablehlo/stablehlo/transforms/Passes.h
-@@ -35,6 +35,7 @@
+diff --git a/stablehlo/transforms/Passes.h b/stablehlo/transforms/Passes.h
+index 0e4406d2..a5a4aa27 100644
+--- a/stablehlo/transforms/Passes.h
++++ b/stablehlo/transforms/Passes.h
+@@ -35,6 +35,7 @@ limitations under the License.
#include "mlir/Transforms/DialectConversion.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "stablehlo/dialect/Version.h"
@@ -284,7 +355,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/Passes.h b/stablehlo/stablehlo/trans
namespace mlir {
namespace stablehlo {
-@@ -54,17 +55,17 @@
+@@ -54,17 +55,17 @@ void populateStablehloRefineShapesPatterns(MLIRContext *context,
// Populates StableHLO ops to VHLO ops rewriting patterns.
void populateStablehloToVhloPatterns(MLIRContext *context,
RewritePatternSet *patterns,
@@ -305,10 +376,11 @@ diff --ruN a/stablehlo/stablehlo/transforms/Passes.h b/stablehlo/stablehlo/trans
/// Collection of rewrite patterns for lowering of CHLO ops to StableHLO and
/// Shape ops.
-diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp
---- stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp
-+++ stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp
-@@ -57,7 +57,7 @@
+diff --git a/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stablehlo/transforms/StablehloLegalizeToVhlo.cpp
+index e768ded6..25903399 100644
+--- a/stablehlo/transforms/StablehloLegalizeToVhlo.cpp
++++ b/stablehlo/transforms/StablehloLegalizeToVhlo.cpp
+@@ -57,7 +57,7 @@ namespace {
class StablehloToVhloTypeConverter : public vhlo::VhloTypeConverter {
public:
StablehloToVhloTypeConverter(bool allowOtherDialects)
@@ -317,7 +389,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stable
LLVM_DEBUG(
llvm::dbgs()
<< "[StablehloToVhloTypeConverter] Creating with allowOtherDialects: "
-@@ -82,7 +82,15 @@
+@@ -82,7 +82,15 @@ class StablehloToVhloTypeConverter : public vhlo::VhloTypeConverter {
return vhlo::TokenV1Type::get(token.getContext());
});
addBuiltinToVhloConversions();
@@ -334,7 +406,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stable
}
Attribute convertEncoding(Attribute attr) const final {
-@@ -114,7 +122,7 @@
+@@ -114,7 +122,7 @@ class StablehloToVhloTypeConverter : public vhlo::VhloTypeConverter {
return vhlo::Name##Version##Attr::get(attr.getContext(), vhloValue.value())
Attribute convertGeneric(Attribute stablehloAttr,
@@ -343,7 +415,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stable
LLVM_DEBUG(llvm::dbgs() << "Convert generic: " << stablehloAttr << '\n');
// Handle StableHLO attributes.
-@@ -241,6 +249,10 @@
+@@ -241,6 +249,10 @@ Attribute convertGeneric(Attribute stablehloAttr,
return vhlo::TypeV1Attr::get(attr.getContext(), vhloType);
}
@@ -354,7 +426,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stable
LLVM_DEBUG(llvm::dbgs() << "Failed to convert: " << stablehloAttr << '\n');
return {}; // Failed to convert attribute.
}
-@@ -268,13 +280,15 @@
+@@ -268,13 +280,15 @@ SpecialResult notSpecial() { return SpecialResult::NOT_SPECIAL; }
Attribute convertBool(const ConversionPattern& pattern, int64_t stablehloDim) {
auto stablehloType = IntegerType::get(pattern.getContext(), 1);
auto stablehloAttr = IntegerAttr::get(stablehloType, stablehloDim);
@@ -372,7 +444,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stable
}
Attribute convertInts(const ConversionPattern& pattern,
-@@ -282,7 +296,8 @@
+@@ -282,7 +296,8 @@ Attribute convertInts(const ConversionPattern& pattern,
auto stablehloType = RankedTensorType::get(
stablehloDims.size(), IntegerType::get(pattern.getContext(), 64));
auto stablehloAttr = DenseIntElementsAttr::get(stablehloType, stablehloDims);
@@ -382,7 +454,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stable
}
Attribute convertSymbol(const ConversionPattern& pattern,
-@@ -290,7 +305,7 @@
+@@ -290,7 +305,7 @@ Attribute convertSymbol(const ConversionPattern& pattern,
auto stablehloSymbolAttr = dyn_cast<FlatSymbolRefAttr>(stablehloAttr);
if (!stablehloSymbolAttr) return {};
return convertGeneric(stablehloSymbolAttr.getAttr(),
@@ -391,7 +463,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stable
}
SpecialResult convertChannelHandle(const ConversionPattern& pattern,
-@@ -445,15 +460,15 @@
+@@ -445,15 +460,15 @@ SpecialResult convertDotAlgorithm(const ConversionPattern& pattern,
vhloAttrs.emplace_back(
StringAttr::get(pattern.getContext(), "lhs_precision_type"),
convertGeneric(TypeAttr::get(attr.getLhsPrecisionType()),
@@ -410,7 +482,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stable
// Components
auto vhloLhsComponentCount = convertInt(pattern, attr.getLhsComponentCount());
-@@ -712,7 +727,9 @@
+@@ -712,7 +727,9 @@ LogicalResult addDefaults(const OpConversionPattern<StablehloOpTy>& pattern,
auto addDefaultAttr = [&](StringRef vhloName, Attribute stablehloAttr) {
vhloAttrs.emplace_back(
StringAttr::get(pattern.getContext(), vhloName),
@@ -421,7 +493,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stable
};
if constexpr (std::is_same<StablehloOpTy, func::FuncOp>::value) {
if (!stablehloOp.getSymVisibilityAttr())
-@@ -987,8 +1004,9 @@
+@@ -987,8 +1004,9 @@ class StablehloToVhloOpConverter : public OpConversionPattern<StablehloOpTy> {
case SpecialResult::SPECIAL_FAILURE:
return failure();
case SpecialResult::NOT_SPECIAL:
@@ -433,7 +505,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stable
if (!vhloAttr) return failure();
vhloAttrs.push_back({stablehloAttr.getName(), vhloAttr});
break;
-@@ -1075,14 +1093,14 @@
+@@ -1075,14 +1093,14 @@ struct StablehloLegalizeToVhloPass
}
private:
@@ -450,10 +522,11 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stable
populateStablehloToVhloPatterns<
#define GET_OP_LIST
#include "stablehlo/dialect/StablehloOps.cpp.inc"
-diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp
---- stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp
-+++ stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp
-@@ -16,6 +16,7 @@
+diff --git a/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehlo/transforms/StablehloRefineShapes.cpp
+index 64eaa3c0..4b585b46 100644
+--- a/stablehlo/transforms/StablehloRefineShapes.cpp
++++ b/stablehlo/transforms/StablehloRefineShapes.cpp
+@@ -16,6 +16,7 @@ limitations under the License.
#include <cstddef>
#include <cstdint>
@@ -461,7 +534,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehl
#include <tuple>
#include <utility>
-@@ -1038,8 +1039,11 @@
+@@ -1038,8 +1039,11 @@ LogicalResult applyShapeRefinementPatterns(func::FuncOp func,
// Populate additional patterns for StableHLO extensions.
state.addAdditionalPatterns(patterns);
@@ -473,10 +546,11 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehl
// The folding patterns implement partial evaluation of shape computations
// which is a critical part of implementing type refinement for ops like
-diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp
---- stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp
-+++ stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp
-@@ -57,7 +57,10 @@
+diff --git a/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stablehlo/transforms/VhloLegalizeToStablehlo.cpp
+index 5662ad9d..41973a0e 100644
+--- a/stablehlo/transforms/VhloLegalizeToStablehlo.cpp
++++ b/stablehlo/transforms/VhloLegalizeToStablehlo.cpp
+@@ -57,7 +57,10 @@ namespace {
class VhloToStablehloTypeConverter : public vhlo::VhloTypeConverter {
public:
@@ -488,7 +562,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable
addConversion([](Type type) -> Type { return type; });
addConversion([](vhlo::TokenV1Type token) -> Type {
LLVM_DEBUG(llvm::dbgs() << "Converting TokenType\n");
-@@ -90,7 +93,7 @@
+@@ -90,7 +93,7 @@ class VhloToStablehloTypeConverter : public vhlo::VhloTypeConverter {
return stablehlo::Name##Attr::get(attr.getContext(), stablehloValue.value())
Attribute convertGeneric(Attribute vhloAttr,
@@ -497,18 +571,18 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable
LLVM_DEBUG(llvm::dbgs() << "Converting attr " << vhloAttr);
if (auto vhloAttrs = dyn_cast<vhlo::ArrayV1Attr>(vhloAttr)) {
SmallVector<Attribute> stablehloAttrs;
-@@ -189,6 +192,10 @@
- // All VHLO attributes must have counterparts in StableHLO.
+@@ -190,6 +193,10 @@ Attribute convertGeneric(Attribute vhloAttr,
return {};
}
-+
+
+ // Fall back to type converter for unknown attributes.
+ auto unknownAttr = typeConverter->convertUnknownAttribute(vhloAttr);
+ if (unknownAttr) return unknownAttr;
-
++
// This should be unreachable unless program is a mix of VHLO and other
// dialects, e.g. due to user edits to textual assembly format.
-@@ -229,7 +236,7 @@
+ return {};
+@@ -229,7 +236,7 @@ bool isNoneType(Attribute vhloAttr) {
}
LogicalResult convertTypeAttr(Attribute vhloAttr, Type& result,
@@ -517,7 +591,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable
auto stablehloAttr = convertGeneric(vhloAttr, typeConverter);
if (!stablehloAttr || !isa<TypeAttr>(stablehloAttr)) return failure();
result = cast<TypeAttr>(stablehloAttr).getValue();
-@@ -244,7 +251,7 @@
+@@ -244,7 +251,7 @@ LogicalResult convertInt(Attribute vhloAttr, int64_t& result) {
}
LogicalResult convertInts(Attribute vhloAttr,
@@ -526,7 +600,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable
SmallVector<int64_t>& result) {
auto vhloTensorAttr = dyn_cast<vhlo::TensorV1Attr>(vhloAttr);
if (!vhloTensorAttr) return failure();
-@@ -256,7 +263,7 @@
+@@ -256,7 +263,7 @@ LogicalResult convertInts(Attribute vhloAttr,
}
Attribute convertSymbol(Attribute vhloAttr,
@@ -535,7 +609,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable
auto vhloStringAttr = dyn_cast<vhlo::StringV1Attr>(vhloAttr);
if (!vhloStringAttr) return {};
auto stablehloStringAttr = dyn_cast_or_null<StringAttr>(
-@@ -267,7 +274,7 @@
+@@ -267,7 +274,7 @@ Attribute convertSymbol(Attribute vhloAttr,
template <typename OpType>
Attribute convertChannelHandle(OpType vhloOp,
@@ -544,7 +618,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable
int64_t channelId, channelType;
if (failed(convertInt(vhloOp.getChannelId(), channelId)) ||
failed(convertInt(vhloOp.getChannelType(), channelType)))
-@@ -277,7 +284,7 @@
+@@ -277,7 +284,7 @@ Attribute convertChannelHandle(OpType vhloOp,
}
Attribute convertChannelId(Attribute vhloAttr,
@@ -553,7 +627,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable
int64_t channelId;
if (failed(convertInt(vhloAttr, channelId))) return {};
return stablehlo::ChannelHandleAttr::get(vhloAttr.getContext(), channelId,
-@@ -285,8 +292,8 @@
+@@ -285,8 +292,8 @@ Attribute convertChannelId(Attribute vhloAttr,
}
template <typename OpType>
@@ -564,7 +638,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable
int64_t stablehloInputBatchDimension, stablehloInputFeatureDimension;
SmallVector<int64_t> stablehloInputSpatialDimensions;
int64_t stablehloKernelInputFeatureDimension,
-@@ -323,7 +330,7 @@
+@@ -323,7 +330,7 @@ Attribute convertConvDimensionNumbers(OpType vhloOp,
}
Attribute convertCustomCallCalledComputations(
@@ -573,7 +647,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable
if (auto vhloArrayAttr = dyn_cast<vhlo::ArrayV1Attr>(vhloAttr)) {
SmallVector<Attribute> stablehloAttrs;
for (auto vhloAttr : vhloArrayAttr.getValue()) {
-@@ -336,8 +343,8 @@
+@@ -336,8 +343,8 @@ Attribute convertCustomCallCalledComputations(
return {};
}
@@ -584,7 +658,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable
Type lhsPrecisionType, rhsPrecisionType, accumulationType;
if (isNoneType(vhloOp.getLhsComponentCount())) {
// All must be nonetype
-@@ -373,8 +380,8 @@
+@@ -373,8 +380,8 @@ FailureOr<Attribute> convertDotAlgorithm(vhlo::DotGeneralOpV2 vhloOp,
numPrimitiveOperations, allowImpreciseAccumulation);
}
@@ -595,7 +669,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable
SmallVector<int64_t> stablehloLhsBatchingDimensions,
stablehloRhsBatchingDimensions, stablehloLhsContractingDimensions,
stablehloRhsContractingDimensions;
-@@ -394,13 +401,13 @@
+@@ -394,13 +401,13 @@ Attribute convertDotDimensionNumbers(vhlo::DotGeneralOpV2 vhloOp,
}
Attribute convertFuncCallee(Attribute vhloAttr,
@@ -612,7 +686,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable
SmallVector<int64_t> stablehloOffsetDims, stablehloCollapsedSliceDims,
stablehloOperandBatchingDims, stablehloStartIndicesBatchingDims,
stablehloStartIndexMap;
-@@ -423,8 +430,8 @@
+@@ -423,8 +430,8 @@ Attribute convertGatherDimensionNumbers(OpType vhloOp,
stablehloStartIndexMap, stablehloIndexVectorDim);
}
@@ -623,7 +697,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable
SmallVector<int64_t> stablehloUpdateWindowDims, stablehloInsertedWindowDims,
stablehloInputBatchingDims, stablehloScatterIndicesBatchingDims,
stablehloScatterDimsToOperandDims;
-@@ -463,10 +470,11 @@
+@@ -463,10 +470,11 @@ LogicalResult implodeSpecial(const OpConversionPattern<VhloOpTy>& pattern,
VhloOpTy vhloOp,
SmallVector<NamedAttribute>& vhloAttrs,
SmallVector<NamedAttribute>& stablehloAttrs) {
@@ -637,7 +711,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable
if (!stablehloAttr) return failure();
stablehloAttrs.emplace_back(
StringAttr::get(pattern.getContext(), "dimension_numbers"),
-@@ -480,7 +488,7 @@
+@@ -480,7 +488,7 @@ LogicalResult implodeSpecial(const OpConversionPattern<VhloOpTy>& pattern,
if constexpr (std::is_same<VhloOpTy, vhlo::DotGeneralOpV2>::value) {
// Dot Dimension Numbers
auto stablehloDotDimAttr =
@@ -646,7 +720,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable
if (!stablehloDotDimAttr) return failure();
stablehloAttrs.emplace_back(
StringAttr::get(pattern.getContext(), "dot_dimension_numbers"),
-@@ -489,8 +497,7 @@
+@@ -489,8 +497,7 @@ LogicalResult implodeSpecial(const OpConversionPattern<VhloOpTy>& pattern,
"lhs_contracting_dimensions", "rhs_contracting_dimensions");
// Dot Algorithm
@@ -656,7 +730,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable
if (failed(stablehloDotAlgorithmAttr)) return failure();
if (stablehloDotAlgorithmAttr.value())
stablehloAttrs.emplace_back(
-@@ -503,8 +510,7 @@
+@@ -503,8 +510,7 @@ LogicalResult implodeSpecial(const OpConversionPattern<VhloOpTy>& pattern,
}
if constexpr (std::is_same<VhloOpTy, vhlo::DynamicGatherOpV2>::value ||
std::is_same<VhloOpTy, vhlo::GatherOpV2>::value) {
@@ -666,7 +740,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable
if (!stablehloAttr) return failure();
stablehloAttrs.emplace_back(
StringAttr::get(pattern.getContext(), "dimension_numbers"),
-@@ -514,8 +520,7 @@
+@@ -514,8 +520,7 @@ LogicalResult implodeSpecial(const OpConversionPattern<VhloOpTy>& pattern,
"start_index_map", "index_vector_dim");
}
if constexpr (std::is_same<VhloOpTy, vhlo::ScatterOpV2>::value) {
@@ -676,7 +750,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable
if (!stablehloAttr) return failure();
stablehloAttrs.emplace_back(
StringAttr::get(pattern.getContext(), "scatter_dimension_numbers"),
-@@ -526,8 +531,7 @@
+@@ -526,8 +531,7 @@ LogicalResult implodeSpecial(const OpConversionPattern<VhloOpTy>& pattern,
}
if constexpr (std::is_same<VhloOpTy, vhlo::RecvOpV1>::value ||
std::is_same<VhloOpTy, vhlo::SendOpV1>::value) {
@@ -686,7 +760,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable
if (!stablehloAttr) return failure();
stablehloAttrs.emplace_back(
StringAttr::get(pattern.getContext(), "channel_handle"), stablehloAttr);
-@@ -537,7 +541,7 @@
+@@ -537,7 +541,7 @@ LogicalResult implodeSpecial(const OpConversionPattern<VhloOpTy>& pattern,
}
template <typename T, typename DenseArrayAttr>
@@ -695,7 +769,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable
StringAttr vhloName, Attribute vhloAttr,
SmallVector<NamedAttribute>& stablehloAttrs) {
auto tensorAttr = dyn_cast<vhlo::TensorV1Attr>(vhloAttr);
-@@ -556,15 +560,15 @@
+@@ -556,15 +560,15 @@ SpecialResult convertDenseArray(const TypeConverter* typeConverter,
}
SpecialResult convertDenseBoolArray(
@@ -715,7 +789,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable
return convertDenseArray<int64_t, DenseI64ArrayAttr>(
typeConverter, vhloName, vhloAttr, stablehloAttrs);
}
-@@ -575,7 +579,8 @@
+@@ -575,7 +579,8 @@ SpecialResult convertSpecial(const OpConversionPattern<VhloOpTy>& pattern,
SmallVector<NamedAttribute>& stablehloAttrs) {
StringAttr stablehloName = vhloName;
Attribute stablehloAttr;
@@ -725,7 +799,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable
if constexpr (std::is_same<VhloOpTy, vhlo::AllGatherOpV2>::value ||
std::is_same<VhloOpTy, vhlo::AllReduceOpV2>::value ||
-@@ -585,7 +590,7 @@
+@@ -585,7 +590,7 @@ SpecialResult convertSpecial(const OpConversionPattern<VhloOpTy>& pattern,
std::is_same<VhloOpTy, vhlo::CollectiveBroadcastOpV1>::value) {
if (vhloName == "channel_id") {
stablehloName = StringAttr::get(pattern.getContext(), "channel_handle");
@@ -734,7 +808,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable
if (!stablehloAttr) return specialFailure();
}
if (vhloName == "use_global_device_ids") {
-@@ -597,20 +602,20 @@
+@@ -597,20 +602,20 @@ SpecialResult convertSpecial(const OpConversionPattern<VhloOpTy>& pattern,
}
if constexpr (std::is_same<VhloOpTy, vhlo::CustomCallOpV1>::value) {
if (vhloName == "called_computations") {
@@ -759,7 +833,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable
if (!stablehloAttr) return specialFailure();
}
}
-@@ -760,8 +765,8 @@
+@@ -760,8 +765,8 @@ bool isDefaultResultAccuracyAttribute(Attribute vhloAttr) {
template <typename T>
bool isSplatTensor(const ConversionPattern& pattern, Attribute vhloAttr,
T splatValue) {
@@ -770,7 +844,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable
return attr && attr.isSplat() &&
attr.template getSplatValue<T>() == splatValue;
}
-@@ -977,8 +982,9 @@
+@@ -977,8 +982,9 @@ class VhloToStablehloOpConverter : public OpConversionPattern<VhloOpTy> {
case SpecialResult::SPECIAL_FAILURE:
return failure();
case SpecialResult::NOT_SPECIAL:
@@ -782,7 +856,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable
if (!stablehloAttr) return failure();
stablehloAttrs.push_back({vhloAttr.getName(), stablehloAttr});
break;
-@@ -1056,7 +1062,7 @@
+@@ -1056,7 +1062,7 @@ struct ReconcileUnrealizedConversionCasts
template <typename... StablehloOpTypes>
void populateVhloToStablehloPatterns(MLIRContext* context,
RewritePatternSet* patterns,
@@ -791,7 +865,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable
patterns
->add<VhloToStablehloOpConverter<StablehloToVhloOp<StablehloOpTypes>>...>(
*converter, context);
-@@ -1104,7 +1110,7 @@
+@@ -1104,7 +1110,7 @@ struct VhloLegalizeToStablehloPass
void populateVhloToStablehloPatterns(MLIRContext* context,
RewritePatternSet* patterns,
@@ -800,10 +874,11 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable
populateVhloToStablehloPatterns<
#define GET_OP_LIST
#include "stablehlo/dialect/StablehloOps.cpp.inc"
-diff --ruN a/stablehlo/stablehlo/transforms/VhloToVersion.cpp b/stablehlo/stablehlo/transforms/VhloToVersion.cpp
---- stablehlo/stablehlo/transforms/VhloToVersion.cpp
-+++ stablehlo/stablehlo/transforms/VhloToVersion.cpp
-@@ -56,15 +56,23 @@
+diff --git a/stablehlo/transforms/VhloToVersion.cpp b/stablehlo/transforms/VhloToVersion.cpp
+index dfbf2877..7818e601 100644
+--- a/stablehlo/transforms/VhloToVersion.cpp
++++ b/stablehlo/transforms/VhloToVersion.cpp
+@@ -56,15 +56,23 @@ namespace {
// Currently there are no type-to-version conversions so this class
// simply validates that all types are from the VHLO dialect.
@@ -833,7 +908,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloToVersion.cpp b/stablehlo/stable
};
// Check user-specified target version. Emit error if invalid.
-@@ -111,6 +119,10 @@
+@@ -111,6 +119,10 @@ bool isLegalVersion(VersionedInterface& interface, const Version& target) {
LogicalResult isLegalType(Type type, const Version& targetVersion);
LogicalResult isLegalAttribute(const Attribute& attr, Version targetVersion) {
@@ -844,7 +919,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloToVersion.cpp b/stablehlo/stable
auto attrInterface = dyn_cast<VersionedAttrInterface>(attr);
if (!attrInterface || !isLegalVersion(attrInterface, targetVersion)) {
LLVM_DEBUG(llvm::dbgs() << "failed to legalize attribute " << attr
-@@ -119,10 +131,11 @@
+@@ -119,10 +131,11 @@ LogicalResult isLegalAttribute(const Attribute& attr, Version targetVersion) {
}
// Recursively check attrs if VHLO attr is a container
@@ -857,7 +932,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloToVersion.cpp b/stablehlo/stable
if (auto arrAttr = dyn_cast<DictionaryV1Attr>(attr)) {
return success(llvm::all_of(
arrAttr.getValue(), [&](std::pair<Attribute, Attribute> entry) {
-@@ -146,6 +159,10 @@
+@@ -146,6 +159,10 @@ LogicalResult isLegalAttribute(const Attribute& attr, Version targetVersion) {
}
LogicalResult isLegalType(Type type, const Version& targetVersion) {
@@ -868,7 +943,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloToVersion.cpp b/stablehlo/stable
// All valid VHLO types must have versioned type interface.
auto typeInterface = dyn_cast<VersionedTypeInterface>(type);
if (!typeInterface || !isLegalVersion(typeInterface, targetVersion)) {
-@@ -170,10 +187,11 @@
+@@ -170,10 +187,11 @@ LogicalResult isLegalType(Type type, const Version& targetVersion) {
return failure();
return isLegalType(ranked.getElementType(), targetVersion);
}
@@ -881,7 +956,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloToVersion.cpp b/stablehlo/stable
if (auto quant = dyn_cast<UniformQuantizedV1Type>(type))
return success(
succeeded(isLegalType(quant.getStorageType(), targetVersion)) &&
-@@ -213,7 +231,7 @@
+@@ -213,7 +231,7 @@ bool isLegalLocation(Location loc, const Version& targetVersion) {
bool isLegalOperation(Operation* op, const Version& targetVersion) {
// Validate op
auto opInterface = dyn_cast<VersionedOpInterface>(op);
@@ -890,7 +965,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloToVersion.cpp b/stablehlo/stable
if (!isLegalVersion(opInterface, targetVersion)) return false;
LLVM_DEBUG(llvm::dbgs() << "Legal op version for target. " << op << '\n');
-@@ -454,7 +472,7 @@
+@@ -454,7 +472,7 @@ struct AllReduceOpV2ToV1 : public OpRewritePattern<AllReduceOpV2> {
namespace stablehlo {
void populateVhloToVersionPatterns(MLIRContext* context,
RewritePatternSet* patterns,
@@ -899,10 +974,11 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloToVersion.cpp b/stablehlo/stable
vhlo::populateWithGenerated(*patterns);
patterns->add<vhlo::ScatterOpV1ToV2, vhlo::ScatterOpV2ToV1>(context);
patterns->add<vhlo::AllReduceOpV1ToV2, vhlo::AllReduceOpV2ToV1>(context);
-diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp b/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp
---- stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp
-+++ stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp
-@@ -19,6 +19,7 @@
+diff --git a/stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp b/stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp
+index 9cfac0ba..09bfb783 100644
+--- a/stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp
++++ b/stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp
+@@ -19,6 +19,7 @@ limitations under the License.
#include <memory>
#include <numeric>
#include <optional>
@@ -910,7 +986,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
#include <utility>
#include "llvm/ADT/APInt.h"
-@@ -47,6 +48,7 @@
+@@ -47,6 +48,7 @@ limitations under the License.
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
@@ -918,7 +994,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
-@@ -98,29 +100,88 @@
+@@ -98,29 +100,88 @@ LogicalResult validateStaticShapeResult(PatternRewriter& rewriter,
return success();
}
@@ -982,8 +1058,8 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
if (res) return cast<TypedAttr>(res);
return nullptr;
-+}
-+
+ }
+
+
+/// Binary constant folder that used a generic folder function to handle both
+/// ints and floats.
@@ -1005,8 +1081,8 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
+ if (!res) return rewriter.notifyMatchFailure(op, "folding failed");
+
+ return res;
- }
-
++}
++
template <class AttrElementT, class TargetAttrElementT, class CalculationT,
typename OpType>
-LogicalResult evalConvertHelper(PatternRewriter& rewriter, OpType op,
@@ -1014,7 +1090,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
DenseIntOrFPElementsAttr elements, Type resType,
CalculationT&& calculate) {
auto result = constFoldCastOp<AttrElementT, TargetAttrElementT,
-@@ -128,18 +189,19 @@
+@@ -128,18 +189,19 @@ LogicalResult evalConvertHelper(PatternRewriter& rewriter, OpType op,
typename TargetAttrElementT::ValueType, void>(
elements, resType, calculate);
@@ -1036,7 +1112,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
DenseIntOrFPElementsAttr elements,
RankedTensorType resultType) {
auto oldType = getElementTypeOrSelf(elements);
-@@ -153,7 +215,7 @@
+@@ -153,7 +215,7 @@ LogicalResult evalConvert(PatternRewriter& rewriter, OpType op,
if (auto newFloatType = dyn_cast<FloatType>(newType)) {
// Float -> Float
const auto& targetSemantics = newFloatType.getFloatSemantics();
@@ -1045,7 +1121,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
rewriter, op, elements, resultType,
[&targetSemantics](const APFloat& operand, bool& castStatus) {
bool losesInfo;
-@@ -167,7 +229,7 @@
+@@ -167,7 +229,7 @@ LogicalResult evalConvert(PatternRewriter& rewriter, OpType op,
}
// Float -> Int
@@ -1054,7 +1130,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
rewriter, op, elements, resultType,
[&newBitWidth, &isNewTypeUnsigned](const APFloat& operand,
bool& castStatus) {
-@@ -186,7 +248,7 @@
+@@ -186,7 +248,7 @@ LogicalResult evalConvert(PatternRewriter& rewriter, OpType op,
if (auto newFloatType = dyn_cast<FloatType>(newType)) {
// Int -> Float
@@ -1063,7 +1139,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
rewriter, op, elements, resultType,
[&newFloatType, &isOldTypeUnsigned](const APInt& operand,
bool& /*castStatus*/) {
-@@ -199,64 +261,12 @@
+@@ -199,7 +261,7 @@ LogicalResult evalConvert(PatternRewriter& rewriter, OpType op,
}
// Int -> Int
@@ -1072,10 +1148,10 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
rewriter, op, elements, resultType,
[&newBitWidth, &isOldTypeUnsigned](const APInt& operand,
bool& /*castStatus*/) {
- return APSInt(operand, isOldTypeUnsigned).extOrTrunc(newBitWidth);
+@@ -207,58 +269,6 @@ LogicalResult evalConvert(PatternRewriter& rewriter, OpType op,
});
--}
--
+ }
+
-// The patterns below implement partial evaluation of shape computations which
-// is a critical part of implementing type refinement for ops like
-// dynamic_broadcast_in_dim, dynamic_iota and dynamic_reshape whose shape
@@ -1126,10 +1202,12 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
- rewriter.replaceOpWithNewOp<ConstantOp>(op,
- getTensorAttr(resultType, result));
- return success();
- }
-
+-}
+-
template <typename OpType>
-@@ -275,29 +285,18 @@
+ struct FoldOpRewritePattern : OpRewritePattern<OpType> {
+ FoldOpRewritePattern(MLIRContext* context,
+@@ -275,29 +285,18 @@ struct FoldOpRewritePattern : OpRewritePattern<OpType> {
ArrayRef<StringRef> generatedNames = {}) = delete;
const StablehloAggressiveFolderPassOptions& options;
@@ -1154,9 +1232,8 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
- rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(op, res);
- return success();
- }
--
+
- return failure();
-+
+ LogicalResult validateElementCountForFold(PatternRewriter& rewriter,
+ Operation* op,
+ ShapedType resultType) const {
@@ -1171,31 +1248,21 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
}
};
-@@ -318,100 +317,102 @@
+@@ -318,100 +317,102 @@ struct ShapeOpRewritePattern : public FoldOpRewritePattern<OpType> {
}
};
-struct EvalAddOpShapePattern : public FoldOpRewritePattern<AddOp> {
- using FoldOpRewritePattern::FoldOpRewritePattern;
--
-- LogicalResult matchAndRewrite(AddOp op,
-- PatternRewriter& rewriter) const override {
-- return evalElementwise(rewriter, op,
-- [&](APSInt lhs, APSInt rhs) { return lhs + rhs; });
-- }
--};
--
--struct EvalAndOpPattern : public FoldOpRewritePattern<AndOp> {
-- using FoldOpRewritePattern::FoldOpRewritePattern;
--
-- LogicalResult matchAndRewrite(AndOp op,
-- PatternRewriter& rewriter) const override {
+struct FoldAddOpPattern final
+ : public ShapeOpRewritePattern<mlir::stablehlo::AddOp> {
+ using ShapeOpRewritePattern::ShapeOpRewritePattern;
-+
+
+- LogicalResult matchAndRewrite(AddOp op,
+ LogicalResult matchAndRewrite(mlir::stablehlo::AddOp op,
-+ PatternRewriter& rewriter) const override {
+ PatternRewriter& rewriter) const override {
+- return evalElementwise(rewriter, op,
+- [&](APSInt lhs, APSInt rhs) { return lhs + rhs; });
+ if (failed(validateShapeFoldDtype(rewriter, op, op.getType())))
+ return failure();
+
@@ -1203,14 +1270,17 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
+ if (failed(res)) return failure();
+ rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(op, res.value());
+ return success();
-+ }
-+};
-+
+ }
+ };
+
+-struct EvalAndOpPattern : public FoldOpRewritePattern<AndOp> {
+- using FoldOpRewritePattern::FoldOpRewritePattern;
+struct FoldAndOpPattern : public ShapeOpRewritePattern<AndOp> {
+ using ShapeOpRewritePattern::ShapeOpRewritePattern;
-+
+
+- LogicalResult matchAndRewrite(AndOp op,
+ LogicalResult matchAndRewrite(mlir::stablehlo::AndOp op,
-+ PatternRewriter& rewriter) const override {
+ PatternRewriter& rewriter) const override {
+ // TODO: Support more int types
auto resultType = op.getType();
if (!resultType.getElementType().isInteger(1))
@@ -1219,24 +1289,14 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
- return evalElementwise(rewriter, op, [&](APSInt lhsInt, APSInt rhsInt) {
- return getAPSInt(resultType.getElementType(), lhsInt != 0 && rhsInt != 0);
- });
-- }
+ auto res = foldBinaryOpIntOrFloat(rewriter, op, FoldAnd{});
+ if (failed(res)) return failure();
+ rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(op, res.value());
+ return success();
-+ }
-+
-+ struct FoldAnd {
-+ APInt operator()(APInt lhs, APInt rhs) const {
-+ return APInt(lhs.getBitWidth(), !lhs.isZero() && !rhs.isZero());
-+ }
-+ std::optional<APFloat> operator()(APFloat lhs, APFloat rhs) const {
-+ return std::nullopt;
-+ }
-+ };
- };
+ }
+-};
- // Pattern: broadcast_in_dim(splat, _) -> constant(splat)
+-// Pattern: broadcast_in_dim(splat, _) -> constant(splat)
-struct FoldBroadcastInDimSplatPattern final
- : FoldOpRewritePattern<mlir::stablehlo::BroadcastInDimOp> {
- using FoldOpRewritePattern::FoldOpRewritePattern;
@@ -1251,14 +1311,22 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
- op, SplatElementsAttr::get(op.getType(),
- cstAttr.getSplatValue<Attribute>()));
- return success();
-- }
++ struct FoldAnd {
++ APInt operator()(APInt lhs, APInt rhs) const {
++ return APInt(lhs.getBitWidth(), !lhs.isZero() && !rhs.isZero());
+ }
- return failure();
- }
--};
--
++ std::optional<APFloat> operator()(APFloat lhs, APFloat rhs) const {
++ return std::nullopt;
++ }
++ };
+ };
+
-struct EvalBroadcastInDimOpPattern
- : public FoldOpRewritePattern<BroadcastInDimOp> {
- using FoldOpRewritePattern::FoldOpRewritePattern;
++// Pattern: broadcast_in_dim(splat, _) -> constant(splat)
+struct FoldBroadcastInDimOpSplatPattern
+ : public ShapeOpRewritePattern<BroadcastInDimOp> {
+ using ShapeOpRewritePattern::ShapeOpRewritePattern;
@@ -1267,8 +1335,10 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
PatternRewriter& rewriter) const override {
auto resultType = op.getType();
- if (failed(validateStaticShapeResult(rewriter, op, resultType)))
-- return failure();
--
++ if (failed(validateStaticShapeResult(rewriter, op, resultType)) ||
++ failed(validateShapeFoldDtype(rewriter, op, resultType)))
+ return failure();
+
- auto operandType = op.getOperand().getType();
- if (operandType.getRank() != 0)
- return rewriter.notifyMatchFailure(op, "expected 0-dimensional type");
@@ -1277,52 +1347,34 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
- if (failed(hlo::matchInts(op.getOperand(), operand)))
- return rewriter.notifyMatchFailure(op, "expected constant operands");
- auto scalar = operand[0];
--
++ SplatElementsAttr cstAttr;
++ matchPattern(op.getOperand(), m_Constant(&cstAttr));
++ if (!cstAttr) return rewriter.notifyMatchFailure(op, "operand not splat");
+
- rewriter.replaceOpWithNewOp<ConstantOp>(
- op, getTensorAttr(op.getType(), scalar));
-- return success();
-- }
--};
--
++ rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(
++ op, SplatElementsAttr::get(op.getType(),
++ cstAttr.getSplatValue<Attribute>()));
+ return success();
+ }
+ };
+
-struct EvalClampOpPattern : public FoldOpRewritePattern<ClampOp> {
- using FoldOpRewritePattern::FoldOpRewritePattern;
--
++struct FoldCompareOpPattern : public ShapeOpRewritePattern<CompareOp> {
++ using ShapeOpRewritePattern::ShapeOpRewritePattern;
+
- LogicalResult matchAndRewrite(ClampOp op,
-- PatternRewriter& rewriter) const override {
++ LogicalResult matchAndRewrite(CompareOp op,
+ PatternRewriter& rewriter) const override {
- return evalElementwise(rewriter, op,
- [&](APSInt min, APSInt operand, APSInt max) {
- if (operand < min) return min;
- if (max < operand) return max;
- return operand;
- });
-- }
--};
--
--struct EvalCompareOpPattern : public FoldOpRewritePattern<CompareOp> {
-- using FoldOpRewritePattern::FoldOpRewritePattern;
-+ if (failed(validateStaticShapeResult(rewriter, op, resultType)) ||
-+ failed(validateShapeFoldDtype(rewriter, op, resultType)))
-+ return failure();
-+
-+ SplatElementsAttr cstAttr;
-+ matchPattern(op.getOperand(), m_Constant(&cstAttr));
-+ if (!cstAttr) return rewriter.notifyMatchFailure(op, "operand not splat");
-+
-+ rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(
-+ op, SplatElementsAttr::get(op.getType(),
-+ cstAttr.getSplatValue<Attribute>()));
-+ return success();
-+ }
-+};
-+
-+struct FoldCompareOpPattern : public ShapeOpRewritePattern<CompareOp> {
-+ using ShapeOpRewritePattern::ShapeOpRewritePattern;
-
- LogicalResult matchAndRewrite(CompareOp op,
- PatternRewriter& rewriter) const override {
- auto resultType = op.getType();
-- auto kind = op.getCompareType();
-- return evalElementwise(rewriter, op, [&](APInt lhs, APInt rhs) {
++ auto resultType = op.getType();
+ if (failed(validateShapeFoldDtype(rewriter, op, resultType)))
+ return failure();
+
@@ -1332,15 +1384,23 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
+ if (failed(res)) return failure();
+ rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(op, res.value());
+ return success();
-+ }
-+
+ }
+-};
+
+-struct EvalCompareOpPattern : public FoldOpRewritePattern<CompareOp> {
+- using FoldOpRewritePattern::FoldOpRewritePattern;
+ struct FoldCompare {
+ FoldCompare(ComparisonDirection direction,
+ std::optional<ComparisonType> kind)
+ : direction(direction), kind(kind) {}
+ ComparisonDirection direction;
+ std::optional<ComparisonType> kind;
-+
+
+- LogicalResult matchAndRewrite(CompareOp op,
+- PatternRewriter& rewriter) const override {
+- auto resultType = op.getType();
+- auto kind = op.getCompareType();
+- return evalElementwise(rewriter, op, [&](APInt lhs, APInt rhs) {
+ // TODO: Enable float folding.
+ std::optional<APFloat> operator()(APFloat lhs, APFloat rhs) {
+ return std::nullopt;
@@ -1352,7 +1412,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
case ComparisonDirection::EQ:
result = lhs == rhs;
break;
-@@ -431,9 +432,9 @@
+@@ -431,9 +432,9 @@ struct EvalCompareOpPattern : public FoldOpRewritePattern<CompareOp> {
result = kind == ComparisonType::SIGNED ? lhs.slt(rhs) : lhs.ult(rhs);
break;
}
@@ -1365,7 +1425,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
};
//////////////////////////////////
-@@ -441,16 +442,15 @@
+@@ -441,16 +442,15 @@ struct EvalCompareOpPattern : public FoldOpRewritePattern<CompareOp> {
/////////////////////////////////
struct FoldConcatenateOpPattern final
@@ -1387,7 +1447,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
return failure();
// Fold concatenate when all inputs are constants.
-@@ -466,6 +466,7 @@
+@@ -466,6 +466,7 @@ struct FoldConcatenateOpPattern final
int64_t{1}, std::multiplies<>{});
SmallVector<Attribute> newElems;
@@ -1395,7 +1455,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
newElems.reserve(numElems);
for (int64_t i = 0; i != topSize; ++i) {
-@@ -485,31 +486,7 @@
+@@ -485,31 +486,7 @@ struct FoldConcatenateOpPattern final
int64_t foldOpElementLimit;
};
@@ -1428,20 +1488,17 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
using ShapeOpRewritePattern::ShapeOpRewritePattern;
LogicalResult matchAndRewrite(ConvertOp op,
-@@ -532,28 +509,50 @@
+@@ -532,28 +509,50 @@ struct EvalConvertOpPattern : public ShapeOpRewritePattern<ConvertOp> {
return rewriter.notifyMatchFailure(
op, "expected constant integer or float operand");
- return evalConvert(rewriter, op, elements, resultType);
-- }
--};
--
++ return foldConvert(rewriter, op, elements, resultType);
+ }
+ };
+
-struct EvalDivOpPattern : public FoldOpRewritePattern<DivOp> {
- using FoldOpRewritePattern::FoldOpRewritePattern;
-+ return foldConvert(rewriter, op, elements, resultType);
-+ }
-+};
-+
+struct FoldDivOpPattern : public ShapeOpRewritePattern<DivOp> {
+ using ShapeOpRewritePattern::ShapeOpRewritePattern;
@@ -1449,12 +1506,6 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
PatternRewriter& rewriter) const override {
- return evalElementwise(rewriter, op,
- [&](APSInt lhs, APSInt rhs) { return lhs / rhs; });
-- }
--};
--
--struct EvalGetDimensionSizeOpPattern
-- : public FoldOpRewritePattern<GetDimensionSizeOp> {
-- using FoldOpRewritePattern::FoldOpRewritePattern;
+ auto resultType = op.getType();
+ if (failed(validateShapeFoldDtype(rewriter, op, resultType)))
+ return failure();
@@ -1464,7 +1515,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
+ if (failed(res)) return failure();
+ rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(op, res.value());
+ return success();
-+ }
+ }
+
+ struct FoldDivide {
+ FoldDivide(bool isUnsignedInt)
@@ -1479,8 +1530,11 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
+ static APInt foldUint(APInt lhs, APInt rhs) { return lhs.udiv(rhs); }
+ static APInt foldSint(APInt lhs, APInt rhs) { return lhs.sdiv(rhs); }
+ };
-+};
-+
+ };
+
+-struct EvalGetDimensionSizeOpPattern
+- : public FoldOpRewritePattern<GetDimensionSizeOp> {
+- using FoldOpRewritePattern::FoldOpRewritePattern;
+struct FoldGetDimensionSizeOpPattern
+ : public ShapeOpRewritePattern<GetDimensionSizeOp> {
+ using ShapeOpRewritePattern::ShapeOpRewritePattern;
@@ -1494,7 +1548,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
return failure();
auto operandType = op.getOperand().getType();
-@@ -567,86 +566,187 @@
+@@ -567,86 +566,187 @@ struct EvalGetDimensionSizeOpPattern
}
};
@@ -1549,11 +1603,6 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
- return evalElementwise(rewriter, op, [&](APSInt lhs, APSInt rhs) {
- return lhs >= rhs ? lhs : rhs;
- });
-- }
--};
--
--struct EvalMinOpPattern : public FoldOpRewritePattern<MinOp> {
-- using FoldOpRewritePattern::FoldOpRewritePattern;
+ auto resultType = op.getType();
+ if (failed(validateShapeFoldDtype(rewriter, op, resultType)))
+ return failure();
@@ -1563,9 +1612,11 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
+ if (failed(res)) return failure();
+ rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(op, res.value());
+ return success();
-+ }
-+};
-+
+ }
+ };
+
+-struct EvalMinOpPattern : public FoldOpRewritePattern<MinOp> {
+- using FoldOpRewritePattern::FoldOpRewritePattern;
+struct FoldMinOpPattern : public ShapeOpRewritePattern<MinOp> {
+ using ShapeOpRewritePattern::ShapeOpRewritePattern;
@@ -1574,11 +1625,6 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
- return evalElementwise(rewriter, op, [&](APSInt lhs, APSInt rhs) {
- return lhs <= rhs ? lhs : rhs;
- });
-- }
--};
--
--struct FoldMulOpPattern final : FoldOpRewritePattern<mlir::stablehlo::MulOp> {
-- using FoldOpRewritePattern::FoldOpRewritePattern;
+ auto resultType = op.getType();
+ if (failed(validateShapeFoldDtype(rewriter, op, resultType)))
+ return failure();
@@ -1588,19 +1634,35 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
+ if (failed(res)) return failure();
+ rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(op, res.value());
+ return success();
-+ }
-+};
-+
+ }
+ };
+
+-struct FoldMulOpPattern final : FoldOpRewritePattern<mlir::stablehlo::MulOp> {
+- using FoldOpRewritePattern::FoldOpRewritePattern;
+// Clamp is folded using Min and Max folders.
+struct FoldClampOpPattern : public ShapeOpRewritePattern<ClampOp> {
+ using ShapeOpRewritePattern::ShapeOpRewritePattern;
-+
+
+- LogicalResult matchAndRewrite(mlir::stablehlo::MulOp op,
+ LogicalResult matchAndRewrite(ClampOp op,
-+ PatternRewriter& rewriter) const override {
+ PatternRewriter& rewriter) const override {
+- TypedAttr lhsAttr;
+- matchPattern(op.getLhs(), m_Constant(&lhsAttr));
+-
+- TypedAttr rhsAttr;
+- matchPattern(op.getRhs(), m_Constant(&rhsAttr));
+-
+- if (TypedAttr res;
+- lhsAttr && rhsAttr &&
+- (res = foldBinaryOpIntOrFloat(lhsAttr, rhsAttr, std::multiplies<>{}))) {
+- rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(op, res);
+- return success();
+- }
+ auto resultType = op.getType();
+ if (failed(validateShapeFoldDtype(rewriter, op, resultType)))
+ return failure();
-+
+
+- return failure();
+ TypedAttr minAttr, operandAttr, maxAttr;
+ matchPattern(op.getMin(), m_Constant(&minAttr));
+ matchPattern(op.getOperand(), m_Constant(&operandAttr));
@@ -1620,43 +1682,19 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
+ if (!res) return rewriter.notifyMatchFailure(op, "failed to fold clamp");
+ rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(op, res);
+ return success();
-+ }
-+};
-+
-+struct FoldMulOpPattern final : ShapeOpRewritePattern<mlir::stablehlo::MulOp> {
-+ using ShapeOpRewritePattern::ShapeOpRewritePattern;
+ }
+ };
- LogicalResult matchAndRewrite(mlir::stablehlo::MulOp op,
- PatternRewriter& rewriter) const override {
-- TypedAttr lhsAttr;
-- matchPattern(op.getLhs(), m_Constant(&lhsAttr));
--
-- TypedAttr rhsAttr;
-- matchPattern(op.getRhs(), m_Constant(&rhsAttr));
--
-- if (TypedAttr res;
-- lhsAttr && rhsAttr &&
-- (res = foldBinaryOpIntOrFloat(lhsAttr, rhsAttr, std::multiplies<>{}))) {
-- rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(op, res);
-- return success();
-- }
--
-- return failure();
-- }
--};
--
-struct EvalMulOpPattern : public FoldOpRewritePattern<MulOp> {
- using FoldOpRewritePattern::FoldOpRewritePattern;
--
++struct FoldMulOpPattern final : ShapeOpRewritePattern<mlir::stablehlo::MulOp> {
++ using ShapeOpRewritePattern::ShapeOpRewritePattern;
+
- LogicalResult matchAndRewrite(MulOp op,
-- PatternRewriter& rewriter) const override {
++ LogicalResult matchAndRewrite(mlir::stablehlo::MulOp op,
+ PatternRewriter& rewriter) const override {
- return evalElementwise(rewriter, op,
- [&](APSInt lhs, APSInt rhs) { return lhs * rhs; });
-- }
--};
--
--struct EvalOrOpPattern : public FoldOpRewritePattern<OrOp> {
-- using FoldOpRewritePattern::FoldOpRewritePattern;
+ if (failed(validateShapeFoldDtype(rewriter, op, op.getType())))
+ return failure();
+
@@ -1664,9 +1702,11 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
+ if (failed(res)) return failure();
+ rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(op, res.value());
+ return success();
-+ }
-+};
-+
+ }
+ };
+
+-struct EvalOrOpPattern : public FoldOpRewritePattern<OrOp> {
+- using FoldOpRewritePattern::FoldOpRewritePattern;
+struct FoldOrOpPattern : public ShapeOpRewritePattern<OrOp> {
+ using ShapeOpRewritePattern::ShapeOpRewritePattern;
@@ -1680,16 +1720,11 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
- return evalElementwise(rewriter, op, [&](APSInt lhsInt, APSInt rhsInt) {
- return getAPSInt(resultType.getElementType(), lhsInt != 0 || rhsInt != 0);
- });
-- }
--};
--
--struct EvalRemOpPattern : public FoldOpRewritePattern<RemOp> {
-- using FoldOpRewritePattern::FoldOpRewritePattern;
+ auto res = foldBinaryOpIntOrFloat(rewriter, op, FoldOr{});
+ if (failed(res)) return failure();
+ rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(op, res.value());
+ return success();
-+ }
+ }
+
+ struct FoldOr {
+ APInt operator()(APInt lhs, APInt rhs) const {
@@ -1699,8 +1734,10 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
+ return std::nullopt;
+ }
+ };
-+};
-+
+ };
+
+-struct EvalRemOpPattern : public FoldOpRewritePattern<RemOp> {
+- using FoldOpRewritePattern::FoldOpRewritePattern;
+struct FoldRemOpPattern : public ShapeOpRewritePattern<RemOp> {
+ using ShapeOpRewritePattern::ShapeOpRewritePattern;
@@ -1708,10 +1745,6 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
PatternRewriter& rewriter) const override {
- return evalElementwise(rewriter, op,
- [&](APSInt lhs, APSInt rhs) { return lhs % rhs; });
-- }
--};
--
--struct EvalReshapeOpPattern : public ShapeOpRewritePattern<ReshapeOp> {
+ auto resultType = op.getType();
+ if (failed(validateShapeFoldDtype(rewriter, op, resultType)))
+ return failure();
@@ -1721,7 +1754,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
+ if (failed(res)) return failure();
+ rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(op, res.value());
+ return success();
-+ }
+ }
+
+ struct FoldRem {
+ FoldRem(bool isUnsignedInt)
@@ -1736,14 +1769,15 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
+ static APInt foldUint(APInt lhs, APInt rhs) { return lhs.urem(rhs); }
+ static APInt foldSint(APInt lhs, APInt rhs) { return lhs.srem(rhs); }
+ };
-+};
-+
+ };
+
+-struct EvalReshapeOpPattern : public ShapeOpRewritePattern<ReshapeOp> {
+// Pattern: reshape(cst, shape) -> cst
+struct FoldReshapeOpPattern : public ShapeOpRewritePattern<ReshapeOp> {
using ShapeOpRewritePattern::ShapeOpRewritePattern;
LogicalResult matchAndRewrite(ReshapeOp op,
-@@ -656,7 +756,6 @@
+@@ -656,7 +756,6 @@ struct EvalReshapeOpPattern : public ShapeOpRewritePattern<ReshapeOp> {
failed(validateShapeFoldDtype(rewriter, op, resultType)))
return failure();
@@ -1751,7 +1785,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
DenseIntOrFPElementsAttr attr;
if (!matchPattern(op.getOperand(), m_Constant(&attr)))
return rewriter.notifyMatchFailure(op, "expected constant operand");
-@@ -665,53 +764,98 @@
+@@ -665,53 +764,98 @@ struct EvalReshapeOpPattern : public ShapeOpRewritePattern<ReshapeOp> {
}
};
@@ -1764,16 +1798,14 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
PatternRewriter& rewriter) const override {
auto resultType = op.getType();
- if (failed(validateStaticShapeResult(rewriter, op, resultType)))
-- return failure();
--
++ if (failed(validateStaticShapeResult(rewriter, op, resultType)) ||
++ failed(validateShapeFoldDtype(rewriter, op, resultType)))
+ return failure();
+
- SmallVector<APSInt> pred, onTrue, onFalse;
- if (failed(hlo::matchInts(op.getPred(), pred)) ||
- failed(hlo::matchInts(op.getOnTrue(), onTrue)) ||
- failed(hlo::matchInts(op.getOnFalse(), onFalse)))
-+ if (failed(validateStaticShapeResult(rewriter, op, resultType)) ||
-+ failed(validateShapeFoldDtype(rewriter, op, resultType)))
-+ return failure();
-+
+ DenseIntElementsAttr predAttr;
+ DenseElementsAttr onTrueAttr, onFalseAttr;
+ matchPattern(op.getPred(), m_Constant(&predAttr));
@@ -1783,14 +1815,17 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
return rewriter.notifyMatchFailure(op, "expected constant operands");
- SmallVector<APSInt> result;
+- for (auto [predEl, onTrueEl, onFalseEl] :
+- llvm::zip(pred, onTrue, onFalse)) {
+- result.push_back(predEl != 0 ? onTrueEl : onFalseEl);
+ // Optimization, handle splat predicate
+ if (isa<SplatElementsAttr>(predAttr)) {
+ auto pred = predAttr.getSplatValue<APInt>();
+ rewriter.replaceOpWithNewOp<ConstantOp>(
+ op, pred.isZero() ? onFalseAttr : onTrueAttr);
+ return success();
-+ }
-+
+ }
+
+ // TODO: Enable float folding.
+ if (op.getType().getElementType().isFloat())
+ return rewriter.notifyMatchFailure(op, "float select not supported yet");
@@ -1800,27 +1835,17 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
+ return failure();
+
+ SmallVector<APInt> result;
- for (auto [predEl, onTrueEl, onFalseEl] :
-- llvm::zip(pred, onTrue, onFalse)) {
-- result.push_back(predEl != 0 ? onTrueEl : onFalseEl);
-- }
--
++ for (auto [predEl, onTrueEl, onFalseEl] :
+ llvm::zip(predAttr.getValues<APInt>(), onTrueAttr.getValues<APInt>(),
+ onFalseAttr.getValues<APInt>())) {
+ result.push_back(!predEl.isZero() ? onTrueEl : onFalseEl);
+ }
rewriter.replaceOpWithNewOp<ConstantOp>(
- op, getTensorAttr(op.getType(), result));
-- return success();
-- }
--};
--
--struct EvalSignOpPattern : public FoldOpRewritePattern<SignOp> {
-- using FoldOpRewritePattern::FoldOpRewritePattern;
+ op, DenseIntElementsAttr::get(resultType, result));
+
-+ return success();
-+ }
+ return success();
+ }
+
+ struct FoldSelect {
+ std::optional<APFloat> operator()(APFloat pred, APFloat onTrue,
@@ -1832,8 +1857,10 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
+ return pred != 0 ? onTrue : onFalse;
+ }
+ };
-+};
-+
+ };
+
+-struct EvalSignOpPattern : public FoldOpRewritePattern<SignOp> {
+- using FoldOpRewritePattern::FoldOpRewritePattern;
+struct FoldSignOpPattern : public ShapeOpRewritePattern<SignOp> {
+ using ShapeOpRewritePattern::ShapeOpRewritePattern;
@@ -1881,7 +1908,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
};
template <typename RangeType>
-@@ -749,13 +893,14 @@
+@@ -749,13 +893,14 @@ DenseElementsAttr sliceType(SliceOp& op, const RangeType& data) {
ArrayRef<ElementType>(result));
}
@@ -1899,7 +1926,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
return failure();
auto operand = op.getOperand();
-@@ -784,36 +929,18 @@
+@@ -784,36 +929,18 @@ struct EvalSliceOpPattern : public FoldOpRewritePattern<SliceOp> {
};
struct FoldSubtractOpPattern final
@@ -1930,14 +1957,13 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
-
-struct EvalSubtractOpPattern : public FoldOpRewritePattern<SubtractOp> {
- using FoldOpRewritePattern::FoldOpRewritePattern;
--
++ if (failed(validateShapeFoldDtype(rewriter, op, op.getType())))
++ return failure();
+
- LogicalResult matchAndRewrite(SubtractOp op,
- PatternRewriter& rewriter) const override {
- return evalElementwise(rewriter, op,
- [&](APSInt lhs, APSInt rhs) { return lhs - rhs; });
-+ if (failed(validateShapeFoldDtype(rewriter, op, op.getType())))
-+ return failure();
-+
+ auto res = foldBinaryOpIntOrFloat(rewriter, op, std::minus<>{});
+ if (failed(res)) return failure();
+ rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(op, res.value());
@@ -1945,23 +1971,35 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
}
};
-@@ -823,42 +950,38 @@
+@@ -823,42 +950,38 @@ struct FoldSqrtOpPattern
LogicalResult matchAndRewrite(mlir::stablehlo::SqrtOp op,
PatternRewriter& rewriter) const final {
- TypedAttr lhsAttr;
- matchPattern(op.getOperand(), m_Constant(&lhsAttr));
--
++ auto res = foldUnaryOpIntOrFloat(rewriter, op, FoldSqrt());
++ if (failed(res)) return failure();
++ rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(op, res.value());
++ return success();
++ }
+
- if (!lhsAttr)
- return rewriter.notifyMatchFailure(op, "operand not constant");
--
++ struct FoldSqrt {
++ std::optional<APFloat> operator()(APFloat operand) {
++ if (operand.getSizeInBits(operand.getSemantics()) == 64)
++ return APFloat(std::sqrt(operand.convertToDouble()));
+
- if (auto res = constFoldUnaryOp<FloatAttr, FloatAttr::ValueType, void>(
- lhsAttr, foldSqrt)) {
- rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
- op, op.getType(), llvm::cast<ElementsAttr>(res));
- return success();
-- }
--
++ if (operand.getSizeInBits(operand.getSemantics()) == 32)
++ return APFloat(sqrtf(operand.convertToFloat()));
++ return std::nullopt;
+ }
+
- return rewriter.notifyMatchFailure(op, "unable to fold sqrt");
- }
-
@@ -1973,32 +2011,14 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
- return APFloat(sqrtf(a.convertToFloat()));
- return {};
- }
--};
--
--struct EvalIotaOpPattern : public FoldOpRewritePattern<IotaOp> {
-+ auto res = foldUnaryOpIntOrFloat(rewriter, op, FoldSqrt());
-+ if (failed(res)) return failure();
-+ rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(op, res.value());
-+ return success();
-+ }
-+
-+ struct FoldSqrt {
-+ std::optional<APFloat> operator()(APFloat operand) {
-+ if (operand.getSizeInBits(operand.getSemantics()) == 64)
-+ return APFloat(std::sqrt(operand.convertToDouble()));
-+
-+ if (operand.getSizeInBits(operand.getSemantics()) == 32)
-+ return APFloat(sqrtf(operand.convertToFloat()));
-+ return std::nullopt;
-+ }
-+
+ // TODO: Enable int folding.
+ std::optional<APInt> operator()(APInt operand) {
+ return std::nullopt;
+ }
+ };
-+};
-+
+ };
+
+-struct EvalIotaOpPattern : public FoldOpRewritePattern<IotaOp> {
+struct FoldIotaOpPattern : public FoldOpRewritePattern<IotaOp> {
using FoldOpRewritePattern::FoldOpRewritePattern;
@@ -2015,7 +2035,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
auto elementType = resultType.getElementType();
-@@ -929,7 +1052,7 @@
+@@ -929,7 +1052,7 @@ DenseElementsAttr transposeType(TransposeOp& op, const RangeType& data) {
// transpose(constant) => constant with permuted dimensions
// This covers ranked tensor types with 0 dimensions(zero elements) and 0
// rank(scalar), as well as splat values.
@@ -2024,7 +2044,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
using FoldOpRewritePattern::FoldOpRewritePattern;
LogicalResult matchAndRewrite(TransposeOp op,
-@@ -943,6 +1066,7 @@
+@@ -943,6 +1066,7 @@ struct EvalTransposeOpPattern : public FoldOpRewritePattern<TransposeOp> {
return rewriter.notifyMatchFailure(
op, "expected constant integer or float operand");
@@ -2032,7 +2052,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
DenseElementsAttr resAttr;
if (auto data = els.tryGetValues<APInt>())
resAttr = transposeType(op, *data);
-@@ -957,6 +1081,7 @@
+@@ -957,6 +1081,7 @@ struct EvalTransposeOpPattern : public FoldOpRewritePattern<TransposeOp> {
}
};
@@ -2040,7 +2060,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
struct LowerBoolSplatConstantsIntoReduceOpRegion
: public FoldOpRewritePattern<ReduceOp> {
using FoldOpRewritePattern::FoldOpRewritePattern;
-@@ -1160,7 +1285,7 @@
+@@ -1160,7 +1285,7 @@ bool hasNoDeclaredSideEffects(Operation* op) {
return true;
}
@@ -2049,7 +2069,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
: public FoldOpRewritePattern<WhileOp> {
using FoldOpRewritePattern::FoldOpRewritePattern;
-@@ -1233,23 +1358,16 @@
+@@ -1233,23 +1358,16 @@ void populateStablehloAggressiveFolderPatterns(
PatternBenefit benefit) {
populateStablehloShapeFolderPatterns(context, patterns, options, benefit);
@@ -2083,7 +2103,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
}
class StablehloTargetIndependentOptimizationPass {
-@@ -1266,25 +1384,25 @@
+@@ -1266,25 +1384,25 @@ void populateStablehloShapeFolderPatterns(
MLIRContext* context, RewritePatternSet* patterns,
const StablehloAggressiveFolderPassOptions& options,
PatternBenefit benefit) {
@@ -2128,28 +2148,29 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
}
void populateStablehloShapeFolderPatterns(MLIRContext* context,
-diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td b/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td
---- stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td
-+++ stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td
-@@ -48,6 +48,8 @@
- def RankEqual : Constraint<
+diff --git a/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td b/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td
+index 8adb6dbc..3a8edc61 100644
+--- a/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td
++++ b/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td
+@@ -49,6 +49,8 @@ def RankEqual : Constraint<
CPred<"llvm::cast<ShapedType>($0.getType()).getRank() == llvm::cast<ShapedType>($1.getType()).getRank()">,
"same rank">;
-+
-+def TensorDimsAllOne : Constraint<CPred<"tensorDimsAllOne($0, $1)">, "all tensor dims are 1">;
++def TensorDimsAllOne : Constraint<CPred<"tensorDimsAllOne($0, $1)">, "all tensor dims are 1">;
++
def TypesEqual : Constraint<CPred<"$0.getType() == $1.getType()">, "operands are equal">;
-@@ -100,6 +102,8 @@
- def ZeroExtent : AttrConstraint<
+ ///////////
+@@ -101,6 +103,8 @@ def ZeroExtent : AttrConstraint<
CPred<"cast<DenseElementsAttr>($_self).getNumElements() == 0">,
"is zero extent">;
-+
-+def AnyStaticShapeIntTensor : StaticShapeTensorOf<[HLO_Int]>;
++def AnyStaticShapeIntTensor : StaticShapeTensorOf<[HLO_Int]>;
++
///////////
//// Native Code Call Utilities
-@@ -503,7 +507,7 @@
+
+@@ -503,7 +507,7 @@ def SelectOp_InvertBroadcastPredicateAndSwap
// Must be static shape, otherwise would require broadcasting via
// CHLO_ConstantLike.
def SubtractOp_FoldToZero
@@ -2158,4 +2179,6 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimp
(StableHLO_ConstantLike<"0"> $operand)>;
// Pattern: subtract(X, 0) -> X
+--
+2.48.1
--
2.48.1