Skip to content
6 changes: 6 additions & 0 deletions csrc/multidevice/communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,12 @@ int64_t getRootRelativeIndex(const CommParams& params) {
Communication::Communication(IrBuilderPasskey passkey, CommParams params)
: Expr(passkey), params_(std::move(params)) {
NVF_ERROR(params_.mesh.size() > 0, "The mesh size must be greater than 0.");
NVF_ERROR(
hasRoot(params_.type) == (params_.root >= 0),
"Root ",
params_.root,
" is not expected by CommunicationType ",
params_.type);
}

Communication::Communication(const Communication* src, IrCloner* ir_cloner)
Expand Down
143 changes: 53 additions & 90 deletions csrc/multidevice/lower_communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,31 +56,6 @@ inline bool isDeviceInvolved(
return sender_mesh.has(my_device_index) || receiver_mesh.has(my_device_index);
}

// Utility function used for setting up a scatter or gather communication
// params. Since most of the steps are somewhat similar/opposite in those
// cases, we gathered the two implementations into one function. The argument
// "is_scatter" allows to discriminate between scatter and gather
//
// TODO: remove this helper. root_tv and is_scatter are not used and the
// definition is trivial to inline.
CommParams createParamsForGatherScatter(
DeviceIdxType my_device_index,
DeviceIdxType root,
TensorView* root_tv, // is_scatter ? input_tv : output_tv
bool is_scatter) {
const DeviceMesh& mesh = root_tv->getDeviceMesh();
CommParams params;
params.type =
is_scatter ? CommunicationType::Scatter : CommunicationType::Gather;
params.root = root;
params.mesh = mesh;
params.team = mesh.vector();
if (!mesh.has(root)) {
params.team.push_back(root);
}
return params;
}

// Adds one or zero Scatter communication to the vector 'comms'
void lowerToScatter(
DeviceIdxType my_device_index,
Expand All @@ -93,9 +68,15 @@ void lowerToScatter(
if (!isDeviceInvolved(my_device_index, root, receiver_mesh)) {
return;
}
auto params =
createParamsForGatherScatter(my_device_index, root, output_tv, true);
comms.push_back(IrBuilder::create<Communication>(std::move(params)));
Team team = receiver_mesh.vector();
if (!receiver_mesh.has(root)) {
team.push_back(root);
}
comms.push_back(IrBuilder::create<Communication>(CommParams{
.type = CommunicationType::Scatter,
.root = root,
.mesh = receiver_mesh,
.team = team}));
}

/*
Expand All @@ -115,9 +96,15 @@ void lowerToGather(
if (!isDeviceInvolved(my_device_index, root, sender_mesh)) {
continue;
}
auto params =
createParamsForGatherScatter(my_device_index, root, input_tv, false);
comms.push_back(IrBuilder::create<Communication>(std::move(params)));
Team team = sender_mesh.vector();
if (!sender_mesh.has(root)) {
team.push_back(root);
}
comms.push_back(IrBuilder::create<Communication>(CommParams{
.type = CommunicationType::Gather,
.root = root,
.mesh = sender_mesh,
.team = team}));
}
}

Expand All @@ -132,28 +119,10 @@ void lowerToAllgather(
return;
}

CommParams params;
params.type = CommunicationType::Allgather;
params.mesh = mesh;
params.team = mesh.vector();
comms.push_back(IrBuilder::create<Communication>(std::move(params)));
}

// Creates and set the CommParams for a Broadcast or Send/Recv communication
CommParams createParamsForBroadcastOrP2P(
DeviceIdxType my_device_index,
DeviceIdxType root,
// receiver devices
const DeviceMesh& mesh) {
CommParams params;
params.type = CommunicationType::Broadcast;
params.root = root;
params.mesh = mesh;
params.team = mesh.vector();
if (!mesh.has(root)) {
params.team.push_back(root);
}
return params;
comms.push_back(IrBuilder::create<Communication>(CommParams{
.type = CommunicationType::Allgather,
.mesh = mesh,
.team = mesh.vector()}));
}

// Adds one or zero Broadcast or Send/Recv communication to the vector 'comms'
Expand All @@ -165,8 +134,15 @@ void lowerToBroadcastOrP2P(
if (!isDeviceInvolved(my_device_index, root, mesh)) {
return;
}
auto params = createParamsForBroadcastOrP2P(my_device_index, root, mesh);
comms.push_back(IrBuilder::create<Communication>(std::move(params)));
Team team = mesh.vector();
if (!mesh.has(root)) {
team.push_back(root);
}
comms.push_back(IrBuilder::create<Communication>(CommParams{
.type = CommunicationType::Broadcast,
.root = root,
.mesh = mesh,
.team = team}));
}

// Adds several Broadcast or Send/Recv communications to the vector 'comms'
Expand Down Expand Up @@ -201,25 +177,6 @@ void lowerToBroadcastOrP2P(
}
}

CommParams createParamsForReduce(
DeviceIdxType my_device_index,
DeviceIdxType root,
TensorView* input_tv,
TensorView* output_tv,
BinaryOpType op_type) {
const DeviceMesh& mesh = input_tv->getDeviceMesh();
CommParams params;
params.type = CommunicationType::Reduce;
params.root = root;
params.redOp = getC10dReduceOpType(op_type);
params.mesh = mesh;
params.team = mesh.vector();
if (!mesh.has(root)) {
params.team.push_back(root);
}
return params;
}

void lowerToReduce(
DeviceIdxType my_device_index,
TensorView* input_tv,
Expand All @@ -228,14 +185,22 @@ void lowerToReduce(
std::vector<Communication*>& comms) {
const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh();
const DeviceMesh& sender_mesh = input_tv->getDeviceMesh();
const auto reduce_op_type = getC10dReduceOpType(op_type);
// we create as many Reduces as there are devices in the receiver mesh
for (auto root : receiver_mesh.vector()) {
if (!isDeviceInvolved(my_device_index, root, sender_mesh)) {
continue;
}
auto params = createParamsForReduce(
my_device_index, root, input_tv, output_tv, op_type);
comms.push_back(IrBuilder::create<Communication>(std::move(params)));
Team team = sender_mesh.vector();
if (!sender_mesh.has(root)) {
team.push_back(root);
}
comms.push_back(IrBuilder::create<Communication>(CommParams{
.type = CommunicationType::Reduce,
.root = root,
.mesh = sender_mesh,
.team = team,
.redOp = reduce_op_type}));
}
}

Expand All @@ -250,12 +215,11 @@ void lowerToAllreduce(
return;
}

CommParams params;
params.type = CommunicationType::Allreduce;
params.redOp = getC10dReduceOpType(op_type);
params.mesh = mesh;
params.team = mesh.vector();
comms.push_back(IrBuilder::create<Communication>(params));
comms.push_back(IrBuilder::create<Communication>(CommParams{
.type = CommunicationType::Allreduce,
.mesh = mesh,
.team = mesh.vector(),
.redOp = getC10dReduceOpType(op_type)}));
}

void lowerToReduceScatter(
Expand All @@ -279,13 +243,12 @@ void lowerToReduceScatter(
scattered_axis++;
}

CommParams params;
params.type = CommunicationType::ReduceScatter;
params.redOp = getC10dReduceOpType(op_type);
params.mesh = mesh;
params.team = mesh.vector();
params.scattered_axis = scattered_axis;
comms.push_back(IrBuilder::create<Communication>(params));
comms.push_back(IrBuilder::create<Communication>(CommParams{
.type = CommunicationType::ReduceScatter,
.mesh = mesh,
.team = mesh.vector(),
.redOp = getC10dReduceOpType(op_type),
.scattered_axis = scattered_axis}));
}

} // namespace
Expand Down
Loading