-
Notifications
You must be signed in to change notification settings - Fork 29
Open
Description
There is no sharding rule for stablehlo.batch_norm_inference in SDY. This gap results in shape and sharding inference failures when batch norm appears under tensor parallelism.
As I’m currently unable to contribute directly, I prepared a change in op_sharding_rule_registry.cc to address this. Kindly review and, if acceptable, proceed with a PR.
diff --git a/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.cc b/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.cc
index 47ce53e..63b2424 100644
--- a/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.cc
+++ b/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.cc
@@ -303,6 +303,37 @@ OpShardingRuleAttr createOpShardingRule(Operation* op,
}
return builder.build();
})
+ .Case<stablehlo::BatchNormInferenceOp>(
+ [conservativePropagation](stablehlo::BatchNormInferenceOp bn) {
+ auto inTy = llvm::cast<mlir::RankedTensorType>(bn.getOperand().getType());
+ auto outTy = llvm::cast<mlir::RankedTensorType>(bn.getResult().getType());
+
+ OpShardingRuleBuilder builder(bn);
+
+ const int64_t numOperands = static_cast<int64_t>(bn->getNumOperands());
+ llvm::SmallVector<int64_t> opDims(numOperands, kNullDim);
+
+ for (auto [dU, dimSize] : llvm::enumerate(inTy.getShape())) {
+ const int64_t d = static_cast<int64_t>(dU);
+ std::fill(opDims.begin(), opDims.end(), kNullDim);
+ opDims[0] = d;
+ builder.addFactor(opDims, d, dimSize);
+ }
+
+ const int64_t featAxis = static_cast<int64_t>(bn.getFeatureIndex());
+ const int64_t C = outTy.getDimSize(featAxis);
+
+ for (int64_t paramIdx : {1LL, 2LL, 3LL, 4LL}) {
+ std::fill(opDims.begin(), opDims.end(), kNullDim);
+ opDims[paramIdx] = 0;
+ auto factorType = conservativePropagation ? FactorType::kNeedReplication
+ : FactorType::kPassThrough;
+ builder.addFactor(opDims, kNullDim, C,
+ factorType, true);
+ }
+
+ return builder.build();
+ })
Metadata
Metadata
Assignees
Labels
No labels