From d8a2bfc46389793e690d78485a65e43c2e95a242 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Mon, 13 May 2024 23:07:52 +0000 Subject: [PATCH 1/7] const ref --- csrc/multidevice/communication.h | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/csrc/multidevice/communication.h b/csrc/multidevice/communication.h index bf8a74f194a..83af17a4b1c 100644 --- a/csrc/multidevice/communication.h +++ b/csrc/multidevice/communication.h @@ -141,8 +141,7 @@ class Communication : public Expr { // the constructor that takes IrCloner aren't needed. bool sameAs(const Statement* other) const override; - // TODO: const CommParams&. - auto params() const { + const CommParams& params() const { return params_; } From 681e7623b2c21b44d18b7ff6102c353cd3e0a261 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Wed, 15 May 2024 00:04:55 +0000 Subject: [PATCH 2/7] Attempt to remove some fields from CommsParams. --- csrc/multidevice/communication.cpp | 89 +++++++++---------- csrc/multidevice/communication.h | 14 ++- csrc/multidevice/executor.cpp | 2 +- csrc/multidevice/lower_communication.cpp | 31 +++---- tests/cpp/test_multidevice_communications.cpp | 27 +++--- 5 files changed, 76 insertions(+), 87 deletions(-) diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index b7ffc3d0bfd..9cef4ab01c0 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -100,7 +100,7 @@ T getInitialValue(c10d::ReduceOp::RedOpType op) { bool hasRoot(CommunicationType type) { return type == CommunicationType::Gather || - type == CommunicationType::Scatter || + type == CommunicationType::Scatter || type == CommunicationType::Reduce || type == CommunicationType::Broadcast || type == CommunicationType::SendRecv; } @@ -111,36 +111,27 @@ bool isReduction(CommunicationType type) { type == CommunicationType::ReduceScatter; } -void assertValid(const CommParams& params) { - std::unordered_set team_without_duplicates( - params.team.begin(), params.team.end()); - NVF_ERROR( - team_without_duplicates.size() == params.team.size(), - "the communication must not involve the same device more than once"); - NVF_ERROR(!params.team.empty(), "the team size must be greater than 0"); - if (hasRoot(params.type)) { - auto it = std::find(params.team.begin(), params.team.end(), params.root); - NVF_ERROR( - it != params.team.end(), - "root (device ", - params.root, - ") must be present in the communication's team"); - } -} - // pytorch's process group expects the root to be specified // as an integer between 0 and world_size-1. We choose it to be // the device's relative index within the team int64_t getRootRelativeIndex(const CommParams& params) { - auto it = std::find(params.team.begin(), params.team.end(), params.root); - return std::distance(params.team.begin(), it); + int64_t relative_index = params.mesh.idxOf(params.root); + // This assumes the root, if not in mesh, is added to the end of the team + // vector. Given the assumption is distantly enforced at + // Communication::team(), I slightly prefer making this function a method of + // Communication so relative_index can be computed from Communication::team() + // directly. Wdyt? + if (relative_index == -1) { + relative_index = params.mesh.size(); + } + return relative_index; } } // namespace Communication::Communication(IrBuilderPasskey passkey, CommParams params) : Expr(passkey), params_(std::move(params)) { - assertValid(params_); + NVF_ERROR(params_.mesh.size() > 0, "The mesh size must be greater than 0."); } Communication::Communication(const Communication* src, IrCloner* ir_cloner) @@ -148,16 +139,25 @@ Communication::Communication(const Communication* src, IrCloner* ir_cloner) NVFUSER_DEFINE_CLONE_AND_CREATE(Communication) +const Team& Communication::team() { + if (team_.empty()) { + team_ = params_.mesh.vector(); + if (hasRoot(params_.type) && !isRootInMesh()) { + team_.push_back(params_.root); + } + } + return team_; +} + std::string Communication::toString(const int indent_size) const { std::stringstream ss; indent(ss, indent_size) << "Communication " << params_.type << ": {" << std::endl; - if (hasRoot(params_.type)) { indent(ss, indent_size + 1) << "root: " << params_.root << "," << std::endl; } - indent(ss, indent_size + 1) << "team: " << params_.team << "," << std::endl; + indent(ss, indent_size + 1) << "mesh: " << params_.mesh << "," << std::endl; indent(ss, indent_size) << "}"; return ss.str(); @@ -180,7 +180,7 @@ bool Communication::sameAs(const Statement* other) const { return ( p1.type == p2.type && (!hasRoot(p1.type) || p1.root == p2.root) && - p1.team == p2.team && (!isReduction(p1.type) || p1.redOp == p2.redOp)); + p1.mesh == p2.mesh && (!isReduction(p1.type) || p1.redOp == p2.redOp)); } namespace { @@ -192,7 +192,7 @@ c10::intrusive_ptr postBroadcast( at::Tensor output_tensor) { const CommParams& params = communication->params(); if (my_device_index == params.root) { - if (params.is_root_in_mesh) { + if (communication->isRootInMesh()) { // Do a local copy and the subsequent broadcast will be in place. Consider // ProcessGroupNCCL::_broadcast_oop so ncclBroadcast doesn't wait for the // local copy to complete. @@ -203,7 +203,7 @@ c10::intrusive_ptr postBroadcast( } } - if (params.team.size() == 1) { + if (communication->team().size() == 1) { return nullptr; } @@ -219,7 +219,7 @@ c10::intrusive_ptr postGather( at::Tensor input_tensor, at::Tensor output_tensor) { const CommParams& params = communication->params(); - if (my_device_index == params.root && !params.is_root_in_mesh) { + if (my_device_index == params.root && !communication->isRootInMesh()) { // This is likely a suboptimal way to allocate tensors for nccl. To benefit // from zero copy // (https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/bufferreg.html), @@ -235,9 +235,9 @@ c10::intrusive_ptr postGather( if (my_device_index == params.root) { output_tensors.resize(1); int64_t j = 0; - for (auto i : c10::irange(params.team.size())) { + for (auto i : c10::irange(communication->team().size())) { if (root_relative_index == static_cast(i) && - !params.is_root_in_mesh) { + !communication->isRootInMesh()) { output_tensors[0].push_back(input_tensor); continue; } @@ -245,7 +245,7 @@ c10::intrusive_ptr postGather( j++; } - assertBufferCount(output_tensors[0], params.team.size()); + assertBufferCount(output_tensors[0], communication->team().size()); assertBuffersHaveSameSize(input_tensors, output_tensors[0]); } @@ -259,13 +259,11 @@ c10::intrusive_ptr postAllgather( c10::intrusive_ptr backend, at::Tensor input_tensor, at::Tensor output_tensor) { - const CommParams& params = communication->params(); - std::vector input_tensors({input_tensor}); std::vector> output_tensors(1); output_tensors[0] = at::split(output_tensor, /*split_size=*/1, /*dim=*/0); - assertBufferCount(output_tensors[0], params.team.size()); + assertBufferCount(output_tensors[0], communication->team().size()); assertBuffersHaveSameSize(input_tensors, output_tensors[0]); return backend->allgather(output_tensors, input_tensors, {}); } @@ -278,7 +276,7 @@ c10::intrusive_ptr postScatter( at::Tensor output_tensor) { const CommParams& params = communication->params(); - if (my_device_index == params.root && !params.is_root_in_mesh) { + if (my_device_index == params.root && !communication->isRootInMesh()) { output_tensor = at::empty_like(input_tensor.slice(0, 0, 1)); } std::vector output_tensors({output_tensor}); @@ -288,9 +286,9 @@ c10::intrusive_ptr postScatter( if (my_device_index == params.root) { input_tensors.resize(1); int64_t j = 0; - for (auto i : c10::irange(params.team.size())) { + for (auto i : c10::irange(communication->team().size())) { if (root_relative_index == static_cast(i) && - !params.is_root_in_mesh) { + !communication->isRootInMesh()) { input_tensors.front().push_back(output_tensor); continue; } @@ -298,7 +296,7 @@ c10::intrusive_ptr postScatter( j++; } - assertBufferCount(input_tensors[0], params.team.size()); + assertBufferCount(input_tensors[0], communication->team().size()); assertBuffersHaveSameSize(input_tensors[0], output_tensors); } @@ -316,7 +314,7 @@ c10::intrusive_ptr postReduce( at::Tensor tensor; if (my_device_index == params.root) { - if (params.is_root_in_mesh) { + if (communication->isRootInMesh()) { doLocalCopy(output_tensor, input_tensor); tensor = output_tensor; } else { @@ -369,7 +367,7 @@ c10::intrusive_ptr postReduceScatter( std::vector output_tensors({output_tensor}); - assertBufferCount(input_tensors[0], params.team.size()); + assertBufferCount(input_tensors[0], communication->team().size()); return backend->reduce_scatter( output_tensors, input_tensors, {.reduceOp = params.redOp}); } @@ -382,18 +380,15 @@ c10::intrusive_ptr postSendRecv( at::Tensor output_tensor) { const CommParams& params = communication->params(); - NVF_ERROR( - params.team.size() == 1 || params.team.size() == 2, - "the team size should be 1 or 2"); + NVF_ERROR(params.mesh.size() == 1, "The mesh size should be 1."); - if (params.team.size() == 1) { + if (communication->isRootInMesh()) { doLocalCopy(output_tensor, input_tensor); return nullptr; } const DeviceIdxType sender = params.root; - const DeviceIdxType receiver = - params.team.at(0) == sender ? params.team.at(1) : params.team.at(0); + const DeviceIdxType receiver = params.mesh.at(0); std::vector tensors; if (my_device_index == sender) { @@ -414,9 +409,9 @@ c10::intrusive_ptr postSingleCommunication( at::Tensor input_tensor, at::Tensor output_tensor) { const CommParams& params = communication->params(); + const Team& team = communication->team(); NVF_ERROR( - std::find(params.team.begin(), params.team.end(), my_device_index) != - params.team.end(), + std::find(team.begin(), team.end(), my_device_index) != team.end(), "current device index ", my_device_index, " must be present in the communication's team"); diff --git a/csrc/multidevice/communication.h b/csrc/multidevice/communication.h index 83af17a4b1c..e935953f8ea 100644 --- a/csrc/multidevice/communication.h +++ b/csrc/multidevice/communication.h @@ -38,10 +38,9 @@ std::ostream& operator<<(std::ostream& os, const CommunicationType& type); struct CommParams { CommunicationType type; DeviceIdxType root = -1; - bool is_root_in_mesh = true; - Team team; // should not have duplicates and should contain both the root and - // the mesh + DeviceMesh mesh; // Might not contain `root`. c10d::ReduceOp::RedOpType redOp = c10d::ReduceOp::RedOpType::UNUSED; + // reduced_axis is always outermost. int64_t scattered_axis = -1; }; @@ -145,9 +144,16 @@ class Communication : public Expr { return params_; } + bool isRootInMesh() const { + return params_.mesh.has(params_.root); + } + + const Team& team(); + private: - // store the arguments of the communication + // Stores the arguments used to construct the communication. CommParams params_; + Team team_; }; // Triggers the execution of the communication. This is a non-blocking call. diff --git a/csrc/multidevice/executor.cpp b/csrc/multidevice/executor.cpp index eecc6b4873c..ba5eb9b2110 100644 --- a/csrc/multidevice/executor.cpp +++ b/csrc/multidevice/executor.cpp @@ -181,7 +181,7 @@ void MultiDeviceExecutor::postCommunication(SegmentedGroup* group) { // post and wait communications for (Communication* communication : communications) { c10::intrusive_ptr backend = - comm_.getBackendForTeam(communication->params().team, std::nullopt); + comm_.getBackendForTeam(communication->team(), std::nullopt); c10::intrusive_ptr work = postSingleCommunication( communication, comm_.deviceId(), backend, input_tensor, output_tensor); if (work != nullptr) { diff --git a/csrc/multidevice/lower_communication.cpp b/csrc/multidevice/lower_communication.cpp index df5f41bfd74..9fd9a5d89f2 100644 --- a/csrc/multidevice/lower_communication.cpp +++ b/csrc/multidevice/lower_communication.cpp @@ -60,6 +60,9 @@ inline bool isDeviceInvolved( // params. Since most of the steps are somewhat similar/opposite in those // cases, we gathered the two implementations into one function. The argument // "is_scatter" allows to discriminate between scatter and gather +// +// TODO: remove this helper. root_tv and is_scatter are not used and the +// definition is trivial to inline. CommParams createParamsForGatherScatter( DeviceIdxType my_device_index, DeviceIdxType root, @@ -70,11 +73,7 @@ CommParams createParamsForGatherScatter( params.type = is_scatter ? CommunicationType::Scatter : CommunicationType::Gather; params.root = root; - params.team = mesh.vector(); - if (!mesh.has(root)) { - params.is_root_in_mesh = false; - params.team.push_back(root); - } + params.mesh = mesh; return params; } @@ -131,7 +130,7 @@ void lowerToAllgather( CommParams params; params.type = CommunicationType::Allgather; - params.team = mesh.vector(); + params.mesh = mesh; comms.push_back(IrBuilder::create(std::move(params))); } @@ -142,14 +141,9 @@ CommParams createParamsForBroadcastOrP2P( // receiver devices const DeviceMesh& mesh) { CommParams params; - params.root = root; - params.team = mesh.vector(); - if (!mesh.has(root)) { - params.is_root_in_mesh = false; - params.team.push_back(root); - } params.type = CommunicationType::Broadcast; - + params.root = root; + params.mesh = mesh; return params; } @@ -209,12 +203,7 @@ CommParams createParamsForReduce( params.type = CommunicationType::Reduce; params.root = root; params.redOp = getC10dReduceOpType(op_type); - params.team = mesh.vector(); - if (!mesh.has(root)) { - params.is_root_in_mesh = false; - params.team.push_back(root); - } - // FIXME: we may want to store sharded_dim to params for speed. + params.mesh = mesh; return params; } @@ -251,7 +240,7 @@ void lowerToAllreduce( CommParams params; params.type = CommunicationType::Allreduce; params.redOp = getC10dReduceOpType(op_type); - params.team = mesh.vector(); + params.mesh = mesh; comms.push_back(IrBuilder::create(params)); } @@ -269,7 +258,7 @@ void lowerToReduceScatter( CommParams params; params.type = CommunicationType::ReduceScatter; params.redOp = getC10dReduceOpType(op_type); - params.team = mesh.vector(); + params.mesh = mesh; auto reduction_axis = output_tv->getReductionAxis().value(); auto scattered_axis = getShardedAxis(output_tv); // The output tensor is sharded on scattered_axis and needs to be mapped diff --git a/tests/cpp/test_multidevice_communications.cpp b/tests/cpp/test_multidevice_communications.cpp index eaa0b3a3e93..b19bf39915e 100644 --- a/tests/cpp/test_multidevice_communications.cpp +++ b/tests/cpp/test_multidevice_communications.cpp @@ -36,15 +36,14 @@ class CommunicationTest static constexpr c10d::ReduceOp::RedOpType red_op = c10d::ReduceOp::RedOpType::SUM; CommParams params; - std::vector all_ranks; + const DeviceMesh all_ranks; c10::intrusive_ptr backend; IrContainer container; }; -CommunicationTest::CommunicationTest() { - all_ranks = std::vector(communicator->size()); - std::iota(all_ranks.begin(), all_ranks.end(), 0); - backend = communicator->getBackendForTeam(all_ranks, GetParam()); +CommunicationTest::CommunicationTest() + : all_ranks(DeviceMesh::createForNumDevices(communicator->size())), + backend(communicator->getBackendForTeam(all_ranks.vector(), GetParam())) { } void CommunicationTest::SetUp() { @@ -65,7 +64,7 @@ void CommunicationTest::validate(at::Tensor obtained, at::Tensor expected) { TEST_P(CommunicationTest, Gather) { params.type = CommunicationType::Gather; params.root = root; - params.team = all_ranks; + params.mesh = all_ranks; auto communication = IrBuilder::create(&container, params); at::Tensor input_tensor = at::empty({1, tensor_size}, tensor_options); @@ -94,7 +93,7 @@ TEST_P(CommunicationTest, Gather) { TEST_P(CommunicationTest, Allgather) { params.type = CommunicationType::Allgather; - params.team = all_ranks; + params.mesh = all_ranks; auto communication = IrBuilder::create(&container, params); at::Tensor input_tensor = at::empty({1, tensor_size}, tensor_options); @@ -123,7 +122,7 @@ TEST_P(CommunicationTest, Allgather) { TEST_P(CommunicationTest, Scatter) { params.type = CommunicationType::Scatter; params.root = root; - params.team = all_ranks; + params.mesh = all_ranks; auto communication = IrBuilder::create(&container, params); at::Tensor input_tensor; @@ -158,7 +157,7 @@ TEST_P(CommunicationTest, Scatter) { TEST_P(CommunicationTest, Broadcast) { params.type = CommunicationType::Broadcast; params.root = root; - params.team = all_ranks; + params.mesh = all_ranks; auto communication = IrBuilder::create(&container, params); at::Tensor input_tensor; @@ -203,7 +202,7 @@ TEST_P(CommunicationTest, SendRecv) { params.type = CommunicationType::SendRecv; params.root = sender; - params.team = {0, 1}; + params.mesh = {receiver}; auto communication = IrBuilder::create(&container, params); at::Tensor input_tensor; @@ -244,7 +243,7 @@ TEST_P(CommunicationTest, SendRecvToSelf) { params.type = CommunicationType::SendRecv; params.root = sender; - params.team = {0}; + params.mesh = {sender}; auto communication = IrBuilder::create(&container, params); at::Tensor input_tensor = at::empty({tensor_size}, tensor_options); @@ -269,7 +268,7 @@ TEST_P(CommunicationTest, Reduce) { params.type = CommunicationType::Reduce; params.redOp = red_op; params.root = root; - params.team = all_ranks; + params.mesh = all_ranks; auto communication = IrBuilder::create(&container, params); at::Tensor input_tensor = at::empty({1, tensor_size}, tensor_options); @@ -300,7 +299,7 @@ TEST_P(CommunicationTest, Reduce) { TEST_P(CommunicationTest, Allreduce) { params.type = CommunicationType::Allreduce; params.redOp = red_op; - params.team = all_ranks; + params.mesh = all_ranks; auto communication = IrBuilder::create(&container, params); at::Tensor input_tensor = at::empty({1, tensor_size}, tensor_options); @@ -329,7 +328,7 @@ TEST_P(CommunicationTest, ReduceScatter) { params.type = CommunicationType::ReduceScatter; params.redOp = red_op; params.root = root; - params.team = all_ranks; + params.mesh = all_ranks; params.scattered_axis = 1; auto communication = IrBuilder::create(&container, params); From a1fb079baf3b5b92930bbeda51a6de0f281e1f7d Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Wed, 15 May 2024 19:37:35 +0000 Subject: [PATCH 3/7] Remove inlines. --- csrc/multidevice/communication.cpp | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index 9cef4ab01c0..8d432148391 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -48,9 +48,7 @@ std::ostream& operator<<(std::ostream& os, const CommunicationType& type) { namespace { -inline void assertBufferCount( - const std::vector& bufs, - size_t count) { +void assertBufferCount(const std::vector& bufs, size_t count) { NVF_ERROR( bufs.size() == count, "there must be ", @@ -60,7 +58,7 @@ inline void assertBufferCount( " were given"); } -inline void assertBuffersHaveSameSize( +void assertBuffersHaveSameSize( const std::vector& bufs1, const std::vector& bufs2) { if (bufs1.empty() && bufs2.empty()) { @@ -76,7 +74,7 @@ inline void assertBuffersHaveSameSize( } } -inline void doLocalCopy(const at::Tensor& dst, const at::Tensor& src) { +void doLocalCopy(const at::Tensor& dst, const at::Tensor& src) { dst.view_as(src).copy_(src, /*non_blocking=*/true); } From a1770bcff3080d59153bd5db7855408b238df971 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Wed, 15 May 2024 19:48:49 +0000 Subject: [PATCH 4/7] More comments. --- csrc/multidevice/communication.cpp | 4 ++++ csrc/multidevice/communication.h | 3 +++ 2 files changed, 7 insertions(+) diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index 8d432148391..8cf9ad681fb 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -139,6 +139,10 @@ NVFUSER_DEFINE_CLONE_AND_CREATE(Communication) const Team& Communication::team() { if (team_.empty()) { + // I could instead compute `team_` in the constructor. But I chose to + // compute and cache it in team() so I can remove the customized cloning + // constructor (the one that takes IrCloner*) in favor of `Expr::Expr(const + // Expr*, IrCloner*)` in an upcoming PR. team_ = params_.mesh.vector(); if (hasRoot(params_.type) && !isRootInMesh()) { team_.push_back(params_.root); diff --git a/csrc/multidevice/communication.h b/csrc/multidevice/communication.h index e935953f8ea..a9a3953ac38 100644 --- a/csrc/multidevice/communication.h +++ b/csrc/multidevice/communication.h @@ -153,6 +153,9 @@ class Communication : public Expr { private: // Stores the arguments used to construct the communication. CommParams params_; + // This can be computed from params_, but given how frequently this is used + // in the hot path, I'm currently storing it as a field that'll be computed by + // Communication::team(). Team team_; }; From 1f0bac08fa11f02acf2e569a30cb8dfb234760a4 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Fri, 31 May 2024 18:30:58 +0000 Subject: [PATCH 5/7] Add team back. --- csrc/multidevice/communication.cpp | 15 +-------- csrc/multidevice/communication.h | 11 ++++--- csrc/multidevice/lower_communication.cpp | 25 ++++++++++++--- tests/cpp/test_multidevice_communications.cpp | 32 ++++++++++++------- 4 files changed, 48 insertions(+), 35 deletions(-) diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index 8cf9ad681fb..ad5efa5b4e3 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -137,20 +137,6 @@ Communication::Communication(const Communication* src, IrCloner* ir_cloner) NVFUSER_DEFINE_CLONE_AND_CREATE(Communication) -const Team& Communication::team() { - if (team_.empty()) { - // I could instead compute `team_` in the constructor. But I chose to - // compute and cache it in team() so I can remove the customized cloning - // constructor (the one that takes IrCloner*) in favor of `Expr::Expr(const - // Expr*, IrCloner*)` in an upcoming PR. - team_ = params_.mesh.vector(); - if (hasRoot(params_.type) && !isRootInMesh()) { - team_.push_back(params_.root); - } - } - return team_; -} - std::string Communication::toString(const int indent_size) const { std::stringstream ss; @@ -160,6 +146,7 @@ std::string Communication::toString(const int indent_size) const { indent(ss, indent_size + 1) << "root: " << params_.root << "," << std::endl; } indent(ss, indent_size + 1) << "mesh: " << params_.mesh << "," << std::endl; + indent(ss, indent_size + 1) << "team: " << params_.team << "," << std::endl; indent(ss, indent_size) << "}"; return ss.str(); diff --git a/csrc/multidevice/communication.h b/csrc/multidevice/communication.h index a9a3953ac38..a817651eae0 100644 --- a/csrc/multidevice/communication.h +++ b/csrc/multidevice/communication.h @@ -39,6 +39,9 @@ struct CommParams { CommunicationType type; DeviceIdxType root = -1; DeviceMesh mesh; // Might not contain `root`. + Team team; // All devices involved in this communication. It must include + // `root`. It can be a subset of `root`+`mesh` in case of 2D + // sharding. c10d::ReduceOp::RedOpType redOp = c10d::ReduceOp::RedOpType::UNUSED; // reduced_axis is always outermost. int64_t scattered_axis = -1; @@ -148,15 +151,13 @@ class Communication : public Expr { return params_.mesh.has(params_.root); } - const Team& team(); + const Team& team() const { + return params_.team; + } private: // Stores the arguments used to construct the communication. CommParams params_; - // This can be computed from params_, but given how frequently this is used - // in the hot path, I'm currently storing it as a field that'll be computed by - // Communication::team(). - Team team_; }; // Triggers the execution of the communication. This is a non-blocking call. diff --git a/csrc/multidevice/lower_communication.cpp b/csrc/multidevice/lower_communication.cpp index 9fd9a5d89f2..b0d0bee764e 100644 --- a/csrc/multidevice/lower_communication.cpp +++ b/csrc/multidevice/lower_communication.cpp @@ -74,6 +74,10 @@ CommParams createParamsForGatherScatter( is_scatter ? CommunicationType::Scatter : CommunicationType::Gather; params.root = root; params.mesh = mesh; + params.team = mesh.vector(); + if (!mesh.has(root)) { + params.team.push_back(root); + } return params; } @@ -131,6 +135,7 @@ void lowerToAllgather( CommParams params; params.type = CommunicationType::Allgather; params.mesh = mesh; + params.team = mesh.vector(); comms.push_back(IrBuilder::create(std::move(params))); } @@ -144,6 +149,10 @@ CommParams createParamsForBroadcastOrP2P( params.type = CommunicationType::Broadcast; params.root = root; params.mesh = mesh; + params.team = mesh.vector(); + if (!mesh.has(root)) { + params.team.push_back(root); + } return params; } @@ -204,6 +213,10 @@ CommParams createParamsForReduce( params.root = root; params.redOp = getC10dReduceOpType(op_type); params.mesh = mesh; + params.team = mesh.vector(); + if (!mesh.has(root)) { + params.team.push_back(root); + } return params; } @@ -241,6 +254,7 @@ void lowerToAllreduce( params.type = CommunicationType::Allreduce; params.redOp = getC10dReduceOpType(op_type); params.mesh = mesh; + params.team = mesh.vector(); comms.push_back(IrBuilder::create(params)); } @@ -255,10 +269,6 @@ void lowerToReduceScatter( return; } - CommParams params; - params.type = CommunicationType::ReduceScatter; - params.redOp = getC10dReduceOpType(op_type); - params.mesh = mesh; auto reduction_axis = output_tv->getReductionAxis().value(); auto scattered_axis = getShardedAxis(output_tv); // The output tensor is sharded on scattered_axis and needs to be mapped @@ -268,8 +278,13 @@ void lowerToReduceScatter( if (reduction_axis <= scattered_axis) { scattered_axis++; } - params.scattered_axis = scattered_axis; + CommParams params; + params.type = CommunicationType::ReduceScatter; + params.redOp = getC10dReduceOpType(op_type); + params.mesh = mesh; + params.team = mesh.vector(); + params.scattered_axis = scattered_axis; comms.push_back(IrBuilder::create(params)); } diff --git a/tests/cpp/test_multidevice_communications.cpp b/tests/cpp/test_multidevice_communications.cpp index b19bf39915e..05d9a1e7f06 100644 --- a/tests/cpp/test_multidevice_communications.cpp +++ b/tests/cpp/test_multidevice_communications.cpp @@ -36,15 +36,16 @@ class CommunicationTest static constexpr c10d::ReduceOp::RedOpType red_op = c10d::ReduceOp::RedOpType::SUM; CommParams params; - const DeviceMesh all_ranks; + const DeviceMesh full_mesh; + const Team all_ranks; c10::intrusive_ptr backend; IrContainer container; }; CommunicationTest::CommunicationTest() - : all_ranks(DeviceMesh::createForNumDevices(communicator->size())), - backend(communicator->getBackendForTeam(all_ranks.vector(), GetParam())) { -} + : full_mesh(DeviceMesh::createForNumDevices(communicator->size())), + all_ranks(full_mesh.vector()), + backend(communicator->getBackendForTeam(all_ranks, GetParam())) {} void CommunicationTest::SetUp() { MultiDeviceTest::SetUp(); @@ -64,7 +65,8 @@ void CommunicationTest::validate(at::Tensor obtained, at::Tensor expected) { TEST_P(CommunicationTest, Gather) { params.type = CommunicationType::Gather; params.root = root; - params.mesh = all_ranks; + params.mesh = full_mesh; + params.team = all_ranks; auto communication = IrBuilder::create(&container, params); at::Tensor input_tensor = at::empty({1, tensor_size}, tensor_options); @@ -93,7 +95,8 @@ TEST_P(CommunicationTest, Gather) { TEST_P(CommunicationTest, Allgather) { params.type = CommunicationType::Allgather; - params.mesh = all_ranks; + params.mesh = full_mesh; + params.team = all_ranks; auto communication = IrBuilder::create(&container, params); at::Tensor input_tensor = at::empty({1, tensor_size}, tensor_options); @@ -122,7 +125,8 @@ TEST_P(CommunicationTest, Allgather) { TEST_P(CommunicationTest, Scatter) { params.type = CommunicationType::Scatter; params.root = root; - params.mesh = all_ranks; + params.mesh = full_mesh; + params.team = all_ranks; auto communication = IrBuilder::create(&container, params); at::Tensor input_tensor; @@ -157,7 +161,8 @@ TEST_P(CommunicationTest, Scatter) { TEST_P(CommunicationTest, Broadcast) { params.type = CommunicationType::Broadcast; params.root = root; - params.mesh = all_ranks; + params.mesh = full_mesh; + params.team = all_ranks; auto communication = IrBuilder::create(&container, params); at::Tensor input_tensor; @@ -203,6 +208,7 @@ TEST_P(CommunicationTest, SendRecv) { params.type = CommunicationType::SendRecv; params.root = sender; params.mesh = {receiver}; + params.team = {sender, receiver}; auto communication = IrBuilder::create(&container, params); at::Tensor input_tensor; @@ -244,6 +250,7 @@ TEST_P(CommunicationTest, SendRecvToSelf) { params.type = CommunicationType::SendRecv; params.root = sender; params.mesh = {sender}; + params.team = {sender}; auto communication = IrBuilder::create(&container, params); at::Tensor input_tensor = at::empty({tensor_size}, tensor_options); @@ -268,7 +275,8 @@ TEST_P(CommunicationTest, Reduce) { params.type = CommunicationType::Reduce; params.redOp = red_op; params.root = root; - params.mesh = all_ranks; + params.mesh = full_mesh; + params.team = all_ranks; auto communication = IrBuilder::create(&container, params); at::Tensor input_tensor = at::empty({1, tensor_size}, tensor_options); @@ -299,7 +307,8 @@ TEST_P(CommunicationTest, Reduce) { TEST_P(CommunicationTest, Allreduce) { params.type = CommunicationType::Allreduce; params.redOp = red_op; - params.mesh = all_ranks; + params.mesh = full_mesh; + params.team = all_ranks; auto communication = IrBuilder::create(&container, params); at::Tensor input_tensor = at::empty({1, tensor_size}, tensor_options); @@ -328,7 +337,8 @@ TEST_P(CommunicationTest, ReduceScatter) { params.type = CommunicationType::ReduceScatter; params.redOp = red_op; params.root = root; - params.mesh = all_ranks; + params.mesh = full_mesh; + params.team = all_ranks; params.scattered_axis = 1; auto communication = IrBuilder::create(&container, params); From 57d30c20f3f347907d813bc5a27a77a96e9d0ad4 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Fri, 31 May 2024 20:12:02 +0000 Subject: [PATCH 6/7] Fix sameAs. --- csrc/multidevice/communication.cpp | 7 ++++--- csrc/multidevice/communication.h | 1 + 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index ad5efa5b4e3..2244cfd3bd0 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -167,9 +167,10 @@ bool Communication::sameAs(const Statement* other) const { const auto& p1 = this->params(); const auto& p2 = other->as()->params(); - return ( - p1.type == p2.type && (!hasRoot(p1.type) || p1.root == p2.root) && - p1.mesh == p2.mesh && (!isReduction(p1.type) || p1.redOp == p2.redOp)); + return (p1.type == p2.type && (!hasRoot(p1.type) || p1.root == p2.root) && + p1.mesh == p2.mesh && + (!isReduction(p1.type) || p1.redOp == p2.redOp)) && + p1.team == p2.team; } namespace { diff --git a/csrc/multidevice/communication.h b/csrc/multidevice/communication.h index a817651eae0..7457679af58 100644 --- a/csrc/multidevice/communication.h +++ b/csrc/multidevice/communication.h @@ -10,6 +10,7 @@ #include #include #include +#include #include #ifdef NVFUSER_DISTRIBUTED #include From 969ecbc5c37c9190ba43b806fe885877ddd428a8 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Fri, 31 May 2024 22:53:12 +0000 Subject: [PATCH 7/7] Fix host IR tests. --- tests/cpp/test_multidevice_host_ir.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cpp/test_multidevice_host_ir.cpp b/tests/cpp/test_multidevice_host_ir.cpp index ca3c098cad9..22132977d50 100644 --- a/tests/cpp/test_multidevice_host_ir.cpp +++ b/tests/cpp/test_multidevice_host_ir.cpp @@ -77,7 +77,7 @@ TEST_P(MultiDeviceHostIrTest, SingleFusionSingleComm) { CommParams comm_params{ .type = CommunicationType::Allgather, .root = 0, - .is_root_in_mesh = true, + .mesh = mesh, .team = mesh.vector()}; auto communication = IrBuilder::create( static_cast(hic.get()), comm_params);