Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 16 additions & 20 deletions csrc/host_ir/lower_to_communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ bool isLocalSizeOne(IterDomain* id) {

} // namespace

std::optional<CommunicationInfo> getCommunicationInfo(Expr* expr) {
CommunicationInfo getCommunicationInfo(Expr* expr) {
NVF_ERROR(
isResharding(expr),
"getCommunicationInfo should only be called when `expr` is known to be a "
Expand All @@ -319,9 +319,11 @@ std::optional<CommunicationInfo> getCommunicationInfo(Expr* expr) {
auto* consumer = expr->outputs().at(0)->as<TensorView>();
std::optional<CommunicationInfo> communication_info = std::nullopt;

auto create_communication_info = [&](CommunicationType type,
IterDomain* p_sharded_id,
IterDomain* c_sharded_id) {
// Fill `communication_info` instead of returning the result, so we can catch
// errors when more than one DIDs have sharding changes.
auto fill_communication_info = [&](CommunicationType type,
IterDomain* p_sharded_id,
IterDomain* c_sharded_id) {
NVF_ERROR(
!communication_info.has_value(),
"Expected at most one sharding change");
Expand Down Expand Up @@ -358,18 +360,18 @@ std::optional<CommunicationInfo> getCommunicationInfo(Expr* expr) {
IterDomain* p_logical_id = getLogicalFromLoopId(producer, p_loop_did);
CommunicationType type = same_mesh ? CommunicationType::Allgather
: CommunicationType::Gather;
create_communication_info(type, p_logical_id, p2c_map.at(p_logical_id));
fill_communication_info(type, p_logical_id, p2c_map.at(p_logical_id));
}
if (!p_sharded && c_sharded) {
IterDomain* c_logical_id = getLogicalFromLoopId(consumer, c_loop_did);
create_communication_info(
fill_communication_info(
CommunicationType::Scatter, c2p_map.at(c_logical_id), c_logical_id);
}
if (p_sharded && c_sharded) {
IterDomain* p_logical_id = getLogicalFromLoopId(producer, p_loop_did);
IterDomain* c_logical_id = getLogicalFromLoopId(consumer, c_loop_did);
// TODO(#4604): This is problematic for 2D sharding.
create_communication_info(
fill_communication_info(
CommunicationType::SendRecv, p_logical_id, c_logical_id);
}
} else {
Expand All @@ -383,7 +385,7 @@ std::optional<CommunicationInfo> getCommunicationInfo(Expr* expr) {
IterDomain* p_logical_id = getLogicalFromLoopId(producer, p_loop_did);
CommunicationType type = same_mesh ? CommunicationType::Allreduce
: CommunicationType::Reduce;
create_communication_info(type, p_logical_id, p2c_map.at(p_logical_id));
fill_communication_info(type, p_logical_id, p2c_map.at(p_logical_id));
continue;
}

Expand All @@ -400,17 +402,17 @@ std::optional<CommunicationInfo> getCommunicationInfo(Expr* expr) {
if (!c_it->second->isReduction()) {
continue;
}
create_communication_info(
fill_communication_info(
CommunicationType::ReduceScatter,
c2p_map.at(c_logical_id),
c_logical_id);
}
}

if (!communication_info.has_value()) {
create_communication_info(CommunicationType::Broadcast, nullptr, nullptr);
fill_communication_info(CommunicationType::Broadcast, nullptr, nullptr);
}
return communication_info;
return *communication_info;
}

namespace {
Expand Down Expand Up @@ -472,22 +474,16 @@ bool isCommunicationLayoutCompliant(Expr* expr) {
auto* producer = expr->inputs().at(0)->as<TensorView>();
auto* consumer = expr->outputs().at(0)->as<TensorView>();

// TODO(#4552): the caller should make sure Expr is a communication so
// getCommunicationInfo always returns a valid CommunicationInfo. Retry after
// #4552 is merged.
auto communication_info = getCommunicationInfo(expr);
if (!communication_info.has_value()) {
return true;
}
CommunicationInfo communication_info = getCommunicationInfo(expr);

Layout p_layout = getCommunicationLayout(
producer, communication_info->type, communication_info->p_sharded_id);
producer, communication_info.type, communication_info.p_sharded_id);
if (!isCompliantWith(*canonicalizeLayout(producer), p_layout)) {
return false;
}

Layout c_layout = getCommunicationLayout(
consumer, communication_info->type, communication_info->c_sharded_id);
consumer, communication_info.type, communication_info.c_sharded_id);
if (!isCompliantWith(*canonicalizeLayout(consumer), c_layout)) {
return false;
}
Expand Down
2 changes: 1 addition & 1 deletion csrc/host_ir/lower_to_communication.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ bool isCommunicationLayoutCompliant(Expr* expr);
// info: type and sharded IDs. We assume that the expr has been decomposed and
// represented a single communication. If multiple communications are present or
// 2D sharding, this function will raise an error.
std::optional<CommunicationInfo> getCommunicationInfo(Expr* expr);
CommunicationInfo getCommunicationInfo(Expr* expr);

// Given the input/output TensorView of a communication, returns its layout
// required by the communication backend (e.g. NCCL or UCC). `sharded_id` is the
Expand Down
6 changes: 1 addition & 5 deletions csrc/preseg_passes/insert_reshardings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -398,11 +398,7 @@ void InsertReshardingsPass::runPass(Fusion* fusion) {
// Validate
for (Expr* e : fusion->exprs()) {
if (isResharding(e)) {
NVF_ERROR(
getCommunicationInfo(e).has_value(),
"After decomposition, any resharding expression is expected to be a "
"lowerable communication: ",
e);
getCommunicationInfo(e);
}
}
}
Expand Down
17 changes: 3 additions & 14 deletions csrc/preseg_passes/reorder_sharded_axis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,11 @@ namespace nvfuser::preseg_passes {

namespace {

void makeCommunicationLayoutCompliant(
Expr* expr,
CommunicationInfo communication_info) {
void makeCommunicationLayoutCompliant(Expr* expr) {
auto* input = expr->inputs().at(0)->as<TensorView>();
auto* output = expr->outputs().at(0)->as<TensorView>();

CommunicationInfo communication_info = getCommunicationInfo(expr);
IterDomain* p_sharded_id = communication_info.p_sharded_id;
IterDomain* c_sharded_id = communication_info.c_sharded_id;

Expand Down Expand Up @@ -81,17 +80,7 @@ void ReorderShardedAxisPass::runPass(Fusion* fusion) {
continue;
}

auto communication_info = getCommunicationInfo(expr);
// Should really be simply NVF_ERROR(communication_info.has_value());
//
// I'll try to do that after #4552 is merged. Some of the `mesh.size() > 1`
// check in getCommunicationInfo and convertSingleOpToCommuniation will also
// need to go away for this simplification.
if (!communication_info.has_value()) {
continue;
}

makeCommunicationLayoutCompliant(expr, *communication_info);
makeCommunicationLayoutCompliant(expr);
}

if (isDebugDumpEnabled(DebugDumpOption::PreSegmenterLogging)) {
Expand Down