diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index 4f4db1711ab..e5694c8340c 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -32,15 +32,17 @@ inline void assertBuffersHaveSameSize( if (bufs1.empty() && bufs2.empty()) { return; } - auto sizes = (bufs1.empty() ? bufs2 : bufs1).at(0).sizes(); - for (auto& bufs : {bufs1, bufs2}) { - for (auto& buf : bufs) { - NVF_ERROR(buf.sizes() == sizes, "all buffers must have the same size"); + const auto numel = (bufs1.empty() ? bufs2 : bufs1).at(0).numel(); + for (const auto& bufs : {bufs1, bufs2}) { + for (const auto& buf : bufs) { + NVF_ERROR( + buf.numel() == numel, + "all buffers must have the same number of elements"); } } } -inline void post_common(Communication& self, Communicator& comm) { +void post_common(Communication& self, Communicator& comm) { NVF_ERROR( std::find( self.params().team.begin(), @@ -52,7 +54,25 @@ inline void post_common(Communication& self, Communicator& comm) { } inline void doLocalCopy(const at::Tensor& dst, const at::Tensor& src) { - dst.copy_(src, /* non-blocking */ true); + dst.view_as(src).copy_(src, /*non_blocking=*/true); +} + +template +T getInitialValue(c10d::ReduceOp::RedOpType op) { + // TODO: add other ops + switch (op) { + case c10d::ReduceOp::RedOpType::SUM: + return 0; + case c10d::ReduceOp::RedOpType::PRODUCT: + return 1; + case c10d::ReduceOp::RedOpType::MAX: + return std::numeric_limits::min(); + case c10d::ReduceOp::RedOpType::MIN: + return std::numeric_limits::max(); + default: + NVF_ERROR(false, "unsupported reduction op type"); + return 0; + } } } // namespace @@ -61,7 +81,6 @@ Communication::Communication(CommParams params, std::string name, bool has_root) : params_(std::move(params)), collective_type_(std::move(name)), has_root_(has_root) { - assertBuffersHaveSameSize(params_.src_bufs, params_.dst_bufs); NVF_ERROR( std::unique(params_.team.begin(), params_.team.end()) == params_.team.end(), @@ -97,11 +116,6 @@ std::string Communication::toString(int indent) const { ss << r << ", "; } ss << indent1 << "}\n"; - ss << indent1 << "src_bufs: {"; - for (auto& t : params_.src_bufs) { - ss << "\n" << t; - } - ss << "\n" << indent1 << "}\n"; ss << ext_indent << "}"; return ss.str(); @@ -111,180 +125,209 @@ Broadcast::Broadcast(CommParams params) : Communication(params, "broadcast") {} c10::intrusive_ptr Broadcast::post( Communicator& comm, + at::Tensor input_tensor, + at::Tensor output_tensor, std::optional backend) { post_common(*this, comm); if (comm.deviceId() == params_.root) { - assertBufferCount(params_.src_bufs, 1); - if (params_.dst_bufs.size() == 1) { - doLocalCopy(params_.dst_bufs.at(0), params_.src_bufs.at(0)); + if (params_.is_root_in_mesh) { + // Do a local copy and the subsequent broadcast will be in place. Consider + // ProcessGroupNCCL::_broadcast_oop so ncclBroadcast doesn't wait for the + // local copy to complete. + doLocalCopy(output_tensor, input_tensor); } else { - assertBufferCount(params_.dst_bufs, 0); + // `output_tensor` isn't allocated for this device. + output_tensor = input_tensor; } - } else { - assertBufferCount(params_.src_bufs, 0); - assertBufferCount(params_.dst_bufs, 1); } if (params_.team.size() == 1) { return nullptr; } + std::vector tensors({output_tensor}); return comm.getBackendForTeam(params_.team, backend) - ->broadcast( - comm.deviceId() == params_.root ? params_.src_bufs : params_.dst_bufs, - {.rootRank = root_relative_index_}); + ->broadcast(tensors, {.rootRank = root_relative_index_}); } -Gather::Gather(CommParams params) : Communication(params, "gather") { - assertBufferCount(params_.src_bufs, 1); -} +Gather::Gather(CommParams params) : Communication(params, "gather") {} c10::intrusive_ptr Gather::post( Communicator& comm, + at::Tensor input_tensor, + at::Tensor output_tensor, std::optional backend) { post_common(*this, comm); - // This is used to change the representation of the buffers to match c10d - // ProcessGroup API - std::vector> buf_list; - if (comm.deviceId() == params_.root) { - assertBufferCount(params_.dst_bufs, params_.team.size()); - buf_list = {std::move(params_.dst_bufs)}; - } else { - assertBufferCount(params_.dst_bufs, 0); + + if (comm.deviceId() == params_.root && !params_.is_root_in_mesh) { + // This is likely a suboptimal way to allocate tensors for nccl. To benefit + // from zero copy + // (https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/bufferreg.html), + // tensors for nccl should be `ncclMemAlloc`ed and be `ncclCommRegister`ed. + // https://github.com/pytorch/pytorch/issues/124807 is one proposal trying + // to partially address this problem. + input_tensor = at::empty_like(output_tensor.slice(0, 0, 1)); } - auto work = - comm.getBackendForTeam(params_.team, backend) - ->gather( - buf_list, params_.src_bufs, {.rootRank = root_relative_index_}); + std::vector input_tensors({input_tensor}); + + std::vector> output_tensors; if (comm.deviceId() == params_.root) { - params_.dst_bufs = std::move(buf_list.back()); + output_tensors.resize(1); + int64_t j = 0; + for (auto i : c10::irange(params_.team.size())) { + if (root_relative_index_ == static_cast(i) && + !params_.is_root_in_mesh) { + output_tensors[0].push_back(input_tensor); + continue; + } + output_tensors[0].push_back(output_tensor.slice(0, j, j + 1)); + j++; + } + + assertBufferCount(output_tensors[0], params_.team.size()); + assertBuffersHaveSameSize(input_tensors, output_tensors[0]); } - return work; + + return comm.getBackendForTeam(params_.team, backend) + ->gather( + output_tensors, input_tensors, {.rootRank = root_relative_index_}); } Allgather::Allgather(CommParams params) - : Communication(params, "allgather", false) { - assertBufferCount(params_.src_bufs, 1); - assertBufferCount(params_.dst_bufs, params_.team.size()); -} + : Communication(params, "allgather", false) {} c10::intrusive_ptr Allgather::post( Communicator& comm, + at::Tensor input_tensor, + at::Tensor output_tensor, std::optional backend) { post_common(*this, comm); - // This is used to change the representation of the buffers to match c10d - // ProcessGroup API - std::vector> buf_list; - buf_list = {std::move(params_.dst_bufs)}; - auto work = comm.getBackendForTeam(params_.team, backend) - ->allgather(buf_list, params_.src_bufs, {}); - params_.dst_bufs = std::move(buf_list.back()); - return work; -} -Scatter::Scatter(CommParams params) : Communication(params, "scatter") { - assertBufferCount(params_.dst_bufs, 1); + std::vector input_tensors({input_tensor}); + std::vector> output_tensors(1); + output_tensors[0] = at::split(output_tensor, /*split_size=*/1, /*dim=*/0); + + assertBufferCount(output_tensors[0], params_.team.size()); + assertBuffersHaveSameSize(input_tensors, output_tensors[0]); + return comm.getBackendForTeam(params_.team, backend) + ->allgather(output_tensors, input_tensors, {}); } +Scatter::Scatter(CommParams params) : Communication(params, "scatter") {} + c10::intrusive_ptr Scatter::post( Communicator& comm, + at::Tensor input_tensor, + at::Tensor output_tensor, std::optional backend) { post_common(*this, comm); - // This is used to change the representation of the buffers to match c10d - // ProcessGroup API - std::vector> buf_list; - if (comm.deviceId() == params_.root) { - assertBufferCount(params_.src_bufs, params_.team.size()); - buf_list = {std::move(params_.src_bufs)}; - } else { - assertBufferCount(params_.src_bufs, 0); + + if (comm.deviceId() == params_.root && !params_.is_root_in_mesh) { + output_tensor = at::empty_like(input_tensor.slice(0, 0, 1)); } - auto work = - comm.getBackendForTeam(params_.team, backend) - ->scatter( - params_.dst_bufs, buf_list, {.rootRank = root_relative_index_}); + std::vector output_tensors({output_tensor}); + + std::vector> input_tensors; if (comm.deviceId() == params_.root) { - params_.src_bufs = std::move(buf_list.back()); + input_tensors.resize(1); + int64_t j = 0; + for (auto i : c10::irange(params_.team.size())) { + if (root_relative_index_ == static_cast(i) && + !params_.is_root_in_mesh) { + input_tensors.front().push_back(output_tensor); + continue; + } + input_tensors.front().push_back(input_tensor.slice(0, j, j + 1)); + j++; + } + + assertBufferCount(input_tensors[0], params_.team.size()); + assertBuffersHaveSameSize(input_tensors[0], output_tensors); } - return work; -} -Reduce::Reduce(CommParams params) : Communication(params, "reduce") { - assertBuffersHaveSameSize(params_.src_bufs, params_.dst_bufs); - assertBufferCount(params_.src_bufs, 1); + return comm.getBackendForTeam(params_.team, backend) + ->scatter( + output_tensors, input_tensors, {.rootRank = root_relative_index_}); } +Reduce::Reduce(CommParams params) : Communication(params, "reduce") {} + c10::intrusive_ptr Reduce::post( Communicator& comm, + at::Tensor input_tensor, + at::Tensor output_tensor, std::optional backend) { + post_common(*this, comm); + + at::Tensor tensor; if (comm.deviceId() == params_.root) { - assertBufferCount(params_.dst_bufs, 1); + if (params_.is_root_in_mesh) { + doLocalCopy(output_tensor, input_tensor); + tensor = output_tensor; + } else { + NVF_ERROR( + output_tensor.scalar_type() == at::kFloat, + "only float tensors are supported"); + output_tensor.fill_(getInitialValue(params_.redOp)); + tensor = output_tensor; + } } else { - assertBufferCount(params_.dst_bufs, 0); + tensor = input_tensor; } - post_common(*this, comm); - auto& buf = - (comm.deviceId() == params_.root) ? params_.dst_bufs : params_.src_bufs; + std::vector tensors({tensor}); + c10d::ReduceOptions options = { .reduceOp = params_.redOp, .rootRank = root_relative_index_}; - auto team_backend = comm.getBackendForTeam(params_.team, backend); -#if defined(NVFUSER_DISTRIBUTED) && defined(USE_C10D_NCCL) - auto nccl_backend = dynamic_cast(team_backend.get()); - if (nccl_backend) { -#if NVF_TORCH_VERSION_NO_LESS(2, 3, 0) - // API change https://github.com/pytorch/pytorch/pull/119421 - return nccl_backend->_reduce_oop( - buf.at(0), params_.src_bufs.at(0), options); -#else - return nccl_backend->_reduce_oop(buf, params_.src_bufs, options); -#endif - } -#endif - if (comm.deviceId() == params_.root) { - doLocalCopy(params_.dst_bufs.at(0), params_.src_bufs.at(0)); - } - return team_backend->reduce(buf, options); + // TODO: avoid local copy by using out-of-place reduction. + return comm.getBackendForTeam(params_.team, backend) + ->reduce(tensors, options); } Allreduce::Allreduce(CommParams params) - : Communication(params, "allreduce", false) { - assertBuffersHaveSameSize(params_.src_bufs, params_.dst_bufs); - assertBufferCount(params_.src_bufs, 1); - assertBufferCount(params_.dst_bufs, 1); -} + : Communication(params, "allreduce", false) {} c10::intrusive_ptr Allreduce::post( Communicator& comm, + at::Tensor input_tensor, + at::Tensor output_tensor, std::optional backend) { post_common(*this, comm); - doLocalCopy(params_.dst_bufs.at(0), params_.src_bufs.at(0)); + + doLocalCopy(output_tensor, input_tensor); + std::vector output_tensors({output_tensor}); + return comm.getBackendForTeam(params_.team, backend) - ->allreduce(params_.dst_bufs, {.reduceOp = params_.redOp}); + ->allreduce(output_tensors, {.reduceOp = params_.redOp}); } ReduceScatter::ReduceScatter(CommParams params) - : Communication(params, "reduce_scatter", false) { - assertBufferCount(params_.src_bufs, params_.team.size()); - assertBufferCount(params_.dst_bufs, 1); -} + : Communication(params, "reduce_scatter", false) {} c10::intrusive_ptr ReduceScatter::post( Communicator& comm, + at::Tensor input_tensor, + at::Tensor output_tensor, std::optional backend) { post_common(*this, comm); - // This is used to change the representation of the buffers to match c10d - // ProcessGroup API - std::vector> buf_list = {std::move(params_.src_bufs)}; - auto work = comm.getBackendForTeam(params_.team, backend) - ->reduce_scatter( - params_.dst_bufs, buf_list, {.reduceOp = params_.redOp}); - params_.src_bufs = std::move(buf_list.back()); - return work; + + std::vector> input_tensors(1); + NVF_ERROR( + params_.scattered_axis >= 0, + "scattered_axis is expected to be non-negative: ", + params_.scattered_axis) + input_tensors[0] = + at::split(input_tensor, /*split_size=*/1, /*dim=*/params_.scattered_axis); + + std::vector output_tensors({output_tensor}); + + assertBufferCount(input_tensors[0], params_.team.size()); + return comm.getBackendForTeam(params_.team, backend) + ->reduce_scatter( + output_tensors, input_tensors, {.reduceOp = params_.redOp}); } SendRecv::SendRecv(CommParams params) : Communication(params, "send/recv") { - assertBuffersHaveSameSize(params_.src_bufs, params_.dst_bufs); NVF_ERROR( params_.team.size() == 1 || params_.team.size() == 2, "the team size should be 1 or 2"); @@ -292,28 +335,23 @@ SendRecv::SendRecv(CommParams params) : Communication(params, "send/recv") { c10::intrusive_ptr SendRecv::post( Communicator& comm, + at::Tensor input_tensor, + at::Tensor output_tensor, std::optional backend) { post_common(*this, comm); - if (comm.deviceId() == params_.root) { - assertBufferCount(params_.src_bufs, 1); - if (params_.team.size() == 1) { - assertBufferCount(params_.dst_bufs, 1); - doLocalCopy(params_.dst_bufs.at(0), params_.src_bufs.at(0)); - return nullptr; - } else { - assertBufferCount(params_.dst_bufs, 0); - } - } else { - assertBufferCount(params_.src_bufs, 0); - assertBufferCount(params_.dst_bufs, 1); + if (params_.team.size() == 1) { + doLocalCopy(output_tensor, input_tensor); + return nullptr; } + std::vector tensors( + {comm.deviceId() == params_.root ? input_tensor : output_tensor}); return comm.sendRecv( - (params_.team.at(0) == params_.root) ? params_.team.at(1) - : params_.team.at(0), - params_.root, - params_.dst_bufs.empty() ? params_.src_bufs : params_.dst_bufs, + /*receiver=*/(params_.team.at(0) == params_.root) ? params_.team.at(1) + : params_.team.at(0), + /*sender=*/params_.root, + tensors, backend); } diff --git a/csrc/multidevice/communication.h b/csrc/multidevice/communication.h index 8594c89ff5b..e4ce834fa2c 100644 --- a/csrc/multidevice/communication.h +++ b/csrc/multidevice/communication.h @@ -19,43 +19,39 @@ namespace nvfuser { -/* - This struct gathers all the parameters necessary for the - construction a communication -*/ +// This struct gathers all the parameters necessary for the +// construction a communication struct CommParams { DeviceIdxType root = -1; - std::vector src_bufs; - std::vector dst_bufs; - Team team; // should not have duplicate + bool is_root_in_mesh = true; + Team team; // should not have duplicates and should contain both the root and + // the mesh c10d::ReduceOp::RedOpType redOp = c10d::ReduceOp::RedOpType::UNUSED; + int64_t scattered_axis = -1; }; -/* -The class "Communication" represents a MPI-style communication -communication operation to be executed on the network. The base class -Communication should not be used directly but through its derived classes: -Broadcast, Gather, Scatter, Allgather, and SendRecv. Other collectives will be -added later. - -CommParams contains the arguments for the communication constructors. -Note that each process (associated with a device index given by -communicator.deviceId()) will fill CommParams with different arguments, -depending on the role they play in this communication. For example, the root of -a Gather communication will have destination buffers, whereas -non-root will have no destination buffers. Also, the ranks not participating in -the communication should not instantiate it. - -The method "post" triggers the execution of the communication. This call is -non-blocking. The communication can be posted multiple times. -It is assumed that the current device_index (given by -communicator.deviceId()) belongs to the team of the communication, -otherwise an error is thrown. - -NOTE: pytorch's NCCL process group API needs buffers on root for -scatter/gather operation. -*/ - +// The class "Communication" represents a MPI-style communication +// communication operation to be executed on the network. The base class +// Communication should not be used directly but through its derived classes: +// Broadcast, Gather, Scatter, Allgather, and SendRecv. Other collectives will +// be added later. + +// CommParams contains the arguments for the communication constructors. +// Note that each process (associated with a device index given by +// communicator.deviceId()) will fill CommParams with different arguments, +// depending on the role they play in this communication. For example, the root +// of a Gather communication will have destination buffers, whereas +// non-root will have no destination buffers. Also, the ranks not participating +// in the communication should not instantiate it. + +// The method "post" triggers the execution of the communication. This call is +// non-blocking. The communication can be posted multiple times. +// It is assumed that the current device_index (given by +// communicator.deviceId()) belongs to the team of the communication, +// otherwise an error is thrown. + +// NOTE: pytorch's NCCL process group API needs buffers on root for +// scatter/gather operation. class Communication { public: virtual ~Communication() = default; @@ -70,6 +66,8 @@ class Communication { // The communication can be posted multiple times virtual c10::intrusive_ptr post( Communicator& comm, + at::Tensor input_tensor, + at::Tensor output_tensor, std::optional backend = std::nullopt) = 0; protected: @@ -85,6 +83,7 @@ class Communication { private: // used for printing std::string collective_type_; + // FIXME: this seems to be redundant with `root_relative_index_`. // indicates if the communication is rooted bool has_root_ = true; }; @@ -103,6 +102,8 @@ class Broadcast : public Communication { Broadcast(CommParams params); c10::intrusive_ptr post( Communicator& comm, + at::Tensor input_tensor, + at::Tensor output_tensor, std::optional backend = std::nullopt) override; }; @@ -122,6 +123,8 @@ class Gather : public Communication { Gather(CommParams params); c10::intrusive_ptr post( Communicator& comm, + at::Tensor input_tensor, + at::Tensor output_tensor, std::optional backend = std::nullopt) override; }; @@ -139,6 +142,8 @@ class Allgather : public Communication { Allgather(CommParams params); c10::intrusive_ptr post( Communicator& comm, + at::Tensor input_tensor, + at::Tensor output_tensor, std::optional backend = std::nullopt) override; }; @@ -157,6 +162,8 @@ class Scatter : public Communication { Scatter(CommParams params); c10::intrusive_ptr post( Communicator& comm, + at::Tensor input_tensor, + at::Tensor output_tensor, std::optional backend = std::nullopt) override; }; @@ -174,6 +181,8 @@ class Reduce : public Communication { Reduce(CommParams params); c10::intrusive_ptr post( Communicator& comm, + at::Tensor input_tensor, + at::Tensor output_tensor, std::optional backend = std::nullopt) override; }; @@ -189,6 +198,8 @@ class Allreduce : public Communication { Allreduce(CommParams params); c10::intrusive_ptr post( Communicator& comm, + at::Tensor input_tensor, + at::Tensor output_tensor, std::optional backend = std::nullopt) override; }; @@ -204,6 +215,8 @@ class ReduceScatter : public Communication { ReduceScatter(CommParams params); c10::intrusive_ptr post( Communicator& comm, + at::Tensor input_tensor, + at::Tensor output_tensor, std::optional backend = std::nullopt) override; }; @@ -227,6 +240,8 @@ class SendRecv : public Communication { SendRecv(CommParams params); c10::intrusive_ptr post( Communicator& comm, + at::Tensor input_tensor, + at::Tensor output_tensor, std::optional backend = std::nullopt) override; }; diff --git a/csrc/multidevice/executor.cpp b/csrc/multidevice/executor.cpp index 448e232098a..35b35cebac9 100644 --- a/csrc/multidevice/executor.cpp +++ b/csrc/multidevice/executor.cpp @@ -160,23 +160,26 @@ void MultiDeviceExecutor::postCommunication(SegmentedGroup* group) { NVF_ERROR( expr->outputs().size() == 1, "Communication must have exactly one output"); + + auto communications = lowerCommunication(comm_.deviceId(), expr); + + // Compute input_tensor and output_tensor. auto input_val = expr->inputs().at(0); auto output_val = expr->outputs().at(0); - at::Tensor input_tensor, output_tensor; + at::Tensor input_tensor; if (val_to_IValue_.find(input_val) != val_to_IValue_.end()) { input_tensor = val_to_IValue_.at(input_val).toTensor(); } + at::Tensor output_tensor; if (val_to_IValue_.find(output_val) != val_to_IValue_.end()) { output_tensor = val_to_IValue_.at(output_val).toTensor(); } - auto communications = - lowerCommunication(comm_.deviceId(), expr, input_tensor, output_tensor); - // post and wait communications for (auto& communication : communications) { - auto work = communication->post(comm_); - if (work) { + c10::intrusive_ptr work = + communication->post(comm_, input_tensor, output_tensor); + if (work != nullptr) { work->wait(); } } diff --git a/csrc/multidevice/lower_communication.cpp b/csrc/multidevice/lower_communication.cpp index 3d814378dfc..7fb6d4c7f2d 100644 --- a/csrc/multidevice/lower_communication.cpp +++ b/csrc/multidevice/lower_communication.cpp @@ -17,28 +17,6 @@ namespace nvfuser { namespace { -template -inline T getInitialValue(BinaryOpType op) { - // TODO: add other ops - switch (op) { - case BinaryOpType::Add: - // case BinaryOpType::BitwiseOr: - // case BinaryOpType::BitwiseXor: - return 0; - case BinaryOpType::Mul: - return 1; - case BinaryOpType::Max: - return std::numeric_limits::min(); - case BinaryOpType::Min: - return std::numeric_limits::max(); - // case BinaryOpType::BitwiseAnd: - // return ~(T)0; - default: - NVF_ERROR(false, "invalid binary op type"); - return 0; - } -} - // TODO: handle `c10d::RedOpType::reduceOp::AVG` and // `c10d::RedOpType::reduceOp::PREMUL_SUM` inline c10d::ReduceOp::RedOpType getC10dReduceOpType(BinaryOpType op) { @@ -77,22 +55,6 @@ inline bool isDeviceInvolved( return sender_mesh.has(my_device_index) || receiver_mesh.has(my_device_index); } -// Creates a dummy tensor for scatter/gather communications, -// see 'createParamsForGatherScatter' -inline at::Tensor createDummyTensor(at::Tensor reference) { - return at::empty_like(reference, reference.options()); -} - -inline at::Tensor createDummyTensor( - at::Tensor reference, - BinaryOpType op_type) { - // TODO: support other types - NVF_ERROR( - reference.scalar_type() == at::kFloat, - "only float tensors are supported"); - return createDummyTensor(reference).fill_(getInitialValue(op_type)); -} - // Utility function used for setting up a scatter or gather communication // params. Since most of the steps are somewhat similar/opposite in those // cases, we gathered the two implementations into one function. The argument @@ -101,37 +63,15 @@ CommParams createParamsForGatherScatter( DeviceIdxType my_device_index, DeviceIdxType root, TensorView* root_tv, // is_scatter ? input_tv : output_tv - at::Tensor root_buf, // is_scatter? input buf : output buf - at::Tensor buf, // is_scatter? output buf : input buf bool is_scatter) { const DeviceMesh& mesh = root_tv->getDeviceMesh(); CommParams params; params.root = root; params.team = mesh.vector(); - bool is_root_in_mesh = mesh.has(root); - if (!is_root_in_mesh) { + if (!mesh.has(root)) { + params.is_root_in_mesh = false; params.team.push_back(root); } - - if (mesh.has(my_device_index)) { - ((is_scatter) ? params.dst_bufs : params.src_bufs) = {buf}; - } - - if (my_device_index == root) { - for (auto i : c10::irange(mesh.vector().size())) { - auto sliced_buf = root_buf.slice(0, i, i + 1); - ((is_scatter) ? params.src_bufs : params.dst_bufs).push_back(sliced_buf); - } - // The scatter/gather semantics imposes the root to be both - // sender and receiver. If the root is not in the mesh, we thus - // have to artificially make it send and receive a dummy buffer - // Since it is an "inplace" operation, this should not cause any overhead - if (!is_root_in_mesh) { - at::Tensor dummy = createDummyTensor(root_buf.slice(0, 0, 1)); - params.src_bufs.push_back(dummy); - params.dst_bufs.push_back(dummy); - } - } return params; } @@ -140,8 +80,6 @@ void lowerToScatter( DeviceIdxType my_device_index, TensorView* input_tv, TensorView* output_tv, - at::Tensor input_tensor, - at::Tensor output_tensor, std::vector>& comms) { // we arbitrarily choose the first device of the sender mesh to be the root const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh(); @@ -149,8 +87,8 @@ void lowerToScatter( if (!isDeviceInvolved(my_device_index, root, receiver_mesh)) { return; } - auto params = createParamsForGatherScatter( - my_device_index, root, output_tv, input_tensor, output_tensor, true); + auto params = + createParamsForGatherScatter(my_device_index, root, output_tv, true); comms.push_back(std::make_shared(std::move(params))); } @@ -164,8 +102,6 @@ void lowerToGather( DeviceIdxType my_device_index, TensorView* input_tv, TensorView* output_tv, - at::Tensor input_tensor, - at::Tensor output_tensor, std::vector>& comms) { // we create as many 'Gathers' as there are devices in the receiver mesh const DeviceMesh& sender_mesh = input_tv->getDeviceMesh(); @@ -173,8 +109,8 @@ void lowerToGather( if (!isDeviceInvolved(my_device_index, root, sender_mesh)) { continue; } - auto params = createParamsForGatherScatter( - my_device_index, root, input_tv, output_tensor, input_tensor, false); + auto params = + createParamsForGatherScatter(my_device_index, root, input_tv, false); comms.push_back(std::make_shared(std::move(params))); } } @@ -184,8 +120,6 @@ void lowerToAllgather( DeviceIdxType my_device_index, TensorView* input_tv, TensorView* output_tv, - at::Tensor input_tensor, - at::Tensor output_tensor, std::vector>& comms) { const DeviceMesh& mesh = input_tv->getDeviceMesh(); if (!mesh.has(my_device_index)) { @@ -194,12 +128,6 @@ void lowerToAllgather( CommParams params; params.team = mesh.vector(); - for (auto i : c10::irange(mesh.vector().size())) { - params.dst_bufs.push_back( - output_tensor.index({at::indexing::Slice(i, i + 1), "..."})); - } - params.src_bufs = {input_tensor}; - comms.push_back(std::make_shared(std::move(params))); } @@ -207,23 +135,16 @@ void lowerToAllgather( CommParams createParamsForBroadcastOrP2P( DeviceIdxType my_device_index, DeviceIdxType root, - const DeviceMesh& mesh, // receiver devices - at::Tensor input_tensor, - at::Tensor output_tensor) { + // receiver devices + const DeviceMesh& mesh) { CommParams params; params.root = root; params.team = mesh.vector(); if (!mesh.has(root)) { + params.is_root_in_mesh = false; params.team.push_back(root); } - if (my_device_index == root) { - params.src_bufs = {input_tensor}; - } - if (mesh.has(my_device_index)) { - params.dst_bufs = {output_tensor}; - } - return params; } @@ -232,14 +153,11 @@ void lowerToBroadcastOrP2P( DeviceIdxType my_device_index, DeviceIdxType root, const DeviceMesh& mesh, // receiver devices - at::Tensor input_tensor, - at::Tensor output_tensor, std::vector>& comms) { if (!isDeviceInvolved(my_device_index, root, mesh)) { return; } - auto params = createParamsForBroadcastOrP2P( - my_device_index, root, mesh, input_tensor, output_tensor); + auto params = createParamsForBroadcastOrP2P(my_device_index, root, mesh); std::shared_ptr comm; if (mesh.vector().size() == 1) { comm = std::make_shared(std::move(params)); @@ -257,8 +175,6 @@ void lowerToBroadcastOrP2P( DeviceIdxType my_device_index, TensorView* input_tv, TensorView* output_tv, - at::Tensor input_tensor, - at::Tensor output_tensor, bool is_sharded, std::vector>& comms) { const DeviceMesh& sender_mesh = input_tv->getDeviceMesh(); @@ -274,19 +190,12 @@ void lowerToBroadcastOrP2P( my_device_index, sender_mesh.vector().at(i), DeviceMesh({receiver_mesh.vector().at(i)}), - input_tensor, - output_tensor, comms); } } else { // we arbitrarily choose the first device of the sender mesh to be the root lowerToBroadcastOrP2P( - my_device_index, - sender_mesh.vector().at(0), - receiver_mesh, - input_tensor, - output_tensor, - comms); + my_device_index, sender_mesh.vector().at(0), receiver_mesh, comms); } } @@ -295,34 +204,17 @@ CommParams createParamsForReduce( DeviceIdxType root, TensorView* input_tv, TensorView* output_tv, - at::Tensor input_tensor, - at::Tensor output_tensor, BinaryOpType op_type) { const DeviceMesh& mesh = input_tv->getDeviceMesh(); CommParams params; params.root = root; params.redOp = getC10dReduceOpType(op_type); params.team = mesh.vector(); - bool is_root_in_mesh = mesh.has(root); - if (!is_root_in_mesh) { + if (!mesh.has(root)) { + params.is_root_in_mesh = false; params.team.push_back(root); } - - auto sharded_dim = output_tv->getReductionAxis().value(); - if (mesh.has(my_device_index)) { - params.src_bufs = {input_tensor.squeeze(sharded_dim)}; - } - - if (my_device_index == root) { - params.dst_bufs = {output_tensor}; - // The reduce semantics imposes the root to be both - // sender and receiver. If the root is not in the mesh, we thus - // have to artificially make it send and receive a dummy buffer - if (!is_root_in_mesh) { - at::Tensor dummy = createDummyTensor(output_tensor, op_type); - params.src_bufs.push_back(dummy); - } - } + // FIXME: we may want to store sharded_dim to params for speed. return params; } @@ -330,8 +222,6 @@ void lowerToReduce( DeviceIdxType my_device_index, TensorView* input_tv, TensorView* output_tv, - at::Tensor input_tensor, - at::Tensor output_tensor, BinaryOpType op_type, std::vector>& comms) { const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh(); @@ -342,13 +232,7 @@ void lowerToReduce( continue; } auto params = createParamsForReduce( - my_device_index, - root, - input_tv, - output_tv, - input_tensor, - output_tensor, - op_type); + my_device_index, root, input_tv, output_tv, op_type); comms.push_back(std::make_shared(std::move(params))); } } @@ -357,19 +241,16 @@ void lowerToAllreduce( DeviceIdxType my_device_index, TensorView* input_tv, TensorView* output_tv, - at::Tensor input_tensor, - at::Tensor output_tensor, BinaryOpType op_type, std::vector>& comms) { const DeviceMesh& mesh = input_tv->getDeviceMesh(); if (!mesh.has(my_device_index)) { return; } + CommParams params; params.redOp = getC10dReduceOpType(op_type); params.team = mesh.vector(); - params.dst_bufs = {output_tensor}; - params.src_bufs = {input_tensor.view(output_tensor.sizes())}; comms.push_back(std::make_shared(params)); } @@ -377,18 +258,16 @@ void lowerToReduceScatter( DeviceIdxType my_device_index, TensorView* input_tv, TensorView* output_tv, - at::Tensor input_tensor, - at::Tensor output_tensor, BinaryOpType op_type, std::vector>& comms) { const DeviceMesh& mesh = input_tv->getDeviceMesh(); if (!mesh.has(my_device_index)) { return; } + CommParams params; params.redOp = getC10dReduceOpType(op_type); params.team = mesh.vector(); - params.dst_bufs = {output_tensor}; auto reduction_axis = output_tv->getReductionAxis().value(); auto scattered_axis = getShardedAxis(output_tv); // The output tensor is sharded on scattered_axis and needs to be mapped @@ -398,11 +277,7 @@ void lowerToReduceScatter( if (reduction_axis <= scattered_axis) { scattered_axis++; } - for (auto i : c10::irange(mesh.vector().size())) { - auto slice = - input_tensor.slice(scattered_axis, i, i + 1).squeeze(reduction_axis); - params.src_bufs.push_back(slice); - } + params.scattered_axis = scattered_axis; comms.push_back(std::make_shared(params)); } @@ -420,9 +295,7 @@ void lowerToReduceScatter( */ std::vector> lowerCommunication( DeviceIdxType my_device_index, - Expr* c, - at::Tensor input_tensor, - at::Tensor output_tensor) { + Expr* c) { std::vector> comms; NVF_ERROR( c->inputs().size() == 1 && c->inputs().at(0)->isA() && @@ -454,19 +327,6 @@ std::vector> lowerCommunication( original_expr->toString()); bool is_reduction = original_expr->isA(); - auto input_sharded_dim = getShardedAxis(input_tv); - auto output_sharded_dim = getShardedAxis(output_tv); - NVF_ERROR( - !is_input_sharded || !input_tensor.numel() || - static_cast(input_tensor.size(input_sharded_dim)) == 1, - "Sharded dimension should have allocation size 1, but is ", - input_tensor.size(input_sharded_dim)); - NVF_ERROR( - !is_output_sharded || !output_tensor.numel() || is_reduction || - static_cast(output_tensor.size(output_sharded_dim)) == 1, - "Sharded dimension should have allocation size 1, but is ", - output_tensor.size(output_sharded_dim)); - if (is_reduction) { BinaryOpType op_type = output_tv->definition()->as()->getReductionOpType(); @@ -480,70 +340,26 @@ std::vector> lowerCommunication( "ReduceScatter operation must have the same sender and receiver device mesh. " "Insert a Set operation before or after the reduction to reshard ot another device mesh"); lowerToReduceScatter( - my_device_index, - input_tv, - output_tv, - input_tensor, - output_tensor, - op_type, - comms); + my_device_index, input_tv, output_tv, op_type, comms); } else { if (same_mesh) { - lowerToAllreduce( - my_device_index, - input_tv, - output_tv, - input_tensor, - output_tensor, - op_type, - comms); + lowerToAllreduce(my_device_index, input_tv, output_tv, op_type, comms); } else { - lowerToReduce( - my_device_index, - input_tv, - output_tv, - input_tensor, - output_tensor, - op_type, - comms); + lowerToReduce(my_device_index, input_tv, output_tv, op_type, comms); } } } else { if (!is_input_sharded && is_output_sharded) { - lowerToScatter( - my_device_index, - input_tv, - output_tv, - input_tensor, - output_tensor, - comms); + lowerToScatter(my_device_index, input_tv, output_tv, comms); } else if (is_input_sharded && !is_output_sharded) { if (same_mesh) { - lowerToAllgather( - my_device_index, - input_tv, - output_tv, - input_tensor, - output_tensor, - comms); + lowerToAllgather(my_device_index, input_tv, output_tv, comms); } else { - lowerToGather( - my_device_index, - input_tv, - output_tv, - input_tensor, - output_tensor, - comms); + lowerToGather(my_device_index, input_tv, output_tv, comms); } } else { lowerToBroadcastOrP2P( - my_device_index, - input_tv, - output_tv, - input_tensor, - output_tensor, - is_input_sharded, - comms); + my_device_index, input_tv, output_tv, is_input_sharded, comms); } } return comms; diff --git a/csrc/multidevice/lower_communication.h b/csrc/multidevice/lower_communication.h index 3942ae09dd8..172afc9807b 100644 --- a/csrc/multidevice/lower_communication.h +++ b/csrc/multidevice/lower_communication.h @@ -13,16 +13,14 @@ namespace nvfuser { -// returns whether we support transforming a given expression into a series -// of communication +// Returns whether we support transforming a given expression into a series +// of communication. bool isLowerableToCommunication(Expr* expr); // Lower a PipelineCommunication into a series of Communication, given a // device_index. std::vector> lowerCommunication( DeviceIdxType device_index, - Expr* c, - at::Tensor input_tensor, - at::Tensor output_tensor); + Expr* c); } // namespace nvfuser diff --git a/tests/cpp/test_multidevice_communications.cpp b/tests/cpp/test_multidevice_communications.cpp index 0e741c8b731..8c9c042e6e0 100644 --- a/tests/cpp/test_multidevice_communications.cpp +++ b/tests/cpp/test_multidevice_communications.cpp @@ -23,11 +23,14 @@ class CommunicationTest void SetUp() override; void validate(at::Tensor obtained, at::Tensor expected); - void resetDstBuffers(); static constexpr DeviceIdxType root = 0; static constexpr int tensor_size = 1024; - static constexpr int number_of_repetitions = 8; + // This is so we test having multiple inflights collectives on the same + // buffers. This emulates more accurately the type of workload we are + // targeting. + static constexpr int num_repetitions = 8; + // TODO: test other reduction op types. static constexpr c10d::ReduceOp::RedOpType red_op = c10d::ReduceOp::RedOpType::SUM; CommParams params; @@ -54,121 +57,107 @@ void CommunicationTest::validate(at::Tensor obtained, at::Tensor expected) { << obtained; } -void CommunicationTest::resetDstBuffers() { - for (auto& buf : params.dst_bufs) { - buf.copy_(at::full(tensor_size, nan(""), tensor_options)); - } -} - TEST_P(CommunicationTest, Gather) { params.root = root; params.team = all_ranks; - params.src_bufs = {at::empty(tensor_size, tensor_options)}; - if (communicator->deviceId() == root) { - for (int64_t i = 0; i < communicator->size(); i++) { - params.dst_bufs.push_back(at::empty(tensor_size, tensor_options)); - } - } auto communication = Gather(params); - for (int j : c10::irange(number_of_repetitions)) { - resetDstBuffers(); - params.src_bufs.at(0).copy_( - at::arange(tensor_size, tensor_options) + - (communicator->deviceId() + 1) * j); - - auto work = communication.post(*communicator, GetParam()); + at::Tensor input_tensor = at::empty({1, tensor_size}, tensor_options); + at::Tensor output_tensor = + at::empty({communicator->size(), tensor_size}, tensor_options); + for (auto repetition : c10::irange(num_repetitions)) { + input_tensor.copy_( + at::arange(tensor_size, tensor_options).unsqueeze(0) + + (communicator->deviceId() + 1) * repetition); + auto work = communication.post( + *communicator, input_tensor, output_tensor, GetParam()); work->wait(); if (communicator->deviceId() == root) { - for (int i : c10::irange(communicator->size())) { - auto obtained = params.dst_bufs.at(i); - auto ref = at::arange(tensor_size, tensor_options) + (i + 1) * j; - validate(obtained, ref); - } + at::Tensor ref = at::arange(tensor_size, tensor_options).unsqueeze(0) + + at::arange(1, communicator->size() + 1, tensor_options).unsqueeze(1) * + repetition; + validate(output_tensor, ref); } } } TEST_P(CommunicationTest, Allgather) { params.team = all_ranks; - params.src_bufs = { - at::empty(tensor_size, tensor_options) * communicator->deviceId()}; - for (int64_t i = 0; i < communicator->size(); i++) { - params.dst_bufs.push_back(at::empty(tensor_size, tensor_options)); - } auto communication = Allgather(params); - for (int j : c10::irange(number_of_repetitions)) { - resetDstBuffers(); - params.src_bufs.at(0).copy_( - at::arange(tensor_size, tensor_options) + - (communicator->deviceId() + 1) * j); + at::Tensor input_tensor = at::empty({1, tensor_size}, tensor_options); + at::Tensor output_tensor = + at::empty({communicator->size(), tensor_size}, tensor_options); + for (auto repetition : c10::irange(num_repetitions)) { + input_tensor.copy_( + at::arange(tensor_size, tensor_options).unsqueeze(0) + + (communicator->deviceId() + 1) * repetition); - auto work = communication.post(*communicator, GetParam()); + auto work = communication.post( + *communicator, input_tensor, output_tensor, GetParam()); work->wait(); - for (int i : c10::irange(communicator->size())) { - auto obtained = params.dst_bufs.at(i); - auto ref = at::arange(tensor_size, tensor_options) + (i + 1) * j; - validate(obtained, ref); - } + at::Tensor ref = at::arange(tensor_size, tensor_options).unsqueeze(0) + + at::arange(1, communicator->size() + 1, tensor_options).unsqueeze(1) * + repetition; + validate(output_tensor, ref); } } TEST_P(CommunicationTest, Scatter) { params.root = root; params.team = all_ranks; + auto communication = Scatter(params); + + at::Tensor input_tensor; if (communicator->deviceId() == root) { - for (int64_t i = 0; i < communicator->size(); i++) { - params.src_bufs.push_back( - at::empty(tensor_size, tensor_options) * static_cast(i)); - } + input_tensor = + at::empty({communicator->size(), tensor_size}, tensor_options); } - params.dst_bufs = {at::empty(tensor_size, tensor_options)}; - auto communication = Scatter(params); + at::Tensor output_tensor = at::empty({1, tensor_size}, tensor_options); - for (int j : c10::irange(number_of_repetitions)) { - resetDstBuffers(); - for (int i : c10::irange(params.src_bufs.size())) { - params.src_bufs.at(i).copy_( - at::arange(tensor_size, tensor_options) + (i + 1) * j); + for (auto repetition : c10::irange(num_repetitions)) { + if (communicator->deviceId() == root) { + input_tensor.copy_( + at::arange(tensor_size, tensor_options).unsqueeze(0) + + at::arange(1, communicator->size() + 1, tensor_options).unsqueeze(1) * + repetition); } - auto work = communication.post(*communicator, GetParam()); + auto work = communication.post( + *communicator, input_tensor, output_tensor, GetParam()); work->wait(); - auto obtained = params.dst_bufs.at(0); - auto ref = at::arange(tensor_size, tensor_options) + - (communicator->deviceId() + 1) * j; - validate(obtained, ref); + auto ref = at::arange(tensor_size, tensor_options).unsqueeze(0) + + (communicator->deviceId() + 1) * repetition; + validate(output_tensor, ref); } } TEST_P(CommunicationTest, Broadcast) { params.root = root; params.team = all_ranks; - if (communicator->deviceId() == root) { - params.src_bufs = {at::empty(tensor_size, tensor_options)}; - } - params.dst_bufs = {at::empty(tensor_size, tensor_options)}; - auto communication = Broadcast(params); - for (int j : c10::irange(number_of_repetitions)) { - resetDstBuffers(); + at::Tensor input_tensor; + if (communicator->deviceId() == root) { + input_tensor = at::empty({tensor_size}, tensor_options); + } + at::Tensor output_tensor = at::empty({tensor_size}, tensor_options); + for (auto repetition : c10::irange(num_repetitions)) { if (communicator->deviceId() == root) { - params.src_bufs.at(0).copy_(at::arange(tensor_size, tensor_options) + j); + input_tensor.copy_(at::arange(tensor_size, tensor_options) + repetition); } - auto work = communication.post(*communicator, GetParam()); - if (communicator->size() > 1) { + auto work = communication.post( + *communicator, input_tensor, output_tensor, GetParam()); + if (work != nullptr) { work->wait(); } - auto obtained = params.dst_bufs.at(0); - auto ref = at::arange(tensor_size, tensor_options) + j; - validate(obtained, ref); + auto ref = at::arange(tensor_size, tensor_options) + repetition; + validate(output_tensor, ref); } } @@ -177,59 +166,63 @@ TEST_P(CommunicationTest, SendRecv) { GTEST_SKIP() << "This test needs at least 2 GPUs and 2 ranks."; } - DeviceIdxType sender = 0; - DeviceIdxType receiver = 1; - if (communicator->deviceId() > 1) { // only devices 0 and 1 participate + constexpr DeviceIdxType sender = 0; + constexpr DeviceIdxType receiver = 1; + if (communicator->deviceId() > 1) { + // Only devices 0 and 1 participate. return; } params.root = sender; params.team = {0, 1}; + auto communication = SendRecv(params); + + at::Tensor input_tensor; + at::Tensor output_tensor; if (communicator->deviceId() == sender) { - params.src_bufs.push_back(at::empty(tensor_size, tensor_options)); + input_tensor = at::empty({tensor_size}, tensor_options); } else { - params.dst_bufs.push_back(at::empty(tensor_size, tensor_options)); + NVF_ERROR(communicator->deviceId() == receiver); + output_tensor = at::empty({tensor_size}, tensor_options); } - auto communication = SendRecv(params); - for (int j : c10::irange(number_of_repetitions)) { - resetDstBuffers(); + for (auto repetition : c10::irange(num_repetitions)) { if (communicator->deviceId() == sender) { - params.src_bufs.at(0).copy_(at::arange(tensor_size, tensor_options) + j); + input_tensor.copy_(at::arange(tensor_size, tensor_options) + repetition); } - auto work = communication.post(*communicator, GetParam()); + auto work = communication.post( + *communicator, input_tensor, output_tensor, GetParam()); work->wait(); if (communicator->deviceId() == receiver) { - auto obtained = params.dst_bufs.at(0); - auto ref = at::arange(tensor_size, tensor_options) + j; - validate(obtained, ref); + auto ref = at::arange(tensor_size, tensor_options) + repetition; + validate(output_tensor, ref); } } } TEST_P(CommunicationTest, SendRecvToSelf) { - DeviceIdxType sender = 0; - if (communicator->deviceId() > 0) { // only device 0 participates + constexpr DeviceIdxType sender = 0; + if (communicator->deviceId() > 0) { + // Only device 0 participates. return; } params.root = sender; params.team = {0}; - params.src_bufs.push_back(at::empty(tensor_size, tensor_options)); - params.dst_bufs.push_back(at::empty(tensor_size, tensor_options)); auto communication = SendRecv(params); - for (int j : c10::irange(number_of_repetitions)) { - resetDstBuffers(); - params.src_bufs.at(0).copy_(at::arange(tensor_size, tensor_options) + j); + at::Tensor input_tensor = at::empty({tensor_size}, tensor_options); + at::Tensor output_tensor = at::empty_like(input_tensor); - communication.post(*communicator, GetParam()); + for (auto repetition : c10::irange(num_repetitions)) { + input_tensor.copy_(at::arange(tensor_size, tensor_options) + repetition); - auto obtained = params.dst_bufs.at(0); - auto ref = at::arange(tensor_size, tensor_options) + j; - validate(obtained, ref); + communication.post(*communicator, input_tensor, output_tensor, GetParam()); + + auto ref = at::arange(tensor_size, tensor_options) + repetition; + validate(output_tensor, ref); } } @@ -237,27 +230,25 @@ TEST_P(CommunicationTest, Reduce) { params.redOp = red_op; params.root = root; params.team = all_ranks; - params.src_bufs = {at::empty(tensor_size, tensor_options)}; - if (communicator->deviceId() == root) { - params.dst_bufs = {at::empty(tensor_size, tensor_options)}; - } auto communication = Reduce(params); - for (int j : c10::irange(number_of_repetitions)) { - resetDstBuffers(); - params.src_bufs.at(0).copy_( - at::arange(tensor_size, tensor_options) + - (communicator->deviceId() + 1) * j); + at::Tensor input_tensor = at::empty({1, tensor_size}, tensor_options); + at::Tensor output_tensor = at::empty({tensor_size}, tensor_options); + + for (auto repetition : c10::irange(num_repetitions)) { + input_tensor.copy_( + at::arange(tensor_size, tensor_options).unsqueeze(0) + + (communicator->deviceId() + 1) * repetition); - auto work = communication.post(*communicator, GetParam()); + auto work = communication.post( + *communicator, input_tensor, output_tensor, GetParam()); work->wait(); if (communicator->deviceId() == root) { - auto obtained = params.dst_bufs.at(0); - int S = communicator->size(); - auto ref = - at::arange(tensor_size, tensor_options) * S + S * (S + 1) / 2 * j; - validate(obtained, ref); + const int s = communicator->size(); + auto ref = at::arange(tensor_size, tensor_options) * s + + s * (s + 1) / 2 * repetition; + validate(output_tensor, ref); } } } @@ -265,24 +256,23 @@ TEST_P(CommunicationTest, Reduce) { TEST_P(CommunicationTest, Allreduce) { params.redOp = red_op; params.team = all_ranks; - params.src_bufs = {at::empty(tensor_size, tensor_options)}; - params.dst_bufs = {at::empty(tensor_size, tensor_options)}; auto communication = Allreduce(params); - for (int j : c10::irange(number_of_repetitions)) { - resetDstBuffers(); - params.src_bufs.at(0).copy_( - at::arange(tensor_size, tensor_options) + - (communicator->deviceId() + 1) * j); + at::Tensor input_tensor = at::empty({1, tensor_size}, tensor_options); + at::Tensor output_tensor = at::empty({tensor_size}, tensor_options); + for (auto repetition : c10::irange(num_repetitions)) { + input_tensor.copy_( + at::arange(tensor_size, tensor_options).unsqueeze(0) + + (communicator->deviceId() + 1) * repetition); - auto work = communication.post(*communicator, GetParam()); + auto work = communication.post( + *communicator, input_tensor, output_tensor, GetParam()); work->wait(); - auto obtained = params.dst_bufs.at(0); - int S = communicator->size(); - auto ref = - at::arange(tensor_size, tensor_options) * S + S * (S + 1) / 2 * j; - validate(obtained, ref); + const int s = communicator->size(); + auto ref = at::arange(tensor_size, tensor_options) * s + + s * (s + 1) / 2 * repetition; + validate(output_tensor, ref); } } @@ -290,28 +280,32 @@ TEST_P(CommunicationTest, ReduceScatter) { params.redOp = red_op; params.root = root; params.team = all_ranks; - for (int64_t i = 0; i < communicator->size(); i++) { - params.src_bufs.push_back(at::empty(tensor_size, tensor_options)); - } - params.dst_bufs = {at::empty(tensor_size, tensor_options)}; + params.scattered_axis = 1; auto communication = ReduceScatter(params); - for (int j : c10::irange(number_of_repetitions)) { - resetDstBuffers(); - for (int i : c10::irange(communicator->size())) { - params.src_bufs.at(i).copy_( - at::arange(tensor_size, tensor_options) + - (communicator->deviceId() + 1) * (i + j)); - } + const int num_devices = communicator->size(); + const int device_id = communicator->deviceId(); + at::Tensor unsharded_input_tensor = + at::empty({num_devices, num_devices, tensor_size}, tensor_options); + at::Tensor input_tensor = + unsharded_input_tensor.slice(0, device_id, device_id + 1); + at::Tensor output_tensor = at::empty({1, tensor_size}, tensor_options); + + for (auto repetition : c10::irange(num_repetitions)) { + std::ignore = repetition; - auto work = communication.post(*communicator, GetParam()); + // Create a tensor with integer values to avoid rounding error so we can + // validate using `equal` for more confidence. + unsharded_input_tensor.copy_(at::randint( + 2, {num_devices, num_devices, tensor_size}, tensor_options)); + + auto work = communication.post( + *communicator, input_tensor, output_tensor, GetParam()); work->wait(); - auto obtained = params.dst_bufs.at(0); - int S = communicator->size(); - auto ref = at::arange(tensor_size, tensor_options) * S + - S * (S + 1) / 2 * (communicator->deviceId() + j); - validate(obtained, ref); + auto ref = + unsharded_input_tensor.sum({0}).slice(0, device_id, device_id + 1); + validate(output_tensor, ref); } }