From efe406904cfbd11e88f5a4b9b225dc35f73cf1e4 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Wed, 15 May 2024 21:37:11 +0000 Subject: [PATCH 1/4] Fix comments. --- csrc/multidevice/communication.h | 92 +++++++++++++++----------------- 1 file changed, 42 insertions(+), 50 deletions(-) diff --git a/csrc/multidevice/communication.h b/csrc/multidevice/communication.h index 7457679af58..95380c92bd2 100644 --- a/csrc/multidevice/communication.h +++ b/csrc/multidevice/communication.h @@ -53,21 +53,54 @@ struct CommParams { // Communication should not be used directly but through its derived classes: // Broadcast, Gather, Scatter, Allgather, and SendRecv. Other collectives will // be added later. +class Communication : public Expr { + public: + using Expr::Expr; + Communication(IrBuilderPasskey passkey, CommParams params); + Communication(const Communication* src, IrCloner* ir_cloner); -// CommParams contains the arguments for the communication constructors. -// Note that each process (associated with a device index given by -// communicator.deviceId()) will fill CommParams with different arguments, -// depending on the role they play in this communication. For example, the root -// of a Gather communication will have destination buffers, whereas -// non-root will have no destination buffers. Also, the ranks not participating -// in the communication should not instantiate it. + Communication(const Communication& other) = delete; + Communication& operator=(const Communication& other) = delete; + Communication(Communication&& other) = delete; + Communication& operator=(Communication&& other) = delete; + + NVFUSER_DECLARE_CLONE_AND_CREATE + + std::string toString(int indent_size = 0) const override; + std::string toInlineString(int indent_size = 0) const override; + const char* getOpString() const override { + return "Communication"; + } + + // TDOO: add params_ (or flattened parameters) as data attributes so this and + // the constructor that takes IrCloner aren't needed. + bool sameAs(const Statement* other) const override; + + const CommParams& params() const { + return params_; + } + + bool isRootInMesh() const { + return params_.mesh.has(params_.root); + } + + const Team& team(); + + private: + // Stores the arguments used to construct the communication. + CommParams params_; + // This can be computed from params_, but given how frequently this is used + // in the hot path, I'm currently storing it as a field that'll be computed by + // Communication::team(). + Team team_; +}; // The method "post" triggers the execution of the communication. This call is // non-blocking. The communication can be posted multiple times. // It is assumed that the current device_index (given by // communicator.deviceId()) belongs to the team of the communication, // otherwise an error is thrown. - +// // NOTE: pytorch's NCCL process group API needs buffers on root for // scatter/gather operation. // (*) Broadcast @@ -121,48 +154,7 @@ struct CommParams { // (*) SendRecv // Copies the sender's src buffers to the receiver's dst buffer // It is equivalent to a Broadcast with a team of size == 2 -class Communication : public Expr { - public: - using Expr::Expr; - Communication(IrBuilderPasskey passkey, CommParams params); - Communication(const Communication* src, IrCloner* ir_cloner); - - Communication(const Communication& other) = delete; - Communication& operator=(const Communication& other) = delete; - Communication(Communication&& other) = delete; - Communication& operator=(Communication&& other) = delete; - - NVFUSER_DECLARE_CLONE_AND_CREATE - - std::string toString(int indent_size = 0) const override; - std::string toInlineString(int indent_size = 0) const override; - const char* getOpString() const override { - return "Communication"; - } - - // TDOO: add params_ (or flattened parameters) as data attributes so this and - // the constructor that takes IrCloner aren't needed. - bool sameAs(const Statement* other) const override; - - const CommParams& params() const { - return params_; - } - - bool isRootInMesh() const { - return params_.mesh.has(params_.root); - } - - const Team& team() const { - return params_.team; - } - - private: - // Stores the arguments used to construct the communication. - CommParams params_; -}; - -// Triggers the execution of the communication. This is a non-blocking call. -// The communication can be posted multiple times +// // TODO: c10d::Backend* should be sufficient. c10::intrusive_ptr postSingleCommunication( Communication* communication, From 8f5a178283be26c5877012cc7adf5456e9dbdc8c Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Wed, 15 May 2024 22:59:29 +0000 Subject: [PATCH 2/4] Remove CommParams and make the fields there data attributes. --- csrc/multidevice/communication.cpp | 144 ++++++++---------- csrc/multidevice/communication.h | 72 +++++---- csrc/multidevice/lower_communication.cpp | 59 +++---- tests/cpp/test_multidevice_communications.cpp | 78 ++++------ 4 files changed, 155 insertions(+), 198 deletions(-) diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index 6306e1d65d8..be890ba9098 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -109,50 +109,54 @@ bool isReduction(CommunicationType type) { type == CommunicationType::ReduceScatter; } -// pytorch's process group expects the root to be specified -// as an integer between 0 and world_size-1. We choose it to be -// the device's relative index within the team -int64_t getRootRelativeIndex(const CommParams& params) { - int64_t relative_index = params.mesh.idxOf(params.root); - // This assumes the root, if not in mesh, is added to the end of the team - // vector. Given the assumption is distantly enforced at - // Communication::team(), I slightly prefer making this function a method of - // Communication so relative_index can be computed from Communication::team() - // directly. Wdyt? - if (relative_index == -1) { - relative_index = params.mesh.size(); - } - return relative_index; -} - } // namespace -Communication::Communication(IrBuilderPasskey passkey, CommParams params) - : Expr(passkey), params_(std::move(params)) { - NVF_ERROR(params_.mesh.size() > 0, "The mesh size must be greater than 0."); +Communication::Communication( + IrBuilderPasskey passkey, + CommunicationType type, + DeviceMesh mesh, + Team team, + DeviceIdxType root, + RedOpType red_op, + int64_t scattered_axis) + : Expr(passkey) { + NVF_ERROR(mesh.size() > 0, "The mesh size must be greater than 0."); NVF_ERROR( - hasRoot(params_.type) == (params_.root >= 0), + hasRoot(type) == (root >= 0), "Root ", - params_.root, + root, " is not expected by CommunicationType ", - params_.type); + type); + NVF_ERROR(isReduction(type) == (red_op != RedOpType::UNUSED)) + NVF_ERROR( + (type == CommunicationType::ReduceScatter) == (scattered_axis >= 0)); + + addDataAttribute(type); + addDataAttribute(mesh); + addDataAttribute(team); + addDataAttribute(root); + addDataAttribute(red_op); + addDataAttribute(scattered_axis); } -Communication::Communication(const Communication* src, IrCloner* ir_cloner) - : Expr(src, ir_cloner), params_(src->params()) {} - NVFUSER_DEFINE_CLONE_AND_CREATE(Communication) +int64_t Communication::getRootRelativeIndex() { + auto i = std::find(team().begin(), team().end(), root()); + NVF_ERROR( + i != team().end(), "Unable to find root ", root(), " in team ", team()); + return std::distance(team().begin(), i); +} + std::string Communication::toString(const int indent_size) const { std::stringstream ss; - indent(ss, indent_size) << "Communication " << params_.type << ": {" - << std::endl; - if (hasRoot(params_.type)) { - indent(ss, indent_size + 1) << "root: " << params_.root << "," << std::endl; + indent(ss, indent_size) << "Communication " << type() << ": {" << std::endl; + if (hasRoot(type())) { + indent(ss, indent_size + 1) << "root: " << root() << "," << std::endl; } - indent(ss, indent_size + 1) << "mesh: " << params_.mesh << "," << std::endl; - indent(ss, indent_size + 1) << "team: " << params_.team << "," << std::endl; + indent(ss, indent_size + 1) << "mesh: " << mesh() << "," << std::endl; + indent(ss, indent_size + 1) << "team: " << team() << "," << std::endl; indent(ss, indent_size) << "}"; return ss.str(); @@ -162,23 +166,6 @@ std::string Communication::toInlineString(int indent_size) const { return toString(indent_size); } -// TODO add checking symbolic representation of src and dst buffers -bool Communication::sameAs(const Statement* other) const { - if (other == this) { - return true; - } - if (!other->isA()) { - return false; - } - const auto& p1 = this->params(); - const auto& p2 = other->as()->params(); - - return (p1.type == p2.type && (!hasRoot(p1.type) || p1.root == p2.root) && - p1.mesh == p2.mesh && - (!isReduction(p1.type) || p1.redOp == p2.redOp)) && - p1.team == p2.team; -} - namespace { c10::intrusive_ptr postBroadcast( Communication* communication, @@ -186,8 +173,7 @@ c10::intrusive_ptr postBroadcast( c10::intrusive_ptr backend, at::Tensor input_tensor, at::Tensor output_tensor) { - const CommParams& params = communication->params(); - if (my_device_index == params.root) { + if (my_device_index == communication->root()) { if (communication->isRootInMesh()) { // Do a local copy and the subsequent broadcast will be in place. Consider // ProcessGroupNCCL::_broadcast_oop so ncclBroadcast doesn't wait for the @@ -205,7 +191,7 @@ c10::intrusive_ptr postBroadcast( std::vector tensors({output_tensor}); return backend->broadcast( - tensors, {.rootRank = getRootRelativeIndex(params)}); + tensors, {.rootRank = communication->getRootRelativeIndex()}); } c10::intrusive_ptr postGather( @@ -214,8 +200,8 @@ c10::intrusive_ptr postGather( c10::intrusive_ptr backend, at::Tensor input_tensor, at::Tensor output_tensor) { - const CommParams& params = communication->params(); - if (my_device_index == params.root && !communication->isRootInMesh()) { + if (my_device_index == communication->root() && + !communication->isRootInMesh()) { // This is likely a suboptimal way to allocate tensors for nccl. To benefit // from zero copy // (https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/bufferreg.html), @@ -226,9 +212,9 @@ c10::intrusive_ptr postGather( } std::vector input_tensors({input_tensor}); - auto root_relative_index = getRootRelativeIndex(params); + auto root_relative_index = communication->getRootRelativeIndex(); std::vector> output_tensors; - if (my_device_index == params.root) { + if (my_device_index == communication->root()) { output_tensors.resize(1); int64_t j = 0; for (auto i : c10::irange(communication->team().size())) { @@ -270,16 +256,15 @@ c10::intrusive_ptr postScatter( c10::intrusive_ptr backend, at::Tensor input_tensor, at::Tensor output_tensor) { - const CommParams& params = communication->params(); - - if (my_device_index == params.root && !communication->isRootInMesh()) { + if (my_device_index == communication->root() && + !communication->isRootInMesh()) { output_tensor = at::empty_like(input_tensor.slice(0, 0, 1)); } std::vector output_tensors({output_tensor}); - auto root_relative_index = getRootRelativeIndex(params); + auto root_relative_index = communication->getRootRelativeIndex(); std::vector> input_tensors; - if (my_device_index == params.root) { + if (my_device_index == communication->root()) { input_tensors.resize(1); int64_t j = 0; for (auto i : c10::irange(communication->team().size())) { @@ -306,10 +291,8 @@ c10::intrusive_ptr postReduce( c10::intrusive_ptr backend, at::Tensor input_tensor, at::Tensor output_tensor) { - const CommParams& params = communication->params(); - at::Tensor tensor; - if (my_device_index == params.root) { + if (my_device_index == communication->root()) { if (communication->isRootInMesh()) { doLocalCopy(output_tensor, input_tensor); tensor = output_tensor; @@ -317,7 +300,7 @@ c10::intrusive_ptr postReduce( NVF_ERROR( output_tensor.scalar_type() == at::kFloat, "only float tensors are supported"); - output_tensor.fill_(getInitialValue(params.redOp)); + output_tensor.fill_(getInitialValue(communication->reduceOp())); tensor = output_tensor; } } else { @@ -326,7 +309,8 @@ c10::intrusive_ptr postReduce( std::vector tensors({tensor}); c10d::ReduceOptions options = { - .reduceOp = params.redOp, .rootRank = getRootRelativeIndex(params)}; + .reduceOp = communication->reduceOp(), + .rootRank = communication->getRootRelativeIndex()}; // TODO: avoid local copy by using out-of-place reduction. return backend->reduce(tensors, options); } @@ -337,12 +321,11 @@ c10::intrusive_ptr postAllreduce( c10::intrusive_ptr backend, at::Tensor input_tensor, at::Tensor output_tensor) { - const CommParams& params = communication->params(); - doLocalCopy(output_tensor, input_tensor); std::vector output_tensors({output_tensor}); - return backend->allreduce(output_tensors, {.reduceOp = params.redOp}); + return backend->allreduce( + output_tensors, {.reduceOp = communication->reduceOp()}); } c10::intrusive_ptr postReduceScatter( @@ -351,21 +334,19 @@ c10::intrusive_ptr postReduceScatter( c10::intrusive_ptr backend, at::Tensor input_tensor, at::Tensor output_tensor) { - const CommParams& params = communication->params(); - std::vector> input_tensors(1); + const auto scattered_axis = communication->scatteredAxis(); NVF_ERROR( - params.scattered_axis >= 0, + scattered_axis >= 0, "scattered_axis is expected to be non-negative: ", - params.scattered_axis) - input_tensors[0] = - at::split(input_tensor, /*split_size=*/1, params.scattered_axis); + scattered_axis) + input_tensors[0] = at::split(input_tensor, /*split_size=*/1, scattered_axis); std::vector output_tensors({output_tensor}); assertBufferCount(input_tensors[0], communication->team().size()); return backend->reduce_scatter( - output_tensors, input_tensors, {.reduceOp = params.redOp}); + output_tensors, input_tensors, {.reduceOp = communication->reduceOp()}); } c10::intrusive_ptr postSendRecv( @@ -374,17 +355,15 @@ c10::intrusive_ptr postSendRecv( c10::intrusive_ptr backend, at::Tensor input_tensor, at::Tensor output_tensor) { - const CommParams& params = communication->params(); - - NVF_ERROR(params.mesh.size() == 1, "The mesh size should be 1."); + NVF_ERROR(communication->mesh().size() == 1, "The mesh size should be 1."); if (communication->isRootInMesh()) { doLocalCopy(output_tensor, input_tensor); return nullptr; } - const DeviceIdxType sender = params.root; - const DeviceIdxType receiver = params.mesh.at(0); + const DeviceIdxType sender = communication->root(); + const DeviceIdxType receiver = communication->mesh().at(0); std::vector tensors; if (my_device_index == sender) { @@ -404,7 +383,6 @@ c10::intrusive_ptr postSingleCommunication( c10::intrusive_ptr backend, at::Tensor input_tensor, at::Tensor output_tensor) { - const CommParams& params = communication->params(); const Team& team = communication->team(); NVF_ERROR( std::find(team.begin(), team.end(), my_device_index) != team.end(), @@ -412,7 +390,7 @@ c10::intrusive_ptr postSingleCommunication( my_device_index, " must be present in the communication's team"); - switch (communication->params().type) { + switch (communication->type()) { case CommunicationType::Gather: return postGather( communication, my_device_index, backend, input_tensor, output_tensor); @@ -438,7 +416,7 @@ c10::intrusive_ptr postSingleCommunication( return postSendRecv( communication, my_device_index, backend, input_tensor, output_tensor); default: - NVF_ERROR(false, "Wrong communication type: ", params.type); + NVF_ERROR(false, "Wrong communication type: ", communication->type()); return nullptr; } } diff --git a/csrc/multidevice/communication.h b/csrc/multidevice/communication.h index 95380c92bd2..da3d32a0e36 100644 --- a/csrc/multidevice/communication.h +++ b/csrc/multidevice/communication.h @@ -35,18 +35,7 @@ enum class CommunicationType { std::ostream& operator<<(std::ostream& os, const CommunicationType& type); -// This struct gathers all the parameters needed to construct a Communication. -struct CommParams { - CommunicationType type; - DeviceIdxType root = -1; - DeviceMesh mesh; // Might not contain `root`. - Team team; // All devices involved in this communication. It must include - // `root`. It can be a subset of `root`+`mesh` in case of 2D - // sharding. - c10d::ReduceOp::RedOpType redOp = c10d::ReduceOp::RedOpType::UNUSED; - // reduced_axis is always outermost. - int64_t scattered_axis = -1; -}; +using RedOpType = c10d::ReduceOp::RedOpType; // The class "Communication" represents a MPI-style communication // communication operation to be executed on the network. The base class @@ -56,8 +45,22 @@ struct CommParams { class Communication : public Expr { public: using Expr::Expr; - Communication(IrBuilderPasskey passkey, CommParams params); - Communication(const Communication* src, IrCloner* ir_cloner); + // Only specify `root` for types that have root. + // Only specify `red_op` for reduction types. + // Only specify `scattered_axis` for ReduceScatter. + // + // TODO: pass in input/output TV and compute root, mesh and scatteredAxis from + // them. + Communication( + IrBuilderPasskey passkey, + CommunicationType type, + DeviceMesh mesh, // Might not contain `root`. + Team team, // All devices involved in this communication. It must include + // `root`. It can be a subset of `root`+`mesh` in case of 2D + // sharding. + DeviceIdxType root = -1, + RedOpType red_op = RedOpType::UNUSED, + int64_t scattered_axis = -1); Communication(const Communication& other) = delete; Communication& operator=(const Communication& other) = delete; @@ -72,27 +75,38 @@ class Communication : public Expr { return "Communication"; } - // TDOO: add params_ (or flattened parameters) as data attributes so this and - // the constructor that takes IrCloner aren't needed. - bool sameAs(const Statement* other) const override; + CommunicationType type() const { + return attribute(0); + } - const CommParams& params() const { - return params_; + const DeviceMesh& mesh() const { + return attribute(1); } - bool isRootInMesh() const { - return params_.mesh.has(params_.root); + const Team& team() const { + return attribute(2); } - const Team& team(); + DeviceIdxType root() const { + return attribute(3); + } + + RedOpType reduceOp() const { + return attribute(4); + } + + int64_t scatteredAxis() const { + return attribute(5); + } + + bool isRootInMesh() const { + return mesh().has(root()); + } - private: - // Stores the arguments used to construct the communication. - CommParams params_; - // This can be computed from params_, but given how frequently this is used - // in the hot path, I'm currently storing it as a field that'll be computed by - // Communication::team(). - Team team_; + // PyTorch's process group expects the root to be specified + // as an integer between 0 and world_size-1. We choose it to be + // the device's relative index within the team + int64_t getRootRelativeIndex(); }; // The method "post" triggers the execution of the communication. This call is diff --git a/csrc/multidevice/lower_communication.cpp b/csrc/multidevice/lower_communication.cpp index 680f30bfb69..bc1b4d7f962 100644 --- a/csrc/multidevice/lower_communication.cpp +++ b/csrc/multidevice/lower_communication.cpp @@ -72,11 +72,8 @@ void lowerToScatter( if (!receiver_mesh.has(root)) { team.push_back(root); } - comms.push_back(IrBuilder::create(CommParams{ - .type = CommunicationType::Scatter, - .root = root, - .mesh = receiver_mesh, - .team = team})); + comms.push_back(IrBuilder::create( + CommunicationType::Scatter, receiver_mesh, team, root)); } /* @@ -100,11 +97,8 @@ void lowerToGather( if (!sender_mesh.has(root)) { team.push_back(root); } - comms.push_back(IrBuilder::create(CommParams{ - .type = CommunicationType::Gather, - .root = root, - .mesh = sender_mesh, - .team = team})); + comms.push_back(IrBuilder::create( + CommunicationType::Gather, sender_mesh, team, root)); } } @@ -119,10 +113,8 @@ void lowerToAllgather( return; } - comms.push_back(IrBuilder::create(CommParams{ - .type = CommunicationType::Allgather, - .mesh = mesh, - .team = mesh.vector()})); + comms.push_back(IrBuilder::create( + CommunicationType::Allgather, mesh, mesh.vector())); } // Adds one or zero Broadcast or Send/Recv communication to the vector 'comms' @@ -138,11 +130,8 @@ void lowerToBroadcastOrP2P( if (!mesh.has(root)) { team.push_back(root); } - comms.push_back(IrBuilder::create(CommParams{ - .type = CommunicationType::Broadcast, - .root = root, - .mesh = mesh, - .team = team})); + comms.push_back(IrBuilder::create( + CommunicationType::Broadcast, mesh, team, root)); } // Adds several Broadcast or Send/Recv communications to the vector 'comms' @@ -195,12 +184,8 @@ void lowerToReduce( if (!sender_mesh.has(root)) { team.push_back(root); } - comms.push_back(IrBuilder::create(CommParams{ - .type = CommunicationType::Reduce, - .root = root, - .mesh = sender_mesh, - .team = team, - .redOp = reduce_op_type})); + comms.push_back(IrBuilder::create( + CommunicationType::Reduce, sender_mesh, team, root, reduce_op_type)); } } @@ -215,11 +200,12 @@ void lowerToAllreduce( return; } - comms.push_back(IrBuilder::create(CommParams{ - .type = CommunicationType::Allreduce, - .mesh = mesh, - .team = mesh.vector(), - .redOp = getC10dReduceOpType(op_type)})); + comms.push_back(IrBuilder::create( + CommunicationType::Allreduce, + mesh, + mesh.vector(), + /*root=*/-1, + getC10dReduceOpType(op_type))); } void lowerToReduceScatter( @@ -243,12 +229,13 @@ void lowerToReduceScatter( scattered_axis++; } - comms.push_back(IrBuilder::create(CommParams{ - .type = CommunicationType::ReduceScatter, - .mesh = mesh, - .team = mesh.vector(), - .redOp = getC10dReduceOpType(op_type), - .scattered_axis = scattered_axis})); + comms.push_back(IrBuilder::create( + CommunicationType::ReduceScatter, + mesh, + mesh.vector(), + /*root=*/-1, + getC10dReduceOpType(op_type), + scattered_axis)); } } // namespace diff --git a/tests/cpp/test_multidevice_communications.cpp b/tests/cpp/test_multidevice_communications.cpp index 53c94ab2a7b..15bdb622b34 100644 --- a/tests/cpp/test_multidevice_communications.cpp +++ b/tests/cpp/test_multidevice_communications.cpp @@ -63,12 +63,7 @@ void CommunicationTest::validate(at::Tensor obtained, at::Tensor expected) { TEST_P(CommunicationTest, Gather) { auto communication = IrBuilder::create( - &container, - CommParams{ - .type = CommunicationType::Gather, - .root = kRoot, - .mesh = full_mesh_, - .team = all_ranks_}); + &container, CommunicationType::Gather, full_mesh_, all_ranks_, kRoot); at::Tensor input_tensor = at::empty({1, kTensorSize}, tensor_options); at::Tensor output_tensor = @@ -96,11 +91,7 @@ TEST_P(CommunicationTest, Gather) { TEST_P(CommunicationTest, Allgather) { auto communication = IrBuilder::create( - &container, - CommParams{ - .type = CommunicationType::Allgather, - .mesh = full_mesh_, - .team = all_ranks_}); + &container, CommunicationType::Allgather, full_mesh_, all_ranks_); at::Tensor input_tensor = at::empty({1, kTensorSize}, tensor_options); at::Tensor output_tensor = @@ -127,12 +118,7 @@ TEST_P(CommunicationTest, Allgather) { TEST_P(CommunicationTest, Scatter) { auto communication = IrBuilder::create( - &container, - CommParams{ - .type = CommunicationType::Scatter, - .root = kRoot, - .mesh = full_mesh_, - .team = all_ranks_}); + &container, CommunicationType::Scatter, full_mesh_, all_ranks_, kRoot); at::Tensor input_tensor; if (communicator->deviceId() == kRoot) { @@ -165,12 +151,7 @@ TEST_P(CommunicationTest, Scatter) { TEST_P(CommunicationTest, Broadcast) { auto communication = IrBuilder::create( - &container, - CommParams{ - .type = CommunicationType::Broadcast, - .root = kRoot, - .mesh = full_mesh_, - .team = all_ranks_}); + &container, CommunicationType::Broadcast, full_mesh_, all_ranks_, kRoot); at::Tensor input_tensor; if (communicator->deviceId() == kRoot) { @@ -214,11 +195,10 @@ TEST_P(CommunicationTest, SendRecv) { auto communication = IrBuilder::create( &container, - CommParams{ - .type = CommunicationType::SendRecv, - .root = sender, - .mesh = {receiver}, - .team = {sender, receiver}}); + CommunicationType::SendRecv, + DeviceMesh({receiver}), + /*team=*/{sender, receiver}, + /*root=*/sender); at::Tensor input_tensor; at::Tensor output_tensor; @@ -258,11 +238,10 @@ TEST_P(CommunicationTest, SendRecvToSelf) { auto communication = IrBuilder::create( &container, - CommParams{ - .type = CommunicationType::SendRecv, - .root = sender, - .mesh = {sender}, - .team = {sender}}); + CommunicationType::SendRecv, + DeviceMesh({sender}), + /*team=*/{sender}, + /*root=*/sender); at::Tensor input_tensor = at::empty({kTensorSize}, tensor_options); at::Tensor output_tensor = at::empty_like(input_tensor); @@ -285,12 +264,11 @@ TEST_P(CommunicationTest, SendRecvToSelf) { TEST_P(CommunicationTest, Reduce) { auto communication = IrBuilder::create( &container, - CommParams{ - .type = CommunicationType::Reduce, - .root = kRoot, - .mesh = full_mesh_, - .team = all_ranks_, - .redOp = kReductionOp}); + CommunicationType::Reduce, + full_mesh_, + all_ranks_, + kRoot, + kReductionOp); at::Tensor input_tensor = at::empty({1, kTensorSize}, tensor_options); at::Tensor output_tensor = at::empty({kTensorSize}, tensor_options); @@ -320,11 +298,11 @@ TEST_P(CommunicationTest, Reduce) { TEST_P(CommunicationTest, Allreduce) { auto communication = IrBuilder::create( &container, - CommParams{ - .type = CommunicationType::Allreduce, - .mesh = full_mesh_, - .team = all_ranks_, - .redOp = kReductionOp}); + CommunicationType::Allreduce, + full_mesh_, + all_ranks_, + /*root=*/-1, + kReductionOp); at::Tensor input_tensor = at::empty({1, kTensorSize}, tensor_options); at::Tensor output_tensor = at::empty({kTensorSize}, tensor_options); @@ -351,12 +329,12 @@ TEST_P(CommunicationTest, Allreduce) { TEST_P(CommunicationTest, ReduceScatter) { auto communication = IrBuilder::create( &container, - CommParams{ - .type = CommunicationType::ReduceScatter, - .mesh = full_mesh_, - .team = all_ranks_, - .redOp = kReductionOp, - .scattered_axis = 1}); + CommunicationType::ReduceScatter, + full_mesh_, + all_ranks_, + /*root=*/-1, + kReductionOp, + /*scattered_axis=*/1); const int num_devices = communicator->size(); const int device_id = communicator->deviceId(); From f173f72715576bd90c2eb2483d0fcc2c9dee8376 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Mon, 3 Jun 2024 21:47:48 +0000 Subject: [PATCH 3/4] Fix build. --- csrc/host_ir/executor.cpp | 14 +++++++------- tests/cpp/test_multidevice_communications.cpp | 4 ++-- tests/cpp/test_multidevice_host_ir.cpp | 9 ++++----- 3 files changed, 13 insertions(+), 14 deletions(-) diff --git a/csrc/host_ir/executor.cpp b/csrc/host_ir/executor.cpp index 5a7e4ecb03c..a6e3a0fa10e 100644 --- a/csrc/host_ir/executor.cpp +++ b/csrc/host_ir/executor.cpp @@ -19,7 +19,7 @@ HostIrExecutor::HostIrExecutor( HostIrExecutorParams params) : container_(std::move(container)), communicator_(communicator), - params_(params){}; + params_(params) {}; std::vector HostIrExecutor::runWithInput( std::unordered_map val_to_IValue) { @@ -108,12 +108,12 @@ void HostIrExecutor::postCommunication(PostOnStream* post_ir) { post_ir->hostOpToPost()->isA(), "op must be a Communication: ", post_ir->hostOpToPost()); - auto communication = post_ir->hostOpToPost()->as(); + auto* communication = post_ir->hostOpToPost()->as(); NVF_ERROR( std::find( - communication->params().team.begin(), - communication->params().team.end(), - communicator_->deviceId()) != communication->params().team.end(), + communication->team().begin(), + communication->team().end(), + communicator_->deviceId()) != communication->team().end(), "current device index ", communicator_->deviceId(), " must be present in the communication's team"); @@ -129,8 +129,8 @@ void HostIrExecutor::postCommunication(PostOnStream* post_ir) { output_tensor = val_to_IValue_.at(output_val).toTensor(); } - c10::intrusive_ptr backend = communicator_->getBackendForTeam( - communication->params().team, std::nullopt); + c10::intrusive_ptr backend = + communicator_->getBackendForTeam(communication->team(), std::nullopt); c10::intrusive_ptr work = postSingleCommunication( communication, communicator_->deviceId(), diff --git a/tests/cpp/test_multidevice_communications.cpp b/tests/cpp/test_multidevice_communications.cpp index 15bdb622b34..7628357f16a 100644 --- a/tests/cpp/test_multidevice_communications.cpp +++ b/tests/cpp/test_multidevice_communications.cpp @@ -197,7 +197,7 @@ TEST_P(CommunicationTest, SendRecv) { &container, CommunicationType::SendRecv, DeviceMesh({receiver}), - /*team=*/{sender, receiver}, + /*team=*/Team({sender, receiver}), /*root=*/sender); at::Tensor input_tensor; @@ -240,7 +240,7 @@ TEST_P(CommunicationTest, SendRecvToSelf) { &container, CommunicationType::SendRecv, DeviceMesh({sender}), - /*team=*/{sender}, + /*team=*/Team({sender}), /*root=*/sender); at::Tensor input_tensor = at::empty({kTensorSize}, tensor_options); diff --git a/tests/cpp/test_multidevice_host_ir.cpp b/tests/cpp/test_multidevice_host_ir.cpp index 819d57cfba9..1a7eab8d6c5 100644 --- a/tests/cpp/test_multidevice_host_ir.cpp +++ b/tests/cpp/test_multidevice_host_ir.cpp @@ -74,12 +74,11 @@ TEST_P(MultiDeviceHostIrTest, SingleFusionSingleComm) { static_cast(hic.get()), std::move(fusion)); // [Step 3b)] Create a Communication Ir - CommParams comm_params{ - .type = CommunicationType::Allgather, - .mesh = mesh, - .team = mesh.vector()}; auto communication = IrBuilder::create( - static_cast(hic.get()), comm_params); + static_cast(hic.get()), + CommunicationType::Allgather, + mesh, + mesh.vector()); // [Step 4)] Create TensorViews at the Host level IrCloner ir_cloner(hic.get()); From 9e19d67dc27852327b9ce6f499c3b2afb517bce2 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Mon, 3 Jun 2024 22:02:06 +0000 Subject: [PATCH 4/4] Fix lint. --- csrc/host_ir/executor.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/host_ir/executor.cpp b/csrc/host_ir/executor.cpp index a6e3a0fa10e..8b17d69a722 100644 --- a/csrc/host_ir/executor.cpp +++ b/csrc/host_ir/executor.cpp @@ -19,7 +19,7 @@ HostIrExecutor::HostIrExecutor( HostIrExecutorParams params) : container_(std::move(container)), communicator_(communicator), - params_(params) {}; + params_(params) {} std::vector HostIrExecutor::runWithInput( std::unordered_map val_to_IValue) {