Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 45 additions & 63 deletions csrc/preseg_passes/propagate_shardings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,63 +41,52 @@ bool isSplitDivisible(IterDomain* id, Split* ref_split) {
return id_extent % split_factor == 0;
}

// Sort the given tvs by the number of device/stream dimensions in descending
// order. Break ties by rank of device mesh.
template <typename Range>
std::vector<TensorView*> filterTvsWithMesh(const Range& tvs) {
std::vector<TensorView*> tvs_with_mesh;
std::copy_if(
tvs.begin(),
tvs.end(),
std::back_inserter(tvs_with_mesh),
[](TensorView* tv) { return tv != nullptr && tv->hasDeviceMesh(); });
return tvs_with_mesh;
}
std::vector<TensorView*> sortTvsByParallelDims(const Range& tvs) {
auto num_parallel_dims = [](TensorView* tv) {
return std::count_if(
tv->getLoopDomain().begin(),
tv->getLoopDomain().end(),
[](IterDomain* id) {
return !id->isReduction() && (id->isStream() || id->isDeviceDim());
});
};

// Sort the given tvs by the number of device dimensions in descending order.
// Break ties by the total number of dimensions.
// Only includes TensorViews that have a device mesh.
template <typename Range>
std::vector<TensorView*> sortTvsByDeviceDims(const Range& tvs) {
// Filter out TVs without a device mesh
std::vector<TensorView*> tvs_with_mesh = filterTvsWithMesh(tvs);

// Then sort the filtered TVs
std::stable_sort(
tvs_with_mesh.begin(), tvs_with_mesh.end(), [](auto a, auto b) {
int64_t a_device_dims = numDeviceDims(a);
int64_t b_device_dims = numDeviceDims(b);
if (a_device_dims != b_device_dims) {
return a_device_dims > b_device_dims;
}
// Break ties by rank of device mesh.
return a->getDeviceMesh().rank() > b->getDeviceMesh().rank();
});
std::vector<TensorView*> tvs_vec(tvs.begin(), tvs.end());

return tvs_with_mesh;
std::ranges::stable_sort(tvs_vec, [&num_parallel_dims](auto a, auto b) {
return std::make_pair(num_parallel_dims(a), a->getDeviceMesh().rank()) >
std::make_pair(num_parallel_dims(b), b->getDeviceMesh().rank());
});

return tvs_vec;
}

// Order the inputs of the expression based on their priority.
// For linear op, we use weights and bias before input.
// For matmul op, we use weights before input.
// For other ops, we sort the inputs by the number of device dimensions in
// descending order.
// For other ops, we sort the inputs by the number of device/stream dimensions
// in descending order.
std::vector<TensorView*> getOrderedReferenceInputs(Expr* expr) {
const auto& inputs = ir_utils::filterByType<TensorView>(expr->inputs());
if (LinearOp* linear_op = dynamic_cast<LinearOp*>(expr)) {
if (auto* linear_op = dynamic_cast<LinearOp*>(expr)) {
// Use weights and bias before input.
return filterTvsWithMesh(std::vector<TensorView*>(
{linear_op->inB(), linear_op->bias(), linear_op->inA()}));
if (linear_op->hasBias()) {
return {linear_op->inB(), linear_op->bias(), linear_op->inA()};
} else {
return {linear_op->inB(), linear_op->inA()};
}
}

if (MatmulOp* matmul_op = dynamic_cast<MatmulOp*>(expr)) {
if (auto* matmul_op = dynamic_cast<MatmulOp*>(expr)) {
// Use weights before input.
return filterTvsWithMesh(
std::vector<TensorView*>({matmul_op->inB(), matmul_op->inA()}));
return {matmul_op->inB(), matmul_op->inA()};
}

// Sort inputs by number of device dimensions in descending order
std::vector<TensorView*> sorted_inputs = sortTvsByDeviceDims(inputs);

return sorted_inputs;
// Sort inputs by number of device/stream dimensions in descending order
return sortTvsByParallelDims(inputs);
}

// Returns the set of parallel types not seen on the loop domain of the given
Expand Down Expand Up @@ -186,6 +175,16 @@ void transformLoopDomain(
target->setDeviceMesh(ref->getDeviceMesh());
}

// If either the ref or target are scatter op outputs, skip propagation
if (target->definition() != nullptr &&
target->definition()->isA<ScatterOp>()) {
// Scatter op output has a disjoint logical-to-loop domain.
// So we skip propagation to avoid errors in the following code such as when
// setting the loop domain. It is not clear to me what device / stream
// parallelization would mean on scatter output.
return;
}

std::unordered_map<IterDomain*, IterDomain*> ref2target =
getRef2TargetMap(ref, target, direction);

Expand Down Expand Up @@ -416,11 +415,6 @@ void PropagateShardingsPass::runPass(Fusion* fusion) {
// Propagate shardings from reference inputs in order.
for (auto* ref_input : reference_inputs) {
NVF_ERROR(ref_input != nullptr);
NVF_ERROR(
ref_input->hasDeviceMesh(),
"Reference input ",
ref_input,
" has no device mesh.");

// Consider out [M, N] = linear (inp [M, K], weight (N,
// K)) with inp sharded on M ([DIDx(d), M/d, K]) and weight sharded on N
Expand Down Expand Up @@ -452,27 +446,15 @@ void PropagateShardingsPass::runPass(Fusion* fusion) {
// inputs. See MultiDevicePresegPassesTest.ResidualAdd for an example.
for (Expr* expr : exprs | std::views::reverse) {
const auto& outputs = ir_utils::filterByType<TensorView>(expr->outputs());
if (outputs.empty()) {
continue;
}
// All outputs of an expression (Welford, SDPA) should be uniformly sharded.
// We pick the most parallel output as the reference.
// This is to avoid picking seed/offset tvs in SDPA.
std::vector<TensorView*> sorted_outputs = sortTvsByDeviceDims(outputs);

if (sorted_outputs.empty()) {
// No output with a device mesh.
continue;
}

std::vector<TensorView*> sorted_outputs = sortTvsByParallelDims(outputs);
TensorView* ref_output = sorted_outputs.front();
NVF_ERROR(
ref_output != nullptr && ref_output->hasDeviceMesh(),
"Reference output ",
ref_output,
" has no device mesh.");

// For fusion inputs, only check if they have a device mesh. We do not
// modify their sharding. For non-fusion inputs, we try to propagate
// shardings from the reference output for parallel types that are not
// already present.

for (auto* target : ir_utils::filterByType<TensorView>(expr->inputs())) {
// Allow inputs to be stream parallelized for easier analysis.
if (user_sharded_tvs.count(target) && !target->isFusionInput()) {
Expand Down
46 changes: 45 additions & 1 deletion tests/cpp/test_stream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
#include <ir/interface_nodes.h>
#include <multidevice/utils.h>
#include <ops/alias.h>
#include <ops/arith.h>
#include <ops/all_ops.h>
#include <preseg_passes/propagate_shardings.h>
#include <tests/cpp/utils.h>

namespace nvfuser {
Expand Down Expand Up @@ -86,4 +87,47 @@ TEST_F(StreamTest, haveDifferentShardings) {
EXPECT_TRUE(haveDifferentShardings(tv2, tv3, {ParallelType::Stream}));
}

TEST_F(StreamTest, ForwardPropagation) {
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

const int64_t s = 2;

TensorView* in = makeContigTensor(2);
TensorView* w = makeContigTensor(2);
TensorView* out = matmul(in, w);
fusion->addInput(in);
fusion->addInput(w);
fusion->addOutput(out);

w->outer_split(1, s);
w->axis(1)->parallelize(ParallelType::Stream);

preseg_passes::OptimizationPass<
preseg_passes::PropagateShardingsPass>::runPass(fusion.get());
EXPECT_TRUE(out->axis(1)->isStream()) << out;
}

TEST_F(StreamTest, BackwardPropagation) {
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

const int64_t s = 2;

TensorView* tv0 = makeContigTensor(2);
TensorView* tv1 = add(tv0, IrBuilder::create<Val>(1.0));
TensorView* tv2 = add(tv1, tv1);
fusion->addInput(tv0);
fusion->addOutput(tv2);

tv2->outer_split(0, s);
tv2->axis(0)->parallelize(ParallelType::Stream);

preseg_passes::OptimizationPass<
preseg_passes::PropagateShardingsPass>::runPass(fusion.get());
for (auto* tv : {tv0, tv1, tv2}) {
EXPECT_TRUE(tv->axis(0)->isStream()) << tv;
}
}

} // namespace nvfuser