From 2a62bc5df9774810313142eb0c9390aab3cd18f8 Mon Sep 17 00:00:00 2001 From: Hugo Mano Date: Thu, 29 May 2025 08:02:32 +0200 Subject: [PATCH] Revert "Add optional allowOtherDialects field to stablehlo.serialize_portable_artifact." Commit: https://github.com/openxla/xla/commit/e7137a383809a24875a95237b1d1f6485acdf710 Issue: C does not support default arguments with ZigTranslateC --- third_party/stablehlo/temporary.patch | 151 ++++---------------------- 1 file changed, 21 insertions(+), 130 deletions(-) diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch index 6e1fd159f9..d17c141b18 100755 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -1,3 +1,23 @@ +diff --ruN a/stablehlo/stablehlo/dialect/StablehloOps.cpp b/stablehlo/stablehlo/dialect/StablehloOps.cpp +--- stablehlo/stablehlo/dialect/StablehloOps.cpp ++++ stablehlo/stablehlo/dialect/StablehloOps.cpp +@@ -511,12 +511,10 @@ + void CustomCallOp::getEffects( + SmallVectorImpl>& + effects) { +- // Note: `has_side_effect` "defaults" to `false` but isn't required to exist. +- // This semantic contradiction means, in practical terms, that the attribute +- // won't exist by default but should be *treated* as `false` if missing. +- // `getHasSideEffect()` abstracts this nuance away and returns `false` by +- // default, whereas `getHasSideEffectAttr()` may return a null attribute. +- if (!getHasSideEffect()) return; ++ // CustomCall has "all possible effects" unless the has_side_effect is present ++ // and set to false. ++ auto hasSideEffect = getHasSideEffectAttr(); ++ if (hasSideEffect && !hasSideEffect.getValue()) return; + effects.emplace_back(MemoryEffects::Allocate::get()); + effects.emplace_back(MemoryEffects::Free::get()); + effects.emplace_back(MemoryEffects::Write::get()); diff --ruN a/stablehlo/stablehlo/dialect/StablehloOps.h b/stablehlo/stablehlo/dialect/StablehloOps.h --- stablehlo/stablehlo/dialect/StablehloOps.h +++ stablehlo/stablehlo/dialect/StablehloOps.h @@ -82,135 +102,6 @@ diff --ruN a/stablehlo/stablehlo/dialect/StablehloOps.td b/stablehlo/stablehlo/d ]> { let summary = "Recv operation"; let description = [{ -diff --ruN a/stablehlo/stablehlo/dialect/TypeInference.cpp b/stablehlo/stablehlo/dialect/TypeInference.cpp ---- stablehlo/stablehlo/dialect/TypeInference.cpp -+++ stablehlo/stablehlo/dialect/TypeInference.cpp -@@ -879,7 +879,8 @@ - - auto replicaIds = replicaGroups.getValues(); - -- llvm::SmallSet replicaIdsSeen; -+ // Large programs can have many replicas, use a set with efficient lookup. -+ llvm::DenseSet replicaIdsSeen; - for (int64_t replicaId : replicaIds) { - // Replica groups are stored in a 2D tensor. If the op supports non-uniform - // groups, null replica IDs are stored as -1. -@@ -1841,6 +1842,7 @@ - /*allGroupsMustHaveSameSize=*/true, - /*useGlobalDeviceIds=*/false, splitCount))) - return failure(); -+ - for (const Value& operand : operands) { - auto operandType = cast(operand.getType()); - -@@ -3562,6 +3564,19 @@ - DenseIntElementsAttr replicaGroups, - int64_t channelId, bool useGlobalDeviceIds, - ValueRange results) { -+ // all_gather_i3, all_gather_c2, all_gather_c4 -+ if (failed(verifyReplicaGroups(location, replicaGroups, -+ /*allGroupsMustHaveSameSize=*/true, -+ useGlobalDeviceIds, -+ /*expectedGroupSize=*/std::nullopt))) -+ return failure(); -+ -+ // all_gather_c5 -+ if (useGlobalDeviceIds && channelId < 0) -+ return emitOptionalError( -+ location, -+ "channel_id cannot be negative when useGlobalDeviceIds is set"); -+ - for (const auto& [operand, result] : llvm::zip(operands, results)) { - auto operandType = cast(operand.getType()); - auto resultType = cast(result.getType()); -@@ -3576,19 +3591,6 @@ - return emitOptionalError( - location, - "dimension size of operand at 'all_gather_dim' cannot be zero"); -- -- // all_gather_i3, all_gather_c2, all_gather_c4 -- if (failed(verifyReplicaGroups(location, replicaGroups, -- /*allGroupsMustHaveSameSize=*/true, -- useGlobalDeviceIds, -- /*expectedGroupSize=*/std::nullopt))) -- return failure(); -- -- // all_gather_c5 -- if (useGlobalDeviceIds && channelId < 0) -- return emitOptionalError( -- location, -- "channel_id cannot be negative when useGlobalDeviceIds is set"); - - // all_gather_c6 - if (resultType.getRank() != operandType.getRank()) -@@ -3788,7 +3790,7 @@ - "but instead it is of rank ", replicaGroupType.getRank()); - - auto replicaIds = replicaGroups.getValues(); -- llvm::SmallSet replicaIdsSeen; -+ llvm::DenseSet replicaIdsSeen; - for (int64_t replicaId : replicaIds) { - // collective_broadcast_c2 - // We only check that is is not negative, as it is impossible -diff --ruN a/stablehlo/stablehlo/integrations/c/StablehloDialectApi.cpp b/stablehlo/stablehlo/integrations/c/StablehloDialectApi.cpp ---- stablehlo/stablehlo/integrations/c/StablehloDialectApi.cpp -+++ stablehlo/stablehlo/integrations/c/StablehloDialectApi.cpp -@@ -78,10 +78,11 @@ - - MlirLogicalResult stablehloSerializePortableArtifactFromModule( - MlirModule moduleStr, MlirStringRef targetVersion, -- MlirStringCallback callback, void *userData) { -+ MlirStringCallback callback, void *userData, bool allowOtherDialects) { - mlir::detail::CallbackOstream stream(callback, userData); - if (failed(mlir::stablehlo::serializePortableArtifact( -- unwrap(moduleStr), unwrap(targetVersion), stream))) -+ unwrap(moduleStr), unwrap(targetVersion), stream, -+ allowOtherDialects))) - return mlirLogicalResultFailure(); - return mlirLogicalResultSuccess(); - } -diff --ruN a/stablehlo/stablehlo/integrations/c/StablehloDialectApi.h b/stablehlo/stablehlo/integrations/c/StablehloDialectApi.h ---- stablehlo/stablehlo/integrations/c/StablehloDialectApi.h -+++ stablehlo/stablehlo/integrations/c/StablehloDialectApi.h -@@ -92,7 +92,8 @@ - stablehloSerializePortableArtifactFromModule(MlirModule moduleStr, - MlirStringRef targetVersion, - MlirStringCallback callback, -- void* userData); -+ void* userData, -+ bool allowOtherDialects = false); - - // Read a StableHLO program from a portable artifact, returning the module as - // MLIR bytecode. Note, this bytecode returned is not a portable artifact, -diff --ruN a/stablehlo/stablehlo/integrations/python/StablehloApi.cpp b/stablehlo/stablehlo/integrations/python/StablehloApi.cpp ---- stablehlo/stablehlo/integrations/python/StablehloApi.cpp -+++ stablehlo/stablehlo/integrations/python/StablehloApi.cpp -@@ -102,20 +102,22 @@ - // - m.def( - "serialize_portable_artifact", -- [](MlirModule module, std::string_view target) -> nb::bytes { -+ [](MlirModule module, std::string_view target, -+ bool allowOtherDialects) -> nb::bytes { - StringWriterHelper accumulator; - if (mlirLogicalResultIsFailure( - stablehloSerializePortableArtifactFromModule( - module, toMlirStringRef(target), - accumulator.getMlirStringCallback(), -- accumulator.getUserData()))) { -+ accumulator.getUserData(), allowOtherDialects))) { - throw nb::value_error("failed to serialize module"); - } - - std::string serialized = accumulator.toString(); - return nb::bytes(serialized.data(), serialized.size()); - }, -- nb::arg("module"), nb::arg("target")); -+ nb::arg("module"), nb::arg("target"), -+ nb::arg("allow_other_dialects") = false); - - m.def( - "deserialize_portable_artifact", diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_convert_to_signless.mlir b/stablehlo/stablehlo/tests/transforms/stablehlo_convert_to_signless.mlir --- stablehlo/stablehlo/tests/transforms/stablehlo_convert_to_signless.mlir +++ stablehlo/stablehlo/tests/transforms/stablehlo_convert_to_signless.mlir @@ -223,4 +114,4 @@ diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_convert_to_signless. %7 = builtin.unrealized_conversion_cast %6 : memref to memref func.return %7 : memref } - + -- 2.39.5 (Apple Git-154)