Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 36 additions & 15 deletions shardy/dialect/sdy/transforms/export/reshard_to_collectives.cc
Original file line number Diff line number Diff line change
Expand Up @@ -327,17 +327,13 @@ class CollectiveInserter {
mesh(inSharding.getMesh(op)),
curMeshName(inSharding.getMeshSymName()),
outMeshName(outSharding.getMeshSymName()),
unreducedAxes(outSharding.getUnreducedAxes()),
inUnreducedAxes(inSharding.getUnreducedAxes()),
outUnreducedAxes(outSharding.getUnreducedAxes()),
inAxesPerDim(getAxesPerDim<AxisList>(inSharding)),
outAxesPerDim(getAxesPerDim<AxisList>(outSharding)),
currentAxesPerDim(getAxesPerDim<SmallVector<AxisRefAttr>>(inSharding)),
capacityPerDim(inSharding.getRank(), 1),
collectiveAxesPerDim(inSharding.getRank()) {
// Unreduced axes in the input and output sharding must match, given we
// insert an all-reduce if an unreduced axis becomes replicated/sharded, and
// never insert a reshard that goes from replicated/sharded to unreduced.
assert(inSharding.getUnreducedAxes() == outSharding.getUnreducedAxes());

// We align sub-axes between the input and output axes, so that we can treat
// sub-axes like full axes and assume any two sub-axes that overlap are also
// equal, which allows using them as keys in a hash map.
Expand All @@ -358,6 +354,15 @@ class CollectiveInserter {
// If the input and output sharding are the same, returns the input value
// without inserting any collective.
Value insert() {
if (inUnreducedAxes != outUnreducedAxes) {
assert(getAxisSetDiff(inUnreducedAxes, outUnreducedAxes, mesh).empty());
SmallVector<AxisRefAttr> newUnreducedAxes =
getAxisSetDiff(outUnreducedAxes, inUnreducedAxes, mesh);
tryShardedToUnreduced(newUnreducedAxes);
assert(isDone());
return result;
}

// In the common case where all axes are a power of 2, in which case a
// bigger axis is always divisible by a smaller axis, we are guaranteed to
// be done after trying all-slice -> collective-permute -> all-to-alls ->
Expand Down Expand Up @@ -392,7 +397,6 @@ class CollectiveInserter {
tryAllGather();
}
assert(isDone());

return result;
}

Expand All @@ -412,7 +416,7 @@ class CollectiveInserter {

TensorShardingAttr getCurrentSharding() const {
return TensorShardingAttr::getClosed(getContext(), curMeshName,
currentAxesPerDim, unreducedAxes);
currentAxesPerDim, outUnreducedAxes);
}

// If an all-gather can be performed on `dim`, returns the axes to gather for
Expand Down Expand Up @@ -474,6 +478,23 @@ class CollectiveInserter {
}
}

// Tries to insert an `sdy.sharded_to_unreduced`.
void tryShardedToUnreduced(SmallVector<AxisRefAttr>& newUnreducedAxes) {
SmallVector<AxisRefAttr> allAxes;
for (auto [dim, collectiveAxes] : llvm::enumerate(collectiveAxesPerDim)) {
SmallVector<AxisRefAttr> axes = getGatheringAxes(dim);
collectiveAxes = AxisRefListAttr::get(getContext(), axes);
allAxes.append(axes);
}

sortAndMergeAxes(newUnreducedAxes, mesh);
sortAndMergeAxes(allAxes, mesh);
assert(newUnreducedAxes == allAxes);

result = ShardedToUnreducedOp::create(
rewriter, loc, result, collectiveAxesPerDim, getCurrentSharding());
}

// For each dimension d, distribute axes from `getAvailableAxes(d)` in
// `inAxesPerDim[d]` based on the capacity for that dimension
// (`capacityPerDim[d]`).
Expand Down Expand Up @@ -1305,7 +1326,7 @@ class CollectiveInserter {
Value result;
MeshAttr mesh;
FlatSymbolRefAttr curMeshName, outMeshName;
ArrayRef<AxisRefAttr> unreducedAxes;
ArrayRef<AxisRefAttr> inUnreducedAxes, outUnreducedAxes;
SmallVector<AxisList> inAxesPerDim, outAxesPerDim;
AxesPerDim currentAxesPerDim;
SmallVector<int64_t> capacityPerDim;
Expand Down Expand Up @@ -1342,15 +1363,15 @@ class ReshardPattern : public OpConversionPattern<ReshardOp> {
op, [](Diagnostic& diag) { diag << "Incompatible shardings"; });
}
if (outSharding.isFullyReplicated()) {
if (inSharding.isFullyReplicated()) {
if (inSharding.isFullyReplicated() &&
inSharding.getUnreducedAxes() == outSharding.getUnreducedAxes()) {
rewriter.replaceOp(op, adaptor.getInput());
return success();
}
// TODO(enver): Hard fail if output sharding has a different unreduced
// axes than the input sharding. Note that the out sharding may be fully
// replicated and still have different unreduced axes than the input
// sharding.
SmallVector<AxisRefAttr> oldUnreducedAxes =
llvm::to_vector(outSharding.getUnreducedAxes());
outSharding = TensorShardingAttr::getFullyClosedLike(inSharding);
outSharding = outSharding.replaceUnreducedAxes(oldUnreducedAxes);
}
// TODO(enver): Set input mesh to output mesh if input sharding is fully
// replicated. It requires sdy.all_slice can handle that input and output
Expand Down Expand Up @@ -1382,7 +1403,7 @@ struct ReshardToCollectivesPass
LogicalResult initialize(MLIRContext* context) final {
target = std::make_shared<ConversionTarget>(*context);
target->addLegalOp<AllGatherOp, AllSliceOp, AllToAllOp,
CollectivePermuteOp>();
CollectivePermuteOp, ShardedToUnreducedOp>();
target->addDynamicallyLegalOp<ReshardOp>([&](ReshardOp op) {
TensorShardingAttr inSharding = getSharding(op.getInput());
TensorShardingAttr outSharding = op.getSharding();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -890,6 +890,46 @@ func.func @out_unreduced_axes_preserved(%arg0 : tensor<16x8xf32> {sdy.sharding=#
return %0 : tensor<16x8xf32>
}

// CHECK-LABEL: func @sharded_to_unreduced_1
func.func @sharded_to_unreduced_1(%arg0 : tensor<24x8xf32> {sdy.sharding=#sdy.sharding<@mesh1d_6, [{"x"}, {}]>}) -> tensor<24x8xf32> {
// CHECK-NEXT: %0 = sdy.sharded_to_unreduced [{"x"}, {}] %arg0 out_sharding=<@mesh1d_6, [{}, {}], unreduced={"x"}>
// CHECK-NEXT: return %0
%0 = sdy.reshard %arg0 <@mesh1d_6, [{}, {}], unreduced={"x"}> : tensor<24x8xf32>
return %0 : tensor<24x8xf32>
}

// CHECK-LABEL: func @sharded_to_unreduced_single_axis
func.func @sharded_to_unreduced_single_axis(%arg0 : tensor<16x8xf32> {sdy.sharding=#sdy.sharding<@mesh2d, [{"y"}, {"x"}]>}) -> tensor<16x8xf32> {
// CHECK-NEXT: %0 = sdy.sharded_to_unreduced [{}, {"x"}] %arg0 out_sharding=<@mesh2d, [{"y"}, {}], unreduced={"x"}>
// CHECK-NEXT: return %0
%0 = sdy.reshard %arg0 <@mesh2d, [{"y"}, {}], unreduced={"x"}> : tensor<16x8xf32>
return %0 : tensor<16x8xf32>
}

// CHECK-LABEL: func @sharded_to_unreduced_multiple_axes
func.func @sharded_to_unreduced_multiple_axes(%arg0 : tensor<16x8xf32> {sdy.sharding=#sdy.sharding<@mesh3d, [{"x", "z", "y"}, {}]>}) -> tensor<16x8xf32> {
// CHECK-NEXT: %0 = sdy.sharded_to_unreduced [{"z", "y"}, {}] %arg0 out_sharding=<@mesh3d, [{"x"}, {}], unreduced={"y", "z"}>
// CHECK-NEXT: return %0
%0 = sdy.reshard %arg0 <@mesh3d, [{"x"}, {}], unreduced={"y", "z"}> : tensor<16x8xf32>
return %0 : tensor<16x8xf32>
}

// CHECK-LABEL: func @sharded_to_unreduced_multiple_dims
func.func @sharded_to_unreduced_multiple_dims(%arg0 : tensor<16x8xf32> {sdy.sharding=#sdy.sharding<@mesh3d, [{"y", "z"}, {"x"}]>}) -> tensor<16x8xf32> {
// CHECK-NEXT: %0 = sdy.sharded_to_unreduced [{"z"}, {"x"}] %arg0 out_sharding=<@mesh3d, [{"y"}, {}], unreduced={"x", "z"}>
// CHECK-NEXT: return %0
%0 = sdy.reshard %arg0 <@mesh3d, [{"y"}, {}], unreduced={"x", "z"}> : tensor<16x8xf32>
return %0 : tensor<16x8xf32>
}

// CHECK-LABEL: func @sharded_to_unreduced_with_subaxis
func.func @sharded_to_unreduced_with_subaxis(%arg0 : tensor<16x8xf32> {sdy.sharding=#sdy.sharding<@mesh2d_2x8, [{"y"}, {"x"}]>}) -> tensor<16x8xf32> {
// CHECK-NEXT: %0 = sdy.sharded_to_unreduced [{"y":(4)2}, {}] %arg0 out_sharding=<@mesh2d_2x8, [{"y":(1)4}, {"x"}], unreduced={"y":(4)2}>
// CHECK-NEXT: return %0
%0 = sdy.reshard %arg0 <@mesh2d_2x8, [{"y":(1)4}, {"x"}], unreduced={"y":(4)2}> : tensor<16x8xf32>
return %0 : tensor<16x8xf32>
}

// TODO(b/391138813): Add proper support for axes that can't co-exist

// LABEL: func @reshard_with_non_divisible_subaxes_same_pre_size
Expand Down
Loading