diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index 2244cfd3bd0..6306e1d65d8 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -130,6 +130,12 @@ int64_t getRootRelativeIndex(const CommParams& params) { Communication::Communication(IrBuilderPasskey passkey, CommParams params) : Expr(passkey), params_(std::move(params)) { NVF_ERROR(params_.mesh.size() > 0, "The mesh size must be greater than 0."); + NVF_ERROR( + hasRoot(params_.type) == (params_.root >= 0), + "Root ", + params_.root, + " is not expected by CommunicationType ", + params_.type); } Communication::Communication(const Communication* src, IrCloner* ir_cloner) diff --git a/csrc/multidevice/lower_communication.cpp b/csrc/multidevice/lower_communication.cpp index b0d0bee764e..680f30bfb69 100644 --- a/csrc/multidevice/lower_communication.cpp +++ b/csrc/multidevice/lower_communication.cpp @@ -56,31 +56,6 @@ inline bool isDeviceInvolved( return sender_mesh.has(my_device_index) || receiver_mesh.has(my_device_index); } -// Utility function used for setting up a scatter or gather communication -// params. Since most of the steps are somewhat similar/opposite in those -// cases, we gathered the two implementations into one function. The argument -// "is_scatter" allows to discriminate between scatter and gather -// -// TODO: remove this helper. root_tv and is_scatter are not used and the -// definition is trivial to inline. -CommParams createParamsForGatherScatter( - DeviceIdxType my_device_index, - DeviceIdxType root, - TensorView* root_tv, // is_scatter ? input_tv : output_tv - bool is_scatter) { - const DeviceMesh& mesh = root_tv->getDeviceMesh(); - CommParams params; - params.type = - is_scatter ? CommunicationType::Scatter : CommunicationType::Gather; - params.root = root; - params.mesh = mesh; - params.team = mesh.vector(); - if (!mesh.has(root)) { - params.team.push_back(root); - } - return params; -} - // Adds one or zero Scatter communication to the vector 'comms' void lowerToScatter( DeviceIdxType my_device_index, @@ -93,9 +68,15 @@ void lowerToScatter( if (!isDeviceInvolved(my_device_index, root, receiver_mesh)) { return; } - auto params = - createParamsForGatherScatter(my_device_index, root, output_tv, true); - comms.push_back(IrBuilder::create(std::move(params))); + Team team = receiver_mesh.vector(); + if (!receiver_mesh.has(root)) { + team.push_back(root); + } + comms.push_back(IrBuilder::create(CommParams{ + .type = CommunicationType::Scatter, + .root = root, + .mesh = receiver_mesh, + .team = team})); } /* @@ -115,9 +96,15 @@ void lowerToGather( if (!isDeviceInvolved(my_device_index, root, sender_mesh)) { continue; } - auto params = - createParamsForGatherScatter(my_device_index, root, input_tv, false); - comms.push_back(IrBuilder::create(std::move(params))); + Team team = sender_mesh.vector(); + if (!sender_mesh.has(root)) { + team.push_back(root); + } + comms.push_back(IrBuilder::create(CommParams{ + .type = CommunicationType::Gather, + .root = root, + .mesh = sender_mesh, + .team = team})); } } @@ -132,28 +119,10 @@ void lowerToAllgather( return; } - CommParams params; - params.type = CommunicationType::Allgather; - params.mesh = mesh; - params.team = mesh.vector(); - comms.push_back(IrBuilder::create(std::move(params))); -} - -// Creates and set the CommParams for a Broadcast or Send/Recv communication -CommParams createParamsForBroadcastOrP2P( - DeviceIdxType my_device_index, - DeviceIdxType root, - // receiver devices - const DeviceMesh& mesh) { - CommParams params; - params.type = CommunicationType::Broadcast; - params.root = root; - params.mesh = mesh; - params.team = mesh.vector(); - if (!mesh.has(root)) { - params.team.push_back(root); - } - return params; + comms.push_back(IrBuilder::create(CommParams{ + .type = CommunicationType::Allgather, + .mesh = mesh, + .team = mesh.vector()})); } // Adds one or zero Broadcast or Send/Recv communication to the vector 'comms' @@ -165,8 +134,15 @@ void lowerToBroadcastOrP2P( if (!isDeviceInvolved(my_device_index, root, mesh)) { return; } - auto params = createParamsForBroadcastOrP2P(my_device_index, root, mesh); - comms.push_back(IrBuilder::create(std::move(params))); + Team team = mesh.vector(); + if (!mesh.has(root)) { + team.push_back(root); + } + comms.push_back(IrBuilder::create(CommParams{ + .type = CommunicationType::Broadcast, + .root = root, + .mesh = mesh, + .team = team})); } // Adds several Broadcast or Send/Recv communications to the vector 'comms' @@ -201,25 +177,6 @@ void lowerToBroadcastOrP2P( } } -CommParams createParamsForReduce( - DeviceIdxType my_device_index, - DeviceIdxType root, - TensorView* input_tv, - TensorView* output_tv, - BinaryOpType op_type) { - const DeviceMesh& mesh = input_tv->getDeviceMesh(); - CommParams params; - params.type = CommunicationType::Reduce; - params.root = root; - params.redOp = getC10dReduceOpType(op_type); - params.mesh = mesh; - params.team = mesh.vector(); - if (!mesh.has(root)) { - params.team.push_back(root); - } - return params; -} - void lowerToReduce( DeviceIdxType my_device_index, TensorView* input_tv, @@ -228,14 +185,22 @@ void lowerToReduce( std::vector& comms) { const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh(); const DeviceMesh& sender_mesh = input_tv->getDeviceMesh(); + const auto reduce_op_type = getC10dReduceOpType(op_type); // we create as many Reduces as there are devices in the receiver mesh for (auto root : receiver_mesh.vector()) { if (!isDeviceInvolved(my_device_index, root, sender_mesh)) { continue; } - auto params = createParamsForReduce( - my_device_index, root, input_tv, output_tv, op_type); - comms.push_back(IrBuilder::create(std::move(params))); + Team team = sender_mesh.vector(); + if (!sender_mesh.has(root)) { + team.push_back(root); + } + comms.push_back(IrBuilder::create(CommParams{ + .type = CommunicationType::Reduce, + .root = root, + .mesh = sender_mesh, + .team = team, + .redOp = reduce_op_type})); } } @@ -250,12 +215,11 @@ void lowerToAllreduce( return; } - CommParams params; - params.type = CommunicationType::Allreduce; - params.redOp = getC10dReduceOpType(op_type); - params.mesh = mesh; - params.team = mesh.vector(); - comms.push_back(IrBuilder::create(params)); + comms.push_back(IrBuilder::create(CommParams{ + .type = CommunicationType::Allreduce, + .mesh = mesh, + .team = mesh.vector(), + .redOp = getC10dReduceOpType(op_type)})); } void lowerToReduceScatter( @@ -279,13 +243,12 @@ void lowerToReduceScatter( scattered_axis++; } - CommParams params; - params.type = CommunicationType::ReduceScatter; - params.redOp = getC10dReduceOpType(op_type); - params.mesh = mesh; - params.team = mesh.vector(); - params.scattered_axis = scattered_axis; - comms.push_back(IrBuilder::create(params)); + comms.push_back(IrBuilder::create(CommParams{ + .type = CommunicationType::ReduceScatter, + .mesh = mesh, + .team = mesh.vector(), + .redOp = getC10dReduceOpType(op_type), + .scattered_axis = scattered_axis})); } } // namespace diff --git a/tests/cpp/test_multidevice_communications.cpp b/tests/cpp/test_multidevice_communications.cpp index 05d9a1e7f06..e59079e93fa 100644 --- a/tests/cpp/test_multidevice_communications.cpp +++ b/tests/cpp/test_multidevice_communications.cpp @@ -35,7 +35,6 @@ class CommunicationTest // TODO: test other reduction op types. static constexpr c10d::ReduceOp::RedOpType red_op = c10d::ReduceOp::RedOpType::SUM; - CommParams params; const DeviceMesh full_mesh; const Team all_ranks; c10::intrusive_ptr backend; @@ -63,11 +62,13 @@ void CommunicationTest::validate(at::Tensor obtained, at::Tensor expected) { } TEST_P(CommunicationTest, Gather) { - params.type = CommunicationType::Gather; - params.root = root; - params.mesh = full_mesh; - params.team = all_ranks; - auto communication = IrBuilder::create(&container, params); + auto communication = IrBuilder::create( + &container, + CommParams{ + .type = CommunicationType::Gather, + .root = root, + .mesh = full_mesh, + .team = all_ranks}); at::Tensor input_tensor = at::empty({1, tensor_size}, tensor_options); at::Tensor output_tensor = @@ -94,10 +95,12 @@ TEST_P(CommunicationTest, Gather) { } TEST_P(CommunicationTest, Allgather) { - params.type = CommunicationType::Allgather; - params.mesh = full_mesh; - params.team = all_ranks; - auto communication = IrBuilder::create(&container, params); + auto communication = IrBuilder::create( + &container, + CommParams{ + .type = CommunicationType::Allgather, + .mesh = full_mesh, + .team = all_ranks}); at::Tensor input_tensor = at::empty({1, tensor_size}, tensor_options); at::Tensor output_tensor = @@ -123,11 +126,13 @@ TEST_P(CommunicationTest, Allgather) { } TEST_P(CommunicationTest, Scatter) { - params.type = CommunicationType::Scatter; - params.root = root; - params.mesh = full_mesh; - params.team = all_ranks; - auto communication = IrBuilder::create(&container, params); + auto communication = IrBuilder::create( + &container, + CommParams{ + .type = CommunicationType::Scatter, + .root = root, + .mesh = full_mesh, + .team = all_ranks}); at::Tensor input_tensor; if (communicator->deviceId() == root) { @@ -159,11 +164,13 @@ TEST_P(CommunicationTest, Scatter) { } TEST_P(CommunicationTest, Broadcast) { - params.type = CommunicationType::Broadcast; - params.root = root; - params.mesh = full_mesh; - params.team = all_ranks; - auto communication = IrBuilder::create(&container, params); + auto communication = IrBuilder::create( + &container, + CommParams{ + .type = CommunicationType::Broadcast, + .root = root, + .mesh = full_mesh, + .team = all_ranks}); at::Tensor input_tensor; if (communicator->deviceId() == root) { @@ -205,11 +212,13 @@ TEST_P(CommunicationTest, SendRecv) { return; } - params.type = CommunicationType::SendRecv; - params.root = sender; - params.mesh = {receiver}; - params.team = {sender, receiver}; - auto communication = IrBuilder::create(&container, params); + auto communication = IrBuilder::create( + &container, + CommParams{ + .type = CommunicationType::SendRecv, + .root = sender, + .mesh = {receiver}, + .team = {sender, receiver}}); at::Tensor input_tensor; at::Tensor output_tensor; @@ -247,11 +256,13 @@ TEST_P(CommunicationTest, SendRecvToSelf) { return; } - params.type = CommunicationType::SendRecv; - params.root = sender; - params.mesh = {sender}; - params.team = {sender}; - auto communication = IrBuilder::create(&container, params); + auto communication = IrBuilder::create( + &container, + CommParams{ + .type = CommunicationType::SendRecv, + .root = sender, + .mesh = {sender}, + .team = {sender}}); at::Tensor input_tensor = at::empty({tensor_size}, tensor_options); at::Tensor output_tensor = at::empty_like(input_tensor); @@ -272,12 +283,14 @@ TEST_P(CommunicationTest, SendRecvToSelf) { } TEST_P(CommunicationTest, Reduce) { - params.type = CommunicationType::Reduce; - params.redOp = red_op; - params.root = root; - params.mesh = full_mesh; - params.team = all_ranks; - auto communication = IrBuilder::create(&container, params); + auto communication = IrBuilder::create( + &container, + CommParams{ + .type = CommunicationType::Reduce, + .root = root, + .mesh = full_mesh, + .team = all_ranks, + .redOp = red_op}); at::Tensor input_tensor = at::empty({1, tensor_size}, tensor_options); at::Tensor output_tensor = at::empty({tensor_size}, tensor_options); @@ -305,11 +318,13 @@ TEST_P(CommunicationTest, Reduce) { } TEST_P(CommunicationTest, Allreduce) { - params.type = CommunicationType::Allreduce; - params.redOp = red_op; - params.mesh = full_mesh; - params.team = all_ranks; - auto communication = IrBuilder::create(&container, params); + auto communication = IrBuilder::create( + &container, + CommParams{ + .type = CommunicationType::Allreduce, + .mesh = full_mesh, + .team = all_ranks, + .redOp = red_op}); at::Tensor input_tensor = at::empty({1, tensor_size}, tensor_options); at::Tensor output_tensor = at::empty({tensor_size}, tensor_options); @@ -334,13 +349,14 @@ TEST_P(CommunicationTest, Allreduce) { } TEST_P(CommunicationTest, ReduceScatter) { - params.type = CommunicationType::ReduceScatter; - params.redOp = red_op; - params.root = root; - params.mesh = full_mesh; - params.team = all_ranks; - params.scattered_axis = 1; - auto communication = IrBuilder::create(&container, params); + auto communication = IrBuilder::create( + &container, + CommParams{ + .type = CommunicationType::ReduceScatter, + .mesh = full_mesh, + .team = all_ranks, + .redOp = red_op, + .scattered_axis = 1}); const int num_devices = communicator->size(); const int device_id = communicator->deviceId(); diff --git a/tests/cpp/test_multidevice_host_ir.cpp b/tests/cpp/test_multidevice_host_ir.cpp index 22132977d50..819d57cfba9 100644 --- a/tests/cpp/test_multidevice_host_ir.cpp +++ b/tests/cpp/test_multidevice_host_ir.cpp @@ -76,7 +76,6 @@ TEST_P(MultiDeviceHostIrTest, SingleFusionSingleComm) { // [Step 3b)] Create a Communication Ir CommParams comm_params{ .type = CommunicationType::Allgather, - .root = 0, .mesh = mesh, .team = mesh.vector()}; auto communication = IrBuilder::create(