diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index b7ffc3d0bfd..2244cfd3bd0 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -48,9 +48,7 @@ std::ostream& operator<<(std::ostream& os, const CommunicationType& type) { namespace { -inline void assertBufferCount( - const std::vector& bufs, - size_t count) { +void assertBufferCount(const std::vector& bufs, size_t count) { NVF_ERROR( bufs.size() == count, "there must be ", @@ -60,7 +58,7 @@ inline void assertBufferCount( " were given"); } -inline void assertBuffersHaveSameSize( +void assertBuffersHaveSameSize( const std::vector& bufs1, const std::vector& bufs2) { if (bufs1.empty() && bufs2.empty()) { @@ -76,7 +74,7 @@ inline void assertBuffersHaveSameSize( } } -inline void doLocalCopy(const at::Tensor& dst, const at::Tensor& src) { +void doLocalCopy(const at::Tensor& dst, const at::Tensor& src) { dst.view_as(src).copy_(src, /*non_blocking=*/true); } @@ -100,7 +98,7 @@ T getInitialValue(c10d::ReduceOp::RedOpType op) { bool hasRoot(CommunicationType type) { return type == CommunicationType::Gather || - type == CommunicationType::Scatter || + type == CommunicationType::Scatter || type == CommunicationType::Reduce || type == CommunicationType::Broadcast || type == CommunicationType::SendRecv; } @@ -111,36 +109,27 @@ bool isReduction(CommunicationType type) { type == CommunicationType::ReduceScatter; } -void assertValid(const CommParams& params) { - std::unordered_set team_without_duplicates( - params.team.begin(), params.team.end()); - NVF_ERROR( - team_without_duplicates.size() == params.team.size(), - "the communication must not involve the same device more than once"); - NVF_ERROR(!params.team.empty(), "the team size must be greater than 0"); - if (hasRoot(params.type)) { - auto it = std::find(params.team.begin(), params.team.end(), params.root); - NVF_ERROR( - it != params.team.end(), - "root (device ", - params.root, - ") must be present in the communication's team"); - } -} - // pytorch's process group expects the root to be specified // as an integer between 0 and world_size-1. We choose it to be // the device's relative index within the team int64_t getRootRelativeIndex(const CommParams& params) { - auto it = std::find(params.team.begin(), params.team.end(), params.root); - return std::distance(params.team.begin(), it); + int64_t relative_index = params.mesh.idxOf(params.root); + // This assumes the root, if not in mesh, is added to the end of the team + // vector. Given the assumption is distantly enforced at + // Communication::team(), I slightly prefer making this function a method of + // Communication so relative_index can be computed from Communication::team() + // directly. Wdyt? + if (relative_index == -1) { + relative_index = params.mesh.size(); + } + return relative_index; } } // namespace Communication::Communication(IrBuilderPasskey passkey, CommParams params) : Expr(passkey), params_(std::move(params)) { - assertValid(params_); + NVF_ERROR(params_.mesh.size() > 0, "The mesh size must be greater than 0."); } Communication::Communication(const Communication* src, IrCloner* ir_cloner) @@ -153,10 +142,10 @@ std::string Communication::toString(const int indent_size) const { indent(ss, indent_size) << "Communication " << params_.type << ": {" << std::endl; - if (hasRoot(params_.type)) { indent(ss, indent_size + 1) << "root: " << params_.root << "," << std::endl; } + indent(ss, indent_size + 1) << "mesh: " << params_.mesh << "," << std::endl; indent(ss, indent_size + 1) << "team: " << params_.team << "," << std::endl; indent(ss, indent_size) << "}"; @@ -178,9 +167,10 @@ bool Communication::sameAs(const Statement* other) const { const auto& p1 = this->params(); const auto& p2 = other->as()->params(); - return ( - p1.type == p2.type && (!hasRoot(p1.type) || p1.root == p2.root) && - p1.team == p2.team && (!isReduction(p1.type) || p1.redOp == p2.redOp)); + return (p1.type == p2.type && (!hasRoot(p1.type) || p1.root == p2.root) && + p1.mesh == p2.mesh && + (!isReduction(p1.type) || p1.redOp == p2.redOp)) && + p1.team == p2.team; } namespace { @@ -192,7 +182,7 @@ c10::intrusive_ptr postBroadcast( at::Tensor output_tensor) { const CommParams& params = communication->params(); if (my_device_index == params.root) { - if (params.is_root_in_mesh) { + if (communication->isRootInMesh()) { // Do a local copy and the subsequent broadcast will be in place. Consider // ProcessGroupNCCL::_broadcast_oop so ncclBroadcast doesn't wait for the // local copy to complete. @@ -203,7 +193,7 @@ c10::intrusive_ptr postBroadcast( } } - if (params.team.size() == 1) { + if (communication->team().size() == 1) { return nullptr; } @@ -219,7 +209,7 @@ c10::intrusive_ptr postGather( at::Tensor input_tensor, at::Tensor output_tensor) { const CommParams& params = communication->params(); - if (my_device_index == params.root && !params.is_root_in_mesh) { + if (my_device_index == params.root && !communication->isRootInMesh()) { // This is likely a suboptimal way to allocate tensors for nccl. To benefit // from zero copy // (https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/bufferreg.html), @@ -235,9 +225,9 @@ c10::intrusive_ptr postGather( if (my_device_index == params.root) { output_tensors.resize(1); int64_t j = 0; - for (auto i : c10::irange(params.team.size())) { + for (auto i : c10::irange(communication->team().size())) { if (root_relative_index == static_cast(i) && - !params.is_root_in_mesh) { + !communication->isRootInMesh()) { output_tensors[0].push_back(input_tensor); continue; } @@ -245,7 +235,7 @@ c10::intrusive_ptr postGather( j++; } - assertBufferCount(output_tensors[0], params.team.size()); + assertBufferCount(output_tensors[0], communication->team().size()); assertBuffersHaveSameSize(input_tensors, output_tensors[0]); } @@ -259,13 +249,11 @@ c10::intrusive_ptr postAllgather( c10::intrusive_ptr backend, at::Tensor input_tensor, at::Tensor output_tensor) { - const CommParams& params = communication->params(); - std::vector input_tensors({input_tensor}); std::vector> output_tensors(1); output_tensors[0] = at::split(output_tensor, /*split_size=*/1, /*dim=*/0); - assertBufferCount(output_tensors[0], params.team.size()); + assertBufferCount(output_tensors[0], communication->team().size()); assertBuffersHaveSameSize(input_tensors, output_tensors[0]); return backend->allgather(output_tensors, input_tensors, {}); } @@ -278,7 +266,7 @@ c10::intrusive_ptr postScatter( at::Tensor output_tensor) { const CommParams& params = communication->params(); - if (my_device_index == params.root && !params.is_root_in_mesh) { + if (my_device_index == params.root && !communication->isRootInMesh()) { output_tensor = at::empty_like(input_tensor.slice(0, 0, 1)); } std::vector output_tensors({output_tensor}); @@ -288,9 +276,9 @@ c10::intrusive_ptr postScatter( if (my_device_index == params.root) { input_tensors.resize(1); int64_t j = 0; - for (auto i : c10::irange(params.team.size())) { + for (auto i : c10::irange(communication->team().size())) { if (root_relative_index == static_cast(i) && - !params.is_root_in_mesh) { + !communication->isRootInMesh()) { input_tensors.front().push_back(output_tensor); continue; } @@ -298,7 +286,7 @@ c10::intrusive_ptr postScatter( j++; } - assertBufferCount(input_tensors[0], params.team.size()); + assertBufferCount(input_tensors[0], communication->team().size()); assertBuffersHaveSameSize(input_tensors[0], output_tensors); } @@ -316,7 +304,7 @@ c10::intrusive_ptr postReduce( at::Tensor tensor; if (my_device_index == params.root) { - if (params.is_root_in_mesh) { + if (communication->isRootInMesh()) { doLocalCopy(output_tensor, input_tensor); tensor = output_tensor; } else { @@ -369,7 +357,7 @@ c10::intrusive_ptr postReduceScatter( std::vector output_tensors({output_tensor}); - assertBufferCount(input_tensors[0], params.team.size()); + assertBufferCount(input_tensors[0], communication->team().size()); return backend->reduce_scatter( output_tensors, input_tensors, {.reduceOp = params.redOp}); } @@ -382,18 +370,15 @@ c10::intrusive_ptr postSendRecv( at::Tensor output_tensor) { const CommParams& params = communication->params(); - NVF_ERROR( - params.team.size() == 1 || params.team.size() == 2, - "the team size should be 1 or 2"); + NVF_ERROR(params.mesh.size() == 1, "The mesh size should be 1."); - if (params.team.size() == 1) { + if (communication->isRootInMesh()) { doLocalCopy(output_tensor, input_tensor); return nullptr; } const DeviceIdxType sender = params.root; - const DeviceIdxType receiver = - params.team.at(0) == sender ? params.team.at(1) : params.team.at(0); + const DeviceIdxType receiver = params.mesh.at(0); std::vector tensors; if (my_device_index == sender) { @@ -414,9 +399,9 @@ c10::intrusive_ptr postSingleCommunication( at::Tensor input_tensor, at::Tensor output_tensor) { const CommParams& params = communication->params(); + const Team& team = communication->team(); NVF_ERROR( - std::find(params.team.begin(), params.team.end(), my_device_index) != - params.team.end(), + std::find(team.begin(), team.end(), my_device_index) != team.end(), "current device index ", my_device_index, " must be present in the communication's team"); diff --git a/csrc/multidevice/communication.h b/csrc/multidevice/communication.h index bf8a74f194a..7457679af58 100644 --- a/csrc/multidevice/communication.h +++ b/csrc/multidevice/communication.h @@ -10,6 +10,7 @@ #include #include #include +#include #include #ifdef NVFUSER_DISTRIBUTED #include @@ -38,10 +39,12 @@ std::ostream& operator<<(std::ostream& os, const CommunicationType& type); struct CommParams { CommunicationType type; DeviceIdxType root = -1; - bool is_root_in_mesh = true; - Team team; // should not have duplicates and should contain both the root and - // the mesh + DeviceMesh mesh; // Might not contain `root`. + Team team; // All devices involved in this communication. It must include + // `root`. It can be a subset of `root`+`mesh` in case of 2D + // sharding. c10d::ReduceOp::RedOpType redOp = c10d::ReduceOp::RedOpType::UNUSED; + // reduced_axis is always outermost. int64_t scattered_axis = -1; }; @@ -141,13 +144,20 @@ class Communication : public Expr { // the constructor that takes IrCloner aren't needed. bool sameAs(const Statement* other) const override; - // TODO: const CommParams&. - auto params() const { + const CommParams& params() const { return params_; } + bool isRootInMesh() const { + return params_.mesh.has(params_.root); + } + + const Team& team() const { + return params_.team; + } + private: - // store the arguments of the communication + // Stores the arguments used to construct the communication. CommParams params_; }; diff --git a/csrc/multidevice/executor.cpp b/csrc/multidevice/executor.cpp index eecc6b4873c..ba5eb9b2110 100644 --- a/csrc/multidevice/executor.cpp +++ b/csrc/multidevice/executor.cpp @@ -181,7 +181,7 @@ void MultiDeviceExecutor::postCommunication(SegmentedGroup* group) { // post and wait communications for (Communication* communication : communications) { c10::intrusive_ptr backend = - comm_.getBackendForTeam(communication->params().team, std::nullopt); + comm_.getBackendForTeam(communication->team(), std::nullopt); c10::intrusive_ptr work = postSingleCommunication( communication, comm_.deviceId(), backend, input_tensor, output_tensor); if (work != nullptr) { diff --git a/csrc/multidevice/lower_communication.cpp b/csrc/multidevice/lower_communication.cpp index df5f41bfd74..b0d0bee764e 100644 --- a/csrc/multidevice/lower_communication.cpp +++ b/csrc/multidevice/lower_communication.cpp @@ -60,6 +60,9 @@ inline bool isDeviceInvolved( // params. Since most of the steps are somewhat similar/opposite in those // cases, we gathered the two implementations into one function. The argument // "is_scatter" allows to discriminate between scatter and gather +// +// TODO: remove this helper. root_tv and is_scatter are not used and the +// definition is trivial to inline. CommParams createParamsForGatherScatter( DeviceIdxType my_device_index, DeviceIdxType root, @@ -70,9 +73,9 @@ CommParams createParamsForGatherScatter( params.type = is_scatter ? CommunicationType::Scatter : CommunicationType::Gather; params.root = root; + params.mesh = mesh; params.team = mesh.vector(); if (!mesh.has(root)) { - params.is_root_in_mesh = false; params.team.push_back(root); } return params; @@ -131,6 +134,7 @@ void lowerToAllgather( CommParams params; params.type = CommunicationType::Allgather; + params.mesh = mesh; params.team = mesh.vector(); comms.push_back(IrBuilder::create(std::move(params))); } @@ -142,14 +146,13 @@ CommParams createParamsForBroadcastOrP2P( // receiver devices const DeviceMesh& mesh) { CommParams params; + params.type = CommunicationType::Broadcast; params.root = root; + params.mesh = mesh; params.team = mesh.vector(); if (!mesh.has(root)) { - params.is_root_in_mesh = false; params.team.push_back(root); } - params.type = CommunicationType::Broadcast; - return params; } @@ -209,12 +212,11 @@ CommParams createParamsForReduce( params.type = CommunicationType::Reduce; params.root = root; params.redOp = getC10dReduceOpType(op_type); + params.mesh = mesh; params.team = mesh.vector(); if (!mesh.has(root)) { - params.is_root_in_mesh = false; params.team.push_back(root); } - // FIXME: we may want to store sharded_dim to params for speed. return params; } @@ -251,6 +253,7 @@ void lowerToAllreduce( CommParams params; params.type = CommunicationType::Allreduce; params.redOp = getC10dReduceOpType(op_type); + params.mesh = mesh; params.team = mesh.vector(); comms.push_back(IrBuilder::create(params)); } @@ -266,10 +269,6 @@ void lowerToReduceScatter( return; } - CommParams params; - params.type = CommunicationType::ReduceScatter; - params.redOp = getC10dReduceOpType(op_type); - params.team = mesh.vector(); auto reduction_axis = output_tv->getReductionAxis().value(); auto scattered_axis = getShardedAxis(output_tv); // The output tensor is sharded on scattered_axis and needs to be mapped @@ -279,8 +278,13 @@ void lowerToReduceScatter( if (reduction_axis <= scattered_axis) { scattered_axis++; } - params.scattered_axis = scattered_axis; + CommParams params; + params.type = CommunicationType::ReduceScatter; + params.redOp = getC10dReduceOpType(op_type); + params.mesh = mesh; + params.team = mesh.vector(); + params.scattered_axis = scattered_axis; comms.push_back(IrBuilder::create(params)); } diff --git a/tests/cpp/test_multidevice_communications.cpp b/tests/cpp/test_multidevice_communications.cpp index eaa0b3a3e93..05d9a1e7f06 100644 --- a/tests/cpp/test_multidevice_communications.cpp +++ b/tests/cpp/test_multidevice_communications.cpp @@ -36,16 +36,16 @@ class CommunicationTest static constexpr c10d::ReduceOp::RedOpType red_op = c10d::ReduceOp::RedOpType::SUM; CommParams params; - std::vector all_ranks; + const DeviceMesh full_mesh; + const Team all_ranks; c10::intrusive_ptr backend; IrContainer container; }; -CommunicationTest::CommunicationTest() { - all_ranks = std::vector(communicator->size()); - std::iota(all_ranks.begin(), all_ranks.end(), 0); - backend = communicator->getBackendForTeam(all_ranks, GetParam()); -} +CommunicationTest::CommunicationTest() + : full_mesh(DeviceMesh::createForNumDevices(communicator->size())), + all_ranks(full_mesh.vector()), + backend(communicator->getBackendForTeam(all_ranks, GetParam())) {} void CommunicationTest::SetUp() { MultiDeviceTest::SetUp(); @@ -65,6 +65,7 @@ void CommunicationTest::validate(at::Tensor obtained, at::Tensor expected) { TEST_P(CommunicationTest, Gather) { params.type = CommunicationType::Gather; params.root = root; + params.mesh = full_mesh; params.team = all_ranks; auto communication = IrBuilder::create(&container, params); @@ -94,6 +95,7 @@ TEST_P(CommunicationTest, Gather) { TEST_P(CommunicationTest, Allgather) { params.type = CommunicationType::Allgather; + params.mesh = full_mesh; params.team = all_ranks; auto communication = IrBuilder::create(&container, params); @@ -123,6 +125,7 @@ TEST_P(CommunicationTest, Allgather) { TEST_P(CommunicationTest, Scatter) { params.type = CommunicationType::Scatter; params.root = root; + params.mesh = full_mesh; params.team = all_ranks; auto communication = IrBuilder::create(&container, params); @@ -158,6 +161,7 @@ TEST_P(CommunicationTest, Scatter) { TEST_P(CommunicationTest, Broadcast) { params.type = CommunicationType::Broadcast; params.root = root; + params.mesh = full_mesh; params.team = all_ranks; auto communication = IrBuilder::create(&container, params); @@ -203,7 +207,8 @@ TEST_P(CommunicationTest, SendRecv) { params.type = CommunicationType::SendRecv; params.root = sender; - params.team = {0, 1}; + params.mesh = {receiver}; + params.team = {sender, receiver}; auto communication = IrBuilder::create(&container, params); at::Tensor input_tensor; @@ -244,7 +249,8 @@ TEST_P(CommunicationTest, SendRecvToSelf) { params.type = CommunicationType::SendRecv; params.root = sender; - params.team = {0}; + params.mesh = {sender}; + params.team = {sender}; auto communication = IrBuilder::create(&container, params); at::Tensor input_tensor = at::empty({tensor_size}, tensor_options); @@ -269,6 +275,7 @@ TEST_P(CommunicationTest, Reduce) { params.type = CommunicationType::Reduce; params.redOp = red_op; params.root = root; + params.mesh = full_mesh; params.team = all_ranks; auto communication = IrBuilder::create(&container, params); @@ -300,6 +307,7 @@ TEST_P(CommunicationTest, Reduce) { TEST_P(CommunicationTest, Allreduce) { params.type = CommunicationType::Allreduce; params.redOp = red_op; + params.mesh = full_mesh; params.team = all_ranks; auto communication = IrBuilder::create(&container, params); @@ -329,6 +337,7 @@ TEST_P(CommunicationTest, ReduceScatter) { params.type = CommunicationType::ReduceScatter; params.redOp = red_op; params.root = root; + params.mesh = full_mesh; params.team = all_ranks; params.scattered_axis = 1; auto communication = IrBuilder::create(&container, params); diff --git a/tests/cpp/test_multidevice_host_ir.cpp b/tests/cpp/test_multidevice_host_ir.cpp index ca3c098cad9..22132977d50 100644 --- a/tests/cpp/test_multidevice_host_ir.cpp +++ b/tests/cpp/test_multidevice_host_ir.cpp @@ -77,7 +77,7 @@ TEST_P(MultiDeviceHostIrTest, SingleFusionSingleComm) { CommParams comm_params{ .type = CommunicationType::Allgather, .root = 0, - .is_root_in_mesh = true, + .mesh = mesh, .team = mesh.vector()}; auto communication = IrBuilder::create( static_cast(hic.get()), comm_params);