1856 lines
81 KiB
Diff
1856 lines
81 KiB
Diff
|
|
From 4c0819ac9fb9dfc6156ae4de83fb29e987ade780 Mon Sep 17 00:00:00 2001
|
||
|
|
From: Corentin Godeau <corentin@zml.ai>
|
||
|
|
Date: Mon, 16 Jun 2025 14:27:47 +0000
|
||
|
|
Subject: [PATCH] Update Stablehlo patch
|
||
|
|
|
||
|
|
---
|
||
|
|
third_party/stablehlo/temporary.patch | 841 +++++++++++++-------------
|
||
|
|
1 file changed, 432 insertions(+), 409 deletions(-)
|
||
|
|
mode change 100755 => 100644 third_party/stablehlo/temporary.patch
|
||
|
|
|
||
|
|
diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch
|
||
|
|
old mode 100755
|
||
|
|
new mode 100644
|
||
|
|
index f54e8341ba..6b25e78467
|
||
|
|
--- a/third_party/stablehlo/temporary.patch
|
||
|
|
+++ b/third_party/stablehlo/temporary.patch
|
||
|
|
@@ -1,7 +1,47 @@
|
||
|
|
-diff --ruN a/stablehlo/stablehlo/dialect/VhloAttrs.td b/stablehlo/stablehlo/dialect/VhloAttrs.td
|
||
|
|
---- stablehlo/stablehlo/dialect/VhloAttrs.td
|
||
|
|
-+++ stablehlo/stablehlo/dialect/VhloAttrs.td
|
||
|
|
-@@ -45,14 +45,6 @@
|
||
|
|
+From c48aa0650bf27b81c42799f9218dd90779e36e3b Mon Sep 17 00:00:00 2001
|
||
|
|
+From: Corentin Godeau <corentin@zml.ai>
|
||
|
|
+Date: Mon, 16 Jun 2025 14:08:48 +0000
|
||
|
|
+Subject: [PATCH] Remove default argument that is not valid C
|
||
|
|
+
|
||
|
|
+---
|
||
|
|
+ stablehlo/dialect/Version.h | 2 +-
|
||
|
|
+ stablehlo/dialect/VhloAttrs.td | 14 +-
|
||
|
|
+ stablehlo/dialect/VhloDialect.td | 1 +
|
||
|
|
+ stablehlo/dialect/VhloTypes.h | 24 +-
|
||
|
|
+ .../integrations/c/StablehloDialectApi.h | 2 +-
|
||
|
|
+ .../stablehlo_aggressive_folder.mlir | 37 +-
|
||
|
|
+ .../stablehlo_aggressive_simplification.mlir | 2 +-
|
||
|
|
+ .../transforms/stablehlo_refine_shapes.mlir | 10 +-
|
||
|
|
+ .../stablehlo_legalize_to_vhlo_mixed.mlir | 38 +-
|
||
|
|
+ .../tests/vhlo/vhlo_attributes_invalid.mlir | 13 +-
|
||
|
|
+ stablehlo/tools/StablehloTranslateMain.cpp | 8 +-
|
||
|
|
+ stablehlo/transforms/Passes.h | 7 +-
|
||
|
|
+ .../transforms/StablehloLegalizeToVhlo.cpp | 48 +-
|
||
|
|
+ .../transforms/StablehloRefineShapes.cpp | 4 +
|
||
|
|
+ .../transforms/VhloLegalizeToStablehlo.cpp | 100 ++-
|
||
|
|
+ stablehlo/transforms/VhloToVersion.cpp | 38 +-
|
||
|
|
+ .../StablehloAggressiveFolder.cpp | 816 ++++++++++--------
|
||
|
|
+ ...ablehloAggressiveSimplificationPatterns.td | 6 +-
|
||
|
|
+ 18 files changed, 696 insertions(+), 474 deletions(-)
|
||
|
|
+
|
||
|
|
+diff --git a/stablehlo/dialect/Version.h b/stablehlo/dialect/Version.h
|
||
|
|
+index eb35ad52..6aaef985 100644
|
||
|
|
+--- a/stablehlo/dialect/Version.h
|
||
|
|
++++ b/stablehlo/dialect/Version.h
|
||
|
|
+@@ -38,7 +38,7 @@ class Version {
|
||
|
|
+ static FailureOr<Version> fromString(llvm::StringRef versionRef);
|
||
|
|
+
|
||
|
|
+ /// Return a Version representing the current VHLO dialect version.
|
||
|
|
+- static Version getCurrentVersion() { return Version(1, 10, 10); }
|
||
|
|
++ static Version getCurrentVersion() { return Version(1, 11, 0); }
|
||
|
|
+
|
||
|
|
+ /// Return a Version representing the minimum supported VHLO dialect version.
|
||
|
|
+ static Version getMinimumVersion() { return Version(0, 9, 0); }
|
||
|
|
+diff --git a/stablehlo/dialect/VhloAttrs.td b/stablehlo/dialect/VhloAttrs.td
|
||
|
|
+index bf75ad27..ab48630a 100644
|
||
|
|
+--- a/stablehlo/dialect/VhloAttrs.td
|
||
|
|
++++ b/stablehlo/dialect/VhloAttrs.td
|
||
|
|
+@@ -45,14 +45,6 @@ class VHLO_AttrDef<string name, string minVersion, string maxVersion>
|
||
|
|
def VHLO_ArrayAttrV1 : VHLO_AttrDef<"ArrayV1", "0.9.0", "current"> {
|
||
|
|
let mnemonic = "array_v1";
|
||
|
|
let parameters = (ins ArrayRefParameter<"mlir::Attribute">:$value);
|
||
|
|
@@ -16,7 +56,7 @@ diff --ruN a/stablehlo/stablehlo/dialect/VhloAttrs.td b/stablehlo/stablehlo/dial
|
||
|
|
let assemblyFormat = "`<` custom<AttributeArray>($value) `>`";
|
||
|
|
}
|
||
|
|
|
||
|
|
-@@ -75,9 +67,9 @@
|
||
|
|
+@@ -75,9 +67,9 @@ def VHLO_DictionaryAttrV1 : VHLO_AttrDef<"DictionaryV1", "0.9.0", "current"> {
|
||
|
|
LogicalResult DictionaryV1Attr::verify(
|
||
|
|
llvm::function_ref<mlir::InFlightDiagnostic ()> errFn,
|
||
|
|
ArrayRef<std::pair<mlir::Attribute, mlir::Attribute>> value) {
|
||
|
|
@@ -29,36 +69,48 @@ diff --ruN a/stablehlo/stablehlo/dialect/VhloAttrs.td b/stablehlo/stablehlo/dial
|
||
|
|
return success();
|
||
|
|
}
|
||
|
|
}];
|
||
|
|
-diff --ruN a/stablehlo/stablehlo/dialect/VhloTypes.h b/stablehlo/stablehlo/dialect/VhloTypes.h
|
||
|
|
---- stablehlo/stablehlo/dialect/VhloTypes.h
|
||
|
|
-+++ stablehlo/stablehlo/dialect/VhloTypes.h
|
||
|
|
-@@ -27,20 +27,23 @@
|
||
|
|
+diff --git a/stablehlo/dialect/VhloDialect.td b/stablehlo/dialect/VhloDialect.td
|
||
|
|
+index 4e50cb32..edd220e3 100644
|
||
|
|
+--- a/stablehlo/dialect/VhloDialect.td
|
||
|
|
++++ b/stablehlo/dialect/VhloDialect.td
|
||
|
|
+@@ -49,6 +49,7 @@ def VHLO_Dialect : Dialect {
|
||
|
|
+ 1.8.0: Introduce `f4E2M1FN`, `f6E2M3FN`, `f6E3M2FN` and `f8E8M0FNU` types.
|
||
|
|
+ 1.9.0: Add `ResultAccuracy` attribute to `exp` op.
|
||
|
|
+ 1.10.0: Add `ResultAccuracy` attribute to `cbrt`, `cosine`, `exponential`, `exponential_minus_one`, `log`, `log_plus_one`, `logistic`, `rsqrt`, `sine`, `sqrt`, `tan` and `tanh` ops.
|
||
|
|
++ 1.11.0: Allow (de)serializing VHLO programs mixed with potentially unstable dialects.
|
||
|
|
+ }];
|
||
|
|
+
|
||
|
|
+ let useDefaultAttributePrinterParser = 0;
|
||
|
|
+diff --git a/stablehlo/dialect/VhloTypes.h b/stablehlo/dialect/VhloTypes.h
|
||
|
|
+index e5aee254..f3a8d68d 100644
|
||
|
|
+--- a/stablehlo/dialect/VhloTypes.h
|
||
|
|
++++ b/stablehlo/dialect/VhloTypes.h
|
||
|
|
+@@ -27,20 +27,23 @@ limitations under the License.
|
||
|
|
namespace mlir {
|
||
|
|
namespace vhlo {
|
||
|
|
|
||
|
|
-class VhloTypeConverterBase : public TypeConverter {
|
||
|
|
-- public:
|
||
|
|
-- VhloTypeConverterBase() : TypeConverter(){};
|
||
|
|
--
|
||
|
|
-- virtual ~VhloTypeConverterBase() = default;
|
||
|
|
--
|
||
|
|
-- virtual Attribute convertEncoding(Attribute attr) const = 0;
|
||
|
|
--};
|
||
|
|
-
|
||
|
|
- // This class is used to manage conversions between VHLO and Builtin
|
||
|
|
- // dialects.
|
||
|
|
--class VhloTypeConverter : public VhloTypeConverterBase {
|
||
|
|
++
|
||
|
|
++// This class is used to manage conversions between VHLO and Builtin
|
||
|
|
++// dialects.
|
||
|
|
+class VhloTypeConverter : public TypeConverter {
|
||
|
|
public:
|
||
|
|
-- VhloTypeConverter() : VhloTypeConverterBase() {}
|
||
|
|
+- VhloTypeConverterBase() : TypeConverter(){};
|
||
|
|
+ VhloTypeConverter() : TypeConverter(), allowOtherDialects(false) {}
|
||
|
|
+ VhloTypeConverter(bool allowOtherDialects)
|
||
|
|
+ : TypeConverter(), allowOtherDialects(allowOtherDialects) {}
|
||
|
|
-+
|
||
|
|
+
|
||
|
|
+- virtual ~VhloTypeConverterBase() = default;
|
||
|
|
+ virtual ~VhloTypeConverter() = default;
|
||
|
|
-+
|
||
|
|
-+ virtual Attribute convertEncoding(Attribute attr) const = 0;
|
||
|
|
-+
|
||
|
|
+
|
||
|
|
+ virtual Attribute convertEncoding(Attribute attr) const = 0;
|
||
|
|
+-};
|
||
|
|
+
|
||
|
|
+-// This class is used to manage conversions between VHLO and Builtin
|
||
|
|
+-// dialects.
|
||
|
|
+-class VhloTypeConverter : public VhloTypeConverterBase {
|
||
|
|
+- public:
|
||
|
|
+- VhloTypeConverter() : VhloTypeConverterBase() {}
|
||
|
|
+ Attribute convertUnknownAttribute(Attribute attr) const {
|
||
|
|
+ if (allowOtherDialects) return attr;
|
||
|
|
+ return {};
|
||
|
|
@@ -66,7 +118,7 @@ diff --ruN a/stablehlo/stablehlo/dialect/VhloTypes.h b/stablehlo/stablehlo/diale
|
||
|
|
|
||
|
|
// A subclass can call this method to add conversions from VHLO -> Builtin
|
||
|
|
// types. Note that conversions are applied in reverse order, with the most
|
||
|
|
-@@ -58,6 +61,9 @@
|
||
|
|
+@@ -58,6 +61,9 @@ class VhloTypeConverter : public VhloTypeConverterBase {
|
||
|
|
|
||
|
|
// Mark unrealized casts as legal. Useful for dialect mixing.
|
||
|
|
void addUnrealizedMaterializations();
|
||
|
|
@@ -76,19 +128,34 @@ diff --ruN a/stablehlo/stablehlo/dialect/VhloTypes.h b/stablehlo/stablehlo/diale
|
||
|
|
};
|
||
|
|
|
||
|
|
// Autogenerated VHLO type printers and parsers.
|
||
|
|
-diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir b/stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir
|
||
|
|
---- stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir
|
||
|
|
-+++ stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir
|
||
|
|
+diff --git a/stablehlo/integrations/c/StablehloDialectApi.h b/stablehlo/integrations/c/StablehloDialectApi.h
|
||
|
|
+index 385156bf..24d11c1d 100644
|
||
|
|
+--- a/stablehlo/integrations/c/StablehloDialectApi.h
|
||
|
|
++++ b/stablehlo/integrations/c/StablehloDialectApi.h
|
||
|
|
+@@ -93,7 +93,7 @@ stablehloSerializePortableArtifactFromModule(MlirModule moduleStr,
|
||
|
|
+ MlirStringRef targetVersion,
|
||
|
|
+ MlirStringCallback callback,
|
||
|
|
+ void* userData,
|
||
|
|
+- bool allowOtherDialects = false);
|
||
|
|
++ bool allowOtherDialects);
|
||
|
|
+
|
||
|
|
+ // Read a StableHLO program from a portable artifact, returning the module as
|
||
|
|
+ // MLIR bytecode. Note, this bytecode returned is not a portable artifact,
|
||
|
|
+diff --git a/stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir b/stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir
|
||
|
|
+index 758965d0..c5239671 100644
|
||
|
|
+--- a/stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir
|
||
|
|
++++ b/stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir
|
||
|
|
@@ -1,4 +1,4 @@
|
||
|
|
-// RUN: stablehlo-opt --stablehlo-aggressive-folder --split-input-file --verify-diagnostics %s | FileCheck %s
|
||
|
|
+// RUN: stablehlo-opt --stablehlo-aggressive-folder=fold-op-element-limit=100 --split-input-file --verify-diagnostics %s | FileCheck %s
|
||
|
|
|
||
|
|
////////
|
||
|
|
// AddOp
|
||
|
|
-@@ -42,6 +42,21 @@
|
||
|
|
+@@ -41,6 +41,21 @@ func.func @broadcast_in_dim_fold_splat(%arg0: tensor<3x3xi32>)
|
||
|
|
+
|
||
|
|
// -----
|
||
|
|
|
||
|
|
- ////////
|
||
|
|
++////////
|
||
|
|
+// ClampOp
|
||
|
|
+
|
||
|
|
+// CHECK-LABEL: func.func @clamp_fold
|
||
|
|
@@ -103,18 +170,13 @@ diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_folder.ml
|
||
|
|
+
|
||
|
|
+// -----
|
||
|
|
+
|
||
|
|
-+////////
|
||
|
|
+ ////////
|
||
|
|
// CompareOp
|
||
|
|
|
||
|
|
- // CHECK-LABEL: func.func @compare_folds
|
||
|
|
-@@ -98,6 +113,26 @@
|
||
|
|
- // CHECK-DAG: [[R3:%.+]] = stablehlo.constant dense<{{\[\[0, 1, 2, 11, 12\], \[3, 4, 5, 13, 14\]\]}}> : tensor<2x5xi32>
|
||
|
|
- // CHECK-NEXT: return [[R0]], [[R1]], [[R2]], [[R3]]
|
||
|
|
- return %0, %1, %2, %3 : tensor<6xi32>, tensor<3xi32>, tensor<3x3xi32>, tensor<2x5xi32>
|
||
|
|
-+}
|
||
|
|
-+
|
||
|
|
-+// -----
|
||
|
|
-+
|
||
|
|
+@@ -102,6 +117,26 @@ func.func @concatenate_fold() -> (tensor<6xi32>, tensor<3xi32>, tensor<3x3xi32>,
|
||
|
|
+
|
||
|
|
+ // -----
|
||
|
|
+
|
||
|
|
+////////
|
||
|
|
+// DivOp
|
||
|
|
+
|
||
|
|
@@ -131,32 +193,37 @@ diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_folder.ml
|
||
|
|
+ %1 = stablehlo.divide %cst_1, %cst_1 : tensor<ui32>
|
||
|
|
+ %2 = stablehlo.divide %cst_2, %cst_2 : tensor<f32>
|
||
|
|
+ return %0, %1, %2 : tensor<i32>, tensor<ui32>, tensor<f32>
|
||
|
|
- }
|
||
|
|
++}
|
||
|
|
++
|
||
|
|
++// -----
|
||
|
|
++
|
||
|
|
+ ////////
|
||
|
|
+ // MulOp
|
||
|
|
|
||
|
|
- // -----
|
||
|
|
-diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir b/stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir
|
||
|
|
---- stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir
|
||
|
|
-+++ stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir
|
||
|
|
+diff --git a/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir b/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir
|
||
|
|
+index 4921a224..a1eed944 100644
|
||
|
|
+--- a/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir
|
||
|
|
++++ b/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir
|
||
|
|
@@ -1,4 +1,4 @@
|
||
|
|
-// RUN: stablehlo-opt --stablehlo-aggressive-simplification --allow-unregistered-dialect --split-input-file %s | FileCheck %s
|
||
|
|
+// RUN: stablehlo-opt --stablehlo-aggressive-simplification=fold-op-element-limit=100 --allow-unregistered-dialect --split-input-file %s | FileCheck %s
|
||
|
|
|
||
|
|
/////////
|
||
|
|
// AddOp
|
||
|
|
-diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_refine_shapes.mlir b/stablehlo/stablehlo/tests/transforms/stablehlo_refine_shapes.mlir
|
||
|
|
---- stablehlo/stablehlo/tests/transforms/stablehlo_refine_shapes.mlir
|
||
|
|
-+++ stablehlo/stablehlo/tests/transforms/stablehlo_refine_shapes.mlir
|
||
|
|
-@@ -521,16 +521,16 @@
|
||
|
|
+diff --git a/stablehlo/tests/transforms/stablehlo_refine_shapes.mlir b/stablehlo/tests/transforms/stablehlo_refine_shapes.mlir
|
||
|
|
+index b262bf09..01b5dd8f 100644
|
||
|
|
+--- a/stablehlo/tests/transforms/stablehlo_refine_shapes.mlir
|
||
|
|
++++ b/stablehlo/tests/transforms/stablehlo_refine_shapes.mlir
|
||
|
|
+@@ -521,16 +521,16 @@ func.func @eval_slice_zerodim() -> tensor<0x2x1xi64> {
|
||
|
|
// -----
|
||
|
|
|
||
|
|
// CHECK-LABEL: func @eval_slice_zerorank
|
||
|
|
-func.func @eval_slice_zerorank() -> tensor<f32> {
|
||
|
|
- // CHECK: [[RESULT:%.*]] = stablehlo.constant dense<3.300000e+01> : tensor<f32>
|
||
|
|
-- // CHECK: return [[RESULT]]
|
||
|
|
-- %0 = stablehlo.constant dense<33.0> : tensor<f32>
|
||
|
|
+func.func @eval_slice_zerorank() -> tensor<i32> {
|
||
|
|
+ // CHECK: [[RESULT:%.*]] = stablehlo.constant dense<33> : tensor<i32>
|
||
|
|
-+ // CHECK: return [[RESULT]]
|
||
|
|
+ // CHECK: return [[RESULT]]
|
||
|
|
+- %0 = stablehlo.constant dense<33.0> : tensor<f32>
|
||
|
|
+ %0 = stablehlo.constant dense<33> : tensor<i32>
|
||
|
|
%1 = "stablehlo.slice"(%0) {
|
||
|
|
start_indices = array<i64>,
|
||
|
|
@@ -169,9 +236,10 @@ diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_refine_shapes.mlir b
|
||
|
|
}
|
||
|
|
|
||
|
|
// -----
|
||
|
|
-diff --ruN a/stablehlo/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo_mixed.mlir b/stablehlo/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo_mixed.mlir
|
||
|
|
---- stablehlo/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo_mixed.mlir
|
||
|
|
-+++ stablehlo/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo_mixed.mlir
|
||
|
|
+diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo_mixed.mlir b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo_mixed.mlir
|
||
|
|
+index 6340898e..a5d344ac 100644
|
||
|
|
+--- a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo_mixed.mlir
|
||
|
|
++++ b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo_mixed.mlir
|
||
|
|
@@ -5,15 +5,9 @@
|
||
|
|
// about what constitutes a good test! The CHECK should be
|
||
|
|
// minimized and named to reflect the test intent.
|
||
|
|
@@ -190,14 +258,10 @@ diff --ruN a/stablehlo/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo_mixed.mli
|
||
|
|
// RUN: diff %t.0 %t.1
|
||
|
|
|
||
|
|
// CHECK-LABEL: vhlo.func_v1 @op_other(
|
||
|
|
-@@ -26,6 +20,34 @@
|
||
|
|
- func.func @op_other(%arg0: tensor<f32>) -> tensor<f32> {
|
||
|
|
- %0 = arith.addf %arg0, %arg0 : tensor<f32>
|
||
|
|
- return %0 : tensor<f32>
|
||
|
|
-+}
|
||
|
|
-+
|
||
|
|
-+// -----
|
||
|
|
-+
|
||
|
|
+@@ -30,6 +24,34 @@ func.func @op_other(%arg0: tensor<f32>) -> tensor<f32> {
|
||
|
|
+
|
||
|
|
+ // -----
|
||
|
|
+
|
||
|
|
+// CHECK-LABEL: vhlo.func_v1 @func_attributes(
|
||
|
|
+// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !vhlo.tensor_v1<!vhlo.f32_v1>) -> (!vhlo.tensor_v1<!vhlo.f32_v1>) {
|
||
|
|
+// CHECK: "vhlo.return_v1"(%[[VAL_0]]) : (!vhlo.tensor_v1<!vhlo.f32_v1>) -> ()
|
||
|
|
@@ -222,12 +286,17 @@ diff --ruN a/stablehlo/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo_mixed.mli
|
||
|
|
+ vhlo.mixed_dict = {affine_map = affine_map<(d0) -> (d0)>, str_attr = "STR_ATTR"}
|
||
|
|
+} {
|
||
|
|
+ return %arg0 : tensor<f32>
|
||
|
|
- }
|
||
|
|
-
|
||
|
|
- // -----
|
||
|
|
-diff --ruN a/stablehlo/stablehlo/tests/vhlo/vhlo_attributes_invalid.mlir b/stablehlo/stablehlo/tests/vhlo/vhlo_attributes_invalid.mlir
|
||
|
|
---- stablehlo/stablehlo/tests/vhlo/vhlo_attributes_invalid.mlir
|
||
|
|
-+++ stablehlo/stablehlo/tests/vhlo/vhlo_attributes_invalid.mlir
|
||
|
|
++}
|
||
|
|
++
|
||
|
|
++// -----
|
||
|
|
++
|
||
|
|
+ // CHECK-LABEL: vhlo.func_v1 @op_shlo(
|
||
|
|
+ // CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !vhlo.tensor_v1<!vhlo.f32_v1>) -> (!vhlo.tensor_v1<!vhlo.f32_v1>) {
|
||
|
|
+ // CHECK: %[[VAL_1:.*]] = "vhlo.add_v1"(%[[VAL_0]], %[[VAL_0]]) : (!vhlo.tensor_v1<!vhlo.f32_v1>, !vhlo.tensor_v1<!vhlo.f32_v1>) -> !vhlo.tensor_v1<!vhlo.f32_v1>
|
||
|
|
+diff --git a/stablehlo/tests/vhlo/vhlo_attributes_invalid.mlir b/stablehlo/tests/vhlo/vhlo_attributes_invalid.mlir
|
||
|
|
+index b73d1005..f4330e67 100644
|
||
|
|
+--- a/stablehlo/tests/vhlo/vhlo_attributes_invalid.mlir
|
||
|
|
++++ b/stablehlo/tests/vhlo/vhlo_attributes_invalid.mlir
|
||
|
|
@@ -1,17 +1,8 @@
|
||
|
|
// RUN: stablehlo-opt --vhlo-to-version=target=1.9.0 -verify-diagnostics --split-input-file %s
|
||
|
|
|
||
|
|
@@ -248,10 +317,11 @@ diff --ruN a/stablehlo/stablehlo/tests/vhlo/vhlo_attributes_invalid.mlir b/stabl
|
||
|
|
} {
|
||
|
|
return
|
||
|
|
}
|
||
|
|
-diff --ruN a/stablehlo/stablehlo/tools/StablehloTranslateMain.cpp b/stablehlo/stablehlo/tools/StablehloTranslateMain.cpp
|
||
|
|
---- stablehlo/stablehlo/tools/StablehloTranslateMain.cpp
|
||
|
|
-+++ stablehlo/stablehlo/tools/StablehloTranslateMain.cpp
|
||
|
|
-@@ -76,6 +76,11 @@
|
||
|
|
+diff --git a/stablehlo/tools/StablehloTranslateMain.cpp b/stablehlo/tools/StablehloTranslateMain.cpp
|
||
|
|
+index e3171b2d..fdf0d6a9 100644
|
||
|
|
+--- a/stablehlo/tools/StablehloTranslateMain.cpp
|
||
|
|
++++ b/stablehlo/tools/StablehloTranslateMain.cpp
|
||
|
|
+@@ -76,6 +76,11 @@ llvm::cl::opt<std::string> targetOption(
|
||
|
|
"target", llvm::cl::desc("Target version for serialization"),
|
||
|
|
llvm::cl::init(""));
|
||
|
|
|
||
|
|
@@ -263,7 +333,7 @@ diff --ruN a/stablehlo/stablehlo/tools/StablehloTranslateMain.cpp b/stablehlo/st
|
||
|
|
llvm::cl::opt<std::string> argsOption(
|
||
|
|
"args", llvm::cl::desc("Arguments to pass to the interpreter"),
|
||
|
|
llvm::cl::init(""));
|
||
|
|
-@@ -317,7 +322,8 @@
|
||
|
|
+@@ -317,7 +322,8 @@ TranslateFromMLIRRegistration serializeRegistration(
|
||
|
|
return module.emitError("failed to strip debuginfo");
|
||
|
|
}
|
||
|
|
|
||
|
|
@@ -273,10 +343,11 @@ diff --ruN a/stablehlo/stablehlo/tools/StablehloTranslateMain.cpp b/stablehlo/st
|
||
|
|
},
|
||
|
|
[](DialectRegistry ®istry) {
|
||
|
|
mlir::registerAllDialects(registry);
|
||
|
|
-diff --ruN a/stablehlo/stablehlo/transforms/Passes.h b/stablehlo/stablehlo/transforms/Passes.h
|
||
|
|
---- stablehlo/stablehlo/transforms/Passes.h
|
||
|
|
-+++ stablehlo/stablehlo/transforms/Passes.h
|
||
|
|
-@@ -35,6 +35,7 @@
|
||
|
|
+diff --git a/stablehlo/transforms/Passes.h b/stablehlo/transforms/Passes.h
|
||
|
|
+index 0e4406d2..a5a4aa27 100644
|
||
|
|
+--- a/stablehlo/transforms/Passes.h
|
||
|
|
++++ b/stablehlo/transforms/Passes.h
|
||
|
|
+@@ -35,6 +35,7 @@ limitations under the License.
|
||
|
|
#include "mlir/Transforms/DialectConversion.h"
|
||
|
|
#include "stablehlo/dialect/StablehloOps.h"
|
||
|
|
#include "stablehlo/dialect/Version.h"
|
||
|
|
@@ -284,7 +355,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/Passes.h b/stablehlo/stablehlo/trans
|
||
|
|
|
||
|
|
namespace mlir {
|
||
|
|
namespace stablehlo {
|
||
|
|
-@@ -54,17 +55,17 @@
|
||
|
|
+@@ -54,17 +55,17 @@ void populateStablehloRefineShapesPatterns(MLIRContext *context,
|
||
|
|
// Populates StableHLO ops to VHLO ops rewriting patterns.
|
||
|
|
void populateStablehloToVhloPatterns(MLIRContext *context,
|
||
|
|
RewritePatternSet *patterns,
|
||
|
|
@@ -305,10 +376,11 @@ diff --ruN a/stablehlo/stablehlo/transforms/Passes.h b/stablehlo/stablehlo/trans
|
||
|
|
|
||
|
|
/// Collection of rewrite patterns for lowering of CHLO ops to StableHLO and
|
||
|
|
/// Shape ops.
|
||
|
|
-diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp
|
||
|
|
---- stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp
|
||
|
|
-+++ stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp
|
||
|
|
-@@ -57,7 +57,7 @@
|
||
|
|
+diff --git a/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stablehlo/transforms/StablehloLegalizeToVhlo.cpp
|
||
|
|
+index e768ded6..25903399 100644
|
||
|
|
+--- a/stablehlo/transforms/StablehloLegalizeToVhlo.cpp
|
||
|
|
++++ b/stablehlo/transforms/StablehloLegalizeToVhlo.cpp
|
||
|
|
+@@ -57,7 +57,7 @@ namespace {
|
||
|
|
class StablehloToVhloTypeConverter : public vhlo::VhloTypeConverter {
|
||
|
|
public:
|
||
|
|
StablehloToVhloTypeConverter(bool allowOtherDialects)
|
||
|
|
@@ -317,7 +389,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stable
|
||
|
|
LLVM_DEBUG(
|
||
|
|
llvm::dbgs()
|
||
|
|
<< "[StablehloToVhloTypeConverter] Creating with allowOtherDialects: "
|
||
|
|
-@@ -82,7 +82,15 @@
|
||
|
|
+@@ -82,7 +82,15 @@ class StablehloToVhloTypeConverter : public vhlo::VhloTypeConverter {
|
||
|
|
return vhlo::TokenV1Type::get(token.getContext());
|
||
|
|
});
|
||
|
|
addBuiltinToVhloConversions();
|
||
|
|
@@ -334,7 +406,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stable
|
||
|
|
}
|
||
|
|
|
||
|
|
Attribute convertEncoding(Attribute attr) const final {
|
||
|
|
-@@ -114,7 +122,7 @@
|
||
|
|
+@@ -114,7 +122,7 @@ class StablehloToVhloTypeConverter : public vhlo::VhloTypeConverter {
|
||
|
|
return vhlo::Name##Version##Attr::get(attr.getContext(), vhloValue.value())
|
||
|
|
|
||
|
|
Attribute convertGeneric(Attribute stablehloAttr,
|
||
|
|
@@ -343,7 +415,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stable
|
||
|
|
LLVM_DEBUG(llvm::dbgs() << "Convert generic: " << stablehloAttr << '\n');
|
||
|
|
|
||
|
|
// Handle StableHLO attributes.
|
||
|
|
-@@ -241,6 +249,10 @@
|
||
|
|
+@@ -241,6 +249,10 @@ Attribute convertGeneric(Attribute stablehloAttr,
|
||
|
|
return vhlo::TypeV1Attr::get(attr.getContext(), vhloType);
|
||
|
|
}
|
||
|
|
|
||
|
|
@@ -354,7 +426,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stable
|
||
|
|
LLVM_DEBUG(llvm::dbgs() << "Failed to convert: " << stablehloAttr << '\n');
|
||
|
|
return {}; // Failed to convert attribute.
|
||
|
|
}
|
||
|
|
-@@ -268,13 +280,15 @@
|
||
|
|
+@@ -268,13 +280,15 @@ SpecialResult notSpecial() { return SpecialResult::NOT_SPECIAL; }
|
||
|
|
Attribute convertBool(const ConversionPattern& pattern, int64_t stablehloDim) {
|
||
|
|
auto stablehloType = IntegerType::get(pattern.getContext(), 1);
|
||
|
|
auto stablehloAttr = IntegerAttr::get(stablehloType, stablehloDim);
|
||
|
|
@@ -372,7 +444,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stable
|
||
|
|
}
|
||
|
|
|
||
|
|
Attribute convertInts(const ConversionPattern& pattern,
|
||
|
|
-@@ -282,7 +296,8 @@
|
||
|
|
+@@ -282,7 +296,8 @@ Attribute convertInts(const ConversionPattern& pattern,
|
||
|
|
auto stablehloType = RankedTensorType::get(
|
||
|
|
stablehloDims.size(), IntegerType::get(pattern.getContext(), 64));
|
||
|
|
auto stablehloAttr = DenseIntElementsAttr::get(stablehloType, stablehloDims);
|
||
|
|
@@ -382,7 +454,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stable
|
||
|
|
}
|
||
|
|
|
||
|
|
Attribute convertSymbol(const ConversionPattern& pattern,
|
||
|
|
-@@ -290,7 +305,7 @@
|
||
|
|
+@@ -290,7 +305,7 @@ Attribute convertSymbol(const ConversionPattern& pattern,
|
||
|
|
auto stablehloSymbolAttr = dyn_cast<FlatSymbolRefAttr>(stablehloAttr);
|
||
|
|
if (!stablehloSymbolAttr) return {};
|
||
|
|
return convertGeneric(stablehloSymbolAttr.getAttr(),
|
||
|
|
@@ -391,7 +463,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stable
|
||
|
|
}
|
||
|
|
|
||
|
|
SpecialResult convertChannelHandle(const ConversionPattern& pattern,
|
||
|
|
-@@ -445,15 +460,15 @@
|
||
|
|
+@@ -445,15 +460,15 @@ SpecialResult convertDotAlgorithm(const ConversionPattern& pattern,
|
||
|
|
vhloAttrs.emplace_back(
|
||
|
|
StringAttr::get(pattern.getContext(), "lhs_precision_type"),
|
||
|
|
convertGeneric(TypeAttr::get(attr.getLhsPrecisionType()),
|
||
|
|
@@ -410,7 +482,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stable
|
||
|
|
|
||
|
|
// Components
|
||
|
|
auto vhloLhsComponentCount = convertInt(pattern, attr.getLhsComponentCount());
|
||
|
|
-@@ -712,7 +727,9 @@
|
||
|
|
+@@ -712,7 +727,9 @@ LogicalResult addDefaults(const OpConversionPattern<StablehloOpTy>& pattern,
|
||
|
|
auto addDefaultAttr = [&](StringRef vhloName, Attribute stablehloAttr) {
|
||
|
|
vhloAttrs.emplace_back(
|
||
|
|
StringAttr::get(pattern.getContext(), vhloName),
|
||
|
|
@@ -421,7 +493,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stable
|
||
|
|
};
|
||
|
|
if constexpr (std::is_same<StablehloOpTy, func::FuncOp>::value) {
|
||
|
|
if (!stablehloOp.getSymVisibilityAttr())
|
||
|
|
-@@ -987,8 +1004,9 @@
|
||
|
|
+@@ -987,8 +1004,9 @@ class StablehloToVhloOpConverter : public OpConversionPattern<StablehloOpTy> {
|
||
|
|
case SpecialResult::SPECIAL_FAILURE:
|
||
|
|
return failure();
|
||
|
|
case SpecialResult::NOT_SPECIAL:
|
||
|
|
@@ -433,7 +505,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stable
|
||
|
|
if (!vhloAttr) return failure();
|
||
|
|
vhloAttrs.push_back({stablehloAttr.getName(), vhloAttr});
|
||
|
|
break;
|
||
|
|
-@@ -1075,14 +1093,14 @@
|
||
|
|
+@@ -1075,14 +1093,14 @@ struct StablehloLegalizeToVhloPass
|
||
|
|
}
|
||
|
|
|
||
|
|
private:
|
||
|
|
@@ -450,10 +522,11 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stable
|
||
|
|
populateStablehloToVhloPatterns<
|
||
|
|
#define GET_OP_LIST
|
||
|
|
#include "stablehlo/dialect/StablehloOps.cpp.inc"
|
||
|
|
-diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp
|
||
|
|
---- stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp
|
||
|
|
-+++ stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp
|
||
|
|
-@@ -16,6 +16,7 @@
|
||
|
|
+diff --git a/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehlo/transforms/StablehloRefineShapes.cpp
|
||
|
|
+index 64eaa3c0..4b585b46 100644
|
||
|
|
+--- a/stablehlo/transforms/StablehloRefineShapes.cpp
|
||
|
|
++++ b/stablehlo/transforms/StablehloRefineShapes.cpp
|
||
|
|
+@@ -16,6 +16,7 @@ limitations under the License.
|
||
|
|
|
||
|
|
#include <cstddef>
|
||
|
|
#include <cstdint>
|
||
|
|
@@ -461,7 +534,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehl
|
||
|
|
#include <tuple>
|
||
|
|
#include <utility>
|
||
|
|
|
||
|
|
-@@ -1038,8 +1039,11 @@
|
||
|
|
+@@ -1038,8 +1039,11 @@ LogicalResult applyShapeRefinementPatterns(func::FuncOp func,
|
||
|
|
// Populate additional patterns for StableHLO extensions.
|
||
|
|
state.addAdditionalPatterns(patterns);
|
||
|
|
|
||
|
|
@@ -473,10 +546,11 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehl
|
||
|
|
|
||
|
|
// The folding patterns implement partial evaluation of shape computations
|
||
|
|
// which is a critical part of implementing type refinement for ops like
|
||
|
|
-diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp
|
||
|
|
---- stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp
|
||
|
|
-+++ stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp
|
||
|
|
-@@ -57,7 +57,10 @@
|
||
|
|
+diff --git a/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stablehlo/transforms/VhloLegalizeToStablehlo.cpp
|
||
|
|
+index 5662ad9d..41973a0e 100644
|
||
|
|
+--- a/stablehlo/transforms/VhloLegalizeToStablehlo.cpp
|
||
|
|
++++ b/stablehlo/transforms/VhloLegalizeToStablehlo.cpp
|
||
|
|
+@@ -57,7 +57,10 @@ namespace {
|
||
|
|
|
||
|
|
class VhloToStablehloTypeConverter : public vhlo::VhloTypeConverter {
|
||
|
|
public:
|
||
|
|
@@ -488,7 +562,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable
|
||
|
|
addConversion([](Type type) -> Type { return type; });
|
||
|
|
addConversion([](vhlo::TokenV1Type token) -> Type {
|
||
|
|
LLVM_DEBUG(llvm::dbgs() << "Converting TokenType\n");
|
||
|
|
-@@ -90,7 +93,7 @@
|
||
|
|
+@@ -90,7 +93,7 @@ class VhloToStablehloTypeConverter : public vhlo::VhloTypeConverter {
|
||
|
|
return stablehlo::Name##Attr::get(attr.getContext(), stablehloValue.value())
|
||
|
|
|
||
|
|
Attribute convertGeneric(Attribute vhloAttr,
|
||
|
|
@@ -497,18 +571,18 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable
|
||
|
|
LLVM_DEBUG(llvm::dbgs() << "Converting attr " << vhloAttr);
|
||
|
|
if (auto vhloAttrs = dyn_cast<vhlo::ArrayV1Attr>(vhloAttr)) {
|
||
|
|
SmallVector<Attribute> stablehloAttrs;
|
||
|
|
-@@ -189,6 +192,10 @@
|
||
|
|
- // All VHLO attributes must have counterparts in StableHLO.
|
||
|
|
+@@ -190,6 +193,10 @@ Attribute convertGeneric(Attribute vhloAttr,
|
||
|
|
return {};
|
||
|
|
}
|
||
|
|
-+
|
||
|
|
+
|
||
|
|
+ // Fall back to type converter for unknown attributes.
|
||
|
|
+ auto unknownAttr = typeConverter->convertUnknownAttribute(vhloAttr);
|
||
|
|
+ if (unknownAttr) return unknownAttr;
|
||
|
|
-
|
||
|
|
++
|
||
|
|
// This should be unreachable unless program is a mix of VHLO and other
|
||
|
|
// dialects, e.g. due to user edits to textual assembly format.
|
||
|
|
-@@ -229,7 +236,7 @@
|
||
|
|
+ return {};
|
||
|
|
+@@ -229,7 +236,7 @@ bool isNoneType(Attribute vhloAttr) {
|
||
|
|
}
|
||
|
|
|
||
|
|
LogicalResult convertTypeAttr(Attribute vhloAttr, Type& result,
|
||
|
|
@@ -517,7 +591,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable
|
||
|
|
auto stablehloAttr = convertGeneric(vhloAttr, typeConverter);
|
||
|
|
if (!stablehloAttr || !isa<TypeAttr>(stablehloAttr)) return failure();
|
||
|
|
result = cast<TypeAttr>(stablehloAttr).getValue();
|
||
|
|
-@@ -244,7 +251,7 @@
|
||
|
|
+@@ -244,7 +251,7 @@ LogicalResult convertInt(Attribute vhloAttr, int64_t& result) {
|
||
|
|
}
|
||
|
|
|
||
|
|
LogicalResult convertInts(Attribute vhloAttr,
|
||
|
|
@@ -526,7 +600,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable
|
||
|
|
SmallVector<int64_t>& result) {
|
||
|
|
auto vhloTensorAttr = dyn_cast<vhlo::TensorV1Attr>(vhloAttr);
|
||
|
|
if (!vhloTensorAttr) return failure();
|
||
|
|
-@@ -256,7 +263,7 @@
|
||
|
|
+@@ -256,7 +263,7 @@ LogicalResult convertInts(Attribute vhloAttr,
|
||
|
|
}
|
||
|
|
|
||
|
|
Attribute convertSymbol(Attribute vhloAttr,
|
||
|
|
@@ -535,7 +609,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable
|
||
|
|
auto vhloStringAttr = dyn_cast<vhlo::StringV1Attr>(vhloAttr);
|
||
|
|
if (!vhloStringAttr) return {};
|
||
|
|
auto stablehloStringAttr = dyn_cast_or_null<StringAttr>(
|
||
|
|
-@@ -267,7 +274,7 @@
|
||
|
|
+@@ -267,7 +274,7 @@ Attribute convertSymbol(Attribute vhloAttr,
|
||
|
|
|
||
|
|
template <typename OpType>
|
||
|
|
Attribute convertChannelHandle(OpType vhloOp,
|
||
|
|
@@ -544,7 +618,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable
|
||
|
|
int64_t channelId, channelType;
|
||
|
|
if (failed(convertInt(vhloOp.getChannelId(), channelId)) ||
|
||
|
|
failed(convertInt(vhloOp.getChannelType(), channelType)))
|
||
|
|
-@@ -277,7 +284,7 @@
|
||
|
|
+@@ -277,7 +284,7 @@ Attribute convertChannelHandle(OpType vhloOp,
|
||
|
|
}
|
||
|
|
|
||
|
|
Attribute convertChannelId(Attribute vhloAttr,
|
||
|
|
@@ -553,7 +627,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable
|
||
|
|
int64_t channelId;
|
||
|
|
if (failed(convertInt(vhloAttr, channelId))) return {};
|
||
|
|
return stablehlo::ChannelHandleAttr::get(vhloAttr.getContext(), channelId,
|
||
|
|
-@@ -285,8 +292,8 @@
|
||
|
|
+@@ -285,8 +292,8 @@ Attribute convertChannelId(Attribute vhloAttr,
|
||
|
|
}
|
||
|
|
|
||
|
|
template <typename OpType>
|
||
|
|
@@ -564,7 +638,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable
|
||
|
|
int64_t stablehloInputBatchDimension, stablehloInputFeatureDimension;
|
||
|
|
SmallVector<int64_t> stablehloInputSpatialDimensions;
|
||
|
|
int64_t stablehloKernelInputFeatureDimension,
|
||
|
|
-@@ -323,7 +330,7 @@
|
||
|
|
+@@ -323,7 +330,7 @@ Attribute convertConvDimensionNumbers(OpType vhloOp,
|
||
|
|
}
|
||
|
|
|
||
|
|
Attribute convertCustomCallCalledComputations(
|
||
|
|
@@ -573,7 +647,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable
|
||
|
|
if (auto vhloArrayAttr = dyn_cast<vhlo::ArrayV1Attr>(vhloAttr)) {
|
||
|
|
SmallVector<Attribute> stablehloAttrs;
|
||
|
|
for (auto vhloAttr : vhloArrayAttr.getValue()) {
|
||
|
|
-@@ -336,8 +343,8 @@
|
||
|
|
+@@ -336,8 +343,8 @@ Attribute convertCustomCallCalledComputations(
|
||
|
|
return {};
|
||
|
|
}
|
||
|
|
|
||
|
|
@@ -584,7 +658,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable
|
||
|
|
Type lhsPrecisionType, rhsPrecisionType, accumulationType;
|
||
|
|
if (isNoneType(vhloOp.getLhsComponentCount())) {
|
||
|
|
// All must be nonetype
|
||
|
|
-@@ -373,8 +380,8 @@
|
||
|
|
+@@ -373,8 +380,8 @@ FailureOr<Attribute> convertDotAlgorithm(vhlo::DotGeneralOpV2 vhloOp,
|
||
|
|
numPrimitiveOperations, allowImpreciseAccumulation);
|
||
|
|
}
|
||
|
|
|
||
|
|
@@ -595,7 +669,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable
|
||
|
|
SmallVector<int64_t> stablehloLhsBatchingDimensions,
|
||
|
|
stablehloRhsBatchingDimensions, stablehloLhsContractingDimensions,
|
||
|
|
stablehloRhsContractingDimensions;
|
||
|
|
-@@ -394,13 +401,13 @@
|
||
|
|
+@@ -394,13 +401,13 @@ Attribute convertDotDimensionNumbers(vhlo::DotGeneralOpV2 vhloOp,
|
||
|
|
}
|
||
|
|
|
||
|
|
Attribute convertFuncCallee(Attribute vhloAttr,
|
||
|
|
@@ -612,7 +686,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable
|
||
|
|
SmallVector<int64_t> stablehloOffsetDims, stablehloCollapsedSliceDims,
|
||
|
|
stablehloOperandBatchingDims, stablehloStartIndicesBatchingDims,
|
||
|
|
stablehloStartIndexMap;
|
||
|
|
-@@ -423,8 +430,8 @@
|
||
|
|
+@@ -423,8 +430,8 @@ Attribute convertGatherDimensionNumbers(OpType vhloOp,
|
||
|
|
stablehloStartIndexMap, stablehloIndexVectorDim);
|
||
|
|
}
|
||
|
|
|
||
|
|
@@ -623,7 +697,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable
|
||
|
|
SmallVector<int64_t> stablehloUpdateWindowDims, stablehloInsertedWindowDims,
|
||
|
|
stablehloInputBatchingDims, stablehloScatterIndicesBatchingDims,
|
||
|
|
stablehloScatterDimsToOperandDims;
|
||
|
|
-@@ -463,10 +470,11 @@
|
||
|
|
+@@ -463,10 +470,11 @@ LogicalResult implodeSpecial(const OpConversionPattern<VhloOpTy>& pattern,
|
||
|
|
VhloOpTy vhloOp,
|
||
|
|
SmallVector<NamedAttribute>& vhloAttrs,
|
||
|
|
SmallVector<NamedAttribute>& stablehloAttrs) {
|
||
|
|
@@ -637,7 +711,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable
|
||
|
|
if (!stablehloAttr) return failure();
|
||
|
|
stablehloAttrs.emplace_back(
|
||
|
|
StringAttr::get(pattern.getContext(), "dimension_numbers"),
|
||
|
|
-@@ -480,7 +488,7 @@
|
||
|
|
+@@ -480,7 +488,7 @@ LogicalResult implodeSpecial(const OpConversionPattern<VhloOpTy>& pattern,
|
||
|
|
if constexpr (std::is_same<VhloOpTy, vhlo::DotGeneralOpV2>::value) {
|
||
|
|
// Dot Dimension Numbers
|
||
|
|
auto stablehloDotDimAttr =
|
||
|
|
@@ -646,7 +720,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable
|
||
|
|
if (!stablehloDotDimAttr) return failure();
|
||
|
|
stablehloAttrs.emplace_back(
|
||
|
|
StringAttr::get(pattern.getContext(), "dot_dimension_numbers"),
|
||
|
|
-@@ -489,8 +497,7 @@
|
||
|
|
+@@ -489,8 +497,7 @@ LogicalResult implodeSpecial(const OpConversionPattern<VhloOpTy>& pattern,
|
||
|
|
"lhs_contracting_dimensions", "rhs_contracting_dimensions");
|
||
|
|
|
||
|
|
// Dot Algorithm
|
||
|
|
@@ -656,7 +730,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable
|
||
|
|
if (failed(stablehloDotAlgorithmAttr)) return failure();
|
||
|
|
if (stablehloDotAlgorithmAttr.value())
|
||
|
|
stablehloAttrs.emplace_back(
|
||
|
|
-@@ -503,8 +510,7 @@
|
||
|
|
+@@ -503,8 +510,7 @@ LogicalResult implodeSpecial(const OpConversionPattern<VhloOpTy>& pattern,
|
||
|
|
}
|
||
|
|
if constexpr (std::is_same<VhloOpTy, vhlo::DynamicGatherOpV2>::value ||
|
||
|
|
std::is_same<VhloOpTy, vhlo::GatherOpV2>::value) {
|
||
|
|
@@ -666,7 +740,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable
|
||
|
|
if (!stablehloAttr) return failure();
|
||
|
|
stablehloAttrs.emplace_back(
|
||
|
|
StringAttr::get(pattern.getContext(), "dimension_numbers"),
|
||
|
|
-@@ -514,8 +520,7 @@
|
||
|
|
+@@ -514,8 +520,7 @@ LogicalResult implodeSpecial(const OpConversionPattern<VhloOpTy>& pattern,
|
||
|
|
"start_index_map", "index_vector_dim");
|
||
|
|
}
|
||
|
|
if constexpr (std::is_same<VhloOpTy, vhlo::ScatterOpV2>::value) {
|
||
|
|
@@ -676,7 +750,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable
|
||
|
|
if (!stablehloAttr) return failure();
|
||
|
|
stablehloAttrs.emplace_back(
|
||
|
|
StringAttr::get(pattern.getContext(), "scatter_dimension_numbers"),
|
||
|
|
-@@ -526,8 +531,7 @@
|
||
|
|
+@@ -526,8 +531,7 @@ LogicalResult implodeSpecial(const OpConversionPattern<VhloOpTy>& pattern,
|
||
|
|
}
|
||
|
|
if constexpr (std::is_same<VhloOpTy, vhlo::RecvOpV1>::value ||
|
||
|
|
std::is_same<VhloOpTy, vhlo::SendOpV1>::value) {
|
||
|
|
@@ -686,7 +760,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable
|
||
|
|
if (!stablehloAttr) return failure();
|
||
|
|
stablehloAttrs.emplace_back(
|
||
|
|
StringAttr::get(pattern.getContext(), "channel_handle"), stablehloAttr);
|
||
|
|
-@@ -537,7 +541,7 @@
|
||
|
|
+@@ -537,7 +541,7 @@ LogicalResult implodeSpecial(const OpConversionPattern<VhloOpTy>& pattern,
|
||
|
|
}
|
||
|
|
|
||
|
|
template <typename T, typename DenseArrayAttr>
|
||
|
|
@@ -695,7 +769,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable
|
||
|
|
StringAttr vhloName, Attribute vhloAttr,
|
||
|
|
SmallVector<NamedAttribute>& stablehloAttrs) {
|
||
|
|
auto tensorAttr = dyn_cast<vhlo::TensorV1Attr>(vhloAttr);
|
||
|
|
-@@ -556,15 +560,15 @@
|
||
|
|
+@@ -556,15 +560,15 @@ SpecialResult convertDenseArray(const TypeConverter* typeConverter,
|
||
|
|
}
|
||
|
|
|
||
|
|
SpecialResult convertDenseBoolArray(
|
||
|
|
@@ -715,7 +789,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable
|
||
|
|
return convertDenseArray<int64_t, DenseI64ArrayAttr>(
|
||
|
|
typeConverter, vhloName, vhloAttr, stablehloAttrs);
|
||
|
|
}
|
||
|
|
-@@ -575,7 +579,8 @@
|
||
|
|
+@@ -575,7 +579,8 @@ SpecialResult convertSpecial(const OpConversionPattern<VhloOpTy>& pattern,
|
||
|
|
SmallVector<NamedAttribute>& stablehloAttrs) {
|
||
|
|
StringAttr stablehloName = vhloName;
|
||
|
|
Attribute stablehloAttr;
|
||
|
|
@@ -725,7 +799,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable
|
||
|
|
|
||
|
|
if constexpr (std::is_same<VhloOpTy, vhlo::AllGatherOpV2>::value ||
|
||
|
|
std::is_same<VhloOpTy, vhlo::AllReduceOpV2>::value ||
|
||
|
|
-@@ -585,7 +590,7 @@
|
||
|
|
+@@ -585,7 +590,7 @@ SpecialResult convertSpecial(const OpConversionPattern<VhloOpTy>& pattern,
|
||
|
|
std::is_same<VhloOpTy, vhlo::CollectiveBroadcastOpV1>::value) {
|
||
|
|
if (vhloName == "channel_id") {
|
||
|
|
stablehloName = StringAttr::get(pattern.getContext(), "channel_handle");
|
||
|
|
@@ -734,7 +808,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable
|
||
|
|
if (!stablehloAttr) return specialFailure();
|
||
|
|
}
|
||
|
|
if (vhloName == "use_global_device_ids") {
|
||
|
|
-@@ -597,20 +602,20 @@
|
||
|
|
+@@ -597,20 +602,20 @@ SpecialResult convertSpecial(const OpConversionPattern<VhloOpTy>& pattern,
|
||
|
|
}
|
||
|
|
if constexpr (std::is_same<VhloOpTy, vhlo::CustomCallOpV1>::value) {
|
||
|
|
if (vhloName == "called_computations") {
|
||
|
|
@@ -759,7 +833,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable
|
||
|
|
if (!stablehloAttr) return specialFailure();
|
||
|
|
}
|
||
|
|
}
|
||
|
|
-@@ -760,8 +765,8 @@
|
||
|
|
+@@ -760,8 +765,8 @@ bool isDefaultResultAccuracyAttribute(Attribute vhloAttr) {
|
||
|
|
template <typename T>
|
||
|
|
bool isSplatTensor(const ConversionPattern& pattern, Attribute vhloAttr,
|
||
|
|
T splatValue) {
|
||
|
|
@@ -770,7 +844,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable
|
||
|
|
return attr && attr.isSplat() &&
|
||
|
|
attr.template getSplatValue<T>() == splatValue;
|
||
|
|
}
|
||
|
|
-@@ -977,8 +982,9 @@
|
||
|
|
+@@ -977,8 +982,9 @@ class VhloToStablehloOpConverter : public OpConversionPattern<VhloOpTy> {
|
||
|
|
case SpecialResult::SPECIAL_FAILURE:
|
||
|
|
return failure();
|
||
|
|
case SpecialResult::NOT_SPECIAL:
|
||
|
|
@@ -782,7 +856,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable
|
||
|
|
if (!stablehloAttr) return failure();
|
||
|
|
stablehloAttrs.push_back({vhloAttr.getName(), stablehloAttr});
|
||
|
|
break;
|
||
|
|
-@@ -1056,7 +1062,7 @@
|
||
|
|
+@@ -1056,7 +1062,7 @@ struct ReconcileUnrealizedConversionCasts
|
||
|
|
template <typename... StablehloOpTypes>
|
||
|
|
void populateVhloToStablehloPatterns(MLIRContext* context,
|
||
|
|
RewritePatternSet* patterns,
|
||
|
|
@@ -791,7 +865,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable
|
||
|
|
patterns
|
||
|
|
->add<VhloToStablehloOpConverter<StablehloToVhloOp<StablehloOpTypes>>...>(
|
||
|
|
*converter, context);
|
||
|
|
-@@ -1104,7 +1110,7 @@
|
||
|
|
+@@ -1104,7 +1110,7 @@ struct VhloLegalizeToStablehloPass
|
||
|
|
|
||
|
|
void populateVhloToStablehloPatterns(MLIRContext* context,
|
||
|
|
RewritePatternSet* patterns,
|
||
|
|
@@ -800,10 +874,11 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stable
|
||
|
|
populateVhloToStablehloPatterns<
|
||
|
|
#define GET_OP_LIST
|
||
|
|
#include "stablehlo/dialect/StablehloOps.cpp.inc"
|
||
|
|
-diff --ruN a/stablehlo/stablehlo/transforms/VhloToVersion.cpp b/stablehlo/stablehlo/transforms/VhloToVersion.cpp
|
||
|
|
---- stablehlo/stablehlo/transforms/VhloToVersion.cpp
|
||
|
|
-+++ stablehlo/stablehlo/transforms/VhloToVersion.cpp
|
||
|
|
-@@ -56,15 +56,23 @@
|
||
|
|
+diff --git a/stablehlo/transforms/VhloToVersion.cpp b/stablehlo/transforms/VhloToVersion.cpp
|
||
|
|
+index dfbf2877..7818e601 100644
|
||
|
|
+--- a/stablehlo/transforms/VhloToVersion.cpp
|
||
|
|
++++ b/stablehlo/transforms/VhloToVersion.cpp
|
||
|
|
+@@ -56,15 +56,23 @@ namespace {
|
||
|
|
|
||
|
|
// Currently there are no type-to-version conversions so this class
|
||
|
|
// simply validates that all types are from the VHLO dialect.
|
||
|
|
@@ -833,7 +908,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloToVersion.cpp b/stablehlo/stable
|
||
|
|
};
|
||
|
|
|
||
|
|
// Check user-specified target version. Emit error if invalid.
|
||
|
|
-@@ -111,6 +119,10 @@
|
||
|
|
+@@ -111,6 +119,10 @@ bool isLegalVersion(VersionedInterface& interface, const Version& target) {
|
||
|
|
LogicalResult isLegalType(Type type, const Version& targetVersion);
|
||
|
|
|
||
|
|
LogicalResult isLegalAttribute(const Attribute& attr, Version targetVersion) {
|
||
|
|
@@ -844,7 +919,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloToVersion.cpp b/stablehlo/stable
|
||
|
|
auto attrInterface = dyn_cast<VersionedAttrInterface>(attr);
|
||
|
|
if (!attrInterface || !isLegalVersion(attrInterface, targetVersion)) {
|
||
|
|
LLVM_DEBUG(llvm::dbgs() << "failed to legalize attribute " << attr
|
||
|
|
-@@ -119,10 +131,11 @@
|
||
|
|
+@@ -119,10 +131,11 @@ LogicalResult isLegalAttribute(const Attribute& attr, Version targetVersion) {
|
||
|
|
}
|
||
|
|
|
||
|
|
// Recursively check attrs if VHLO attr is a container
|
||
|
|
@@ -857,7 +932,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloToVersion.cpp b/stablehlo/stable
|
||
|
|
if (auto arrAttr = dyn_cast<DictionaryV1Attr>(attr)) {
|
||
|
|
return success(llvm::all_of(
|
||
|
|
arrAttr.getValue(), [&](std::pair<Attribute, Attribute> entry) {
|
||
|
|
-@@ -146,6 +159,10 @@
|
||
|
|
+@@ -146,6 +159,10 @@ LogicalResult isLegalAttribute(const Attribute& attr, Version targetVersion) {
|
||
|
|
}
|
||
|
|
|
||
|
|
LogicalResult isLegalType(Type type, const Version& targetVersion) {
|
||
|
|
@@ -868,7 +943,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloToVersion.cpp b/stablehlo/stable
|
||
|
|
// All valid VHLO types must have versioned type interface.
|
||
|
|
auto typeInterface = dyn_cast<VersionedTypeInterface>(type);
|
||
|
|
if (!typeInterface || !isLegalVersion(typeInterface, targetVersion)) {
|
||
|
|
-@@ -170,10 +187,11 @@
|
||
|
|
+@@ -170,10 +187,11 @@ LogicalResult isLegalType(Type type, const Version& targetVersion) {
|
||
|
|
return failure();
|
||
|
|
return isLegalType(ranked.getElementType(), targetVersion);
|
||
|
|
}
|
||
|
|
@@ -881,7 +956,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloToVersion.cpp b/stablehlo/stable
|
||
|
|
if (auto quant = dyn_cast<UniformQuantizedV1Type>(type))
|
||
|
|
return success(
|
||
|
|
succeeded(isLegalType(quant.getStorageType(), targetVersion)) &&
|
||
|
|
-@@ -213,7 +231,7 @@
|
||
|
|
+@@ -213,7 +231,7 @@ bool isLegalLocation(Location loc, const Version& targetVersion) {
|
||
|
|
bool isLegalOperation(Operation* op, const Version& targetVersion) {
|
||
|
|
// Validate op
|
||
|
|
auto opInterface = dyn_cast<VersionedOpInterface>(op);
|
||
|
|
@@ -890,7 +965,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloToVersion.cpp b/stablehlo/stable
|
||
|
|
if (!isLegalVersion(opInterface, targetVersion)) return false;
|
||
|
|
LLVM_DEBUG(llvm::dbgs() << "Legal op version for target. " << op << '\n');
|
||
|
|
|
||
|
|
-@@ -454,7 +472,7 @@
|
||
|
|
+@@ -454,7 +472,7 @@ struct AllReduceOpV2ToV1 : public OpRewritePattern<AllReduceOpV2> {
|
||
|
|
namespace stablehlo {
|
||
|
|
void populateVhloToVersionPatterns(MLIRContext* context,
|
||
|
|
RewritePatternSet* patterns,
|
||
|
|
@@ -899,10 +974,11 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloToVersion.cpp b/stablehlo/stable
|
||
|
|
vhlo::populateWithGenerated(*patterns);
|
||
|
|
patterns->add<vhlo::ScatterOpV1ToV2, vhlo::ScatterOpV2ToV1>(context);
|
||
|
|
patterns->add<vhlo::AllReduceOpV1ToV2, vhlo::AllReduceOpV2ToV1>(context);
|
||
|
|
-diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp b/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp
|
||
|
|
---- stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp
|
||
|
|
-+++ stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp
|
||
|
|
-@@ -19,6 +19,7 @@
|
||
|
|
+diff --git a/stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp b/stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp
|
||
|
|
+index 9cfac0ba..09bfb783 100644
|
||
|
|
+--- a/stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp
|
||
|
|
++++ b/stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp
|
||
|
|
+@@ -19,6 +19,7 @@ limitations under the License.
|
||
|
|
#include <memory>
|
||
|
|
#include <numeric>
|
||
|
|
#include <optional>
|
||
|
|
@@ -910,7 +986,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
|
||
|
|
#include <utility>
|
||
|
|
|
||
|
|
#include "llvm/ADT/APInt.h"
|
||
|
|
-@@ -47,6 +48,7 @@
|
||
|
|
+@@ -47,6 +48,7 @@ limitations under the License.
|
||
|
|
#include "mlir/IR/TypeUtilities.h"
|
||
|
|
#include "mlir/IR/Types.h"
|
||
|
|
#include "mlir/IR/Value.h"
|
||
|
|
@@ -918,7 +994,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
|
||
|
|
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||
|
|
#include "mlir/Pass/Pass.h"
|
||
|
|
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
|
||
|
|
-@@ -98,29 +100,88 @@
|
||
|
|
+@@ -98,29 +100,88 @@ LogicalResult validateStaticShapeResult(PatternRewriter& rewriter,
|
||
|
|
return success();
|
||
|
|
}
|
||
|
|
|
||
|
|
@@ -982,8 +1058,8 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
|
||
|
|
if (res) return cast<TypedAttr>(res);
|
||
|
|
|
||
|
|
return nullptr;
|
||
|
|
-+}
|
||
|
|
-+
|
||
|
|
+ }
|
||
|
|
+
|
||
|
|
+
|
||
|
|
+/// Binary constant folder that used a generic folder function to handle both
|
||
|
|
+/// ints and floats.
|
||
|
|
@@ -1005,8 +1081,8 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
|
||
|
|
+ if (!res) return rewriter.notifyMatchFailure(op, "folding failed");
|
||
|
|
+
|
||
|
|
+ return res;
|
||
|
|
- }
|
||
|
|
-
|
||
|
|
++}
|
||
|
|
++
|
||
|
|
template <class AttrElementT, class TargetAttrElementT, class CalculationT,
|
||
|
|
typename OpType>
|
||
|
|
-LogicalResult evalConvertHelper(PatternRewriter& rewriter, OpType op,
|
||
|
|
@@ -1014,7 +1090,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
|
||
|
|
DenseIntOrFPElementsAttr elements, Type resType,
|
||
|
|
CalculationT&& calculate) {
|
||
|
|
auto result = constFoldCastOp<AttrElementT, TargetAttrElementT,
|
||
|
|
-@@ -128,18 +189,19 @@
|
||
|
|
+@@ -128,18 +189,19 @@ LogicalResult evalConvertHelper(PatternRewriter& rewriter, OpType op,
|
||
|
|
typename TargetAttrElementT::ValueType, void>(
|
||
|
|
elements, resType, calculate);
|
||
|
|
|
||
|
|
@@ -1036,7 +1112,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
|
||
|
|
DenseIntOrFPElementsAttr elements,
|
||
|
|
RankedTensorType resultType) {
|
||
|
|
auto oldType = getElementTypeOrSelf(elements);
|
||
|
|
-@@ -153,7 +215,7 @@
|
||
|
|
+@@ -153,7 +215,7 @@ LogicalResult evalConvert(PatternRewriter& rewriter, OpType op,
|
||
|
|
if (auto newFloatType = dyn_cast<FloatType>(newType)) {
|
||
|
|
// Float -> Float
|
||
|
|
const auto& targetSemantics = newFloatType.getFloatSemantics();
|
||
|
|
@@ -1045,7 +1121,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
|
||
|
|
rewriter, op, elements, resultType,
|
||
|
|
[&targetSemantics](const APFloat& operand, bool& castStatus) {
|
||
|
|
bool losesInfo;
|
||
|
|
-@@ -167,7 +229,7 @@
|
||
|
|
+@@ -167,7 +229,7 @@ LogicalResult evalConvert(PatternRewriter& rewriter, OpType op,
|
||
|
|
}
|
||
|
|
|
||
|
|
// Float -> Int
|
||
|
|
@@ -1054,7 +1130,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
|
||
|
|
rewriter, op, elements, resultType,
|
||
|
|
[&newBitWidth, &isNewTypeUnsigned](const APFloat& operand,
|
||
|
|
bool& castStatus) {
|
||
|
|
-@@ -186,7 +248,7 @@
|
||
|
|
+@@ -186,7 +248,7 @@ LogicalResult evalConvert(PatternRewriter& rewriter, OpType op,
|
||
|
|
|
||
|
|
if (auto newFloatType = dyn_cast<FloatType>(newType)) {
|
||
|
|
// Int -> Float
|
||
|
|
@@ -1063,7 +1139,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
|
||
|
|
rewriter, op, elements, resultType,
|
||
|
|
[&newFloatType, &isOldTypeUnsigned](const APInt& operand,
|
||
|
|
bool& /*castStatus*/) {
|
||
|
|
-@@ -199,64 +261,12 @@
|
||
|
|
+@@ -199,7 +261,7 @@ LogicalResult evalConvert(PatternRewriter& rewriter, OpType op,
|
||
|
|
}
|
||
|
|
|
||
|
|
// Int -> Int
|
||
|
|
@@ -1072,10 +1148,10 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
|
||
|
|
rewriter, op, elements, resultType,
|
||
|
|
[&newBitWidth, &isOldTypeUnsigned](const APInt& operand,
|
||
|
|
bool& /*castStatus*/) {
|
||
|
|
- return APSInt(operand, isOldTypeUnsigned).extOrTrunc(newBitWidth);
|
||
|
|
+@@ -207,58 +269,6 @@ LogicalResult evalConvert(PatternRewriter& rewriter, OpType op,
|
||
|
|
});
|
||
|
|
--}
|
||
|
|
--
|
||
|
|
+ }
|
||
|
|
+
|
||
|
|
-// The patterns below implement partial evaluation of shape computations which
|
||
|
|
-// is a critical part of implementing type refinement for ops like
|
||
|
|
-// dynamic_broadcast_in_dim, dynamic_iota and dynamic_reshape whose shape
|
||
|
|
@@ -1126,10 +1202,12 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
|
||
|
|
- rewriter.replaceOpWithNewOp<ConstantOp>(op,
|
||
|
|
- getTensorAttr(resultType, result));
|
||
|
|
- return success();
|
||
|
|
- }
|
||
|
|
-
|
||
|
|
+-}
|
||
|
|
+-
|
||
|
|
template <typename OpType>
|
||
|
|
-@@ -275,29 +285,18 @@
|
||
|
|
+ struct FoldOpRewritePattern : OpRewritePattern<OpType> {
|
||
|
|
+ FoldOpRewritePattern(MLIRContext* context,
|
||
|
|
+@@ -275,29 +285,18 @@ struct FoldOpRewritePattern : OpRewritePattern<OpType> {
|
||
|
|
ArrayRef<StringRef> generatedNames = {}) = delete;
|
||
|
|
|
||
|
|
const StablehloAggressiveFolderPassOptions& options;
|
||
|
|
@@ -1154,9 +1232,8 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
|
||
|
|
- rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(op, res);
|
||
|
|
- return success();
|
||
|
|
- }
|
||
|
|
--
|
||
|
|
+
|
||
|
|
- return failure();
|
||
|
|
-+
|
||
|
|
+ LogicalResult validateElementCountForFold(PatternRewriter& rewriter,
|
||
|
|
+ Operation* op,
|
||
|
|
+ ShapedType resultType) const {
|
||
|
|
@@ -1171,31 +1248,21 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
|
||
|
|
}
|
||
|
|
};
|
||
|
|
|
||
|
|
-@@ -318,100 +317,102 @@
|
||
|
|
+@@ -318,100 +317,102 @@ struct ShapeOpRewritePattern : public FoldOpRewritePattern<OpType> {
|
||
|
|
}
|
||
|
|
};
|
||
|
|
|
||
|
|
-struct EvalAddOpShapePattern : public FoldOpRewritePattern<AddOp> {
|
||
|
|
- using FoldOpRewritePattern::FoldOpRewritePattern;
|
||
|
|
--
|
||
|
|
-- LogicalResult matchAndRewrite(AddOp op,
|
||
|
|
-- PatternRewriter& rewriter) const override {
|
||
|
|
-- return evalElementwise(rewriter, op,
|
||
|
|
-- [&](APSInt lhs, APSInt rhs) { return lhs + rhs; });
|
||
|
|
-- }
|
||
|
|
--};
|
||
|
|
--
|
||
|
|
--struct EvalAndOpPattern : public FoldOpRewritePattern<AndOp> {
|
||
|
|
-- using FoldOpRewritePattern::FoldOpRewritePattern;
|
||
|
|
--
|
||
|
|
-- LogicalResult matchAndRewrite(AndOp op,
|
||
|
|
-- PatternRewriter& rewriter) const override {
|
||
|
|
+struct FoldAddOpPattern final
|
||
|
|
+ : public ShapeOpRewritePattern<mlir::stablehlo::AddOp> {
|
||
|
|
+ using ShapeOpRewritePattern::ShapeOpRewritePattern;
|
||
|
|
-+
|
||
|
|
+
|
||
|
|
+- LogicalResult matchAndRewrite(AddOp op,
|
||
|
|
+ LogicalResult matchAndRewrite(mlir::stablehlo::AddOp op,
|
||
|
|
-+ PatternRewriter& rewriter) const override {
|
||
|
|
+ PatternRewriter& rewriter) const override {
|
||
|
|
+- return evalElementwise(rewriter, op,
|
||
|
|
+- [&](APSInt lhs, APSInt rhs) { return lhs + rhs; });
|
||
|
|
+ if (failed(validateShapeFoldDtype(rewriter, op, op.getType())))
|
||
|
|
+ return failure();
|
||
|
|
+
|
||
|
|
@@ -1203,14 +1270,17 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
|
||
|
|
+ if (failed(res)) return failure();
|
||
|
|
+ rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(op, res.value());
|
||
|
|
+ return success();
|
||
|
|
-+ }
|
||
|
|
-+};
|
||
|
|
-+
|
||
|
|
+ }
|
||
|
|
+ };
|
||
|
|
+
|
||
|
|
+-struct EvalAndOpPattern : public FoldOpRewritePattern<AndOp> {
|
||
|
|
+- using FoldOpRewritePattern::FoldOpRewritePattern;
|
||
|
|
+struct FoldAndOpPattern : public ShapeOpRewritePattern<AndOp> {
|
||
|
|
+ using ShapeOpRewritePattern::ShapeOpRewritePattern;
|
||
|
|
-+
|
||
|
|
+
|
||
|
|
+- LogicalResult matchAndRewrite(AndOp op,
|
||
|
|
+ LogicalResult matchAndRewrite(mlir::stablehlo::AndOp op,
|
||
|
|
-+ PatternRewriter& rewriter) const override {
|
||
|
|
+ PatternRewriter& rewriter) const override {
|
||
|
|
+ // TODO: Support more int types
|
||
|
|
auto resultType = op.getType();
|
||
|
|
if (!resultType.getElementType().isInteger(1))
|
||
|
|
@@ -1219,24 +1289,14 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
|
||
|
|
- return evalElementwise(rewriter, op, [&](APSInt lhsInt, APSInt rhsInt) {
|
||
|
|
- return getAPSInt(resultType.getElementType(), lhsInt != 0 && rhsInt != 0);
|
||
|
|
- });
|
||
|
|
-- }
|
||
|
|
+ auto res = foldBinaryOpIntOrFloat(rewriter, op, FoldAnd{});
|
||
|
|
+ if (failed(res)) return failure();
|
||
|
|
+ rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(op, res.value());
|
||
|
|
+ return success();
|
||
|
|
-+ }
|
||
|
|
-+
|
||
|
|
-+ struct FoldAnd {
|
||
|
|
-+ APInt operator()(APInt lhs, APInt rhs) const {
|
||
|
|
-+ return APInt(lhs.getBitWidth(), !lhs.isZero() && !rhs.isZero());
|
||
|
|
-+ }
|
||
|
|
-+ std::optional<APFloat> operator()(APFloat lhs, APFloat rhs) const {
|
||
|
|
-+ return std::nullopt;
|
||
|
|
-+ }
|
||
|
|
-+ };
|
||
|
|
- };
|
||
|
|
+ }
|
||
|
|
+-};
|
||
|
|
|
||
|
|
- // Pattern: broadcast_in_dim(splat, _) -> constant(splat)
|
||
|
|
+-// Pattern: broadcast_in_dim(splat, _) -> constant(splat)
|
||
|
|
-struct FoldBroadcastInDimSplatPattern final
|
||
|
|
- : FoldOpRewritePattern<mlir::stablehlo::BroadcastInDimOp> {
|
||
|
|
- using FoldOpRewritePattern::FoldOpRewritePattern;
|
||
|
|
@@ -1251,14 +1311,22 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
|
||
|
|
- op, SplatElementsAttr::get(op.getType(),
|
||
|
|
- cstAttr.getSplatValue<Attribute>()));
|
||
|
|
- return success();
|
||
|
|
-- }
|
||
|
|
++ struct FoldAnd {
|
||
|
|
++ APInt operator()(APInt lhs, APInt rhs) const {
|
||
|
|
++ return APInt(lhs.getBitWidth(), !lhs.isZero() && !rhs.isZero());
|
||
|
|
+ }
|
||
|
|
- return failure();
|
||
|
|
- }
|
||
|
|
--};
|
||
|
|
--
|
||
|
|
++ std::optional<APFloat> operator()(APFloat lhs, APFloat rhs) const {
|
||
|
|
++ return std::nullopt;
|
||
|
|
++ }
|
||
|
|
++ };
|
||
|
|
+ };
|
||
|
|
+
|
||
|
|
-struct EvalBroadcastInDimOpPattern
|
||
|
|
- : public FoldOpRewritePattern<BroadcastInDimOp> {
|
||
|
|
- using FoldOpRewritePattern::FoldOpRewritePattern;
|
||
|
|
++// Pattern: broadcast_in_dim(splat, _) -> constant(splat)
|
||
|
|
+struct FoldBroadcastInDimOpSplatPattern
|
||
|
|
+ : public ShapeOpRewritePattern<BroadcastInDimOp> {
|
||
|
|
+ using ShapeOpRewritePattern::ShapeOpRewritePattern;
|
||
|
|
@@ -1267,8 +1335,10 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
|
||
|
|
PatternRewriter& rewriter) const override {
|
||
|
|
auto resultType = op.getType();
|
||
|
|
- if (failed(validateStaticShapeResult(rewriter, op, resultType)))
|
||
|
|
-- return failure();
|
||
|
|
--
|
||
|
|
++ if (failed(validateStaticShapeResult(rewriter, op, resultType)) ||
|
||
|
|
++ failed(validateShapeFoldDtype(rewriter, op, resultType)))
|
||
|
|
+ return failure();
|
||
|
|
+
|
||
|
|
- auto operandType = op.getOperand().getType();
|
||
|
|
- if (operandType.getRank() != 0)
|
||
|
|
- return rewriter.notifyMatchFailure(op, "expected 0-dimensional type");
|
||
|
|
@@ -1277,52 +1347,34 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
|
||
|
|
- if (failed(hlo::matchInts(op.getOperand(), operand)))
|
||
|
|
- return rewriter.notifyMatchFailure(op, "expected constant operands");
|
||
|
|
- auto scalar = operand[0];
|
||
|
|
--
|
||
|
|
++ SplatElementsAttr cstAttr;
|
||
|
|
++ matchPattern(op.getOperand(), m_Constant(&cstAttr));
|
||
|
|
++ if (!cstAttr) return rewriter.notifyMatchFailure(op, "operand not splat");
|
||
|
|
+
|
||
|
|
- rewriter.replaceOpWithNewOp<ConstantOp>(
|
||
|
|
- op, getTensorAttr(op.getType(), scalar));
|
||
|
|
-- return success();
|
||
|
|
-- }
|
||
|
|
--};
|
||
|
|
--
|
||
|
|
++ rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(
|
||
|
|
++ op, SplatElementsAttr::get(op.getType(),
|
||
|
|
++ cstAttr.getSplatValue<Attribute>()));
|
||
|
|
+ return success();
|
||
|
|
+ }
|
||
|
|
+ };
|
||
|
|
+
|
||
|
|
-struct EvalClampOpPattern : public FoldOpRewritePattern<ClampOp> {
|
||
|
|
- using FoldOpRewritePattern::FoldOpRewritePattern;
|
||
|
|
--
|
||
|
|
++struct FoldCompareOpPattern : public ShapeOpRewritePattern<CompareOp> {
|
||
|
|
++ using ShapeOpRewritePattern::ShapeOpRewritePattern;
|
||
|
|
+
|
||
|
|
- LogicalResult matchAndRewrite(ClampOp op,
|
||
|
|
-- PatternRewriter& rewriter) const override {
|
||
|
|
++ LogicalResult matchAndRewrite(CompareOp op,
|
||
|
|
+ PatternRewriter& rewriter) const override {
|
||
|
|
- return evalElementwise(rewriter, op,
|
||
|
|
- [&](APSInt min, APSInt operand, APSInt max) {
|
||
|
|
- if (operand < min) return min;
|
||
|
|
- if (max < operand) return max;
|
||
|
|
- return operand;
|
||
|
|
- });
|
||
|
|
-- }
|
||
|
|
--};
|
||
|
|
--
|
||
|
|
--struct EvalCompareOpPattern : public FoldOpRewritePattern<CompareOp> {
|
||
|
|
-- using FoldOpRewritePattern::FoldOpRewritePattern;
|
||
|
|
-+ if (failed(validateStaticShapeResult(rewriter, op, resultType)) ||
|
||
|
|
-+ failed(validateShapeFoldDtype(rewriter, op, resultType)))
|
||
|
|
-+ return failure();
|
||
|
|
-+
|
||
|
|
-+ SplatElementsAttr cstAttr;
|
||
|
|
-+ matchPattern(op.getOperand(), m_Constant(&cstAttr));
|
||
|
|
-+ if (!cstAttr) return rewriter.notifyMatchFailure(op, "operand not splat");
|
||
|
|
-+
|
||
|
|
-+ rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(
|
||
|
|
-+ op, SplatElementsAttr::get(op.getType(),
|
||
|
|
-+ cstAttr.getSplatValue<Attribute>()));
|
||
|
|
-+ return success();
|
||
|
|
-+ }
|
||
|
|
-+};
|
||
|
|
-+
|
||
|
|
-+struct FoldCompareOpPattern : public ShapeOpRewritePattern<CompareOp> {
|
||
|
|
-+ using ShapeOpRewritePattern::ShapeOpRewritePattern;
|
||
|
|
-
|
||
|
|
- LogicalResult matchAndRewrite(CompareOp op,
|
||
|
|
- PatternRewriter& rewriter) const override {
|
||
|
|
- auto resultType = op.getType();
|
||
|
|
-- auto kind = op.getCompareType();
|
||
|
|
-- return evalElementwise(rewriter, op, [&](APInt lhs, APInt rhs) {
|
||
|
|
++ auto resultType = op.getType();
|
||
|
|
+ if (failed(validateShapeFoldDtype(rewriter, op, resultType)))
|
||
|
|
+ return failure();
|
||
|
|
+
|
||
|
|
@@ -1332,15 +1384,23 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
|
||
|
|
+ if (failed(res)) return failure();
|
||
|
|
+ rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(op, res.value());
|
||
|
|
+ return success();
|
||
|
|
-+ }
|
||
|
|
-+
|
||
|
|
+ }
|
||
|
|
+-};
|
||
|
|
+
|
||
|
|
+-struct EvalCompareOpPattern : public FoldOpRewritePattern<CompareOp> {
|
||
|
|
+- using FoldOpRewritePattern::FoldOpRewritePattern;
|
||
|
|
+ struct FoldCompare {
|
||
|
|
+ FoldCompare(ComparisonDirection direction,
|
||
|
|
+ std::optional<ComparisonType> kind)
|
||
|
|
+ : direction(direction), kind(kind) {}
|
||
|
|
+ ComparisonDirection direction;
|
||
|
|
+ std::optional<ComparisonType> kind;
|
||
|
|
-+
|
||
|
|
+
|
||
|
|
+- LogicalResult matchAndRewrite(CompareOp op,
|
||
|
|
+- PatternRewriter& rewriter) const override {
|
||
|
|
+- auto resultType = op.getType();
|
||
|
|
+- auto kind = op.getCompareType();
|
||
|
|
+- return evalElementwise(rewriter, op, [&](APInt lhs, APInt rhs) {
|
||
|
|
+ // TODO: Enable float folding.
|
||
|
|
+ std::optional<APFloat> operator()(APFloat lhs, APFloat rhs) {
|
||
|
|
+ return std::nullopt;
|
||
|
|
@@ -1352,7 +1412,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
|
||
|
|
case ComparisonDirection::EQ:
|
||
|
|
result = lhs == rhs;
|
||
|
|
break;
|
||
|
|
-@@ -431,9 +432,9 @@
|
||
|
|
+@@ -431,9 +432,9 @@ struct EvalCompareOpPattern : public FoldOpRewritePattern<CompareOp> {
|
||
|
|
result = kind == ComparisonType::SIGNED ? lhs.slt(rhs) : lhs.ult(rhs);
|
||
|
|
break;
|
||
|
|
}
|
||
|
|
@@ -1365,7 +1425,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
|
||
|
|
};
|
||
|
|
|
||
|
|
//////////////////////////////////
|
||
|
|
-@@ -441,16 +442,15 @@
|
||
|
|
+@@ -441,16 +442,15 @@ struct EvalCompareOpPattern : public FoldOpRewritePattern<CompareOp> {
|
||
|
|
/////////////////////////////////
|
||
|
|
|
||
|
|
struct FoldConcatenateOpPattern final
|
||
|
|
@@ -1387,7 +1447,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
|
||
|
|
return failure();
|
||
|
|
|
||
|
|
// Fold concatenate when all inputs are constants.
|
||
|
|
-@@ -466,6 +466,7 @@
|
||
|
|
+@@ -466,6 +466,7 @@ struct FoldConcatenateOpPattern final
|
||
|
|
int64_t{1}, std::multiplies<>{});
|
||
|
|
|
||
|
|
SmallVector<Attribute> newElems;
|
||
|
|
@@ -1395,7 +1455,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
|
||
|
|
newElems.reserve(numElems);
|
||
|
|
|
||
|
|
for (int64_t i = 0; i != topSize; ++i) {
|
||
|
|
-@@ -485,31 +486,7 @@
|
||
|
|
+@@ -485,31 +486,7 @@ struct FoldConcatenateOpPattern final
|
||
|
|
int64_t foldOpElementLimit;
|
||
|
|
};
|
||
|
|
|
||
|
|
@@ -1428,20 +1488,17 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
|
||
|
|
using ShapeOpRewritePattern::ShapeOpRewritePattern;
|
||
|
|
|
||
|
|
LogicalResult matchAndRewrite(ConvertOp op,
|
||
|
|
-@@ -532,28 +509,50 @@
|
||
|
|
+@@ -532,28 +509,50 @@ struct EvalConvertOpPattern : public ShapeOpRewritePattern<ConvertOp> {
|
||
|
|
return rewriter.notifyMatchFailure(
|
||
|
|
op, "expected constant integer or float operand");
|
||
|
|
|
||
|
|
- return evalConvert(rewriter, op, elements, resultType);
|
||
|
|
-- }
|
||
|
|
--};
|
||
|
|
--
|
||
|
|
++ return foldConvert(rewriter, op, elements, resultType);
|
||
|
|
+ }
|
||
|
|
+ };
|
||
|
|
+
|
||
|
|
-struct EvalDivOpPattern : public FoldOpRewritePattern<DivOp> {
|
||
|
|
- using FoldOpRewritePattern::FoldOpRewritePattern;
|
||
|
|
-+ return foldConvert(rewriter, op, elements, resultType);
|
||
|
|
-+ }
|
||
|
|
-+};
|
||
|
|
-+
|
||
|
|
+struct FoldDivOpPattern : public ShapeOpRewritePattern<DivOp> {
|
||
|
|
+ using ShapeOpRewritePattern::ShapeOpRewritePattern;
|
||
|
|
|
||
|
|
@@ -1449,12 +1506,6 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
|
||
|
|
PatternRewriter& rewriter) const override {
|
||
|
|
- return evalElementwise(rewriter, op,
|
||
|
|
- [&](APSInt lhs, APSInt rhs) { return lhs / rhs; });
|
||
|
|
-- }
|
||
|
|
--};
|
||
|
|
--
|
||
|
|
--struct EvalGetDimensionSizeOpPattern
|
||
|
|
-- : public FoldOpRewritePattern<GetDimensionSizeOp> {
|
||
|
|
-- using FoldOpRewritePattern::FoldOpRewritePattern;
|
||
|
|
+ auto resultType = op.getType();
|
||
|
|
+ if (failed(validateShapeFoldDtype(rewriter, op, resultType)))
|
||
|
|
+ return failure();
|
||
|
|
@@ -1464,7 +1515,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
|
||
|
|
+ if (failed(res)) return failure();
|
||
|
|
+ rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(op, res.value());
|
||
|
|
+ return success();
|
||
|
|
-+ }
|
||
|
|
+ }
|
||
|
|
+
|
||
|
|
+ struct FoldDivide {
|
||
|
|
+ FoldDivide(bool isUnsignedInt)
|
||
|
|
@@ -1479,8 +1530,11 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
|
||
|
|
+ static APInt foldUint(APInt lhs, APInt rhs) { return lhs.udiv(rhs); }
|
||
|
|
+ static APInt foldSint(APInt lhs, APInt rhs) { return lhs.sdiv(rhs); }
|
||
|
|
+ };
|
||
|
|
-+};
|
||
|
|
-+
|
||
|
|
+ };
|
||
|
|
+
|
||
|
|
+-struct EvalGetDimensionSizeOpPattern
|
||
|
|
+- : public FoldOpRewritePattern<GetDimensionSizeOp> {
|
||
|
|
+- using FoldOpRewritePattern::FoldOpRewritePattern;
|
||
|
|
+struct FoldGetDimensionSizeOpPattern
|
||
|
|
+ : public ShapeOpRewritePattern<GetDimensionSizeOp> {
|
||
|
|
+ using ShapeOpRewritePattern::ShapeOpRewritePattern;
|
||
|
|
@@ -1494,7 +1548,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
|
||
|
|
return failure();
|
||
|
|
|
||
|
|
auto operandType = op.getOperand().getType();
|
||
|
|
-@@ -567,86 +566,187 @@
|
||
|
|
+@@ -567,86 +566,187 @@ struct EvalGetDimensionSizeOpPattern
|
||
|
|
}
|
||
|
|
};
|
||
|
|
|
||
|
|
@@ -1549,11 +1603,6 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
|
||
|
|
- return evalElementwise(rewriter, op, [&](APSInt lhs, APSInt rhs) {
|
||
|
|
- return lhs >= rhs ? lhs : rhs;
|
||
|
|
- });
|
||
|
|
-- }
|
||
|
|
--};
|
||
|
|
--
|
||
|
|
--struct EvalMinOpPattern : public FoldOpRewritePattern<MinOp> {
|
||
|
|
-- using FoldOpRewritePattern::FoldOpRewritePattern;
|
||
|
|
+ auto resultType = op.getType();
|
||
|
|
+ if (failed(validateShapeFoldDtype(rewriter, op, resultType)))
|
||
|
|
+ return failure();
|
||
|
|
@@ -1563,9 +1612,11 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
|
||
|
|
+ if (failed(res)) return failure();
|
||
|
|
+ rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(op, res.value());
|
||
|
|
+ return success();
|
||
|
|
-+ }
|
||
|
|
-+};
|
||
|
|
-+
|
||
|
|
+ }
|
||
|
|
+ };
|
||
|
|
+
|
||
|
|
+-struct EvalMinOpPattern : public FoldOpRewritePattern<MinOp> {
|
||
|
|
+- using FoldOpRewritePattern::FoldOpRewritePattern;
|
||
|
|
+struct FoldMinOpPattern : public ShapeOpRewritePattern<MinOp> {
|
||
|
|
+ using ShapeOpRewritePattern::ShapeOpRewritePattern;
|
||
|
|
|
||
|
|
@@ -1574,11 +1625,6 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
|
||
|
|
- return evalElementwise(rewriter, op, [&](APSInt lhs, APSInt rhs) {
|
||
|
|
- return lhs <= rhs ? lhs : rhs;
|
||
|
|
- });
|
||
|
|
-- }
|
||
|
|
--};
|
||
|
|
--
|
||
|
|
--struct FoldMulOpPattern final : FoldOpRewritePattern<mlir::stablehlo::MulOp> {
|
||
|
|
-- using FoldOpRewritePattern::FoldOpRewritePattern;
|
||
|
|
+ auto resultType = op.getType();
|
||
|
|
+ if (failed(validateShapeFoldDtype(rewriter, op, resultType)))
|
||
|
|
+ return failure();
|
||
|
|
@@ -1588,19 +1634,35 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
|
||
|
|
+ if (failed(res)) return failure();
|
||
|
|
+ rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(op, res.value());
|
||
|
|
+ return success();
|
||
|
|
-+ }
|
||
|
|
-+};
|
||
|
|
-+
|
||
|
|
+ }
|
||
|
|
+ };
|
||
|
|
+
|
||
|
|
+-struct FoldMulOpPattern final : FoldOpRewritePattern<mlir::stablehlo::MulOp> {
|
||
|
|
+- using FoldOpRewritePattern::FoldOpRewritePattern;
|
||
|
|
+// Clamp is folded using Min and Max folders.
|
||
|
|
+struct FoldClampOpPattern : public ShapeOpRewritePattern<ClampOp> {
|
||
|
|
+ using ShapeOpRewritePattern::ShapeOpRewritePattern;
|
||
|
|
-+
|
||
|
|
+
|
||
|
|
+- LogicalResult matchAndRewrite(mlir::stablehlo::MulOp op,
|
||
|
|
+ LogicalResult matchAndRewrite(ClampOp op,
|
||
|
|
-+ PatternRewriter& rewriter) const override {
|
||
|
|
+ PatternRewriter& rewriter) const override {
|
||
|
|
+- TypedAttr lhsAttr;
|
||
|
|
+- matchPattern(op.getLhs(), m_Constant(&lhsAttr));
|
||
|
|
+-
|
||
|
|
+- TypedAttr rhsAttr;
|
||
|
|
+- matchPattern(op.getRhs(), m_Constant(&rhsAttr));
|
||
|
|
+-
|
||
|
|
+- if (TypedAttr res;
|
||
|
|
+- lhsAttr && rhsAttr &&
|
||
|
|
+- (res = foldBinaryOpIntOrFloat(lhsAttr, rhsAttr, std::multiplies<>{}))) {
|
||
|
|
+- rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(op, res);
|
||
|
|
+- return success();
|
||
|
|
+- }
|
||
|
|
+ auto resultType = op.getType();
|
||
|
|
+ if (failed(validateShapeFoldDtype(rewriter, op, resultType)))
|
||
|
|
+ return failure();
|
||
|
|
-+
|
||
|
|
+
|
||
|
|
+- return failure();
|
||
|
|
+ TypedAttr minAttr, operandAttr, maxAttr;
|
||
|
|
+ matchPattern(op.getMin(), m_Constant(&minAttr));
|
||
|
|
+ matchPattern(op.getOperand(), m_Constant(&operandAttr));
|
||
|
|
@@ -1620,43 +1682,19 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
|
||
|
|
+ if (!res) return rewriter.notifyMatchFailure(op, "failed to fold clamp");
|
||
|
|
+ rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(op, res);
|
||
|
|
+ return success();
|
||
|
|
-+ }
|
||
|
|
-+};
|
||
|
|
-+
|
||
|
|
-+struct FoldMulOpPattern final : ShapeOpRewritePattern<mlir::stablehlo::MulOp> {
|
||
|
|
-+ using ShapeOpRewritePattern::ShapeOpRewritePattern;
|
||
|
|
+ }
|
||
|
|
+ };
|
||
|
|
|
||
|
|
- LogicalResult matchAndRewrite(mlir::stablehlo::MulOp op,
|
||
|
|
- PatternRewriter& rewriter) const override {
|
||
|
|
-- TypedAttr lhsAttr;
|
||
|
|
-- matchPattern(op.getLhs(), m_Constant(&lhsAttr));
|
||
|
|
--
|
||
|
|
-- TypedAttr rhsAttr;
|
||
|
|
-- matchPattern(op.getRhs(), m_Constant(&rhsAttr));
|
||
|
|
--
|
||
|
|
-- if (TypedAttr res;
|
||
|
|
-- lhsAttr && rhsAttr &&
|
||
|
|
-- (res = foldBinaryOpIntOrFloat(lhsAttr, rhsAttr, std::multiplies<>{}))) {
|
||
|
|
-- rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(op, res);
|
||
|
|
-- return success();
|
||
|
|
-- }
|
||
|
|
--
|
||
|
|
-- return failure();
|
||
|
|
-- }
|
||
|
|
--};
|
||
|
|
--
|
||
|
|
-struct EvalMulOpPattern : public FoldOpRewritePattern<MulOp> {
|
||
|
|
- using FoldOpRewritePattern::FoldOpRewritePattern;
|
||
|
|
--
|
||
|
|
++struct FoldMulOpPattern final : ShapeOpRewritePattern<mlir::stablehlo::MulOp> {
|
||
|
|
++ using ShapeOpRewritePattern::ShapeOpRewritePattern;
|
||
|
|
+
|
||
|
|
- LogicalResult matchAndRewrite(MulOp op,
|
||
|
|
-- PatternRewriter& rewriter) const override {
|
||
|
|
++ LogicalResult matchAndRewrite(mlir::stablehlo::MulOp op,
|
||
|
|
+ PatternRewriter& rewriter) const override {
|
||
|
|
- return evalElementwise(rewriter, op,
|
||
|
|
- [&](APSInt lhs, APSInt rhs) { return lhs * rhs; });
|
||
|
|
-- }
|
||
|
|
--};
|
||
|
|
--
|
||
|
|
--struct EvalOrOpPattern : public FoldOpRewritePattern<OrOp> {
|
||
|
|
-- using FoldOpRewritePattern::FoldOpRewritePattern;
|
||
|
|
+ if (failed(validateShapeFoldDtype(rewriter, op, op.getType())))
|
||
|
|
+ return failure();
|
||
|
|
+
|
||
|
|
@@ -1664,9 +1702,11 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
|
||
|
|
+ if (failed(res)) return failure();
|
||
|
|
+ rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(op, res.value());
|
||
|
|
+ return success();
|
||
|
|
-+ }
|
||
|
|
-+};
|
||
|
|
-+
|
||
|
|
+ }
|
||
|
|
+ };
|
||
|
|
+
|
||
|
|
+-struct EvalOrOpPattern : public FoldOpRewritePattern<OrOp> {
|
||
|
|
+- using FoldOpRewritePattern::FoldOpRewritePattern;
|
||
|
|
+struct FoldOrOpPattern : public ShapeOpRewritePattern<OrOp> {
|
||
|
|
+ using ShapeOpRewritePattern::ShapeOpRewritePattern;
|
||
|
|
|
||
|
|
@@ -1680,16 +1720,11 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
|
||
|
|
- return evalElementwise(rewriter, op, [&](APSInt lhsInt, APSInt rhsInt) {
|
||
|
|
- return getAPSInt(resultType.getElementType(), lhsInt != 0 || rhsInt != 0);
|
||
|
|
- });
|
||
|
|
-- }
|
||
|
|
--};
|
||
|
|
--
|
||
|
|
--struct EvalRemOpPattern : public FoldOpRewritePattern<RemOp> {
|
||
|
|
-- using FoldOpRewritePattern::FoldOpRewritePattern;
|
||
|
|
+ auto res = foldBinaryOpIntOrFloat(rewriter, op, FoldOr{});
|
||
|
|
+ if (failed(res)) return failure();
|
||
|
|
+ rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(op, res.value());
|
||
|
|
+ return success();
|
||
|
|
-+ }
|
||
|
|
+ }
|
||
|
|
+
|
||
|
|
+ struct FoldOr {
|
||
|
|
+ APInt operator()(APInt lhs, APInt rhs) const {
|
||
|
|
@@ -1699,8 +1734,10 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
|
||
|
|
+ return std::nullopt;
|
||
|
|
+ }
|
||
|
|
+ };
|
||
|
|
-+};
|
||
|
|
-+
|
||
|
|
+ };
|
||
|
|
+
|
||
|
|
+-struct EvalRemOpPattern : public FoldOpRewritePattern<RemOp> {
|
||
|
|
+- using FoldOpRewritePattern::FoldOpRewritePattern;
|
||
|
|
+struct FoldRemOpPattern : public ShapeOpRewritePattern<RemOp> {
|
||
|
|
+ using ShapeOpRewritePattern::ShapeOpRewritePattern;
|
||
|
|
|
||
|
|
@@ -1708,10 +1745,6 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
|
||
|
|
PatternRewriter& rewriter) const override {
|
||
|
|
- return evalElementwise(rewriter, op,
|
||
|
|
- [&](APSInt lhs, APSInt rhs) { return lhs % rhs; });
|
||
|
|
-- }
|
||
|
|
--};
|
||
|
|
--
|
||
|
|
--struct EvalReshapeOpPattern : public ShapeOpRewritePattern<ReshapeOp> {
|
||
|
|
+ auto resultType = op.getType();
|
||
|
|
+ if (failed(validateShapeFoldDtype(rewriter, op, resultType)))
|
||
|
|
+ return failure();
|
||
|
|
@@ -1721,7 +1754,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
|
||
|
|
+ if (failed(res)) return failure();
|
||
|
|
+ rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(op, res.value());
|
||
|
|
+ return success();
|
||
|
|
-+ }
|
||
|
|
+ }
|
||
|
|
+
|
||
|
|
+ struct FoldRem {
|
||
|
|
+ FoldRem(bool isUnsignedInt)
|
||
|
|
@@ -1736,14 +1769,15 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
|
||
|
|
+ static APInt foldUint(APInt lhs, APInt rhs) { return lhs.urem(rhs); }
|
||
|
|
+ static APInt foldSint(APInt lhs, APInt rhs) { return lhs.srem(rhs); }
|
||
|
|
+ };
|
||
|
|
-+};
|
||
|
|
-+
|
||
|
|
+ };
|
||
|
|
+
|
||
|
|
+-struct EvalReshapeOpPattern : public ShapeOpRewritePattern<ReshapeOp> {
|
||
|
|
+// Pattern: reshape(cst, shape) -> cst
|
||
|
|
+struct FoldReshapeOpPattern : public ShapeOpRewritePattern<ReshapeOp> {
|
||
|
|
using ShapeOpRewritePattern::ShapeOpRewritePattern;
|
||
|
|
|
||
|
|
LogicalResult matchAndRewrite(ReshapeOp op,
|
||
|
|
-@@ -656,7 +756,6 @@
|
||
|
|
+@@ -656,7 +756,6 @@ struct EvalReshapeOpPattern : public ShapeOpRewritePattern<ReshapeOp> {
|
||
|
|
failed(validateShapeFoldDtype(rewriter, op, resultType)))
|
||
|
|
return failure();
|
||
|
|
|
||
|
|
@@ -1751,7 +1785,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
|
||
|
|
DenseIntOrFPElementsAttr attr;
|
||
|
|
if (!matchPattern(op.getOperand(), m_Constant(&attr)))
|
||
|
|
return rewriter.notifyMatchFailure(op, "expected constant operand");
|
||
|
|
-@@ -665,53 +764,98 @@
|
||
|
|
+@@ -665,53 +764,98 @@ struct EvalReshapeOpPattern : public ShapeOpRewritePattern<ReshapeOp> {
|
||
|
|
}
|
||
|
|
};
|
||
|
|
|
||
|
|
@@ -1764,16 +1798,14 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
|
||
|
|
PatternRewriter& rewriter) const override {
|
||
|
|
auto resultType = op.getType();
|
||
|
|
- if (failed(validateStaticShapeResult(rewriter, op, resultType)))
|
||
|
|
-- return failure();
|
||
|
|
--
|
||
|
|
++ if (failed(validateStaticShapeResult(rewriter, op, resultType)) ||
|
||
|
|
++ failed(validateShapeFoldDtype(rewriter, op, resultType)))
|
||
|
|
+ return failure();
|
||
|
|
+
|
||
|
|
- SmallVector<APSInt> pred, onTrue, onFalse;
|
||
|
|
- if (failed(hlo::matchInts(op.getPred(), pred)) ||
|
||
|
|
- failed(hlo::matchInts(op.getOnTrue(), onTrue)) ||
|
||
|
|
- failed(hlo::matchInts(op.getOnFalse(), onFalse)))
|
||
|
|
-+ if (failed(validateStaticShapeResult(rewriter, op, resultType)) ||
|
||
|
|
-+ failed(validateShapeFoldDtype(rewriter, op, resultType)))
|
||
|
|
-+ return failure();
|
||
|
|
-+
|
||
|
|
+ DenseIntElementsAttr predAttr;
|
||
|
|
+ DenseElementsAttr onTrueAttr, onFalseAttr;
|
||
|
|
+ matchPattern(op.getPred(), m_Constant(&predAttr));
|
||
|
|
@@ -1783,14 +1815,17 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
|
||
|
|
return rewriter.notifyMatchFailure(op, "expected constant operands");
|
||
|
|
|
||
|
|
- SmallVector<APSInt> result;
|
||
|
|
+- for (auto [predEl, onTrueEl, onFalseEl] :
|
||
|
|
+- llvm::zip(pred, onTrue, onFalse)) {
|
||
|
|
+- result.push_back(predEl != 0 ? onTrueEl : onFalseEl);
|
||
|
|
+ // Optimization, handle splat predicate
|
||
|
|
+ if (isa<SplatElementsAttr>(predAttr)) {
|
||
|
|
+ auto pred = predAttr.getSplatValue<APInt>();
|
||
|
|
+ rewriter.replaceOpWithNewOp<ConstantOp>(
|
||
|
|
+ op, pred.isZero() ? onFalseAttr : onTrueAttr);
|
||
|
|
+ return success();
|
||
|
|
-+ }
|
||
|
|
-+
|
||
|
|
+ }
|
||
|
|
+
|
||
|
|
+ // TODO: Enable float folding.
|
||
|
|
+ if (op.getType().getElementType().isFloat())
|
||
|
|
+ return rewriter.notifyMatchFailure(op, "float select not supported yet");
|
||
|
|
@@ -1800,27 +1835,17 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
|
||
|
|
+ return failure();
|
||
|
|
+
|
||
|
|
+ SmallVector<APInt> result;
|
||
|
|
- for (auto [predEl, onTrueEl, onFalseEl] :
|
||
|
|
-- llvm::zip(pred, onTrue, onFalse)) {
|
||
|
|
-- result.push_back(predEl != 0 ? onTrueEl : onFalseEl);
|
||
|
|
-- }
|
||
|
|
--
|
||
|
|
++ for (auto [predEl, onTrueEl, onFalseEl] :
|
||
|
|
+ llvm::zip(predAttr.getValues<APInt>(), onTrueAttr.getValues<APInt>(),
|
||
|
|
+ onFalseAttr.getValues<APInt>())) {
|
||
|
|
+ result.push_back(!predEl.isZero() ? onTrueEl : onFalseEl);
|
||
|
|
+ }
|
||
|
|
rewriter.replaceOpWithNewOp<ConstantOp>(
|
||
|
|
- op, getTensorAttr(op.getType(), result));
|
||
|
|
-- return success();
|
||
|
|
-- }
|
||
|
|
--};
|
||
|
|
--
|
||
|
|
--struct EvalSignOpPattern : public FoldOpRewritePattern<SignOp> {
|
||
|
|
-- using FoldOpRewritePattern::FoldOpRewritePattern;
|
||
|
|
+ op, DenseIntElementsAttr::get(resultType, result));
|
||
|
|
+
|
||
|
|
-+ return success();
|
||
|
|
-+ }
|
||
|
|
+ return success();
|
||
|
|
+ }
|
||
|
|
+
|
||
|
|
+ struct FoldSelect {
|
||
|
|
+ std::optional<APFloat> operator()(APFloat pred, APFloat onTrue,
|
||
|
|
@@ -1832,8 +1857,10 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
|
||
|
|
+ return pred != 0 ? onTrue : onFalse;
|
||
|
|
+ }
|
||
|
|
+ };
|
||
|
|
-+};
|
||
|
|
-+
|
||
|
|
+ };
|
||
|
|
+
|
||
|
|
+-struct EvalSignOpPattern : public FoldOpRewritePattern<SignOp> {
|
||
|
|
+- using FoldOpRewritePattern::FoldOpRewritePattern;
|
||
|
|
+struct FoldSignOpPattern : public ShapeOpRewritePattern<SignOp> {
|
||
|
|
+ using ShapeOpRewritePattern::ShapeOpRewritePattern;
|
||
|
|
|
||
|
|
@@ -1881,7 +1908,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
|
||
|
|
};
|
||
|
|
|
||
|
|
template <typename RangeType>
|
||
|
|
-@@ -749,13 +893,14 @@
|
||
|
|
+@@ -749,13 +893,14 @@ DenseElementsAttr sliceType(SliceOp& op, const RangeType& data) {
|
||
|
|
ArrayRef<ElementType>(result));
|
||
|
|
}
|
||
|
|
|
||
|
|
@@ -1899,7 +1926,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
|
||
|
|
return failure();
|
||
|
|
|
||
|
|
auto operand = op.getOperand();
|
||
|
|
-@@ -784,36 +929,18 @@
|
||
|
|
+@@ -784,36 +929,18 @@ struct EvalSliceOpPattern : public FoldOpRewritePattern<SliceOp> {
|
||
|
|
};
|
||
|
|
|
||
|
|
struct FoldSubtractOpPattern final
|
||
|
|
@@ -1930,14 +1957,13 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
|
||
|
|
-
|
||
|
|
-struct EvalSubtractOpPattern : public FoldOpRewritePattern<SubtractOp> {
|
||
|
|
- using FoldOpRewritePattern::FoldOpRewritePattern;
|
||
|
|
--
|
||
|
|
++ if (failed(validateShapeFoldDtype(rewriter, op, op.getType())))
|
||
|
|
++ return failure();
|
||
|
|
+
|
||
|
|
- LogicalResult matchAndRewrite(SubtractOp op,
|
||
|
|
- PatternRewriter& rewriter) const override {
|
||
|
|
- return evalElementwise(rewriter, op,
|
||
|
|
- [&](APSInt lhs, APSInt rhs) { return lhs - rhs; });
|
||
|
|
-+ if (failed(validateShapeFoldDtype(rewriter, op, op.getType())))
|
||
|
|
-+ return failure();
|
||
|
|
-+
|
||
|
|
+ auto res = foldBinaryOpIntOrFloat(rewriter, op, std::minus<>{});
|
||
|
|
+ if (failed(res)) return failure();
|
||
|
|
+ rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(op, res.value());
|
||
|
|
@@ -1945,23 +1971,35 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
|
||
|
|
}
|
||
|
|
};
|
||
|
|
|
||
|
|
-@@ -823,42 +950,38 @@
|
||
|
|
+@@ -823,42 +950,38 @@ struct FoldSqrtOpPattern
|
||
|
|
|
||
|
|
LogicalResult matchAndRewrite(mlir::stablehlo::SqrtOp op,
|
||
|
|
PatternRewriter& rewriter) const final {
|
||
|
|
- TypedAttr lhsAttr;
|
||
|
|
- matchPattern(op.getOperand(), m_Constant(&lhsAttr));
|
||
|
|
--
|
||
|
|
++ auto res = foldUnaryOpIntOrFloat(rewriter, op, FoldSqrt());
|
||
|
|
++ if (failed(res)) return failure();
|
||
|
|
++ rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(op, res.value());
|
||
|
|
++ return success();
|
||
|
|
++ }
|
||
|
|
+
|
||
|
|
- if (!lhsAttr)
|
||
|
|
- return rewriter.notifyMatchFailure(op, "operand not constant");
|
||
|
|
--
|
||
|
|
++ struct FoldSqrt {
|
||
|
|
++ std::optional<APFloat> operator()(APFloat operand) {
|
||
|
|
++ if (operand.getSizeInBits(operand.getSemantics()) == 64)
|
||
|
|
++ return APFloat(std::sqrt(operand.convertToDouble()));
|
||
|
|
+
|
||
|
|
- if (auto res = constFoldUnaryOp<FloatAttr, FloatAttr::ValueType, void>(
|
||
|
|
- lhsAttr, foldSqrt)) {
|
||
|
|
- rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
|
||
|
|
- op, op.getType(), llvm::cast<ElementsAttr>(res));
|
||
|
|
- return success();
|
||
|
|
-- }
|
||
|
|
--
|
||
|
|
++ if (operand.getSizeInBits(operand.getSemantics()) == 32)
|
||
|
|
++ return APFloat(sqrtf(operand.convertToFloat()));
|
||
|
|
++ return std::nullopt;
|
||
|
|
+ }
|
||
|
|
+
|
||
|
|
- return rewriter.notifyMatchFailure(op, "unable to fold sqrt");
|
||
|
|
- }
|
||
|
|
-
|
||
|
|
@@ -1973,32 +2011,14 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
|
||
|
|
- return APFloat(sqrtf(a.convertToFloat()));
|
||
|
|
- return {};
|
||
|
|
- }
|
||
|
|
--};
|
||
|
|
--
|
||
|
|
--struct EvalIotaOpPattern : public FoldOpRewritePattern<IotaOp> {
|
||
|
|
-+ auto res = foldUnaryOpIntOrFloat(rewriter, op, FoldSqrt());
|
||
|
|
-+ if (failed(res)) return failure();
|
||
|
|
-+ rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(op, res.value());
|
||
|
|
-+ return success();
|
||
|
|
-+ }
|
||
|
|
-+
|
||
|
|
-+ struct FoldSqrt {
|
||
|
|
-+ std::optional<APFloat> operator()(APFloat operand) {
|
||
|
|
-+ if (operand.getSizeInBits(operand.getSemantics()) == 64)
|
||
|
|
-+ return APFloat(std::sqrt(operand.convertToDouble()));
|
||
|
|
-+
|
||
|
|
-+ if (operand.getSizeInBits(operand.getSemantics()) == 32)
|
||
|
|
-+ return APFloat(sqrtf(operand.convertToFloat()));
|
||
|
|
-+ return std::nullopt;
|
||
|
|
-+ }
|
||
|
|
-+
|
||
|
|
+ // TODO: Enable int folding.
|
||
|
|
+ std::optional<APInt> operator()(APInt operand) {
|
||
|
|
+ return std::nullopt;
|
||
|
|
+ }
|
||
|
|
+ };
|
||
|
|
-+};
|
||
|
|
-+
|
||
|
|
+ };
|
||
|
|
+
|
||
|
|
+-struct EvalIotaOpPattern : public FoldOpRewritePattern<IotaOp> {
|
||
|
|
+struct FoldIotaOpPattern : public FoldOpRewritePattern<IotaOp> {
|
||
|
|
using FoldOpRewritePattern::FoldOpRewritePattern;
|
||
|
|
|
||
|
|
@@ -2015,7 +2035,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
|
||
|
|
|
||
|
|
auto elementType = resultType.getElementType();
|
||
|
|
|
||
|
|
-@@ -929,7 +1052,7 @@
|
||
|
|
+@@ -929,7 +1052,7 @@ DenseElementsAttr transposeType(TransposeOp& op, const RangeType& data) {
|
||
|
|
// transpose(constant) => constant with permuted dimensions
|
||
|
|
// This covers ranked tensor types with 0 dimensions(zero elements) and 0
|
||
|
|
// rank(scalar), as well as splat values.
|
||
|
|
@@ -2024,7 +2044,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
|
||
|
|
using FoldOpRewritePattern::FoldOpRewritePattern;
|
||
|
|
|
||
|
|
LogicalResult matchAndRewrite(TransposeOp op,
|
||
|
|
-@@ -943,6 +1066,7 @@
|
||
|
|
+@@ -943,6 +1066,7 @@ struct EvalTransposeOpPattern : public FoldOpRewritePattern<TransposeOp> {
|
||
|
|
return rewriter.notifyMatchFailure(
|
||
|
|
op, "expected constant integer or float operand");
|
||
|
|
|
||
|
|
@@ -2032,7 +2052,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
|
||
|
|
DenseElementsAttr resAttr;
|
||
|
|
if (auto data = els.tryGetValues<APInt>())
|
||
|
|
resAttr = transposeType(op, *data);
|
||
|
|
-@@ -957,6 +1081,7 @@
|
||
|
|
+@@ -957,6 +1081,7 @@ struct EvalTransposeOpPattern : public FoldOpRewritePattern<TransposeOp> {
|
||
|
|
}
|
||
|
|
};
|
||
|
|
|
||
|
|
@@ -2040,7 +2060,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
|
||
|
|
struct LowerBoolSplatConstantsIntoReduceOpRegion
|
||
|
|
: public FoldOpRewritePattern<ReduceOp> {
|
||
|
|
using FoldOpRewritePattern::FoldOpRewritePattern;
|
||
|
|
-@@ -1160,7 +1285,7 @@
|
||
|
|
+@@ -1160,7 +1285,7 @@ bool hasNoDeclaredSideEffects(Operation* op) {
|
||
|
|
return true;
|
||
|
|
}
|
||
|
|
|
||
|
|
@@ -2049,7 +2069,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
|
||
|
|
: public FoldOpRewritePattern<WhileOp> {
|
||
|
|
using FoldOpRewritePattern::FoldOpRewritePattern;
|
||
|
|
|
||
|
|
-@@ -1233,23 +1358,16 @@
|
||
|
|
+@@ -1233,23 +1358,16 @@ void populateStablehloAggressiveFolderPatterns(
|
||
|
|
PatternBenefit benefit) {
|
||
|
|
populateStablehloShapeFolderPatterns(context, patterns, options, benefit);
|
||
|
|
|
||
|
|
@@ -2083,7 +2103,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
|
||
|
|
}
|
||
|
|
|
||
|
|
class StablehloTargetIndependentOptimizationPass {
|
||
|
|
-@@ -1266,25 +1384,25 @@
|
||
|
|
+@@ -1266,25 +1384,25 @@ void populateStablehloShapeFolderPatterns(
|
||
|
|
MLIRContext* context, RewritePatternSet* patterns,
|
||
|
|
const StablehloAggressiveFolderPassOptions& options,
|
||
|
|
PatternBenefit benefit) {
|
||
|
|
@@ -2128,28 +2148,29 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
|
||
|
|
}
|
||
|
|
|
||
|
|
void populateStablehloShapeFolderPatterns(MLIRContext* context,
|
||
|
|
-diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td b/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td
|
||
|
|
---- stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td
|
||
|
|
-+++ stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td
|
||
|
|
-@@ -48,6 +48,8 @@
|
||
|
|
- def RankEqual : Constraint<
|
||
|
|
+diff --git a/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td b/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td
|
||
|
|
+index 8adb6dbc..3a8edc61 100644
|
||
|
|
+--- a/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td
|
||
|
|
++++ b/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td
|
||
|
|
+@@ -49,6 +49,8 @@ def RankEqual : Constraint<
|
||
|
|
CPred<"llvm::cast<ShapedType>($0.getType()).getRank() == llvm::cast<ShapedType>($1.getType()).getRank()">,
|
||
|
|
"same rank">;
|
||
|
|
-+
|
||
|
|
-+def TensorDimsAllOne : Constraint<CPred<"tensorDimsAllOne($0, $1)">, "all tensor dims are 1">;
|
||
|
|
|
||
|
|
++def TensorDimsAllOne : Constraint<CPred<"tensorDimsAllOne($0, $1)">, "all tensor dims are 1">;
|
||
|
|
++
|
||
|
|
def TypesEqual : Constraint<CPred<"$0.getType() == $1.getType()">, "operands are equal">;
|
||
|
|
|
||
|
|
-@@ -100,6 +102,8 @@
|
||
|
|
- def ZeroExtent : AttrConstraint<
|
||
|
|
+ ///////////
|
||
|
|
+@@ -101,6 +103,8 @@ def ZeroExtent : AttrConstraint<
|
||
|
|
CPred<"cast<DenseElementsAttr>($_self).getNumElements() == 0">,
|
||
|
|
"is zero extent">;
|
||
|
|
-+
|
||
|
|
-+def AnyStaticShapeIntTensor : StaticShapeTensorOf<[HLO_Int]>;
|
||
|
|
|
||
|
|
++def AnyStaticShapeIntTensor : StaticShapeTensorOf<[HLO_Int]>;
|
||
|
|
++
|
||
|
|
///////////
|
||
|
|
//// Native Code Call Utilities
|
||
|
|
-@@ -503,7 +507,7 @@
|
||
|
|
+
|
||
|
|
+@@ -503,7 +507,7 @@ def SelectOp_InvertBroadcastPredicateAndSwap
|
||
|
|
// Must be static shape, otherwise would require broadcasting via
|
||
|
|
// CHLO_ConstantLike.
|
||
|
|
def SubtractOp_FoldToZero
|
||
|
|
@@ -2158,4 +2179,6 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimp
|
||
|
|
(StableHLO_ConstantLike<"0"> $operand)>;
|
||
|
|
|
||
|
|
// Pattern: subtract(X, 0) -> X
|
||
|
|
+--
|
||
|
|
+2.48.1
|
||
|
|
|
||
|
|
--
|
||
|
|
2.48.1
|
||
|
|
|