402 lines
17 KiB
Diff
402 lines
17 KiB
Diff
|
|
diff --ruN a/stablehlo/BUILD.bazel b/stablehlo/BUILD.bazel
|
||
|
|
--- stablehlo/BUILD.bazel
|
||
|
|
+++ stablehlo/BUILD.bazel
|
||
|
|
@@ -340,6 +340,21 @@
|
||
|
|
],
|
||
|
|
)
|
||
|
|
|
||
|
|
+gentbl_cc_library(
|
||
|
|
+ name = "stablehlo_create_compatibility_expander_inc_gen",
|
||
|
|
+ tbl_outs = [
|
||
|
|
+ (
|
||
|
|
+ ["--gen-rewriters"],
|
||
|
|
+ "stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.h.inc",
|
||
|
|
+ ),
|
||
|
|
+ ],
|
||
|
|
+ tblgen = "@llvm-project//mlir:mlir-tblgen",
|
||
|
|
+ td_file = "stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.td",
|
||
|
|
+ deps = [
|
||
|
|
+ ":stablehlo_ops_td_files",
|
||
|
|
+ ],
|
||
|
|
+)
|
||
|
|
+
|
||
|
|
cc_library(
|
||
|
|
name = "interpreter_ops",
|
||
|
|
srcs = [
|
||
|
|
@@ -1086,6 +1101,7 @@
|
||
|
|
"stablehlo/transforms/StablehloAggressiveSimplification.cpp",
|
||
|
|
"stablehlo/transforms/StablehloCanonicalizeDynamism.cpp",
|
||
|
|
"stablehlo/transforms/StablehloConvertToSignless.cpp",
|
||
|
|
+ "stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp",
|
||
|
|
"stablehlo/transforms/StablehloLegalizeCompositeToCall.cpp",
|
||
|
|
"stablehlo/transforms/StablehloLegalizeDeprecatedOps.cpp",
|
||
|
|
"stablehlo/transforms/StablehloLegalizeQDQToQuantizedOp.cpp",
|
||
|
|
@@ -1109,6 +1125,7 @@
|
||
|
|
":chlo_ops",
|
||
|
|
":chlo_rewriters_inc_gen",
|
||
|
|
":linalg_passes",
|
||
|
|
+ ":stablehlo_create_compatibility_expander_inc_gen",
|
||
|
|
":stablehlo_legalize_deprecated_ops_inc_gen",
|
||
|
|
":stablehlo_ops",
|
||
|
|
":stablehlo_ops_inc_gen",
|
||
|
|
diff --ruN a/stablehlo/stablehlo/dialect/Version.cpp b/stablehlo/stablehlo/dialect/Version.cpp
|
||
|
|
--- stablehlo/stablehlo/dialect/Version.cpp
|
||
|
|
+++ stablehlo/stablehlo/dialect/Version.cpp
|
||
|
|
@@ -82,7 +82,7 @@
|
||
|
|
case CompatibilityRequirement::WEEK_4:
|
||
|
|
return Version(1, 3, 0); // v1.3.0 - Jul 15, 2024
|
||
|
|
case CompatibilityRequirement::WEEK_12:
|
||
|
|
- return Version(1, 0, 0); // v1.0.0 - May 14, 2024
|
||
|
|
+ return Version(1, 1, 0); // v1.1.0 - May 30, 2024
|
||
|
|
case CompatibilityRequirement::MAX:
|
||
|
|
return Version::getMinimumVersion();
|
||
|
|
}
|
||
|
|
diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir b/stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir
|
||
|
|
--- stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir
|
||
|
|
+++ stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir
|
||
|
|
@@ -0,0 +1,43 @@
|
||
|
|
+// RUN: stablehlo-opt %s -verify-diagnostics -split-input-file -allow-unregistered-dialect --stablehlo-create-compatibility-expander='target=1.0.0' | FileCheck %s --check-prefixes=CHECK
|
||
|
|
+// RUN: stablehlo-opt %s -verify-diagnostics -split-input-file --stablehlo-create-compatibility-expander='target=1.6.0' | FileCheck %s --check-prefixes=CHECK-NO-DOWNGRADE
|
||
|
|
+
|
||
|
|
+// -----
|
||
|
|
+
|
||
|
|
+// CHECK-LABEL @tan_op_non_complex
|
||
|
|
+// CHECK: %[[sine0:.*]] = stablehlo.sine %arg0 : tensor<4xf64>
|
||
|
|
+// CHECK-NEXT: %[[cosine1:.*]] = stablehlo.cosine %arg0 : tensor<4xf64>
|
||
|
|
+// CHECK-NEXT: %[[div2:.*]] = stablehlo.divide %[[sine0]], %[[cosine1]] : tensor<4xf64>
|
||
|
|
+// CHECK-NEXT: return %[[div2]] : tensor<4xf64>
|
||
|
|
+func.func @tan_op_non_complex(%arg0: tensor<4xf64>) -> tensor<4xf64> {
|
||
|
|
+ // CHECK-NO-DOWNGRADE: stablehlo.tan %arg0 : tensor<4xf64>
|
||
|
|
+ %1 = stablehlo.tan %arg0 : tensor<4xf64>
|
||
|
|
+ func.return %1 : tensor<4xf64>
|
||
|
|
+}
|
||
|
|
+
|
||
|
|
+// -----
|
||
|
|
+
|
||
|
|
+// CHECK-LABEL: @tan_op_complex
|
||
|
|
+// CHECK: %[[cst:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<4xf64>
|
||
|
|
+// CHECK: %[[complex:.*]] = stablehlo.complex %arg0, %arg1 : tensor<4xcomplex<f64>>
|
||
|
|
+// CHECK: %[[real:.*]] = stablehlo.real %[[complex]] : (tensor<4xcomplex<f64>>) -> tensor<4xf64>
|
||
|
|
+// CHECK: %[[sine:.*]] = stablehlo.sine %[[real]] : tensor<4xf64>
|
||
|
|
+// CHECK: %[[cosine:.*]] = stablehlo.cosine %[[real]] : tensor<4xf64>
|
||
|
|
+// CHECK: %[[divide1:.*]] = stablehlo.divide %[[sine]], %[[cosine]] : tensor<4xf64>
|
||
|
|
+// CHECK: %[[imag:.*]] = stablehlo.imag %[[complex]] : (tensor<4xcomplex<f64>>) -> tensor<4xf64>
|
||
|
|
+// CHECK: %[[tanh:.*]] = stablehlo.tanh %[[imag]] : tensor<4xf64>
|
||
|
|
+// CHECK: %[[complex2:.*]] = stablehlo.complex %[[divide1]], %[[tanh]] : tensor<4xcomplex<f64>>
|
||
|
|
+// CHECK: %[[multiply:.*]] = stablehlo.multiply %[[divide1]], %[[tanh]] : tensor<4xf64>
|
||
|
|
+// CHECK: %[[negate:.*]] = stablehlo.negate %[[multiply]] : tensor<4xf64>
|
||
|
|
+// CHECK: %[[complex3:.*]] = stablehlo.complex %[[cst]], %[[negate]] : tensor<4xcomplex<f64>>
|
||
|
|
+// CHECK: %[[divide2:.*]] = stablehlo.divide %[[complex2]], %[[complex3]] : tensor<4xcomplex<f64>>
|
||
|
|
+// CHECK: %[[real2:.*]] = stablehlo.real %[[divide2]] : (tensor<4xcomplex<f64>>) -> tensor<4xf64>
|
||
|
|
+// CHECK: %[[imag2:.*]] = stablehlo.imag %[[divide2]] : (tensor<4xcomplex<f64>>) -> tensor<4xf64>
|
||
|
|
+// CHECK: return %[[real2]], %[[imag2]] : tensor<4xf64>, tensor<4xf64>
|
||
|
|
+func.func @tan_op_complex(%arg0: tensor<4xf64>, %arg1: tensor<4xf64>) -> (tensor<4xf64>, tensor<4xf64>) {
|
||
|
|
+ %0 = stablehlo.complex %arg0, %arg1 : tensor<4xcomplex<f64>>
|
||
|
|
+ // CHECK-NO-DOWNGRADE: stablehlo.tan %0 : tensor<4xcomplex<f64>>
|
||
|
|
+ %1 = stablehlo.tan %0 : tensor<4xcomplex<f64>>
|
||
|
|
+ %2 = stablehlo.real %1 : (tensor<4xcomplex<f64>>) -> tensor<4xf64>
|
||
|
|
+ %3 = stablehlo.imag %1 : (tensor<4xcomplex<f64>>) -> tensor<4xf64>
|
||
|
|
+ func.return %2, %3 : tensor<4xf64>, tensor<4xf64>
|
||
|
|
+}
|
||
|
|
diff --ruN a/stablehlo/stablehlo/transforms/CMakeLists.txt b/stablehlo/stablehlo/transforms/CMakeLists.txt
|
||
|
|
--- stablehlo/stablehlo/transforms/CMakeLists.txt
|
||
|
|
+++ stablehlo/stablehlo/transforms/CMakeLists.txt
|
||
|
|
@@ -20,6 +20,10 @@
|
||
|
|
mlir_tablegen(ChloDecompositionPatterns.h.inc --gen-rewriters)
|
||
|
|
add_public_tablegen_target(ChloDecompositionPatternsIncGen)
|
||
|
|
|
||
|
|
+set(LLVM_TARGET_DEFINITIONS StablehloCreateCompatibilityExpanderPatterns.td)
|
||
|
|
+mlir_tablegen(StablehloCreateCompatibilityExpanderPatterns.h.inc --gen-rewriters)
|
||
|
|
+add_public_tablegen_target(StablehloCreateCompatibilityExpanderPatternsIncGen)
|
||
|
|
+
|
||
|
|
set(LLVM_TARGET_DEFINITIONS StablehloLegalizeDeprecatedOpsPatterns.td)
|
||
|
|
mlir_tablegen(StablehloLegalizeDeprecatedOpsPatterns.h.inc --gen-rewriters)
|
||
|
|
add_public_tablegen_target(StablehloLegalizeDeprecatedOpsPatternsIncGen)
|
||
|
|
@@ -27,6 +31,7 @@
|
||
|
|
set(LLVM_TARGET_DEFINITIONS VhloToVersionPatterns.td)
|
||
|
|
mlir_tablegen(VhloToVersionPatterns.h.inc --gen-rewriters)
|
||
|
|
add_public_tablegen_target(VhloToVersionPatterns)
|
||
|
|
+
|
||
|
|
|
||
|
|
add_mlir_dialect_library(StablehloPasses
|
||
|
|
PARTIAL_SOURCES_INTENDED
|
||
|
|
@@ -37,6 +42,7 @@
|
||
|
|
StablehloAggressiveSimplification.cpp
|
||
|
|
StablehloCanonicalizeDynamism.cpp
|
||
|
|
StablehloConvertToSignless.cpp
|
||
|
|
+ StablehloCreateCompatibilityExpander.cpp
|
||
|
|
StablehloLegalizeCompositeToCall.cpp
|
||
|
|
StablehloLegalizeDeprecatedOps.cpp
|
||
|
|
StablehloLegalizeQuantToMath.cpp
|
||
|
|
@@ -53,6 +59,7 @@
|
||
|
|
StablehloLegalizeDeprecatedOpsPatternsIncGen
|
||
|
|
PassesIncGen
|
||
|
|
VhloToVersionPatterns
|
||
|
|
+ StablehloCreateCompatibilityExpanderPatternsIncGen
|
||
|
|
|
||
|
|
LINK_LIBS PUBLIC
|
||
|
|
ChloOps
|
||
|
|
diff --ruN a/stablehlo/stablehlo/transforms/Passes.h b/stablehlo/stablehlo/transforms/Passes.h
|
||
|
|
--- stablehlo/stablehlo/transforms/Passes.h
|
||
|
|
+++ stablehlo/stablehlo/transforms/Passes.h
|
||
|
|
@@ -25,6 +25,7 @@
|
||
|
|
#include "mlir/Pass/Pass.h"
|
||
|
|
#include "mlir/Support/LogicalResult.h"
|
||
|
|
#include "mlir/Transforms/DialectConversion.h"
|
||
|
|
+#include "stablehlo/dialect/Version.h"
|
||
|
|
|
||
|
|
namespace mlir {
|
||
|
|
namespace stablehlo {
|
||
|
|
@@ -96,6 +97,12 @@
|
||
|
|
void populateShapeToStablehloPatterns(MLIRContext *context,
|
||
|
|
RewritePatternSet *patterns);
|
||
|
|
|
||
|
|
+/// Collection of patterns to create compatibility expander for StableHLO
|
||
|
|
+/// operations.
|
||
|
|
+void populateStablehloCreateCompatibilityExpanderPatterns(
|
||
|
|
+ RewritePatternSet *patterns, MLIRContext *context,
|
||
|
|
+ vhlo::Version targetVersion);
|
||
|
|
+
|
||
|
|
//// Additional pass constructors ////
|
||
|
|
|
||
|
|
std::unique_ptr<OperationPass<ModuleOp>> createStablehloRefineArgumentsPass(
|
||
|
|
diff --ruN a/stablehlo/stablehlo/transforms/Passes.td b/stablehlo/stablehlo/transforms/Passes.td
|
||
|
|
--- stablehlo/stablehlo/transforms/Passes.td
|
||
|
|
+++ stablehlo/stablehlo/transforms/Passes.td
|
||
|
|
@@ -292,3 +292,51 @@
|
||
|
|
"mlir::stablehlo::StablehloDialect",
|
||
|
|
];
|
||
|
|
}
|
||
|
|
+
|
||
|
|
+def StablehloCreateCompatibilityExpanderPass : Pass<"stablehlo-create-compatibility-expander", "mlir::func::FuncOp"> {
|
||
|
|
+ let summary = "Create compatibility expander for StableHLO operations.";
|
||
|
|
+
|
||
|
|
+ let description = [{
|
||
|
|
+ StableHLO ops gets updates or new op is introduced in the latest versions.
|
||
|
|
+ This opt-in pass expands backward compatibility with older StableHLO
|
||
|
|
+ versions by decomposing newer StableHLO operations into equivalent
|
||
|
|
+ operations supported by those older versions.
|
||
|
|
+
|
||
|
|
+ Why is this an opt-in pass?
|
||
|
|
+
|
||
|
|
+ Occasionally, StableHLO op enhancements are used to greatly simplify the
|
||
|
|
+ handling of certain common patterns in the OpenXLA ecosystem. This
|
||
|
|
+ includes things like TanOp, which has high framework and compiler support,
|
||
|
|
+ as well as gather/scatter batching dimensions, which can be represented
|
||
|
|
+ using slices, but makes sharding much more difficult. For this category of
|
||
|
|
+ new features, we do not offer automatic downgrade, since it may throw away
|
||
|
|
+ important information used in subsequent optimizations. This pass can be
|
||
|
|
+ used to expand these ops based on a target version to maximize compatibility
|
||
|
|
+ at the expense of potentially less optimal compilation.
|
||
|
|
+
|
||
|
|
+ ```mlir
|
||
|
|
+ func.func @tan_op_non_complex(%arg0: tensor<4xf64>) -> tensor<4xf64> {
|
||
|
|
+ %1 = stablehlo.tan %arg0 : tensor<4xf64>
|
||
|
|
+ func.return %1 : tensor<4xf64>
|
||
|
|
+ }
|
||
|
|
+ ```
|
||
|
|
+
|
||
|
|
+ will become:
|
||
|
|
+
|
||
|
|
+ ```mlir
|
||
|
|
+ func.func @tan_op_non_complex(%arg0: tensor<4xf64>) -> tensor<4xf64> {
|
||
|
|
+ %0 = stablehlo.sine %arg0 : tensor<4xf64>
|
||
|
|
+ %1 = stablehlo.cosine %arg0 : tensor<4xf64>
|
||
|
|
+ %2 = stablehlo.divide %0, %1 : tensor<4xf64>
|
||
|
|
+ return %2 : tensor<4xf64>
|
||
|
|
+ }
|
||
|
|
+ ```
|
||
|
|
+ }];
|
||
|
|
+ let options = [
|
||
|
|
+ Option<"targetVersionOption", "target", "std::string", "",
|
||
|
|
+ "The target version. Must be a version of the form #.#.#.">,
|
||
|
|
+ ];
|
||
|
|
+ let dependentDialects = [
|
||
|
|
+ "mlir::stablehlo::StablehloDialect",
|
||
|
|
+ ];
|
||
|
|
+}
|
||
|
|
diff --ruN a/stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp b/stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp
|
||
|
|
--- stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp
|
||
|
|
+++ stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp
|
||
|
|
@@ -0,0 +1,128 @@
|
||
|
|
+/* Copyright 2024 The StableHLO Authors. All Rights Reserved.
|
||
|
|
+Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
|
+you may not use this file except in compliance with the License.
|
||
|
|
+You may obtain a copy of the License at
|
||
|
|
+ http://www.apache.org/licenses/LICENSE-2.0
|
||
|
|
+Unless required by applicable law or agreed to in writing, software
|
||
|
|
+distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
|
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
|
+See the License for the specific language governing permissions and
|
||
|
|
+limitations under the License.
|
||
|
|
+==============================================================================*/
|
||
|
|
+
|
||
|
|
+#include <fcntl.h>
|
||
|
|
+
|
||
|
|
+#include <cassert>
|
||
|
|
+
|
||
|
|
+#include "llvm/ADT/APFloat.h"
|
||
|
|
+#include "llvm/Support/ErrorHandling.h"
|
||
|
|
+#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||
|
|
+#include "mlir/IR/BuiltinAttributes.h"
|
||
|
|
+#include "mlir/IR/PatternMatch.h"
|
||
|
|
+#include "mlir/Support/LLVM.h"
|
||
|
|
+#include "mlir/Transforms/DialectConversion.h"
|
||
|
|
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||
|
|
+#include "stablehlo/dialect/StablehloOps.h"
|
||
|
|
+#include "stablehlo/dialect/Version.h"
|
||
|
|
+#include "stablehlo/transforms/Passes.h"
|
||
|
|
+
|
||
|
|
+namespace mlir {
|
||
|
|
+namespace stablehlo {
|
||
|
|
+#define GEN_PASS_DEF_STABLEHLOCREATECOMPATIBILITYEXPANDERPASS
|
||
|
|
+#include "stablehlo/transforms/Passes.h.inc"
|
||
|
|
+
|
||
|
|
+namespace {
|
||
|
|
+
|
||
|
|
+//===----------------------------------------------------------------------===//
|
||
|
|
+// Helpers.
|
||
|
|
+//===----------------------------------------------------------------------===//
|
||
|
|
+
|
||
|
|
+// Creates a constant with all ones.
|
||
|
|
+static Value createConstantWithAllOnes(OpBuilder &b, Location loc, Value val) {
|
||
|
|
+ auto shapedTy = dyn_cast<mlir::ShapedType>(val.getType());
|
||
|
|
+ if (!shapedTy) llvm_unreachable("Unsupported shaped type.");
|
||
|
|
+
|
||
|
|
+ mlir::DenseElementsAttr elementsAttr =
|
||
|
|
+ mlir::DenseElementsAttr::get(shapedTy, 1.0);
|
||
|
|
+
|
||
|
|
+ return b.create<mlir::stablehlo::ConstantOp>(loc, val.getType(),
|
||
|
|
+ elementsAttr);
|
||
|
|
+}
|
||
|
|
+
|
||
|
|
+// Check user-specified target version.
|
||
|
|
+vhlo::Version validateTargetVersion(llvm::StringRef versionRef) {
|
||
|
|
+ auto failOrVersion = vhlo::Version::fromString(versionRef);
|
||
|
|
+ if (failed(failOrVersion)) {
|
||
|
|
+ assert(!versionRef.empty() &&
|
||
|
|
+ "No target version specified. Target version must be of the form "
|
||
|
|
+ "`#.#.#`.");
|
||
|
|
+ assert(versionRef.empty() &&
|
||
|
|
+ "Invalid target version argument. Target version must be of the "
|
||
|
|
+ "form `#.#.#`.");
|
||
|
|
+ }
|
||
|
|
+ vhlo::Version targetVersion = *failOrVersion;
|
||
|
|
+ assert((vhlo::Version::getMinimumVersion() <= targetVersion) &&
|
||
|
|
+ "target version is less than minimum supported.");
|
||
|
|
+ assert((targetVersion <= vhlo::Version::getCurrentVersion()) &&
|
||
|
|
+ "target version is greater than current version.");
|
||
|
|
+ return targetVersion;
|
||
|
|
+}
|
||
|
|
+
|
||
|
|
+//===----------------------------------------------------------------------===//
|
||
|
|
+// Pass
|
||
|
|
+//===----------------------------------------------------------------------===//
|
||
|
|
+
|
||
|
|
+struct StablehloCreateCompatibilityExpanderPass
|
||
|
|
+ : public impl::StablehloCreateCompatibilityExpanderPassBase<
|
||
|
|
+ StablehloCreateCompatibilityExpanderPass> {
|
||
|
|
+ StablehloCreateCompatibilityExpanderPass()
|
||
|
|
+ : StablehloCreateCompatibilityExpanderPassBase<
|
||
|
|
+ StablehloCreateCompatibilityExpanderPass>() {}
|
||
|
|
+ StablehloCreateCompatibilityExpanderPass(
|
||
|
|
+ const StablehloCreateCompatibilityExpanderPassOptions &opts)
|
||
|
|
+ : StablehloCreateCompatibilityExpanderPassBase<
|
||
|
|
+ StablehloCreateCompatibilityExpanderPass>(opts) {}
|
||
|
|
+
|
||
|
|
+ public:
|
||
|
|
+ LogicalResult initialize(MLIRContext *context) override {
|
||
|
|
+ auto targetVersion = validateTargetVersion(targetVersionOption);
|
||
|
|
+
|
||
|
|
+ config.useTopDownTraversal = true;
|
||
|
|
+ RewritePatternSet patterns_(context);
|
||
|
|
+ populateStablehloCreateCompatibilityExpanderPatterns(&patterns_, context,
|
||
|
|
+ targetVersion);
|
||
|
|
+ patterns = std::move(patterns_);
|
||
|
|
+ return success();
|
||
|
|
+ }
|
||
|
|
+
|
||
|
|
+ void runOnOperation() override {
|
||
|
|
+ auto func = getOperation();
|
||
|
|
+ if (failed(applyPatternsAndFoldGreedily(func, patterns, config))) {
|
||
|
|
+ func.emitError(
|
||
|
|
+ "Failed to converge StableHLOCreateCompatibilityExpanderPass in ")
|
||
|
|
+ << config.maxIterations << " iterations";
|
||
|
|
+ signalPassFailure();
|
||
|
|
+ }
|
||
|
|
+ }
|
||
|
|
+
|
||
|
|
+ private:
|
||
|
|
+ FrozenRewritePatternSet patterns;
|
||
|
|
+ GreedyRewriteConfig config;
|
||
|
|
+};
|
||
|
|
+
|
||
|
|
+#include "stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.h.inc"
|
||
|
|
+
|
||
|
|
+} // namespace
|
||
|
|
+
|
||
|
|
+void populateStablehloCreateCompatibilityExpanderPatterns(
|
||
|
|
+ RewritePatternSet *patterns, MLIRContext *context,
|
||
|
|
+ vhlo::Version targetVersion) {
|
||
|
|
+ // StableHLO TanOp is introduced in v1.4.0.
|
||
|
|
+ if (targetVersion < vhlo::Version(1, 4, 0)) {
|
||
|
|
+ patterns->add<TanOp_ComplexElementType_CompatiblityExpander>(context);
|
||
|
|
+ patterns->add<TanOp_CompatiblityExpander>(context);
|
||
|
|
+ }
|
||
|
|
+}
|
||
|
|
+
|
||
|
|
+} // namespace stablehlo
|
||
|
|
+} // namespace mlir
|
||
|
|
diff --ruN a/stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.td b/stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.td
|
||
|
|
--- stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.td
|
||
|
|
+++ stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.td
|
||
|
|
@@ -0,0 +1,47 @@
|
||
|
|
+/* Copyright 2022 The StableHLO Authors.
|
||
|
|
+
|
||
|
|
+Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
|
+you may not use this file except in compliance with the License.
|
||
|
|
+You may obtain a copy of the License at
|
||
|
|
+
|
||
|
|
+ http://www.apache.org/licenses/LICENSE-2.0
|
||
|
|
+
|
||
|
|
+Unless required by applicable law or agreed to in writing, software
|
||
|
|
+distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
|
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
|
+See the License for the specific language governing permissions and
|
||
|
|
+limitations under the License.
|
||
|
|
+==============================================================================*/
|
||
|
|
+
|
||
|
|
+include "mlir/IR/OpBase.td"
|
||
|
|
+include "stablehlo/dialect/StablehloOps.td"
|
||
|
|
+
|
||
|
|
+def ComplexElementType : Type<
|
||
|
|
+ CPred<"isa<ComplexType>(cast<ShapedType>($_self).getElementType())">,
|
||
|
|
+ "Complex element type">;
|
||
|
|
+
|
||
|
|
+def NonComplexElementType : Type<
|
||
|
|
+ CPred<"!isa<ComplexType>(cast<ShapedType>($_self).getElementType())">,
|
||
|
|
+ "Non-complex element type">;
|
||
|
|
+
|
||
|
|
+def createConstantWithAllOnes : NativeCodeCall<"createConstantWithAllOnes($_builder, $_loc, $0)">;
|
||
|
|
+
|
||
|
|
+// Express `tan` as
|
||
|
|
+// sine(x) / cosine(x)
|
||
|
|
+def TanOp_CompatiblityExpander : Pat<(StableHLO_TanOp NonComplexElementType:$input),
|
||
|
|
+ (StableHLO_DivOp
|
||
|
|
+ (StableHLO_SineOp $input),
|
||
|
|
+ (StableHLO_CosineOp $input)
|
||
|
|
+ )>;
|
||
|
|
+
|
||
|
|
+// Express `tan(a + bi)` as
|
||
|
|
+// (tan(a) + i tanh(b)) / (1 - i tan(a) * tanh(b))
|
||
|
|
+def TanOp_ComplexElementType_CompatiblityExpander : Pat<(StableHLO_TanOp ComplexElementType:$input),
|
||
|
|
+ (StableHLO_DivOp
|
||
|
|
+ (StableHLO_ComplexOp
|
||
|
|
+ (StableHLO_TanOp:$tan (StableHLO_RealOp $input)),
|
||
|
|
+ (StableHLO_TanhOp:$tanh (StableHLO_ImagOp $input))),
|
||
|
|
+ (StableHLO_ComplexOp
|
||
|
|
+ (createConstantWithAllOnes $tan),
|
||
|
|
+ (StableHLO_NegOp (StableHLO_MulOp $tan, $tanh)))
|
||
|
|
+ )>;
|
||
|
|
|