Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions csrc/host_ir/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ HostIrExecutor::HostIrExecutor(
HostIrExecutorParams params)
: container_(std::move(container)),
communicator_(communicator),
params_(params){};
params_(params) {}

std::vector<at::Tensor> HostIrExecutor::runWithInput(
std::unordered_map<Val*, c10::IValue> val_to_IValue) {
Expand Down Expand Up @@ -108,12 +108,12 @@ void HostIrExecutor::postCommunication(PostOnStream* post_ir) {
post_ir->hostOpToPost()->isA<Communication>(),
"op must be a Communication: ",
post_ir->hostOpToPost());
auto communication = post_ir->hostOpToPost()->as<Communication>();
auto* communication = post_ir->hostOpToPost()->as<Communication>();
NVF_ERROR(
std::find(
communication->params().team.begin(),
communication->params().team.end(),
communicator_->deviceId()) != communication->params().team.end(),
communication->team().begin(),
communication->team().end(),
communicator_->deviceId()) != communication->team().end(),
"current device index ",
communicator_->deviceId(),
" must be present in the communication's team");
Expand All @@ -129,8 +129,8 @@ void HostIrExecutor::postCommunication(PostOnStream* post_ir) {
output_tensor = val_to_IValue_.at(output_val).toTensor();
}

c10::intrusive_ptr<c10d::Backend> backend = communicator_->getBackendForTeam(
communication->params().team, std::nullopt);
c10::intrusive_ptr<c10d::Backend> backend =
communicator_->getBackendForTeam(communication->team(), std::nullopt);
c10::intrusive_ptr<c10d::Work> work = postSingleCommunication(
communication,
communicator_->deviceId(),
Expand Down
144 changes: 61 additions & 83 deletions csrc/multidevice/communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,50 +109,54 @@ bool isReduction(CommunicationType type) {
type == CommunicationType::ReduceScatter;
}

// pytorch's process group expects the root to be specified
// as an integer between 0 and world_size-1. We choose it to be
// the device's relative index within the team
int64_t getRootRelativeIndex(const CommParams& params) {
int64_t relative_index = params.mesh.idxOf(params.root);
// This assumes the root, if not in mesh, is added to the end of the team
// vector. Given the assumption is distantly enforced at
// Communication::team(), I slightly prefer making this function a method of
// Communication so relative_index can be computed from Communication::team()
// directly. Wdyt?
if (relative_index == -1) {
relative_index = params.mesh.size();
}
return relative_index;
}

} // namespace

Communication::Communication(IrBuilderPasskey passkey, CommParams params)
: Expr(passkey), params_(std::move(params)) {
NVF_ERROR(params_.mesh.size() > 0, "The mesh size must be greater than 0.");
Communication::Communication(
IrBuilderPasskey passkey,
CommunicationType type,
DeviceMesh mesh,
Team team,
DeviceIdxType root,
RedOpType red_op,
int64_t scattered_axis)
: Expr(passkey) {
NVF_ERROR(mesh.size() > 0, "The mesh size must be greater than 0.");
NVF_ERROR(
hasRoot(params_.type) == (params_.root >= 0),
hasRoot(type) == (root >= 0),
"Root ",
params_.root,
root,
" is not expected by CommunicationType ",
params_.type);
type);
NVF_ERROR(isReduction(type) == (red_op != RedOpType::UNUSED))
NVF_ERROR(
(type == CommunicationType::ReduceScatter) == (scattered_axis >= 0));

addDataAttribute(type);
addDataAttribute(mesh);
addDataAttribute(team);
addDataAttribute(root);
addDataAttribute(red_op);
addDataAttribute(scattered_axis);
}

Communication::Communication(const Communication* src, IrCloner* ir_cloner)
: Expr(src, ir_cloner), params_(src->params()) {}

NVFUSER_DEFINE_CLONE_AND_CREATE(Communication)

int64_t Communication::getRootRelativeIndex() {
auto i = std::find(team().begin(), team().end(), root());
NVF_ERROR(
i != team().end(), "Unable to find root ", root(), " in team ", team());
return std::distance(team().begin(), i);
}

std::string Communication::toString(const int indent_size) const {
std::stringstream ss;

indent(ss, indent_size) << "Communication " << params_.type << ": {"
<< std::endl;
if (hasRoot(params_.type)) {
indent(ss, indent_size + 1) << "root: " << params_.root << "," << std::endl;
indent(ss, indent_size) << "Communication " << type() << ": {" << std::endl;
if (hasRoot(type())) {
indent(ss, indent_size + 1) << "root: " << root() << "," << std::endl;
}
indent(ss, indent_size + 1) << "mesh: " << params_.mesh << "," << std::endl;
indent(ss, indent_size + 1) << "team: " << params_.team << "," << std::endl;
indent(ss, indent_size + 1) << "mesh: " << mesh() << "," << std::endl;
indent(ss, indent_size + 1) << "team: " << team() << "," << std::endl;
indent(ss, indent_size) << "}";

return ss.str();
Expand All @@ -162,32 +166,14 @@ std::string Communication::toInlineString(int indent_size) const {
return toString(indent_size);
}

// TODO add checking symbolic representation of src and dst buffers
bool Communication::sameAs(const Statement* other) const {
if (other == this) {
return true;
}
if (!other->isA<Communication>()) {
return false;
}
const auto& p1 = this->params();
const auto& p2 = other->as<Communication>()->params();

return (p1.type == p2.type && (!hasRoot(p1.type) || p1.root == p2.root) &&
p1.mesh == p2.mesh &&
(!isReduction(p1.type) || p1.redOp == p2.redOp)) &&
p1.team == p2.team;
}

namespace {
c10::intrusive_ptr<c10d::Work> postBroadcast(
Communication* communication,
DeviceIdxType my_device_index,
c10::intrusive_ptr<c10d::Backend> backend,
at::Tensor input_tensor,
at::Tensor output_tensor) {
const CommParams& params = communication->params();
if (my_device_index == params.root) {
if (my_device_index == communication->root()) {
if (communication->isRootInMesh()) {
// Do a local copy and the subsequent broadcast will be in place. Consider
// ProcessGroupNCCL::_broadcast_oop so ncclBroadcast doesn't wait for the
Expand All @@ -205,7 +191,7 @@ c10::intrusive_ptr<c10d::Work> postBroadcast(

std::vector<at::Tensor> tensors({output_tensor});
return backend->broadcast(
tensors, {.rootRank = getRootRelativeIndex(params)});
tensors, {.rootRank = communication->getRootRelativeIndex()});
}

c10::intrusive_ptr<c10d::Work> postGather(
Expand All @@ -214,8 +200,8 @@ c10::intrusive_ptr<c10d::Work> postGather(
c10::intrusive_ptr<c10d::Backend> backend,
at::Tensor input_tensor,
at::Tensor output_tensor) {
const CommParams& params = communication->params();
if (my_device_index == params.root && !communication->isRootInMesh()) {
if (my_device_index == communication->root() &&
!communication->isRootInMesh()) {
// This is likely a suboptimal way to allocate tensors for nccl. To benefit
// from zero copy
// (https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/bufferreg.html),
Expand All @@ -226,9 +212,9 @@ c10::intrusive_ptr<c10d::Work> postGather(
}
std::vector<at::Tensor> input_tensors({input_tensor});

auto root_relative_index = getRootRelativeIndex(params);
auto root_relative_index = communication->getRootRelativeIndex();
std::vector<std::vector<at::Tensor>> output_tensors;
if (my_device_index == params.root) {
if (my_device_index == communication->root()) {
output_tensors.resize(1);
int64_t j = 0;
for (auto i : c10::irange(communication->team().size())) {
Expand Down Expand Up @@ -270,16 +256,15 @@ c10::intrusive_ptr<c10d::Work> postScatter(
c10::intrusive_ptr<c10d::Backend> backend,
at::Tensor input_tensor,
at::Tensor output_tensor) {
const CommParams& params = communication->params();

if (my_device_index == params.root && !communication->isRootInMesh()) {
if (my_device_index == communication->root() &&
!communication->isRootInMesh()) {
output_tensor = at::empty_like(input_tensor.slice(0, 0, 1));
}
std::vector<at::Tensor> output_tensors({output_tensor});

auto root_relative_index = getRootRelativeIndex(params);
auto root_relative_index = communication->getRootRelativeIndex();
std::vector<std::vector<at::Tensor>> input_tensors;
if (my_device_index == params.root) {
if (my_device_index == communication->root()) {
input_tensors.resize(1);
int64_t j = 0;
for (auto i : c10::irange(communication->team().size())) {
Expand All @@ -306,18 +291,16 @@ c10::intrusive_ptr<c10d::Work> postReduce(
c10::intrusive_ptr<c10d::Backend> backend,
at::Tensor input_tensor,
at::Tensor output_tensor) {
const CommParams& params = communication->params();

at::Tensor tensor;
if (my_device_index == params.root) {
if (my_device_index == communication->root()) {
if (communication->isRootInMesh()) {
doLocalCopy(output_tensor, input_tensor);
tensor = output_tensor;
} else {
NVF_ERROR(
output_tensor.scalar_type() == at::kFloat,
"only float tensors are supported");
output_tensor.fill_(getInitialValue<float>(params.redOp));
output_tensor.fill_(getInitialValue<float>(communication->reduceOp()));
tensor = output_tensor;
}
} else {
Expand All @@ -326,7 +309,8 @@ c10::intrusive_ptr<c10d::Work> postReduce(
std::vector<at::Tensor> tensors({tensor});

c10d::ReduceOptions options = {
.reduceOp = params.redOp, .rootRank = getRootRelativeIndex(params)};
.reduceOp = communication->reduceOp(),
.rootRank = communication->getRootRelativeIndex()};
// TODO: avoid local copy by using out-of-place reduction.
return backend->reduce(tensors, options);
}
Expand All @@ -337,12 +321,11 @@ c10::intrusive_ptr<c10d::Work> postAllreduce(
c10::intrusive_ptr<c10d::Backend> backend,
at::Tensor input_tensor,
at::Tensor output_tensor) {
const CommParams& params = communication->params();

doLocalCopy(output_tensor, input_tensor);
std::vector<at::Tensor> output_tensors({output_tensor});

return backend->allreduce(output_tensors, {.reduceOp = params.redOp});
return backend->allreduce(
output_tensors, {.reduceOp = communication->reduceOp()});
}

c10::intrusive_ptr<c10d::Work> postReduceScatter(
Expand All @@ -351,21 +334,19 @@ c10::intrusive_ptr<c10d::Work> postReduceScatter(
c10::intrusive_ptr<c10d::Backend> backend,
at::Tensor input_tensor,
at::Tensor output_tensor) {
const CommParams& params = communication->params();

std::vector<std::vector<at::Tensor>> input_tensors(1);
const auto scattered_axis = communication->scatteredAxis();
NVF_ERROR(
params.scattered_axis >= 0,
scattered_axis >= 0,
"scattered_axis is expected to be non-negative: ",
params.scattered_axis)
input_tensors[0] =
at::split(input_tensor, /*split_size=*/1, params.scattered_axis);
scattered_axis)
input_tensors[0] = at::split(input_tensor, /*split_size=*/1, scattered_axis);

std::vector<at::Tensor> output_tensors({output_tensor});

assertBufferCount(input_tensors[0], communication->team().size());
return backend->reduce_scatter(
output_tensors, input_tensors, {.reduceOp = params.redOp});
output_tensors, input_tensors, {.reduceOp = communication->reduceOp()});
}

c10::intrusive_ptr<c10d::Work> postSendRecv(
Expand All @@ -374,17 +355,15 @@ c10::intrusive_ptr<c10d::Work> postSendRecv(
c10::intrusive_ptr<c10d::Backend> backend,
at::Tensor input_tensor,
at::Tensor output_tensor) {
const CommParams& params = communication->params();

NVF_ERROR(params.mesh.size() == 1, "The mesh size should be 1.");
NVF_ERROR(communication->mesh().size() == 1, "The mesh size should be 1.");

if (communication->isRootInMesh()) {
doLocalCopy(output_tensor, input_tensor);
return nullptr;
}

const DeviceIdxType sender = params.root;
const DeviceIdxType receiver = params.mesh.at(0);
const DeviceIdxType sender = communication->root();
const DeviceIdxType receiver = communication->mesh().at(0);

std::vector<at::Tensor> tensors;
if (my_device_index == sender) {
Expand All @@ -404,15 +383,14 @@ c10::intrusive_ptr<c10d::Work> postSingleCommunication(
c10::intrusive_ptr<c10d::Backend> backend,
at::Tensor input_tensor,
at::Tensor output_tensor) {
const CommParams& params = communication->params();
const Team& team = communication->team();
NVF_ERROR(
std::find(team.begin(), team.end(), my_device_index) != team.end(),
"current device index ",
my_device_index,
" must be present in the communication's team");

switch (communication->params().type) {
switch (communication->type()) {
case CommunicationType::Gather:
return postGather(
communication, my_device_index, backend, input_tensor, output_tensor);
Expand All @@ -438,7 +416,7 @@ c10::intrusive_ptr<c10d::Work> postSingleCommunication(
return postSendRecv(
communication, my_device_index, backend, input_tensor, output_tensor);
default:
NVF_ERROR(false, "Wrong communication type: ", params.type);
NVF_ERROR(false, "Wrong communication type: ", communication->type());
return nullptr;
}
}
Expand Down
Loading