Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions csrc/host_ir/lower_to_communication.cpp
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor renaming

Original file line number Diff line number Diff line change
Expand Up @@ -301,22 +301,22 @@ bool isLocalSizeOne(IterDomain* id) {

} // namespace

CommunicationInfo getCommunicationInfo(Expr* expr) {
CommunicationInfo getCommunicationInfo(Expr* e) {
NVF_ERROR(
isResharding(expr),
"getCommunicationInfo should only be called when `expr` is known to be a "
"communication. So `expr` should be resharding. Given: ",
expr);
isResharding(e),
"getCommunicationInfo should only be called when `e` is known to be a "
"communication. So `e` should be resharding. Given: ",
e);

NVF_ERROR(
expr->isA<LoadStoreOp>() || expr->isA<ReductionOp>(),
"getCommunicationInfo should only be called when `expr` is known to be a "
"communication. So `expr` should be either a LoadStoreOp or a "
e->isA<LoadStoreOp>() || e->isA<ReductionOp>(),
"getCommunicationInfo should only be called when `e` is known to be a "
"communication. So `e` should be either a LoadStoreOp or a "
"ReductionOp. Given: ",
expr);
e);

auto* producer = expr->inputs().at(0)->as<TensorView>();
auto* consumer = expr->outputs().at(0)->as<TensorView>();
auto* producer = e->inputs().at(0)->as<TensorView>();
auto* consumer = e->outputs().at(0)->as<TensorView>();
std::optional<CommunicationInfo> communication_info = std::nullopt;

// Fill `communication_info` instead of returning the result, so we can catch
Expand Down Expand Up @@ -355,7 +355,7 @@ CommunicationInfo getCommunicationInfo(Expr* expr) {
const bool c_sharded = c_loop_did != nullptr && consumer_mesh.size() > 1;
const bool same_mesh = producer_mesh == consumer_mesh;

if (expr->isA<LoadStoreOp>()) {
if (e->isA<LoadStoreOp>()) {
if (p_sharded && !c_sharded) {
IterDomain* p_logical_id = getLogicalFromLoopId(producer, p_loop_did);
CommunicationType type = same_mesh ? CommunicationType::Allgather
Expand All @@ -375,7 +375,7 @@ CommunicationInfo getCommunicationInfo(Expr* expr) {
CommunicationType::SendRecv, p_logical_id, c_logical_id);
}
} else {
NVF_ERROR(expr->isA<ReductionOp>());
NVF_ERROR(e->isA<ReductionOp>());
if (!p_sharded) {
// Not a reduction based communication.
continue;
Expand Down
2 changes: 1 addition & 1 deletion csrc/runtime/allocations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ std::pair<std::vector<int64_t>, std::vector<int64_t>> inferShape(
inferred_val.hasValue(),
"Could not launch kernel as program could not infer ",
symbolic_size->toInlineString(),
"(",
" (",
symbolic_size->toString(),
") for the buffer ",
tv->toString());
Expand Down
20 changes: 11 additions & 9 deletions csrc/runtime/fusion_kernel_runtime.cpp
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor renaming

Original file line number Diff line number Diff line change
Expand Up @@ -483,24 +483,26 @@ void FusionKernelRuntime::compileFusionParallel(KernelArgumentHolder args) {
switch (group_to_run->schedulerType()) {
case SchedulerType::Communication: {
auto deviceid = Communicator::getInstance().deviceId();
NVF_ERROR(
group_to_run->exprs().size() == 1,
"Communication segments must contain only one Expr");
for (auto* expr : convertSingleOpToCommunication(
NVF_ERROR_EQ(
group_to_run->exprs().size(),
1,
"Communication segments must contain only one Expr.");
for (auto* e : convertSingleOpToCommunication(
ir_cloner.clone(group_to_run->exprs().at(0)), deviceid)) {
NVF_ERROR(
expr->isA<Communication>(),
"Exprs in a Communication group should be Communication");
e->isA<Communication>(),
"Exprs in a Communication group should be Communication: ",
e);
// Allocate the recv buffers of communications
auto* communication = expr->as<Communication>();
auto* communication = e->as<Communication>();
TensorView* tv = communication->out();
if (tv->getDeviceMesh().has(deviceid)) {
auto* allocate =
IrBuilder::create<kir::Allocate>(tv, MemoryType::Global);
hic->pushBackTopLevelExprs(allocate);
}
hic->pushBackTopLevelExprs(expr);
auto wait = IrBuilder::create<hir::Wait>(expr->as<Communication>());
hic->pushBackTopLevelExprs(communication);
auto wait = IrBuilder::create<hir::Wait>(communication);
hic->pushBackTopLevelExprs(wait);
}
} break;
Expand Down
7 changes: 2 additions & 5 deletions tests/cpp/test_multidevice_lower_communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ class LowerGatherTest
public testing::WithParamInterface<std::tuple<InOutMesh, bool>> {};

TEST_P(LowerGatherTest, ) {
EnableOptionsGuard opt_guard;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessary because NVFuserTest owns an EnableOptionsGuard already.

const auto& [meshes, enable_host_ir_lowering] = GetParam();
const auto& [in_mesh, out_mesh] = meshes;

Expand Down Expand Up @@ -138,7 +137,6 @@ class LowerScatterTest
public testing::WithParamInterface<std::tuple<InOutMesh, bool>> {};

TEST_P(LowerScatterTest, ) {
EnableOptionsGuard opt_guard;
const auto& [meshes, enable_host_ir_lowering] = GetParam();
const auto& [in_mesh, out_mesh] = meshes;

Expand Down Expand Up @@ -189,7 +187,6 @@ class LowerSendRecvTest
public testing::WithParamInterface<std::tuple<InOutMesh, bool>> {};

TEST_P(LowerSendRecvTest, ) {
EnableOptionsGuard opt_guard;
const auto& [meshes, enable_host_ir_lowering] = GetParam();
const auto& [in_mesh, out_mesh] = meshes;

Expand Down Expand Up @@ -255,7 +252,6 @@ void LowerCollectiveTest::SetUp() {
// available. Therefore, we call it after the isBackendAvailable check.
communicator_->setDefaultBackend(backend_type);

EnableOptionsGuard enable_options_guard;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This (apparently introduced in #4170) is what has been hiding #4230.

if (enable_host_ir_lowering) {
EnableOptionsGuard::getCurOptions().set(EnableOption::HostIrLowering);
}
Expand Down Expand Up @@ -778,7 +774,8 @@ INSTANTIATE_TEST_SUITE_P(
LowerCollectiveTest,
::testing::Combine(
testing::Values(CommunicatorBackend::kNccl, CommunicatorBackend::kUcc),
testing::Bool()),
// Can't do testing::Bool() yet due to #4230
testing::Values(false)),
([](const testing::TestParamInfo<std::tuple<CommunicatorBackend, bool>>&
info) -> std::string {
const auto& [backend_type, enable_host_ir_lowering] = info.param;
Expand Down