diff --ruN a/stablehlo/BUILD.bazel b/stablehlo/BUILD.bazel --- stablehlo/BUILD.bazel +++ stablehlo/BUILD.bazel @@ -340,6 +340,21 @@ ], ) +gentbl_cc_library( + name = "stablehlo_create_compatibility_expander_inc_gen", + tbl_outs = [ + ( + ["--gen-rewriters"], + "stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.td", + deps = [ + ":stablehlo_ops_td_files", + ], +) + cc_library( name = "interpreter_ops", srcs = [ @@ -1086,6 +1101,7 @@ "stablehlo/transforms/StablehloAggressiveSimplification.cpp", "stablehlo/transforms/StablehloCanonicalizeDynamism.cpp", "stablehlo/transforms/StablehloConvertToSignless.cpp", + "stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp", "stablehlo/transforms/StablehloLegalizeCompositeToCall.cpp", "stablehlo/transforms/StablehloLegalizeDeprecatedOps.cpp", "stablehlo/transforms/StablehloLegalizeQDQToQuantizedOp.cpp", @@ -1109,6 +1125,7 @@ ":chlo_ops", ":chlo_rewriters_inc_gen", ":linalg_passes", + ":stablehlo_create_compatibility_expander_inc_gen", ":stablehlo_legalize_deprecated_ops_inc_gen", ":stablehlo_ops", ":stablehlo_ops_inc_gen", diff --ruN a/stablehlo/stablehlo/dialect/Version.cpp b/stablehlo/stablehlo/dialect/Version.cpp --- stablehlo/stablehlo/dialect/Version.cpp +++ stablehlo/stablehlo/dialect/Version.cpp @@ -82,7 +82,7 @@ case CompatibilityRequirement::WEEK_4: return Version(1, 3, 0); // v1.3.0 - Jul 15, 2024 case CompatibilityRequirement::WEEK_12: - return Version(1, 0, 0); // v1.0.0 - May 14, 2024 + return Version(1, 1, 0); // v1.1.0 - May 30, 2024 case CompatibilityRequirement::MAX: return Version::getMinimumVersion(); } diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir b/stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir --- stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir +++ stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir @@ -0,0 +1,43 @@ +// RUN: stablehlo-opt %s -verify-diagnostics -split-input-file -allow-unregistered-dialect --stablehlo-create-compatibility-expander='target=1.0.0' | FileCheck %s --check-prefixes=CHECK +// RUN: stablehlo-opt %s -verify-diagnostics -split-input-file --stablehlo-create-compatibility-expander='target=1.6.0' | FileCheck %s --check-prefixes=CHECK-NO-DOWNGRADE + +// ----- + +// CHECK-LABEL @tan_op_non_complex +// CHECK: %[[sine0:.*]] = stablehlo.sine %arg0 : tensor<4xf64> +// CHECK-NEXT: %[[cosine1:.*]] = stablehlo.cosine %arg0 : tensor<4xf64> +// CHECK-NEXT: %[[div2:.*]] = stablehlo.divide %[[sine0]], %[[cosine1]] : tensor<4xf64> +// CHECK-NEXT: return %[[div2]] : tensor<4xf64> +func.func @tan_op_non_complex(%arg0: tensor<4xf64>) -> tensor<4xf64> { + // CHECK-NO-DOWNGRADE: stablehlo.tan %arg0 : tensor<4xf64> + %1 = stablehlo.tan %arg0 : tensor<4xf64> + func.return %1 : tensor<4xf64> +} + +// ----- + +// CHECK-LABEL: @tan_op_complex +// CHECK: %[[cst:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<4xf64> +// CHECK: %[[complex:.*]] = stablehlo.complex %arg0, %arg1 : tensor<4xcomplex> +// CHECK: %[[real:.*]] = stablehlo.real %[[complex]] : (tensor<4xcomplex>) -> tensor<4xf64> +// CHECK: %[[sine:.*]] = stablehlo.sine %[[real]] : tensor<4xf64> +// CHECK: %[[cosine:.*]] = stablehlo.cosine %[[real]] : tensor<4xf64> +// CHECK: %[[divide1:.*]] = stablehlo.divide %[[sine]], %[[cosine]] : tensor<4xf64> +// CHECK: %[[imag:.*]] = stablehlo.imag %[[complex]] : (tensor<4xcomplex>) -> tensor<4xf64> +// CHECK: %[[tanh:.*]] = stablehlo.tanh %[[imag]] : tensor<4xf64> +// CHECK: %[[complex2:.*]] = stablehlo.complex %[[divide1]], %[[tanh]] : tensor<4xcomplex> +// CHECK: %[[multiply:.*]] = stablehlo.multiply %[[divide1]], %[[tanh]] : tensor<4xf64> +// CHECK: %[[negate:.*]] = stablehlo.negate %[[multiply]] : tensor<4xf64> +// CHECK: %[[complex3:.*]] = stablehlo.complex %[[cst]], %[[negate]] : tensor<4xcomplex> +// CHECK: %[[divide2:.*]] = stablehlo.divide %[[complex2]], %[[complex3]] : tensor<4xcomplex> +// CHECK: %[[real2:.*]] = stablehlo.real %[[divide2]] : (tensor<4xcomplex>) -> tensor<4xf64> +// CHECK: %[[imag2:.*]] = stablehlo.imag %[[divide2]] : (tensor<4xcomplex>) -> tensor<4xf64> +// CHECK: return %[[real2]], %[[imag2]] : tensor<4xf64>, tensor<4xf64> +func.func @tan_op_complex(%arg0: tensor<4xf64>, %arg1: tensor<4xf64>) -> (tensor<4xf64>, tensor<4xf64>) { + %0 = stablehlo.complex %arg0, %arg1 : tensor<4xcomplex> + // CHECK-NO-DOWNGRADE: stablehlo.tan %0 : tensor<4xcomplex> + %1 = stablehlo.tan %0 : tensor<4xcomplex> + %2 = stablehlo.real %1 : (tensor<4xcomplex>) -> tensor<4xf64> + %3 = stablehlo.imag %1 : (tensor<4xcomplex>) -> tensor<4xf64> + func.return %2, %3 : tensor<4xf64>, tensor<4xf64> +} diff --ruN a/stablehlo/stablehlo/transforms/CMakeLists.txt b/stablehlo/stablehlo/transforms/CMakeLists.txt --- stablehlo/stablehlo/transforms/CMakeLists.txt +++ stablehlo/stablehlo/transforms/CMakeLists.txt @@ -20,6 +20,10 @@ mlir_tablegen(ChloDecompositionPatterns.h.inc --gen-rewriters) add_public_tablegen_target(ChloDecompositionPatternsIncGen) +set(LLVM_TARGET_DEFINITIONS StablehloCreateCompatibilityExpanderPatterns.td) +mlir_tablegen(StablehloCreateCompatibilityExpanderPatterns.h.inc --gen-rewriters) +add_public_tablegen_target(StablehloCreateCompatibilityExpanderPatternsIncGen) + set(LLVM_TARGET_DEFINITIONS StablehloLegalizeDeprecatedOpsPatterns.td) mlir_tablegen(StablehloLegalizeDeprecatedOpsPatterns.h.inc --gen-rewriters) add_public_tablegen_target(StablehloLegalizeDeprecatedOpsPatternsIncGen) @@ -27,6 +31,7 @@ set(LLVM_TARGET_DEFINITIONS VhloToVersionPatterns.td) mlir_tablegen(VhloToVersionPatterns.h.inc --gen-rewriters) add_public_tablegen_target(VhloToVersionPatterns) + add_mlir_dialect_library(StablehloPasses PARTIAL_SOURCES_INTENDED @@ -37,6 +42,7 @@ StablehloAggressiveSimplification.cpp StablehloCanonicalizeDynamism.cpp StablehloConvertToSignless.cpp + StablehloCreateCompatibilityExpander.cpp StablehloLegalizeCompositeToCall.cpp StablehloLegalizeDeprecatedOps.cpp StablehloLegalizeQuantToMath.cpp @@ -53,6 +59,7 @@ StablehloLegalizeDeprecatedOpsPatternsIncGen PassesIncGen VhloToVersionPatterns + StablehloCreateCompatibilityExpanderPatternsIncGen LINK_LIBS PUBLIC ChloOps diff --ruN a/stablehlo/stablehlo/transforms/Passes.h b/stablehlo/stablehlo/transforms/Passes.h --- stablehlo/stablehlo/transforms/Passes.h +++ stablehlo/stablehlo/transforms/Passes.h @@ -25,6 +25,7 @@ #include "mlir/Pass/Pass.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" +#include "stablehlo/dialect/Version.h" namespace mlir { namespace stablehlo { @@ -96,6 +97,12 @@ void populateShapeToStablehloPatterns(MLIRContext *context, RewritePatternSet *patterns); +/// Collection of patterns to create compatibility expander for StableHLO +/// operations. +void populateStablehloCreateCompatibilityExpanderPatterns( + RewritePatternSet *patterns, MLIRContext *context, + vhlo::Version targetVersion); + //// Additional pass constructors //// std::unique_ptr> createStablehloRefineArgumentsPass( diff --ruN a/stablehlo/stablehlo/transforms/Passes.td b/stablehlo/stablehlo/transforms/Passes.td --- stablehlo/stablehlo/transforms/Passes.td +++ stablehlo/stablehlo/transforms/Passes.td @@ -292,3 +292,51 @@ "mlir::stablehlo::StablehloDialect", ]; } + +def StablehloCreateCompatibilityExpanderPass : Pass<"stablehlo-create-compatibility-expander", "mlir::func::FuncOp"> { + let summary = "Create compatibility expander for StableHLO operations."; + + let description = [{ + StableHLO ops gets updates or new op is introduced in the latest versions. + This opt-in pass expands backward compatibility with older StableHLO + versions by decomposing newer StableHLO operations into equivalent + operations supported by those older versions. + + Why is this an opt-in pass? + + Occasionally, StableHLO op enhancements are used to greatly simplify the + handling of certain common patterns in the OpenXLA ecosystem. This + includes things like TanOp, which has high framework and compiler support, + as well as gather/scatter batching dimensions, which can be represented + using slices, but makes sharding much more difficult. For this category of + new features, we do not offer automatic downgrade, since it may throw away + important information used in subsequent optimizations. This pass can be + used to expand these ops based on a target version to maximize compatibility + at the expense of potentially less optimal compilation. + + ```mlir + func.func @tan_op_non_complex(%arg0: tensor<4xf64>) -> tensor<4xf64> { + %1 = stablehlo.tan %arg0 : tensor<4xf64> + func.return %1 : tensor<4xf64> + } + ``` + + will become: + + ```mlir + func.func @tan_op_non_complex(%arg0: tensor<4xf64>) -> tensor<4xf64> { + %0 = stablehlo.sine %arg0 : tensor<4xf64> + %1 = stablehlo.cosine %arg0 : tensor<4xf64> + %2 = stablehlo.divide %0, %1 : tensor<4xf64> + return %2 : tensor<4xf64> + } + ``` + }]; + let options = [ + Option<"targetVersionOption", "target", "std::string", "", + "The target version. Must be a version of the form #.#.#.">, + ]; + let dependentDialects = [ + "mlir::stablehlo::StablehloDialect", + ]; +} diff --ruN a/stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp b/stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp --- stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp +++ stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp @@ -0,0 +1,128 @@ +/* Copyright 2024 The StableHLO Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include + +#include "llvm/ADT/APFloat.h" +#include "llvm/Support/ErrorHandling.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "stablehlo/dialect/Version.h" +#include "stablehlo/transforms/Passes.h" + +namespace mlir { +namespace stablehlo { +#define GEN_PASS_DEF_STABLEHLOCREATECOMPATIBILITYEXPANDERPASS +#include "stablehlo/transforms/Passes.h.inc" + +namespace { + +//===----------------------------------------------------------------------===// +// Helpers. +//===----------------------------------------------------------------------===// + +// Creates a constant with all ones. +static Value createConstantWithAllOnes(OpBuilder &b, Location loc, Value val) { + auto shapedTy = dyn_cast(val.getType()); + if (!shapedTy) llvm_unreachable("Unsupported shaped type."); + + mlir::DenseElementsAttr elementsAttr = + mlir::DenseElementsAttr::get(shapedTy, 1.0); + + return b.create(loc, val.getType(), + elementsAttr); +} + +// Check user-specified target version. +vhlo::Version validateTargetVersion(llvm::StringRef versionRef) { + auto failOrVersion = vhlo::Version::fromString(versionRef); + if (failed(failOrVersion)) { + assert(!versionRef.empty() && + "No target version specified. Target version must be of the form " + "`#.#.#`."); + assert(versionRef.empty() && + "Invalid target version argument. Target version must be of the " + "form `#.#.#`."); + } + vhlo::Version targetVersion = *failOrVersion; + assert((vhlo::Version::getMinimumVersion() <= targetVersion) && + "target version is less than minimum supported."); + assert((targetVersion <= vhlo::Version::getCurrentVersion()) && + "target version is greater than current version."); + return targetVersion; +} + +//===----------------------------------------------------------------------===// +// Pass +//===----------------------------------------------------------------------===// + +struct StablehloCreateCompatibilityExpanderPass + : public impl::StablehloCreateCompatibilityExpanderPassBase< + StablehloCreateCompatibilityExpanderPass> { + StablehloCreateCompatibilityExpanderPass() + : StablehloCreateCompatibilityExpanderPassBase< + StablehloCreateCompatibilityExpanderPass>() {} + StablehloCreateCompatibilityExpanderPass( + const StablehloCreateCompatibilityExpanderPassOptions &opts) + : StablehloCreateCompatibilityExpanderPassBase< + StablehloCreateCompatibilityExpanderPass>(opts) {} + + public: + LogicalResult initialize(MLIRContext *context) override { + auto targetVersion = validateTargetVersion(targetVersionOption); + + config.useTopDownTraversal = true; + RewritePatternSet patterns_(context); + populateStablehloCreateCompatibilityExpanderPatterns(&patterns_, context, + targetVersion); + patterns = std::move(patterns_); + return success(); + } + + void runOnOperation() override { + auto func = getOperation(); + if (failed(applyPatternsAndFoldGreedily(func, patterns, config))) { + func.emitError( + "Failed to converge StableHLOCreateCompatibilityExpanderPass in ") + << config.maxIterations << " iterations"; + signalPassFailure(); + } + } + + private: + FrozenRewritePatternSet patterns; + GreedyRewriteConfig config; +}; + +#include "stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.h.inc" + +} // namespace + +void populateStablehloCreateCompatibilityExpanderPatterns( + RewritePatternSet *patterns, MLIRContext *context, + vhlo::Version targetVersion) { + // StableHLO TanOp is introduced in v1.4.0. + if (targetVersion < vhlo::Version(1, 4, 0)) { + patterns->add(context); + patterns->add(context); + } +} + +} // namespace stablehlo +} // namespace mlir diff --ruN a/stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.td b/stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.td --- stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.td +++ stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.td @@ -0,0 +1,47 @@ +/* Copyright 2022 The StableHLO Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +include "mlir/IR/OpBase.td" +include "stablehlo/dialect/StablehloOps.td" + +def ComplexElementType : Type< + CPred<"isa(cast($_self).getElementType())">, + "Complex element type">; + +def NonComplexElementType : Type< + CPred<"!isa(cast($_self).getElementType())">, + "Non-complex element type">; + +def createConstantWithAllOnes : NativeCodeCall<"createConstantWithAllOnes($_builder, $_loc, $0)">; + +// Express `tan` as +// sine(x) / cosine(x) +def TanOp_CompatiblityExpander : Pat<(StableHLO_TanOp NonComplexElementType:$input), + (StableHLO_DivOp + (StableHLO_SineOp $input), + (StableHLO_CosineOp $input) + )>; + +// Express `tan(a + bi)` as +// (tan(a) + i tanh(b)) / (1 - i tan(a) * tanh(b)) +def TanOp_ComplexElementType_CompatiblityExpander : Pat<(StableHLO_TanOp ComplexElementType:$input), + (StableHLO_DivOp + (StableHLO_ComplexOp + (StableHLO_TanOp:$tan (StableHLO_RealOp $input)), + (StableHLO_TanhOp:$tanh (StableHLO_ImagOp $input))), + (StableHLO_ComplexOp + (createConstantWithAllOnes $tan), + (StableHLO_NegOp (StableHLO_MulOp $tan, $tanh))) + )>;