From 4560e212eb01fc6dcd5659b9f3760539091bce4d Mon Sep 17 00:00:00 2001 From: snordmann Date: Tue, 30 Apr 2024 09:27:29 -0700 Subject: [PATCH 1/4] make Communications IRs inheriting from Expr --- csrc/dispatch.h | 3 +- csrc/ir/all_nodes.h | 1 + csrc/ir/base_nodes.h | 2 +- csrc/multidevice/communication.cpp | 456 +++++++++--------- csrc/multidevice/communication.h | 168 +++---- csrc/multidevice/communicator.cpp | 18 - csrc/multidevice/communicator.h | 8 - csrc/multidevice/executor.cpp | 10 +- csrc/multidevice/executor.h | 2 +- csrc/multidevice/lower_communication.cpp | 48 +- csrc/multidevice/lower_communication.h | 2 +- tests/cpp/test_multidevice_communications.cpp | 61 ++- 12 files changed, 372 insertions(+), 407 deletions(-) diff --git a/csrc/dispatch.h b/csrc/dispatch.h index 8bd431bc15e..c1ab70cf88b 100644 --- a/csrc/dispatch.h +++ b/csrc/dispatch.h @@ -106,7 +106,8 @@ class Val; f(Merge); \ f(Swizzle); \ f(Swizzle2D); \ - f(Resize); + f(Resize); \ + f(Communication); #define DISPATCH_FOR_ALL_KIR_EXPRS(f) \ f(Allocate); \ f(Asm); \ diff --git a/csrc/ir/all_nodes.h b/csrc/ir/all_nodes.h index accbb8544f9..b59192a22e2 100644 --- a/csrc/ir/all_nodes.h +++ b/csrc/ir/all_nodes.h @@ -11,3 +11,4 @@ #include #include #include +#include diff --git a/csrc/ir/base_nodes.h b/csrc/ir/base_nodes.h index 35eb1cd94bb..9ec8c86fb46 100644 --- a/csrc/ir/base_nodes.h +++ b/csrc/ir/base_nodes.h @@ -490,7 +490,7 @@ using newObjectFuncType = Expr*( //! - Constructors need to register with the Fusion after inputs/outputs //! are defined //! - Implementation of bool sameAs(...) -//! 2) dispatch.h/.cpp must be updated to include dispatch of the new Val +//! 2) dispatch.h/.cpp must be updated to include dispatch of the new Expr //! 3) Default mutator function should be added to mutator.h/.cpp //! 4) Printing functions should be added to ir/iostream.h/.cpp //! 5) Lower case convenience functions should be added to arith.h/.cpp (If diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index d512cd4dcb3..8eedacd963d 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -5,6 +5,8 @@ * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on +#include +#include #include #if defined(USE_C10D_NCCL) #include @@ -40,56 +42,148 @@ inline void assertBuffersHaveSameSize( } } -inline void post_common(Communication& self, Communicator& comm) { - NVF_ERROR( - std::find( - self.params().team.begin(), - self.params().team.end(), - comm.deviceId()) != self.params().team.end(), - "current device index ", - comm.deviceId(), - " must be present in the communication's team"); -} - inline void doLocalCopy(const at::Tensor& dst, const at::Tensor& src) { dst.copy_(src, /* non-blocking */ true); } -} // namespace +inline bool hasRoot(CommunicationType type) { + return type == CommunicationType::Gather || + type == CommunicationType::Scatter || + type == CommunicationType::Broadcast || + type == CommunicationType::SendRecv; +} + +inline bool isReduction(CommunicationType type) { + return type == CommunicationType::Reduce || + type == CommunicationType::Allreduce || + type == CommunicationType::ReduceScatter; +} + +inline std::string typeToString(CommunicationType type) { + switch (type) { + case CommunicationType::Gather: + return "Gather"; + case CommunicationType::Allgather: + return "Allgather"; + case CommunicationType::Scatter: + return "Scatter"; + case CommunicationType::Reduce: + return "Reduce"; + case CommunicationType::Allreduce: + return "Allreduce"; + case CommunicationType::ReduceScatter: + return "ReduceScatter"; + case CommunicationType::Broadcast: + return "Broadcast"; + case CommunicationType::SendRecv: + return "SendRecv"; + default: + NVF_ERROR(false); + return ""; + } +} -Communication::Communication(CommParams params, std::string name, bool has_root) - : params_(std::move(params)), - collective_type_(std::move(name)), - has_root_(has_root) { - assertBuffersHaveSameSize(params_.src_bufs, params_.dst_bufs); +inline void assertValid( + const CommParams& params, + const DeviceIdxType my_device_index) { + assertBuffersHaveSameSize(params.src_bufs, params.dst_bufs); NVF_ERROR( - std::unique(params_.team.begin(), params_.team.end()) == - params_.team.end(), + std::adjacent_find(params.team.cbegin(), params.team.cend()) == + params.team.cend(), "the communication must not involve the same device more than once"); - NVF_ERROR(!params_.team.empty(), "the team size must be greater than 0"); - if (has_root_) { - auto it = std::find(params_.team.begin(), params_.team.end(), params_.root); + NVF_ERROR(!params.team.empty(), "the team size must be greater than 0"); + NVF_ERROR( + std::find(params.team.begin(), params.team.end(), my_device_index) != + params.team.end(), + "current device index ", + my_device_index, + " must be present in the communication's team"); + if (hasRoot(params.type)) { + auto it = std::find(params.team.begin(), params.team.end(), params.root); NVF_ERROR( - it != params_.team.end(), + it != params.team.end(), "root (device ", - params_.root, + params.root, ") must be present in the communication's team"); - // pytorch's process group expects the root to be specified - // as an integer between 0 and world_size-1. We choose it to be - // the device's relative index within the team - root_relative_index_ = std::distance(params_.team.begin(), it); } + bool is_root = (my_device_index == params.root); + switch (params.type) { + case CommunicationType::Gather: + assertBufferCount(params.src_bufs, 1); + assertBufferCount(params.dst_bufs, is_root ? params.team.size() : 0); + break; + case CommunicationType::Allgather: + assertBufferCount(params.src_bufs, 1); + assertBufferCount(params.dst_bufs, params.team.size()); + break; + case CommunicationType::Scatter: + assertBufferCount(params.dst_bufs, 1); + assertBufferCount(params.src_bufs, is_root ? params.team.size() : 0); + break; + case CommunicationType::Reduce: + assertBufferCount(params.src_bufs, 1); + assertBufferCount(params.dst_bufs, is_root ? 1 : 0); + break; + case CommunicationType::Allreduce: + assertBufferCount(params.dst_bufs, 1); + assertBufferCount(params.src_bufs, 1); + break; + case CommunicationType::ReduceScatter: + assertBufferCount(params.dst_bufs, 1); + assertBufferCount(params.src_bufs, params.team.size()); + break; + case CommunicationType::Broadcast: + if (is_root) { + assertBufferCount(params.src_bufs, 1); + NVF_ERROR( + params.dst_bufs.size() < 2, "there must be at most 2 buffer(s)"); + } else { + assertBufferCount(params.src_bufs, 0); + assertBufferCount(params.dst_bufs, 1); + } + break; + case CommunicationType::SendRecv: + NVF_ERROR( + params.team.size() == 1 || params.team.size() == 2, + "the team size should be 1 or 2"); + if (is_root) { + assertBufferCount(params.src_bufs, 1); + assertBufferCount(params.dst_bufs, (params.team.size() == 1) ? 1 : 0); + } else { + assertBufferCount(params.src_bufs, 0); + assertBufferCount(params.dst_bufs, 1); + } + break; + } +} + +// pytorch's process group expects the root to be specified +// as an integer between 0 and world_size-1. We choose it to be +// the device's relative index within the team +DeviceIdxType getRootRelativeIndex(const CommParams& params) { + auto it = std::find(params.team.begin(), params.team.end(), params.root); + return std::distance(params.team.begin(), it); } +} // namespace + +Communication::Communication(IrBuilderPasskey passkey, CommParams params) + : Expr(passkey), params_(std::move(params)) {} + +Communication::Communication(const Communication* src, IrCloner* ir_cloner) + : Expr(src, ir_cloner), params_(src->params()) {} + +NVFUSER_DEFINE_CLONE_AND_CREATE(Communication) + std::string Communication::toString(int indent) const { std::stringstream ss; std::string ext_indent(" ", indent); std::string indent1 = ext_indent + " "; std::string indent2 = ext_indent + " "; - ss << ext_indent << "Communication " << collective_type_ << ": {\n"; + ss << ext_indent << "Communication " << typeToString(params_.type) << ": {\n"; - if (has_root_) { + if (hasRoot(params_.type)) { ss << indent1 << "root: " << params_.root << ",\n"; } ss << indent1 << "team: {"; @@ -107,214 +201,132 @@ std::string Communication::toString(int indent) const { return ss.str(); } -Broadcast::Broadcast(CommParams params) : Communication(params, "broadcast") {} - -c10::intrusive_ptr Broadcast::post( - Communicator& comm, - std::optional backend) { - post_common(*this, comm); - - if (comm.deviceId() == params_.root) { - assertBufferCount(params_.src_bufs, 1); - if (params_.dst_bufs.size() == 1) { - doLocalCopy(params_.dst_bufs.at(0), params_.src_bufs.at(0)); - } else { - assertBufferCount(params_.dst_bufs, 0); - } - } else { - assertBufferCount(params_.src_bufs, 0); - assertBufferCount(params_.dst_bufs, 1); - } - - if (params_.team.size() == 1) { - return nullptr; - } - - return comm.getBackendForTeam(params_.team, backend) - ->broadcast( - comm.deviceId() == params_.root ? params_.src_bufs : params_.dst_bufs, - {.rootRank = root_relative_index_}); -} - -Gather::Gather(CommParams params) : Communication(params, "gather") { - assertBufferCount(params_.src_bufs, 1); +std::string Communication::toInlineString(int indent_size) const { + return toString(indent_size); } -c10::intrusive_ptr Gather::post( - Communicator& comm, - std::optional backend) { - post_common(*this, comm); - // This is used to change the representation of the buffers to match c10d - // ProcessGroup API - std::vector> buf_list; - if (comm.deviceId() == params_.root) { - assertBufferCount(params_.dst_bufs, params_.team.size()); - buf_list = {std::move(params_.dst_bufs)}; - } else { - assertBufferCount(params_.dst_bufs, 0); +// TODO add checking symbolic representation of src and dst buffers +bool Communication::sameAs(const Statement* other) const { + if (other == this) { + return true; } - auto work = - comm.getBackendForTeam(params_.team, backend) - ->gather( - buf_list, params_.src_bufs, {.rootRank = root_relative_index_}); - if (comm.deviceId() == params_.root) { - params_.dst_bufs = std::move(buf_list.back()); + if (!other->isA()) { + return false; } - return work; -} + const auto& p1 = this->params(); + const auto& p2 = other->as()->params(); -Allgather::Allgather(CommParams params) - : Communication(params, "allgather", false) { - assertBufferCount(params_.src_bufs, 1); - assertBufferCount(params_.dst_bufs, params_.team.size()); + return ( + p1.type == p2.type && (!hasRoot(p1.type) || p1.root == p2.root) && + p1.team == p2.team && (!isReduction(p1.type) || p1.redOp == p2.redOp)); } -c10::intrusive_ptr Allgather::post( - Communicator& comm, - std::optional backend) { - post_common(*this, comm); +c10::intrusive_ptr postCommunication( + Communication* communication, + DeviceIdxType my_device_index, + c10::intrusive_ptr backend) { + auto params = communication->params(); + assertValid(params, my_device_index); + bool is_root = (my_device_index == params.root); // This is used to change the representation of the buffers to match c10d // ProcessGroup API std::vector> buf_list; - buf_list = {std::move(params_.dst_bufs)}; - auto work = comm.getBackendForTeam(params_.team, backend) - ->allgather(buf_list, params_.src_bufs, {}); - params_.dst_bufs = std::move(buf_list.back()); - return work; -} - -Scatter::Scatter(CommParams params) : Communication(params, "scatter") { - assertBufferCount(params_.dst_bufs, 1); -} - -c10::intrusive_ptr Scatter::post( - Communicator& comm, - std::optional backend) { - post_common(*this, comm); - // This is used to change the representation of the buffers to match c10d - // ProcessGroup API - std::vector> buf_list; - if (comm.deviceId() == params_.root) { - assertBufferCount(params_.src_bufs, params_.team.size()); - buf_list = {std::move(params_.src_bufs)}; - } else { - assertBufferCount(params_.src_bufs, 0); - } - auto work = - comm.getBackendForTeam(params_.team, backend) - ->scatter( - params_.dst_bufs, buf_list, {.rootRank = root_relative_index_}); - if (comm.deviceId() == params_.root) { - params_.src_bufs = std::move(buf_list.back()); - } - return work; -} - -Reduce::Reduce(CommParams params) : Communication(params, "reduce") { - assertBuffersHaveSameSize(params_.src_bufs, params_.dst_bufs); - assertBufferCount(params_.src_bufs, 1); -} - -c10::intrusive_ptr Reduce::post( - Communicator& comm, - std::optional backend) { - if (comm.deviceId() == params_.root) { - assertBufferCount(params_.dst_bufs, 1); - } else { - assertBufferCount(params_.dst_bufs, 0); - } - post_common(*this, comm); - auto& buf = - (comm.deviceId() == params_.root) ? params_.dst_bufs : params_.src_bufs; - c10d::ReduceOptions options = { - .reduceOp = params_.redOp, .rootRank = root_relative_index_}; - auto team_backend = comm.getBackendForTeam(params_.team, backend); -#if defined(USE_C10D_NCCL) - auto nccl_backend = dynamic_cast(team_backend.get()); - if (nccl_backend) { + switch (params.type) { + case CommunicationType::Gather: { + if (is_root) { + buf_list = {params.dst_bufs}; + } + auto work = backend->gather( + buf_list, + params.src_bufs, + {.rootRank = getRootRelativeIndex(params)}); + return work; + } + case CommunicationType::Allgather: { + // This is used to change the representation of the buffers to match c10d + // ProcessGroup API + buf_list = {params.dst_bufs}; + auto work = backend->allgather(buf_list, params.src_bufs, {}); + return work; + } + case CommunicationType::Scatter: { + // This is used to change the representation of the buffers to match c10d + // ProcessGroup API + if (is_root) { + buf_list = {params.src_bufs}; + } + auto work = backend->scatter( + params.dst_bufs, + buf_list, + {.rootRank = getRootRelativeIndex(params)}); + return work; + } + case CommunicationType::Reduce: { + auto& buf = (is_root) ? params.dst_bufs : params.src_bufs; + c10d::ReduceOptions options = { + .reduceOp = params.redOp, .rootRank = getRootRelativeIndex(params)}; +#if defined(NVFUSER_DISTRIBUTED) && defined(USE_C10D_NCCL) + auto nccl_backend = dynamic_cast(backend.get()); + if (nccl_backend) { #if NVF_TORCH_VERSION_NO_LESS(2, 3, 0) - // API change https://github.com/pytorch/pytorch/pull/119421 - return nccl_backend->_reduce_oop( - buf.at(0), params_.src_bufs.at(0), options); + // API change https://github.com/pytorch/pytorch/pull/119421 + return nccl_backend->_reduce_oop( + buf.at(0), params.src_bufs.at(0), options); #else - return nccl_backend->_reduce_oop(buf, params_.src_bufs, options); + return nccl_backend->_reduce_oop(buf, params.src_bufs, options); #endif - } + } #endif - if (comm.deviceId() == params_.root) { - doLocalCopy(params_.dst_bufs.at(0), params_.src_bufs.at(0)); - } - return team_backend->reduce(buf, options); -} - -Allreduce::Allreduce(CommParams params) - : Communication(params, "allreduce", false) { - assertBuffersHaveSameSize(params_.src_bufs, params_.dst_bufs); - assertBufferCount(params_.src_bufs, 1); - assertBufferCount(params_.dst_bufs, 1); -} - -c10::intrusive_ptr Allreduce::post( - Communicator& comm, - std::optional backend) { - post_common(*this, comm); - doLocalCopy(params_.dst_bufs.at(0), params_.src_bufs.at(0)); - return comm.getBackendForTeam(params_.team, backend) - ->allreduce(params_.dst_bufs, {.reduceOp = params_.redOp}); -} - -ReduceScatter::ReduceScatter(CommParams params) - : Communication(params, "reduce_scatter", false) { - assertBufferCount(params_.src_bufs, params_.team.size()); - assertBufferCount(params_.dst_bufs, 1); -} - -c10::intrusive_ptr ReduceScatter::post( - Communicator& comm, - std::optional backend) { - post_common(*this, comm); - // This is used to change the representation of the buffers to match c10d - // ProcessGroup API - std::vector> buf_list = {std::move(params_.src_bufs)}; - auto work = comm.getBackendForTeam(params_.team, backend) - ->reduce_scatter( - params_.dst_bufs, buf_list, {.reduceOp = params_.redOp}); - params_.src_bufs = std::move(buf_list.back()); - return work; -} - -SendRecv::SendRecv(CommParams params) : Communication(params, "send/recv") { - assertBuffersHaveSameSize(params_.src_bufs, params_.dst_bufs); - NVF_ERROR( - params_.team.size() == 1 || params_.team.size() == 2, - "the team size should be 1 or 2"); -} - -c10::intrusive_ptr SendRecv::post( - Communicator& comm, - std::optional backend) { - post_common(*this, comm); - - if (comm.deviceId() == params_.root) { - assertBufferCount(params_.src_bufs, 1); - if (params_.team.size() == 1) { - assertBufferCount(params_.dst_bufs, 1); - doLocalCopy(params_.dst_bufs.at(0), params_.src_bufs.at(0)); - return nullptr; - } else { - assertBufferCount(params_.dst_bufs, 0); + if (is_root) { + doLocalCopy(params.dst_bufs.at(0), params.src_bufs.at(0)); + } + return backend->reduce(buf, options); + } + case CommunicationType::Allreduce: { + doLocalCopy(params.dst_bufs.at(0), params.src_bufs.at(0)); + return backend->allreduce(params.dst_bufs, {.reduceOp = params.redOp}); + } + case CommunicationType::ReduceScatter: { + // This is used to change the representation of the buffers to match c10d + // ProcessGroup API + buf_list = {params.src_bufs}; + auto work = backend->reduce_scatter( + params.dst_bufs, buf_list, {.reduceOp = params.redOp}); + return work; } - } else { - assertBufferCount(params_.src_bufs, 0); - assertBufferCount(params_.dst_bufs, 1); + case CommunicationType::Broadcast: { + if (is_root && params.dst_bufs.size() == 1) { + doLocalCopy(params.dst_bufs.at(0), params.src_bufs.at(0)); + } + if (params.team.size() == 1) { + return nullptr; + } + return backend->broadcast( + is_root ? params.src_bufs : params.dst_bufs, + {.rootRank = getRootRelativeIndex(params)}); + } + case CommunicationType::SendRecv: { + if (is_root && params.team.size() == 1) { + doLocalCopy(params.dst_bufs.at(0), params.src_bufs.at(0)); + return nullptr; + } + const DeviceIdxType sender = params.root; + const DeviceIdxType receiver = + (params.team.at(0) == sender) ? params.team.at(1) : params.team.at(0); + std::vector& tensor = + params.dst_bufs.empty() ? params.src_bufs : params.dst_bufs; + if (my_device_index == sender) { + return backend->send(tensor, static_cast(receiver), /*tag*/ 0); + } else if (my_device_index == receiver) { + return backend->recv(tensor, static_cast(sender), /*tag*/ 0); + } else { + return nullptr; + } + } + default: + NVF_ERROR(false, "Wrong communication type: ", typeToString(params.type)); + return nullptr; } - - return comm.sendRecv( - (params_.team.at(0) == params_.root) ? params_.team.at(1) - : params_.team.at(0), - params_.root, - params_.dst_bufs.empty() ? params_.src_bufs : params_.dst_bufs, - backend); } } // namespace nvfuser diff --git a/csrc/multidevice/communication.h b/csrc/multidevice/communication.h index f8b6f558810..8272609bc77 100644 --- a/csrc/multidevice/communication.h +++ b/csrc/multidevice/communication.h @@ -7,6 +7,8 @@ // clang-format on #pragma once +#include +#include #include #include #include @@ -15,11 +17,23 @@ namespace nvfuser { +enum class CommunicationType { + Gather, + Allgather, + Scatter, + Reduce, + Allreduce, + ReduceScatter, + Broadcast, + SendRecv +}; + /* This struct gathers all the parameters necessary for the construction a communication */ struct CommParams { + CommunicationType type; DeviceIdxType root = -1; std::vector src_bufs; std::vector dst_bufs; @@ -50,160 +64,64 @@ otherwise an error is thrown. NOTE: pytorch's NCCL process group API needs buffers on root for scatter/gather operation. -*/ - -class Communication { - public: - virtual ~Communication() = default; - - std::string toString(int indent = 0) const; - - const auto& params() const { - return params_; - } - - // Triggers the execution of the communication. This is a non-blocking call. - // The communication can be posted multiple times - virtual c10::intrusive_ptr post( - Communicator& comm, - std::optional backend = std::nullopt) = 0; - - protected: - // argument "name" is only used for printing - // argument "has_root" indicates if the communication is rooted - Communication(CommParams params, std::string name, bool has_root = true); - - // store the arguments of the communication - CommParams params_; - // stores the relative index of the root in the team - DeviceIdxType root_relative_index_ = -1; - - private: - // used for printing - std::string collective_type_; - // indicates if the communication is rooted - bool has_root_ = true; -}; -/* +(*) Broadcast Copies the root's src buffer to each device's dst buffer - Requirements: - the root is set and belongs to the team - the root has one src buffer, and no or one dst buffer - non-roots have no src buffer and one dst buffer - all buffers have the same size -*/ -class Broadcast : public Communication { - public: - Broadcast(CommParams params); - c10::intrusive_ptr post( - Communicator& comm, - std::optional backend = std::nullopt) override; -}; -/* + +(*) Gather Copies each device's source buffer to the root's respective src buffer. The order of the sender devices matches the order of the root's buffers. - Requirements: - the root is set and belongs to the team - the root has one src buffer and dst buffers - non-roots have one src buffer and no dst buffer - all buffers have the same size -*/ -class Gather : public Communication { - public: - Gather(CommParams params); - c10::intrusive_ptr post( - Communicator& comm, - std::optional backend = std::nullopt) override; -}; -/* +(*) Allgather Copies each device's src buffer to each device's respective src buffer. The order of the devices matches the order of the buffers - Requirements: - all device have one src buffer and dst buffers - all buffers have the same size -*/ -class Allgather : public Communication { - public: - Allgather(CommParams params); - c10::intrusive_ptr post( - Communicator& comm, - std::optional backend = std::nullopt) override; -}; -/* +(*) Scatter Copies each root's src buffer to each device's dst buffer. The order of the buffers matches the order of the receiver devices - Requirements: - the root is set and belongs to the team - the root has src buffers and one dst buffer - non-roots have no src buffer and one dst buffer - all buffers have the same size -*/ -class Scatter : public Communication { - public: - Scatter(CommParams params); - c10::intrusive_ptr post( - Communicator& comm, - std::optional backend = std::nullopt) override; -}; -/* +(*) Reduce Reduce the src buffers to the root's dst buffer. - Requirements: - the root is set and belongs to the team - the root has one src buffers and one dst buffer - non-roots have one src buffer and no dst buffer - all buffers have the same size -*/ -class Reduce : public Communication { - public: - Reduce(CommParams params); - c10::intrusive_ptr post( - Communicator& comm, - std::optional backend = std::nullopt) override; -}; -/* +(*) Allreduce Reduce the src buffers to the dst buffer. - Requirements: - all devices have one src buffer and one dst buffer - all buffers have the same size -*/ -class Allreduce : public Communication { - public: - Allreduce(CommParams params); - c10::intrusive_ptr post( - Communicator& comm, - std::optional backend = std::nullopt) override; -}; -/* +(*) ReduceScatter Reduce all the src buffers and shard the result to the dst buffers. - Requirements: - all devices have src buffer and one dst buffer - all buffers have the same size -*/ -class ReduceScatter : public Communication { - public: - ReduceScatter(CommParams params); - c10::intrusive_ptr post( - Communicator& comm, - std::optional backend = std::nullopt) override; -}; -/* +(*) SendRecv Copies the sender's src buffers to the receiver's dst buffer It is equivalent to a Broadcast with a team of size == 2 @@ -218,12 +136,42 @@ case of a local copy) - If team is of size 2, the unique non-root have no src buffer and one dst buffer */ -class SendRecv : public Communication { + +class NVF_API Communication : public Expr { public: - SendRecv(CommParams params); - c10::intrusive_ptr post( - Communicator& comm, - std::optional backend = std::nullopt) override; + using Expr::Expr; + Communication(IrBuilderPasskey passkey, CommParams params); + Communication(const Communication* src, IrCloner* ir_cloner); + + Communication(const Communication& other) = delete; + Communication& operator=(const Communication& other) = delete; + Communication(Communication&& other) = delete; + Communication& operator=(Communication&& other) = delete; + + NVFUSER_DECLARE_CLONE_AND_CREATE + + std::string toString(int indent_size = 0) const override; + std::string toInlineString(int indent_size = 0) const override; + const char* getOpString() const override { + return "Communication"; + } + + bool sameAs(const Statement* other) const override; + + auto params() const { + return params_; + } + + private: + // store the arguments of the communication + CommParams params_; }; +// Triggers the execution of the communication. This is a non-blocking call. +// The communication can be posted multiple times +c10::intrusive_ptr postCommunication( + Communication* communication, + DeviceIdxType my_device_index, + c10::intrusive_ptr backend); + } // namespace nvfuser diff --git a/csrc/multidevice/communicator.cpp b/csrc/multidevice/communicator.cpp index e0162c8bc8e..9a972c8a372 100644 --- a/csrc/multidevice/communicator.cpp +++ b/csrc/multidevice/communicator.cpp @@ -234,24 +234,6 @@ c10::intrusive_ptr Communicator::getBackendForTeam( return backends_.at(team_key); } -c10::intrusive_ptr Communicator::sendRecv( - DeviceIdxType receiver, - DeviceIdxType sender, - std::vector& tensors, - std::optional backend, - int tag) { - NVF_ERROR( - deviceId() == sender || deviceId() == receiver, - "only sender or receiver should post the sendRecv"); - NVF_ERROR(sender != receiver, "cannot send to self"); - - auto world = getWorld(backend); - if (deviceId() == sender) { - return world->send(tensors, static_cast(dIdToRank(receiver)), tag); - } - return world->recv(tensors, static_cast(dIdToRank(sender)), tag); -} - c10::intrusive_ptr Communicator::getWorld( std::optional backend) { std::vector all_ranks(size_); diff --git a/csrc/multidevice/communicator.h b/csrc/multidevice/communicator.h index 53efe005328..cfe1c1bcb14 100644 --- a/csrc/multidevice/communicator.h +++ b/csrc/multidevice/communicator.h @@ -73,14 +73,6 @@ class Communicator { default_backend_ = backend; } - // performs a send/receive p2p data transfer - c10::intrusive_ptr sendRecv( - DeviceIdxType receiver, - DeviceIdxType sender, - std::vector& tensor, - std::optional backend = std::nullopt, - int tag = 0); - // performs a blocking barrier in the communicator void barrier(std::optional backend = std::nullopt) { getWorld(backend)->barrier()->wait(); diff --git a/csrc/multidevice/executor.cpp b/csrc/multidevice/executor.cpp index 448e232098a..8f7f65c3065 100644 --- a/csrc/multidevice/executor.cpp +++ b/csrc/multidevice/executor.cpp @@ -149,7 +149,7 @@ void MultiDeviceExecutor::postKernel( } } -void MultiDeviceExecutor::postCommunication(SegmentedGroup* group) { +void MultiDeviceExecutor::postResharding(SegmentedGroup* group) { // Lower the group into a vector of Communications NVF_ERROR( group->exprs().size() == 1, @@ -174,8 +174,10 @@ void MultiDeviceExecutor::postCommunication(SegmentedGroup* group) { lowerCommunication(comm_.deviceId(), expr, input_tensor, output_tensor); // post and wait communications - for (auto& communication : communications) { - auto work = communication->post(comm_); + for (auto communication : communications) { + auto backend = + comm_.getBackendForTeam(communication->params().team, std::nullopt); + auto work = postCommunication(communication, comm_.deviceId(), backend); if (work) { work->wait(); } @@ -213,7 +215,7 @@ std::vector MultiDeviceExecutor::runWithInput( if (!is_resharding_.at(group)) { postKernel(group, launch_params); } else { - postCommunication(group); + postResharding(group); } } diff --git a/csrc/multidevice/executor.h b/csrc/multidevice/executor.h index 9c277464dd2..d43c8842cbd 100644 --- a/csrc/multidevice/executor.h +++ b/csrc/multidevice/executor.h @@ -116,7 +116,7 @@ class MultiDeviceExecutor { // params_.use_fusion_executor_cache = false void postKernel(SegmentedGroup* group, const LaunchParams& launch_params); // execute a SegmentedGroup representing inter-device communication - void postCommunication(SegmentedGroup* group); + void postResharding(SegmentedGroup* group); // Stores concrete computed values, std::unordered_map val_to_IValue_; diff --git a/csrc/multidevice/lower_communication.cpp b/csrc/multidevice/lower_communication.cpp index 3d814378dfc..f4b1a8c6e07 100644 --- a/csrc/multidevice/lower_communication.cpp +++ b/csrc/multidevice/lower_communication.cpp @@ -6,6 +6,7 @@ */ // clang-format on #include +#include #include #include #include @@ -106,6 +107,8 @@ CommParams createParamsForGatherScatter( bool is_scatter) { const DeviceMesh& mesh = root_tv->getDeviceMesh(); CommParams params; + params.type = + is_scatter ? CommunicationType::Scatter : CommunicationType::Gather; params.root = root; params.team = mesh.vector(); bool is_root_in_mesh = mesh.has(root); @@ -142,7 +145,7 @@ void lowerToScatter( TensorView* output_tv, at::Tensor input_tensor, at::Tensor output_tensor, - std::vector>& comms) { + std::vector& comms) { // we arbitrarily choose the first device of the sender mesh to be the root const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh(); auto root = input_tv->getDeviceMesh().vector().at(0); @@ -151,7 +154,7 @@ void lowerToScatter( } auto params = createParamsForGatherScatter( my_device_index, root, output_tv, input_tensor, output_tensor, true); - comms.push_back(std::make_shared(std::move(params))); + comms.push_back(IrBuilder::create(std::move(params))); } /* @@ -166,7 +169,7 @@ void lowerToGather( TensorView* output_tv, at::Tensor input_tensor, at::Tensor output_tensor, - std::vector>& comms) { + std::vector& comms) { // we create as many 'Gathers' as there are devices in the receiver mesh const DeviceMesh& sender_mesh = input_tv->getDeviceMesh(); for (auto root : output_tv->getDeviceMesh().vector()) { @@ -175,7 +178,7 @@ void lowerToGather( } auto params = createParamsForGatherScatter( my_device_index, root, input_tv, output_tensor, input_tensor, false); - comms.push_back(std::make_shared(std::move(params))); + comms.push_back(IrBuilder::create(std::move(params))); } } @@ -186,13 +189,14 @@ void lowerToAllgather( TensorView* output_tv, at::Tensor input_tensor, at::Tensor output_tensor, - std::vector>& comms) { + std::vector& comms) { const DeviceMesh& mesh = input_tv->getDeviceMesh(); if (!mesh.has(my_device_index)) { return; } CommParams params; + params.type = CommunicationType::Allgather; params.team = mesh.vector(); for (auto i : c10::irange(mesh.vector().size())) { params.dst_bufs.push_back( @@ -200,7 +204,7 @@ void lowerToAllgather( } params.src_bufs = {input_tensor}; - comms.push_back(std::make_shared(std::move(params))); + comms.push_back(IrBuilder::create(std::move(params))); } // Creates and set the CommParams for a Broadcast or Send/Recv communication @@ -216,6 +220,7 @@ CommParams createParamsForBroadcastOrP2P( if (!mesh.has(root)) { params.team.push_back(root); } + params.type = CommunicationType::Broadcast; if (my_device_index == root) { params.src_bufs = {input_tensor}; @@ -234,19 +239,13 @@ void lowerToBroadcastOrP2P( const DeviceMesh& mesh, // receiver devices at::Tensor input_tensor, at::Tensor output_tensor, - std::vector>& comms) { + std::vector& comms) { if (!isDeviceInvolved(my_device_index, root, mesh)) { return; } auto params = createParamsForBroadcastOrP2P( my_device_index, root, mesh, input_tensor, output_tensor); - std::shared_ptr comm; - if (mesh.vector().size() == 1) { - comm = std::make_shared(std::move(params)); - } else { - comm = std::make_shared(std::move(params)); - } - comms.push_back(comm); + comms.push_back(IrBuilder::create(std::move(params))); } // Adds several Broadcast or Send/Recv communications to the vector 'comms' @@ -260,7 +259,7 @@ void lowerToBroadcastOrP2P( at::Tensor input_tensor, at::Tensor output_tensor, bool is_sharded, - std::vector>& comms) { + std::vector& comms) { const DeviceMesh& sender_mesh = input_tv->getDeviceMesh(); const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh(); if (is_sharded) { @@ -300,6 +299,7 @@ CommParams createParamsForReduce( BinaryOpType op_type) { const DeviceMesh& mesh = input_tv->getDeviceMesh(); CommParams params; + params.type = CommunicationType::Reduce; params.root = root; params.redOp = getC10dReduceOpType(op_type); params.team = mesh.vector(); @@ -333,7 +333,7 @@ void lowerToReduce( at::Tensor input_tensor, at::Tensor output_tensor, BinaryOpType op_type, - std::vector>& comms) { + std::vector& comms) { const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh(); const DeviceMesh& sender_mesh = input_tv->getDeviceMesh(); // we create as many Reduces as there are devices in the receiver mesh @@ -349,7 +349,7 @@ void lowerToReduce( input_tensor, output_tensor, op_type); - comms.push_back(std::make_shared(std::move(params))); + comms.push_back(IrBuilder::create(std::move(params))); } } @@ -360,17 +360,18 @@ void lowerToAllreduce( at::Tensor input_tensor, at::Tensor output_tensor, BinaryOpType op_type, - std::vector>& comms) { + std::vector& comms) { const DeviceMesh& mesh = input_tv->getDeviceMesh(); if (!mesh.has(my_device_index)) { return; } CommParams params; + params.type = CommunicationType::Allreduce; params.redOp = getC10dReduceOpType(op_type); params.team = mesh.vector(); params.dst_bufs = {output_tensor}; params.src_bufs = {input_tensor.view(output_tensor.sizes())}; - comms.push_back(std::make_shared(params)); + comms.push_back(IrBuilder::create(params)); } void lowerToReduceScatter( @@ -380,12 +381,13 @@ void lowerToReduceScatter( at::Tensor input_tensor, at::Tensor output_tensor, BinaryOpType op_type, - std::vector>& comms) { + std::vector& comms) { const DeviceMesh& mesh = input_tv->getDeviceMesh(); if (!mesh.has(my_device_index)) { return; } CommParams params; + params.type = CommunicationType::ReduceScatter; params.redOp = getC10dReduceOpType(op_type); params.team = mesh.vector(); params.dst_bufs = {output_tensor}; @@ -404,7 +406,7 @@ void lowerToReduceScatter( params.src_bufs.push_back(slice); } - comms.push_back(std::make_shared(params)); + comms.push_back(IrBuilder::create(params)); } } // namespace @@ -418,12 +420,12 @@ void lowerToReduceScatter( sources *) Leverage the topology to ensure that the senders and recerivers are close */ -std::vector> lowerCommunication( +std::vector lowerCommunication( DeviceIdxType my_device_index, Expr* c, at::Tensor input_tensor, at::Tensor output_tensor) { - std::vector> comms; + std::vector comms; NVF_ERROR( c->inputs().size() == 1 && c->inputs().at(0)->isA() && c->outputs().size() == 1 && c->outputs().at(0)->isA(), diff --git a/csrc/multidevice/lower_communication.h b/csrc/multidevice/lower_communication.h index 3942ae09dd8..f59fccc266b 100644 --- a/csrc/multidevice/lower_communication.h +++ b/csrc/multidevice/lower_communication.h @@ -19,7 +19,7 @@ bool isLowerableToCommunication(Expr* expr); // Lower a PipelineCommunication into a series of Communication, given a // device_index. -std::vector> lowerCommunication( +std::vector lowerCommunication( DeviceIdxType device_index, Expr* c, at::Tensor input_tensor, diff --git a/tests/cpp/test_multidevice_communications.cpp b/tests/cpp/test_multidevice_communications.cpp index 0e741c8b731..2f5b5ef1310 100644 --- a/tests/cpp/test_multidevice_communications.cpp +++ b/tests/cpp/test_multidevice_communications.cpp @@ -7,6 +7,8 @@ // clang-format on #include +#include +#include #include #include #include @@ -32,6 +34,8 @@ class CommunicationTest c10d::ReduceOp::RedOpType::SUM; CommParams params; std::vector all_ranks; + c10::intrusive_ptr backend; + IrContainer container; }; CommunicationTest::CommunicationTest() { @@ -45,6 +49,7 @@ void CommunicationTest::SetUp() { if (!communicator->isBackendAvailable(GetParam())) { GTEST_SKIP() << "Backend not available"; } + backend = communicator->getBackendForTeam(all_ranks, GetParam()); } void CommunicationTest::validate(at::Tensor obtained, at::Tensor expected) { @@ -61,6 +66,7 @@ void CommunicationTest::resetDstBuffers() { } TEST_P(CommunicationTest, Gather) { + params.type = CommunicationType::Gather; params.root = root; params.team = all_ranks; params.src_bufs = {at::empty(tensor_size, tensor_options)}; @@ -69,7 +75,7 @@ TEST_P(CommunicationTest, Gather) { params.dst_bufs.push_back(at::empty(tensor_size, tensor_options)); } } - auto communication = Gather(params); + auto communication = IrBuilder::create(&container, params); for (int j : c10::irange(number_of_repetitions)) { resetDstBuffers(); @@ -77,7 +83,8 @@ TEST_P(CommunicationTest, Gather) { at::arange(tensor_size, tensor_options) + (communicator->deviceId() + 1) * j); - auto work = communication.post(*communicator, GetParam()); + auto work = + postCommunication(communication, communicator->deviceId(), backend); work->wait(); if (communicator->deviceId() == root) { @@ -91,13 +98,14 @@ TEST_P(CommunicationTest, Gather) { } TEST_P(CommunicationTest, Allgather) { + params.type = CommunicationType::Allgather; params.team = all_ranks; params.src_bufs = { at::empty(tensor_size, tensor_options) * communicator->deviceId()}; for (int64_t i = 0; i < communicator->size(); i++) { params.dst_bufs.push_back(at::empty(tensor_size, tensor_options)); } - auto communication = Allgather(params); + auto communication = IrBuilder::create(&container, params); for (int j : c10::irange(number_of_repetitions)) { resetDstBuffers(); @@ -105,7 +113,8 @@ TEST_P(CommunicationTest, Allgather) { at::arange(tensor_size, tensor_options) + (communicator->deviceId() + 1) * j); - auto work = communication.post(*communicator, GetParam()); + auto work = + postCommunication(communication, communicator->deviceId(), backend); work->wait(); for (int i : c10::irange(communicator->size())) { @@ -117,6 +126,7 @@ TEST_P(CommunicationTest, Allgather) { } TEST_P(CommunicationTest, Scatter) { + params.type = CommunicationType::Scatter; params.root = root; params.team = all_ranks; if (communicator->deviceId() == root) { @@ -126,7 +136,7 @@ TEST_P(CommunicationTest, Scatter) { } } params.dst_bufs = {at::empty(tensor_size, tensor_options)}; - auto communication = Scatter(params); + auto communication = IrBuilder::create(&container, params); for (int j : c10::irange(number_of_repetitions)) { resetDstBuffers(); @@ -135,7 +145,8 @@ TEST_P(CommunicationTest, Scatter) { at::arange(tensor_size, tensor_options) + (i + 1) * j); } - auto work = communication.post(*communicator, GetParam()); + auto work = + postCommunication(communication, communicator->deviceId(), backend); work->wait(); auto obtained = params.dst_bufs.at(0); @@ -146,6 +157,7 @@ TEST_P(CommunicationTest, Scatter) { } TEST_P(CommunicationTest, Broadcast) { + params.type = CommunicationType::Broadcast; params.root = root; params.team = all_ranks; if (communicator->deviceId() == root) { @@ -153,7 +165,7 @@ TEST_P(CommunicationTest, Broadcast) { } params.dst_bufs = {at::empty(tensor_size, tensor_options)}; - auto communication = Broadcast(params); + auto communication = IrBuilder::create(&container, params); for (int j : c10::irange(number_of_repetitions)) { resetDstBuffers(); @@ -161,7 +173,8 @@ TEST_P(CommunicationTest, Broadcast) { params.src_bufs.at(0).copy_(at::arange(tensor_size, tensor_options) + j); } - auto work = communication.post(*communicator, GetParam()); + auto work = + postCommunication(communication, communicator->deviceId(), backend); if (communicator->size() > 1) { work->wait(); } @@ -173,6 +186,9 @@ TEST_P(CommunicationTest, Broadcast) { } TEST_P(CommunicationTest, SendRecv) { + if (GetParam() == CommunicatorBackend::ucc) { + GTEST_SKIP() << "Disabling because of UCC hangs, see issue #2091"; + } if (communicator->size() < 2 || torch::cuda::device_count() < 2) { GTEST_SKIP() << "This test needs at least 2 GPUs and 2 ranks."; } @@ -183,6 +199,7 @@ TEST_P(CommunicationTest, SendRecv) { return; } + params.type = CommunicationType::SendRecv; params.root = sender; params.team = {0, 1}; if (communicator->deviceId() == sender) { @@ -190,7 +207,7 @@ TEST_P(CommunicationTest, SendRecv) { } else { params.dst_bufs.push_back(at::empty(tensor_size, tensor_options)); } - auto communication = SendRecv(params); + auto communication = IrBuilder::create(&container, params); for (int j : c10::irange(number_of_repetitions)) { resetDstBuffers(); @@ -198,7 +215,8 @@ TEST_P(CommunicationTest, SendRecv) { params.src_bufs.at(0).copy_(at::arange(tensor_size, tensor_options) + j); } - auto work = communication.post(*communicator, GetParam()); + auto work = + postCommunication(communication, communicator->deviceId(), backend); work->wait(); if (communicator->deviceId() == receiver) { @@ -215,17 +233,18 @@ TEST_P(CommunicationTest, SendRecvToSelf) { return; } + params.type = CommunicationType::SendRecv; params.root = sender; params.team = {0}; params.src_bufs.push_back(at::empty(tensor_size, tensor_options)); params.dst_bufs.push_back(at::empty(tensor_size, tensor_options)); - auto communication = SendRecv(params); + auto communication = IrBuilder::create(&container, params); for (int j : c10::irange(number_of_repetitions)) { resetDstBuffers(); params.src_bufs.at(0).copy_(at::arange(tensor_size, tensor_options) + j); - communication.post(*communicator, GetParam()); + postCommunication(communication, communicator->deviceId(), backend); auto obtained = params.dst_bufs.at(0); auto ref = at::arange(tensor_size, tensor_options) + j; @@ -234,6 +253,7 @@ TEST_P(CommunicationTest, SendRecvToSelf) { } TEST_P(CommunicationTest, Reduce) { + params.type = CommunicationType::Reduce; params.redOp = red_op; params.root = root; params.team = all_ranks; @@ -241,7 +261,7 @@ TEST_P(CommunicationTest, Reduce) { if (communicator->deviceId() == root) { params.dst_bufs = {at::empty(tensor_size, tensor_options)}; } - auto communication = Reduce(params); + auto communication = IrBuilder::create(&container, params); for (int j : c10::irange(number_of_repetitions)) { resetDstBuffers(); @@ -249,7 +269,8 @@ TEST_P(CommunicationTest, Reduce) { at::arange(tensor_size, tensor_options) + (communicator->deviceId() + 1) * j); - auto work = communication.post(*communicator, GetParam()); + auto work = + postCommunication(communication, communicator->deviceId(), backend); work->wait(); if (communicator->deviceId() == root) { @@ -263,11 +284,12 @@ TEST_P(CommunicationTest, Reduce) { } TEST_P(CommunicationTest, Allreduce) { + params.type = CommunicationType::Allreduce; params.redOp = red_op; params.team = all_ranks; params.src_bufs = {at::empty(tensor_size, tensor_options)}; params.dst_bufs = {at::empty(tensor_size, tensor_options)}; - auto communication = Allreduce(params); + auto communication = IrBuilder::create(&container, params); for (int j : c10::irange(number_of_repetitions)) { resetDstBuffers(); @@ -275,7 +297,8 @@ TEST_P(CommunicationTest, Allreduce) { at::arange(tensor_size, tensor_options) + (communicator->deviceId() + 1) * j); - auto work = communication.post(*communicator, GetParam()); + auto work = + postCommunication(communication, communicator->deviceId(), backend); work->wait(); auto obtained = params.dst_bufs.at(0); @@ -287,6 +310,7 @@ TEST_P(CommunicationTest, Allreduce) { } TEST_P(CommunicationTest, ReduceScatter) { + params.type = CommunicationType::ReduceScatter; params.redOp = red_op; params.root = root; params.team = all_ranks; @@ -294,7 +318,7 @@ TEST_P(CommunicationTest, ReduceScatter) { params.src_bufs.push_back(at::empty(tensor_size, tensor_options)); } params.dst_bufs = {at::empty(tensor_size, tensor_options)}; - auto communication = ReduceScatter(params); + auto communication = IrBuilder::create(&container, params); for (int j : c10::irange(number_of_repetitions)) { resetDstBuffers(); @@ -304,7 +328,8 @@ TEST_P(CommunicationTest, ReduceScatter) { (communicator->deviceId() + 1) * (i + j)); } - auto work = communication.post(*communicator, GetParam()); + auto work = + postCommunication(communication, communicator->deviceId(), backend); work->wait(); auto obtained = params.dst_bufs.at(0); From fddc1ab96d3b2147b351d613c782c950e4700e9b Mon Sep 17 00:00:00 2001 From: snordmann Date: Mon, 6 May 2024 08:35:55 -0700 Subject: [PATCH 2/4] minor comments --- csrc/multidevice/communication.cpp | 66 +++++++++++-------- csrc/multidevice/communication.h | 6 +- csrc/multidevice/executor.cpp | 7 +- csrc/multidevice/executor.h | 2 +- tests/cpp/test_multidevice_communications.cpp | 34 +++++----- 5 files changed, 63 insertions(+), 52 deletions(-) diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index 8eedacd963d..a1d55afeefd 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -14,6 +14,37 @@ #include namespace nvfuser { + +std::ostream& operator<<(std::ostream& os, const CommunicationType& type) { + switch (type) { + case CommunicationType::Gather: + os << "Gather"; + break; + case CommunicationType::Allgather: + os << "Allgather"; + break; + case CommunicationType::Scatter: + os << "Scatter"; + break; + case CommunicationType::Reduce: + os << "Reduce"; + break; + case CommunicationType::Allreduce: + os << "Allreduce"; + break; + case CommunicationType::ReduceScatter: + os << "ReduceScatter"; + break; + case CommunicationType::Broadcast: + os << "Broadcast"; + break; + case CommunicationType::SendRecv: + os << "SendRecv"; + break; + } + return os; +} + namespace { inline void assertBufferCount( @@ -59,37 +90,14 @@ inline bool isReduction(CommunicationType type) { type == CommunicationType::ReduceScatter; } -inline std::string typeToString(CommunicationType type) { - switch (type) { - case CommunicationType::Gather: - return "Gather"; - case CommunicationType::Allgather: - return "Allgather"; - case CommunicationType::Scatter: - return "Scatter"; - case CommunicationType::Reduce: - return "Reduce"; - case CommunicationType::Allreduce: - return "Allreduce"; - case CommunicationType::ReduceScatter: - return "ReduceScatter"; - case CommunicationType::Broadcast: - return "Broadcast"; - case CommunicationType::SendRecv: - return "SendRecv"; - default: - NVF_ERROR(false); - return ""; - } -} - inline void assertValid( const CommParams& params, const DeviceIdxType my_device_index) { assertBuffersHaveSameSize(params.src_bufs, params.dst_bufs); + std::unordered_set team_without_duplicates( + params.team.begin(), params.team.end()); NVF_ERROR( - std::adjacent_find(params.team.cbegin(), params.team.cend()) == - params.team.cend(), + team_without_duplicates.size() == params.team.size(), "the communication must not involve the same device more than once"); NVF_ERROR(!params.team.empty(), "the team size must be greater than 0"); NVF_ERROR( @@ -181,7 +189,7 @@ std::string Communication::toString(int indent) const { std::string indent1 = ext_indent + " "; std::string indent2 = ext_indent + " "; - ss << ext_indent << "Communication " << typeToString(params_.type) << ": {\n"; + ss << ext_indent << "Communication " << params_.type << ": {\n"; if (hasRoot(params_.type)) { ss << indent1 << "root: " << params_.root << ",\n"; @@ -221,7 +229,7 @@ bool Communication::sameAs(const Statement* other) const { p1.team == p2.team && (!isReduction(p1.type) || p1.redOp == p2.redOp)); } -c10::intrusive_ptr postCommunication( +c10::intrusive_ptr postSingleCommunication( Communication* communication, DeviceIdxType my_device_index, c10::intrusive_ptr backend) { @@ -324,7 +332,7 @@ c10::intrusive_ptr postCommunication( } } default: - NVF_ERROR(false, "Wrong communication type: ", typeToString(params.type)); + NVF_ERROR(false, "Wrong communication type: ", params.type); return nullptr; } } diff --git a/csrc/multidevice/communication.h b/csrc/multidevice/communication.h index 8272609bc77..91965fcdc81 100644 --- a/csrc/multidevice/communication.h +++ b/csrc/multidevice/communication.h @@ -28,6 +28,8 @@ enum class CommunicationType { SendRecv }; +std::ostream& operator<<(std::ostream& os, const CommunicationType& type); + /* This struct gathers all the parameters necessary for the construction a communication @@ -137,7 +139,7 @@ case of a local copy) buffer */ -class NVF_API Communication : public Expr { +class Communication : public Expr { public: using Expr::Expr; Communication(IrBuilderPasskey passkey, CommParams params); @@ -169,7 +171,7 @@ class NVF_API Communication : public Expr { // Triggers the execution of the communication. This is a non-blocking call. // The communication can be posted multiple times -c10::intrusive_ptr postCommunication( +c10::intrusive_ptr postSingleCommunication( Communication* communication, DeviceIdxType my_device_index, c10::intrusive_ptr backend); diff --git a/csrc/multidevice/executor.cpp b/csrc/multidevice/executor.cpp index 8f7f65c3065..f6f24cb8c24 100644 --- a/csrc/multidevice/executor.cpp +++ b/csrc/multidevice/executor.cpp @@ -149,7 +149,7 @@ void MultiDeviceExecutor::postKernel( } } -void MultiDeviceExecutor::postResharding(SegmentedGroup* group) { +void MultiDeviceExecutor::postCommunication(SegmentedGroup* group) { // Lower the group into a vector of Communications NVF_ERROR( group->exprs().size() == 1, @@ -177,7 +177,8 @@ void MultiDeviceExecutor::postResharding(SegmentedGroup* group) { for (auto communication : communications) { auto backend = comm_.getBackendForTeam(communication->params().team, std::nullopt); - auto work = postCommunication(communication, comm_.deviceId(), backend); + auto work = + postSingleCommunication(communication, comm_.deviceId(), backend); if (work) { work->wait(); } @@ -215,7 +216,7 @@ std::vector MultiDeviceExecutor::runWithInput( if (!is_resharding_.at(group)) { postKernel(group, launch_params); } else { - postResharding(group); + postCommunication(group); } } diff --git a/csrc/multidevice/executor.h b/csrc/multidevice/executor.h index d43c8842cbd..9c277464dd2 100644 --- a/csrc/multidevice/executor.h +++ b/csrc/multidevice/executor.h @@ -116,7 +116,7 @@ class MultiDeviceExecutor { // params_.use_fusion_executor_cache = false void postKernel(SegmentedGroup* group, const LaunchParams& launch_params); // execute a SegmentedGroup representing inter-device communication - void postResharding(SegmentedGroup* group); + void postCommunication(SegmentedGroup* group); // Stores concrete computed values, std::unordered_map val_to_IValue_; diff --git a/tests/cpp/test_multidevice_communications.cpp b/tests/cpp/test_multidevice_communications.cpp index 2f5b5ef1310..a468efd21bb 100644 --- a/tests/cpp/test_multidevice_communications.cpp +++ b/tests/cpp/test_multidevice_communications.cpp @@ -83,8 +83,8 @@ TEST_P(CommunicationTest, Gather) { at::arange(tensor_size, tensor_options) + (communicator->deviceId() + 1) * j); - auto work = - postCommunication(communication, communicator->deviceId(), backend); + auto work = postSingleCommunication( + communication, communicator->deviceId(), backend); work->wait(); if (communicator->deviceId() == root) { @@ -113,8 +113,8 @@ TEST_P(CommunicationTest, Allgather) { at::arange(tensor_size, tensor_options) + (communicator->deviceId() + 1) * j); - auto work = - postCommunication(communication, communicator->deviceId(), backend); + auto work = postSingleCommunication( + communication, communicator->deviceId(), backend); work->wait(); for (int i : c10::irange(communicator->size())) { @@ -145,8 +145,8 @@ TEST_P(CommunicationTest, Scatter) { at::arange(tensor_size, tensor_options) + (i + 1) * j); } - auto work = - postCommunication(communication, communicator->deviceId(), backend); + auto work = postSingleCommunication( + communication, communicator->deviceId(), backend); work->wait(); auto obtained = params.dst_bufs.at(0); @@ -173,8 +173,8 @@ TEST_P(CommunicationTest, Broadcast) { params.src_bufs.at(0).copy_(at::arange(tensor_size, tensor_options) + j); } - auto work = - postCommunication(communication, communicator->deviceId(), backend); + auto work = postSingleCommunication( + communication, communicator->deviceId(), backend); if (communicator->size() > 1) { work->wait(); } @@ -215,8 +215,8 @@ TEST_P(CommunicationTest, SendRecv) { params.src_bufs.at(0).copy_(at::arange(tensor_size, tensor_options) + j); } - auto work = - postCommunication(communication, communicator->deviceId(), backend); + auto work = postSingleCommunication( + communication, communicator->deviceId(), backend); work->wait(); if (communicator->deviceId() == receiver) { @@ -244,7 +244,7 @@ TEST_P(CommunicationTest, SendRecvToSelf) { resetDstBuffers(); params.src_bufs.at(0).copy_(at::arange(tensor_size, tensor_options) + j); - postCommunication(communication, communicator->deviceId(), backend); + postSingleCommunication(communication, communicator->deviceId(), backend); auto obtained = params.dst_bufs.at(0); auto ref = at::arange(tensor_size, tensor_options) + j; @@ -269,8 +269,8 @@ TEST_P(CommunicationTest, Reduce) { at::arange(tensor_size, tensor_options) + (communicator->deviceId() + 1) * j); - auto work = - postCommunication(communication, communicator->deviceId(), backend); + auto work = postSingleCommunication( + communication, communicator->deviceId(), backend); work->wait(); if (communicator->deviceId() == root) { @@ -297,8 +297,8 @@ TEST_P(CommunicationTest, Allreduce) { at::arange(tensor_size, tensor_options) + (communicator->deviceId() + 1) * j); - auto work = - postCommunication(communication, communicator->deviceId(), backend); + auto work = postSingleCommunication( + communication, communicator->deviceId(), backend); work->wait(); auto obtained = params.dst_bufs.at(0); @@ -328,8 +328,8 @@ TEST_P(CommunicationTest, ReduceScatter) { (communicator->deviceId() + 1) * (i + j)); } - auto work = - postCommunication(communication, communicator->deviceId(), backend); + auto work = postSingleCommunication( + communication, communicator->deviceId(), backend); work->wait(); auto obtained = params.dst_bufs.at(0); From 3af73760898824a55312502d2275adb1b78eae67 Mon Sep 17 00:00:00 2001 From: snordmann Date: Tue, 7 May 2024 07:05:07 -0700 Subject: [PATCH 3/4] change return type of getRootRelativeIndex from DeviceIdxType to int64_t --- csrc/multidevice/communication.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index a1d55afeefd..ab8bb3bfabe 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -168,7 +168,7 @@ inline void assertValid( // pytorch's process group expects the root to be specified // as an integer between 0 and world_size-1. We choose it to be // the device's relative index within the team -DeviceIdxType getRootRelativeIndex(const CommParams& params) { +int64_t getRootRelativeIndex(const CommParams& params) { auto it = std::find(params.team.begin(), params.team.end(), params.root); return std::distance(params.team.begin(), it); } From 2c7ca1eec436f68ae5687b89d424120167356119 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Fri, 10 May 2024 23:08:46 +0000 Subject: [PATCH 4/4] Minor fixes. --- csrc/multidevice/communication.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index 4275b8bea8b..b7ffc3d0bfd 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -155,7 +155,7 @@ std::string Communication::toString(const int indent_size) const { << std::endl; if (hasRoot(params_.type)) { - indent(ss, indent_size + 1) << "root: " << params_.root << ",\n"; + indent(ss, indent_size + 1) << "root: " << params_.root << "," << std::endl; } indent(ss, indent_size + 1) << "team: " << params_.team << "," << std::endl; indent(ss, indent_size) << "}";