From 4c0819ac9fb9dfc6156ae4de83fb29e987ade780 Mon Sep 17 00:00:00 2001 From: Corentin Godeau Date: Mon, 16 Jun 2025 14:27:47 +0000 Subject: [PATCH] Update Stablehlo patch --- third_party/stablehlo/temporary.patch | 841 +++++++++++++------------- 1 file changed, 432 insertions(+), 409 deletions(-) mode change 100755 => 100644 third_party/stablehlo/temporary.patch diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch old mode 100755 new mode 100644 index f54e8341ba..6b25e78467 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -1,7 +1,47 @@ -diff --ruN a/stablehlo/stablehlo/dialect/VhloAttrs.td b/stablehlo/stablehlo/dialect/VhloAttrs.td ---- stablehlo/stablehlo/dialect/VhloAttrs.td -+++ stablehlo/stablehlo/dialect/VhloAttrs.td -@@ -45,14 +45,6 @@ +From c48aa0650bf27b81c42799f9218dd90779e36e3b Mon Sep 17 00:00:00 2001 +From: Corentin Godeau +Date: Mon, 16 Jun 2025 14:08:48 +0000 +Subject: [PATCH] Remove default argument that is not valid C + +--- + stablehlo/dialect/Version.h | 2 +- + stablehlo/dialect/VhloAttrs.td | 14 +- + stablehlo/dialect/VhloDialect.td | 1 + + stablehlo/dialect/VhloTypes.h | 24 +- + .../integrations/c/StablehloDialectApi.h | 2 +- + .../stablehlo_aggressive_folder.mlir | 37 +- + .../stablehlo_aggressive_simplification.mlir | 2 +- + .../transforms/stablehlo_refine_shapes.mlir | 10 +- + .../stablehlo_legalize_to_vhlo_mixed.mlir | 38 +- + .../tests/vhlo/vhlo_attributes_invalid.mlir | 13 +- + stablehlo/tools/StablehloTranslateMain.cpp | 8 +- + stablehlo/transforms/Passes.h | 7 +- + .../transforms/StablehloLegalizeToVhlo.cpp | 48 +- + .../transforms/StablehloRefineShapes.cpp | 4 + + .../transforms/VhloLegalizeToStablehlo.cpp | 100 ++- + stablehlo/transforms/VhloToVersion.cpp | 38 +- + .../StablehloAggressiveFolder.cpp | 816 ++++++++++-------- + ...ablehloAggressiveSimplificationPatterns.td | 6 +- + 18 files changed, 696 insertions(+), 474 deletions(-) + +diff --git a/stablehlo/dialect/Version.h b/stablehlo/dialect/Version.h +index eb35ad52..6aaef985 100644 +--- a/stablehlo/dialect/Version.h ++++ b/stablehlo/dialect/Version.h +@@ -38,7 +38,7 @@ class Version { + static FailureOr fromString(llvm::StringRef versionRef); + + /// Return a Version representing the current VHLO dialect version. +- static Version getCurrentVersion() { return Version(1, 10, 10); } ++ static Version getCurrentVersion() { return Version(1, 11, 0); } + + /// Return a Version representing the minimum supported VHLO dialect version. + static Version getMinimumVersion() { return Version(0, 9, 0); } +diff --git a/stablehlo/dialect/VhloAttrs.td b/stablehlo/dialect/VhloAttrs.td +index bf75ad27..ab48630a 100644 +--- a/stablehlo/dialect/VhloAttrs.td ++++ b/stablehlo/dialect/VhloAttrs.td +@@ -45,14 +45,6 @@ class VHLO_AttrDef def VHLO_ArrayAttrV1 : VHLO_AttrDef<"ArrayV1", "0.9.0", "current"> { let mnemonic = "array_v1"; let parameters = (ins ArrayRefParameter<"mlir::Attribute">:$value); @@ -16,7 +56,7 @@ diff --ruN a/stablehlo/stablehlo/dialect/VhloAttrs.td b/stablehlo/stablehlo/dial let assemblyFormat = "`<` custom($value) `>`"; } -@@ -75,9 +67,9 @@ +@@ -75,9 +67,9 @@ def VHLO_DictionaryAttrV1 : VHLO_AttrDef<"DictionaryV1", "0.9.0", "current"> { LogicalResult DictionaryV1Attr::verify( llvm::function_ref errFn, ArrayRef> value) { @@ -29,36 +69,48 @@ diff --ruN a/stablehlo/stablehlo/dialect/VhloAttrs.td b/stablehlo/stablehlo/dial return success(); } }]; -diff --ruN a/stablehlo/stablehlo/dialect/VhloTypes.h b/stablehlo/stablehlo/dialect/VhloTypes.h ---- stablehlo/stablehlo/dialect/VhloTypes.h -+++ stablehlo/stablehlo/dialect/VhloTypes.h -@@ -27,20 +27,23 @@ +diff --git a/stablehlo/dialect/VhloDialect.td b/stablehlo/dialect/VhloDialect.td +index 4e50cb32..edd220e3 100644 +--- a/stablehlo/dialect/VhloDialect.td ++++ b/stablehlo/dialect/VhloDialect.td +@@ -49,6 +49,7 @@ def VHLO_Dialect : Dialect { + 1.8.0: Introduce `f4E2M1FN`, `f6E2M3FN`, `f6E3M2FN` and `f8E8M0FNU` types. + 1.9.0: Add `ResultAccuracy` attribute to `exp` op. + 1.10.0: Add `ResultAccuracy` attribute to `cbrt`, `cosine`, `exponential`, `exponential_minus_one`, `log`, `log_plus_one`, `logistic`, `rsqrt`, `sine`, `sqrt`, `tan` and `tanh` ops. ++ 1.11.0: Allow (de)serializing VHLO programs mixed with potentially unstable dialects. + }]; + + let useDefaultAttributePrinterParser = 0; +diff --git a/stablehlo/dialect/VhloTypes.h b/stablehlo/dialect/VhloTypes.h +index e5aee254..f3a8d68d 100644 +--- a/stablehlo/dialect/VhloTypes.h ++++ b/stablehlo/dialect/VhloTypes.h +@@ -27,20 +27,23 @@ limitations under the License. namespace mlir { namespace vhlo { -class VhloTypeConverterBase : public TypeConverter { -- public: -- VhloTypeConverterBase() : TypeConverter(){}; -- -- virtual ~VhloTypeConverterBase() = default; -- -- virtual Attribute convertEncoding(Attribute attr) const = 0; --}; - - // This class is used to manage conversions between VHLO and Builtin - // dialects. --class VhloTypeConverter : public VhloTypeConverterBase { ++ ++// This class is used to manage conversions between VHLO and Builtin ++// dialects. +class VhloTypeConverter : public TypeConverter { public: -- VhloTypeConverter() : VhloTypeConverterBase() {} +- VhloTypeConverterBase() : TypeConverter(){}; + VhloTypeConverter() : TypeConverter(), allowOtherDialects(false) {} + VhloTypeConverter(bool allowOtherDialects) + : TypeConverter(), allowOtherDialects(allowOtherDialects) {} -+ + +- virtual ~VhloTypeConverterBase() = default; + virtual ~VhloTypeConverter() = default; -+ -+ virtual Attribute convertEncoding(Attribute attr) const = 0; -+ + + virtual Attribute convertEncoding(Attribute attr) const = 0; +-}; + +-// This class is used to manage conversions between VHLO and Builtin +-// dialects. +-class VhloTypeConverter : public VhloTypeConverterBase { +- public: +- VhloTypeConverter() : VhloTypeConverterBase() {} + Attribute convertUnknownAttribute(Attribute attr) const { + if (allowOtherDialects) return attr; + return {}; @@ -66,7 +118,7 @@ diff --ruN a/stablehlo/stablehlo/dialect/VhloTypes.h b/stablehlo/stablehlo/diale // A subclass can call this method to add conversions from VHLO -> Builtin // types. Note that conversions are applied in reverse order, with the most -@@ -58,6 +61,9 @@ +@@ -58,6 +61,9 @@ class VhloTypeConverter : public VhloTypeConverterBase { // Mark unrealized casts as legal. Useful for dialect mixing. void addUnrealizedMaterializations(); @@ -76,19 +128,34 @@ diff --ruN a/stablehlo/stablehlo/dialect/VhloTypes.h b/stablehlo/stablehlo/diale }; // Autogenerated VHLO type printers and parsers. -diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir b/stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir ---- stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir -+++ stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir +diff --git a/stablehlo/integrations/c/StablehloDialectApi.h b/stablehlo/integrations/c/StablehloDialectApi.h +index 385156bf..24d11c1d 100644 +--- a/stablehlo/integrations/c/StablehloDialectApi.h ++++ b/stablehlo/integrations/c/StablehloDialectApi.h +@@ -93,7 +93,7 @@ stablehloSerializePortableArtifactFromModule(MlirModule moduleStr, + MlirStringRef targetVersion, + MlirStringCallback callback, + void* userData, +- bool allowOtherDialects = false); ++ bool allowOtherDialects); + + // Read a StableHLO program from a portable artifact, returning the module as + // MLIR bytecode. Note, this bytecode returned is not a portable artifact, +diff --git a/stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir b/stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir +index 758965d0..c5239671 100644 +--- a/stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir ++++ b/stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir @@ -1,4 +1,4 @@ -// RUN: stablehlo-opt --stablehlo-aggressive-folder --split-input-file --verify-diagnostics %s | FileCheck %s +// RUN: stablehlo-opt --stablehlo-aggressive-folder=fold-op-element-limit=100 --split-input-file --verify-diagnostics %s | FileCheck %s //////// // AddOp -@@ -42,6 +42,21 @@ +@@ -41,6 +41,21 @@ func.func @broadcast_in_dim_fold_splat(%arg0: tensor<3x3xi32>) + // ----- - //////// ++//////// +// ClampOp + +// CHECK-LABEL: func.func @clamp_fold @@ -103,18 +170,13 @@ diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_folder.ml + +// ----- + -+//////// + //////// // CompareOp - // CHECK-LABEL: func.func @compare_folds -@@ -98,6 +113,26 @@ - // CHECK-DAG: [[R3:%.+]] = stablehlo.constant dense<{{\[\[0, 1, 2, 11, 12\], \[3, 4, 5, 13, 14\]\]}}> : tensor<2x5xi32> - // CHECK-NEXT: return [[R0]], [[R1]], [[R2]], [[R3]] - return %0, %1, %2, %3 : tensor<6xi32>, tensor<3xi32>, tensor<3x3xi32>, tensor<2x5xi32> -+} -+ -+// ----- -+ +@@ -102,6 +117,26 @@ func.func @concatenate_fold() -> (tensor<6xi32>, tensor<3xi32>, tensor<3x3xi32>, + + // ----- + +//////// +// DivOp + @@ -131,32 +193,37 @@ diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_folder.ml + %1 = stablehlo.divide %cst_1, %cst_1 : tensor + %2 = stablehlo.divide %cst_2, %cst_2 : tensor + return %0, %1, %2 : tensor, tensor, tensor - } ++} ++ ++// ----- ++ + //////// + // MulOp - // ----- -diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir b/stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir ---- stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir -+++ stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir +diff --git a/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir b/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir +index 4921a224..a1eed944 100644 +--- a/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir ++++ b/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir @@ -1,4 +1,4 @@ -// RUN: stablehlo-opt --stablehlo-aggressive-simplification --allow-unregistered-dialect --split-input-file %s | FileCheck %s +// RUN: stablehlo-opt --stablehlo-aggressive-simplification=fold-op-element-limit=100 --allow-unregistered-dialect --split-input-file %s | FileCheck %s ///////// // AddOp -diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_refine_shapes.mlir b/stablehlo/stablehlo/tests/transforms/stablehlo_refine_shapes.mlir ---- stablehlo/stablehlo/tests/transforms/stablehlo_refine_shapes.mlir -+++ stablehlo/stablehlo/tests/transforms/stablehlo_refine_shapes.mlir -@@ -521,16 +521,16 @@ +diff --git a/stablehlo/tests/transforms/stablehlo_refine_shapes.mlir b/stablehlo/tests/transforms/stablehlo_refine_shapes.mlir +index b262bf09..01b5dd8f 100644 +--- a/stablehlo/tests/transforms/stablehlo_refine_shapes.mlir ++++ b/stablehlo/tests/transforms/stablehlo_refine_shapes.mlir +@@ -521,16 +521,16 @@ func.func @eval_slice_zerodim() -> tensor<0x2x1xi64> { // ----- // CHECK-LABEL: func @eval_slice_zerorank -func.func @eval_slice_zerorank() -> tensor { - // CHECK: [[RESULT:%.*]] = stablehlo.constant dense<3.300000e+01> : tensor -- // CHECK: return [[RESULT]] -- %0 = stablehlo.constant dense<33.0> : tensor +func.func @eval_slice_zerorank() -> tensor { + // CHECK: [[RESULT:%.*]] = stablehlo.constant dense<33> : tensor -+ // CHECK: return [[RESULT]] + // CHECK: return [[RESULT]] +- %0 = stablehlo.constant dense<33.0> : tensor + %0 = stablehlo.constant dense<33> : tensor %1 = "stablehlo.slice"(%0) { start_indices = array, @@ -169,9 +236,10 @@ diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_refine_shapes.mlir b } // ----- -diff --ruN a/stablehlo/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo_mixed.mlir b/stablehlo/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo_mixed.mlir ---- stablehlo/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo_mixed.mlir -+++ stablehlo/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo_mixed.mlir +diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo_mixed.mlir b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo_mixed.mlir +index 6340898e..a5d344ac 100644 +--- a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo_mixed.mlir ++++ b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo_mixed.mlir @@ -5,15 +5,9 @@ // about what constitutes a good test! The CHECK should be // minimized and named to reflect the test intent. @@ -190,14 +258,10 @@ diff --ruN a/stablehlo/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo_mixed.mli // RUN: diff %t.0 %t.1 // CHECK-LABEL: vhlo.func_v1 @op_other( -@@ -26,6 +20,34 @@ - func.func @op_other(%arg0: tensor) -> tensor { - %0 = arith.addf %arg0, %arg0 : tensor - return %0 : tensor -+} -+ -+// ----- -+ +@@ -30,6 +24,34 @@ func.func @op_other(%arg0: tensor) -> tensor { + + // ----- + +// CHECK-LABEL: vhlo.func_v1 @func_attributes( +// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !vhlo.tensor_v1) -> (!vhlo.tensor_v1) { +// CHECK: "vhlo.return_v1"(%[[VAL_0]]) : (!vhlo.tensor_v1) -> () @@ -222,12 +286,17 @@ diff --ruN a/stablehlo/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo_mixed.mli + vhlo.mixed_dict = {affine_map = affine_map<(d0) -> (d0)>, str_attr = "STR_ATTR"} +} { + return %arg0 : tensor - } - - // ----- -diff --ruN a/stablehlo/stablehlo/tests/vhlo/vhlo_attributes_invalid.mlir b/stablehlo/stablehlo/tests/vhlo/vhlo_attributes_invalid.mlir ---- stablehlo/stablehlo/tests/vhlo/vhlo_attributes_invalid.mlir -+++ stablehlo/stablehlo/tests/vhlo/vhlo_attributes_invalid.mlir ++} ++ ++// ----- ++ + // CHECK-LABEL: vhlo.func_v1 @op_shlo( + // CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !vhlo.tensor_v1) -> (!vhlo.tensor_v1) { + // CHECK: %[[VAL_1:.*]] = "vhlo.add_v1"(%[[VAL_0]], %[[VAL_0]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 +diff --git a/stablehlo/tests/vhlo/vhlo_attributes_invalid.mlir b/stablehlo/tests/vhlo/vhlo_attributes_invalid.mlir +index b73d1005..f4330e67 100644 +--- a/stablehlo/tests/vhlo/vhlo_attributes_invalid.mlir ++++ b/stablehlo/tests/vhlo/vhlo_attributes_invalid.mlir @@ -1,17 +1,8 @@ // RUN: stablehlo-opt --vhlo-to-version=target=1.9.0 -verify-diagnostics --split-input-file %s @@ -248,10 +317,11 @@ diff --ruN a/stablehlo/stablehlo/tests/vhlo/vhlo_attributes_invalid.mlir b/stabl } { return } -diff --ruN a/stablehlo/stablehlo/tools/StablehloTranslateMain.cpp b/stablehlo/stablehlo/tools/StablehloTranslateMain.cpp ---- stablehlo/stablehlo/tools/StablehloTranslateMain.cpp -+++ stablehlo/stablehlo/tools/StablehloTranslateMain.cpp -@@ -76,6 +76,11 @@ +diff --git a/stablehlo/tools/StablehloTranslateMain.cpp b/stablehlo/tools/StablehloTranslateMain.cpp +index e3171b2d..fdf0d6a9 100644 +--- a/stablehlo/tools/StablehloTranslateMain.cpp ++++ b/stablehlo/tools/StablehloTranslateMain.cpp +@@ -76,6 +76,11 @@ llvm::cl::opt targetOption( "target", llvm::cl::desc("Target version for serialization"), llvm::cl::init("")); @@ -263,7 +333,7 @@ diff --ruN a/stablehlo/stablehlo/tools/StablehloTranslateMain.cpp b/stablehlo/st llvm::cl::opt argsOption( "args", llvm::cl::desc("Arguments to pass to the interpreter"), llvm::cl::init("")); -@@ -317,7 +322,8 @@ +@@ -317,7 +322,8 @@ TranslateFromMLIRRegistration serializeRegistration( return module.emitError("failed to strip debuginfo"); } @@ -273,10 +343,11 @@ diff --ruN a/stablehlo/stablehlo/tools/StablehloTranslateMain.cpp b/stablehlo/st }, [](DialectRegistry ®istry) { mlir::registerAllDialects(registry); -diff --ruN a/stablehlo/stablehlo/transforms/Passes.h b/stablehlo/stablehlo/transforms/Passes.h ---- stablehlo/stablehlo/transforms/Passes.h -+++ stablehlo/stablehlo/transforms/Passes.h -@@ -35,6 +35,7 @@ +diff --git a/stablehlo/transforms/Passes.h b/stablehlo/transforms/Passes.h +index 0e4406d2..a5a4aa27 100644 +--- a/stablehlo/transforms/Passes.h ++++ b/stablehlo/transforms/Passes.h +@@ -35,6 +35,7 @@ limitations under the License. #include "mlir/Transforms/DialectConversion.h" #include "stablehlo/dialect/StablehloOps.h" #include "stablehlo/dialect/Version.h" @@ -284,7 +355,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/Passes.h b/stablehlo/stablehlo/trans namespace mlir { namespace stablehlo { -@@ -54,17 +55,17 @@ +@@ -54,17 +55,17 @@ void populateStablehloRefineShapesPatterns(MLIRContext *context, // Populates StableHLO ops to VHLO ops rewriting patterns. void populateStablehloToVhloPatterns(MLIRContext *context, RewritePatternSet *patterns, @@ -305,10 +376,11 @@ diff --ruN a/stablehlo/stablehlo/transforms/Passes.h b/stablehlo/stablehlo/trans /// Collection of rewrite patterns for lowering of CHLO ops to StableHLO and /// Shape ops. -diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp ---- stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp -+++ stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp -@@ -57,7 +57,7 @@ +diff --git a/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stablehlo/transforms/StablehloLegalizeToVhlo.cpp +index e768ded6..25903399 100644 +--- a/stablehlo/transforms/StablehloLegalizeToVhlo.cpp ++++ b/stablehlo/transforms/StablehloLegalizeToVhlo.cpp +@@ -57,7 +57,7 @@ namespace { class StablehloToVhloTypeConverter : public vhlo::VhloTypeConverter { public: StablehloToVhloTypeConverter(bool allowOtherDialects) @@ -317,7 +389,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stable LLVM_DEBUG( llvm::dbgs() << "[StablehloToVhloTypeConverter] Creating with allowOtherDialects: " -@@ -82,7 +82,15 @@ +@@ -82,7 +82,15 @@ class StablehloToVhloTypeConverter : public vhlo::VhloTypeConverter { return vhlo::TokenV1Type::get(token.getContext()); }); addBuiltinToVhloConversions(); @@ -334,7 +406,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stable } Attribute convertEncoding(Attribute attr) const final { -@@ -114,7 +122,7 @@ +@@ -114,7 +122,7 @@ class StablehloToVhloTypeConverter : public vhlo::VhloTypeConverter { return vhlo::Name##Version##Attr::get(attr.getContext(), vhloValue.value()) Attribute convertGeneric(Attribute stablehloAttr, @@ -343,7 +415,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stable LLVM_DEBUG(llvm::dbgs() << "Convert generic: " << stablehloAttr << '\n'); // Handle StableHLO attributes. -@@ -241,6 +249,10 @@ +@@ -241,6 +249,10 @@ Attribute convertGeneric(Attribute stablehloAttr, return vhlo::TypeV1Attr::get(attr.getContext(), vhloType); } @@ -354,7 +426,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stable LLVM_DEBUG(llvm::dbgs() << "Failed to convert: " << stablehloAttr << '\n'); return {}; // Failed to convert attribute. } -@@ -268,13 +280,15 @@ +@@ -268,13 +280,15 @@ SpecialResult notSpecial() { return SpecialResult::NOT_SPECIAL; } Attribute convertBool(const ConversionPattern& pattern, int64_t stablehloDim) { auto stablehloType = IntegerType::get(pattern.getContext(), 1); auto stablehloAttr = IntegerAttr::get(stablehloType, stablehloDim); @@ -372,7 +444,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stable } Attribute convertInts(const ConversionPattern& pattern, -@@ -282,7 +296,8 @@ +@@ -282,7 +296,8 @@ Attribute convertInts(const ConversionPattern& pattern, auto stablehloType = RankedTensorType::get( stablehloDims.size(), IntegerType::get(pattern.getContext(), 64)); auto stablehloAttr = DenseIntElementsAttr::get(stablehloType, stablehloDims); @@ -382,7 +454,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stable } Attribute convertSymbol(const ConversionPattern& pattern, -@@ -290,7 +305,7 @@ +@@ -290,7 +305,7 @@ Attribute convertSymbol(const ConversionPattern& pattern, auto stablehloSymbolAttr = dyn_cast(stablehloAttr); if (!stablehloSymbolAttr) return {}; return convertGeneric(stablehloSymbolAttr.getAttr(), @@ -391,7 +463,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stable } SpecialResult convertChannelHandle(const ConversionPattern& pattern, -@@ -445,15 +460,15 @@ +@@ -445,15 +460,15 @@ SpecialResult convertDotAlgorithm(const ConversionPattern& pattern, vhloAttrs.emplace_back( StringAttr::get(pattern.getContext(), "lhs_precision_type"), convertGeneric(TypeAttr::get(attr.getLhsPrecisionType()), @@ -410,7 +482,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stable // Components auto vhloLhsComponentCount = convertInt(pattern, attr.getLhsComponentCount()); -@@ -712,7 +727,9 @@ +@@ -712,7 +727,9 @@ LogicalResult addDefaults(const OpConversionPattern& pattern, auto addDefaultAttr = [&](StringRef vhloName, Attribute stablehloAttr) { vhloAttrs.emplace_back( StringAttr::get(pattern.getContext(), vhloName), @@ -421,7 +493,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stable }; if constexpr (std::is_same::value) { if (!stablehloOp.getSymVisibilityAttr()) -@@ -987,8 +1004,9 @@ +@@ -987,8 +1004,9 @@ class StablehloToVhloOpConverter : public OpConversionPattern { case SpecialResult::SPECIAL_FAILURE: return failure(); case SpecialResult::NOT_SPECIAL: @@ -433,7 +505,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stable if (!vhloAttr) return failure(); vhloAttrs.push_back({stablehloAttr.getName(), vhloAttr}); break; -@@ -1075,14 +1093,14 @@ +@@ -1075,14 +1093,14 @@ struct StablehloLegalizeToVhloPass } private: @@ -450,10 +522,11 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stable populateStablehloToVhloPatterns< #define GET_OP_LIST #include "stablehlo/dialect/StablehloOps.cpp.inc" -diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp ---- stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp -+++ stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp -@@ -16,6 +16,7 @@ +diff --git a/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehlo/transforms/StablehloRefineShapes.cpp +index 64eaa3c0..4b585b46 100644 +--- a/stablehlo/transforms/StablehloRefineShapes.cpp ++++ b/stablehlo/transforms/StablehloRefineShapes.cpp +@@ -16,6 +16,7 @@ limitations under the License. #include #include @@ -461,7 +534,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehl #include #include -@@ -1038,8 +1039,11 @@ +@@ -1038,8 +1039,11 @@ LogicalResult applyShapeRefinementPatterns(func::FuncOp func, // Populate additional patterns for StableHLO extensions. state.addAdditionalPatterns(patterns); @@ -473,10 +546,11 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehl // The folding patterns implement partial evaluation of shape computations // which is a critical part of implementing type refinement for ops like -diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp ---- stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp -+++ stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp -@@ -57,7 +57,10 @@ +diff --git a/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stablehlo/transforms/VhloLegalizeToStablehlo.cpp +index 5662ad9d..41973a0e 100644 +--- a/stablehlo/transforms/VhloLegalizeToStablehlo.cpp ++++ b/stablehlo/transforms/VhloLegalizeToStablehlo.cpp +@@ -57,7 +57,10 @@ namespace { class VhloToStablehloTypeConverter : public vhlo::VhloTypeConverter { public: @@ -488,7 +562,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable addConversion([](Type type) -> Type { return type; }); addConversion([](vhlo::TokenV1Type token) -> Type { LLVM_DEBUG(llvm::dbgs() << "Converting TokenType\n"); -@@ -90,7 +93,7 @@ +@@ -90,7 +93,7 @@ class VhloToStablehloTypeConverter : public vhlo::VhloTypeConverter { return stablehlo::Name##Attr::get(attr.getContext(), stablehloValue.value()) Attribute convertGeneric(Attribute vhloAttr, @@ -497,18 +571,18 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable LLVM_DEBUG(llvm::dbgs() << "Converting attr " << vhloAttr); if (auto vhloAttrs = dyn_cast(vhloAttr)) { SmallVector stablehloAttrs; -@@ -189,6 +192,10 @@ - // All VHLO attributes must have counterparts in StableHLO. +@@ -190,6 +193,10 @@ Attribute convertGeneric(Attribute vhloAttr, return {}; } -+ + + // Fall back to type converter for unknown attributes. + auto unknownAttr = typeConverter->convertUnknownAttribute(vhloAttr); + if (unknownAttr) return unknownAttr; - ++ // This should be unreachable unless program is a mix of VHLO and other // dialects, e.g. due to user edits to textual assembly format. -@@ -229,7 +236,7 @@ + return {}; +@@ -229,7 +236,7 @@ bool isNoneType(Attribute vhloAttr) { } LogicalResult convertTypeAttr(Attribute vhloAttr, Type& result, @@ -517,7 +591,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable auto stablehloAttr = convertGeneric(vhloAttr, typeConverter); if (!stablehloAttr || !isa(stablehloAttr)) return failure(); result = cast(stablehloAttr).getValue(); -@@ -244,7 +251,7 @@ +@@ -244,7 +251,7 @@ LogicalResult convertInt(Attribute vhloAttr, int64_t& result) { } LogicalResult convertInts(Attribute vhloAttr, @@ -526,7 +600,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable SmallVector& result) { auto vhloTensorAttr = dyn_cast(vhloAttr); if (!vhloTensorAttr) return failure(); -@@ -256,7 +263,7 @@ +@@ -256,7 +263,7 @@ LogicalResult convertInts(Attribute vhloAttr, } Attribute convertSymbol(Attribute vhloAttr, @@ -535,7 +609,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable auto vhloStringAttr = dyn_cast(vhloAttr); if (!vhloStringAttr) return {}; auto stablehloStringAttr = dyn_cast_or_null( -@@ -267,7 +274,7 @@ +@@ -267,7 +274,7 @@ Attribute convertSymbol(Attribute vhloAttr, template Attribute convertChannelHandle(OpType vhloOp, @@ -544,7 +618,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable int64_t channelId, channelType; if (failed(convertInt(vhloOp.getChannelId(), channelId)) || failed(convertInt(vhloOp.getChannelType(), channelType))) -@@ -277,7 +284,7 @@ +@@ -277,7 +284,7 @@ Attribute convertChannelHandle(OpType vhloOp, } Attribute convertChannelId(Attribute vhloAttr, @@ -553,7 +627,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable int64_t channelId; if (failed(convertInt(vhloAttr, channelId))) return {}; return stablehlo::ChannelHandleAttr::get(vhloAttr.getContext(), channelId, -@@ -285,8 +292,8 @@ +@@ -285,8 +292,8 @@ Attribute convertChannelId(Attribute vhloAttr, } template @@ -564,7 +638,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable int64_t stablehloInputBatchDimension, stablehloInputFeatureDimension; SmallVector stablehloInputSpatialDimensions; int64_t stablehloKernelInputFeatureDimension, -@@ -323,7 +330,7 @@ +@@ -323,7 +330,7 @@ Attribute convertConvDimensionNumbers(OpType vhloOp, } Attribute convertCustomCallCalledComputations( @@ -573,7 +647,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable if (auto vhloArrayAttr = dyn_cast(vhloAttr)) { SmallVector stablehloAttrs; for (auto vhloAttr : vhloArrayAttr.getValue()) { -@@ -336,8 +343,8 @@ +@@ -336,8 +343,8 @@ Attribute convertCustomCallCalledComputations( return {}; } @@ -584,7 +658,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable Type lhsPrecisionType, rhsPrecisionType, accumulationType; if (isNoneType(vhloOp.getLhsComponentCount())) { // All must be nonetype -@@ -373,8 +380,8 @@ +@@ -373,8 +380,8 @@ FailureOr convertDotAlgorithm(vhlo::DotGeneralOpV2 vhloOp, numPrimitiveOperations, allowImpreciseAccumulation); } @@ -595,7 +669,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable SmallVector stablehloLhsBatchingDimensions, stablehloRhsBatchingDimensions, stablehloLhsContractingDimensions, stablehloRhsContractingDimensions; -@@ -394,13 +401,13 @@ +@@ -394,13 +401,13 @@ Attribute convertDotDimensionNumbers(vhlo::DotGeneralOpV2 vhloOp, } Attribute convertFuncCallee(Attribute vhloAttr, @@ -612,7 +686,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable SmallVector stablehloOffsetDims, stablehloCollapsedSliceDims, stablehloOperandBatchingDims, stablehloStartIndicesBatchingDims, stablehloStartIndexMap; -@@ -423,8 +430,8 @@ +@@ -423,8 +430,8 @@ Attribute convertGatherDimensionNumbers(OpType vhloOp, stablehloStartIndexMap, stablehloIndexVectorDim); } @@ -623,7 +697,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable SmallVector stablehloUpdateWindowDims, stablehloInsertedWindowDims, stablehloInputBatchingDims, stablehloScatterIndicesBatchingDims, stablehloScatterDimsToOperandDims; -@@ -463,10 +470,11 @@ +@@ -463,10 +470,11 @@ LogicalResult implodeSpecial(const OpConversionPattern& pattern, VhloOpTy vhloOp, SmallVector& vhloAttrs, SmallVector& stablehloAttrs) { @@ -637,7 +711,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable if (!stablehloAttr) return failure(); stablehloAttrs.emplace_back( StringAttr::get(pattern.getContext(), "dimension_numbers"), -@@ -480,7 +488,7 @@ +@@ -480,7 +488,7 @@ LogicalResult implodeSpecial(const OpConversionPattern& pattern, if constexpr (std::is_same::value) { // Dot Dimension Numbers auto stablehloDotDimAttr = @@ -646,7 +720,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable if (!stablehloDotDimAttr) return failure(); stablehloAttrs.emplace_back( StringAttr::get(pattern.getContext(), "dot_dimension_numbers"), -@@ -489,8 +497,7 @@ +@@ -489,8 +497,7 @@ LogicalResult implodeSpecial(const OpConversionPattern& pattern, "lhs_contracting_dimensions", "rhs_contracting_dimensions"); // Dot Algorithm @@ -656,7 +730,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable if (failed(stablehloDotAlgorithmAttr)) return failure(); if (stablehloDotAlgorithmAttr.value()) stablehloAttrs.emplace_back( -@@ -503,8 +510,7 @@ +@@ -503,8 +510,7 @@ LogicalResult implodeSpecial(const OpConversionPattern& pattern, } if constexpr (std::is_same::value || std::is_same::value) { @@ -666,7 +740,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable if (!stablehloAttr) return failure(); stablehloAttrs.emplace_back( StringAttr::get(pattern.getContext(), "dimension_numbers"), -@@ -514,8 +520,7 @@ +@@ -514,8 +520,7 @@ LogicalResult implodeSpecial(const OpConversionPattern& pattern, "start_index_map", "index_vector_dim"); } if constexpr (std::is_same::value) { @@ -676,7 +750,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable if (!stablehloAttr) return failure(); stablehloAttrs.emplace_back( StringAttr::get(pattern.getContext(), "scatter_dimension_numbers"), -@@ -526,8 +531,7 @@ +@@ -526,8 +531,7 @@ LogicalResult implodeSpecial(const OpConversionPattern& pattern, } if constexpr (std::is_same::value || std::is_same::value) { @@ -686,7 +760,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable if (!stablehloAttr) return failure(); stablehloAttrs.emplace_back( StringAttr::get(pattern.getContext(), "channel_handle"), stablehloAttr); -@@ -537,7 +541,7 @@ +@@ -537,7 +541,7 @@ LogicalResult implodeSpecial(const OpConversionPattern& pattern, } template @@ -695,7 +769,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable StringAttr vhloName, Attribute vhloAttr, SmallVector& stablehloAttrs) { auto tensorAttr = dyn_cast(vhloAttr); -@@ -556,15 +560,15 @@ +@@ -556,15 +560,15 @@ SpecialResult convertDenseArray(const TypeConverter* typeConverter, } SpecialResult convertDenseBoolArray( @@ -715,7 +789,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable return convertDenseArray( typeConverter, vhloName, vhloAttr, stablehloAttrs); } -@@ -575,7 +579,8 @@ +@@ -575,7 +579,8 @@ SpecialResult convertSpecial(const OpConversionPattern& pattern, SmallVector& stablehloAttrs) { StringAttr stablehloName = vhloName; Attribute stablehloAttr; @@ -725,7 +799,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable if constexpr (std::is_same::value || std::is_same::value || -@@ -585,7 +590,7 @@ +@@ -585,7 +590,7 @@ SpecialResult convertSpecial(const OpConversionPattern& pattern, std::is_same::value) { if (vhloName == "channel_id") { stablehloName = StringAttr::get(pattern.getContext(), "channel_handle"); @@ -734,7 +808,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable if (!stablehloAttr) return specialFailure(); } if (vhloName == "use_global_device_ids") { -@@ -597,20 +602,20 @@ +@@ -597,20 +602,20 @@ SpecialResult convertSpecial(const OpConversionPattern& pattern, } if constexpr (std::is_same::value) { if (vhloName == "called_computations") { @@ -759,7 +833,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable if (!stablehloAttr) return specialFailure(); } } -@@ -760,8 +765,8 @@ +@@ -760,8 +765,8 @@ bool isDefaultResultAccuracyAttribute(Attribute vhloAttr) { template bool isSplatTensor(const ConversionPattern& pattern, Attribute vhloAttr, T splatValue) { @@ -770,7 +844,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable return attr && attr.isSplat() && attr.template getSplatValue() == splatValue; } -@@ -977,8 +982,9 @@ +@@ -977,8 +982,9 @@ class VhloToStablehloOpConverter : public OpConversionPattern { case SpecialResult::SPECIAL_FAILURE: return failure(); case SpecialResult::NOT_SPECIAL: @@ -782,7 +856,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable if (!stablehloAttr) return failure(); stablehloAttrs.push_back({vhloAttr.getName(), stablehloAttr}); break; -@@ -1056,7 +1062,7 @@ +@@ -1056,7 +1062,7 @@ struct ReconcileUnrealizedConversionCasts template void populateVhloToStablehloPatterns(MLIRContext* context, RewritePatternSet* patterns, @@ -791,7 +865,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable patterns ->add>...>( *converter, context); -@@ -1104,7 +1110,7 @@ +@@ -1104,7 +1110,7 @@ struct VhloLegalizeToStablehloPass void populateVhloToStablehloPatterns(MLIRContext* context, RewritePatternSet* patterns, @@ -800,10 +874,11 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable populateVhloToStablehloPatterns< #define GET_OP_LIST #include "stablehlo/dialect/StablehloOps.cpp.inc" -diff --ruN a/stablehlo/stablehlo/transforms/VhloToVersion.cpp b/stablehlo/stablehlo/transforms/VhloToVersion.cpp ---- stablehlo/stablehlo/transforms/VhloToVersion.cpp -+++ stablehlo/stablehlo/transforms/VhloToVersion.cpp -@@ -56,15 +56,23 @@ +diff --git a/stablehlo/transforms/VhloToVersion.cpp b/stablehlo/transforms/VhloToVersion.cpp +index dfbf2877..7818e601 100644 +--- a/stablehlo/transforms/VhloToVersion.cpp ++++ b/stablehlo/transforms/VhloToVersion.cpp +@@ -56,15 +56,23 @@ namespace { // Currently there are no type-to-version conversions so this class // simply validates that all types are from the VHLO dialect. @@ -833,7 +908,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloToVersion.cpp b/stablehlo/stable }; // Check user-specified target version. Emit error if invalid. -@@ -111,6 +119,10 @@ +@@ -111,6 +119,10 @@ bool isLegalVersion(VersionedInterface& interface, const Version& target) { LogicalResult isLegalType(Type type, const Version& targetVersion); LogicalResult isLegalAttribute(const Attribute& attr, Version targetVersion) { @@ -844,7 +919,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloToVersion.cpp b/stablehlo/stable auto attrInterface = dyn_cast(attr); if (!attrInterface || !isLegalVersion(attrInterface, targetVersion)) { LLVM_DEBUG(llvm::dbgs() << "failed to legalize attribute " << attr -@@ -119,10 +131,11 @@ +@@ -119,10 +131,11 @@ LogicalResult isLegalAttribute(const Attribute& attr, Version targetVersion) { } // Recursively check attrs if VHLO attr is a container @@ -857,7 +932,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloToVersion.cpp b/stablehlo/stable if (auto arrAttr = dyn_cast(attr)) { return success(llvm::all_of( arrAttr.getValue(), [&](std::pair entry) { -@@ -146,6 +159,10 @@ +@@ -146,6 +159,10 @@ LogicalResult isLegalAttribute(const Attribute& attr, Version targetVersion) { } LogicalResult isLegalType(Type type, const Version& targetVersion) { @@ -868,7 +943,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloToVersion.cpp b/stablehlo/stable // All valid VHLO types must have versioned type interface. auto typeInterface = dyn_cast(type); if (!typeInterface || !isLegalVersion(typeInterface, targetVersion)) { -@@ -170,10 +187,11 @@ +@@ -170,10 +187,11 @@ LogicalResult isLegalType(Type type, const Version& targetVersion) { return failure(); return isLegalType(ranked.getElementType(), targetVersion); } @@ -881,7 +956,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloToVersion.cpp b/stablehlo/stable if (auto quant = dyn_cast(type)) return success( succeeded(isLegalType(quant.getStorageType(), targetVersion)) && -@@ -213,7 +231,7 @@ +@@ -213,7 +231,7 @@ bool isLegalLocation(Location loc, const Version& targetVersion) { bool isLegalOperation(Operation* op, const Version& targetVersion) { // Validate op auto opInterface = dyn_cast(op); @@ -890,7 +965,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloToVersion.cpp b/stablehlo/stable if (!isLegalVersion(opInterface, targetVersion)) return false; LLVM_DEBUG(llvm::dbgs() << "Legal op version for target. " << op << '\n'); -@@ -454,7 +472,7 @@ +@@ -454,7 +472,7 @@ struct AllReduceOpV2ToV1 : public OpRewritePattern { namespace stablehlo { void populateVhloToVersionPatterns(MLIRContext* context, RewritePatternSet* patterns, @@ -899,10 +974,11 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloToVersion.cpp b/stablehlo/stable vhlo::populateWithGenerated(*patterns); patterns->add(context); patterns->add(context); -diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp b/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp ---- stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp -+++ stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp -@@ -19,6 +19,7 @@ +diff --git a/stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp b/stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp +index 9cfac0ba..09bfb783 100644 +--- a/stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp ++++ b/stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp +@@ -19,6 +19,7 @@ limitations under the License. #include #include #include @@ -910,7 +986,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold #include #include "llvm/ADT/APInt.h" -@@ -47,6 +48,7 @@ +@@ -47,6 +48,7 @@ limitations under the License. #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" @@ -918,7 +994,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Pass/Pass.h" #include "mlir/Rewrite/FrozenRewritePatternSet.h" -@@ -98,29 +100,88 @@ +@@ -98,29 +100,88 @@ LogicalResult validateStaticShapeResult(PatternRewriter& rewriter, return success(); } @@ -982,8 +1058,8 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold if (res) return cast(res); return nullptr; -+} -+ + } + + +/// Binary constant folder that used a generic folder function to handle both +/// ints and floats. @@ -1005,8 +1081,8 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + if (!res) return rewriter.notifyMatchFailure(op, "folding failed"); + + return res; - } - ++} ++ template -LogicalResult evalConvertHelper(PatternRewriter& rewriter, OpType op, @@ -1014,7 +1090,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold DenseIntOrFPElementsAttr elements, Type resType, CalculationT&& calculate) { auto result = constFoldCastOp( elements, resType, calculate); @@ -1036,7 +1112,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold DenseIntOrFPElementsAttr elements, RankedTensorType resultType) { auto oldType = getElementTypeOrSelf(elements); -@@ -153,7 +215,7 @@ +@@ -153,7 +215,7 @@ LogicalResult evalConvert(PatternRewriter& rewriter, OpType op, if (auto newFloatType = dyn_cast(newType)) { // Float -> Float const auto& targetSemantics = newFloatType.getFloatSemantics(); @@ -1045,7 +1121,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold rewriter, op, elements, resultType, [&targetSemantics](const APFloat& operand, bool& castStatus) { bool losesInfo; -@@ -167,7 +229,7 @@ +@@ -167,7 +229,7 @@ LogicalResult evalConvert(PatternRewriter& rewriter, OpType op, } // Float -> Int @@ -1054,7 +1130,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold rewriter, op, elements, resultType, [&newBitWidth, &isNewTypeUnsigned](const APFloat& operand, bool& castStatus) { -@@ -186,7 +248,7 @@ +@@ -186,7 +248,7 @@ LogicalResult evalConvert(PatternRewriter& rewriter, OpType op, if (auto newFloatType = dyn_cast(newType)) { // Int -> Float @@ -1063,7 +1139,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold rewriter, op, elements, resultType, [&newFloatType, &isOldTypeUnsigned](const APInt& operand, bool& /*castStatus*/) { -@@ -199,64 +261,12 @@ +@@ -199,7 +261,7 @@ LogicalResult evalConvert(PatternRewriter& rewriter, OpType op, } // Int -> Int @@ -1072,10 +1148,10 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold rewriter, op, elements, resultType, [&newBitWidth, &isOldTypeUnsigned](const APInt& operand, bool& /*castStatus*/) { - return APSInt(operand, isOldTypeUnsigned).extOrTrunc(newBitWidth); +@@ -207,58 +269,6 @@ LogicalResult evalConvert(PatternRewriter& rewriter, OpType op, }); --} -- + } + -// The patterns below implement partial evaluation of shape computations which -// is a critical part of implementing type refinement for ops like -// dynamic_broadcast_in_dim, dynamic_iota and dynamic_reshape whose shape @@ -1126,10 +1202,12 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold - rewriter.replaceOpWithNewOp(op, - getTensorAttr(resultType, result)); - return success(); - } - +-} +- template -@@ -275,29 +285,18 @@ + struct FoldOpRewritePattern : OpRewritePattern { + FoldOpRewritePattern(MLIRContext* context, +@@ -275,29 +285,18 @@ struct FoldOpRewritePattern : OpRewritePattern { ArrayRef generatedNames = {}) = delete; const StablehloAggressiveFolderPassOptions& options; @@ -1154,9 +1232,8 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold - rewriter.replaceOpWithNewOp(op, res); - return success(); - } -- + - return failure(); -+ + LogicalResult validateElementCountForFold(PatternRewriter& rewriter, + Operation* op, + ShapedType resultType) const { @@ -1171,31 +1248,21 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold } }; -@@ -318,100 +317,102 @@ +@@ -318,100 +317,102 @@ struct ShapeOpRewritePattern : public FoldOpRewritePattern { } }; -struct EvalAddOpShapePattern : public FoldOpRewritePattern { - using FoldOpRewritePattern::FoldOpRewritePattern; -- -- LogicalResult matchAndRewrite(AddOp op, -- PatternRewriter& rewriter) const override { -- return evalElementwise(rewriter, op, -- [&](APSInt lhs, APSInt rhs) { return lhs + rhs; }); -- } --}; -- --struct EvalAndOpPattern : public FoldOpRewritePattern { -- using FoldOpRewritePattern::FoldOpRewritePattern; -- -- LogicalResult matchAndRewrite(AndOp op, -- PatternRewriter& rewriter) const override { +struct FoldAddOpPattern final + : public ShapeOpRewritePattern { + using ShapeOpRewritePattern::ShapeOpRewritePattern; -+ + +- LogicalResult matchAndRewrite(AddOp op, + LogicalResult matchAndRewrite(mlir::stablehlo::AddOp op, -+ PatternRewriter& rewriter) const override { + PatternRewriter& rewriter) const override { +- return evalElementwise(rewriter, op, +- [&](APSInt lhs, APSInt rhs) { return lhs + rhs; }); + if (failed(validateShapeFoldDtype(rewriter, op, op.getType()))) + return failure(); + @@ -1203,14 +1270,17 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + if (failed(res)) return failure(); + rewriter.replaceOpWithNewOp(op, res.value()); + return success(); -+ } -+}; -+ + } + }; + +-struct EvalAndOpPattern : public FoldOpRewritePattern { +- using FoldOpRewritePattern::FoldOpRewritePattern; +struct FoldAndOpPattern : public ShapeOpRewritePattern { + using ShapeOpRewritePattern::ShapeOpRewritePattern; -+ + +- LogicalResult matchAndRewrite(AndOp op, + LogicalResult matchAndRewrite(mlir::stablehlo::AndOp op, -+ PatternRewriter& rewriter) const override { + PatternRewriter& rewriter) const override { + // TODO: Support more int types auto resultType = op.getType(); if (!resultType.getElementType().isInteger(1)) @@ -1219,24 +1289,14 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold - return evalElementwise(rewriter, op, [&](APSInt lhsInt, APSInt rhsInt) { - return getAPSInt(resultType.getElementType(), lhsInt != 0 && rhsInt != 0); - }); -- } + auto res = foldBinaryOpIntOrFloat(rewriter, op, FoldAnd{}); + if (failed(res)) return failure(); + rewriter.replaceOpWithNewOp(op, res.value()); + return success(); -+ } -+ -+ struct FoldAnd { -+ APInt operator()(APInt lhs, APInt rhs) const { -+ return APInt(lhs.getBitWidth(), !lhs.isZero() && !rhs.isZero()); -+ } -+ std::optional operator()(APFloat lhs, APFloat rhs) const { -+ return std::nullopt; -+ } -+ }; - }; + } +-}; - // Pattern: broadcast_in_dim(splat, _) -> constant(splat) +-// Pattern: broadcast_in_dim(splat, _) -> constant(splat) -struct FoldBroadcastInDimSplatPattern final - : FoldOpRewritePattern { - using FoldOpRewritePattern::FoldOpRewritePattern; @@ -1251,14 +1311,22 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold - op, SplatElementsAttr::get(op.getType(), - cstAttr.getSplatValue())); - return success(); -- } ++ struct FoldAnd { ++ APInt operator()(APInt lhs, APInt rhs) const { ++ return APInt(lhs.getBitWidth(), !lhs.isZero() && !rhs.isZero()); + } - return failure(); - } --}; -- ++ std::optional operator()(APFloat lhs, APFloat rhs) const { ++ return std::nullopt; ++ } ++ }; + }; + -struct EvalBroadcastInDimOpPattern - : public FoldOpRewritePattern { - using FoldOpRewritePattern::FoldOpRewritePattern; ++// Pattern: broadcast_in_dim(splat, _) -> constant(splat) +struct FoldBroadcastInDimOpSplatPattern + : public ShapeOpRewritePattern { + using ShapeOpRewritePattern::ShapeOpRewritePattern; @@ -1267,8 +1335,10 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold PatternRewriter& rewriter) const override { auto resultType = op.getType(); - if (failed(validateStaticShapeResult(rewriter, op, resultType))) -- return failure(); -- ++ if (failed(validateStaticShapeResult(rewriter, op, resultType)) || ++ failed(validateShapeFoldDtype(rewriter, op, resultType))) + return failure(); + - auto operandType = op.getOperand().getType(); - if (operandType.getRank() != 0) - return rewriter.notifyMatchFailure(op, "expected 0-dimensional type"); @@ -1277,52 +1347,34 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold - if (failed(hlo::matchInts(op.getOperand(), operand))) - return rewriter.notifyMatchFailure(op, "expected constant operands"); - auto scalar = operand[0]; -- ++ SplatElementsAttr cstAttr; ++ matchPattern(op.getOperand(), m_Constant(&cstAttr)); ++ if (!cstAttr) return rewriter.notifyMatchFailure(op, "operand not splat"); + - rewriter.replaceOpWithNewOp( - op, getTensorAttr(op.getType(), scalar)); -- return success(); -- } --}; -- ++ rewriter.replaceOpWithNewOp( ++ op, SplatElementsAttr::get(op.getType(), ++ cstAttr.getSplatValue())); + return success(); + } + }; + -struct EvalClampOpPattern : public FoldOpRewritePattern { - using FoldOpRewritePattern::FoldOpRewritePattern; -- ++struct FoldCompareOpPattern : public ShapeOpRewritePattern { ++ using ShapeOpRewritePattern::ShapeOpRewritePattern; + - LogicalResult matchAndRewrite(ClampOp op, -- PatternRewriter& rewriter) const override { ++ LogicalResult matchAndRewrite(CompareOp op, + PatternRewriter& rewriter) const override { - return evalElementwise(rewriter, op, - [&](APSInt min, APSInt operand, APSInt max) { - if (operand < min) return min; - if (max < operand) return max; - return operand; - }); -- } --}; -- --struct EvalCompareOpPattern : public FoldOpRewritePattern { -- using FoldOpRewritePattern::FoldOpRewritePattern; -+ if (failed(validateStaticShapeResult(rewriter, op, resultType)) || -+ failed(validateShapeFoldDtype(rewriter, op, resultType))) -+ return failure(); -+ -+ SplatElementsAttr cstAttr; -+ matchPattern(op.getOperand(), m_Constant(&cstAttr)); -+ if (!cstAttr) return rewriter.notifyMatchFailure(op, "operand not splat"); -+ -+ rewriter.replaceOpWithNewOp( -+ op, SplatElementsAttr::get(op.getType(), -+ cstAttr.getSplatValue())); -+ return success(); -+ } -+}; -+ -+struct FoldCompareOpPattern : public ShapeOpRewritePattern { -+ using ShapeOpRewritePattern::ShapeOpRewritePattern; - - LogicalResult matchAndRewrite(CompareOp op, - PatternRewriter& rewriter) const override { - auto resultType = op.getType(); -- auto kind = op.getCompareType(); -- return evalElementwise(rewriter, op, [&](APInt lhs, APInt rhs) { ++ auto resultType = op.getType(); + if (failed(validateShapeFoldDtype(rewriter, op, resultType))) + return failure(); + @@ -1332,15 +1384,23 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + if (failed(res)) return failure(); + rewriter.replaceOpWithNewOp(op, res.value()); + return success(); -+ } -+ + } +-}; + +-struct EvalCompareOpPattern : public FoldOpRewritePattern { +- using FoldOpRewritePattern::FoldOpRewritePattern; + struct FoldCompare { + FoldCompare(ComparisonDirection direction, + std::optional kind) + : direction(direction), kind(kind) {} + ComparisonDirection direction; + std::optional kind; -+ + +- LogicalResult matchAndRewrite(CompareOp op, +- PatternRewriter& rewriter) const override { +- auto resultType = op.getType(); +- auto kind = op.getCompareType(); +- return evalElementwise(rewriter, op, [&](APInt lhs, APInt rhs) { + // TODO: Enable float folding. + std::optional operator()(APFloat lhs, APFloat rhs) { + return std::nullopt; @@ -1352,7 +1412,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold case ComparisonDirection::EQ: result = lhs == rhs; break; -@@ -431,9 +432,9 @@ +@@ -431,9 +432,9 @@ struct EvalCompareOpPattern : public FoldOpRewritePattern { result = kind == ComparisonType::SIGNED ? lhs.slt(rhs) : lhs.ult(rhs); break; } @@ -1365,7 +1425,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold }; ////////////////////////////////// -@@ -441,16 +442,15 @@ +@@ -441,16 +442,15 @@ struct EvalCompareOpPattern : public FoldOpRewritePattern { ///////////////////////////////// struct FoldConcatenateOpPattern final @@ -1387,7 +1447,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold return failure(); // Fold concatenate when all inputs are constants. -@@ -466,6 +466,7 @@ +@@ -466,6 +466,7 @@ struct FoldConcatenateOpPattern final int64_t{1}, std::multiplies<>{}); SmallVector newElems; @@ -1395,7 +1455,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold newElems.reserve(numElems); for (int64_t i = 0; i != topSize; ++i) { -@@ -485,31 +486,7 @@ +@@ -485,31 +486,7 @@ struct FoldConcatenateOpPattern final int64_t foldOpElementLimit; }; @@ -1428,20 +1488,17 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold using ShapeOpRewritePattern::ShapeOpRewritePattern; LogicalResult matchAndRewrite(ConvertOp op, -@@ -532,28 +509,50 @@ +@@ -532,28 +509,50 @@ struct EvalConvertOpPattern : public ShapeOpRewritePattern { return rewriter.notifyMatchFailure( op, "expected constant integer or float operand"); - return evalConvert(rewriter, op, elements, resultType); -- } --}; -- ++ return foldConvert(rewriter, op, elements, resultType); + } + }; + -struct EvalDivOpPattern : public FoldOpRewritePattern { - using FoldOpRewritePattern::FoldOpRewritePattern; -+ return foldConvert(rewriter, op, elements, resultType); -+ } -+}; -+ +struct FoldDivOpPattern : public ShapeOpRewritePattern { + using ShapeOpRewritePattern::ShapeOpRewritePattern; @@ -1449,12 +1506,6 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold PatternRewriter& rewriter) const override { - return evalElementwise(rewriter, op, - [&](APSInt lhs, APSInt rhs) { return lhs / rhs; }); -- } --}; -- --struct EvalGetDimensionSizeOpPattern -- : public FoldOpRewritePattern { -- using FoldOpRewritePattern::FoldOpRewritePattern; + auto resultType = op.getType(); + if (failed(validateShapeFoldDtype(rewriter, op, resultType))) + return failure(); @@ -1464,7 +1515,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + if (failed(res)) return failure(); + rewriter.replaceOpWithNewOp(op, res.value()); + return success(); -+ } + } + + struct FoldDivide { + FoldDivide(bool isUnsignedInt) @@ -1479,8 +1530,11 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + static APInt foldUint(APInt lhs, APInt rhs) { return lhs.udiv(rhs); } + static APInt foldSint(APInt lhs, APInt rhs) { return lhs.sdiv(rhs); } + }; -+}; -+ + }; + +-struct EvalGetDimensionSizeOpPattern +- : public FoldOpRewritePattern { +- using FoldOpRewritePattern::FoldOpRewritePattern; +struct FoldGetDimensionSizeOpPattern + : public ShapeOpRewritePattern { + using ShapeOpRewritePattern::ShapeOpRewritePattern; @@ -1494,7 +1548,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold return failure(); auto operandType = op.getOperand().getType(); -@@ -567,86 +566,187 @@ +@@ -567,86 +566,187 @@ struct EvalGetDimensionSizeOpPattern } }; @@ -1549,11 +1603,6 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold - return evalElementwise(rewriter, op, [&](APSInt lhs, APSInt rhs) { - return lhs >= rhs ? lhs : rhs; - }); -- } --}; -- --struct EvalMinOpPattern : public FoldOpRewritePattern { -- using FoldOpRewritePattern::FoldOpRewritePattern; + auto resultType = op.getType(); + if (failed(validateShapeFoldDtype(rewriter, op, resultType))) + return failure(); @@ -1563,9 +1612,11 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + if (failed(res)) return failure(); + rewriter.replaceOpWithNewOp(op, res.value()); + return success(); -+ } -+}; -+ + } + }; + +-struct EvalMinOpPattern : public FoldOpRewritePattern { +- using FoldOpRewritePattern::FoldOpRewritePattern; +struct FoldMinOpPattern : public ShapeOpRewritePattern { + using ShapeOpRewritePattern::ShapeOpRewritePattern; @@ -1574,11 +1625,6 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold - return evalElementwise(rewriter, op, [&](APSInt lhs, APSInt rhs) { - return lhs <= rhs ? lhs : rhs; - }); -- } --}; -- --struct FoldMulOpPattern final : FoldOpRewritePattern { -- using FoldOpRewritePattern::FoldOpRewritePattern; + auto resultType = op.getType(); + if (failed(validateShapeFoldDtype(rewriter, op, resultType))) + return failure(); @@ -1588,19 +1634,35 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + if (failed(res)) return failure(); + rewriter.replaceOpWithNewOp(op, res.value()); + return success(); -+ } -+}; -+ + } + }; + +-struct FoldMulOpPattern final : FoldOpRewritePattern { +- using FoldOpRewritePattern::FoldOpRewritePattern; +// Clamp is folded using Min and Max folders. +struct FoldClampOpPattern : public ShapeOpRewritePattern { + using ShapeOpRewritePattern::ShapeOpRewritePattern; -+ + +- LogicalResult matchAndRewrite(mlir::stablehlo::MulOp op, + LogicalResult matchAndRewrite(ClampOp op, -+ PatternRewriter& rewriter) const override { + PatternRewriter& rewriter) const override { +- TypedAttr lhsAttr; +- matchPattern(op.getLhs(), m_Constant(&lhsAttr)); +- +- TypedAttr rhsAttr; +- matchPattern(op.getRhs(), m_Constant(&rhsAttr)); +- +- if (TypedAttr res; +- lhsAttr && rhsAttr && +- (res = foldBinaryOpIntOrFloat(lhsAttr, rhsAttr, std::multiplies<>{}))) { +- rewriter.replaceOpWithNewOp(op, res); +- return success(); +- } + auto resultType = op.getType(); + if (failed(validateShapeFoldDtype(rewriter, op, resultType))) + return failure(); -+ + +- return failure(); + TypedAttr minAttr, operandAttr, maxAttr; + matchPattern(op.getMin(), m_Constant(&minAttr)); + matchPattern(op.getOperand(), m_Constant(&operandAttr)); @@ -1620,43 +1682,19 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + if (!res) return rewriter.notifyMatchFailure(op, "failed to fold clamp"); + rewriter.replaceOpWithNewOp(op, res); + return success(); -+ } -+}; -+ -+struct FoldMulOpPattern final : ShapeOpRewritePattern { -+ using ShapeOpRewritePattern::ShapeOpRewritePattern; + } + }; - LogicalResult matchAndRewrite(mlir::stablehlo::MulOp op, - PatternRewriter& rewriter) const override { -- TypedAttr lhsAttr; -- matchPattern(op.getLhs(), m_Constant(&lhsAttr)); -- -- TypedAttr rhsAttr; -- matchPattern(op.getRhs(), m_Constant(&rhsAttr)); -- -- if (TypedAttr res; -- lhsAttr && rhsAttr && -- (res = foldBinaryOpIntOrFloat(lhsAttr, rhsAttr, std::multiplies<>{}))) { -- rewriter.replaceOpWithNewOp(op, res); -- return success(); -- } -- -- return failure(); -- } --}; -- -struct EvalMulOpPattern : public FoldOpRewritePattern { - using FoldOpRewritePattern::FoldOpRewritePattern; -- ++struct FoldMulOpPattern final : ShapeOpRewritePattern { ++ using ShapeOpRewritePattern::ShapeOpRewritePattern; + - LogicalResult matchAndRewrite(MulOp op, -- PatternRewriter& rewriter) const override { ++ LogicalResult matchAndRewrite(mlir::stablehlo::MulOp op, + PatternRewriter& rewriter) const override { - return evalElementwise(rewriter, op, - [&](APSInt lhs, APSInt rhs) { return lhs * rhs; }); -- } --}; -- --struct EvalOrOpPattern : public FoldOpRewritePattern { -- using FoldOpRewritePattern::FoldOpRewritePattern; + if (failed(validateShapeFoldDtype(rewriter, op, op.getType()))) + return failure(); + @@ -1664,9 +1702,11 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + if (failed(res)) return failure(); + rewriter.replaceOpWithNewOp(op, res.value()); + return success(); -+ } -+}; -+ + } + }; + +-struct EvalOrOpPattern : public FoldOpRewritePattern { +- using FoldOpRewritePattern::FoldOpRewritePattern; +struct FoldOrOpPattern : public ShapeOpRewritePattern { + using ShapeOpRewritePattern::ShapeOpRewritePattern; @@ -1680,16 +1720,11 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold - return evalElementwise(rewriter, op, [&](APSInt lhsInt, APSInt rhsInt) { - return getAPSInt(resultType.getElementType(), lhsInt != 0 || rhsInt != 0); - }); -- } --}; -- --struct EvalRemOpPattern : public FoldOpRewritePattern { -- using FoldOpRewritePattern::FoldOpRewritePattern; + auto res = foldBinaryOpIntOrFloat(rewriter, op, FoldOr{}); + if (failed(res)) return failure(); + rewriter.replaceOpWithNewOp(op, res.value()); + return success(); -+ } + } + + struct FoldOr { + APInt operator()(APInt lhs, APInt rhs) const { @@ -1699,8 +1734,10 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + return std::nullopt; + } + }; -+}; -+ + }; + +-struct EvalRemOpPattern : public FoldOpRewritePattern { +- using FoldOpRewritePattern::FoldOpRewritePattern; +struct FoldRemOpPattern : public ShapeOpRewritePattern { + using ShapeOpRewritePattern::ShapeOpRewritePattern; @@ -1708,10 +1745,6 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold PatternRewriter& rewriter) const override { - return evalElementwise(rewriter, op, - [&](APSInt lhs, APSInt rhs) { return lhs % rhs; }); -- } --}; -- --struct EvalReshapeOpPattern : public ShapeOpRewritePattern { + auto resultType = op.getType(); + if (failed(validateShapeFoldDtype(rewriter, op, resultType))) + return failure(); @@ -1721,7 +1754,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + if (failed(res)) return failure(); + rewriter.replaceOpWithNewOp(op, res.value()); + return success(); -+ } + } + + struct FoldRem { + FoldRem(bool isUnsignedInt) @@ -1736,14 +1769,15 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + static APInt foldUint(APInt lhs, APInt rhs) { return lhs.urem(rhs); } + static APInt foldSint(APInt lhs, APInt rhs) { return lhs.srem(rhs); } + }; -+}; -+ + }; + +-struct EvalReshapeOpPattern : public ShapeOpRewritePattern { +// Pattern: reshape(cst, shape) -> cst +struct FoldReshapeOpPattern : public ShapeOpRewritePattern { using ShapeOpRewritePattern::ShapeOpRewritePattern; LogicalResult matchAndRewrite(ReshapeOp op, -@@ -656,7 +756,6 @@ +@@ -656,7 +756,6 @@ struct EvalReshapeOpPattern : public ShapeOpRewritePattern { failed(validateShapeFoldDtype(rewriter, op, resultType))) return failure(); @@ -1751,7 +1785,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold DenseIntOrFPElementsAttr attr; if (!matchPattern(op.getOperand(), m_Constant(&attr))) return rewriter.notifyMatchFailure(op, "expected constant operand"); -@@ -665,53 +764,98 @@ +@@ -665,53 +764,98 @@ struct EvalReshapeOpPattern : public ShapeOpRewritePattern { } }; @@ -1764,16 +1798,14 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold PatternRewriter& rewriter) const override { auto resultType = op.getType(); - if (failed(validateStaticShapeResult(rewriter, op, resultType))) -- return failure(); -- ++ if (failed(validateStaticShapeResult(rewriter, op, resultType)) || ++ failed(validateShapeFoldDtype(rewriter, op, resultType))) + return failure(); + - SmallVector pred, onTrue, onFalse; - if (failed(hlo::matchInts(op.getPred(), pred)) || - failed(hlo::matchInts(op.getOnTrue(), onTrue)) || - failed(hlo::matchInts(op.getOnFalse(), onFalse))) -+ if (failed(validateStaticShapeResult(rewriter, op, resultType)) || -+ failed(validateShapeFoldDtype(rewriter, op, resultType))) -+ return failure(); -+ + DenseIntElementsAttr predAttr; + DenseElementsAttr onTrueAttr, onFalseAttr; + matchPattern(op.getPred(), m_Constant(&predAttr)); @@ -1783,14 +1815,17 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold return rewriter.notifyMatchFailure(op, "expected constant operands"); - SmallVector result; +- for (auto [predEl, onTrueEl, onFalseEl] : +- llvm::zip(pred, onTrue, onFalse)) { +- result.push_back(predEl != 0 ? onTrueEl : onFalseEl); + // Optimization, handle splat predicate + if (isa(predAttr)) { + auto pred = predAttr.getSplatValue(); + rewriter.replaceOpWithNewOp( + op, pred.isZero() ? onFalseAttr : onTrueAttr); + return success(); -+ } -+ + } + + // TODO: Enable float folding. + if (op.getType().getElementType().isFloat()) + return rewriter.notifyMatchFailure(op, "float select not supported yet"); @@ -1800,27 +1835,17 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + return failure(); + + SmallVector result; - for (auto [predEl, onTrueEl, onFalseEl] : -- llvm::zip(pred, onTrue, onFalse)) { -- result.push_back(predEl != 0 ? onTrueEl : onFalseEl); -- } -- ++ for (auto [predEl, onTrueEl, onFalseEl] : + llvm::zip(predAttr.getValues(), onTrueAttr.getValues(), + onFalseAttr.getValues())) { + result.push_back(!predEl.isZero() ? onTrueEl : onFalseEl); + } rewriter.replaceOpWithNewOp( - op, getTensorAttr(op.getType(), result)); -- return success(); -- } --}; -- --struct EvalSignOpPattern : public FoldOpRewritePattern { -- using FoldOpRewritePattern::FoldOpRewritePattern; + op, DenseIntElementsAttr::get(resultType, result)); + -+ return success(); -+ } + return success(); + } + + struct FoldSelect { + std::optional operator()(APFloat pred, APFloat onTrue, @@ -1832,8 +1857,10 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold + return pred != 0 ? onTrue : onFalse; + } + }; -+}; -+ + }; + +-struct EvalSignOpPattern : public FoldOpRewritePattern { +- using FoldOpRewritePattern::FoldOpRewritePattern; +struct FoldSignOpPattern : public ShapeOpRewritePattern { + using ShapeOpRewritePattern::ShapeOpRewritePattern; @@ -1881,7 +1908,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold }; template -@@ -749,13 +893,14 @@ +@@ -749,13 +893,14 @@ DenseElementsAttr sliceType(SliceOp& op, const RangeType& data) { ArrayRef(result)); } @@ -1899,7 +1926,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold return failure(); auto operand = op.getOperand(); -@@ -784,36 +929,18 @@ +@@ -784,36 +929,18 @@ struct EvalSliceOpPattern : public FoldOpRewritePattern { }; struct FoldSubtractOpPattern final @@ -1930,14 +1957,13 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold - -struct EvalSubtractOpPattern : public FoldOpRewritePattern { - using FoldOpRewritePattern::FoldOpRewritePattern; -- ++ if (failed(validateShapeFoldDtype(rewriter, op, op.getType()))) ++ return failure(); + - LogicalResult matchAndRewrite(SubtractOp op, - PatternRewriter& rewriter) const override { - return evalElementwise(rewriter, op, - [&](APSInt lhs, APSInt rhs) { return lhs - rhs; }); -+ if (failed(validateShapeFoldDtype(rewriter, op, op.getType()))) -+ return failure(); -+ + auto res = foldBinaryOpIntOrFloat(rewriter, op, std::minus<>{}); + if (failed(res)) return failure(); + rewriter.replaceOpWithNewOp(op, res.value()); @@ -1945,23 +1971,35 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold } }; -@@ -823,42 +950,38 @@ +@@ -823,42 +950,38 @@ struct FoldSqrtOpPattern LogicalResult matchAndRewrite(mlir::stablehlo::SqrtOp op, PatternRewriter& rewriter) const final { - TypedAttr lhsAttr; - matchPattern(op.getOperand(), m_Constant(&lhsAttr)); -- ++ auto res = foldUnaryOpIntOrFloat(rewriter, op, FoldSqrt()); ++ if (failed(res)) return failure(); ++ rewriter.replaceOpWithNewOp(op, res.value()); ++ return success(); ++ } + - if (!lhsAttr) - return rewriter.notifyMatchFailure(op, "operand not constant"); -- ++ struct FoldSqrt { ++ std::optional operator()(APFloat operand) { ++ if (operand.getSizeInBits(operand.getSemantics()) == 64) ++ return APFloat(std::sqrt(operand.convertToDouble())); + - if (auto res = constFoldUnaryOp( - lhsAttr, foldSqrt)) { - rewriter.replaceOpWithNewOp( - op, op.getType(), llvm::cast(res)); - return success(); -- } -- ++ if (operand.getSizeInBits(operand.getSemantics()) == 32) ++ return APFloat(sqrtf(operand.convertToFloat())); ++ return std::nullopt; + } + - return rewriter.notifyMatchFailure(op, "unable to fold sqrt"); - } - @@ -1973,32 +2011,14 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold - return APFloat(sqrtf(a.convertToFloat())); - return {}; - } --}; -- --struct EvalIotaOpPattern : public FoldOpRewritePattern { -+ auto res = foldUnaryOpIntOrFloat(rewriter, op, FoldSqrt()); -+ if (failed(res)) return failure(); -+ rewriter.replaceOpWithNewOp(op, res.value()); -+ return success(); -+ } -+ -+ struct FoldSqrt { -+ std::optional operator()(APFloat operand) { -+ if (operand.getSizeInBits(operand.getSemantics()) == 64) -+ return APFloat(std::sqrt(operand.convertToDouble())); -+ -+ if (operand.getSizeInBits(operand.getSemantics()) == 32) -+ return APFloat(sqrtf(operand.convertToFloat())); -+ return std::nullopt; -+ } -+ + // TODO: Enable int folding. + std::optional operator()(APInt operand) { + return std::nullopt; + } + }; -+}; -+ + }; + +-struct EvalIotaOpPattern : public FoldOpRewritePattern { +struct FoldIotaOpPattern : public FoldOpRewritePattern { using FoldOpRewritePattern::FoldOpRewritePattern; @@ -2015,7 +2035,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold auto elementType = resultType.getElementType(); -@@ -929,7 +1052,7 @@ +@@ -929,7 +1052,7 @@ DenseElementsAttr transposeType(TransposeOp& op, const RangeType& data) { // transpose(constant) => constant with permuted dimensions // This covers ranked tensor types with 0 dimensions(zero elements) and 0 // rank(scalar), as well as splat values. @@ -2024,7 +2044,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold using FoldOpRewritePattern::FoldOpRewritePattern; LogicalResult matchAndRewrite(TransposeOp op, -@@ -943,6 +1066,7 @@ +@@ -943,6 +1066,7 @@ struct EvalTransposeOpPattern : public FoldOpRewritePattern { return rewriter.notifyMatchFailure( op, "expected constant integer or float operand"); @@ -2032,7 +2052,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold DenseElementsAttr resAttr; if (auto data = els.tryGetValues()) resAttr = transposeType(op, *data); -@@ -957,6 +1081,7 @@ +@@ -957,6 +1081,7 @@ struct EvalTransposeOpPattern : public FoldOpRewritePattern { } }; @@ -2040,7 +2060,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold struct LowerBoolSplatConstantsIntoReduceOpRegion : public FoldOpRewritePattern { using FoldOpRewritePattern::FoldOpRewritePattern; -@@ -1160,7 +1285,7 @@ +@@ -1160,7 +1285,7 @@ bool hasNoDeclaredSideEffects(Operation* op) { return true; } @@ -2049,7 +2069,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold : public FoldOpRewritePattern { using FoldOpRewritePattern::FoldOpRewritePattern; -@@ -1233,23 +1358,16 @@ +@@ -1233,23 +1358,16 @@ void populateStablehloAggressiveFolderPatterns( PatternBenefit benefit) { populateStablehloShapeFolderPatterns(context, patterns, options, benefit); @@ -2083,7 +2103,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold } class StablehloTargetIndependentOptimizationPass { -@@ -1266,25 +1384,25 @@ +@@ -1266,25 +1384,25 @@ void populateStablehloShapeFolderPatterns( MLIRContext* context, RewritePatternSet* patterns, const StablehloAggressiveFolderPassOptions& options, PatternBenefit benefit) { @@ -2128,28 +2148,29 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold } void populateStablehloShapeFolderPatterns(MLIRContext* context, -diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td b/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td ---- stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td -+++ stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td -@@ -48,6 +48,8 @@ - def RankEqual : Constraint< +diff --git a/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td b/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td +index 8adb6dbc..3a8edc61 100644 +--- a/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td ++++ b/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td +@@ -49,6 +49,8 @@ def RankEqual : Constraint< CPred<"llvm::cast($0.getType()).getRank() == llvm::cast($1.getType()).getRank()">, "same rank">; -+ -+def TensorDimsAllOne : Constraint, "all tensor dims are 1">; ++def TensorDimsAllOne : Constraint, "all tensor dims are 1">; ++ def TypesEqual : Constraint, "operands are equal">; -@@ -100,6 +102,8 @@ - def ZeroExtent : AttrConstraint< + /////////// +@@ -101,6 +103,8 @@ def ZeroExtent : AttrConstraint< CPred<"cast($_self).getNumElements() == 0">, "is zero extent">; -+ -+def AnyStaticShapeIntTensor : StaticShapeTensorOf<[HLO_Int]>; ++def AnyStaticShapeIntTensor : StaticShapeTensorOf<[HLO_Int]>; ++ /////////// //// Native Code Call Utilities -@@ -503,7 +507,7 @@ + +@@ -503,7 +507,7 @@ def SelectOp_InvertBroadcastPredicateAndSwap // Must be static shape, otherwise would require broadcasting via // CHLO_ConstantLike. def SubtractOp_FoldToZero @@ -2158,4 +2179,6 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimp (StableHLO_ConstantLike<"0"> $operand)>; // Pattern: subtract(X, 0) -> X +-- +2.48.1 -- 2.48.1