From ab4489f6fa143b9629ac1a6d5af81951f56e5d8d Mon Sep 17 00:00:00 2001 From: mcowan Date: Sat, 18 May 2024 05:46:11 +0000 Subject: [PATCH 1/8] set allocation domain of sharded tensors --- csrc/multidevice/executor.cpp | 2 +- csrc/multidevice/utils.cpp | 36 +++++++++++++++++++++------ csrc/multidevice/utils.h | 4 ++- tests/cpp/test_multidevice_matmul.cpp | 7 +----- tests/cpp/test_sharding.cpp | 13 +++++++--- 5 files changed, 44 insertions(+), 18 deletions(-) diff --git a/csrc/multidevice/executor.cpp b/csrc/multidevice/executor.cpp index fb33eac3bf2..51536d1ecc4 100644 --- a/csrc/multidevice/executor.cpp +++ b/csrc/multidevice/executor.cpp @@ -50,7 +50,7 @@ MultiDeviceExecutor::MultiDeviceExecutor( Communicator& comm, MultiDeviceExecutorParams params) : comm_(comm), params_(params) { - propagateShardings(fusion.get()); + propagateShardingsAndSetAllocationDomain(fusion.get()); insertReshardings(fusion.get()); insertShardedAxisReordering(fusion.get()); SegmentCandidateFinderOptions options{ diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp index ad49d53b2ab..12057268279 100644 --- a/csrc/multidevice/utils.cpp +++ b/csrc/multidevice/utils.cpp @@ -317,23 +317,41 @@ void insertReshardingsAfter(Fusion* fusion) { } } +void setShardedAllocationDomain(TensorView* tv) { + if (!tv->hasAllocation()) { + tv->setAllocationDomain(tv->getLeafDomain(), tv->getContiguity()); + } +} + } // namespace -void propagateShardings(Fusion* fusion) { +void propagateShardingsAndSetAllocationDomain(Fusion* fusion) { + // Set global input's allocation domain. + // Currently we only propagate shardings from consumer to producer, + // so a DeviceMesh is required. + for (auto global_input_tv : + ir_utils::filterByType(fusion->inputs())) { + NVF_ERROR( + global_input_tv->hasDeviceMesh(), + "Currently global inputs must be sharded sharded ", + global_input_tv); + setShardedAllocationDomain(global_input_tv); + } + for (auto expr : fusion->exprs()) { auto inputs = ir_utils::filterByType(expr->inputs()); auto outputs = ir_utils::filterByType(expr->outputs()); TensorView* input_with_mesh = nullptr; for (auto tv : inputs) { - if (tv->hasDeviceMesh()) { + NVF_ERROR( + tv->hasDeviceMesh(), + "Currently require inputs are sharded ", + expr->toString()); + if (input_with_mesh == nullptr) { input_with_mesh = tv; - break; } } - NVF_ERROR( - input_with_mesh != nullptr, - "At least one input requires a DeviceMesh ", - expr->toString()); + std::vector outputs_without_mesh; for (auto tv : outputs) { if (!tv->hasDeviceMesh()) { @@ -341,6 +359,10 @@ void propagateShardings(Fusion* fusion) { } } shardAllLike(input_with_mesh, outputs_without_mesh); + // All outputs have a sharding, so set the allocation domain. + for (auto tv : outputs) { + setShardedAllocationDomain(tv); + } } } diff --git a/csrc/multidevice/utils.h b/csrc/multidevice/utils.h index d60bca1e455..7fae5b300bf 100644 --- a/csrc/multidevice/utils.h +++ b/csrc/multidevice/utils.h @@ -90,7 +90,9 @@ void unshard(TensorView*); // This assumes that all global inputs are sharded. // This cannot be done when the Op is inserted into the fusion, because // the multidevice shcheduling hasn't been applied. -void propagateShardings(Fusion* fusion); +// After this step all TensorViews have a DeviceMesh, so the allocation +// domain is set for all TensorViews as well if not explicitly set. +void propagateShardingsAndSetAllocationDomain(Fusion* fusion); // Runs through the fusion and inserts a resharding Set Op after // any resharding Expr that is not directly lowerable to a series of diff --git a/tests/cpp/test_multidevice_matmul.cpp b/tests/cpp/test_multidevice_matmul.cpp index 3c046a51abb..792acd8c0a8 100644 --- a/tests/cpp/test_multidevice_matmul.cpp +++ b/tests/cpp/test_multidevice_matmul.cpp @@ -37,7 +37,7 @@ namespace nvfuser { class DistributedMatmulTest : public MultiDeviceTest { protected: DistributedMatmulTest() - : num_devices_(communicator->size()), optimization_guard_(false) { + : num_devices_(communicator->size()){ DisableOptionsGuard::getCurOptions().set(DisableOption::MatmulExprEval); } @@ -69,11 +69,6 @@ class DistributedMatmulTest : public MultiDeviceTest { atMatmul(a.to(at::kDouble), b.to(at::kDouble), layout).to(at::kFloat); return std::make_tuple(a, b, c); } - - private: - preseg_passes::OptimizationPassGuard - optimization_guard_; - DisableOptionsGuard option_guard_; }; TEST_F(DistributedMatmulTest, LayoutTN_NoComms) { diff --git a/tests/cpp/test_sharding.cpp b/tests/cpp/test_sharding.cpp index 705ad6da1cc..f3ff3b01060 100644 --- a/tests/cpp/test_sharding.cpp +++ b/tests/cpp/test_sharding.cpp @@ -5,8 +5,11 @@ * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on -#include + #include + +#include +#include #include #include #include @@ -39,7 +42,7 @@ TEST_F(ShardingTest, IsSharded) { EXPECT_ANY_THROW(isSharded(c)); } -TEST_F(ShardingTest, PropagateSharding) { +TEST_F(ShardingTest, TestPropagateShardingsAndSetAllocationDomain) { Fusion fusion; FusionGuard fg(&fusion); @@ -57,9 +60,13 @@ TEST_F(ShardingTest, PropagateSharding) { fusion.addOutput(c); // Expected behavior: a's shardings propagate to c. - propagateShardings(&fusion); + propagateShardingsAndSetAllocationDomain(&fusion); std::vector tvs = {c}; EXPECT_TRUE(getTvsWithDifferentSharding(a, tvs).empty()); + + for (auto tv : {a, b, c}) { + EXPECT_TRUE(tv->hasAllocation()); + } } TEST_P(ShardingTest, ComputeIndex) { From cb34b54bd8121a928d4bbd69be8eb21020a6790b Mon Sep 17 00:00:00 2001 From: mcowan Date: Sat, 18 May 2024 05:47:43 +0000 Subject: [PATCH 2/8] name --- tests/cpp/test_sharding.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cpp/test_sharding.cpp b/tests/cpp/test_sharding.cpp index f3ff3b01060..24babf113d4 100644 --- a/tests/cpp/test_sharding.cpp +++ b/tests/cpp/test_sharding.cpp @@ -42,7 +42,7 @@ TEST_F(ShardingTest, IsSharded) { EXPECT_ANY_THROW(isSharded(c)); } -TEST_F(ShardingTest, TestPropagateShardingsAndSetAllocationDomain) { +TEST_F(ShardingTest, PropagateShardingsAndSetAllocationDomain) { Fusion fusion; FusionGuard fg(&fusion); From 256d3455cde014b2c75904948ed250f3583b0a73 Mon Sep 17 00:00:00 2001 From: mcowan Date: Sat, 18 May 2024 05:48:19 +0000 Subject: [PATCH 3/8] remove unused import --- tests/cpp/test_sharding.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/cpp/test_sharding.cpp b/tests/cpp/test_sharding.cpp index 24babf113d4..9fe5b0595b5 100644 --- a/tests/cpp/test_sharding.cpp +++ b/tests/cpp/test_sharding.cpp @@ -9,7 +9,6 @@ #include #include -#include #include #include #include From 5fc3683a8f79c05f6d558d0e131b5484a055c0ca Mon Sep 17 00:00:00 2001 From: mcowan Date: Sat, 18 May 2024 05:55:55 +0000 Subject: [PATCH 4/8] fix error messages --- csrc/multidevice/utils.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp index 12057268279..42b4c3cf3ed 100644 --- a/csrc/multidevice/utils.cpp +++ b/csrc/multidevice/utils.cpp @@ -333,8 +333,8 @@ void propagateShardingsAndSetAllocationDomain(Fusion* fusion) { ir_utils::filterByType(fusion->inputs())) { NVF_ERROR( global_input_tv->hasDeviceMesh(), - "Currently global inputs must be sharded sharded ", - global_input_tv); + "Global inputs must be assigned a DeviceMesh ", + global_input_tv->toString()); setShardedAllocationDomain(global_input_tv); } @@ -345,7 +345,7 @@ void propagateShardingsAndSetAllocationDomain(Fusion* fusion) { for (auto tv : inputs) { NVF_ERROR( tv->hasDeviceMesh(), - "Currently require inputs are sharded ", + "Expression inputs should be assigned a DeviceMesh ", expr->toString()); if (input_with_mesh == nullptr) { input_with_mesh = tv; From 4231198e530f51e3bf34014bf296a37443e97733 Mon Sep 17 00:00:00 2001 From: mcowan Date: Mon, 20 May 2024 20:57:43 +0000 Subject: [PATCH 5/8] only set allocation domain of resharding expr tvs and fix broken tests --- csrc/multidevice/communication.cpp | 3 +- csrc/multidevice/executor.cpp | 5 ++- csrc/multidevice/utils.cpp | 46 ++++++++++++++----------- csrc/multidevice/utils.h | 11 ++++-- tests/cpp/test_multidevice_matmul.cpp | 10 +++--- tests/cpp/test_multidevice_sharding.cpp | 7 ++-- tests/cpp/test_sharding.cpp | 39 ++++++++++++++++++--- 7 files changed, 84 insertions(+), 37 deletions(-) diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index 4ad1cab4818..b7ffc3d0bfd 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -294,8 +294,7 @@ c10::intrusive_ptr postScatter( input_tensors.front().push_back(output_tensor); continue; } - input_tensors.front().push_back( - input_tensor.slice(0, j, j + 1).contiguous()); + input_tensors.front().push_back(input_tensor.slice(0, j, j + 1)); j++; } diff --git a/csrc/multidevice/executor.cpp b/csrc/multidevice/executor.cpp index 51536d1ecc4..eecc6b4873c 100644 --- a/csrc/multidevice/executor.cpp +++ b/csrc/multidevice/executor.cpp @@ -50,9 +50,12 @@ MultiDeviceExecutor::MultiDeviceExecutor( Communicator& comm, MultiDeviceExecutorParams params) : comm_(comm), params_(params) { - propagateShardingsAndSetAllocationDomain(fusion.get()); + // Sharding PreSegmenter passes. + // Note: passes run before PreSegmenter optimization passes. + propagateShardings(fusion.get()); insertReshardings(fusion.get()); insertShardedAxisReordering(fusion.get()); + setShardedAllocationDomain(fusion.get()); SegmentCandidateFinderOptions options{ .run_translate_welford = false, .run_combine_reductions = false, diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp index 42b4c3cf3ed..7eb1d86ac83 100644 --- a/csrc/multidevice/utils.cpp +++ b/csrc/multidevice/utils.cpp @@ -319,31 +319,19 @@ void insertReshardingsAfter(Fusion* fusion) { void setShardedAllocationDomain(TensorView* tv) { if (!tv->hasAllocation()) { - tv->setAllocationDomain(tv->getLeafDomain(), tv->getContiguity()); + tv->setAllocationDomain(tv->getLeafDomain(), true); } } } // namespace -void propagateShardingsAndSetAllocationDomain(Fusion* fusion) { - // Set global input's allocation domain. - // Currently we only propagate shardings from consumer to producer, - // so a DeviceMesh is required. - for (auto global_input_tv : - ir_utils::filterByType(fusion->inputs())) { - NVF_ERROR( - global_input_tv->hasDeviceMesh(), - "Global inputs must be assigned a DeviceMesh ", - global_input_tv->toString()); - setShardedAllocationDomain(global_input_tv); - } - +void propagateShardings(Fusion* fusion) { for (auto expr : fusion->exprs()) { auto inputs = ir_utils::filterByType(expr->inputs()); auto outputs = ir_utils::filterByType(expr->outputs()); TensorView* input_with_mesh = nullptr; for (auto tv : inputs) { - NVF_ERROR( + NVF_CHECK( tv->hasDeviceMesh(), "Expression inputs should be assigned a DeviceMesh ", expr->toString()); @@ -359,10 +347,6 @@ void propagateShardingsAndSetAllocationDomain(Fusion* fusion) { } } shardAllLike(input_with_mesh, outputs_without_mesh); - // All outputs have a sharding, so set the allocation domain. - for (auto tv : outputs) { - setShardedAllocationDomain(tv); - } } } @@ -397,7 +381,8 @@ void insertShardedAxisReordering(Fusion* fusion) { auto [shard_additions, shard_deletions] = getShardingChanges(expr); NVF_ERROR( shard_additions.size() + shard_deletions.size() <= 1, - "Resharding expr can only support one axis") + "Resharding expr can only support one axis ", + expr->toString()) // For gather operations i.e. ID goes from sharded to unsharded // this will rematerialize a sharded axis. @@ -486,6 +471,27 @@ void insertShardedAxisReordering(Fusion* fusion) { } } +void setShardedAllocationDomain(Fusion* fusion) { + for (auto expr : fusion->exprs()) { + if (isResharding(expr)) { + for (auto tv : ir_utils::filterByType(expr->inputs())) { + for (auto c : tv->getContiguity()) { + if (c.has_value()) { + NVF_CHECK( + c.value(), + "Resharding expression input must be contiguous ", + expr); + } + } + setShardedAllocationDomain(tv); + } + for (auto tv : ir_utils::filterByType(expr->outputs())) { + setShardedAllocationDomain(tv); + } + } + } +} + int64_t requestedNumberOfDevices(Fusion* fusion) { DeviceIdxType max_index = 0; for (auto tv : ir_utils::allTvs(fusion)) { diff --git a/csrc/multidevice/utils.h b/csrc/multidevice/utils.h index 7fae5b300bf..a07072404eb 100644 --- a/csrc/multidevice/utils.h +++ b/csrc/multidevice/utils.h @@ -90,9 +90,7 @@ void unshard(TensorView*); // This assumes that all global inputs are sharded. // This cannot be done when the Op is inserted into the fusion, because // the multidevice shcheduling hasn't been applied. -// After this step all TensorViews have a DeviceMesh, so the allocation -// domain is set for all TensorViews as well if not explicitly set. -void propagateShardingsAndSetAllocationDomain(Fusion* fusion); +void propagateShardings(Fusion* fusion); // Runs through the fusion and inserts a resharding Set Op after // any resharding Expr that is not directly lowerable to a series of @@ -107,6 +105,13 @@ void insertReshardings(Fusion* fusion); // to the front so that communication operations are contiguous. void insertShardedAxisReordering(Fusion* fusion); +// Resharding expressions are mapped to collective libraries which expect +// contiguous tensors and output contiguous buffers. This pass checks that +// inputs are contiguous and sets the allocation domain of inputs and outputs of +// all resharding expressions. This pass should run after all passes that add or +// update resharding expressions. +void setShardedAllocationDomain(Fusion* fusion); + // Returns the index of the a sharded axis if none return -1. // TODO: Assumes no merges/splits on sharded axis. int64_t getShardedAxis(TensorView*); diff --git a/tests/cpp/test_multidevice_matmul.cpp b/tests/cpp/test_multidevice_matmul.cpp index 792acd8c0a8..e26ade0559c 100644 --- a/tests/cpp/test_multidevice_matmul.cpp +++ b/tests/cpp/test_multidevice_matmul.cpp @@ -23,8 +23,6 @@ #include #include #include -#include -#include #include #include #include @@ -36,8 +34,7 @@ namespace nvfuser { class DistributedMatmulTest : public MultiDeviceTest { protected: - DistributedMatmulTest() - : num_devices_(communicator->size()){ + DistributedMatmulTest() : num_devices_(communicator->size()) { DisableOptionsGuard::getCurOptions().set(DisableOption::MatmulExprEval); } @@ -104,6 +101,11 @@ TEST_F(DistributedMatmulTest, LayoutTN_NoComms) { } b->setDeviceMesh(mesh); + // TODO: If c's allocation domain isn't set, it will fail validation at + // csrc/device_lower/validation.cpp:419, Vectorized dim for consumer has to be + // from a contiguous inner most position. + c->setAllocationDomain(c->getLeafDomain(), true); + auto [in0, in1, out] = getInputsAndReferenceOutputs(MmaLayout::TN, M, N, K); in0 = in0.view({Mo, Mi, K}); out = out.view({Mo, Mi, N}); diff --git a/tests/cpp/test_multidevice_sharding.cpp b/tests/cpp/test_multidevice_sharding.cpp index 14651df4f7d..465377fc38f 100644 --- a/tests/cpp/test_multidevice_sharding.cpp +++ b/tests/cpp/test_multidevice_sharding.cpp @@ -33,8 +33,9 @@ TEST_P(MultideviceShardingTest, UnshardedGlobalInput) { input_size[sharded_dim] = num_devices; input_size[sharded_output_dim] = num_devices; - TensorView* tv0 = creates_concrete_tensor ? makeConcreteTensor(input_size) - : makeSymbolicTensor(4); + TensorView* tv0 = creates_concrete_tensor + ? makeContigConcreteTensor(input_size) + : makeContigTensor(4); TensorView* tv1 = set(tv0); TensorView* tv2 = add(tv1, tv1); TensorView* tv3 = sum(tv2, {sharded_dim}); @@ -83,7 +84,7 @@ TEST_P(MultideviceShardingTest, ShardGlobalInput) { TensorView* tv0 = creates_concrete_tensor ? makeConcreteTensor(unsharded_input_size) - : makeSymbolicTensor(unsharded_input_size.size()); + : makeContigTensor(unsharded_input_size.size()); TensorView* tv1 = set(tv0); TensorView* tv2 = add(tv1, tv1); fusion->addInput(tv0); diff --git a/tests/cpp/test_sharding.cpp b/tests/cpp/test_sharding.cpp index 9fe5b0595b5..c67ae74b45d 100644 --- a/tests/cpp/test_sharding.cpp +++ b/tests/cpp/test_sharding.cpp @@ -41,7 +41,7 @@ TEST_F(ShardingTest, IsSharded) { EXPECT_ANY_THROW(isSharded(c)); } -TEST_F(ShardingTest, PropagateShardingsAndSetAllocationDomain) { +TEST_F(ShardingTest, PropagateShardings) { Fusion fusion; FusionGuard fg(&fusion); @@ -59,12 +59,43 @@ TEST_F(ShardingTest, PropagateShardingsAndSetAllocationDomain) { fusion.addOutput(c); // Expected behavior: a's shardings propagate to c. - propagateShardingsAndSetAllocationDomain(&fusion); + propagateShardings(&fusion); std::vector tvs = {c}; EXPECT_TRUE(getTvsWithDifferentSharding(a, tvs).empty()); +} + +TEST_F(ShardingTest, ShardedAllocationDomain) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* a = makeContigTensor(3); + TensorView* b = makeContigTensor(3); + TensorView* c = add(a, b); + TensorView* d = sum(c, {1}); - for (auto tv : {a, b, c}) { - EXPECT_TRUE(tv->hasAllocation()); + DeviceMesh mesh = DeviceMesh::createForNumDevices(3); + for (auto tv : {a, b, c, d}) { + tv->setDeviceMesh(mesh); + } + a->axis(1)->parallelize(ParallelType::DIDx); + c->axis(1)->parallelize(ParallelType::DIDx); + fusion.addInput(a); + fusion.addInput(b); + fusion.addOutput(d); + + propagateShardings(&fusion); + insertReshardings(&fusion); + insertShardedAxisReordering(&fusion); + setShardedAllocationDomain(&fusion); + for (auto expr : fusion.exprs()) { + if (isResharding(expr)) { + for (auto tv : ir_utils::filterByType(expr->inputs())) { + EXPECT_TRUE(tv->hasAllocation()); + } + for (auto tv : ir_utils::filterByType(expr->outputs())) { + EXPECT_TRUE(tv->hasAllocation()); + } + } } } From bdebf65de71aa6fcdf70e5a8768a889e7685ecc7 Mon Sep 17 00:00:00 2001 From: mcowan Date: Mon, 20 May 2024 21:11:06 +0000 Subject: [PATCH 6/8] contig fix --- tests/cpp/test_multidevice_sharding.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cpp/test_multidevice_sharding.cpp b/tests/cpp/test_multidevice_sharding.cpp index 465377fc38f..164dab45cd0 100644 --- a/tests/cpp/test_multidevice_sharding.cpp +++ b/tests/cpp/test_multidevice_sharding.cpp @@ -83,7 +83,7 @@ TEST_P(MultideviceShardingTest, ShardGlobalInput) { unsharded_input_size[sharded_dim] = num_devices; TensorView* tv0 = creates_concrete_tensor - ? makeConcreteTensor(unsharded_input_size) + ? makeContigConcreteTensor(unsharded_input_size) : makeContigTensor(unsharded_input_size.size()); TensorView* tv1 = set(tv0); TensorView* tv2 = add(tv1, tv1); From 4e34f1b810aea792b2a188eb2249f352caae32f8 Mon Sep 17 00:00:00 2001 From: Meghan Cowan Date: Mon, 20 May 2024 14:19:51 -0700 Subject: [PATCH 7/8] Undo test name change --- tests/cpp/test_sharding.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cpp/test_sharding.cpp b/tests/cpp/test_sharding.cpp index c67ae74b45d..79e10a702c2 100644 --- a/tests/cpp/test_sharding.cpp +++ b/tests/cpp/test_sharding.cpp @@ -41,7 +41,7 @@ TEST_F(ShardingTest, IsSharded) { EXPECT_ANY_THROW(isSharded(c)); } -TEST_F(ShardingTest, PropagateShardings) { +TEST_F(ShardingTest, PropagateSharding) { Fusion fusion; FusionGuard fg(&fusion); From 373a344fd2ac07f06b85811b82b057952e36a485 Mon Sep 17 00:00:00 2001 From: mcowan Date: Fri, 24 May 2024 17:37:09 +0000 Subject: [PATCH 8/8] feedback --- csrc/multidevice/utils.cpp | 38 +++++++++++++++++++------------------ tests/cpp/test_sharding.cpp | 25 ++++++++++++++++++++---- 2 files changed, 41 insertions(+), 22 deletions(-) diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp index 7eb1d86ac83..4d79378d927 100644 --- a/csrc/multidevice/utils.cpp +++ b/csrc/multidevice/utils.cpp @@ -333,8 +333,9 @@ void propagateShardings(Fusion* fusion) { for (auto tv : inputs) { NVF_CHECK( tv->hasDeviceMesh(), - "Expression inputs should be assigned a DeviceMesh ", - expr->toString()); + "Tensor ", + tv->toString(), + " should be assigned a DeviceMesh"); if (input_with_mesh == nullptr) { input_with_mesh = tv; } @@ -366,7 +367,7 @@ void insertShardedAxisReordering(Fusion* fusion) { } NVF_ERROR( ir_utils::isTvOp(expr), - "Non-tv op is not supported : ", + "Non-tv op is not supported:", expr->toString()); NVF_ERROR( expr->outputs().size() == 1, @@ -381,7 +382,7 @@ void insertShardedAxisReordering(Fusion* fusion) { auto [shard_additions, shard_deletions] = getShardingChanges(expr); NVF_ERROR( shard_additions.size() + shard_deletions.size() <= 1, - "Resharding expr can only support one axis ", + "Resharding expr can only support one axis:", expr->toString()) // For gather operations i.e. ID goes from sharded to unsharded @@ -472,22 +473,23 @@ void insertShardedAxisReordering(Fusion* fusion) { } void setShardedAllocationDomain(Fusion* fusion) { - for (auto expr : fusion->exprs()) { - if (isResharding(expr)) { - for (auto tv : ir_utils::filterByType(expr->inputs())) { - for (auto c : tv->getContiguity()) { - if (c.has_value()) { - NVF_CHECK( - c.value(), - "Resharding expression input must be contiguous ", - expr); - } + for (Expr* expr : fusion->exprs()) { + if (!isResharding(expr)) { + continue; + } + for (TensorView* tv : ir_utils::filterByType(expr->inputs())) { + for (auto c : tv->getContiguity()) { + if (c.has_value()) { + NVF_CHECK( + c.value(), + "Resharding expression input must be contiguous: ", + expr); } - setShardedAllocationDomain(tv); - } - for (auto tv : ir_utils::filterByType(expr->outputs())) { - setShardedAllocationDomain(tv); } + setShardedAllocationDomain(tv); + } + for (auto tv : ir_utils::filterByType(expr->outputs())) { + setShardedAllocationDomain(tv); } } } diff --git a/tests/cpp/test_sharding.cpp b/tests/cpp/test_sharding.cpp index 79e10a702c2..860fd5ded4a 100644 --- a/tests/cpp/test_sharding.cpp +++ b/tests/cpp/test_sharding.cpp @@ -64,6 +64,21 @@ TEST_F(ShardingTest, PropagateSharding) { EXPECT_TRUE(getTvsWithDifferentSharding(a, tvs).empty()); } +void isContiguous(TensorView* tv) { + EXPECT_TRUE(tv->hasAllocation()); + auto contiguity = tv->getContiguity(); + auto alloc_domain = tv->getAllocationDomain(); + for (auto i : c10::irange(contiguity.size())) { + // TODO: This should eventually check that DeviceDim domains also has no + // value. + if (alloc_domain[i]->isReduction() || alloc_domain[i]->isBroadcast()) { + EXPECT_FALSE(contiguity[i].has_value()); + } else { + EXPECT_TRUE(contiguity[i].value()); + } + } +} + TEST_F(ShardingTest, ShardedAllocationDomain) { Fusion fusion; FusionGuard fg(&fusion); @@ -77,8 +92,10 @@ TEST_F(ShardingTest, ShardedAllocationDomain) { for (auto tv : {a, b, c, d}) { tv->setDeviceMesh(mesh); } - a->axis(1)->parallelize(ParallelType::DIDx); - c->axis(1)->parallelize(ParallelType::DIDx); + + int sharded_dim = 1; + a->axis(sharded_dim)->parallelize(ParallelType::DIDx); + c->axis(sharded_dim)->parallelize(ParallelType::DIDx); fusion.addInput(a); fusion.addInput(b); fusion.addOutput(d); @@ -90,10 +107,10 @@ TEST_F(ShardingTest, ShardedAllocationDomain) { for (auto expr : fusion.exprs()) { if (isResharding(expr)) { for (auto tv : ir_utils::filterByType(expr->inputs())) { - EXPECT_TRUE(tv->hasAllocation()); + isContiguous(tv); } for (auto tv : ir_utils::filterByType(expr->outputs())) { - EXPECT_TRUE(tv->hasAllocation()); + isContiguous(tv); } } }