Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions csrc/host_ir/lower_to_communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,17 +54,24 @@ void lowerToScatter(
TensorView* output_tv,
const HostIrLowerParams& params,
std::vector<Expr*>& comms) {
// we arbitrarily choose the first device of the sender mesh to be the root
const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh();
NVF_ERROR(
receiver_mesh.rank() == 1,
"Gather only supported on a 1D mesh. Given ",
receiver_mesh);
auto root = input_tv->getDeviceMesh().at(0);

// Find a common device between input and receiver meshes to be the root
auto it = std::ranges::find_if(
input_tv->getDeviceMesh().vector(),
[&receiver_mesh](DeviceIdxType device) {
return receiver_mesh.has(device);
});
NVF_ERROR(
it != input_tv->getDeviceMesh().vector().end(),
"No common device found between input and receiver meshes");
DeviceIdxType root = *it;

Team team = receiver_mesh.vector();
if (!receiver_mesh.has(root)) {
team.push_back(root);
}
comms.push_back(IrBuilder::create<Communication>(
CommunicationType::Scatter,
output_tv,
Expand Down
45 changes: 28 additions & 17 deletions csrc/multidevice/communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -378,33 +378,44 @@ c10::intrusive_ptr<c10d::Work> postScatter(
c10d::Backend* backend,
at::Tensor input_tensor,
at::Tensor output_tensor) {
if (my_device_index == communication->root() &&
!communication->out()->getDeviceMesh().has(communication->root())) {
output_tensor = at::empty_like(input_tensor.slice(0, 0, 1));
}
std::vector<at::Tensor> output_tensors({output_tensor});
NVF_ERROR(
isTvContiguous(communication->in()), "Input tensor is not contiguous");
NVF_ERROR(
isTvContiguous(communication->out()), "Output tensor is not contiguous");

auto output_device_mesh = communication->out()->getDeviceMesh();
NVF_ERROR(
output_device_mesh.has(communication->root()),
"communication->root() ",
communication->root(),
" is not in the output device mesh ",
output_device_mesh,
".");

auto root_relative_index = communication->getRootRelativeIndex();
std::vector<std::vector<at::Tensor>> input_tensors;

output_tensor = output_tensor.as_strided({output_tensor.numel()}, {1});
std::vector<at::Tensor> output_tensors({output_tensor});

if (my_device_index == communication->root()) {
auto splits = at::tensor_split(
input_tensor.as_strided({input_tensor.numel()}, {1}),
output_device_mesh.size(),
/*dim=*/0);

input_tensors.resize(1);
int64_t j = 0;
for (auto i : arange(communication->team().size())) {
if (root_relative_index == static_cast<DeviceIdxType>(i) &&
!communication->out()->getDeviceMesh().has(communication->root())) {
input_tensors.front().push_back(output_tensor);
continue;
}
input_tensors.front().push_back(input_tensor.slice(0, j, j + 1));
j++;
for (const auto& split : splits) {
input_tensors.front().push_back(split);
}

assertBufferCount(input_tensors[0], communication->team().size());
assertBufferCount(input_tensors[0], output_device_mesh.size());
assertBuffersHaveSameSize(input_tensors[0], output_tensors);
}

return backend->scatter(
output_tensors, input_tensors, {.rootRank = root_relative_index});
output_tensors,
input_tensors,
{.rootRank = communication->getRootRelativeIndex()});
}

c10::intrusive_ptr<c10d::Work> postReduce(
Expand Down
46 changes: 46 additions & 0 deletions tests/cpp/test_multidevice_lower_communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,52 @@ TEST_P(LowerCollectiveTest, AllgatherLoopSplit_Noncontig) {
__FILE__);
}

TEST_P(LowerCollectiveTest, ScatterLoopSplit) {
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());
const auto d = communicator_->size();
auto full_mesh = DeviceMesh::createForNumDevices(d);

DeviceMesh mesh_zero({0});
TensorView* tv0 = makeConcreteTensor({5, d * 3});
TensorView* tv1 = set(tv0);

tv0->setDeviceMesh(mesh_zero);
tv0->outer_split(1, d);
tv0->axis(1)->parallelize(ParallelType::Serial);
tv0->reorder({2, 0, 1});

tv1->setDeviceMesh(full_mesh);
tv1->outer_split(1, d);
tv1->axis(1)->parallelize(ParallelType::DIDx);
tv1->reorder({2, 0, 1});

fusion->addInput(tv0);
fusion->addOutput(tv1);

for (auto tv : {tv0, tv1}) {
tv->setAllocationDomain(tv->getLoopDomain(), true);
}

at::Tensor unsharded_in_tensor =
at::randn({d * 3, 5}, tensor_options).transpose(0, 1);

at::Tensor expected_output = shardTensor(unsharded_in_tensor, 1, full_mesh);

FusionExecutorCache executor_cache(std::move(fusion));
at::Tensor out_tensor =
executor_cache.runFusionWithInputs({unsharded_in_tensor})[0]
.as<at::Tensor>();

testValidate(
executor_cache.fusion(),
{out_tensor},
{unsharded_in_tensor},
{expected_output},
__LINE__,
__FILE__);
}

INSTANTIATE_TEST_SUITE_P(
HostIrLowering,
LowerCollectiveTest,
Expand Down
4 changes: 2 additions & 2 deletions tests/cpp/test_multidevice_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -355,8 +355,8 @@ INSTANTIATE_TEST_SUITE_P(
PipelineTestTwoStages,
testing::Combine(
testing::Values(CommunicatorBackend::kNccl, CommunicatorBackend::kUcc),
all_meshes,
all_meshes,
testing::Values(mesh0, mesh1),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing this to avoid cases where root is not in output device mesh.

testing::Values(mesh2, mesh4, mesh5),
testing::Values(false),
testing::Values(true),
testing::Values(false),
Expand Down