Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 38 additions & 53 deletions csrc/multidevice/communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,7 @@ std::ostream& operator<<(std::ostream& os, const CommunicationType& type) {

namespace {

inline void assertBufferCount(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless proven to be important, I tend to leave the inline decision to the compiler. Otherwise, the function could organic grow to a giant without people realizing the overhead of inlining.

const std::vector<at::Tensor>& bufs,
size_t count) {
void assertBufferCount(const std::vector<at::Tensor>& bufs, size_t count) {
NVF_ERROR(
bufs.size() == count,
"there must be ",
Expand All @@ -60,7 +58,7 @@ inline void assertBufferCount(
" were given");
}

inline void assertBuffersHaveSameSize(
void assertBuffersHaveSameSize(
const std::vector<at::Tensor>& bufs1,
const std::vector<at::Tensor>& bufs2) {
if (bufs1.empty() && bufs2.empty()) {
Expand All @@ -76,7 +74,7 @@ inline void assertBuffersHaveSameSize(
}
}

inline void doLocalCopy(const at::Tensor& dst, const at::Tensor& src) {
void doLocalCopy(const at::Tensor& dst, const at::Tensor& src) {
dst.view_as(src).copy_(src, /*non_blocking=*/true);
}

Expand All @@ -100,7 +98,7 @@ T getInitialValue(c10d::ReduceOp::RedOpType op) {

bool hasRoot(CommunicationType type) {
return type == CommunicationType::Gather ||
type == CommunicationType::Scatter ||
type == CommunicationType::Scatter || type == CommunicationType::Reduce ||
type == CommunicationType::Broadcast ||
type == CommunicationType::SendRecv;
}
Expand All @@ -111,36 +109,27 @@ bool isReduction(CommunicationType type) {
type == CommunicationType::ReduceScatter;
}

void assertValid(const CommParams& params) {
std::unordered_set<DeviceIdxType> team_without_duplicates(
params.team.begin(), params.team.end());
NVF_ERROR(
team_without_duplicates.size() == params.team.size(),
"the communication must not involve the same device more than once");
Comment on lines -115 to -119
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could move this check to Communication::team() but it looks unnecessary because

  1. DeviceMesh already checks against duplicates, and
  2. Communication::team() adds the root only when it's not in the mesh.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sounds ok to me. I think only the CommunicationTests directly set the team, but the normal use case creates a team from the DeviceMesh, which checks against duplicates.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR changed the CommunicationTests as well to set mesh instead of team. The bottom line is that CommParams contains all parameters (e.g. mesh) needed for constructing a Communication and other fields (e.g. team) are all computed.

NVF_ERROR(!params.team.empty(), "the team size must be greater than 0");
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is moved to Communication's constructor.

if (hasRoot(params.type)) {
auto it = std::find(params.team.begin(), params.team.end(), params.root);
NVF_ERROR(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check could be moved to Communication::team(), but it would look a bit silly:

if (root isn't in team) {
  team.push_back(root);
}
assert root is in team  // obviously true

it != params.team.end(),
"root (device ",
params.root,
") must be present in the communication's team");
}
}

// pytorch's process group expects the root to be specified
// as an integer between 0 and world_size-1. We choose it to be
// the device's relative index within the team
int64_t getRootRelativeIndex(const CommParams& params) {
auto it = std::find(params.team.begin(), params.team.end(), params.root);
return std::distance(params.team.begin(), it);
int64_t relative_index = params.mesh.idxOf(params.root);
// This assumes the root, if not in mesh, is added to the end of the team
// vector. Given the assumption is distantly enforced at
// Communication::team(), I slightly prefer making this function a method of
// Communication so relative_index can be computed from Communication::team()
// directly. Wdyt?
if (relative_index == -1) {
relative_index = params.mesh.size();
}
return relative_index;
}

} // namespace

Communication::Communication(IrBuilderPasskey passkey, CommParams params)
: Expr(passkey), params_(std::move(params)) {
assertValid(params_);
NVF_ERROR(params_.mesh.size() > 0, "The mesh size must be greater than 0.");
}

Communication::Communication(const Communication* src, IrCloner* ir_cloner)
Expand All @@ -153,10 +142,10 @@ std::string Communication::toString(const int indent_size) const {

indent(ss, indent_size) << "Communication " << params_.type << ": {"
<< std::endl;

if (hasRoot(params_.type)) {
indent(ss, indent_size + 1) << "root: " << params_.root << "," << std::endl;
}
indent(ss, indent_size + 1) << "mesh: " << params_.mesh << "," << std::endl;
indent(ss, indent_size + 1) << "team: " << params_.team << "," << std::endl;
indent(ss, indent_size) << "}";

Expand All @@ -178,9 +167,10 @@ bool Communication::sameAs(const Statement* other) const {
const auto& p1 = this->params();
const auto& p2 = other->as<Communication>()->params();

return (
p1.type == p2.type && (!hasRoot(p1.type) || p1.root == p2.root) &&
p1.team == p2.team && (!isReduction(p1.type) || p1.redOp == p2.redOp));
return (p1.type == p2.type && (!hasRoot(p1.type) || p1.root == p2.root) &&
p1.mesh == p2.mesh &&
(!isReduction(p1.type) || p1.redOp == p2.redOp)) &&
p1.team == p2.team;
}

namespace {
Expand All @@ -192,7 +182,7 @@ c10::intrusive_ptr<c10d::Work> postBroadcast(
at::Tensor output_tensor) {
const CommParams& params = communication->params();
if (my_device_index == params.root) {
if (params.is_root_in_mesh) {
if (communication->isRootInMesh()) {
// Do a local copy and the subsequent broadcast will be in place. Consider
// ProcessGroupNCCL::_broadcast_oop so ncclBroadcast doesn't wait for the
// local copy to complete.
Expand All @@ -203,7 +193,7 @@ c10::intrusive_ptr<c10d::Work> postBroadcast(
}
}

if (params.team.size() == 1) {
if (communication->team().size() == 1) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was an intermediate change that turns out to be unnecessary for this PR. However, #2256 will remove CommParams, so I chose to keep this change which makes the code closer to the end state.

return nullptr;
}

Expand All @@ -219,7 +209,7 @@ c10::intrusive_ptr<c10d::Work> postGather(
at::Tensor input_tensor,
at::Tensor output_tensor) {
const CommParams& params = communication->params();
if (my_device_index == params.root && !params.is_root_in_mesh) {
if (my_device_index == params.root && !communication->isRootInMesh()) {
// This is likely a suboptimal way to allocate tensors for nccl. To benefit
// from zero copy
// (https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/bufferreg.html),
Expand All @@ -235,17 +225,17 @@ c10::intrusive_ptr<c10d::Work> postGather(
if (my_device_index == params.root) {
output_tensors.resize(1);
int64_t j = 0;
for (auto i : c10::irange(params.team.size())) {
for (auto i : c10::irange(communication->team().size())) {
if (root_relative_index == static_cast<DeviceIdxType>(i) &&
!params.is_root_in_mesh) {
!communication->isRootInMesh()) {
output_tensors[0].push_back(input_tensor);
continue;
}
output_tensors[0].push_back(output_tensor.slice(0, j, j + 1));
j++;
}

assertBufferCount(output_tensors[0], params.team.size());
assertBufferCount(output_tensors[0], communication->team().size());
assertBuffersHaveSameSize(input_tensors, output_tensors[0]);
}

Expand All @@ -259,13 +249,11 @@ c10::intrusive_ptr<c10d::Work> postAllgather(
c10::intrusive_ptr<c10d::Backend> backend,
at::Tensor input_tensor,
at::Tensor output_tensor) {
const CommParams& params = communication->params();

std::vector<at::Tensor> input_tensors({input_tensor});
std::vector<std::vector<at::Tensor>> output_tensors(1);
output_tensors[0] = at::split(output_tensor, /*split_size=*/1, /*dim=*/0);

assertBufferCount(output_tensors[0], params.team.size());
assertBufferCount(output_tensors[0], communication->team().size());
assertBuffersHaveSameSize(input_tensors, output_tensors[0]);
return backend->allgather(output_tensors, input_tensors, {});
}
Expand All @@ -278,7 +266,7 @@ c10::intrusive_ptr<c10d::Work> postScatter(
at::Tensor output_tensor) {
const CommParams& params = communication->params();

if (my_device_index == params.root && !params.is_root_in_mesh) {
if (my_device_index == params.root && !communication->isRootInMesh()) {
output_tensor = at::empty_like(input_tensor.slice(0, 0, 1));
}
std::vector<at::Tensor> output_tensors({output_tensor});
Expand All @@ -288,17 +276,17 @@ c10::intrusive_ptr<c10d::Work> postScatter(
if (my_device_index == params.root) {
input_tensors.resize(1);
int64_t j = 0;
for (auto i : c10::irange(params.team.size())) {
for (auto i : c10::irange(communication->team().size())) {
if (root_relative_index == static_cast<DeviceIdxType>(i) &&
!params.is_root_in_mesh) {
!communication->isRootInMesh()) {
input_tensors.front().push_back(output_tensor);
continue;
}
input_tensors.front().push_back(input_tensor.slice(0, j, j + 1));
j++;
}

assertBufferCount(input_tensors[0], params.team.size());
assertBufferCount(input_tensors[0], communication->team().size());
assertBuffersHaveSameSize(input_tensors[0], output_tensors);
}

Expand All @@ -316,7 +304,7 @@ c10::intrusive_ptr<c10d::Work> postReduce(

at::Tensor tensor;
if (my_device_index == params.root) {
if (params.is_root_in_mesh) {
if (communication->isRootInMesh()) {
doLocalCopy(output_tensor, input_tensor);
tensor = output_tensor;
} else {
Expand Down Expand Up @@ -369,7 +357,7 @@ c10::intrusive_ptr<c10d::Work> postReduceScatter(

std::vector<at::Tensor> output_tensors({output_tensor});

assertBufferCount(input_tensors[0], params.team.size());
assertBufferCount(input_tensors[0], communication->team().size());
return backend->reduce_scatter(
output_tensors, input_tensors, {.reduceOp = params.redOp});
}
Expand All @@ -382,18 +370,15 @@ c10::intrusive_ptr<c10d::Work> postSendRecv(
at::Tensor output_tensor) {
const CommParams& params = communication->params();

NVF_ERROR(
params.team.size() == 1 || params.team.size() == 2,
"the team size should be 1 or 2");
NVF_ERROR(params.mesh.size() == 1, "The mesh size should be 1.");

if (params.team.size() == 1) {
if (communication->isRootInMesh()) {
doLocalCopy(output_tensor, input_tensor);
return nullptr;
}

const DeviceIdxType sender = params.root;
const DeviceIdxType receiver =
params.team.at(0) == sender ? params.team.at(1) : params.team.at(0);
const DeviceIdxType receiver = params.mesh.at(0);

std::vector<at::Tensor> tensors;
if (my_device_index == sender) {
Expand All @@ -414,9 +399,9 @@ c10::intrusive_ptr<c10d::Work> postSingleCommunication(
at::Tensor input_tensor,
at::Tensor output_tensor) {
const CommParams& params = communication->params();
const Team& team = communication->team();
NVF_ERROR(
std::find(params.team.begin(), params.team.end(), my_device_index) !=
params.team.end(),
std::find(team.begin(), team.end(), my_device_index) != team.end(),
"current device index ",
my_device_index,
" must be present in the communication's team");
Expand Down
22 changes: 16 additions & 6 deletions csrc/multidevice/communication.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <ir/base_nodes.h>
#include <ir/builder.h>
#include <multidevice/communicator.h>
#include <multidevice/device_mesh.h>
#include <multidevice/multidevice.h>
#ifdef NVFUSER_DISTRIBUTED
#include <torch/csrc/distributed/c10d/Types.hpp>
Expand Down Expand Up @@ -38,10 +39,12 @@ std::ostream& operator<<(std::ostream& os, const CommunicationType& type);
struct CommParams {
CommunicationType type;
DeviceIdxType root = -1;
bool is_root_in_mesh = true;
Team team; // should not have duplicates and should contain both the root and
// the mesh
DeviceMesh mesh; // Might not contain `root`.
Team team; // All devices involved in this communication. It must include
// `root`. It can be a subset of `root`+`mesh` in case of 2D
// sharding.
c10d::ReduceOp::RedOpType redOp = c10d::ReduceOp::RedOpType::UNUSED;
// reduced_axis is always outermost.
int64_t scattered_axis = -1;
};

Expand Down Expand Up @@ -141,13 +144,20 @@ class Communication : public Expr {
// the constructor that takes IrCloner aren't needed.
bool sameAs(const Statement* other) const override;

// TODO: const CommParams&.
auto params() const {
const CommParams& params() const {
return params_;
}

bool isRootInMesh() const {
return params_.mesh.has(params_.root);
}

const Team& team() const {
return params_.team;
}

private:
// store the arguments of the communication
// Stores the arguments used to construct the communication.
CommParams params_;
};

Expand Down
2 changes: 1 addition & 1 deletion csrc/multidevice/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ void MultiDeviceExecutor::postCommunication(SegmentedGroup* group) {
// post and wait communications
for (Communication* communication : communications) {
c10::intrusive_ptr<c10d::Backend> backend =
comm_.getBackendForTeam(communication->params().team, std::nullopt);
comm_.getBackendForTeam(communication->team(), std::nullopt);
c10::intrusive_ptr<c10d::Work> work = postSingleCommunication(
communication, comm_.deviceId(), backend, input_tensor, output_tensor);
if (work != nullptr) {
Expand Down
Loading