diff --git a/csrc/dispatch.h b/csrc/dispatch.h index ad55bd24f94..a99e7cee736 100644 --- a/csrc/dispatch.h +++ b/csrc/dispatch.h @@ -107,7 +107,8 @@ class Val; f(Swizzle); \ f(Swizzle2D); \ f(Resize); \ - f(MatmulOp); + f(MatmulOp); \ + f(Communication); #define DISPATCH_FOR_ALL_KIR_EXPRS(f) \ f(Allocate); \ f(Asm); \ diff --git a/csrc/ir/all_nodes.h b/csrc/ir/all_nodes.h index accbb8544f9..b59192a22e2 100644 --- a/csrc/ir/all_nodes.h +++ b/csrc/ir/all_nodes.h @@ -11,3 +11,4 @@ #include #include #include +#include diff --git a/csrc/ir/base_nodes.h b/csrc/ir/base_nodes.h index 35eb1cd94bb..9ec8c86fb46 100644 --- a/csrc/ir/base_nodes.h +++ b/csrc/ir/base_nodes.h @@ -490,7 +490,7 @@ using newObjectFuncType = Expr*( //! - Constructors need to register with the Fusion after inputs/outputs //! are defined //! - Implementation of bool sameAs(...) -//! 2) dispatch.h/.cpp must be updated to include dispatch of the new Val +//! 2) dispatch.h/.cpp must be updated to include dispatch of the new Expr //! 3) Default mutator function should be added to mutator.h/.cpp //! 4) Printing functions should be added to ir/iostream.h/.cpp //! 5) Lower case convenience functions should be added to arith.h/.cpp (If diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index 0bb0ba3cdf8..b7ffc3d0bfd 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -5,7 +5,9 @@ * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on +#include #include +#include #include #if defined(NVFUSER_DISTRIBUTED) && defined(USE_C10D_NCCL) #include @@ -13,6 +15,37 @@ #include namespace nvfuser { + +std::ostream& operator<<(std::ostream& os, const CommunicationType& type) { + switch (type) { + case CommunicationType::Gather: + os << "Gather"; + break; + case CommunicationType::Allgather: + os << "Allgather"; + break; + case CommunicationType::Scatter: + os << "Scatter"; + break; + case CommunicationType::Reduce: + os << "Reduce"; + break; + case CommunicationType::Allreduce: + os << "Allreduce"; + break; + case CommunicationType::ReduceScatter: + os << "ReduceScatter"; + break; + case CommunicationType::Broadcast: + os << "Broadcast"; + break; + case CommunicationType::SendRecv: + os << "SendRecv"; + break; + } + return os; +} + namespace { inline void assertBufferCount( @@ -43,17 +76,6 @@ inline void assertBuffersHaveSameSize( } } -void post_common(Communication& self, Communicator& comm) { - NVF_ERROR( - std::find( - self.params().team.begin(), - self.params().team.end(), - comm.deviceId()) != self.params().team.end(), - "current device index ", - comm.deviceId(), - " must be present in the communication's team"); -} - inline void doLocalCopy(const at::Tensor& dst, const at::Tensor& src) { dst.view_as(src).copy_(src, /*non_blocking=*/true); } @@ -76,37 +98,63 @@ T getInitialValue(c10d::ReduceOp::RedOpType op) { } } -} // namespace +bool hasRoot(CommunicationType type) { + return type == CommunicationType::Gather || + type == CommunicationType::Scatter || + type == CommunicationType::Broadcast || + type == CommunicationType::SendRecv; +} -Communication::Communication(CommParams params, std::string name, bool has_root) - : params_(std::move(params)), - collective_type_(std::move(name)), - has_root_(has_root) { +bool isReduction(CommunicationType type) { + return type == CommunicationType::Reduce || + type == CommunicationType::Allreduce || + type == CommunicationType::ReduceScatter; +} + +void assertValid(const CommParams& params) { + std::unordered_set team_without_duplicates( + params.team.begin(), params.team.end()); NVF_ERROR( - std::unique(params_.team.begin(), params_.team.end()) == - params_.team.end(), + team_without_duplicates.size() == params.team.size(), "the communication must not involve the same device more than once"); - NVF_ERROR(!params_.team.empty(), "the team size must be greater than 0"); - if (has_root_) { - auto it = std::find(params_.team.begin(), params_.team.end(), params_.root); + NVF_ERROR(!params.team.empty(), "the team size must be greater than 0"); + if (hasRoot(params.type)) { + auto it = std::find(params.team.begin(), params.team.end(), params.root); NVF_ERROR( - it != params_.team.end(), + it != params.team.end(), "root (device ", - params_.root, + params.root, ") must be present in the communication's team"); - // pytorch's process group expects the root to be specified - // as an integer between 0 and world_size-1. We choose it to be - // the device's relative index within the team - root_relative_index_ = std::distance(params_.team.begin(), it); } } +// pytorch's process group expects the root to be specified +// as an integer between 0 and world_size-1. We choose it to be +// the device's relative index within the team +int64_t getRootRelativeIndex(const CommParams& params) { + auto it = std::find(params.team.begin(), params.team.end(), params.root); + return std::distance(params.team.begin(), it); +} + +} // namespace + +Communication::Communication(IrBuilderPasskey passkey, CommParams params) + : Expr(passkey), params_(std::move(params)) { + assertValid(params_); +} + +Communication::Communication(const Communication* src, IrCloner* ir_cloner) + : Expr(src, ir_cloner), params_(src->params()) {} + +NVFUSER_DEFINE_CLONE_AND_CREATE(Communication) + std::string Communication::toString(const int indent_size) const { std::stringstream ss; - indent(ss, indent_size) << "Communication " << collective_type_ << ": {" + indent(ss, indent_size) << "Communication " << params_.type << ": {" << std::endl; - if (has_root_) { + + if (hasRoot(params_.type)) { indent(ss, indent_size + 1) << "root: " << params_.root << "," << std::endl; } indent(ss, indent_size + 1) << "team: " << params_.team << "," << std::endl; @@ -115,17 +163,36 @@ std::string Communication::toString(const int indent_size) const { return ss.str(); } -Broadcast::Broadcast(CommParams params) : Communication(params, "broadcast") {} +std::string Communication::toInlineString(int indent_size) const { + return toString(indent_size); +} -c10::intrusive_ptr Broadcast::post( - Communicator& comm, - at::Tensor input_tensor, - at::Tensor output_tensor, - std::optional backend) { - post_common(*this, comm); +// TODO add checking symbolic representation of src and dst buffers +bool Communication::sameAs(const Statement* other) const { + if (other == this) { + return true; + } + if (!other->isA()) { + return false; + } + const auto& p1 = this->params(); + const auto& p2 = other->as()->params(); + + return ( + p1.type == p2.type && (!hasRoot(p1.type) || p1.root == p2.root) && + p1.team == p2.team && (!isReduction(p1.type) || p1.redOp == p2.redOp)); +} - if (comm.deviceId() == params_.root) { - if (params_.is_root_in_mesh) { +namespace { +c10::intrusive_ptr postBroadcast( + Communication* communication, + DeviceIdxType my_device_index, + c10::intrusive_ptr backend, + at::Tensor input_tensor, + at::Tensor output_tensor) { + const CommParams& params = communication->params(); + if (my_device_index == params.root) { + if (params.is_root_in_mesh) { // Do a local copy and the subsequent broadcast will be in place. Consider // ProcessGroupNCCL::_broadcast_oop so ncclBroadcast doesn't wait for the // local copy to complete. @@ -136,25 +203,23 @@ c10::intrusive_ptr Broadcast::post( } } - if (params_.team.size() == 1) { + if (params.team.size() == 1) { return nullptr; } std::vector tensors({output_tensor}); - return comm.getBackendForTeam(params_.team, backend) - ->broadcast(tensors, {.rootRank = root_relative_index_}); + return backend->broadcast( + tensors, {.rootRank = getRootRelativeIndex(params)}); } -Gather::Gather(CommParams params) : Communication(params, "gather") {} - -c10::intrusive_ptr Gather::post( - Communicator& comm, +c10::intrusive_ptr postGather( + Communication* communication, + DeviceIdxType my_device_index, + c10::intrusive_ptr backend, at::Tensor input_tensor, - at::Tensor output_tensor, - std::optional backend) { - post_common(*this, comm); - - if (comm.deviceId() == params_.root && !params_.is_root_in_mesh) { + at::Tensor output_tensor) { + const CommParams& params = communication->params(); + if (my_device_index == params.root && !params.is_root_in_mesh) { // This is likely a suboptimal way to allocate tensors for nccl. To benefit // from zero copy // (https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/bufferreg.html), @@ -165,13 +230,14 @@ c10::intrusive_ptr Gather::post( } std::vector input_tensors({input_tensor}); + auto root_relative_index = getRootRelativeIndex(params); std::vector> output_tensors; - if (comm.deviceId() == params_.root) { + if (my_device_index == params.root) { output_tensors.resize(1); int64_t j = 0; - for (auto i : c10::irange(params_.team.size())) { - if (root_relative_index_ == static_cast(i) && - !params_.is_root_in_mesh) { + for (auto i : c10::irange(params.team.size())) { + if (root_relative_index == static_cast(i) && + !params.is_root_in_mesh) { output_tensors[0].push_back(input_tensor); continue; } @@ -179,56 +245,52 @@ c10::intrusive_ptr Gather::post( j++; } - assertBufferCount(output_tensors[0], params_.team.size()); + assertBufferCount(output_tensors[0], params.team.size()); assertBuffersHaveSameSize(input_tensors, output_tensors[0]); } - return comm.getBackendForTeam(params_.team, backend) - ->gather( - output_tensors, input_tensors, {.rootRank = root_relative_index_}); + return backend->gather( + output_tensors, input_tensors, {.rootRank = root_relative_index}); } -Allgather::Allgather(CommParams params) - : Communication(params, "allgather", false) {} - -c10::intrusive_ptr Allgather::post( - Communicator& comm, +c10::intrusive_ptr postAllgather( + Communication* communication, + DeviceIdxType my_device_index, + c10::intrusive_ptr backend, at::Tensor input_tensor, - at::Tensor output_tensor, - std::optional backend) { - post_common(*this, comm); + at::Tensor output_tensor) { + const CommParams& params = communication->params(); std::vector input_tensors({input_tensor}); std::vector> output_tensors(1); output_tensors[0] = at::split(output_tensor, /*split_size=*/1, /*dim=*/0); - assertBufferCount(output_tensors[0], params_.team.size()); + assertBufferCount(output_tensors[0], params.team.size()); assertBuffersHaveSameSize(input_tensors, output_tensors[0]); - return comm.getBackendForTeam(params_.team, backend) - ->allgather(output_tensors, input_tensors, {}); + return backend->allgather(output_tensors, input_tensors, {}); } -Scatter::Scatter(CommParams params) : Communication(params, "scatter") {} - -c10::intrusive_ptr Scatter::post( - Communicator& comm, +c10::intrusive_ptr postScatter( + Communication* communication, + DeviceIdxType my_device_index, + c10::intrusive_ptr backend, at::Tensor input_tensor, - at::Tensor output_tensor, - std::optional backend) { - post_common(*this, comm); + at::Tensor output_tensor) { + const CommParams& params = communication->params(); - if (comm.deviceId() == params_.root && !params_.is_root_in_mesh) { + if (my_device_index == params.root && !params.is_root_in_mesh) { output_tensor = at::empty_like(input_tensor.slice(0, 0, 1)); } std::vector output_tensors({output_tensor}); + auto root_relative_index = getRootRelativeIndex(params); std::vector> input_tensors; - if (comm.deviceId() == params_.root) { + if (my_device_index == params.root) { input_tensors.resize(1); int64_t j = 0; - for (auto i : c10::irange(params_.team.size())) { - if (root_relative_index_ == static_cast(i) && - !params_.is_root_in_mesh) { + for (auto i : c10::irange(params.team.size())) { + if (root_relative_index == static_cast(i) && + !params.is_root_in_mesh) { input_tensors.front().push_back(output_tensor); continue; } @@ -236,34 +298,32 @@ c10::intrusive_ptr Scatter::post( j++; } - assertBufferCount(input_tensors[0], params_.team.size()); + assertBufferCount(input_tensors[0], params.team.size()); assertBuffersHaveSameSize(input_tensors[0], output_tensors); } - return comm.getBackendForTeam(params_.team, backend) - ->scatter( - output_tensors, input_tensors, {.rootRank = root_relative_index_}); + return backend->scatter( + output_tensors, input_tensors, {.rootRank = root_relative_index}); } -Reduce::Reduce(CommParams params) : Communication(params, "reduce") {} - -c10::intrusive_ptr Reduce::post( - Communicator& comm, +c10::intrusive_ptr postReduce( + Communication* communication, + DeviceIdxType my_device_index, + c10::intrusive_ptr backend, at::Tensor input_tensor, - at::Tensor output_tensor, - std::optional backend) { - post_common(*this, comm); + at::Tensor output_tensor) { + const CommParams& params = communication->params(); at::Tensor tensor; - if (comm.deviceId() == params_.root) { - if (params_.is_root_in_mesh) { + if (my_device_index == params.root) { + if (params.is_root_in_mesh) { doLocalCopy(output_tensor, input_tensor); tensor = output_tensor; } else { NVF_ERROR( output_tensor.scalar_type() == at::kFloat, "only float tensors are supported"); - output_tensor.fill_(getInitialValue(params_.redOp)); + output_tensor.fill_(getInitialValue(params.redOp)); tensor = output_tensor; } } else { @@ -272,81 +332,124 @@ c10::intrusive_ptr Reduce::post( std::vector tensors({tensor}); c10d::ReduceOptions options = { - .reduceOp = params_.redOp, .rootRank = root_relative_index_}; + .reduceOp = params.redOp, .rootRank = getRootRelativeIndex(params)}; // TODO: avoid local copy by using out-of-place reduction. - return comm.getBackendForTeam(params_.team, backend) - ->reduce(tensors, options); + return backend->reduce(tensors, options); } -Allreduce::Allreduce(CommParams params) - : Communication(params, "allreduce", false) {} - -c10::intrusive_ptr Allreduce::post( - Communicator& comm, +c10::intrusive_ptr postAllreduce( + Communication* communication, + DeviceIdxType my_device_index, + c10::intrusive_ptr backend, at::Tensor input_tensor, - at::Tensor output_tensor, - std::optional backend) { - post_common(*this, comm); + at::Tensor output_tensor) { + const CommParams& params = communication->params(); doLocalCopy(output_tensor, input_tensor); std::vector output_tensors({output_tensor}); - return comm.getBackendForTeam(params_.team, backend) - ->allreduce(output_tensors, {.reduceOp = params_.redOp}); + return backend->allreduce(output_tensors, {.reduceOp = params.redOp}); } -ReduceScatter::ReduceScatter(CommParams params) - : Communication(params, "reduce_scatter", false) {} - -c10::intrusive_ptr ReduceScatter::post( - Communicator& comm, +c10::intrusive_ptr postReduceScatter( + Communication* communication, + DeviceIdxType my_device_index, + c10::intrusive_ptr backend, at::Tensor input_tensor, - at::Tensor output_tensor, - std::optional backend) { - post_common(*this, comm); + at::Tensor output_tensor) { + const CommParams& params = communication->params(); std::vector> input_tensors(1); NVF_ERROR( - params_.scattered_axis >= 0, + params.scattered_axis >= 0, "scattered_axis is expected to be non-negative: ", - params_.scattered_axis) + params.scattered_axis) input_tensors[0] = - at::split(input_tensor, /*split_size=*/1, /*dim=*/params_.scattered_axis); + at::split(input_tensor, /*split_size=*/1, params.scattered_axis); std::vector output_tensors({output_tensor}); - assertBufferCount(input_tensors[0], params_.team.size()); - return comm.getBackendForTeam(params_.team, backend) - ->reduce_scatter( - output_tensors, input_tensors, {.reduceOp = params_.redOp}); + assertBufferCount(input_tensors[0], params.team.size()); + return backend->reduce_scatter( + output_tensors, input_tensors, {.reduceOp = params.redOp}); } -SendRecv::SendRecv(CommParams params) : Communication(params, "send/recv") { +c10::intrusive_ptr postSendRecv( + Communication* communication, + DeviceIdxType my_device_index, + c10::intrusive_ptr backend, + at::Tensor input_tensor, + at::Tensor output_tensor) { + const CommParams& params = communication->params(); + NVF_ERROR( - params_.team.size() == 1 || params_.team.size() == 2, + params.team.size() == 1 || params.team.size() == 2, "the team size should be 1 or 2"); -} - -c10::intrusive_ptr SendRecv::post( - Communicator& comm, - at::Tensor input_tensor, - at::Tensor output_tensor, - std::optional backend) { - post_common(*this, comm); - if (params_.team.size() == 1) { + if (params.team.size() == 1) { doLocalCopy(output_tensor, input_tensor); return nullptr; } - std::vector tensors( - {comm.deviceId() == params_.root ? input_tensor : output_tensor}); - return comm.sendRecv( - /*receiver=*/(params_.team.at(0) == params_.root) ? params_.team.at(1) - : params_.team.at(0), - /*sender=*/params_.root, - tensors, - backend); + const DeviceIdxType sender = params.root; + const DeviceIdxType receiver = + params.team.at(0) == sender ? params.team.at(1) : params.team.at(0); + + std::vector tensors; + if (my_device_index == sender) { + tensors = {input_tensor}; + return backend->send(tensors, static_cast(receiver), /*tag=*/0); + } else { + NVF_ERROR(my_device_index == receiver); + tensors = {output_tensor}; + return backend->recv(tensors, static_cast(sender), /*tag=*/0); + } +} +} // namespace + +c10::intrusive_ptr postSingleCommunication( + Communication* communication, + DeviceIdxType my_device_index, + c10::intrusive_ptr backend, + at::Tensor input_tensor, + at::Tensor output_tensor) { + const CommParams& params = communication->params(); + NVF_ERROR( + std::find(params.team.begin(), params.team.end(), my_device_index) != + params.team.end(), + "current device index ", + my_device_index, + " must be present in the communication's team"); + + switch (communication->params().type) { + case CommunicationType::Gather: + return postGather( + communication, my_device_index, backend, input_tensor, output_tensor); + case CommunicationType::Allgather: + return postAllgather( + communication, my_device_index, backend, input_tensor, output_tensor); + case CommunicationType::Scatter: + return postScatter( + communication, my_device_index, backend, input_tensor, output_tensor); + case CommunicationType::Reduce: + return postReduce( + communication, my_device_index, backend, input_tensor, output_tensor); + case CommunicationType::Allreduce: + return postAllreduce( + communication, my_device_index, backend, input_tensor, output_tensor); + case CommunicationType::ReduceScatter: + return postReduceScatter( + communication, my_device_index, backend, input_tensor, output_tensor); + case CommunicationType::Broadcast: + return postBroadcast( + communication, my_device_index, backend, input_tensor, output_tensor); + case CommunicationType::SendRecv: + return postSendRecv( + communication, my_device_index, backend, input_tensor, output_tensor); + default: + NVF_ERROR(false, "Wrong communication type: ", params.type); + return nullptr; + } } } // namespace nvfuser diff --git a/csrc/multidevice/communication.h b/csrc/multidevice/communication.h index e4ce834fa2c..bf8a74f194a 100644 --- a/csrc/multidevice/communication.h +++ b/csrc/multidevice/communication.h @@ -7,6 +7,8 @@ // clang-format on #pragma once +#include +#include #include #include #ifdef NVFUSER_DISTRIBUTED @@ -19,9 +21,22 @@ namespace nvfuser { -// This struct gathers all the parameters necessary for the -// construction a communication +enum class CommunicationType { + Gather, + Allgather, + Scatter, + Reduce, + Allreduce, + ReduceScatter, + Broadcast, + SendRecv +}; + +std::ostream& operator<<(std::ostream& os, const CommunicationType& type); + +// This struct gathers all the parameters needed to construct a Communication. struct CommParams { + CommunicationType type; DeviceIdxType root = -1; bool is_root_in_mesh = true; Team team; // should not have duplicates and should contain both the root and @@ -52,197 +67,98 @@ struct CommParams { // NOTE: pytorch's NCCL process group API needs buffers on root for // scatter/gather operation. -class Communication { +// (*) Broadcast +// Copies the root's src buffer to each device's dst buffer +// Requirements: +// - the root is set and belongs to the team +// - the root has one src buffer, and no or one dst buffer +// - non-roots have no src buffer and one dst buffer +// - all buffers have the same size +// (*) Gather +// Copies each device's source buffer to the root's respective src +// buffer. The order of the sender devices matches the order of the +// root's buffers. +// Requirements: +// - the root is set and belongs to the team +// - the root has one src buffer and dst buffers +// - non-roots have one src buffer and no dst buffer +// - all buffers have the same size +// (*) Allgather +// Copies each device's src buffer to each device's respective src +// buffer. The order of the devices matches the order of the +// buffers +// Requirements: +// - all device have one src buffer and dst buffers +// - all buffers have the same size +// (*) Scatter +// Copies each root's src buffer to each device's dst buffer. +// The order of the buffers matches the order of the receiver devices +// Requirements: +// - the root is set and belongs to the team +// - the root has src buffers and one dst buffer +// - non-roots have no src buffer and one dst buffer +// - all buffers have the same size +// (*) Reduce +// Reduce the src buffers to the root's dst buffer. +// Requirements: +// - the root is set and belongs to the team +// - the root has one src buffers and one dst buffer +// - non-roots have one src buffer and no dst buffer +// - all buffers have the same size +// (*) Allreduce +// Reduce the src buffers to the dst buffer. +// Requirements: +// - all devices have one src buffer and one dst buffer +// - all buffers have the same size +// (*) ReduceScatter +// Reduce all the src buffers and shard the result to the dst buffers. +// Requirements: +// - all devices have src buffer and one dst buffer +// - all buffers have the same size +// (*) SendRecv +// Copies the sender's src buffers to the receiver's dst buffer +// It is equivalent to a Broadcast with a team of size == 2 +class Communication : public Expr { public: - virtual ~Communication() = default; + using Expr::Expr; + Communication(IrBuilderPasskey passkey, CommParams params); + Communication(const Communication* src, IrCloner* ir_cloner); - std::string toString(int indent = 0) const; + Communication(const Communication& other) = delete; + Communication& operator=(const Communication& other) = delete; + Communication(Communication&& other) = delete; + Communication& operator=(Communication&& other) = delete; - const auto& params() const { - return params_; + NVFUSER_DECLARE_CLONE_AND_CREATE + + std::string toString(int indent_size = 0) const override; + std::string toInlineString(int indent_size = 0) const override; + const char* getOpString() const override { + return "Communication"; } - // Triggers the execution of the communication. This is a non-blocking call. - // The communication can be posted multiple times - virtual c10::intrusive_ptr post( - Communicator& comm, - at::Tensor input_tensor, - at::Tensor output_tensor, - std::optional backend = std::nullopt) = 0; + // TDOO: add params_ (or flattened parameters) as data attributes so this and + // the constructor that takes IrCloner aren't needed. + bool sameAs(const Statement* other) const override; - protected: - // argument "name" is only used for printing - // argument "has_root" indicates if the communication is rooted - Communication(CommParams params, std::string name, bool has_root = true); + // TODO: const CommParams&. + auto params() const { + return params_; + } + private: // store the arguments of the communication CommParams params_; - // stores the relative index of the root in the team - DeviceIdxType root_relative_index_ = -1; - - private: - // used for printing - std::string collective_type_; - // FIXME: this seems to be redundant with `root_relative_index_`. - // indicates if the communication is rooted - bool has_root_ = true; -}; - -/* -Copies the root's src buffer to each device's dst buffer - -Requirements: - - the root is set and belongs to the team - - the root has one src buffer, and no or one dst buffer - - non-roots have no src buffer and one dst buffer - - all buffers have the same size -*/ -class Broadcast : public Communication { - public: - Broadcast(CommParams params); - c10::intrusive_ptr post( - Communicator& comm, - at::Tensor input_tensor, - at::Tensor output_tensor, - std::optional backend = std::nullopt) override; -}; - -/* -Copies each device's source buffer to the root's respective src -buffer. The order of the sender devices matches the order of the -root's buffers. - -Requirements: - - the root is set and belongs to the team - - the root has one src buffer and dst buffers - - non-roots have one src buffer and no dst buffer - - all buffers have the same size -*/ -class Gather : public Communication { - public: - Gather(CommParams params); - c10::intrusive_ptr post( - Communicator& comm, - at::Tensor input_tensor, - at::Tensor output_tensor, - std::optional backend = std::nullopt) override; -}; - -/* -Copies each device's src buffer to each device's respective src -buffer. The order of the devices matches the order of the -buffers - -Requirements: - - all device have one src buffer and dst buffers - - all buffers have the same size -*/ -class Allgather : public Communication { - public: - Allgather(CommParams params); - c10::intrusive_ptr post( - Communicator& comm, - at::Tensor input_tensor, - at::Tensor output_tensor, - std::optional backend = std::nullopt) override; }; -/* -Copies each root's src buffer to each device's dst buffer. -The order of the buffers matches the order of the receiver devices - -Requirements: - - the root is set and belongs to the team - - the root has src buffers and one dst buffer - - non-roots have no src buffer and one dst buffer - - all buffers have the same size -*/ -class Scatter : public Communication { - public: - Scatter(CommParams params); - c10::intrusive_ptr post( - Communicator& comm, - at::Tensor input_tensor, - at::Tensor output_tensor, - std::optional backend = std::nullopt) override; -}; - -/* -Reduce the src buffers to the root's dst buffer. - -Requirements: - - the root is set and belongs to the team - - the root has one src buffers and one dst buffer - - non-roots have one src buffer and no dst buffer - - all buffers have the same size -*/ -class Reduce : public Communication { - public: - Reduce(CommParams params); - c10::intrusive_ptr post( - Communicator& comm, - at::Tensor input_tensor, - at::Tensor output_tensor, - std::optional backend = std::nullopt) override; -}; - -/* -Reduce the src buffers to the dst buffer. - -Requirements: - - all devices have one src buffer and one dst buffer - - all buffers have the same size -*/ -class Allreduce : public Communication { - public: - Allreduce(CommParams params); - c10::intrusive_ptr post( - Communicator& comm, - at::Tensor input_tensor, - at::Tensor output_tensor, - std::optional backend = std::nullopt) override; -}; - -/* -Reduce all the src buffers and shard the result to the dst buffers. - -Requirements: - - all devices have src buffer and one dst buffer - - all buffers have the same size -*/ -class ReduceScatter : public Communication { - public: - ReduceScatter(CommParams params); - c10::intrusive_ptr post( - Communicator& comm, - at::Tensor input_tensor, - at::Tensor output_tensor, - std::optional backend = std::nullopt) override; -}; - -/* -Copies the sender's src buffers to the receiver's dst buffer -It is equivalent to a Broadcast with a team of size == 2 - -Requirements: - - the team must be of size 2 or 1 (in which case the SendRecv reduces to a -local copy) - - all buffers have the same size - - the root is set and belongs to the team. The "root" corresponds to the -sender - - If the team size the root has one src buffers and no dst buffer (or one in -case of a local copy) - - If team is of size 2, the unique non-root have no src buffer and one dst -buffer -*/ -class SendRecv : public Communication { - public: - SendRecv(CommParams params); - c10::intrusive_ptr post( - Communicator& comm, - at::Tensor input_tensor, - at::Tensor output_tensor, - std::optional backend = std::nullopt) override; -}; +// Triggers the execution of the communication. This is a non-blocking call. +// The communication can be posted multiple times +// TODO: c10d::Backend* should be sufficient. +c10::intrusive_ptr postSingleCommunication( + Communication* communication, + DeviceIdxType my_device_index, + c10::intrusive_ptr backend, + at::Tensor input_tensor, + at::Tensor output_tensor); } // namespace nvfuser diff --git a/csrc/multidevice/communicator.cpp b/csrc/multidevice/communicator.cpp index b113edd61d6..7ce1bcbbdd8 100644 --- a/csrc/multidevice/communicator.cpp +++ b/csrc/multidevice/communicator.cpp @@ -244,24 +244,6 @@ c10::intrusive_ptr Communicator::getBackendForTeam( return backends_.at(team_key); } -c10::intrusive_ptr Communicator::sendRecv( - DeviceIdxType receiver, - DeviceIdxType sender, - std::vector& tensors, - std::optional backend, - int tag) { - NVF_ERROR( - deviceId() == sender || deviceId() == receiver, - "only sender or receiver should post the sendRecv"); - NVF_ERROR(sender != receiver, "cannot send to self"); - - auto world = getWorld(backend); - if (deviceId() == sender) { - return world->send(tensors, static_cast(dIdToRank(receiver)), tag); - } - return world->recv(tensors, static_cast(dIdToRank(sender)), tag); -} - c10::intrusive_ptr Communicator::getWorld( std::optional backend) { std::vector all_ranks(size_); diff --git a/csrc/multidevice/communicator.h b/csrc/multidevice/communicator.h index f8a960bffd6..8a57369ab99 100644 --- a/csrc/multidevice/communicator.h +++ b/csrc/multidevice/communicator.h @@ -80,14 +80,6 @@ class Communicator { default_backend_ = backend; } - // performs a send/receive p2p data transfer - c10::intrusive_ptr sendRecv( - DeviceIdxType receiver, - DeviceIdxType sender, - std::vector& tensor, - std::optional backend = std::nullopt, - int tag = 0); - // performs a blocking barrier in the communicator void barrier(std::optional backend = std::nullopt) { getWorld(backend)->barrier()->wait(); diff --git a/csrc/multidevice/executor.cpp b/csrc/multidevice/executor.cpp index 35b35cebac9..fb33eac3bf2 100644 --- a/csrc/multidevice/executor.cpp +++ b/csrc/multidevice/executor.cpp @@ -176,9 +176,11 @@ void MultiDeviceExecutor::postCommunication(SegmentedGroup* group) { } // post and wait communications - for (auto& communication : communications) { - c10::intrusive_ptr work = - communication->post(comm_, input_tensor, output_tensor); + for (Communication* communication : communications) { + c10::intrusive_ptr backend = + comm_.getBackendForTeam(communication->params().team, std::nullopt); + c10::intrusive_ptr work = postSingleCommunication( + communication, comm_.deviceId(), backend, input_tensor, output_tensor); if (work != nullptr) { work->wait(); } diff --git a/csrc/multidevice/lower_communication.cpp b/csrc/multidevice/lower_communication.cpp index 7fb6d4c7f2d..6de4bd78376 100644 --- a/csrc/multidevice/lower_communication.cpp +++ b/csrc/multidevice/lower_communication.cpp @@ -6,6 +6,7 @@ */ // clang-format on #include +#include #include #include #include @@ -66,6 +67,8 @@ CommParams createParamsForGatherScatter( bool is_scatter) { const DeviceMesh& mesh = root_tv->getDeviceMesh(); CommParams params; + params.type = + is_scatter ? CommunicationType::Scatter : CommunicationType::Gather; params.root = root; params.team = mesh.vector(); if (!mesh.has(root)) { @@ -80,7 +83,7 @@ void lowerToScatter( DeviceIdxType my_device_index, TensorView* input_tv, TensorView* output_tv, - std::vector>& comms) { + std::vector& comms) { // we arbitrarily choose the first device of the sender mesh to be the root const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh(); auto root = input_tv->getDeviceMesh().vector().at(0); @@ -89,7 +92,7 @@ void lowerToScatter( } auto params = createParamsForGatherScatter(my_device_index, root, output_tv, true); - comms.push_back(std::make_shared(std::move(params))); + comms.push_back(IrBuilder::create(std::move(params))); } /* @@ -102,7 +105,7 @@ void lowerToGather( DeviceIdxType my_device_index, TensorView* input_tv, TensorView* output_tv, - std::vector>& comms) { + std::vector& comms) { // we create as many 'Gathers' as there are devices in the receiver mesh const DeviceMesh& sender_mesh = input_tv->getDeviceMesh(); for (auto root : output_tv->getDeviceMesh().vector()) { @@ -111,7 +114,7 @@ void lowerToGather( } auto params = createParamsForGatherScatter(my_device_index, root, input_tv, false); - comms.push_back(std::make_shared(std::move(params))); + comms.push_back(IrBuilder::create(std::move(params))); } } @@ -120,15 +123,16 @@ void lowerToAllgather( DeviceIdxType my_device_index, TensorView* input_tv, TensorView* output_tv, - std::vector>& comms) { + std::vector& comms) { const DeviceMesh& mesh = input_tv->getDeviceMesh(); if (!mesh.has(my_device_index)) { return; } CommParams params; + params.type = CommunicationType::Allgather; params.team = mesh.vector(); - comms.push_back(std::make_shared(std::move(params))); + comms.push_back(IrBuilder::create(std::move(params))); } // Creates and set the CommParams for a Broadcast or Send/Recv communication @@ -144,6 +148,7 @@ CommParams createParamsForBroadcastOrP2P( params.is_root_in_mesh = false; params.team.push_back(root); } + params.type = CommunicationType::Broadcast; return params; } @@ -153,18 +158,12 @@ void lowerToBroadcastOrP2P( DeviceIdxType my_device_index, DeviceIdxType root, const DeviceMesh& mesh, // receiver devices - std::vector>& comms) { + std::vector& comms) { if (!isDeviceInvolved(my_device_index, root, mesh)) { return; } auto params = createParamsForBroadcastOrP2P(my_device_index, root, mesh); - std::shared_ptr comm; - if (mesh.vector().size() == 1) { - comm = std::make_shared(std::move(params)); - } else { - comm = std::make_shared(std::move(params)); - } - comms.push_back(comm); + comms.push_back(IrBuilder::create(std::move(params))); } // Adds several Broadcast or Send/Recv communications to the vector 'comms' @@ -176,7 +175,7 @@ void lowerToBroadcastOrP2P( TensorView* input_tv, TensorView* output_tv, bool is_sharded, - std::vector>& comms) { + std::vector& comms) { const DeviceMesh& sender_mesh = input_tv->getDeviceMesh(); const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh(); if (is_sharded) { @@ -207,6 +206,7 @@ CommParams createParamsForReduce( BinaryOpType op_type) { const DeviceMesh& mesh = input_tv->getDeviceMesh(); CommParams params; + params.type = CommunicationType::Reduce; params.root = root; params.redOp = getC10dReduceOpType(op_type); params.team = mesh.vector(); @@ -223,7 +223,7 @@ void lowerToReduce( TensorView* input_tv, TensorView* output_tv, BinaryOpType op_type, - std::vector>& comms) { + std::vector& comms) { const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh(); const DeviceMesh& sender_mesh = input_tv->getDeviceMesh(); // we create as many Reduces as there are devices in the receiver mesh @@ -233,7 +233,7 @@ void lowerToReduce( } auto params = createParamsForReduce( my_device_index, root, input_tv, output_tv, op_type); - comms.push_back(std::make_shared(std::move(params))); + comms.push_back(IrBuilder::create(std::move(params))); } } @@ -242,16 +242,17 @@ void lowerToAllreduce( TensorView* input_tv, TensorView* output_tv, BinaryOpType op_type, - std::vector>& comms) { + std::vector& comms) { const DeviceMesh& mesh = input_tv->getDeviceMesh(); if (!mesh.has(my_device_index)) { return; } CommParams params; + params.type = CommunicationType::Allreduce; params.redOp = getC10dReduceOpType(op_type); params.team = mesh.vector(); - comms.push_back(std::make_shared(params)); + comms.push_back(IrBuilder::create(params)); } void lowerToReduceScatter( @@ -259,13 +260,14 @@ void lowerToReduceScatter( TensorView* input_tv, TensorView* output_tv, BinaryOpType op_type, - std::vector>& comms) { + std::vector& comms) { const DeviceMesh& mesh = input_tv->getDeviceMesh(); if (!mesh.has(my_device_index)) { return; } CommParams params; + params.type = CommunicationType::ReduceScatter; params.redOp = getC10dReduceOpType(op_type); params.team = mesh.vector(); auto reduction_axis = output_tv->getReductionAxis().value(); @@ -279,7 +281,7 @@ void lowerToReduceScatter( } params.scattered_axis = scattered_axis; - comms.push_back(std::make_shared(params)); + comms.push_back(IrBuilder::create(params)); } } // namespace @@ -293,10 +295,10 @@ void lowerToReduceScatter( sources *) Leverage the topology to ensure that the senders and recerivers are close */ -std::vector> lowerCommunication( +std::vector lowerCommunication( DeviceIdxType my_device_index, Expr* c) { - std::vector> comms; + std::vector comms; NVF_ERROR( c->inputs().size() == 1 && c->inputs().at(0)->isA() && c->outputs().size() == 1 && c->outputs().at(0)->isA(), diff --git a/csrc/multidevice/lower_communication.h b/csrc/multidevice/lower_communication.h index 172afc9807b..f6dcf8879e3 100644 --- a/csrc/multidevice/lower_communication.h +++ b/csrc/multidevice/lower_communication.h @@ -19,7 +19,7 @@ bool isLowerableToCommunication(Expr* expr); // Lower a PipelineCommunication into a series of Communication, given a // device_index. -std::vector> lowerCommunication( +std::vector lowerCommunication( DeviceIdxType device_index, Expr* c); diff --git a/tests/cpp/test_multidevice_communications.cpp b/tests/cpp/test_multidevice_communications.cpp index 8c9c042e6e0..eaa0b3a3e93 100644 --- a/tests/cpp/test_multidevice_communications.cpp +++ b/tests/cpp/test_multidevice_communications.cpp @@ -7,6 +7,8 @@ // clang-format on #include +#include +#include #include #include #include @@ -35,11 +37,14 @@ class CommunicationTest c10d::ReduceOp::RedOpType::SUM; CommParams params; std::vector all_ranks; + c10::intrusive_ptr backend; + IrContainer container; }; CommunicationTest::CommunicationTest() { all_ranks = std::vector(communicator->size()); std::iota(all_ranks.begin(), all_ranks.end(), 0); + backend = communicator->getBackendForTeam(all_ranks, GetParam()); } void CommunicationTest::SetUp() { @@ -58,9 +63,10 @@ void CommunicationTest::validate(at::Tensor obtained, at::Tensor expected) { } TEST_P(CommunicationTest, Gather) { + params.type = CommunicationType::Gather; params.root = root; params.team = all_ranks; - auto communication = Gather(params); + auto communication = IrBuilder::create(&container, params); at::Tensor input_tensor = at::empty({1, tensor_size}, tensor_options); at::Tensor output_tensor = @@ -69,8 +75,12 @@ TEST_P(CommunicationTest, Gather) { input_tensor.copy_( at::arange(tensor_size, tensor_options).unsqueeze(0) + (communicator->deviceId() + 1) * repetition); - auto work = communication.post( - *communicator, input_tensor, output_tensor, GetParam()); + auto work = postSingleCommunication( + communication, + communicator->deviceId(), + backend, + input_tensor, + output_tensor); work->wait(); if (communicator->deviceId() == root) { @@ -83,8 +93,9 @@ TEST_P(CommunicationTest, Gather) { } TEST_P(CommunicationTest, Allgather) { + params.type = CommunicationType::Allgather; params.team = all_ranks; - auto communication = Allgather(params); + auto communication = IrBuilder::create(&container, params); at::Tensor input_tensor = at::empty({1, tensor_size}, tensor_options); at::Tensor output_tensor = @@ -94,8 +105,12 @@ TEST_P(CommunicationTest, Allgather) { at::arange(tensor_size, tensor_options).unsqueeze(0) + (communicator->deviceId() + 1) * repetition); - auto work = communication.post( - *communicator, input_tensor, output_tensor, GetParam()); + auto work = postSingleCommunication( + communication, + communicator->deviceId(), + backend, + input_tensor, + output_tensor); work->wait(); at::Tensor ref = at::arange(tensor_size, tensor_options).unsqueeze(0) + @@ -106,9 +121,10 @@ TEST_P(CommunicationTest, Allgather) { } TEST_P(CommunicationTest, Scatter) { + params.type = CommunicationType::Scatter; params.root = root; params.team = all_ranks; - auto communication = Scatter(params); + auto communication = IrBuilder::create(&container, params); at::Tensor input_tensor; if (communicator->deviceId() == root) { @@ -125,8 +141,12 @@ TEST_P(CommunicationTest, Scatter) { repetition); } - auto work = communication.post( - *communicator, input_tensor, output_tensor, GetParam()); + auto work = postSingleCommunication( + communication, + communicator->deviceId(), + backend, + input_tensor, + output_tensor); work->wait(); auto ref = at::arange(tensor_size, tensor_options).unsqueeze(0) + @@ -136,9 +156,10 @@ TEST_P(CommunicationTest, Scatter) { } TEST_P(CommunicationTest, Broadcast) { + params.type = CommunicationType::Broadcast; params.root = root; params.team = all_ranks; - auto communication = Broadcast(params); + auto communication = IrBuilder::create(&container, params); at::Tensor input_tensor; if (communicator->deviceId() == root) { @@ -150,8 +171,12 @@ TEST_P(CommunicationTest, Broadcast) { input_tensor.copy_(at::arange(tensor_size, tensor_options) + repetition); } - auto work = communication.post( - *communicator, input_tensor, output_tensor, GetParam()); + auto work = postSingleCommunication( + communication, + communicator->deviceId(), + backend, + input_tensor, + output_tensor); if (work != nullptr) { work->wait(); } @@ -162,6 +187,9 @@ TEST_P(CommunicationTest, Broadcast) { } TEST_P(CommunicationTest, SendRecv) { + if (GetParam() == CommunicatorBackend::ucc) { + GTEST_SKIP() << "Disabling because of UCC hangs, see issue #2091"; + } if (communicator->size() < 2 || torch::cuda::device_count() < 2) { GTEST_SKIP() << "This test needs at least 2 GPUs and 2 ranks."; } @@ -173,9 +201,10 @@ TEST_P(CommunicationTest, SendRecv) { return; } + params.type = CommunicationType::SendRecv; params.root = sender; params.team = {0, 1}; - auto communication = SendRecv(params); + auto communication = IrBuilder::create(&container, params); at::Tensor input_tensor; at::Tensor output_tensor; @@ -191,8 +220,12 @@ TEST_P(CommunicationTest, SendRecv) { input_tensor.copy_(at::arange(tensor_size, tensor_options) + repetition); } - auto work = communication.post( - *communicator, input_tensor, output_tensor, GetParam()); + auto work = postSingleCommunication( + communication, + communicator->deviceId(), + backend, + input_tensor, + output_tensor); work->wait(); if (communicator->deviceId() == receiver) { @@ -209,9 +242,10 @@ TEST_P(CommunicationTest, SendRecvToSelf) { return; } + params.type = CommunicationType::SendRecv; params.root = sender; params.team = {0}; - auto communication = SendRecv(params); + auto communication = IrBuilder::create(&container, params); at::Tensor input_tensor = at::empty({tensor_size}, tensor_options); at::Tensor output_tensor = at::empty_like(input_tensor); @@ -219,7 +253,12 @@ TEST_P(CommunicationTest, SendRecvToSelf) { for (auto repetition : c10::irange(num_repetitions)) { input_tensor.copy_(at::arange(tensor_size, tensor_options) + repetition); - communication.post(*communicator, input_tensor, output_tensor, GetParam()); + postSingleCommunication( + communication, + communicator->deviceId(), + backend, + input_tensor, + output_tensor); auto ref = at::arange(tensor_size, tensor_options) + repetition; validate(output_tensor, ref); @@ -227,10 +266,11 @@ TEST_P(CommunicationTest, SendRecvToSelf) { } TEST_P(CommunicationTest, Reduce) { + params.type = CommunicationType::Reduce; params.redOp = red_op; params.root = root; params.team = all_ranks; - auto communication = Reduce(params); + auto communication = IrBuilder::create(&container, params); at::Tensor input_tensor = at::empty({1, tensor_size}, tensor_options); at::Tensor output_tensor = at::empty({tensor_size}, tensor_options); @@ -240,8 +280,12 @@ TEST_P(CommunicationTest, Reduce) { at::arange(tensor_size, tensor_options).unsqueeze(0) + (communicator->deviceId() + 1) * repetition); - auto work = communication.post( - *communicator, input_tensor, output_tensor, GetParam()); + auto work = postSingleCommunication( + communication, + communicator->deviceId(), + backend, + input_tensor, + output_tensor); work->wait(); if (communicator->deviceId() == root) { @@ -254,9 +298,10 @@ TEST_P(CommunicationTest, Reduce) { } TEST_P(CommunicationTest, Allreduce) { + params.type = CommunicationType::Allreduce; params.redOp = red_op; params.team = all_ranks; - auto communication = Allreduce(params); + auto communication = IrBuilder::create(&container, params); at::Tensor input_tensor = at::empty({1, tensor_size}, tensor_options); at::Tensor output_tensor = at::empty({tensor_size}, tensor_options); @@ -265,8 +310,12 @@ TEST_P(CommunicationTest, Allreduce) { at::arange(tensor_size, tensor_options).unsqueeze(0) + (communicator->deviceId() + 1) * repetition); - auto work = communication.post( - *communicator, input_tensor, output_tensor, GetParam()); + auto work = postSingleCommunication( + communication, + communicator->deviceId(), + backend, + input_tensor, + output_tensor); work->wait(); const int s = communicator->size(); @@ -277,11 +326,12 @@ TEST_P(CommunicationTest, Allreduce) { } TEST_P(CommunicationTest, ReduceScatter) { + params.type = CommunicationType::ReduceScatter; params.redOp = red_op; params.root = root; params.team = all_ranks; params.scattered_axis = 1; - auto communication = ReduceScatter(params); + auto communication = IrBuilder::create(&container, params); const int num_devices = communicator->size(); const int device_id = communicator->deviceId(); @@ -299,8 +349,12 @@ TEST_P(CommunicationTest, ReduceScatter) { unsharded_input_tensor.copy_(at::randint( 2, {num_devices, num_devices, tensor_size}, tensor_options)); - auto work = communication.post( - *communicator, input_tensor, output_tensor, GetParam()); + auto work = postSingleCommunication( + communication, + communicator->deviceId(), + backend, + input_tensor, + output_tensor); work->wait(); auto ref =