187 lines
8.8 KiB
Diff
187 lines
8.8 KiB
Diff
|
|
From 2a62bc5df9774810313142eb0c9390aab3cd18f8 Mon Sep 17 00:00:00 2001
|
||
|
|
From: Hugo Mano <hugo@zml.ai>
|
||
|
|
Date: Thu, 29 May 2025 08:02:32 +0200
|
||
|
|
Subject: [PATCH] Revert "Add optional allowOtherDialects field to
|
||
|
|
stablehlo.serialize_portable_artifact."
|
||
|
|
|
||
|
|
Commit: https://github.com/openxla/xla/commit/e7137a383809a24875a95237b1d1f6485acdf710
|
||
|
|
|
||
|
|
Issue: C does not support default arguments with ZigTranslateC
|
||
|
|
---
|
||
|
|
third_party/stablehlo/temporary.patch | 151 ++++----------------------
|
||
|
|
1 file changed, 21 insertions(+), 130 deletions(-)
|
||
|
|
|
||
|
|
diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch
|
||
|
|
index 6e1fd159f9..d17c141b18 100755
|
||
|
|
--- a/third_party/stablehlo/temporary.patch
|
||
|
|
+++ b/third_party/stablehlo/temporary.patch
|
||
|
|
@@ -1,3 +1,23 @@
|
||
|
|
+diff --ruN a/stablehlo/stablehlo/dialect/StablehloOps.cpp b/stablehlo/stablehlo/dialect/StablehloOps.cpp
|
||
|
|
+--- stablehlo/stablehlo/dialect/StablehloOps.cpp
|
||
|
|
++++ stablehlo/stablehlo/dialect/StablehloOps.cpp
|
||
|
|
+@@ -511,12 +511,10 @@
|
||
|
|
+ void CustomCallOp::getEffects(
|
||
|
|
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>&
|
||
|
|
+ effects) {
|
||
|
|
+- // Note: `has_side_effect` "defaults" to `false` but isn't required to exist.
|
||
|
|
+- // This semantic contradiction means, in practical terms, that the attribute
|
||
|
|
+- // won't exist by default but should be *treated* as `false` if missing.
|
||
|
|
+- // `getHasSideEffect()` abstracts this nuance away and returns `false` by
|
||
|
|
+- // default, whereas `getHasSideEffectAttr()` may return a null attribute.
|
||
|
|
+- if (!getHasSideEffect()) return;
|
||
|
|
++ // CustomCall has "all possible effects" unless the has_side_effect is present
|
||
|
|
++ // and set to false.
|
||
|
|
++ auto hasSideEffect = getHasSideEffectAttr();
|
||
|
|
++ if (hasSideEffect && !hasSideEffect.getValue()) return;
|
||
|
|
+ effects.emplace_back(MemoryEffects::Allocate::get());
|
||
|
|
+ effects.emplace_back(MemoryEffects::Free::get());
|
||
|
|
+ effects.emplace_back(MemoryEffects::Write::get());
|
||
|
|
diff --ruN a/stablehlo/stablehlo/dialect/StablehloOps.h b/stablehlo/stablehlo/dialect/StablehloOps.h
|
||
|
|
--- stablehlo/stablehlo/dialect/StablehloOps.h
|
||
|
|
+++ stablehlo/stablehlo/dialect/StablehloOps.h
|
||
|
|
@@ -82,135 +102,6 @@ diff --ruN a/stablehlo/stablehlo/dialect/StablehloOps.td b/stablehlo/stablehlo/d
|
||
|
|
]> {
|
||
|
|
let summary = "Recv operation";
|
||
|
|
let description = [{
|
||
|
|
-diff --ruN a/stablehlo/stablehlo/dialect/TypeInference.cpp b/stablehlo/stablehlo/dialect/TypeInference.cpp
|
||
|
|
---- stablehlo/stablehlo/dialect/TypeInference.cpp
|
||
|
|
-+++ stablehlo/stablehlo/dialect/TypeInference.cpp
|
||
|
|
-@@ -879,7 +879,8 @@
|
||
|
|
-
|
||
|
|
- auto replicaIds = replicaGroups.getValues<int64_t>();
|
||
|
|
-
|
||
|
|
-- llvm::SmallSet<int64_t, 8> replicaIdsSeen;
|
||
|
|
-+ // Large programs can have many replicas, use a set with efficient lookup.
|
||
|
|
-+ llvm::DenseSet<int64_t> replicaIdsSeen;
|
||
|
|
- for (int64_t replicaId : replicaIds) {
|
||
|
|
- // Replica groups are stored in a 2D tensor. If the op supports non-uniform
|
||
|
|
- // groups, null replica IDs are stored as -1.
|
||
|
|
-@@ -1841,6 +1842,7 @@
|
||
|
|
- /*allGroupsMustHaveSameSize=*/true,
|
||
|
|
- /*useGlobalDeviceIds=*/false, splitCount)))
|
||
|
|
- return failure();
|
||
|
|
-+
|
||
|
|
- for (const Value& operand : operands) {
|
||
|
|
- auto operandType = cast<RankedTensorType>(operand.getType());
|
||
|
|
-
|
||
|
|
-@@ -3562,6 +3564,19 @@
|
||
|
|
- DenseIntElementsAttr replicaGroups,
|
||
|
|
- int64_t channelId, bool useGlobalDeviceIds,
|
||
|
|
- ValueRange results) {
|
||
|
|
-+ // all_gather_i3, all_gather_c2, all_gather_c4
|
||
|
|
-+ if (failed(verifyReplicaGroups(location, replicaGroups,
|
||
|
|
-+ /*allGroupsMustHaveSameSize=*/true,
|
||
|
|
-+ useGlobalDeviceIds,
|
||
|
|
-+ /*expectedGroupSize=*/std::nullopt)))
|
||
|
|
-+ return failure();
|
||
|
|
-+
|
||
|
|
-+ // all_gather_c5
|
||
|
|
-+ if (useGlobalDeviceIds && channelId < 0)
|
||
|
|
-+ return emitOptionalError(
|
||
|
|
-+ location,
|
||
|
|
-+ "channel_id cannot be negative when useGlobalDeviceIds is set");
|
||
|
|
-+
|
||
|
|
- for (const auto& [operand, result] : llvm::zip(operands, results)) {
|
||
|
|
- auto operandType = cast<RankedTensorType>(operand.getType());
|
||
|
|
- auto resultType = cast<RankedTensorType>(result.getType());
|
||
|
|
-@@ -3576,19 +3591,6 @@
|
||
|
|
- return emitOptionalError(
|
||
|
|
- location,
|
||
|
|
- "dimension size of operand at 'all_gather_dim' cannot be zero");
|
||
|
|
--
|
||
|
|
-- // all_gather_i3, all_gather_c2, all_gather_c4
|
||
|
|
-- if (failed(verifyReplicaGroups(location, replicaGroups,
|
||
|
|
-- /*allGroupsMustHaveSameSize=*/true,
|
||
|
|
-- useGlobalDeviceIds,
|
||
|
|
-- /*expectedGroupSize=*/std::nullopt)))
|
||
|
|
-- return failure();
|
||
|
|
--
|
||
|
|
-- // all_gather_c5
|
||
|
|
-- if (useGlobalDeviceIds && channelId < 0)
|
||
|
|
-- return emitOptionalError(
|
||
|
|
-- location,
|
||
|
|
-- "channel_id cannot be negative when useGlobalDeviceIds is set");
|
||
|
|
-
|
||
|
|
- // all_gather_c6
|
||
|
|
- if (resultType.getRank() != operandType.getRank())
|
||
|
|
-@@ -3788,7 +3790,7 @@
|
||
|
|
- "but instead it is of rank ", replicaGroupType.getRank());
|
||
|
|
-
|
||
|
|
- auto replicaIds = replicaGroups.getValues<int64_t>();
|
||
|
|
-- llvm::SmallSet<int64_t, 8> replicaIdsSeen;
|
||
|
|
-+ llvm::DenseSet<int64_t> replicaIdsSeen;
|
||
|
|
- for (int64_t replicaId : replicaIds) {
|
||
|
|
- // collective_broadcast_c2
|
||
|
|
- // We only check that is is not negative, as it is impossible
|
||
|
|
-diff --ruN a/stablehlo/stablehlo/integrations/c/StablehloDialectApi.cpp b/stablehlo/stablehlo/integrations/c/StablehloDialectApi.cpp
|
||
|
|
---- stablehlo/stablehlo/integrations/c/StablehloDialectApi.cpp
|
||
|
|
-+++ stablehlo/stablehlo/integrations/c/StablehloDialectApi.cpp
|
||
|
|
-@@ -78,10 +78,11 @@
|
||
|
|
-
|
||
|
|
- MlirLogicalResult stablehloSerializePortableArtifactFromModule(
|
||
|
|
- MlirModule moduleStr, MlirStringRef targetVersion,
|
||
|
|
-- MlirStringCallback callback, void *userData) {
|
||
|
|
-+ MlirStringCallback callback, void *userData, bool allowOtherDialects) {
|
||
|
|
- mlir::detail::CallbackOstream stream(callback, userData);
|
||
|
|
- if (failed(mlir::stablehlo::serializePortableArtifact(
|
||
|
|
-- unwrap(moduleStr), unwrap(targetVersion), stream)))
|
||
|
|
-+ unwrap(moduleStr), unwrap(targetVersion), stream,
|
||
|
|
-+ allowOtherDialects)))
|
||
|
|
- return mlirLogicalResultFailure();
|
||
|
|
- return mlirLogicalResultSuccess();
|
||
|
|
- }
|
||
|
|
-diff --ruN a/stablehlo/stablehlo/integrations/c/StablehloDialectApi.h b/stablehlo/stablehlo/integrations/c/StablehloDialectApi.h
|
||
|
|
---- stablehlo/stablehlo/integrations/c/StablehloDialectApi.h
|
||
|
|
-+++ stablehlo/stablehlo/integrations/c/StablehloDialectApi.h
|
||
|
|
-@@ -92,7 +92,8 @@
|
||
|
|
- stablehloSerializePortableArtifactFromModule(MlirModule moduleStr,
|
||
|
|
- MlirStringRef targetVersion,
|
||
|
|
- MlirStringCallback callback,
|
||
|
|
-- void* userData);
|
||
|
|
-+ void* userData,
|
||
|
|
-+ bool allowOtherDialects = false);
|
||
|
|
-
|
||
|
|
- // Read a StableHLO program from a portable artifact, returning the module as
|
||
|
|
- // MLIR bytecode. Note, this bytecode returned is not a portable artifact,
|
||
|
|
-diff --ruN a/stablehlo/stablehlo/integrations/python/StablehloApi.cpp b/stablehlo/stablehlo/integrations/python/StablehloApi.cpp
|
||
|
|
---- stablehlo/stablehlo/integrations/python/StablehloApi.cpp
|
||
|
|
-+++ stablehlo/stablehlo/integrations/python/StablehloApi.cpp
|
||
|
|
-@@ -102,20 +102,22 @@
|
||
|
|
- //
|
||
|
|
- m.def(
|
||
|
|
- "serialize_portable_artifact",
|
||
|
|
-- [](MlirModule module, std::string_view target) -> nb::bytes {
|
||
|
|
-+ [](MlirModule module, std::string_view target,
|
||
|
|
-+ bool allowOtherDialects) -> nb::bytes {
|
||
|
|
- StringWriterHelper accumulator;
|
||
|
|
- if (mlirLogicalResultIsFailure(
|
||
|
|
- stablehloSerializePortableArtifactFromModule(
|
||
|
|
- module, toMlirStringRef(target),
|
||
|
|
- accumulator.getMlirStringCallback(),
|
||
|
|
-- accumulator.getUserData()))) {
|
||
|
|
-+ accumulator.getUserData(), allowOtherDialects))) {
|
||
|
|
- throw nb::value_error("failed to serialize module");
|
||
|
|
- }
|
||
|
|
-
|
||
|
|
- std::string serialized = accumulator.toString();
|
||
|
|
- return nb::bytes(serialized.data(), serialized.size());
|
||
|
|
- },
|
||
|
|
-- nb::arg("module"), nb::arg("target"));
|
||
|
|
-+ nb::arg("module"), nb::arg("target"),
|
||
|
|
-+ nb::arg("allow_other_dialects") = false);
|
||
|
|
-
|
||
|
|
- m.def(
|
||
|
|
- "deserialize_portable_artifact",
|
||
|
|
diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_convert_to_signless.mlir b/stablehlo/stablehlo/tests/transforms/stablehlo_convert_to_signless.mlir
|
||
|
|
--- stablehlo/stablehlo/tests/transforms/stablehlo_convert_to_signless.mlir
|
||
|
|
+++ stablehlo/stablehlo/tests/transforms/stablehlo_convert_to_signless.mlir
|
||
|
|
@@ -223,4 +114,4 @@ diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_convert_to_signless.
|
||
|
|
%7 = builtin.unrealized_conversion_cast %6 : memref<i16> to memref<ui16>
|
||
|
|
func.return %7 : memref<ui16>
|
||
|
|
}
|
||
|
|
-
|
||
|
|
+
|
||
|
|
--
|
||
|
|
2.39.5 (Apple Git-154)
|
||
|
|
|