diff --git a/shardy/dialect/mpmd/transforms/common/BUILD b/shardy/dialect/mpmd/transforms/common/BUILD index e7c5204a2..2efe64a20 100644 --- a/shardy/dialect/mpmd/transforms/common/BUILD +++ b/shardy/dialect/mpmd/transforms/common/BUILD @@ -31,12 +31,14 @@ cc_library( name = "passes", srcs = [ "absorb_inferred_fragments.cc", + "add_side_effect_to_avoid_cse.cc", "call_rewrites.cc", "copy_constants.cc", "fragment_dce.cc", "fragment_dedup.cc", "merge_fragments.cc", "merge_transfers.cc", + "remove_side_effect_after_cse.cc", "remove_transfer_cycles.cc", "rule_based_merge.cc", "split_bwd_fragments.cc", diff --git a/shardy/dialect/mpmd/transforms/common/add_side_effect_to_avoid_cse.cc b/shardy/dialect/mpmd/transforms/common/add_side_effect_to_avoid_cse.cc new file mode 100644 index 000000000..8a903e447 --- /dev/null +++ b/shardy/dialect/mpmd/transforms/common/add_side_effect_to_avoid_cse.cc @@ -0,0 +1,57 @@ +/* Copyright 2025 The MPMD Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "shardy/dialect/mpmd/transforms/common/passes.h" +#include "shardy/dialect/mpmd/transforms/common/utils.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "mlir/Pass/Pass.h" + +namespace mlir::mpmd { + +namespace { + +#define GEN_PASS_DEF_ADDSIDEEFFECTTOAVOIDCSEPASS +#include "shardy/dialect/mpmd/transforms/common/passes.h.inc" + +// Adds `has_side_effect=true` attribute to `stablehlo.custom_call` ops that +// have `mhlo.no_cse` attribute. +struct AddSideEffectToAvoidCSEPass + : public impl::AddSideEffectToAvoidCSEPassBase< + AddSideEffectToAvoidCSEPass> { + using impl::AddSideEffectToAvoidCSEPassBase< + AddSideEffectToAvoidCSEPass>::AddSideEffectToAvoidCSEPassBase; + + void runOnOperation() override { + func::FuncOp funcOp = getOperation(); + funcOp.walk([&](stablehlo::CustomCallOp customCallOp) { + if (customCallOp->hasAttr(kMhloNoCseAttr)) { + customCallOp.setHasSideEffectAttr( + BoolAttr::get(customCallOp.getContext(), true)); + } + }); + } +}; + +} // namespace + +std::unique_ptr createAddSideEffectToAvoidCSEPass() { + return std::make_unique(); +} + +} // namespace mlir::mpmd diff --git a/shardy/dialect/mpmd/transforms/common/passes.td b/shardy/dialect/mpmd/transforms/common/passes.td index 7316c85a0..6fd4d9429 100644 --- a/shardy/dialect/mpmd/transforms/common/passes.td +++ b/shardy/dialect/mpmd/transforms/common/passes.td @@ -15,6 +15,30 @@ limitations under the License. include "mlir/Pass/PassBase.td" +def AddSideEffectToAvoidCSEPass : + PassBase<"mpmd-add-side-effect-to-avoid-cse", "OperationPass"> { + let summary = "Adds a side effect attribute to custom_call ops with " + "{mhlo.no_cse} to avoid CSE."; + let description = [{ + For `stablehlo.custom_call` operations that have the `{mhlo.no_cse}` + attribute, this pass adds an `{has_side_effect = true}` attribute. + This prevents MLIR's CSE pass from eliminating these operations, because + CSE skips operations with side effects. + }]; +} + +def RemoveSideEffectAfterCSEPass : + PassBase<"mpmd-remove-side-effect-after-cse", "OperationPass"> { + let summary = "Removes side effect attribute from custom_call ops with " + "{mhlo.no_cse}."; + let description = [{ + For `stablehlo.custom_call` operations that have the `{mhlo.no_cse}` + attribute, this pass removes the `{has_side_effect = true}` attribute if + it exists. This is useful to run after CSE to remove the attribute that + is no longer needed. + }]; +} + // TODO: b/374694825 - This pass is not complete yet. In particular, we also // need to consider: (a) side-ways merging. We need to be careful with this as // it may have performance and jitting time implications. (b) relax the diff --git a/shardy/dialect/mpmd/transforms/common/remove_side_effect_after_cse.cc b/shardy/dialect/mpmd/transforms/common/remove_side_effect_after_cse.cc new file mode 100644 index 000000000..62b13ae5c --- /dev/null +++ b/shardy/dialect/mpmd/transforms/common/remove_side_effect_after_cse.cc @@ -0,0 +1,56 @@ +/* Copyright 2025 The MPMD Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "shardy/dialect/mpmd/transforms/common/passes.h" +#include "shardy/dialect/mpmd/transforms/common/utils.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "mlir/Pass/Pass.h" + +namespace mlir::mpmd { + +namespace { + +#define GEN_PASS_DEF_REMOVESIDEEFFECTAFTERCSEPASS +#include "shardy/dialect/mpmd/transforms/common/passes.h.inc" + +// Removes `has_side_effect=true` attribute from `stablehlo.custom_call` ops +// that have `mhlo.no_cse` attribute. +struct RemoveSideEffectAfterCSEPass + : public impl::RemoveSideEffectAfterCSEPassBase< + RemoveSideEffectAfterCSEPass> { + using impl::RemoveSideEffectAfterCSEPassBase< + RemoveSideEffectAfterCSEPass>::RemoveSideEffectAfterCSEPassBase; + + void runOnOperation() override { + func::FuncOp funcOp = getOperation(); + funcOp.walk([&](stablehlo::CustomCallOp customCallOp) { + if (customCallOp->hasAttr(kMhloNoCseAttr)) { + customCallOp->removeAttr(kHasSideEffectAttr); + } + }); + } +}; + +} // namespace + +std::unique_ptr createRemoveSideEffectAfterCSEPass() { + return std::make_unique(); +} + +} // namespace mlir::mpmd diff --git a/shardy/dialect/mpmd/transforms/common/test/add_side_effect_to_avoid_cse.mlir b/shardy/dialect/mpmd/transforms/common/test/add_side_effect_to_avoid_cse.mlir new file mode 100644 index 000000000..ccb7bebe5 --- /dev/null +++ b/shardy/dialect/mpmd/transforms/common/test/add_side_effect_to_avoid_cse.mlir @@ -0,0 +1,30 @@ +// RUN: mpmd_opt %s -mpmd-add-side-effect-to-avoid-cse | FileCheck %s + +// CHECK-LABEL: func @custom_call_with_no_cse_should_add_side_effect +// CHECK-SAME: (%arg0: tensor) -> tensor +func.func @custom_call_with_no_cse_should_add_side_effect(%arg0: tensor) -> tensor { + // CHECK: %[[RES0:.*]] = stablehlo.custom_call @Sharding(%arg0) + // CHECK-SAME: has_side_effect = true + // CHECK-SAME: mhlo.no_cse + // CHECK-SAME: : (tensor) -> tensor + %0 = stablehlo.custom_call @Sharding(%arg0) {mhlo.no_cse} : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: func @custom_call_without_no_cse_should_not_add_side_effect +// CHECK-SAME: (%arg0: tensor) -> tensor +func.func @custom_call_without_no_cse_should_not_add_side_effect(%arg0: tensor) -> tensor { + // CHECK-NOT: has_side_effect + // CHECK: stablehlo.custom_call @Sharding(%arg0) : (tensor) -> tensor + %0 = stablehlo.custom_call @Sharding(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: func @other_op_with_no_cse_should_not_add_side_effect +// CHECK-SAME: (%arg0: tensor) -> tensor +func.func @other_op_with_no_cse_should_not_add_side_effect(%arg0: tensor) -> tensor { + // CHECK-NOT: has_side_effect + // CHECK: stablehlo.add %arg0, %arg0 {mhlo.no_cse} : tensor + %0 = stablehlo.add %arg0, %arg0 {mhlo.no_cse} : tensor + func.return %0 : tensor +} diff --git a/shardy/dialect/mpmd/transforms/common/test/avoid_cse_on_custom_calls_marked_with_no_cse.mlir b/shardy/dialect/mpmd/transforms/common/test/avoid_cse_on_custom_calls_marked_with_no_cse.mlir new file mode 100644 index 000000000..63f280c67 --- /dev/null +++ b/shardy/dialect/mpmd/transforms/common/test/avoid_cse_on_custom_calls_marked_with_no_cse.mlir @@ -0,0 +1,16 @@ +// RUN: mpmd_opt %s -mpmd-add-side-effect-to-avoid-cse -cse -mpmd-remove-side-effect-after-cse | FileCheck %s + +// CHECK-LABEL: func @duplicate_custom_call_with_no_cse_should_be_csed +// CHECK-SAME: (%arg0: tensor) -> (tensor, tensor) +func.func @duplicate_custom_call_with_no_cse_should_be_csed(%arg0: tensor) -> (tensor, tensor) { + // CHECK: %[[RES0:.*]] = stablehlo.custom_call @Sharding(%arg0) + // CHECK-NOT: has_side_effect + // CHECK-SAME: mhlo.no_cse + // CHECK: %[[RES1:.*]] = stablehlo.custom_call @Sharding(%arg0) + // CHECK-NOT: has_side_effect + // CHECK-SAME: mhlo.no_cse + // CHECK: return %[[RES0]], %[[RES1]] + %0 = stablehlo.custom_call @Sharding(%arg0) {mhlo.no_cse} : (tensor) -> tensor + %1 = stablehlo.custom_call @Sharding(%arg0) {mhlo.no_cse} : (tensor) -> tensor + func.return %0, %1 : tensor, tensor +} diff --git a/shardy/dialect/mpmd/transforms/common/test/remove_side_effect_after_cse.mlir b/shardy/dialect/mpmd/transforms/common/test/remove_side_effect_after_cse.mlir new file mode 100644 index 000000000..7fc14072e --- /dev/null +++ b/shardy/dialect/mpmd/transforms/common/test/remove_side_effect_after_cse.mlir @@ -0,0 +1,28 @@ +// RUN: mpmd_opt %s -mpmd-remove-side-effect-after-cse | FileCheck %s + +// CHECK-LABEL: func @custom_call_with_no_cse_should_remove_side_effect +// CHECK-SAME: (%arg0: tensor) -> tensor +func.func @custom_call_with_no_cse_should_remove_side_effect(%arg0: tensor) -> tensor { + // CHECK-NOT: has_side_effect = true + // CHECK: %[[RES0:.*]] = stablehlo.custom_call @Sharding(%arg0) + // CHECK-SAME: mhlo.no_cse + // CHECK-SAME: : (tensor) -> tensor + %0 = stablehlo.custom_call @Sharding(%arg0) {has_side_effect = true,mhlo.no_cse} : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: func @custom_call_without_no_cse_should_do_nothing +// CHECK-SAME: (%arg0: tensor) -> tensor +func.func @custom_call_without_no_cse_should_do_nothing(%arg0: tensor) -> tensor { + // CHECK: stablehlo.custom_call @Sharding(%arg0) {has_side_effect = true} + %0 = stablehlo.custom_call @Sharding(%arg0) {has_side_effect = true}: (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: func @other_op_with_no_cse_should_do_nothing +// CHECK-SAME: (%arg0: tensor) -> tensor +func.func @other_op_with_no_cse_should_do_nothing(%arg0: tensor) -> tensor { + // CHECK: stablehlo.add %arg0, %arg0 {has_side_effect = true, mhlo.no_cse} : tensor + %0 = stablehlo.add %arg0, %arg0 {has_side_effect = true, mhlo.no_cse} : tensor + func.return %0 : tensor +} diff --git a/shardy/dialect/mpmd/transforms/common/utils.h b/shardy/dialect/mpmd/transforms/common/utils.h index eaa85200f..baac976d0 100644 --- a/shardy/dialect/mpmd/transforms/common/utils.h +++ b/shardy/dialect/mpmd/transforms/common/utils.h @@ -34,6 +34,11 @@ limitations under the License. namespace mlir::mpmd { +// The attribute to avoid CSE. +inline constexpr StringRef kMhloNoCseAttr = "mhlo.no_cse"; +// The attribute to indicate that an op has side effects. +inline constexpr StringRef kHasSideEffectAttr = "has_side_effect"; + // The name of the attribute that keeps track of how many times a loop has been // unrolled. constexpr StringRef kUnrollCounterAttrName = "unroll_counter";