Skip to content

Add SDY sharding rule for stablehlo.batch_norm_inference #844

@sshonTT

Description

@sshonTT

There is no sharding rule for stablehlo.batch_norm_inference in SDY. This gap results in shape and sharding inference failures when batch norm appears under tensor parallelism.

As I’m currently unable to contribute directly, I prepared a change in op_sharding_rule_registry.cc to address this. Kindly review and, if acceptable, proceed with a PR.

diff --git a/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.cc b/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.cc
index 47ce53e..63b2424 100644
--- a/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.cc
+++ b/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.cc
@@ -303,6 +303,37 @@ OpShardingRuleAttr createOpShardingRule(Operation* op,
         }
         return builder.build();
       })
+      .Case<stablehlo::BatchNormInferenceOp>(
+        [conservativePropagation](stablehlo::BatchNormInferenceOp bn) {
+          auto inTy  = llvm::cast<mlir::RankedTensorType>(bn.getOperand().getType());
+          auto outTy = llvm::cast<mlir::RankedTensorType>(bn.getResult().getType());
+
+          OpShardingRuleBuilder builder(bn);
+
+          const int64_t numOperands = static_cast<int64_t>(bn->getNumOperands());
+          llvm::SmallVector<int64_t> opDims(numOperands, kNullDim);
+
+          for (auto [dU, dimSize] : llvm::enumerate(inTy.getShape())) {
+            const int64_t d = static_cast<int64_t>(dU);
+            std::fill(opDims.begin(), opDims.end(), kNullDim);
+            opDims[0] = d;
+            builder.addFactor(opDims, d, dimSize);
+          }
+
+          const int64_t featAxis = static_cast<int64_t>(bn.getFeatureIndex());
+          const int64_t C = outTy.getDimSize(featAxis);
+
+          for (int64_t paramIdx : {1LL, 2LL, 3LL, 4LL}) {
+            std::fill(opDims.begin(), opDims.end(), kNullDim);
+            opDims[paramIdx] = 0;
+            auto factorType = conservativePropagation ? FactorType::kNeedReplication
+                                                        : FactorType::kPassThrough;
+            builder.addFactor(opDims, kNullDim, C,
+                  factorType, true);
+          }
+
+          return builder.build();
+        })

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions