diff --git a/csrc/host_ir/lower_to_communication.cpp b/csrc/host_ir/lower_to_communication.cpp index 49e796a679e..8416ef6cdd1 100644 --- a/csrc/host_ir/lower_to_communication.cpp +++ b/csrc/host_ir/lower_to_communication.cpp @@ -54,17 +54,24 @@ void lowerToScatter( TensorView* output_tv, const HostIrLowerParams& params, std::vector& comms) { - // we arbitrarily choose the first device of the sender mesh to be the root const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh(); NVF_ERROR( receiver_mesh.rank() == 1, "Gather only supported on a 1D mesh. Given ", receiver_mesh); - auto root = input_tv->getDeviceMesh().at(0); + + // Find a common device between input and receiver meshes to be the root + auto it = std::ranges::find_if( + input_tv->getDeviceMesh().vector(), + [&receiver_mesh](DeviceIdxType device) { + return receiver_mesh.has(device); + }); + NVF_ERROR( + it != input_tv->getDeviceMesh().vector().end(), + "No common device found between input and receiver meshes"); + DeviceIdxType root = *it; + Team team = receiver_mesh.vector(); - if (!receiver_mesh.has(root)) { - team.push_back(root); - } comms.push_back(IrBuilder::create( CommunicationType::Scatter, output_tv, diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index 041e13eb80c..3a2abebb264 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -378,33 +378,44 @@ c10::intrusive_ptr postScatter( c10d::Backend* backend, at::Tensor input_tensor, at::Tensor output_tensor) { - if (my_device_index == communication->root() && - !communication->out()->getDeviceMesh().has(communication->root())) { - output_tensor = at::empty_like(input_tensor.slice(0, 0, 1)); - } - std::vector output_tensors({output_tensor}); + NVF_ERROR( + isTvContiguous(communication->in()), "Input tensor is not contiguous"); + NVF_ERROR( + isTvContiguous(communication->out()), "Output tensor is not contiguous"); + + auto output_device_mesh = communication->out()->getDeviceMesh(); + NVF_ERROR( + output_device_mesh.has(communication->root()), + "communication->root() ", + communication->root(), + " is not in the output device mesh ", + output_device_mesh, + "."); - auto root_relative_index = communication->getRootRelativeIndex(); std::vector> input_tensors; + + output_tensor = output_tensor.as_strided({output_tensor.numel()}, {1}); + std::vector output_tensors({output_tensor}); + if (my_device_index == communication->root()) { + auto splits = at::tensor_split( + input_tensor.as_strided({input_tensor.numel()}, {1}), + output_device_mesh.size(), + /*dim=*/0); + input_tensors.resize(1); - int64_t j = 0; - for (auto i : arange(communication->team().size())) { - if (root_relative_index == static_cast(i) && - !communication->out()->getDeviceMesh().has(communication->root())) { - input_tensors.front().push_back(output_tensor); - continue; - } - input_tensors.front().push_back(input_tensor.slice(0, j, j + 1)); - j++; + for (const auto& split : splits) { + input_tensors.front().push_back(split); } - assertBufferCount(input_tensors[0], communication->team().size()); + assertBufferCount(input_tensors[0], output_device_mesh.size()); assertBuffersHaveSameSize(input_tensors[0], output_tensors); } return backend->scatter( - output_tensors, input_tensors, {.rootRank = root_relative_index}); + output_tensors, + input_tensors, + {.rootRank = communication->getRootRelativeIndex()}); } c10::intrusive_ptr postReduce( diff --git a/tests/cpp/test_multidevice_lower_communication.cpp b/tests/cpp/test_multidevice_lower_communication.cpp index a9725615380..622d8d2450a 100644 --- a/tests/cpp/test_multidevice_lower_communication.cpp +++ b/tests/cpp/test_multidevice_lower_communication.cpp @@ -585,6 +585,52 @@ TEST_P(LowerCollectiveTest, AllgatherLoopSplit_Noncontig) { __FILE__); } +TEST_P(LowerCollectiveTest, ScatterLoopSplit) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + const auto d = communicator_->size(); + auto full_mesh = DeviceMesh::createForNumDevices(d); + + DeviceMesh mesh_zero({0}); + TensorView* tv0 = makeConcreteTensor({5, d * 3}); + TensorView* tv1 = set(tv0); + + tv0->setDeviceMesh(mesh_zero); + tv0->outer_split(1, d); + tv0->axis(1)->parallelize(ParallelType::Serial); + tv0->reorder({2, 0, 1}); + + tv1->setDeviceMesh(full_mesh); + tv1->outer_split(1, d); + tv1->axis(1)->parallelize(ParallelType::DIDx); + tv1->reorder({2, 0, 1}); + + fusion->addInput(tv0); + fusion->addOutput(tv1); + + for (auto tv : {tv0, tv1}) { + tv->setAllocationDomain(tv->getLoopDomain(), true); + } + + at::Tensor unsharded_in_tensor = + at::randn({d * 3, 5}, tensor_options).transpose(0, 1); + + at::Tensor expected_output = shardTensor(unsharded_in_tensor, 1, full_mesh); + + FusionExecutorCache executor_cache(std::move(fusion)); + at::Tensor out_tensor = + executor_cache.runFusionWithInputs({unsharded_in_tensor})[0] + .as(); + + testValidate( + executor_cache.fusion(), + {out_tensor}, + {unsharded_in_tensor}, + {expected_output}, + __LINE__, + __FILE__); +} + INSTANTIATE_TEST_SUITE_P( HostIrLowering, LowerCollectiveTest, diff --git a/tests/cpp/test_multidevice_pipeline.cpp b/tests/cpp/test_multidevice_pipeline.cpp index 12dfed5dd43..05f1620f4c7 100644 --- a/tests/cpp/test_multidevice_pipeline.cpp +++ b/tests/cpp/test_multidevice_pipeline.cpp @@ -355,8 +355,8 @@ INSTANTIATE_TEST_SUITE_P( PipelineTestTwoStages, testing::Combine( testing::Values(CommunicatorBackend::kNccl, CommunicatorBackend::kUcc), - all_meshes, - all_meshes, + testing::Values(mesh0, mesh1), + testing::Values(mesh2, mesh4, mesh5), testing::Values(false), testing::Values(true), testing::Values(false),