Radix/third_party/modules/xla/20250527.0-cb67f2f/patches/0003-Revert-Add-optional-allowOtherDialects-field-to-stab.patch

187 lines
8.8 KiB
Diff

From 2a62bc5df9774810313142eb0c9390aab3cd18f8 Mon Sep 17 00:00:00 2001
From: Hugo Mano <hugo@zml.ai>
Date: Thu, 29 May 2025 08:02:32 +0200
Subject: [PATCH] Revert "Add optional allowOtherDialects field to
stablehlo.serialize_portable_artifact."
Commit: https://github.com/openxla/xla/commit/e7137a383809a24875a95237b1d1f6485acdf710
Issue: C does not support default arguments with ZigTranslateC
---
third_party/stablehlo/temporary.patch | 151 ++++----------------------
1 file changed, 21 insertions(+), 130 deletions(-)
diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch
index 6e1fd159f9..d17c141b18 100755
--- a/third_party/stablehlo/temporary.patch
+++ b/third_party/stablehlo/temporary.patch
@@ -1,3 +1,23 @@
+diff --ruN a/stablehlo/stablehlo/dialect/StablehloOps.cpp b/stablehlo/stablehlo/dialect/StablehloOps.cpp
+--- stablehlo/stablehlo/dialect/StablehloOps.cpp
++++ stablehlo/stablehlo/dialect/StablehloOps.cpp
+@@ -511,12 +511,10 @@
+ void CustomCallOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>&
+ effects) {
+- // Note: `has_side_effect` "defaults" to `false` but isn't required to exist.
+- // This semantic contradiction means, in practical terms, that the attribute
+- // won't exist by default but should be *treated* as `false` if missing.
+- // `getHasSideEffect()` abstracts this nuance away and returns `false` by
+- // default, whereas `getHasSideEffectAttr()` may return a null attribute.
+- if (!getHasSideEffect()) return;
++ // CustomCall has "all possible effects" unless the has_side_effect is present
++ // and set to false.
++ auto hasSideEffect = getHasSideEffectAttr();
++ if (hasSideEffect && !hasSideEffect.getValue()) return;
+ effects.emplace_back(MemoryEffects::Allocate::get());
+ effects.emplace_back(MemoryEffects::Free::get());
+ effects.emplace_back(MemoryEffects::Write::get());
diff --ruN a/stablehlo/stablehlo/dialect/StablehloOps.h b/stablehlo/stablehlo/dialect/StablehloOps.h
--- stablehlo/stablehlo/dialect/StablehloOps.h
+++ stablehlo/stablehlo/dialect/StablehloOps.h
@@ -82,135 +102,6 @@ diff --ruN a/stablehlo/stablehlo/dialect/StablehloOps.td b/stablehlo/stablehlo/d
]> {
let summary = "Recv operation";
let description = [{
-diff --ruN a/stablehlo/stablehlo/dialect/TypeInference.cpp b/stablehlo/stablehlo/dialect/TypeInference.cpp
---- stablehlo/stablehlo/dialect/TypeInference.cpp
-+++ stablehlo/stablehlo/dialect/TypeInference.cpp
-@@ -879,7 +879,8 @@
-
- auto replicaIds = replicaGroups.getValues<int64_t>();
-
-- llvm::SmallSet<int64_t, 8> replicaIdsSeen;
-+ // Large programs can have many replicas, use a set with efficient lookup.
-+ llvm::DenseSet<int64_t> replicaIdsSeen;
- for (int64_t replicaId : replicaIds) {
- // Replica groups are stored in a 2D tensor. If the op supports non-uniform
- // groups, null replica IDs are stored as -1.
-@@ -1841,6 +1842,7 @@
- /*allGroupsMustHaveSameSize=*/true,
- /*useGlobalDeviceIds=*/false, splitCount)))
- return failure();
-+
- for (const Value& operand : operands) {
- auto operandType = cast<RankedTensorType>(operand.getType());
-
-@@ -3562,6 +3564,19 @@
- DenseIntElementsAttr replicaGroups,
- int64_t channelId, bool useGlobalDeviceIds,
- ValueRange results) {
-+ // all_gather_i3, all_gather_c2, all_gather_c4
-+ if (failed(verifyReplicaGroups(location, replicaGroups,
-+ /*allGroupsMustHaveSameSize=*/true,
-+ useGlobalDeviceIds,
-+ /*expectedGroupSize=*/std::nullopt)))
-+ return failure();
-+
-+ // all_gather_c5
-+ if (useGlobalDeviceIds && channelId < 0)
-+ return emitOptionalError(
-+ location,
-+ "channel_id cannot be negative when useGlobalDeviceIds is set");
-+
- for (const auto& [operand, result] : llvm::zip(operands, results)) {
- auto operandType = cast<RankedTensorType>(operand.getType());
- auto resultType = cast<RankedTensorType>(result.getType());
-@@ -3576,19 +3591,6 @@
- return emitOptionalError(
- location,
- "dimension size of operand at 'all_gather_dim' cannot be zero");
--
-- // all_gather_i3, all_gather_c2, all_gather_c4
-- if (failed(verifyReplicaGroups(location, replicaGroups,
-- /*allGroupsMustHaveSameSize=*/true,
-- useGlobalDeviceIds,
-- /*expectedGroupSize=*/std::nullopt)))
-- return failure();
--
-- // all_gather_c5
-- if (useGlobalDeviceIds && channelId < 0)
-- return emitOptionalError(
-- location,
-- "channel_id cannot be negative when useGlobalDeviceIds is set");
-
- // all_gather_c6
- if (resultType.getRank() != operandType.getRank())
-@@ -3788,7 +3790,7 @@
- "but instead it is of rank ", replicaGroupType.getRank());
-
- auto replicaIds = replicaGroups.getValues<int64_t>();
-- llvm::SmallSet<int64_t, 8> replicaIdsSeen;
-+ llvm::DenseSet<int64_t> replicaIdsSeen;
- for (int64_t replicaId : replicaIds) {
- // collective_broadcast_c2
- // We only check that is is not negative, as it is impossible
-diff --ruN a/stablehlo/stablehlo/integrations/c/StablehloDialectApi.cpp b/stablehlo/stablehlo/integrations/c/StablehloDialectApi.cpp
---- stablehlo/stablehlo/integrations/c/StablehloDialectApi.cpp
-+++ stablehlo/stablehlo/integrations/c/StablehloDialectApi.cpp
-@@ -78,10 +78,11 @@
-
- MlirLogicalResult stablehloSerializePortableArtifactFromModule(
- MlirModule moduleStr, MlirStringRef targetVersion,
-- MlirStringCallback callback, void *userData) {
-+ MlirStringCallback callback, void *userData, bool allowOtherDialects) {
- mlir::detail::CallbackOstream stream(callback, userData);
- if (failed(mlir::stablehlo::serializePortableArtifact(
-- unwrap(moduleStr), unwrap(targetVersion), stream)))
-+ unwrap(moduleStr), unwrap(targetVersion), stream,
-+ allowOtherDialects)))
- return mlirLogicalResultFailure();
- return mlirLogicalResultSuccess();
- }
-diff --ruN a/stablehlo/stablehlo/integrations/c/StablehloDialectApi.h b/stablehlo/stablehlo/integrations/c/StablehloDialectApi.h
---- stablehlo/stablehlo/integrations/c/StablehloDialectApi.h
-+++ stablehlo/stablehlo/integrations/c/StablehloDialectApi.h
-@@ -92,7 +92,8 @@
- stablehloSerializePortableArtifactFromModule(MlirModule moduleStr,
- MlirStringRef targetVersion,
- MlirStringCallback callback,
-- void* userData);
-+ void* userData,
-+ bool allowOtherDialects = false);
-
- // Read a StableHLO program from a portable artifact, returning the module as
- // MLIR bytecode. Note, this bytecode returned is not a portable artifact,
-diff --ruN a/stablehlo/stablehlo/integrations/python/StablehloApi.cpp b/stablehlo/stablehlo/integrations/python/StablehloApi.cpp
---- stablehlo/stablehlo/integrations/python/StablehloApi.cpp
-+++ stablehlo/stablehlo/integrations/python/StablehloApi.cpp
-@@ -102,20 +102,22 @@
- //
- m.def(
- "serialize_portable_artifact",
-- [](MlirModule module, std::string_view target) -> nb::bytes {
-+ [](MlirModule module, std::string_view target,
-+ bool allowOtherDialects) -> nb::bytes {
- StringWriterHelper accumulator;
- if (mlirLogicalResultIsFailure(
- stablehloSerializePortableArtifactFromModule(
- module, toMlirStringRef(target),
- accumulator.getMlirStringCallback(),
-- accumulator.getUserData()))) {
-+ accumulator.getUserData(), allowOtherDialects))) {
- throw nb::value_error("failed to serialize module");
- }
-
- std::string serialized = accumulator.toString();
- return nb::bytes(serialized.data(), serialized.size());
- },
-- nb::arg("module"), nb::arg("target"));
-+ nb::arg("module"), nb::arg("target"),
-+ nb::arg("allow_other_dialects") = false);
-
- m.def(
- "deserialize_portable_artifact",
diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_convert_to_signless.mlir b/stablehlo/stablehlo/tests/transforms/stablehlo_convert_to_signless.mlir
--- stablehlo/stablehlo/tests/transforms/stablehlo_convert_to_signless.mlir
+++ stablehlo/stablehlo/tests/transforms/stablehlo_convert_to_signless.mlir
@@ -223,4 +114,4 @@ diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_convert_to_signless.
%7 = builtin.unrealized_conversion_cast %6 : memref<i16> to memref<ui16>
func.return %7 : memref<ui16>
}
-
+
--
2.39.5 (Apple Git-154)