diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index 4ad1cab4818..b7ffc3d0bfd 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -294,8 +294,7 @@ c10::intrusive_ptr postScatter( input_tensors.front().push_back(output_tensor); continue; } - input_tensors.front().push_back( - input_tensor.slice(0, j, j + 1).contiguous()); + input_tensors.front().push_back(input_tensor.slice(0, j, j + 1)); j++; } diff --git a/csrc/multidevice/executor.cpp b/csrc/multidevice/executor.cpp index fb33eac3bf2..eecc6b4873c 100644 --- a/csrc/multidevice/executor.cpp +++ b/csrc/multidevice/executor.cpp @@ -50,9 +50,12 @@ MultiDeviceExecutor::MultiDeviceExecutor( Communicator& comm, MultiDeviceExecutorParams params) : comm_(comm), params_(params) { + // Sharding PreSegmenter passes. + // Note: passes run before PreSegmenter optimization passes. propagateShardings(fusion.get()); insertReshardings(fusion.get()); insertShardedAxisReordering(fusion.get()); + setShardedAllocationDomain(fusion.get()); SegmentCandidateFinderOptions options{ .run_translate_welford = false, .run_combine_reductions = false, diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp index ad49d53b2ab..4d79378d927 100644 --- a/csrc/multidevice/utils.cpp +++ b/csrc/multidevice/utils.cpp @@ -317,6 +317,12 @@ void insertReshardingsAfter(Fusion* fusion) { } } +void setShardedAllocationDomain(TensorView* tv) { + if (!tv->hasAllocation()) { + tv->setAllocationDomain(tv->getLeafDomain(), true); + } +} + } // namespace void propagateShardings(Fusion* fusion) { @@ -325,15 +331,16 @@ void propagateShardings(Fusion* fusion) { auto outputs = ir_utils::filterByType(expr->outputs()); TensorView* input_with_mesh = nullptr; for (auto tv : inputs) { - if (tv->hasDeviceMesh()) { + NVF_CHECK( + tv->hasDeviceMesh(), + "Tensor ", + tv->toString(), + " should be assigned a DeviceMesh"); + if (input_with_mesh == nullptr) { input_with_mesh = tv; - break; } } - NVF_ERROR( - input_with_mesh != nullptr, - "At least one input requires a DeviceMesh ", - expr->toString()); + std::vector outputs_without_mesh; for (auto tv : outputs) { if (!tv->hasDeviceMesh()) { @@ -360,7 +367,7 @@ void insertShardedAxisReordering(Fusion* fusion) { } NVF_ERROR( ir_utils::isTvOp(expr), - "Non-tv op is not supported : ", + "Non-tv op is not supported:", expr->toString()); NVF_ERROR( expr->outputs().size() == 1, @@ -375,7 +382,8 @@ void insertShardedAxisReordering(Fusion* fusion) { auto [shard_additions, shard_deletions] = getShardingChanges(expr); NVF_ERROR( shard_additions.size() + shard_deletions.size() <= 1, - "Resharding expr can only support one axis") + "Resharding expr can only support one axis:", + expr->toString()) // For gather operations i.e. ID goes from sharded to unsharded // this will rematerialize a sharded axis. @@ -464,6 +472,28 @@ void insertShardedAxisReordering(Fusion* fusion) { } } +void setShardedAllocationDomain(Fusion* fusion) { + for (Expr* expr : fusion->exprs()) { + if (!isResharding(expr)) { + continue; + } + for (TensorView* tv : ir_utils::filterByType(expr->inputs())) { + for (auto c : tv->getContiguity()) { + if (c.has_value()) { + NVF_CHECK( + c.value(), + "Resharding expression input must be contiguous: ", + expr); + } + } + setShardedAllocationDomain(tv); + } + for (auto tv : ir_utils::filterByType(expr->outputs())) { + setShardedAllocationDomain(tv); + } + } +} + int64_t requestedNumberOfDevices(Fusion* fusion) { DeviceIdxType max_index = 0; for (auto tv : ir_utils::allTvs(fusion)) { diff --git a/csrc/multidevice/utils.h b/csrc/multidevice/utils.h index d60bca1e455..a07072404eb 100644 --- a/csrc/multidevice/utils.h +++ b/csrc/multidevice/utils.h @@ -105,6 +105,13 @@ void insertReshardings(Fusion* fusion); // to the front so that communication operations are contiguous. void insertShardedAxisReordering(Fusion* fusion); +// Resharding expressions are mapped to collective libraries which expect +// contiguous tensors and output contiguous buffers. This pass checks that +// inputs are contiguous and sets the allocation domain of inputs and outputs of +// all resharding expressions. This pass should run after all passes that add or +// update resharding expressions. +void setShardedAllocationDomain(Fusion* fusion); + // Returns the index of the a sharded axis if none return -1. // TODO: Assumes no merges/splits on sharded axis. int64_t getShardedAxis(TensorView*); diff --git a/tests/cpp/test_multidevice_matmul.cpp b/tests/cpp/test_multidevice_matmul.cpp index 2755d97c7ac..e0f31b3e114 100644 --- a/tests/cpp/test_multidevice_matmul.cpp +++ b/tests/cpp/test_multidevice_matmul.cpp @@ -23,8 +23,6 @@ #include #include #include -#include -#include #include #include #include @@ -36,8 +34,7 @@ namespace nvfuser { class DistributedMatmulTest : public MultiDeviceTest { protected: - DistributedMatmulTest() - : num_devices_(communicator->size()), optimization_guard_(false) {} + DistributedMatmulTest() : num_devices_(communicator->size()) {} void SetUp() { MultiDeviceTest::SetUp(); @@ -67,11 +64,6 @@ class DistributedMatmulTest : public MultiDeviceTest { atMatmul(a.to(at::kDouble), b.to(at::kDouble), layout).to(at::kFloat); return std::make_tuple(a, b, c); } - - private: - preseg_passes::OptimizationPassGuard - optimization_guard_; - DisableOptionsGuard option_guard_; }; TEST_F(DistributedMatmulTest, LayoutTN_NoComms) { @@ -107,6 +99,11 @@ TEST_F(DistributedMatmulTest, LayoutTN_NoComms) { } b->setDeviceMesh(mesh); + // TODO: If c's allocation domain isn't set, it will fail validation at + // csrc/device_lower/validation.cpp:419, Vectorized dim for consumer has to be + // from a contiguous inner most position. + c->setAllocationDomain(c->getLeafDomain(), true); + auto [in0, in1, out] = getInputsAndReferenceOutputs(MmaLayout::TN, M, N, K); in0 = in0.view({Mo, Mi, K}); out = out.view({Mo, Mi, N}); diff --git a/tests/cpp/test_multidevice_sharding.cpp b/tests/cpp/test_multidevice_sharding.cpp index 14651df4f7d..164dab45cd0 100644 --- a/tests/cpp/test_multidevice_sharding.cpp +++ b/tests/cpp/test_multidevice_sharding.cpp @@ -33,8 +33,9 @@ TEST_P(MultideviceShardingTest, UnshardedGlobalInput) { input_size[sharded_dim] = num_devices; input_size[sharded_output_dim] = num_devices; - TensorView* tv0 = creates_concrete_tensor ? makeConcreteTensor(input_size) - : makeSymbolicTensor(4); + TensorView* tv0 = creates_concrete_tensor + ? makeContigConcreteTensor(input_size) + : makeContigTensor(4); TensorView* tv1 = set(tv0); TensorView* tv2 = add(tv1, tv1); TensorView* tv3 = sum(tv2, {sharded_dim}); @@ -82,8 +83,8 @@ TEST_P(MultideviceShardingTest, ShardGlobalInput) { unsharded_input_size[sharded_dim] = num_devices; TensorView* tv0 = creates_concrete_tensor - ? makeConcreteTensor(unsharded_input_size) - : makeSymbolicTensor(unsharded_input_size.size()); + ? makeContigConcreteTensor(unsharded_input_size) + : makeContigTensor(unsharded_input_size.size()); TensorView* tv1 = set(tv0); TensorView* tv2 = add(tv1, tv1); fusion->addInput(tv0); diff --git a/tests/cpp/test_sharding.cpp b/tests/cpp/test_sharding.cpp index 705ad6da1cc..860fd5ded4a 100644 --- a/tests/cpp/test_sharding.cpp +++ b/tests/cpp/test_sharding.cpp @@ -5,8 +5,10 @@ * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on -#include + #include + +#include #include #include #include @@ -62,6 +64,58 @@ TEST_F(ShardingTest, PropagateSharding) { EXPECT_TRUE(getTvsWithDifferentSharding(a, tvs).empty()); } +void isContiguous(TensorView* tv) { + EXPECT_TRUE(tv->hasAllocation()); + auto contiguity = tv->getContiguity(); + auto alloc_domain = tv->getAllocationDomain(); + for (auto i : c10::irange(contiguity.size())) { + // TODO: This should eventually check that DeviceDim domains also has no + // value. + if (alloc_domain[i]->isReduction() || alloc_domain[i]->isBroadcast()) { + EXPECT_FALSE(contiguity[i].has_value()); + } else { + EXPECT_TRUE(contiguity[i].value()); + } + } +} + +TEST_F(ShardingTest, ShardedAllocationDomain) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* a = makeContigTensor(3); + TensorView* b = makeContigTensor(3); + TensorView* c = add(a, b); + TensorView* d = sum(c, {1}); + + DeviceMesh mesh = DeviceMesh::createForNumDevices(3); + for (auto tv : {a, b, c, d}) { + tv->setDeviceMesh(mesh); + } + + int sharded_dim = 1; + a->axis(sharded_dim)->parallelize(ParallelType::DIDx); + c->axis(sharded_dim)->parallelize(ParallelType::DIDx); + fusion.addInput(a); + fusion.addInput(b); + fusion.addOutput(d); + + propagateShardings(&fusion); + insertReshardings(&fusion); + insertShardedAxisReordering(&fusion); + setShardedAllocationDomain(&fusion); + for (auto expr : fusion.exprs()) { + if (isResharding(expr)) { + for (auto tv : ir_utils::filterByType(expr->inputs())) { + isContiguous(tv); + } + for (auto tv : ir_utils::filterByType(expr->outputs())) { + isContiguous(tv); + } + } + } +} + TEST_P(ShardingTest, ComputeIndex) { const bool creates_concrete_tensor = GetParam(); std::unique_ptr fusion = std::make_unique();