From cf771310006bdc951cf75f4379c38ad1250ad31f Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Wed, 12 Feb 2025 14:56:55 -0800 Subject: [PATCH 01/70] lintrunner --- tests/cpp/test_multidevice_sharding.cpp | 39 ++++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/tests/cpp/test_multidevice_sharding.cpp b/tests/cpp/test_multidevice_sharding.cpp index 2309dc4cd36..e75f4fad280 100644 --- a/tests/cpp/test_multidevice_sharding.cpp +++ b/tests/cpp/test_multidevice_sharding.cpp @@ -745,12 +745,48 @@ TEST_F(MultiDeviceTest, ReorderDIDToFront) { __FILE__); } -TEST_F(MultiDeviceTest, TransformPropagatorSplitReshape) { +TEST_F(MultiDeviceTest, ReorderDIDToFront) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + const auto d = communicator_->size(); + auto mesh = DeviceMesh::createForNumDevices(d); + + const int64_t b = 2, s = 4, h = 16; + TensorView* in = makeConcreteTensor({b, s, d * h}); + TensorView* out = set(in); + fusion->addInput(in); + fusion->addOutput(out); + + for (auto* tv : {in, out}) { + tv->setDeviceMesh(mesh); + tv->split(-1, d, /*inner_split=*/false); + tv->axis(-2)->parallelize(ParallelType::DIDx); + reorderDIDToFront(tv); + tv->setAllocationDomain(tv->getLoopDomain(), true); + NVF_CHECK(tv->axis(0)->isDeviceDim()); + } + + at::Tensor in_tensor = at::randn({b, s, h}, tensor_options); + FusionExecutorCache executor_cache(std::move(fusion)); + at::Tensor out_tensor = executor_cache.runFusionWithInputs({in_tensor})[0]; + + testValidate( + executor_cache.fusion(), + {out_tensor}, + {in_tensor}, + {in_tensor}, + __LINE__, + __FILE__); +} + +TEST_F(MultiDeviceTest, TransformPropagatorWithReshape) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); const int d = communicator_->size(); const int64_t b = 2, s = 2, h = 4, e = 3; + const int64_t b = 2, s = 2, h = 4, e = 3; TensorView* tv0 = makeContigConcreteTensor( {b, s, d * h * e}); // in: loop domain: {b, s, d*h*e} @@ -772,6 +808,7 @@ TEST_F(MultiDeviceTest, TransformPropagatorSplitReshape) { MaxLogicalDomainInfoSpanningTree(tv1).traverse(&propagator_c2p); // in: loop domain: {b, s, d*h, e} after transform propagation + // Loop split and parallelize input tv0->setDeviceMesh(mesh); tv1->setDeviceMesh(mesh); From dddfe64e62fd741c595ec2e82b2a4f0ea973b640 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Wed, 12 Feb 2025 23:14:43 -0800 Subject: [PATCH 02/70] rm duplicate test from rebase --- tests/cpp/test_multidevice_sharding.cpp | 35 ------------------------- 1 file changed, 35 deletions(-) diff --git a/tests/cpp/test_multidevice_sharding.cpp b/tests/cpp/test_multidevice_sharding.cpp index e75f4fad280..c58e23980b3 100644 --- a/tests/cpp/test_multidevice_sharding.cpp +++ b/tests/cpp/test_multidevice_sharding.cpp @@ -745,41 +745,6 @@ TEST_F(MultiDeviceTest, ReorderDIDToFront) { __FILE__); } -TEST_F(MultiDeviceTest, ReorderDIDToFront) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - const auto d = communicator_->size(); - auto mesh = DeviceMesh::createForNumDevices(d); - - const int64_t b = 2, s = 4, h = 16; - TensorView* in = makeConcreteTensor({b, s, d * h}); - TensorView* out = set(in); - fusion->addInput(in); - fusion->addOutput(out); - - for (auto* tv : {in, out}) { - tv->setDeviceMesh(mesh); - tv->split(-1, d, /*inner_split=*/false); - tv->axis(-2)->parallelize(ParallelType::DIDx); - reorderDIDToFront(tv); - tv->setAllocationDomain(tv->getLoopDomain(), true); - NVF_CHECK(tv->axis(0)->isDeviceDim()); - } - - at::Tensor in_tensor = at::randn({b, s, h}, tensor_options); - FusionExecutorCache executor_cache(std::move(fusion)); - at::Tensor out_tensor = executor_cache.runFusionWithInputs({in_tensor})[0]; - - testValidate( - executor_cache.fusion(), - {out_tensor}, - {in_tensor}, - {in_tensor}, - __LINE__, - __FILE__); -} - TEST_F(MultiDeviceTest, TransformPropagatorWithReshape) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); From a66ab5561891ac261b93f0411fac11000a3fddfa Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Mon, 24 Feb 2025 00:10:18 -0800 Subject: [PATCH 03/70] split and merge reshape if ViewOp does not reshard --- csrc/scheduler/utils.cpp | 14 +-- tests/cpp/test_multidevice_sharding.cpp | 127 ++++++------------------ 2 files changed, 37 insertions(+), 104 deletions(-) diff --git a/csrc/scheduler/utils.cpp b/csrc/scheduler/utils.cpp index b4fba60331d..8b99432dc51 100644 --- a/csrc/scheduler/utils.cpp +++ b/csrc/scheduler/utils.cpp @@ -2337,16 +2337,10 @@ void propagateReshapeTransforms(Fusion* fusion, const ComputeAtMap& ca_map) { tv->reorder(old2new); //! Propagate current transformations on from_tv to all graphs transformPropagateToAllFrom(tv, (int64_t)old2new.size()); - - // Propgating the transforms will not replay the DIDx parallelization, so we - // need to do it manually here. - parallelizeAllLike( - tv, - /*pos=*/(int64_t)old2new.size(), - /*selected_tvs=*/{}, - /*selected_parallel_types=*/{ParallelType::DIDx}, - /*propagate_padding=*/false, - /*parallelize_inputs=*/true); + parallelizeAllLike(tv, (int64_t)old2new.size(), {}, {ParallelType::DIDx}); + } + for (auto tv : fusion->allTvs()) { + debug() << tv->toString() << std::endl; } } diff --git a/tests/cpp/test_multidevice_sharding.cpp b/tests/cpp/test_multidevice_sharding.cpp index c58e23980b3..35bf6f0d810 100644 --- a/tests/cpp/test_multidevice_sharding.cpp +++ b/tests/cpp/test_multidevice_sharding.cpp @@ -745,7 +745,7 @@ TEST_F(MultiDeviceTest, ReorderDIDToFront) { __FILE__); } -TEST_F(MultiDeviceTest, TransformPropagatorWithReshape) { +TEST_F(MultiDeviceTest, LoopShardingWithSplitReshape) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -855,43 +855,36 @@ TEST_F(MultiDeviceTest, LoopShardedSplitReshapeIds) { __FILE__); } -TEST_F(MultiDeviceTest, LoopShardedMergeReshapeIds) { +TEST_F(MultiDeviceTest, TransformPropagator) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); const int d = communicator_->size(); - const int64_t b = 2, s = 3, h = 8, e = 4; + const int64_t m = 5, n = 7; - TensorView* tv0 = makeContigConcreteTensor({b, s, d * h, e}); - TensorView* tv1 = reshape(tv0, {b, s, d * h, e}, {b, s, d * h * e}); + TensorView* in = makeContigConcreteTensor({d * m * n}); + TensorView* out = reshape(in, {d * m * n}, {d * m, n}); + TensorView* add_out = add(out, IrBuilder::create(1.0)); - fusion->addInput(tv0); - fusion->addOutput(tv1); + fusion->addInput(in); + fusion->addOutput(add_out); auto mesh = DeviceMesh::createForNumDevices(d); - tv0->setDeviceMesh(mesh); - tv0->split(-2, d, /*inner_split=*/false); - tv0->axis(-3)->parallelize(ParallelType::DIDx); - - tv1->setDeviceMesh(mesh); - tv1->split(-1, d, /*inner_split=*/false); - tv1->axis(-2)->parallelize(ParallelType::DIDx); - - for (auto* tv : {tv0, tv1}) { - reorderDIDToFront(tv); + for (auto* tv : {in, out, add_out}) { + tv->setDeviceMesh(mesh); + tv->split(0, d, /*inner_split=*/false); + tv->axis(0)->parallelize(ParallelType::DIDx); tv->setAllocationDomain(tv->getLoopDomain(), true); } FusionExecutorCache executor_cache(std::move(fusion)); - at::Tensor inp = at::randn({b, s, d * h, e}, tensor_options); - at::Tensor sharded_inp = shardTensor(inp, -2, mesh); - at::Tensor nvf_out = - executor_cache.runFusionWithInputs({sharded_inp})[0].as(); + at::Tensor in_tensor = at::randn({m * n}, tensor_options); + at::Tensor out_tensor = executor_cache.runFusionWithInputs({in_tensor})[0]; testValidate( executor_cache.fusion(), - {nvf_out}, - {sharded_inp}, - {sharded_inp.view({b, s, h * e})}, + {out_tensor}, + {in_tensor}, + {in_tensor.view({m, n}) + 1.0}, __LINE__, __FILE__); } @@ -968,87 +961,33 @@ TEST_F(MultiDeviceTest, TransformerFwd) { FusionGuard fg(fusion.get()); const int d = communicator_->size(); - const int64_t b = 2, s = 3, h = 8, e = 16; - auto mesh = DeviceMesh::createForNumDevices(d); + const int64_t m = 5, n = 7; - std::vector in_shape = {b, s, d * h * e}; - std::vector out_shape = {b, s, d * h, e}; - - // The transformer block produces hq/hk/hv after slicing the MHA linear - // output. - TensorView* hq = makeConcreteTensor(in_shape, DataType::Half); - TensorView* hk = makeConcreteTensor(in_shape, DataType::Half); - TensorView* hv = makeConcreteTensor(in_shape, DataType::Half); - - TensorView* q = reshape(hq, in_shape, out_shape); - TensorView* q_permuted = permute(q, {0, 2, 1, 3}); - TensorView* k = reshape(hk, in_shape, out_shape); - TensorView* k_permuted = permute(k, {0, 2, 1, 3}); - TensorView* v = reshape(hv, in_shape, out_shape); - TensorView* v_permuted = permute(v, {0, 2, 1, 3}); - - SdpfaFwdResult sdpa_out = sdpfa_fwd( - q_permuted, - k_permuted, - v_permuted, - /*dropout_p=*/IrBuilder::create(0.0), - /*is_causal=*/IrBuilder::create(false), - /*scale=*/nullptr); - - TensorView* attn = sdpa_out.output; - TensorView* attn_permute = permute(attn, {0, 2, 1, 3}); - TensorView* out = reshape(attn_permute, out_shape, in_shape); - - fusion->addInput(hq); - fusion->addInput(hk); - fusion->addInput(hv); + TensorView* in = makeContigConcreteTensor({d * m, n}); + TensorView* out = reshape(in, {d * m, n}, {d * m * n}); + // TensorView* add_out = add(out, IrBuilder::create(1.0)); + + fusion->addInput(in); fusion->addOutput(out); - // Shard input tensors - for (auto* tv : {hq, hk, hv}) { + auto mesh = DeviceMesh::createForNumDevices(d); + for (auto* tv : {in, out}) { tv->setDeviceMesh(mesh); - tv->split(-1, d, /*inner_split=*/false); - tv->axis(-2)->parallelize(ParallelType::DIDx); - reorderDIDToFront(tv); - } - propagateShardings(fusion.get(), d); - - for (auto tv : fusion->allTvs()) { + tv->split(0, d, /*inner_split=*/false); + tv->axis(0)->parallelize(ParallelType::DIDx); tv->setAllocationDomain(tv->getLoopDomain(), true); } FusionExecutorCache executor_cache(std::move(fusion)); - at::Tensor hq_tensor = at::randn({in_shape}, tensor_options.dtype(at::kHalf)); - at::Tensor hk_tensor = at::randn({in_shape}, tensor_options.dtype(at::kHalf)); - at::Tensor hv_tensor = at::randn({in_shape}, tensor_options.dtype(at::kHalf)); - - at::Tensor sharded_hq = shardTensor(hq_tensor, -1, mesh); - at::Tensor sharded_hk = shardTensor(hk_tensor, -1, mesh); - at::Tensor sharded_hv = shardTensor(hv_tensor, -1, mesh); - - auto nvf_out = - executor_cache - .runFusionWithInputs({sharded_hq, sharded_hk, sharded_hv})[0] - .as(); - - double scale = 1.0 / std::sqrt(e); - auto reference_out = at::_scaled_dot_product_flash_attention( - hq_tensor.view(out_shape).transpose(1, 2), - hk_tensor.view(out_shape).transpose(1, 2), - hv_tensor.view(out_shape).transpose(1, 2), - /*dropout_p=*/0.0, - /*is_causal=*/false, - /*return_debug_mask=*/false, - scale); - at::Tensor ref_attn = shardTensor( - std::get<0>(reference_out).transpose(1, 2).view(in_shape), -1, mesh); - + at::Tensor in_tensor = at::randn({m, n}, tensor_options); + at::Tensor out_tensor = executor_cache.runFusionWithInputs({in_tensor})[0]; testValidate( executor_cache.fusion(), - {nvf_out}, - {sharded_hq, sharded_hk, sharded_hv}, - {ref_attn}, + {out_tensor}, + {in_tensor}, + {in_tensor.view({m * n})}, __LINE__, __FILE__); } + } // namespace nvfuser From 41c6bcb739665e84708432b06ec6993a73341e93 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Wed, 26 Feb 2025 13:46:29 -0800 Subject: [PATCH 04/70] reshape-permute-sdpa-reshape block --- csrc/scheduler/utils.cpp | 3 - tests/cpp/test_multidevice_sharding.cpp | 183 ++++++++++++++++++++++-- 2 files changed, 168 insertions(+), 18 deletions(-) diff --git a/csrc/scheduler/utils.cpp b/csrc/scheduler/utils.cpp index 8b99432dc51..dd54f43a5d0 100644 --- a/csrc/scheduler/utils.cpp +++ b/csrc/scheduler/utils.cpp @@ -2339,9 +2339,6 @@ void propagateReshapeTransforms(Fusion* fusion, const ComputeAtMap& ca_map) { transformPropagateToAllFrom(tv, (int64_t)old2new.size()); parallelizeAllLike(tv, (int64_t)old2new.size(), {}, {ParallelType::DIDx}); } - for (auto tv : fusion->allTvs()) { - debug() << tv->toString() << std::endl; - } } bool isFastestDimReduction(TensorView* tv) { diff --git a/tests/cpp/test_multidevice_sharding.cpp b/tests/cpp/test_multidevice_sharding.cpp index 35bf6f0d810..849c7a229a6 100644 --- a/tests/cpp/test_multidevice_sharding.cpp +++ b/tests/cpp/test_multidevice_sharding.cpp @@ -15,6 +15,7 @@ #include #include #include +#include namespace nvfuser { @@ -860,33 +861,44 @@ TEST_F(MultiDeviceTest, TransformPropagator) { FusionGuard fg(fusion.get()); const int d = communicator_->size(); - const int64_t m = 5, n = 7; + const int64_t b = 2, s = 3, h = 8, e = 4; - TensorView* in = makeContigConcreteTensor({d * m * n}); - TensorView* out = reshape(in, {d * m * n}, {d * m, n}); - TensorView* add_out = add(out, IrBuilder::create(1.0)); + std::vector in_shape = {b, s, d * h * e}; + std::vector out_shape = {b, s, d * h, e}; + TensorView* in = makeContigConcreteTensor(in_shape); + TensorView* out = reshape(in, in_shape, out_shape); + // TensorView* add_out = add(out, IrBuilder::create(1.0)); fusion->addInput(in); - fusion->addOutput(add_out); + fusion->addOutput(out); auto mesh = DeviceMesh::createForNumDevices(d); - for (auto* tv : {in, out, add_out}) { - tv->setDeviceMesh(mesh); - tv->split(0, d, /*inner_split=*/false); - tv->axis(0)->parallelize(ParallelType::DIDx); - tv->setAllocationDomain(tv->getLoopDomain(), true); - } + + in->setDeviceMesh(mesh); + in->split(-1, d, /*inner_split=*/false); + in->axis(-2)->parallelize(ParallelType::DIDx); + reorderDIDToFront(in); + in->setAllocationDomain(in->getLoopDomain(), true); + + out->setDeviceMesh(mesh); + out->split(-2, d, /*inner_split=*/false); + out->axis(-3)->parallelize(ParallelType::DIDx); + reorderDIDToFront(out); + out->setAllocationDomain(out->getLoopDomain(), true); FusionExecutorCache executor_cache(std::move(fusion)); - at::Tensor in_tensor = at::randn({m * n}, tensor_options); - at::Tensor out_tensor = executor_cache.runFusionWithInputs({in_tensor})[0]; + at::Tensor in_tensor = at::randn(in_shape, tensor_options); + at::Tensor sharded_in = shardTensor(in_tensor, -1, mesh); + + at::Tensor out_tensor = executor_cache.runFusionWithInputs({sharded_in})[0]; testValidate( executor_cache.fusion(), {out_tensor}, - {in_tensor}, - {in_tensor.view({m, n}) + 1.0}, + {sharded_in}, + {sharded_in.view({b, s, h, e})}, __LINE__, __FILE__); + } namespace { @@ -990,4 +1002,145 @@ TEST_F(MultiDeviceTest, TransformerFwd) { __FILE__); } +TEST_F(MultiDeviceTest, TransformerFwd) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + const int d = communicator_->size(); + const int64_t b = 1, s = 3, h = 8, e = 16; + auto mesh = DeviceMesh::createForNumDevices(d); + + std::vector in_shape = {b, s, d*h*e}; + std::vector out_shape = {b, s, d*h, e}; + TensorView* hq = makeConcreteTensor(in_shape, DataType::Half); + TensorView* hk = makeConcreteTensor(in_shape, DataType::Half); + TensorView* hv = makeConcreteTensor(in_shape, DataType::Half); + + TensorView* q = reshape(hq, in_shape, out_shape); + TensorView* q_permuted = permute(q, {0, 2, 1, 3}); + TensorView* k = reshape(hk, in_shape, out_shape); + TensorView* k_permuted = permute(k, {0, 2, 1, 3}); + TensorView* v = reshape(hv, in_shape, out_shape); + TensorView* v_permuted = permute(v, {0, 2, 1, 3}); + + SdpfaFwdResult sdpa_out = sdpfa_fwd( + q_permuted, + k_permuted, + v_permuted, + /*dropout_p=*/IrBuilder::create(0.0), + /*is_causal=*/IrBuilder::create(false), + /*scale=*/nullptr); + + TensorView* attn = sdpa_out.output; + TensorView* attn_permute = permute(attn, {0, 2, 1, 3}); + TensorView* out = reshape(attn_permute, out_shape, in_shape); + + fusion->addInput(hq); + fusion->addInput(hk); + fusion->addInput(hv); + fusion->addOutput(out); + + // Shard input tensors + for (auto* tv : {hq, hk, hv}) { + tv->setDeviceMesh(mesh); + tv->split(-1, d, /*inner_split=*/false); + tv->axis(-2)->parallelize(ParallelType::DIDx); + reorderDIDToFront(tv); + } + + // Emulate what we will eventually do in the pre-segmentation pass + for (Expr* expr: fusion->exprs()) { + debug() << "expr: " << expr->toString() << std::endl; + debug() << "Before: " << std::endl; + for (auto out: expr->outputs()) { + auto tv = out->as(); + debug() << "output: " << tv->toString() << std::endl; + debug() << tv->domain()->toString(0, false) << std::endl; + } + + if (expr->isA()) { + // TransformPropagator cannot be directly used. It raises an error for conflicting transformations from root domain to logical domain. + // Instead, we manually find the reshaped iterdomain and outer split DID. + // This might have to be extended further. + TensorView* reshaped_tv = expr->as()->out(); + auto transform_exprs = StmtSort::getExprsBetween( + {reshaped_tv->getMaybeRootDomain().begin(), reshaped_tv->getMaybeRootDomain().end()}, + {reshaped_tv->getLogicalDomain().begin(), reshaped_tv->getLogicalDomain().end()}); + NVF_CHECK(transform_exprs.size() == 1); + NVF_CHECK(transform_exprs.at(0)->isA() || transform_exprs.at(0)->isA()); + + IterDomain* sharded_id = transform_exprs.at(0)->isA() ? transform_exprs.at(0)->as()->outer() : transform_exprs.at(0)->as()->out(); + + int64_t sharded_axis = -1; + for (const auto i : c10::irange(reshaped_tv->nDims())) { + if (reshaped_tv->axis(i) == sharded_id) { + sharded_axis = i; + break; + } + } + NVF_CHECK(sharded_axis != -1); + reshaped_tv->split(sharded_axis, d, /*inner_split=*/false); + reshaped_tv->axis(sharded_axis)->parallelize(ParallelType::DIDx); + reorderDIDToFront(reshaped_tv); + } + else { + std::vector output_tvs; + for (auto output : expr->outputs()) { + output_tvs.push_back(output->as()); + } + + TransformPropagator propagator_c2p(expr->input(0)->as()); + + // Note: We will finally propagate from each input iteratively. + SetSelector selector(std::unordered_set(output_tvs.begin(), output_tvs.end())); + MaxLogicalDomainInfoSpanningTree(expr->input(0)->as(), &selector).traverse(&propagator_c2p); + scheduler_utils::parallelizeAllLike( + expr->input(0)->as(), + /*pos=*/-1, + /*selected_tv=*/output_tvs); + } + debug() << "After: " << std::endl; + for (auto out: expr->outputs()) { + auto tv = out->as(); + debug() << "output: " << tv->toString() << std::endl; + debug() << tv->domain()->toString(0, false) << std::endl; + } + + } + for (auto tv: fusion->allTvs()) { + tv->setAllocationDomain(tv->getLoopDomain(), true); + } + + FusionExecutorCache executor_cache(std::move(fusion)); + at::Tensor hq_tensor = at::randn({in_shape}, tensor_options.dtype(at::kHalf)); + at::Tensor hk_tensor = at::randn({in_shape}, tensor_options.dtype(at::kHalf)); + at::Tensor hv_tensor = at::randn({in_shape}, tensor_options.dtype(at::kHalf)); + + at::Tensor sharded_hq = shardTensor(hq_tensor, -1, mesh); + at::Tensor sharded_hk = shardTensor(hk_tensor, -1, mesh); + at::Tensor sharded_hv = shardTensor(hv_tensor, -1, mesh); + + at::Tensor nvf_out = executor_cache.runFusionWithInputs( + {sharded_hq, + sharded_hk, + sharded_hv})[0]; + + double scale = 1.0 / std::sqrt(e); + auto reference_out = at::_scaled_dot_product_flash_attention( + hq_tensor.view(out_shape).transpose(1, 2), + hk_tensor.view(out_shape).transpose(1, 2), + hv_tensor.view(out_shape).transpose(1, 2), + /*dropout_p=*/0.0, + /*is_causal=*/false, + /*return_debug_mask=*/false, + /*scale=*/scale); + at::Tensor ref_attn = shardTensor(std::get<0>(reference_out).transpose(1, 2).view(in_shape), -1, mesh); + + testValidate( + executor_cache.fusion(), + {nvf_out}, + {sharded_hq, sharded_hk, sharded_hv}, + {ref_attn}, + __LINE__, __FILE__); +} } // namespace nvfuser From a238bab9b63cd17267c56f55dac5d901c1a188cd Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Wed, 26 Feb 2025 15:58:04 -0800 Subject: [PATCH 05/70] clean test --- tests/cpp/test_multidevice_sharding.cpp | 207 ++++++++++++++---------- 1 file changed, 122 insertions(+), 85 deletions(-) diff --git a/tests/cpp/test_multidevice_sharding.cpp b/tests/cpp/test_multidevice_sharding.cpp index 849c7a229a6..772ea2d8a8b 100644 --- a/tests/cpp/test_multidevice_sharding.cpp +++ b/tests/cpp/test_multidevice_sharding.cpp @@ -746,7 +746,7 @@ TEST_F(MultiDeviceTest, ReorderDIDToFront) { __FILE__); } -TEST_F(MultiDeviceTest, LoopShardingWithSplitReshape) { +TEST_F(MultiDeviceTest, TransformPropagatorSplitReshape) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -856,35 +856,28 @@ TEST_F(MultiDeviceTest, LoopShardedSplitReshapeIds) { __FILE__); } -TEST_F(MultiDeviceTest, TransformPropagator) { +TEST_F(MultiDeviceTest, LoopShardedSplitReshapeIds) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); const int d = communicator_->size(); const int64_t b = 2, s = 3, h = 8, e = 4; - std::vector in_shape = {b, s, d * h * e}; - std::vector out_shape = {b, s, d * h, e}; - TensorView* in = makeContigConcreteTensor(in_shape); - TensorView* out = reshape(in, in_shape, out_shape); - // TensorView* add_out = add(out, IrBuilder::create(1.0)); + TensorView* tv0 = makeContigConcreteTensor({b, s, d * h * e}); + TensorView* tv1 = reshape(tv0, {b, s, d * h * e}, {b, s, d * h, e}); - fusion->addInput(in); - fusion->addOutput(out); + fusion->addInput(tv0); + fusion->addOutput(tv1); auto mesh = DeviceMesh::createForNumDevices(d); - in->setDeviceMesh(mesh); - in->split(-1, d, /*inner_split=*/false); - in->axis(-2)->parallelize(ParallelType::DIDx); - reorderDIDToFront(in); - in->setAllocationDomain(in->getLoopDomain(), true); + tv0->setDeviceMesh(mesh); + tv0->split(-1, d, /*inner_split=*/false); + tv0->axis(-2)->parallelize(ParallelType::DIDx); - out->setDeviceMesh(mesh); - out->split(-2, d, /*inner_split=*/false); - out->axis(-3)->parallelize(ParallelType::DIDx); - reorderDIDToFront(out); - out->setAllocationDomain(out->getLoopDomain(), true); + tv1->setDeviceMesh(mesh); + tv1->split(-2, d, /*inner_split=*/false); + tv1->axis(-3)->parallelize(ParallelType::DIDx); FusionExecutorCache executor_cache(std::move(fusion)); at::Tensor in_tensor = at::randn(in_shape, tensor_options); @@ -991,27 +984,129 @@ TEST_F(MultiDeviceTest, TransformerFwd) { } FusionExecutorCache executor_cache(std::move(fusion)); - at::Tensor in_tensor = at::randn({m, n}, tensor_options); - at::Tensor out_tensor = executor_cache.runFusionWithInputs({in_tensor})[0]; + at::Tensor inp = at::randn({b, s, d * h * e}, tensor_options); + at::Tensor sharded_inp= shardTensor(inp, tv0); + + at::Tensor nvf_out = executor_cache.runFusionWithInputs({sharded_inp})[0]; testValidate( executor_cache.fusion(), - {out_tensor}, - {in_tensor}, - {in_tensor.view({m * n})}, + {nvf_out}, + {sharded_inp}, + {sharded_inp.view({b, s, h, e})}, __LINE__, __FILE__); + +} + +TEST_F(MultiDeviceTest, LoopShardedMergeReshapeIds) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + const int d = communicator_->size(); + const int64_t b = 2, s = 3, h = 8, e = 4; + + TensorView* tv0 = makeContigConcreteTensor({b, s, d * h, e}); + TensorView* tv1 = reshape(tv0, {b, s, d * h, e}, {b, s, d * h * e}); + + fusion->addInput(tv0); + fusion->addOutput(tv1); + + auto mesh = DeviceMesh::createForNumDevices(d); + tv0->setDeviceMesh(mesh); + tv0->split(-2, d, /*inner_split=*/false); + tv0->axis(-3)->parallelize(ParallelType::DIDx); + + tv1->setDeviceMesh(mesh); + tv1->split(-1, d, /*inner_split=*/false); + tv1->axis(-2)->parallelize(ParallelType::DIDx); + + for (auto* tv : {tv0, tv1}) { + reorderDIDToFront(tv); + tv->setAllocationDomain(tv->getLoopDomain(), true); + } + + FusionExecutorCache executor_cache(std::move(fusion)); + at::Tensor inp = at::randn({b, s, d * h, e}, tensor_options); + at::Tensor sharded_inp = shardTensor(inp, tv0); + at::Tensor nvf_out = executor_cache.runFusionWithInputs({sharded_inp})[0]; + testValidate( + executor_cache.fusion(), + {nvf_out}, + {sharded_inp}, + {inp.view({b, s, h * e})}, + __LINE__, + __FILE__); +} + +namespace { +// This is a simplified version of what we will eventually do in the pre-segmentation pass +void propagateShardings(Fusion* fusion, int64_t num_devices) { + for (Expr* expr: fusion->exprs()) { + if (expr->isA()) { + NVF_THROW("SliceOp is not currently supported"); + } + + if (expr->isA()) { + // TransformPropagator cannot be directly used. + // It raises an error for conflicting transformations from root domain to logical domain. + // Instead, we manually find the reshaped iterdomain and outer split DID. + // This might have to be extended further in the presegmentation pass. + TensorView* reshaped_tv = expr->as()->out(); + auto transform_exprs = StmtSort::getExprsBetween( + {reshaped_tv->getMaybeRootDomain().begin(), reshaped_tv->getMaybeRootDomain().end()}, + {reshaped_tv->getLogicalDomain().begin(), reshaped_tv->getLogicalDomain().end()}); + NVF_CHECK(transform_exprs.size() == 1); + auto transform = transform_exprs[0]; + NVF_CHECK(transform->isA() || transform->isA()); + + // Get the sharded ID and its axis position + IterDomain* sharded_id = transform->isA() ? + transform->as()->outer() : + transform->as()->out(); + + auto sharded_it = std::find(reshaped_tv->getLoopDomain().begin(), reshaped_tv->getLoopDomain().end(), sharded_id); + int64_t sharded_axis = std::distance(reshaped_tv->getLoopDomain().begin(), sharded_it); + + // Apply sharding to the reshaped tensor + reshaped_tv->split(sharded_axis, num_devices, false); + reshaped_tv->axis(sharded_axis)->parallelize(ParallelType::DIDx); + reorderDIDToFront(reshaped_tv); + continue; + } + + // For other ops, propagate sharding from input to outputs + auto input_tv = expr->input(0)->as(); + std::vector output_tvs; + for (auto output : expr->outputs()) { + output_tvs.push_back(output->as()); + } + + TransformPropagator propagator(input_tv); + + // Note: We will finally propagate from each input iteratively. + SetSelector selector(std::unordered_set(output_tvs.begin(), output_tvs.end())); + MaxLogicalDomainInfoSpanningTree(input_tv, &selector).traverse(&propagator); + scheduler_utils::parallelizeAllLike( + input_tv, + /*pos=*/-1, + /*selected_tv=*/output_tvs); + } } + +} // namespace TEST_F(MultiDeviceTest, TransformerFwd) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); const int d = communicator_->size(); - const int64_t b = 1, s = 3, h = 8, e = 16; + const int64_t b = 2, s = 3, h = 8, e = 16; auto mesh = DeviceMesh::createForNumDevices(d); std::vector in_shape = {b, s, d*h*e}; std::vector out_shape = {b, s, d*h, e}; + + // The transformer block produces hq/hk/hv after slicing the MHA linear output. TensorView* hq = makeConcreteTensor(in_shape, DataType::Half); TensorView* hk = makeConcreteTensor(in_shape, DataType::Half); TensorView* hv = makeConcreteTensor(in_shape, DataType::Half); @@ -1047,66 +1142,8 @@ TEST_F(MultiDeviceTest, TransformerFwd) { tv->axis(-2)->parallelize(ParallelType::DIDx); reorderDIDToFront(tv); } + propagateShardings(fusion.get(), d); - // Emulate what we will eventually do in the pre-segmentation pass - for (Expr* expr: fusion->exprs()) { - debug() << "expr: " << expr->toString() << std::endl; - debug() << "Before: " << std::endl; - for (auto out: expr->outputs()) { - auto tv = out->as(); - debug() << "output: " << tv->toString() << std::endl; - debug() << tv->domain()->toString(0, false) << std::endl; - } - - if (expr->isA()) { - // TransformPropagator cannot be directly used. It raises an error for conflicting transformations from root domain to logical domain. - // Instead, we manually find the reshaped iterdomain and outer split DID. - // This might have to be extended further. - TensorView* reshaped_tv = expr->as()->out(); - auto transform_exprs = StmtSort::getExprsBetween( - {reshaped_tv->getMaybeRootDomain().begin(), reshaped_tv->getMaybeRootDomain().end()}, - {reshaped_tv->getLogicalDomain().begin(), reshaped_tv->getLogicalDomain().end()}); - NVF_CHECK(transform_exprs.size() == 1); - NVF_CHECK(transform_exprs.at(0)->isA() || transform_exprs.at(0)->isA()); - - IterDomain* sharded_id = transform_exprs.at(0)->isA() ? transform_exprs.at(0)->as()->outer() : transform_exprs.at(0)->as()->out(); - - int64_t sharded_axis = -1; - for (const auto i : c10::irange(reshaped_tv->nDims())) { - if (reshaped_tv->axis(i) == sharded_id) { - sharded_axis = i; - break; - } - } - NVF_CHECK(sharded_axis != -1); - reshaped_tv->split(sharded_axis, d, /*inner_split=*/false); - reshaped_tv->axis(sharded_axis)->parallelize(ParallelType::DIDx); - reorderDIDToFront(reshaped_tv); - } - else { - std::vector output_tvs; - for (auto output : expr->outputs()) { - output_tvs.push_back(output->as()); - } - - TransformPropagator propagator_c2p(expr->input(0)->as()); - - // Note: We will finally propagate from each input iteratively. - SetSelector selector(std::unordered_set(output_tvs.begin(), output_tvs.end())); - MaxLogicalDomainInfoSpanningTree(expr->input(0)->as(), &selector).traverse(&propagator_c2p); - scheduler_utils::parallelizeAllLike( - expr->input(0)->as(), - /*pos=*/-1, - /*selected_tv=*/output_tvs); - } - debug() << "After: " << std::endl; - for (auto out: expr->outputs()) { - auto tv = out->as(); - debug() << "output: " << tv->toString() << std::endl; - debug() << tv->domain()->toString(0, false) << std::endl; - } - - } for (auto tv: fusion->allTvs()) { tv->setAllocationDomain(tv->getLoopDomain(), true); } @@ -1143,4 +1180,4 @@ TEST_F(MultiDeviceTest, TransformerFwd) { {ref_attn}, __LINE__, __FILE__); } -} // namespace nvfuser +} // namespace From 02d32cf425289f5ffe80a994506616dddf990637 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Tue, 4 Mar 2025 14:27:25 -0800 Subject: [PATCH 06/70] add war to identify view op as not resharding, comment --- tests/cpp/test_multidevice_sharding.cpp | 103 +++++++++++++----------- 1 file changed, 57 insertions(+), 46 deletions(-) diff --git a/tests/cpp/test_multidevice_sharding.cpp b/tests/cpp/test_multidevice_sharding.cpp index 772ea2d8a8b..ccc3cd0e40e 100644 --- a/tests/cpp/test_multidevice_sharding.cpp +++ b/tests/cpp/test_multidevice_sharding.cpp @@ -9,13 +9,13 @@ #include #include +#include #include #include #include #include #include #include -#include namespace nvfuser { @@ -985,17 +985,16 @@ TEST_F(MultiDeviceTest, TransformerFwd) { FusionExecutorCache executor_cache(std::move(fusion)); at::Tensor inp = at::randn({b, s, d * h * e}, tensor_options); - at::Tensor sharded_inp= shardTensor(inp, tv0); + at::Tensor sharded_inp = shardTensor(inp, tv0); at::Tensor nvf_out = executor_cache.runFusionWithInputs({sharded_inp})[0]; testValidate( executor_cache.fusion(), {nvf_out}, - {sharded_inp}, + {sharded_inp}, {sharded_inp.view({b, s, h, e})}, __LINE__, __FILE__); - } TEST_F(MultiDeviceTest, LoopShardedMergeReshapeIds) { @@ -1039,41 +1038,51 @@ TEST_F(MultiDeviceTest, LoopShardedMergeReshapeIds) { } namespace { -// This is a simplified version of what we will eventually do in the pre-segmentation pass +// This is a simplified version of what we will eventually do in the +// pre-segmentation pass void propagateShardings(Fusion* fusion, int64_t num_devices) { - for (Expr* expr: fusion->exprs()) { + for (Expr* expr : fusion->exprs()) { if (expr->isA()) { NVF_THROW("SliceOp is not currently supported"); } if (expr->isA()) { - // TransformPropagator cannot be directly used. - // It raises an error for conflicting transformations from root domain to logical domain. - // Instead, we manually find the reshaped iterdomain and outer split DID. - // This might have to be extended further in the presegmentation pass. + // TransformPropagator cannot be directly used. + // It raises an error for conflicting transformations from root domain to + // logical domain. Instead, we manually find the reshaped iterdomain and + // outer split DID. This might have to be extended further in the + // presegmentation pass. + // Note: For simplicity, this assumes that the sharding is on reshaped IDs. It is possible that the non-reshaped IDs are sharded, in which case we can use the TransformPropagator. TensorView* reshaped_tv = expr->as()->out(); auto transform_exprs = StmtSort::getExprsBetween( - {reshaped_tv->getMaybeRootDomain().begin(), reshaped_tv->getMaybeRootDomain().end()}, - {reshaped_tv->getLogicalDomain().begin(), reshaped_tv->getLogicalDomain().end()}); + {reshaped_tv->getMaybeRootDomain().begin(), + reshaped_tv->getMaybeRootDomain().end()}, + {reshaped_tv->getLogicalDomain().begin(), + reshaped_tv->getLogicalDomain().end()}); NVF_CHECK(transform_exprs.size() == 1); auto transform = transform_exprs[0]; NVF_CHECK(transform->isA() || transform->isA()); - // Get the sharded ID and its axis position - IterDomain* sharded_id = transform->isA() ? - transform->as()->outer() : - transform->as()->out(); + // Get the reshaped ID (outer ID for split reshape). + // This is the ID that will be parallelized. + IterDomain* reshaped_id = transform->isA() + ? transform->as()->outer() + : transform->as()->out(); - auto sharded_it = std::find(reshaped_tv->getLoopDomain().begin(), reshaped_tv->getLoopDomain().end(), sharded_id); - int64_t sharded_axis = std::distance(reshaped_tv->getLoopDomain().begin(), sharded_it); + auto reshaped_it = std::find( + reshaped_tv->getLoopDomain().begin(), + reshaped_tv->getLoopDomain().end(), + reshaped_id); + int64_t reshaped_axis = + std::distance(reshaped_tv->getLoopDomain().begin(), reshaped_it); // Apply sharding to the reshaped tensor - reshaped_tv->split(sharded_axis, num_devices, false); - reshaped_tv->axis(sharded_axis)->parallelize(ParallelType::DIDx); + reshaped_tv->split(reshaped_axis, num_devices, false); + reshaped_tv->axis(reshaped_axis)->parallelize(ParallelType::DIDx); reorderDIDToFront(reshaped_tv); continue; } - + // For other ops, propagate sharding from input to outputs auto input_tv = expr->input(0)->as(); std::vector output_tvs; @@ -1082,9 +1091,10 @@ void propagateShardings(Fusion* fusion, int64_t num_devices) { } TransformPropagator propagator(input_tv); - + // Note: We will finally propagate from each input iteratively. - SetSelector selector(std::unordered_set(output_tvs.begin(), output_tvs.end())); + SetSelector selector( + std::unordered_set(output_tvs.begin(), output_tvs.end())); MaxLogicalDomainInfoSpanningTree(input_tv, &selector).traverse(&propagator); scheduler_utils::parallelizeAllLike( input_tv, @@ -1092,7 +1102,7 @@ void propagateShardings(Fusion* fusion, int64_t num_devices) { /*selected_tv=*/output_tvs); } } - + } // namespace TEST_F(MultiDeviceTest, TransformerFwd) { @@ -1103,10 +1113,11 @@ TEST_F(MultiDeviceTest, TransformerFwd) { const int64_t b = 2, s = 3, h = 8, e = 16; auto mesh = DeviceMesh::createForNumDevices(d); - std::vector in_shape = {b, s, d*h*e}; - std::vector out_shape = {b, s, d*h, e}; + std::vector in_shape = {b, s, d * h * e}; + std::vector out_shape = {b, s, d * h, e}; - // The transformer block produces hq/hk/hv after slicing the MHA linear output. + // The transformer block produces hq/hk/hv after slicing the MHA linear + // output. TensorView* hq = makeConcreteTensor(in_shape, DataType::Half); TensorView* hk = makeConcreteTensor(in_shape, DataType::Half); TensorView* hv = makeConcreteTensor(in_shape, DataType::Half); @@ -1119,12 +1130,12 @@ TEST_F(MultiDeviceTest, TransformerFwd) { TensorView* v_permuted = permute(v, {0, 2, 1, 3}); SdpfaFwdResult sdpa_out = sdpfa_fwd( - q_permuted, - k_permuted, - v_permuted, - /*dropout_p=*/IrBuilder::create(0.0), - /*is_causal=*/IrBuilder::create(false), - /*scale=*/nullptr); + q_permuted, + k_permuted, + v_permuted, + /*dropout_p=*/IrBuilder::create(0.0), + /*is_causal=*/IrBuilder::create(false), + /*scale=*/nullptr); TensorView* attn = sdpa_out.output; TensorView* attn_permute = permute(attn, {0, 2, 1, 3}); @@ -1144,7 +1155,7 @@ TEST_F(MultiDeviceTest, TransformerFwd) { } propagateShardings(fusion.get(), d); - for (auto tv: fusion->allTvs()) { + for (auto tv : fusion->allTvs()) { tv->setAllocationDomain(tv->getLoopDomain(), true); } @@ -1152,32 +1163,32 @@ TEST_F(MultiDeviceTest, TransformerFwd) { at::Tensor hq_tensor = at::randn({in_shape}, tensor_options.dtype(at::kHalf)); at::Tensor hk_tensor = at::randn({in_shape}, tensor_options.dtype(at::kHalf)); at::Tensor hv_tensor = at::randn({in_shape}, tensor_options.dtype(at::kHalf)); - + at::Tensor sharded_hq = shardTensor(hq_tensor, -1, mesh); at::Tensor sharded_hk = shardTensor(hk_tensor, -1, mesh); at::Tensor sharded_hv = shardTensor(hv_tensor, -1, mesh); - + at::Tensor nvf_out = executor_cache.runFusionWithInputs( - {sharded_hq, - sharded_hk, - sharded_hv})[0]; + {sharded_hq, sharded_hk, sharded_hv})[0]; double scale = 1.0 / std::sqrt(e); auto reference_out = at::_scaled_dot_product_flash_attention( hq_tensor.view(out_shape).transpose(1, 2), hk_tensor.view(out_shape).transpose(1, 2), hv_tensor.view(out_shape).transpose(1, 2), - /*dropout_p=*/0.0, - /*is_causal=*/false, - /*return_debug_mask=*/false, - /*scale=*/scale); - at::Tensor ref_attn = shardTensor(std::get<0>(reference_out).transpose(1, 2).view(in_shape), -1, mesh); + /*dropout_p=*/0.0, + /*is_causal=*/false, + /*return_debug_mask=*/false, + /*scale=*/scale); + at::Tensor ref_attn = shardTensor( + std::get<0>(reference_out).transpose(1, 2).view(in_shape), -1, mesh); testValidate( executor_cache.fusion(), {nvf_out}, {sharded_hq, sharded_hk, sharded_hv}, {ref_attn}, - __LINE__, __FILE__); + __LINE__, + __FILE__); } -} // namespace +} // namespace nvfuser From f2aff84a87d71bdb1e3e2dc1beb41c01e6dae19c Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Tue, 4 Mar 2025 15:17:03 -0800 Subject: [PATCH 07/70] add parallelize input flag --- csrc/scheduler/utils.cpp | 10 +++++++++- tests/cpp/test_multidevice_sharding.cpp | 26 +++++++++++++------------ 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/csrc/scheduler/utils.cpp b/csrc/scheduler/utils.cpp index dd54f43a5d0..319694f510d 100644 --- a/csrc/scheduler/utils.cpp +++ b/csrc/scheduler/utils.cpp @@ -2337,7 +2337,15 @@ void propagateReshapeTransforms(Fusion* fusion, const ComputeAtMap& ca_map) { tv->reorder(old2new); //! Propagate current transformations on from_tv to all graphs transformPropagateToAllFrom(tv, (int64_t)old2new.size()); - parallelizeAllLike(tv, (int64_t)old2new.size(), {}, {ParallelType::DIDx}); + // Propgating the transforms will not replay the DIDx parallelization, so we + // need to do it manually here. + parallelizeAllLike( + tv, + /*pos=*/(int64_t)old2new.size(), + /*selected_tvs=*/{}, + /*selected_parallel_types=*/{ParallelType::DIDx}, + /*propagate_padding=*/false, + /*parallelize_inputs=*/true); } } diff --git a/tests/cpp/test_multidevice_sharding.cpp b/tests/cpp/test_multidevice_sharding.cpp index ccc3cd0e40e..64e9b405fb7 100644 --- a/tests/cpp/test_multidevice_sharding.cpp +++ b/tests/cpp/test_multidevice_sharding.cpp @@ -798,12 +798,9 @@ TEST_F(MultiDeviceTest, TransformPropagatorSplitReshape) { tv1->setAllocationDomain(tv1->getLoopDomain(), true); FusionExecutorCache executor_cache(std::move(fusion)); - at::Tensor inp = at::randn({b, s, d * h * e}, tensor_options); - at::Tensor sharded_inp = shardTensor(inp, tv0); - - at::Tensor nvf_out = - executor_cache.runFusionWithInputs({sharded_inp})[0].as(); - + at::Tensor in_tensor = at::randn({b, s, h * e}, tensor_options); + at::Tensor out_tensor = + executor_cache.runFusionWithInputs({in_tensor})[0].as(); testValidate( executor_cache.fusion(), {nvf_out}, @@ -1026,13 +1023,14 @@ TEST_F(MultiDeviceTest, LoopShardedMergeReshapeIds) { FusionExecutorCache executor_cache(std::move(fusion)); at::Tensor inp = at::randn({b, s, d * h, e}, tensor_options); - at::Tensor sharded_inp = shardTensor(inp, tv0); - at::Tensor nvf_out = executor_cache.runFusionWithInputs({sharded_inp})[0]; + at::Tensor sharded_inp = shardTensor(inp, -2, mesh); + at::Tensor nvf_out = + executor_cache.runFusionWithInputs({sharded_inp})[0].as(); testValidate( executor_cache.fusion(), {nvf_out}, {sharded_inp}, - {inp.view({b, s, h * e})}, + {sharded_inp.view({b, s, h * e})}, __LINE__, __FILE__); } @@ -1052,7 +1050,9 @@ void propagateShardings(Fusion* fusion, int64_t num_devices) { // logical domain. Instead, we manually find the reshaped iterdomain and // outer split DID. This might have to be extended further in the // presegmentation pass. - // Note: For simplicity, this assumes that the sharding is on reshaped IDs. It is possible that the non-reshaped IDs are sharded, in which case we can use the TransformPropagator. + // Note: For simplicity, this assumes that the sharding is on reshaped + // IDs. It is possible that the non-reshaped IDs are sharded, in which + // case we can use the TransformPropagator. TensorView* reshaped_tv = expr->as()->out(); auto transform_exprs = StmtSort::getExprsBetween( {reshaped_tv->getMaybeRootDomain().begin(), @@ -1168,8 +1168,10 @@ TEST_F(MultiDeviceTest, TransformerFwd) { at::Tensor sharded_hk = shardTensor(hk_tensor, -1, mesh); at::Tensor sharded_hv = shardTensor(hv_tensor, -1, mesh); - at::Tensor nvf_out = executor_cache.runFusionWithInputs( - {sharded_hq, sharded_hk, sharded_hv})[0]; + at::Tensor nvf_out = + executor_cache + .runFusionWithInputs({sharded_hq, sharded_hk, sharded_hv})[0] + .as(); double scale = 1.0 / std::sqrt(e); auto reference_out = at::_scaled_dot_product_flash_attention( From aa7aada2d072d96d9bf791606e6b9dc9059e07bb Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Tue, 4 Mar 2025 15:19:49 -0800 Subject: [PATCH 08/70] rm import --- tests/cpp/test_multidevice_sharding.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/cpp/test_multidevice_sharding.cpp b/tests/cpp/test_multidevice_sharding.cpp index 64e9b405fb7..3ab316582a9 100644 --- a/tests/cpp/test_multidevice_sharding.cpp +++ b/tests/cpp/test_multidevice_sharding.cpp @@ -9,7 +9,6 @@ #include #include -#include #include #include #include From 92f6ff9cda7b2df85134cf568d7bc086708dc72d Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Wed, 5 Mar 2025 22:51:05 -0800 Subject: [PATCH 09/70] fix rebase --- tests/cpp/test_multidevice_sharding.cpp | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/tests/cpp/test_multidevice_sharding.cpp b/tests/cpp/test_multidevice_sharding.cpp index 3ab316582a9..e2e1b1197b8 100644 --- a/tests/cpp/test_multidevice_sharding.cpp +++ b/tests/cpp/test_multidevice_sharding.cpp @@ -751,7 +751,6 @@ TEST_F(MultiDeviceTest, TransformPropagatorSplitReshape) { const int d = communicator_->size(); const int64_t b = 2, s = 2, h = 4, e = 3; - const int64_t b = 2, s = 2, h = 4, e = 3; TensorView* tv0 = makeContigConcreteTensor( {b, s, d * h * e}); // in: loop domain: {b, s, d*h*e} @@ -773,7 +772,6 @@ TEST_F(MultiDeviceTest, TransformPropagatorSplitReshape) { MaxLogicalDomainInfoSpanningTree(tv1).traverse(&propagator_c2p); // in: loop domain: {b, s, d*h, e} after transform propagation - // Loop split and parallelize input tv0->setDeviceMesh(mesh); tv1->setDeviceMesh(mesh); @@ -797,9 +795,9 @@ TEST_F(MultiDeviceTest, TransformPropagatorSplitReshape) { tv1->setAllocationDomain(tv1->getLoopDomain(), true); FusionExecutorCache executor_cache(std::move(fusion)); - at::Tensor in_tensor = at::randn({b, s, h * e}, tensor_options); - at::Tensor out_tensor = - executor_cache.runFusionWithInputs({in_tensor})[0].as(); + at::Tensor inp = at::randn({b, s, d * h * e}, tensor_options); + at::Tensor sharded_inp = shardTensor(inp, tv0); + testValidate( executor_cache.fusion(), {nvf_out}, From 3f7f3a5f0a7d24948bbccc32e1abc05f27b27d87 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Wed, 5 Mar 2025 22:53:44 -0800 Subject: [PATCH 10/70] fix rebase --- tests/cpp/test_multidevice_sharding.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/cpp/test_multidevice_sharding.cpp b/tests/cpp/test_multidevice_sharding.cpp index e2e1b1197b8..fce79e42865 100644 --- a/tests/cpp/test_multidevice_sharding.cpp +++ b/tests/cpp/test_multidevice_sharding.cpp @@ -798,6 +798,9 @@ TEST_F(MultiDeviceTest, TransformPropagatorSplitReshape) { at::Tensor inp = at::randn({b, s, d * h * e}, tensor_options); at::Tensor sharded_inp = shardTensor(inp, tv0); + at::Tensor nvf_out = + executor_cache.runFusionWithInputs({sharded_inp})[0].as(); + testValidate( executor_cache.fusion(), {nvf_out}, From 02d2ce359f6605efe04ff9df4d2feb1e77442c0f Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Fri, 7 Mar 2025 12:17:11 -0800 Subject: [PATCH 11/70] return did pos from reorder --- csrc/multidevice/utils.cpp | 3 ++- csrc/multidevice/utils.h | 2 +- csrc/preseg_passes/propagate_shardings.cpp | 8 ++++++++ tests/python/test_multidevice.py | 4 ++-- 4 files changed, 13 insertions(+), 4 deletions(-) diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp index 4be6b63be65..fbe4ab1566e 100644 --- a/csrc/multidevice/utils.cpp +++ b/csrc/multidevice/utils.cpp @@ -712,7 +712,7 @@ std::set involvedDevices(Expr* expr) { return ret; } -void reorderDIDToFront(TensorView* tv) { +int64_t reorderDIDToFront(TensorView* tv) { // old position to new position std::unordered_map order_map; int64_t current_pos = 0; @@ -725,6 +725,7 @@ void reorderDIDToFront(TensorView* tv) { } tv->reorder(order_map); + return current_pos; } std::unordered_set getTvsWithDifferentSharding( diff --git a/csrc/multidevice/utils.h b/csrc/multidevice/utils.h index 34c510ccb2e..9ae692e3589 100644 --- a/csrc/multidevice/utils.h +++ b/csrc/multidevice/utils.h @@ -114,7 +114,7 @@ at::Tensor shardTensor( DeviceIdxType device_id); // Reorders a TensorView so that the DID parallelized axis are in front. -void reorderDIDToFront(TensorView*); +int64_t reorderDIDToFront(TensorView*); // Given a TensorView and the shape of a sharded tensor of which certain // dimensions are partially allocated, returns the global shape that'll be used diff --git a/csrc/preseg_passes/propagate_shardings.cpp b/csrc/preseg_passes/propagate_shardings.cpp index 69ba5983060..6edc508381b 100644 --- a/csrc/preseg_passes/propagate_shardings.cpp +++ b/csrc/preseg_passes/propagate_shardings.cpp @@ -13,6 +13,7 @@ #include #include #include +#include namespace nvfuser::preseg_passes { @@ -95,6 +96,13 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { outputs_without_mesh.push_back(tv); } } + + int64_t did_pos = reorderDIDToFront(ref_input); + TransformPropagator propagator(ref_input, did_pos); + SetSelector selector( + {outputs_without_mesh.begin(), outputs_without_mesh.end()}); + MaxLogicalDomainInfoSpanningTree(ref_input, &selector) + .traverse(&propagator); shardAllLike(ref_input, outputs_without_mesh); } diff --git a/tests/python/test_multidevice.py b/tests/python/test_multidevice.py index 237aa9f6aa6..8a7f7dbcef7 100644 --- a/tests/python/test_multidevice.py +++ b/tests/python/test_multidevice.py @@ -1112,7 +1112,7 @@ def test_transformer_forward(multidevice_test, benchmark): # Benchmark and profile. The profile can be collected and displayed using # `nsys`. See instructions in test_transformer_engine.py. - benchmark.pedantic(benchmark_fn, rounds=5) + # benchmark.pedantic(benchmark_fn, rounds=5) # All tensors are replicated to all devices at this moment; future PRs will try @@ -1691,4 +1691,4 @@ def test_transformer_backward(multidevice_test, benchmark): _assert_shape_dtype(layernorm0_weight_grad, [e], torch.bfloat16) _assert_shape_dtype(inp_grad, [b, s, e], torch.bfloat16) - benchmark.pedantic(benchmark_fn, rounds=5) + # benchmark.pedantic(benchmark_fn, rounds=5) From db7ac1e26df81e3bdf0ba57dc2b13c80aaa63ff8 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Tue, 11 Mar 2025 17:30:50 -0700 Subject: [PATCH 12/70] reshape sharding for transformer case --- csrc/multidevice/utils.h | 4 + csrc/preseg_passes/propagate_shardings.cpp | 94 ++++++++++++++++++++++ tests/cpp/test_multidevice_sharding.cpp | 4 +- 3 files changed, 101 insertions(+), 1 deletion(-) diff --git a/csrc/multidevice/utils.h b/csrc/multidevice/utils.h index 9ae692e3589..fb2777b1ba8 100644 --- a/csrc/multidevice/utils.h +++ b/csrc/multidevice/utils.h @@ -39,6 +39,10 @@ bool isSharded(const TensorView*); // Returns number of device dimensions in a TensorView's loop domain. int64_t numDeviceDims(const TensorView*); +std::vector getInputsInTargetDomain( + IterDomain* loop_id, + const std::vector& target_domain); + // Returns the subset of tvs which elements have the different multi-device // sharding as ref std::unordered_set getTvsWithDifferentSharding( diff --git a/csrc/preseg_passes/propagate_shardings.cpp b/csrc/preseg_passes/propagate_shardings.cpp index 6edc508381b..4cda6caa083 100644 --- a/csrc/preseg_passes/propagate_shardings.cpp +++ b/csrc/preseg_passes/propagate_shardings.cpp @@ -47,6 +47,32 @@ void validateMeshes(Fusion* fusion) { tv_without_mesh, " not."); } + +std::pair, std::unordered_set> getReshapedIds(ViewOp* view_op, const std::unordered_map& c2p) { + std::unordered_set p_reshaped_ids; // Reshaped logical IDs + std::unordered_set c_reshaped_ids; // Reshaped root IDs + + TensorView* consumer = view_op->out(); + std::vector c_root_domain = consumer->getMaybeRootDomain(); + + for (auto id : consumer->getLogicalDomain()) { + if (id->isRFactorProduct() && id->definition() && + !id->definition()->isA()) { + auto root_ids = getInputsInTargetDomain(id, c_root_domain); + for (auto root_id : root_ids) { + c_reshaped_ids.insert(root_id); + } + } + } + // Get the logical iterdomains in the producer that are reshaped. + std::unordered_set p_reshape_ids; + for (auto id : c_reshaped_ids) { + if (auto p_id = c2p.find(id); p_id != c2p.end()) { + p_reshaped_ids.insert(p_id->second); + } + } + return std::make_pair(p_reshaped_ids, c_reshaped_ids); +} } // namespace void PropagateShardingsPass::runPass(Fusion* fusion) { @@ -97,12 +123,80 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { } } + if (outputs_without_mesh.empty()) { + continue; + } + + // This restricts the transform propagation to the DID axis. int64_t did_pos = reorderDIDToFront(ref_input); + + if (ViewOp* view_op = dynamic_cast(expr)) { + // This implementation asserts that only one sharding is applied on the reshaped ids. + // Inner split is not supported. + // The cases are: + // 1. Split reshape: [h] -> [a, h/a]. Sharding on h is applied to a in consumer. + // 2. Merge reshape: [a, h/a] -> [h]. Sharding on a is applied to h in consumer. + // An improvement is to support mult-levels of sharding (not a real case in practice) if they are all outer splits. + // For example: [h] -> [cp, h/cp] -> [cp, tp, h/(cp*tp)] + + TensorView* producer = view_op->in(); + TensorView* consumer = view_op->out(); + + const std::unordered_map& c2p = PairwiseLogicalDomainMap(producer, consumer).mapConsumerToProducer(); + auto [p_reshaped_ids, c_reshaped_ids] = getReshapedIds(view_op, c2p); + + auto p_loop_domain = producer->getLoopDomain(); + auto c_loop_domain = consumer->getLoopDomain(); + + for (auto idx: c10::irange(did_pos)) { + auto p_transforms = DependencyCheck::getAllExprsBetween( + {p_reshaped_ids.begin(), p_reshaped_ids.end()}, {p_loop_domain.at(idx)}); + if (p_transforms.empty()) { + // Sharding is not on reshaped ids. We will use the TransformPropagator. + continue; + } + + NVF_ERROR(p_transforms.size() == 1 && p_transforms.back()->isA(), "Expected only a single DID split on reshaped ids."); + auto* p_did_split = p_transforms.front()->as(); + + auto reshape_transform = DependencyCheck::getAllExprsBetween( + {c_reshaped_ids.begin(), c_reshaped_ids.end()}, {consumer->getLogicalDomain().begin(), consumer->getLogicalDomain().end()}); + + NVF_ERROR((reshape_transform.size() == 1 && reshape_transform.front()->isOneOf()), "Expected a split or merge transform between root and logical reshaped ids."); + + if (reshape_transform.front()->isA()){ + // Check that the sharding is on the outer reshaped id. If it is on inner reshaped id (h/a for merge reshape), for non-resharding, the consumer should be inner split which is not supported. + auto* outer_id = reshape_transform.front()->as()->outer(); + auto* producer_outer_id = c2p.at(outer_id); + NVF_ERROR(p_did_split->in() == producer_outer_id, "Expected the sharding to be on the outer reshaped id."); + } + + auto* c_sharded_id = reshape_transform.front()->isA() ? reshape_transform.front()->as()->outer() : reshape_transform.front()->as()->out(); + + int64_t sharded_axis = std::distance( + c_loop_domain.begin(), + std::find(c_loop_domain.begin(), + c_loop_domain.end(), + c_sharded_id)); + + Val* split_factor = p_did_split->factor(); + consumer->split(sharded_axis, split_factor, /*inner_split=*/false); + consumer->axis(sharded_axis)->parallelize(ParallelType::DIDx); + + // Move this did_pos to the end in producer to avoid using TransformPropagator on it. + // producer->reorder({{idx, -1}}); + } + did_pos = did_pos - 1; + } + + // Propagate the DID loop split to the outputs without mesh. TransformPropagator propagator(ref_input, did_pos); SetSelector selector( {outputs_without_mesh.begin(), outputs_without_mesh.end()}); MaxLogicalDomainInfoSpanningTree(ref_input, &selector) .traverse(&propagator); + + // Apply parallelization on the outputs without mesh. shardAllLike(ref_input, outputs_without_mesh); } diff --git a/tests/cpp/test_multidevice_sharding.cpp b/tests/cpp/test_multidevice_sharding.cpp index fce79e42865..de5fd04687b 100644 --- a/tests/cpp/test_multidevice_sharding.cpp +++ b/tests/cpp/test_multidevice_sharding.cpp @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -1153,7 +1154,8 @@ TEST_F(MultiDeviceTest, TransformerFwd) { tv->axis(-2)->parallelize(ParallelType::DIDx); reorderDIDToFront(tv); } - propagateShardings(fusion.get(), d); + + preseg_passes::OptimizationPass::runPass(fusion.get()); for (auto tv : fusion->allTvs()) { tv->setAllocationDomain(tv->getLoopDomain(), true); From 4502436e176e27df2bbca15cce1bc4ade1c34ac8 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Wed, 12 Mar 2025 18:33:40 -0700 Subject: [PATCH 13/70] add ordering of inputs, custom selector for directioned propagation, backward transform propagation --- csrc/multidevice/utils.cpp | 10 +- csrc/multidevice/utils.h | 2 +- csrc/preseg_passes/propagate_shardings.cpp | 289 ++++++++++++--------- tests/cpp/test_multidevice_sharding.cpp | 29 +++ 4 files changed, 211 insertions(+), 119 deletions(-) diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp index fbe4ab1566e..a2cfd3403d9 100644 --- a/csrc/multidevice/utils.cpp +++ b/csrc/multidevice/utils.cpp @@ -571,13 +571,19 @@ bool isInnerResharding(Expr* expr) { return false; } -void shardAllLike(TensorView* ref, std::vector tvs) { +void shardAllLike(TensorView* ref, std::vector tvs, bool parallelize_inputs) { for (auto tv : tvs) { tv->setDeviceMesh(ref->getDeviceMesh()); } + // TODO: If the tv already has a particular device parallel type, skip that. if (!tvs.empty()) { scheduler_utils::parallelizeAllLike( - ref, tvs, {ParallelType::DIDx, ParallelType::Serial}); + ref, + /*pos=*/-1, + /*selected_tvs=*/tvs, + /*selected_parallel_types=*/{ParallelType::DIDx, ParallelType::Serial}, + /*propagate_padding=*/false, + /*parallelize_inputs=*/parallelize_inputs); } // parallelAllLke, tries to DID-parallelize diff --git a/csrc/multidevice/utils.h b/csrc/multidevice/utils.h index fb2777b1ba8..823eb29093e 100644 --- a/csrc/multidevice/utils.h +++ b/csrc/multidevice/utils.h @@ -62,7 +62,7 @@ bool haveDifferentShardings( bool isInnerResharding(Expr* expr); // Shards all tensors in tvs like reference -void shardAllLike(TensorView* ref, std::vector tvs); +void shardAllLike(TensorView* ref, std::vector tvs, bool parallelize_inputs=false); // Shards all TVs between from and to AND between TVs created inside a fusion // and to. This is required for (1) expressions like rng_uniform that create a diff --git a/csrc/preseg_passes/propagate_shardings.cpp b/csrc/preseg_passes/propagate_shardings.cpp index 4cda6caa083..a503284cdab 100644 --- a/csrc/preseg_passes/propagate_shardings.cpp +++ b/csrc/preseg_passes/propagate_shardings.cpp @@ -73,131 +73,174 @@ std::pair, std::unordered_set> getR } return std::make_pair(p_reshaped_ids, c_reshaped_ids); } -} // namespace -void PropagateShardingsPass::runPass(Fusion* fusion) { - auto num_device_parallel_dimensions = [](const TensorView* tv) -> int64_t { - return std::count_if( - tv->getLoopDomain().begin(), - tv->getLoopDomain().end(), - std::mem_fn(&IterDomain::isDeviceDim)); - }; +int64_t num_device_dims(TensorView* tv) { + return std::count_if( + tv->getLoopDomain().begin(), + tv->getLoopDomain().end(), + std::mem_fn(&IterDomain::isDeviceDim)); +} - const std::vector& exprs = fusion->exprs(); - for (Expr* expr : exprs) { - const auto& inputs = ir_utils::filterByType(expr->inputs()); - // Pick the "most parallel" input tensor as the reference. This is useful - // for propagating tensor parallelism from weights to MLP's intermediate - // tensors. For example, - // - // x: [b, s, h]; replicated. - // w0: [h, 4*h]; column-wise sharded. - // w1: [4*h, h]; row-wise sharded. - // y = matmul(x, w0) - // z = matmul(y, w1) - // - // With the above heuristic, `y` can be automatically sharded column-wise. - TensorView* ref_input = nullptr; - auto max_num_dids = std::numeric_limits::min(); - for (auto* input : inputs) { - if (!input->hasDeviceMesh()) { - continue; - } - int64_t num_dids = num_device_parallel_dimensions(input); - if (num_dids > max_num_dids) { - max_num_dids = num_dids; - ref_input = input; - } - } - if (ref_input == nullptr) { +// Order the inputs of the expression based on their priority. +// For linear op, we use weights and bias before input. +// For matmul op, we use weights before input. +// For other ops, we sort the inputs by the number of device dimensions in descending order. +std::vector getOrderedReferenceInputs(Expr* expr) { + const auto& inputs = ir_utils::filterByType(expr->inputs()); + if (LinearOp* linear_op = dynamic_cast(expr)) { + // Use weights and bias before input. + return {linear_op->inB(), linear_op->bias(), linear_op->inA()}; + } + + if (MatmulOp* matmul_op = dynamic_cast(expr)) { + // Use weights before input. + return {matmul_op->inB(), matmul_op->inA()}; + } + + // Sort inputs by number of device dimensions in descending order + std::vector sorted_inputs(inputs.begin(), inputs.end()); + std::sort(sorted_inputs.begin(), sorted_inputs.end(), + [&](TensorView* a, TensorView* b) { + return num_device_dims(a) > num_device_dims(b); + }); + + return sorted_inputs; +} + +std::vector getOutputsWithoutMesh(Expr* expr) { + const auto& outputs = ir_utils::filterByType(expr->outputs()); + std::vector outputs_without_mesh; + std::copy_if( + outputs.begin(), + outputs.end(), + std::back_inserter(outputs_without_mesh), + [](TensorView* tv) { return !tv->hasDeviceMesh(); }); + return outputs_without_mesh; +} + +class PropagateShardingsSelector: public SetSelector { + private: + bool allow_c2p_; + bool allow_p2c_; + + public: + explicit PropagateShardingsSelector( + const std::unordered_set& selected_tvs, + bool allow_c2p = true, + bool allow_p2c = true) + : SetSelector(selected_tvs), + allow_c2p_(allow_c2p), + allow_p2c_(allow_p2c) {} + + bool allowC2P(TensorView* from, TensorView* to) override { + return allow_c2p_ && SetSelector::allowC2P(from, to); + } + + bool allowP2C(TensorView* from, TensorView* to) override { + return allow_p2c_ && SetSelector::allowP2C(from, to); + } +}; + +void handleViewOp(ViewOp* view_op, int64_t did_pos) { + // This implementation asserts that only one sharding is applied on the reshaped ids. + // Inner split is not supported. + // The cases are: + // 1. Split reshape: [h] -> [a, h/a]. Sharding on h is applied to a in consumer. + // 2. Merge reshape: [a, h/a] -> [h]. Sharding on a is applied to h in consumer. + // An improvement is to support mult-levels of sharding (not a real case in practice) if they are all outer splits. + // For example: [h] -> [cp, h/cp] -> [cp, tp, h/(cp*tp)] + + TensorView* producer = view_op->in(); + TensorView* consumer = view_op->out(); + + const std::unordered_map& c2p = PairwiseLogicalDomainMap(producer, consumer).mapConsumerToProducer(); + auto [p_reshaped_ids, c_reshaped_ids] = getReshapedIds(view_op, c2p); + + auto p_loop_domain = producer->getLoopDomain(); + auto c_loop_domain = consumer->getLoopDomain(); + + for (auto idx: c10::irange(did_pos)) { + auto p_transforms = DependencyCheck::getAllExprsBetween( + {p_reshaped_ids.begin(), p_reshaped_ids.end()}, {p_loop_domain.at(idx)}); + if (p_transforms.empty()) { + // Sharding is not on reshaped ids. We will use the TransformPropagator. continue; } - // Note: Tvs without a mesh are assumed to have no manual sharding - // annotation and are sharded like the first producer Tv. - const auto& outputs = ir_utils::filterByType(expr->outputs()); - std::vector outputs_without_mesh; - for (auto* tv : outputs) { - if (!tv->hasDeviceMesh()) { - outputs_without_mesh.push_back(tv); - } + NVF_ERROR(p_transforms.size() == 1 && p_transforms.back()->isA(), "Expected only a single DID split on reshaped ids."); + auto* p_did_split = p_transforms.front()->as(); + + auto reshape_transform = DependencyCheck::getAllExprsBetween( + {c_reshaped_ids.begin(), c_reshaped_ids.end()}, {consumer->getLogicalDomain().begin(), consumer->getLogicalDomain().end()}); + + NVF_ERROR((reshape_transform.size() == 1 && reshape_transform.front()->isOneOf()), "Expected a split or merge transform between root and logical reshaped ids."); + + if (reshape_transform.front()->isA()){ + // Check that the sharding is on the outer reshaped id. If it is on inner reshaped id (h/a for merge reshape), for non-resharding, the consumer should be inner split which is not supported. + auto* outer_id = reshape_transform.front()->as()->outer(); + auto* producer_outer_id = c2p.at(outer_id); + NVF_ERROR(p_did_split->in() == producer_outer_id, "Expected the sharding to be on the outer reshaped id."); } + auto* c_sharded_id = reshape_transform.front()->isA() ? reshape_transform.front()->as()->outer() : reshape_transform.front()->as()->out(); + + int64_t sharded_axis = std::distance( + c_loop_domain.begin(), + std::find(c_loop_domain.begin(), + c_loop_domain.end(), + c_sharded_id)); + + Val* split_factor = p_did_split->factor(); + consumer->split(sharded_axis, split_factor, /*inner_split=*/false); + consumer->axis(sharded_axis)->parallelize(ParallelType::DIDx); + + // Move this did_pos to the end in producer to avoid using TransformPropagator on it. + producer->reorder({{idx, -1}}); + } +} + +} // namespace + +void PropagateShardingsPass::runPass(Fusion* fusion) { + const std::vector& exprs = fusion->exprs(); + + for (Expr* expr : exprs) { + // Note: Tvs without a mesh are assumed to have no manual sharding + // annotation and are sharded like the first producer Tv. + const auto& outputs_without_mesh = getOutputsWithoutMesh(expr); if (outputs_without_mesh.empty()) { continue; } - // This restricts the transform propagation to the DID axis. - int64_t did_pos = reorderDIDToFront(ref_input); - - if (ViewOp* view_op = dynamic_cast(expr)) { - // This implementation asserts that only one sharding is applied on the reshaped ids. - // Inner split is not supported. - // The cases are: - // 1. Split reshape: [h] -> [a, h/a]. Sharding on h is applied to a in consumer. - // 2. Merge reshape: [a, h/a] -> [h]. Sharding on a is applied to h in consumer. - // An improvement is to support mult-levels of sharding (not a real case in practice) if they are all outer splits. - // For example: [h] -> [cp, h/cp] -> [cp, tp, h/(cp*tp)] + const auto& reference_inputs = getOrderedReferenceInputs(expr); - TensorView* producer = view_op->in(); - TensorView* consumer = view_op->out(); + // Propagate shardings from reference inputs in order. + for (auto* ref_input : reference_inputs) { + // Skip if the input has no device dimensions or is nullptr. + if (ref_input == nullptr || num_device_dims(ref_input) == 0) { + continue; + } + + // This restricts the transform propagation to the DID axis. + int64_t did_pos = reorderDIDToFront(ref_input); - const std::unordered_map& c2p = PairwiseLogicalDomainMap(producer, consumer).mapConsumerToProducer(); - auto [p_reshaped_ids, c_reshaped_ids] = getReshapedIds(view_op, c2p); - - auto p_loop_domain = producer->getLoopDomain(); - auto c_loop_domain = consumer->getLoopDomain(); - - for (auto idx: c10::irange(did_pos)) { - auto p_transforms = DependencyCheck::getAllExprsBetween( - {p_reshaped_ids.begin(), p_reshaped_ids.end()}, {p_loop_domain.at(idx)}); - if (p_transforms.empty()) { - // Sharding is not on reshaped ids. We will use the TransformPropagator. - continue; - } - - NVF_ERROR(p_transforms.size() == 1 && p_transforms.back()->isA(), "Expected only a single DID split on reshaped ids."); - auto* p_did_split = p_transforms.front()->as(); - - auto reshape_transform = DependencyCheck::getAllExprsBetween( - {c_reshaped_ids.begin(), c_reshaped_ids.end()}, {consumer->getLogicalDomain().begin(), consumer->getLogicalDomain().end()}); - - NVF_ERROR((reshape_transform.size() == 1 && reshape_transform.front()->isOneOf()), "Expected a split or merge transform between root and logical reshaped ids."); - - if (reshape_transform.front()->isA()){ - // Check that the sharding is on the outer reshaped id. If it is on inner reshaped id (h/a for merge reshape), for non-resharding, the consumer should be inner split which is not supported. - auto* outer_id = reshape_transform.front()->as()->outer(); - auto* producer_outer_id = c2p.at(outer_id); - NVF_ERROR(p_did_split->in() == producer_outer_id, "Expected the sharding to be on the outer reshaped id."); - } - - auto* c_sharded_id = reshape_transform.front()->isA() ? reshape_transform.front()->as()->outer() : reshape_transform.front()->as()->out(); - - int64_t sharded_axis = std::distance( - c_loop_domain.begin(), - std::find(c_loop_domain.begin(), - c_loop_domain.end(), - c_sharded_id)); - - Val* split_factor = p_did_split->factor(); - consumer->split(sharded_axis, split_factor, /*inner_split=*/false); - consumer->axis(sharded_axis)->parallelize(ParallelType::DIDx); - - // Move this did_pos to the end in producer to avoid using TransformPropagator on it. - // producer->reorder({{idx, -1}}); + if (ViewOp* view_op = dynamic_cast(expr)) { + handleViewOp(view_op, did_pos); + did_pos = did_pos - 1; } - did_pos = did_pos - 1; + + // Propagate the DID loop split to the outputs without mesh. + TransformPropagator propagator(ref_input, did_pos); + PropagateShardingsSelector selector( + {outputs_without_mesh.begin(), outputs_without_mesh.end()}, + /*allow_c2p=*/false, + /*allow_p2c=*/true); + MaxLogicalDomainInfoSpanningTree(ref_input, &selector) + .traverse(&propagator); + + // Apply parallelization on the outputs without mesh. + shardAllLike(ref_input, outputs_without_mesh); } - - // Propagate the DID loop split to the outputs without mesh. - TransformPropagator propagator(ref_input, did_pos); - SetSelector selector( - {outputs_without_mesh.begin(), outputs_without_mesh.end()}); - MaxLogicalDomainInfoSpanningTree(ref_input, &selector) - .traverse(&propagator); - - // Apply parallelization on the outputs without mesh. - shardAllLike(ref_input, outputs_without_mesh); } // Back-propagate device meshes. This makes sure all TensorViews have a mesh @@ -215,14 +258,28 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { if (i_output == outputs.end()) { continue; } - TensorView* output_with_mesh = *i_output; + + // All outputs of an expression are uniformly sharded so we pick the first one. + // TODO: Do we need to worry about the case where the outputs are not uniformly sharded? + // The relevant exprs are Welford and SDPA. + TensorView* ref_output = *i_output; + int64_t did_pos = reorderDIDToFront(ref_output); const auto& inputs = ir_utils::filterByType(expr->inputs()); - for (auto* tv : inputs) { - if (!tv->hasDeviceMesh()) { - tv->setDeviceMesh(output_with_mesh->getDeviceMesh()); - } - } + std::vector unsharded_inputs; + std::copy_if( + inputs.begin(), + inputs.end(), + std::back_inserter(unsharded_inputs), + [](TensorView* tv) { return !tv->hasDeviceMesh() || num_device_dims(tv) == 0; }); + + TransformPropagator propagator(ref_output, did_pos); + PropagateShardingsSelector selector( + {unsharded_inputs.begin(), unsharded_inputs.end()}, + /*allow_c2p=*/true, + /*allow_p2c=*/false); + MaxLogicalDomainInfoSpanningTree(ref_output, &selector).traverse(&propagator); + shardAllLike(ref_output, unsharded_inputs, /*parallelize_inputs=*/true); } validateMeshes(fusion); diff --git a/tests/cpp/test_multidevice_sharding.cpp b/tests/cpp/test_multidevice_sharding.cpp index de5fd04687b..528bdcc1fd7 100644 --- a/tests/cpp/test_multidevice_sharding.cpp +++ b/tests/cpp/test_multidevice_sharding.cpp @@ -16,6 +16,7 @@ #include #include #include +#include "multidevice/utils.h" namespace nvfuser { @@ -1195,4 +1196,32 @@ TEST_F(MultiDeviceTest, TransformerFwd) { __LINE__, __FILE__); } + +TEST_F(MultiDeviceTest, ResidualAdd) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + const int d = communicator_->size(); + const int64_t b = 2, s = 3, h = 8; + + TensorView* tv0 = makeContigConcreteTensor({b, d*s, h}); + TensorView* tv1 = makeContigConcreteTensor({b, d*s, h}); + TensorView* tv2 = add(tv0, tv1); + + auto mesh = DeviceMesh::createForNumDevices(d); + tv0->setDeviceMesh(mesh); + tv1->setDeviceMesh(mesh); + tv1->split(1, d, /*inner_split=*/false); + tv1->axis(1)->parallelize(ParallelType::DIDx); + + fusion->addInput(tv0); + fusion->addInput(tv1); + fusion->addOutput(tv2); + + preseg_passes::OptimizationPass::runPass(fusion.get()); + debug() << "tv0: " << tv0->toString() << std::endl; + debug() << "tv1: " << tv1->toString() << std::endl; + debug() << "tv2: " << tv2->toString() << std::endl; + NVF_CHECK(getShardedLoopAxis(tv0, ParallelType::DIDx) != -1); +} } // namespace nvfuser From af5ad27f7f2dfd69137b0d9b8690c64203f7bd59 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Thu, 13 Mar 2025 11:44:07 -0700 Subject: [PATCH 14/70] support multiple merges or splits in a reshape --- csrc/preseg_passes/propagate_shardings.cpp | 31 +++++++++---- tests/cpp/test_multidevice_sharding.cpp | 52 ++++++++++++++++++++-- 2 files changed, 71 insertions(+), 12 deletions(-) diff --git a/csrc/preseg_passes/propagate_shardings.cpp b/csrc/preseg_passes/propagate_shardings.cpp index a503284cdab..4107b83f857 100644 --- a/csrc/preseg_passes/propagate_shardings.cpp +++ b/csrc/preseg_passes/propagate_shardings.cpp @@ -170,19 +170,32 @@ void handleViewOp(ViewOp* view_op, int64_t did_pos) { NVF_ERROR(p_transforms.size() == 1 && p_transforms.back()->isA(), "Expected only a single DID split on reshaped ids."); auto* p_did_split = p_transforms.front()->as(); - auto reshape_transform = DependencyCheck::getAllExprsBetween( + auto reshape_transforms = DependencyCheck::getAllExprsBetween( {c_reshaped_ids.begin(), c_reshaped_ids.end()}, {consumer->getLogicalDomain().begin(), consumer->getLogicalDomain().end()}); - NVF_ERROR((reshape_transform.size() == 1 && reshape_transform.front()->isOneOf()), "Expected a split or merge transform between root and logical reshaped ids."); - - if (reshape_transform.front()->isA()){ - // Check that the sharding is on the outer reshaped id. If it is on inner reshaped id (h/a for merge reshape), for non-resharding, the consumer should be inner split which is not supported. - auto* outer_id = reshape_transform.front()->as()->outer(); - auto* producer_outer_id = c2p.at(outer_id); - NVF_ERROR(p_did_split->in() == producer_outer_id, "Expected the sharding to be on the outer reshaped id."); + // Check if the producer is sharded on the outermost reshaped id. + IterDomain* reshaped_outer_id = nullptr; + for (auto transform_it = reshape_transforms.rbegin(); transform_it != reshape_transforms.rend(); transform_it++) { + auto* transform = *transform_it; + if (transform->isA()) { + reshaped_outer_id = transform->as()->outer(); + } + if (transform->isA()) { + reshaped_outer_id = transform->as()->in(); + } } + NVF_ERROR(c2p.find(reshaped_outer_id)!=c2p.end() && c2p.find(reshaped_outer_id)->second == p_did_split->in(), "Expected the sharding to be on the outer reshaped id."); - auto* c_sharded_id = reshape_transform.front()->isA() ? reshape_transform.front()->as()->outer() : reshape_transform.front()->as()->out(); + // Get the sharded id in the consumer. + IterDomain* c_sharded_id = nullptr; + for (auto transform: reshape_transforms) { + if (transform->isA()) { + c_sharded_id = transform->as()->outer(); + } + if (transform->isA()) { + c_sharded_id = transform->as()->out(); + } + } int64_t sharded_axis = std::distance( c_loop_domain.begin(), diff --git a/tests/cpp/test_multidevice_sharding.cpp b/tests/cpp/test_multidevice_sharding.cpp index 528bdcc1fd7..9ee270bf192 100644 --- a/tests/cpp/test_multidevice_sharding.cpp +++ b/tests/cpp/test_multidevice_sharding.cpp @@ -1219,9 +1219,55 @@ TEST_F(MultiDeviceTest, ResidualAdd) { fusion->addOutput(tv2); preseg_passes::OptimizationPass::runPass(fusion.get()); - debug() << "tv0: " << tv0->toString() << std::endl; - debug() << "tv1: " << tv1->toString() << std::endl; - debug() << "tv2: " << tv2->toString() << std::endl; NVF_CHECK(getShardedLoopAxis(tv0, ParallelType::DIDx) != -1); } + +TEST_F(MultiDeviceTest, MultipleMergeReshape) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + const int d = communicator_->size(); + const int64_t b = 2, s = 3, h = 8; + + TensorView* tv0 = makeContigConcreteTensor({d*b, s, h}); + TensorView* tv1 = reshape(tv0, {d*b, s, h}, {d*b, s, h}); + fusion->addInput(tv0); + fusion->addOutput(tv1); + + auto mesh = DeviceMesh::createForNumDevices(d); + tv0->setDeviceMesh(mesh); + tv0->split(0, d, /*inner_split=*/false); + tv0->axis(0)->parallelize(ParallelType::DIDx); + + auto transform_exprs = StmtSort::getExprsBetween( + {tv1->getMaybeRootDomain().begin(), tv1->getMaybeRootDomain().end()}, + {tv1->getLogicalDomain().begin(), tv1->getLogicalDomain().end()}); + + preseg_passes::OptimizationPass::runPass(fusion.get()); + + NVF_CHECK(getShardedLoopAxis(tv1, ParallelType::DIDx) != -1); +} + +TEST_F(MultiDeviceTest, MultipleSplitReshape) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + const int d = communicator_->size(); + const int64_t b = 2, s = 3, h = 8; + + TensorView* tv0 = makeContigConcreteTensor({d*b*s*h}); + TensorView* tv1 = reshape(tv0, {d*b*s*h}, {d*b, s, h}); + fusion->addInput(tv0); + fusion->addOutput(tv1); + + auto mesh = DeviceMesh::createForNumDevices(d); + tv0->setDeviceMesh(mesh); + tv0->split(0, d, /*inner_split=*/false); + tv0->axis(0)->parallelize(ParallelType::DIDx); + + preseg_passes::OptimizationPass::runPass(fusion.get()); + + NVF_CHECK(getShardedLoopAxis(tv1, ParallelType::DIDx) != -1); +} + } // namespace nvfuser From 0b9577c565d77d95d6ea783c6a1a18d670126653 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Thu, 13 Mar 2025 11:48:04 -0700 Subject: [PATCH 15/70] use producer's parallel type --- csrc/preseg_passes/propagate_shardings.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/preseg_passes/propagate_shardings.cpp b/csrc/preseg_passes/propagate_shardings.cpp index 4107b83f857..48884f464bd 100644 --- a/csrc/preseg_passes/propagate_shardings.cpp +++ b/csrc/preseg_passes/propagate_shardings.cpp @@ -205,7 +205,7 @@ void handleViewOp(ViewOp* view_op, int64_t did_pos) { Val* split_factor = p_did_split->factor(); consumer->split(sharded_axis, split_factor, /*inner_split=*/false); - consumer->axis(sharded_axis)->parallelize(ParallelType::DIDx); + consumer->axis(sharded_axis)->parallelize(p_loop_domain.at(idx)->getParallelType()); // Move this did_pos to the end in producer to avoid using TransformPropagator on it. producer->reorder({{idx, -1}}); From e0c40266e1060248b821c2d8c7f769315fcb6a64 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Thu, 13 Mar 2025 11:49:36 -0700 Subject: [PATCH 16/70] return number of did shardings on reshape --- csrc/preseg_passes/propagate_shardings.cpp | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/csrc/preseg_passes/propagate_shardings.cpp b/csrc/preseg_passes/propagate_shardings.cpp index 48884f464bd..9bfc230d02a 100644 --- a/csrc/preseg_passes/propagate_shardings.cpp +++ b/csrc/preseg_passes/propagate_shardings.cpp @@ -141,7 +141,7 @@ class PropagateShardingsSelector: public SetSelector { } }; -void handleViewOp(ViewOp* view_op, int64_t did_pos) { +int64_t handleViewOp(ViewOp* view_op, int64_t did_pos) { // This implementation asserts that only one sharding is applied on the reshaped ids. // Inner split is not supported. // The cases are: @@ -158,7 +158,8 @@ void handleViewOp(ViewOp* view_op, int64_t did_pos) { auto p_loop_domain = producer->getLoopDomain(); auto c_loop_domain = consumer->getLoopDomain(); - + + int64_t num_reshape_shardings = 0; for (auto idx: c10::irange(did_pos)) { auto p_transforms = DependencyCheck::getAllExprsBetween( {p_reshaped_ids.begin(), p_reshaped_ids.end()}, {p_loop_domain.at(idx)}); @@ -166,7 +167,7 @@ void handleViewOp(ViewOp* view_op, int64_t did_pos) { // Sharding is not on reshaped ids. We will use the TransformPropagator. continue; } - + num_reshape_shardings++; NVF_ERROR(p_transforms.size() == 1 && p_transforms.back()->isA(), "Expected only a single DID split on reshaped ids."); auto* p_did_split = p_transforms.front()->as(); @@ -210,6 +211,7 @@ void handleViewOp(ViewOp* view_op, int64_t did_pos) { // Move this did_pos to the end in producer to avoid using TransformPropagator on it. producer->reorder({{idx, -1}}); } + return num_reshape_shardings; } } // namespace @@ -238,8 +240,8 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { int64_t did_pos = reorderDIDToFront(ref_input); if (ViewOp* view_op = dynamic_cast(expr)) { - handleViewOp(view_op, did_pos); - did_pos = did_pos - 1; + int64_t num_reshape_shardings = handleViewOp(view_op, did_pos); + did_pos = did_pos - num_reshape_shardings; } // Propagate the DID loop split to the outputs without mesh. From f798f1e3e5116fe800d75e6d91e1d084ef245d45 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Thu, 13 Mar 2025 17:36:39 -0700 Subject: [PATCH 17/70] update handling of view op --- csrc/preseg_passes/propagate_shardings.cpp | 94 ++++++++++++++-------- tests/cpp/test_multidevice_sharding.cpp | 33 ++++---- 2 files changed, 78 insertions(+), 49 deletions(-) diff --git a/csrc/preseg_passes/propagate_shardings.cpp b/csrc/preseg_passes/propagate_shardings.cpp index 9bfc230d02a..c5a3e60e5f9 100644 --- a/csrc/preseg_passes/propagate_shardings.cpp +++ b/csrc/preseg_passes/propagate_shardings.cpp @@ -8,6 +8,7 @@ #include #include +#include "type.h" #include #include @@ -64,8 +65,7 @@ std::pair, std::unordered_set> getR } } } - // Get the logical iterdomains in the producer that are reshaped. - std::unordered_set p_reshape_ids; + for (auto id : c_reshaped_ids) { if (auto p_id = c2p.find(id); p_id != c2p.end()) { p_reshaped_ids.insert(p_id->second); @@ -141,60 +141,86 @@ class PropagateShardingsSelector: public SetSelector { } }; -int64_t handleViewOp(ViewOp* view_op, int64_t did_pos) { +void splitLike(TensorView* tv, int64_t axis, Split* ref_split, bool allow_inner_split = false) { + auto split_factor = ref_split->factor(); + auto inner_split = ref_split->innerSplit(); + NVF_ERROR (!inner_split || allow_inner_split, "Inner split is not supported."); + tv->split(axis, split_factor, /*inner_split=*/inner_split); +} + +// Returns the number of DID axis on reshaped ids that were propagated to the consumer. +int64_t handleViewOp(ViewOp* view_op, int64_t num_device_dims) { // This implementation asserts that only one sharding is applied on the reshaped ids. // Inner split is not supported. // The cases are: // 1. Split reshape: [h] -> [a, h/a]. Sharding on h is applied to a in consumer. // 2. Merge reshape: [a, h/a] -> [h]. Sharding on a is applied to h in consumer. + // 3. Multiple splits or merge reshapes: [x, y, z] -> [xyz]. Sharding on x and xyz. Similarly for the corresponding split reshape. + // 4. Independent splits or merge reshapes: [w, x, y, z] -> [wx, yz]. Sharding is on w and y. In the consumer, it is applied to wx and yz. // An improvement is to support mult-levels of sharding (not a real case in practice) if they are all outer splits. - // For example: [h] -> [cp, h/cp] -> [cp, tp, h/(cp*tp)] + // For example: For the reshape [h] -> [a, h/a] where the h is sharded twice: [h] -> [cp, h/cp] -> [cp, tp, h/(cp*tp)] TensorView* producer = view_op->in(); TensorView* consumer = view_op->out(); const std::unordered_map& c2p = PairwiseLogicalDomainMap(producer, consumer).mapConsumerToProducer(); - auto [p_reshaped_ids, c_reshaped_ids] = getReshapedIds(view_op, c2p); + const std::unordered_map& p2c = PairwiseLogicalDomainMap(producer, consumer).mapProducerToConsumer(); + auto [p_logical_reshaped_ids, c_root_reshaped_ids] = getReshapedIds(view_op, c2p); auto p_loop_domain = producer->getLoopDomain(); auto c_loop_domain = consumer->getLoopDomain(); + auto c_logical_domain = consumer->getLogicalDomain(); + // Track number of DID axis on reshaped ids that were propagated to the consumer. + // These will not be included in TransformPropagator. int64_t num_reshape_shardings = 0; - for (auto idx: c10::irange(did_pos)) { + + for (auto idx: c10::irange(num_device_dims)) { + IterDomain* p_did = p_loop_domain.at(idx); + NVF_ERROR(p_did->isDeviceDim()); + auto p_transforms = DependencyCheck::getAllExprsBetween( - {p_reshaped_ids.begin(), p_reshaped_ids.end()}, {p_loop_domain.at(idx)}); + {p_logical_reshaped_ids.begin(), p_logical_reshaped_ids.end()}, {p_loop_domain.at(idx)}); if (p_transforms.empty()) { - // Sharding is not on reshaped ids. We will use the TransformPropagator. + // This did axis is not on reshaped ids. We will use the TransformPropagator. continue; } - num_reshape_shardings++; - NVF_ERROR(p_transforms.size() == 1 && p_transforms.back()->isA(), "Expected only a single DID split on reshaped ids."); - auto* p_did_split = p_transforms.front()->as(); - auto reshape_transforms = DependencyCheck::getAllExprsBetween( - {c_reshaped_ids.begin(), c_reshaped_ids.end()}, {consumer->getLogicalDomain().begin(), consumer->getLogicalDomain().end()}); + NVF_ERROR(TensorDomain::sameAs(c_logical_domain, c_loop_domain), "Sharding on a previously transformed reshape is not supported."); - // Check if the producer is sharded on the outermost reshaped id. - IterDomain* reshaped_outer_id = nullptr; - for (auto transform_it = reshape_transforms.rbegin(); transform_it != reshape_transforms.rend(); transform_it++) { - auto* transform = *transform_it; - if (transform->isA()) { - reshaped_outer_id = transform->as()->outer(); - } + num_reshape_shardings++; + + // Find the producer logical id that is sharded. + // We expect the outermost reshaped id to be sharded and follow the outermost path traversing the transforms + IterDomain* p_logical_did = p_loop_domain.at(idx); + for (auto transform_it = p_transforms.rbegin(); transform_it != p_transforms.rend(); transform_it++) { + auto transform = *transform_it; if (transform->isA()) { - reshaped_outer_id = transform->as()->in(); + NVF_ERROR(p_logical_did == transform->as()->outer(), "Expected the sharding to be on the outer reshaped id."); + p_logical_did = transform->as()->in(); + } + if (transform->isA()) { + p_logical_did = transform->as()->outer(); } } - NVF_ERROR(c2p.find(reshaped_outer_id)!=c2p.end() && c2p.find(reshaped_outer_id)->second == p_did_split->in(), "Expected the sharding to be on the outer reshaped id."); - // Get the sharded id in the consumer. - IterDomain* c_sharded_id = nullptr; + // Find the mapping of the corresponding producer logical id in consumer root. + IterDomain* c_root_did = p2c.at(p_logical_did); + + // Get the reshape transforms corresponding to this root id. + // We use the c_root_did to only find the reshape IDs related to this did. + auto reshape_transforms = DependencyCheck::getAllExprsBetween( + {c_root_did}, {consumer->getLogicalDomain().begin(), consumer->getLogicalDomain().end()}); + + // Obtain the logical axis sharded in the consumer. + IterDomain* c_logical_did = c_root_did; for (auto transform: reshape_transforms) { if (transform->isA()) { - c_sharded_id = transform->as()->outer(); + c_logical_did = transform->as()->outer(); } if (transform->isA()) { - c_sharded_id = transform->as()->out(); + NVF_ERROR(c_logical_did == transform->as()->outer(), "Expected the sharding to be on the outer reshaped id."); + c_logical_did = transform->as()->out(); } } @@ -202,11 +228,11 @@ int64_t handleViewOp(ViewOp* view_op, int64_t did_pos) { c_loop_domain.begin(), std::find(c_loop_domain.begin(), c_loop_domain.end(), - c_sharded_id)); + c_logical_did)); - Val* split_factor = p_did_split->factor(); - consumer->split(sharded_axis, split_factor, /*inner_split=*/false); - consumer->axis(sharded_axis)->parallelize(p_loop_domain.at(idx)->getParallelType()); + auto* p_did_split = p_did->definition()->as(); + splitLike(consumer, sharded_axis, p_did_split); + consumer->axis(sharded_axis)->parallelize(p_did->getParallelType()); // Move this did_pos to the end in producer to avoid using TransformPropagator on it. producer->reorder({{idx, -1}}); @@ -237,15 +263,15 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { } // This restricts the transform propagation to the DID axis. - int64_t did_pos = reorderDIDToFront(ref_input); + int64_t num_device_dims = reorderDIDToFront(ref_input); if (ViewOp* view_op = dynamic_cast(expr)) { - int64_t num_reshape_shardings = handleViewOp(view_op, did_pos); - did_pos = did_pos - num_reshape_shardings; + int64_t num_reshape_shardings = handleViewOp(view_op, num_device_dims); + num_device_dims = num_device_dims - num_reshape_shardings; } // Propagate the DID loop split to the outputs without mesh. - TransformPropagator propagator(ref_input, did_pos); + TransformPropagator propagator(ref_input, num_device_dims); PropagateShardingsSelector selector( {outputs_without_mesh.begin(), outputs_without_mesh.end()}, /*allow_c2p=*/false, diff --git a/tests/cpp/test_multidevice_sharding.cpp b/tests/cpp/test_multidevice_sharding.cpp index 9ee270bf192..8b8288a1b8f 100644 --- a/tests/cpp/test_multidevice_sharding.cpp +++ b/tests/cpp/test_multidevice_sharding.cpp @@ -781,17 +781,20 @@ TEST_F(MultiDeviceTest, TransformPropagatorSplitReshape) { tv0->axis(-3)->parallelize(ParallelType::DIDx); // in: loop domain: {b, s, DIDx{d}, h, e} - // Propagate DID loop split to output - TransformPropagator propagator_p2c(tv0); - MaxLogicalDomainInfoSpanningTree(tv0).traverse(&propagator_p2c); - // out: loop domain: {b, s, d, h, e} after transform propagation + // // Propagate DID loop split to output + // TransformPropagator propagator_p2c(tv0); + // MaxLogicalDomainInfoSpanningTree(tv0).traverse(&propagator_p2c); + // // out: loop domain: {b, s, d, h, e} after transform propagation + + // // Parallelize output + // tv1->setDeviceMesh(mesh); + // scheduler_utils::parallelizeAllLike( + // tv0, + // /*pos=*/-1, + // /*selected_tv=*/{tv1}); + // // out: loop domain: {b, s, DIDx{d}, h, e} after parallelization - // Parallelize output - scheduler_utils::parallelizeAllLike( - tv0, - /*pos=*/-1, - /*selected_tv=*/{tv1}); - // out: loop domain: {b, s, DIDx{d}, h, e} after parallelization + preseg_passes::OptimizationPass::runPass(fusion.get()); tv0->setAllocationDomain(tv0->getLoopDomain(), true); tv1->setAllocationDomain(tv1->getLoopDomain(), true); @@ -1248,15 +1251,15 @@ TEST_F(MultiDeviceTest, MultipleMergeReshape) { NVF_CHECK(getShardedLoopAxis(tv1, ParallelType::DIDx) != -1); } -TEST_F(MultiDeviceTest, MultipleSplitReshape) { +TEST_F(MultiDeviceTest, MultipleTransformReshape) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); const int d = communicator_->size(); - const int64_t b = 2, s = 3, h = 8; + const int64_t b = 2, s = 3, h = 8, e = 4; - TensorView* tv0 = makeContigConcreteTensor({d*b*s*h}); - TensorView* tv1 = reshape(tv0, {d*b*s*h}, {d*b, s, h}); + TensorView* tv0 = makeContigConcreteTensor({d*b, s, h*e}); + TensorView* tv1 = reshape(tv0, {d*b, s, h*e}, {d*b*s*h, e}); fusion->addInput(tv0); fusion->addOutput(tv1); @@ -1266,7 +1269,7 @@ TEST_F(MultiDeviceTest, MultipleSplitReshape) { tv0->axis(0)->parallelize(ParallelType::DIDx); preseg_passes::OptimizationPass::runPass(fusion.get()); - + NVF_CHECK(getShardedLoopAxis(tv1, ParallelType::DIDx) != -1); } From f90b4789c34053b60e73e70f4a3874d35e5ee469 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Tue, 18 Mar 2025 11:03:57 -0700 Subject: [PATCH 18/70] move tests into separate files --- CMakeLists.txt | 1 + csrc/preseg_passes/propagate_shardings.cpp | 39 ++-- tests/cpp/test_multidevice_preseg_passes.cpp | 227 +++++++++++++++++++ tests/cpp/test_multidevice_sharding.cpp | 73 ------ 4 files changed, 254 insertions(+), 86 deletions(-) create mode 100644 tests/cpp/test_multidevice_preseg_passes.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 81411114bfd..78a40d032db 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -685,6 +685,7 @@ if(BUILD_TEST) ${NVFUSER_ROOT}/tests/cpp/test_multidevice_matmul.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_pipeline.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_sharding.cpp + ${NVFUSER_ROOT}/tests/cpp/test_multidevice_preseg_passes.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_transformer.cpp ) add_test_without_main(test_multidevice "${MULTIDEVICE_TEST_SRCS}" "") diff --git a/csrc/preseg_passes/propagate_shardings.cpp b/csrc/preseg_passes/propagate_shardings.cpp index c5a3e60e5f9..c02c6314acc 100644 --- a/csrc/preseg_passes/propagate_shardings.cpp +++ b/csrc/preseg_passes/propagate_shardings.cpp @@ -181,28 +181,28 @@ int64_t handleViewOp(ViewOp* view_op, int64_t num_device_dims) { auto p_transforms = DependencyCheck::getAllExprsBetween( {p_logical_reshaped_ids.begin(), p_logical_reshaped_ids.end()}, {p_loop_domain.at(idx)}); + if (p_transforms.empty()) { // This did axis is not on reshaped ids. We will use the TransformPropagator. continue; } - NVF_ERROR(TensorDomain::sameAs(c_logical_domain, c_loop_domain), "Sharding on a previously transformed reshape is not supported."); + if (p_transforms.size() > 1) { + // This reshape has been transformed. + // We will attempt to use TransformPropagator for this did axis. + continue; + } + + NVF_ERROR(p_transforms.front()->isA(), "Expected a split transform producing the did axis."); + NVF_ERROR(TensorDomain::sameAs(c_logical_domain, c_loop_domain), + "Sharding a previously transformed reshape is not supported."); num_reshape_shardings++; // Find the producer logical id that is sharded. // We expect the outermost reshaped id to be sharded and follow the outermost path traversing the transforms - IterDomain* p_logical_did = p_loop_domain.at(idx); - for (auto transform_it = p_transforms.rbegin(); transform_it != p_transforms.rend(); transform_it++) { - auto transform = *transform_it; - if (transform->isA()) { - NVF_ERROR(p_logical_did == transform->as()->outer(), "Expected the sharding to be on the outer reshaped id."); - p_logical_did = transform->as()->in(); - } - if (transform->isA()) { - p_logical_did = transform->as()->outer(); - } - } + auto* p_did_split = p_did->definition()->as(); + IterDomain* p_logical_did = p_did_split->in(); // Find the mapping of the corresponding producer logical id in consumer root. IterDomain* c_root_did = p2c.at(p_logical_did); @@ -230,7 +230,7 @@ int64_t handleViewOp(ViewOp* view_op, int64_t num_device_dims) { c_loop_domain.end(), c_logical_did)); - auto* p_did_split = p_did->definition()->as(); + splitLike(consumer, sharded_axis, p_did_split); consumer->axis(sharded_axis)->parallelize(p_did->getParallelType()); @@ -255,6 +255,8 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { const auto& reference_inputs = getOrderedReferenceInputs(expr); + std::unordered_set output_parallel_types; + // Propagate shardings from reference inputs in order. for (auto* ref_input : reference_inputs) { // Skip if the input has no device dimensions or is nullptr. @@ -265,6 +267,14 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { // This restricts the transform propagation to the DID axis. int64_t num_device_dims = reorderDIDToFront(ref_input); + for (auto idx: c10::irange(num_device_dims)) { + if (output_parallel_types.count(ref_input->axis(idx)->getParallelType())) { + // Do not propagate parallel types already seen on the output. + ref_input->reorder({{idx, -1}}); + num_device_dims--; + } + } + if (ViewOp* view_op = dynamic_cast(expr)) { int64_t num_reshape_shardings = handleViewOp(view_op, num_device_dims); num_device_dims = num_device_dims - num_reshape_shardings; @@ -281,6 +291,9 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { // Apply parallelization on the outputs without mesh. shardAllLike(ref_input, outputs_without_mesh); + for (auto idx: c10::irange(num_device_dims)) { + output_parallel_types.insert(ref_input->axis(idx)->getParallelType()); + } } } diff --git a/tests/cpp/test_multidevice_preseg_passes.cpp b/tests/cpp/test_multidevice_preseg_passes.cpp new file mode 100644 index 00000000000..fa9cea0f017 --- /dev/null +++ b/tests/cpp/test_multidevice_preseg_passes.cpp @@ -0,0 +1,227 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include "multidevice/utils.h" + +namespace nvfuser { + +using MultiDevicePresegPassesTest = MultiDeviceTest; + +TEST_F(MultiDevicePresegPassesTest, ResidualAdd) { + // This is similar to the residual add after MHA dropout in the transformer. + // The output of linear following MHA is all-gathered and sharded on the sequence dim. + // This sharding can be propagated to the linear output through backpropagating the shardings + // from residual add. This information is not present during forward propagation. + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + const int d = communicator_->size(); + const int64_t b = 2, s = 3, h = 8; + + TensorView* tv0 = makeContigConcreteTensor({b, d*s, h}); + TensorView* tv1 = makeContigConcreteTensor({b, d*s, h}); + TensorView* tv2 = add(tv0, tv1); + + auto mesh = DeviceMesh::createForNumDevices(d); + tv0->setDeviceMesh(mesh); + tv1->setDeviceMesh(mesh); + tv1->split(1, d, /*inner_split=*/false); + tv1->axis(1)->parallelize(ParallelType::DIDx); + + fusion->addInput(tv0); + fusion->addInput(tv1); + fusion->addOutput(tv2); + + preseg_passes::OptimizationPass::runPass(fusion.get()); + // Set the allocation domain explicitly until the preseg pass is fixed. + for (auto* tv : {tv0, tv1, tv2}) { + tv->setAllocationDomain(tv->getLoopDomain(), true); + } + + NVF_CHECK(getShardedLogicalAxis(tv0, ParallelType::DIDx) == 1); + at::Tensor inp0 = at::randn({b, d*s, h}, tensor_options); + at::Tensor inp1 = at::randn({b, d*s, h}, tensor_options); + at::Tensor sharded_inp0 = shardTensor(inp0, 1, mesh); + at::Tensor sharded_inp1 = shardTensor(inp1, 1, mesh); + + FusionExecutorCache executor_cache(std::move(fusion)); + at::Tensor nvf_out = executor_cache.runFusionWithInputs({sharded_inp0, sharded_inp1})[0].as(); + testValidate( + executor_cache.fusion(), + {nvf_out}, + {sharded_inp0, sharded_inp1}, + {sharded_inp0 + sharded_inp1}, + __LINE__, + __FILE__); +} + +TEST_F(MultiDevicePresegPassesTest, MultipleTransformReshape) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + const int d = communicator_->size(); + const int64_t b = 2, s = 3, h = 8, e = 4; + + TensorView* tv0 = makeContigConcreteTensor({d*b, s, h*e}); + TensorView* tv1 = reshape(tv0, {d*b, s, h*e}, {d*b*s*h, e}); + fusion->addInput(tv0); + fusion->addOutput(tv1); + + auto mesh = DeviceMesh::createForNumDevices(d); + tv0->setDeviceMesh(mesh); + tv0->split(0, d, /*inner_split=*/false); + tv0->axis(0)->parallelize(ParallelType::DIDx); + + preseg_passes::OptimizationPass::runPass(fusion.get()); + for (auto* tv : {tv0, tv1}) { + tv->setAllocationDomain(tv->getLoopDomain(), true); + } + + NVF_CHECK(getShardedLogicalAxis(tv1, ParallelType::DIDx) == 0); + at::Tensor inp = at::randn({d*b, s, h*e}, tensor_options); + at::Tensor sharded_inp = shardTensor(inp, 0, mesh); + FusionExecutorCache executor_cache(std::move(fusion)); + at::Tensor nvf_out = executor_cache.runFusionWithInputs({sharded_inp})[0].as(); + testValidate( + executor_cache.fusion(), + {nvf_out}, + {sharded_inp}, + {sharded_inp.view({b*s*h, e})}, + __LINE__, + __FILE__); +} + +TEST_F(MultiDevicePresegPassesTest, TransformerFwd) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + const int d = communicator_->size(); + const int64_t b = 2, s = 3, a = 8, h=128; + // const double kDropoutProb = 0.1; + + auto mesh = DeviceMesh::createForNumDevices(d); + + TensorView* inp = makeConcreteTensor({b, d*s, h}, DataType::Half); + TensorView* mha_linear0_weight = makeConcreteTensor({3*d*h, h}, DataType::Half); + TensorView* mha_linear1_weight = makeConcreteTensor({h, d*h}, DataType::Half); + + // Layernorm + TensorView* ln_in = maybeCastOp(DataType::Float, inp); + TensorView* ln_out = layer_norm(ln_in, {h}, /*weight=*/nullptr, /*bias=*/nullptr, /*eps=*/IrBuilder::create(1e-5)).output; + TensorView* mha_in = maybeCastOp(inp->dtype(), ln_out); + // MHA Linear0 + TensorView* mha_linear0_out = linear(mha_in, mha_linear0_weight); + + // reshape -> slice -> permute + TensorView* qkv = reshape(mha_linear0_out, {b, s, 3*d*h}, {b, s, d*a, 3*h/a}); + TensorView* q = slice(qkv, {0, 0, 0, 0}, {b, s, d*a, h/a}); + TensorView* k = slice(qkv, {0, 0, 0, h/a}, {b, s, d*a, 2*h/a}); + TensorView* v = slice(qkv, {0, 0, 0, 2*h/a}, {b, s, d*a, 3*h/a}); + + TensorView* q_permuted = permute(q, {0, 2, 1, 3}); + TensorView* k_permuted = permute(k, {0, 2, 1, 3}); + TensorView* v_permuted = permute(v, {0, 2, 1, 3}); + + SdpfaFwdResult sdpa_out = sdpfa_fwd( + q_permuted, + k_permuted, + v_permuted, + /*dropout_p=*/IrBuilder::create(0.0), + /*is_causal=*/IrBuilder::create(false), + /*scale=*/nullptr); + + TensorView* attn = sdpa_out.output; + TensorView* attn_permute = permute(attn, {0, 2, 1, 3}); + TensorView* attn_reshaped = reshape(attn_permute, {b, s, d*a, h/a}, {b, s, d*h}); + + // MHA Linear1: The reduction dimension is sharded and requires communication. + TensorView* mha_linear1_out = linear(attn_reshaped, mha_linear1_weight); + // Val* prob = IrBuilder::create(1.0 - kDropoutProb); + // Val* scale = IrBuilder::create(1.0 / (1.0 - kDropoutProb)); + // TensorView* dropout_out = dropout(mha_linear1_out, prob, scale).output; + // TensorView* residual_add_out = add(dropout_out, inp); + + fusion->addInput(inp); + fusion->addInput(mha_linear0_weight); + fusion->addInput(mha_linear1_weight); + + fusion->addOutput(ln_out); + + // Rfactor mha_linear1_out for communication. + mha_linear1_out->split(-1, d, /*inner_split=*/false); + TensorView* local_mha_linear1_out = mha_linear1_out->rFactor({-1}); + + // Shard input tensors + for (auto* tv : {inp, mha_linear0_weight, mha_linear1_weight, mha_linear1_out, local_mha_linear1_out}) { + tv->setDeviceMesh(mesh); + } + inp->split(1, d, /*inner_split=*/false); + inp->axis(1)->parallelize(ParallelType::DIDx); + + mha_linear0_weight->split(0, d, /*inner_split=*/false); + mha_linear0_weight->axis(0)->parallelize(ParallelType::DIDx); + + mha_linear1_weight->split(1, d, /*inner_split=*/false); + mha_linear1_weight->axis(1)->parallelize(ParallelType::DIDx); + + // Parallelize MHA linear out: This will be done in insert_reshardings preseg pass. + local_mha_linear1_out->axis(-2)->parallelize(ParallelType::DIDx); + + preseg_passes::OptimizationPass::runPass(fusion.get()); + for (auto tv : fusion->allTvs()) { + tv->setAllocationDomain(tv->getLoopDomain(), true); + } + + FusionExecutorCache executor_cache(std::move(fusion)); + at::Tensor inp_tensor = at::randn({b, d*s, h}, tensor_options.dtype(at::kHalf)); + at::Tensor mha_linear0_weight_tensor = at::randn({3*d*h, h}, tensor_options.dtype(at::kHalf)); + at::Tensor mha_linear1_weight_tensor = at::randn({h, d*h}, tensor_options.dtype(at::kHalf)); + + at::Tensor sharded_inp = shardTensor(inp_tensor, 1, mesh); + at::Tensor sharded_mha_linear0_weight = shardTensor(mha_linear0_weight_tensor, 0, mesh); + at::Tensor sharded_mha_linear1_weight = shardTensor(mha_linear1_weight_tensor, 1, mesh); + + at::Tensor nvf_out = + executor_cache + .runFusionWithInputs({sharded_inp, sharded_mha_linear0_weight, sharded_mha_linear1_weight})[0] + .as(); + debug() << "nvf_out: " << nvf_out.sizes() << std::endl; + + // double scale = 1.0 / std::sqrt(e); + // auto reference_out = at::_scaled_dot_product_flash_attention( + // hq_tensor.view(out_shape).transpose(1, 2), + // hk_tensor.view(out_shape).transpose(1, 2), + // hv_tensor.view(out_shape).transpose(1, 2), + // /*dropout_p=*/0.0, + // /*is_causal=*/false, + // /*return_debug_mask=*/false, + // /*scale=*/scale); + // at::Tensor ref_attn = shardTensor( + // std::get<0>(reference_out).transpose(1, 2).view(in_shape), -1, mesh); + + // testValidate( + // executor_cache.fusion(), + // {nvf_out}, + // {sharded_hq, sharded_hk, sharded_hv}, + // {ref_attn}, + // __LINE__, + // __FILE__); +} + + +} // namespace nvfuser \ No newline at end of file diff --git a/tests/cpp/test_multidevice_sharding.cpp b/tests/cpp/test_multidevice_sharding.cpp index 8b8288a1b8f..37b83520078 100644 --- a/tests/cpp/test_multidevice_sharding.cpp +++ b/tests/cpp/test_multidevice_sharding.cpp @@ -1200,77 +1200,4 @@ TEST_F(MultiDeviceTest, TransformerFwd) { __FILE__); } -TEST_F(MultiDeviceTest, ResidualAdd) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - const int d = communicator_->size(); - const int64_t b = 2, s = 3, h = 8; - - TensorView* tv0 = makeContigConcreteTensor({b, d*s, h}); - TensorView* tv1 = makeContigConcreteTensor({b, d*s, h}); - TensorView* tv2 = add(tv0, tv1); - - auto mesh = DeviceMesh::createForNumDevices(d); - tv0->setDeviceMesh(mesh); - tv1->setDeviceMesh(mesh); - tv1->split(1, d, /*inner_split=*/false); - tv1->axis(1)->parallelize(ParallelType::DIDx); - - fusion->addInput(tv0); - fusion->addInput(tv1); - fusion->addOutput(tv2); - - preseg_passes::OptimizationPass::runPass(fusion.get()); - NVF_CHECK(getShardedLoopAxis(tv0, ParallelType::DIDx) != -1); -} - -TEST_F(MultiDeviceTest, MultipleMergeReshape) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - const int d = communicator_->size(); - const int64_t b = 2, s = 3, h = 8; - - TensorView* tv0 = makeContigConcreteTensor({d*b, s, h}); - TensorView* tv1 = reshape(tv0, {d*b, s, h}, {d*b, s, h}); - fusion->addInput(tv0); - fusion->addOutput(tv1); - - auto mesh = DeviceMesh::createForNumDevices(d); - tv0->setDeviceMesh(mesh); - tv0->split(0, d, /*inner_split=*/false); - tv0->axis(0)->parallelize(ParallelType::DIDx); - - auto transform_exprs = StmtSort::getExprsBetween( - {tv1->getMaybeRootDomain().begin(), tv1->getMaybeRootDomain().end()}, - {tv1->getLogicalDomain().begin(), tv1->getLogicalDomain().end()}); - - preseg_passes::OptimizationPass::runPass(fusion.get()); - - NVF_CHECK(getShardedLoopAxis(tv1, ParallelType::DIDx) != -1); -} - -TEST_F(MultiDeviceTest, MultipleTransformReshape) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - const int d = communicator_->size(); - const int64_t b = 2, s = 3, h = 8, e = 4; - - TensorView* tv0 = makeContigConcreteTensor({d*b, s, h*e}); - TensorView* tv1 = reshape(tv0, {d*b, s, h*e}, {d*b*s*h, e}); - fusion->addInput(tv0); - fusion->addOutput(tv1); - - auto mesh = DeviceMesh::createForNumDevices(d); - tv0->setDeviceMesh(mesh); - tv0->split(0, d, /*inner_split=*/false); - tv0->axis(0)->parallelize(ParallelType::DIDx); - - preseg_passes::OptimizationPass::runPass(fusion.get()); - - NVF_CHECK(getShardedLoopAxis(tv1, ParallelType::DIDx) != -1); -} - } // namespace nvfuser From 4b066d12b92a2ad220a1a088abb40df58ae71583 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Tue, 18 Mar 2025 15:04:37 -0700 Subject: [PATCH 19/70] reorder back as logical, more tests --- csrc/host_ir/lower.cpp | 5 +- csrc/preseg_passes/propagate_shardings.cpp | 173 +++++++----- tests/cpp/test_multidevice_preseg_passes.cpp | 278 ++++++++++++------- 3 files changed, 296 insertions(+), 160 deletions(-) diff --git a/csrc/host_ir/lower.cpp b/csrc/host_ir/lower.cpp index 8be0e4214eb..51bde80e74b 100644 --- a/csrc/host_ir/lower.cpp +++ b/csrc/host_ir/lower.cpp @@ -435,10 +435,9 @@ bool HostIrLower::canLower(Expr* expr, bool ignore_inner_resharding) { // stream-parallelized on axis 0. auto* a = linear->inA()->as(); auto* b = linear->inB()->as(); - auto* bias = linear->bias()->as(); auto* out = linear->out()->as(); - return !isSharded(b) && !(linear->has_bias() && isSharded(bias)) && - !isSharded(out) && + return !isSharded(b) && + !(linear->has_bias() && isSharded(linear->bias())) && !isSharded(out) && a->axis(0)->getParallelType() == ParallelType::Serial && getShardedLogicalAxis(a, ParallelType::DIDx) == 1 && out->axis(0)->getParallelType() == ParallelType::Stream; diff --git a/csrc/preseg_passes/propagate_shardings.cpp b/csrc/preseg_passes/propagate_shardings.cpp index c02c6314acc..985184a77f4 100644 --- a/csrc/preseg_passes/propagate_shardings.cpp +++ b/csrc/preseg_passes/propagate_shardings.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include namespace nvfuser::preseg_passes { @@ -49,13 +50,16 @@ void validateMeshes(Fusion* fusion) { " not."); } -std::pair, std::unordered_set> getReshapedIds(ViewOp* view_op, const std::unordered_map& c2p) { +std::pair, std::unordered_set> +getReshapedIds( + ViewOp* view_op, + const std::unordered_map& c2p) { std::unordered_set p_reshaped_ids; // Reshaped logical IDs std::unordered_set c_reshaped_ids; // Reshaped root IDs TensorView* consumer = view_op->out(); std::vector c_root_domain = consumer->getMaybeRootDomain(); - + for (auto id : consumer->getLogicalDomain()) { if (id->isRFactorProduct() && id->definition() && !id->definition()->isA()) { @@ -84,7 +88,8 @@ int64_t num_device_dims(TensorView* tv) { // Order the inputs of the expression based on their priority. // For linear op, we use weights and bias before input. // For matmul op, we use weights before input. -// For other ops, we sort the inputs by the number of device dimensions in descending order. +// For other ops, we sort the inputs by the number of device dimensions in +// descending order. std::vector getOrderedReferenceInputs(Expr* expr) { const auto& inputs = ir_utils::filterByType(expr->inputs()); if (LinearOp* linear_op = dynamic_cast(expr)) { @@ -99,10 +104,12 @@ std::vector getOrderedReferenceInputs(Expr* expr) { // Sort inputs by number of device dimensions in descending order std::vector sorted_inputs(inputs.begin(), inputs.end()); - std::sort(sorted_inputs.begin(), sorted_inputs.end(), - [&](TensorView* a, TensorView* b) { - return num_device_dims(a) > num_device_dims(b); - }); + std::sort( + sorted_inputs.begin(), + sorted_inputs.end(), + [&](TensorView* a, TensorView* b) { + return num_device_dims(a) > num_device_dims(b); + }); return sorted_inputs; } @@ -118,7 +125,7 @@ std::vector getOutputsWithoutMesh(Expr* expr) { return outputs_without_mesh; } -class PropagateShardingsSelector: public SetSelector { +class PropagateShardingsSelector : public SetSelector { private: bool allow_c2p_; bool allow_p2c_; @@ -141,49 +148,63 @@ class PropagateShardingsSelector: public SetSelector { } }; -void splitLike(TensorView* tv, int64_t axis, Split* ref_split, bool allow_inner_split = false) { +void splitLike( + TensorView* tv, + int64_t axis, + Split* ref_split, + bool allow_inner_split = false) { auto split_factor = ref_split->factor(); auto inner_split = ref_split->innerSplit(); - NVF_ERROR (!inner_split || allow_inner_split, "Inner split is not supported."); + NVF_ERROR(!inner_split || allow_inner_split, "Inner split is not supported."); tv->split(axis, split_factor, /*inner_split=*/inner_split); } -// Returns the number of DID axis on reshaped ids that were propagated to the consumer. +// Returns the number of DID axis on reshaped ids that were propagated to the +// consumer. int64_t handleViewOp(ViewOp* view_op, int64_t num_device_dims) { - // This implementation asserts that only one sharding is applied on the reshaped ids. - // Inner split is not supported. - // The cases are: - // 1. Split reshape: [h] -> [a, h/a]. Sharding on h is applied to a in consumer. - // 2. Merge reshape: [a, h/a] -> [h]. Sharding on a is applied to h in consumer. - // 3. Multiple splits or merge reshapes: [x, y, z] -> [xyz]. Sharding on x and xyz. Similarly for the corresponding split reshape. - // 4. Independent splits or merge reshapes: [w, x, y, z] -> [wx, yz]. Sharding is on w and y. In the consumer, it is applied to wx and yz. - // An improvement is to support mult-levels of sharding (not a real case in practice) if they are all outer splits. - // For example: For the reshape [h] -> [a, h/a] where the h is sharded twice: [h] -> [cp, h/cp] -> [cp, tp, h/(cp*tp)] + // This implementation asserts that only one sharding is applied on the + // reshaped ids. Inner split is not supported. The cases are: + // 1. Split reshape: [h] -> [a, h/a]. Sharding on h is applied to a in + // consumer. + // 2. Merge reshape: [a, h/a] -> [h]. Sharding on a is applied to h in + // consumer. + // 3. Multiple splits or merge reshapes: [x, y, z] -> [xyz]. Sharding on x and + // xyz. Similarly for the corresponding split reshape. + // 4. Independent splits or merge reshapes: [w, x, y, z] -> [wx, yz]. Sharding + // is on w and y. In the consumer, it is applied to wx and yz. An improvement + // is to support mult-levels of sharding (not a real case in practice) if they + // are all outer splits. For example: For the reshape [h] -> [a, h/a] where + // the h is sharded twice: [h] -> [cp, h/cp] -> [cp, tp, h/(cp*tp)] TensorView* producer = view_op->in(); TensorView* consumer = view_op->out(); - const std::unordered_map& c2p = PairwiseLogicalDomainMap(producer, consumer).mapConsumerToProducer(); - const std::unordered_map& p2c = PairwiseLogicalDomainMap(producer, consumer).mapProducerToConsumer(); - auto [p_logical_reshaped_ids, c_root_reshaped_ids] = getReshapedIds(view_op, c2p); - + const std::unordered_map& c2p = + PairwiseLogicalDomainMap(producer, consumer).mapConsumerToProducer(); + const std::unordered_map& p2c = + PairwiseLogicalDomainMap(producer, consumer).mapProducerToConsumer(); + auto [p_logical_reshaped_ids, c_root_reshaped_ids] = + getReshapedIds(view_op, c2p); + auto p_loop_domain = producer->getLoopDomain(); auto c_loop_domain = consumer->getLoopDomain(); auto c_logical_domain = consumer->getLogicalDomain(); - - // Track number of DID axis on reshaped ids that were propagated to the consumer. - // These will not be included in TransformPropagator. + + // Track number of DID axis on reshaped ids that were propagated to the + // consumer. These will not be included in TransformPropagator. int64_t num_reshape_shardings = 0; - for (auto idx: c10::irange(num_device_dims)) { + for (auto idx : c10::irange(num_device_dims)) { IterDomain* p_did = p_loop_domain.at(idx); NVF_ERROR(p_did->isDeviceDim()); - + auto p_transforms = DependencyCheck::getAllExprsBetween( - {p_logical_reshaped_ids.begin(), p_logical_reshaped_ids.end()}, {p_loop_domain.at(idx)}); - + {p_logical_reshaped_ids.begin(), p_logical_reshaped_ids.end()}, + {p_loop_domain.at(idx)}); + if (p_transforms.empty()) { - // This did axis is not on reshaped ids. We will use the TransformPropagator. + // This did axis is not on reshaped ids. We will use the + // TransformPropagator. continue; } @@ -193,58 +214,71 @@ int64_t handleViewOp(ViewOp* view_op, int64_t num_device_dims) { continue; } - NVF_ERROR(p_transforms.front()->isA(), "Expected a split transform producing the did axis."); - NVF_ERROR(TensorDomain::sameAs(c_logical_domain, c_loop_domain), - "Sharding a previously transformed reshape is not supported."); + NVF_ERROR( + p_transforms.front()->isA(), + "Expected a split transform producing the did axis."); + NVF_ERROR( + TensorDomain::sameAs(c_logical_domain, c_loop_domain), + "Sharding a previously transformed reshape is not supported."); num_reshape_shardings++; // Find the producer logical id that is sharded. - // We expect the outermost reshaped id to be sharded and follow the outermost path traversing the transforms + // We expect the outermost reshaped id to be sharded and follow the + // outermost path traversing the transforms auto* p_did_split = p_did->definition()->as(); IterDomain* p_logical_did = p_did_split->in(); - // Find the mapping of the corresponding producer logical id in consumer root. + // Find the mapping of the corresponding producer logical id in consumer + // root. IterDomain* c_root_did = p2c.at(p_logical_did); - - // Get the reshape transforms corresponding to this root id. + + // Get the reshape transforms corresponding to this root id. // We use the c_root_did to only find the reshape IDs related to this did. auto reshape_transforms = DependencyCheck::getAllExprsBetween( - {c_root_did}, {consumer->getLogicalDomain().begin(), consumer->getLogicalDomain().end()}); + {c_root_did}, + {consumer->getLogicalDomain().begin(), + consumer->getLogicalDomain().end()}); // Obtain the logical axis sharded in the consumer. IterDomain* c_logical_did = c_root_did; - for (auto transform: reshape_transforms) { + for (auto transform : reshape_transforms) { if (transform->isA()) { c_logical_did = transform->as()->outer(); } if (transform->isA()) { - NVF_ERROR(c_logical_did == transform->as()->outer(), "Expected the sharding to be on the outer reshaped id."); + NVF_ERROR( + c_logical_did == transform->as()->outer(), + "Expected the sharding to be on the outer reshaped id."); c_logical_did = transform->as()->out(); } } int64_t sharded_axis = std::distance( - c_loop_domain.begin(), - std::find(c_loop_domain.begin(), - c_loop_domain.end(), - c_logical_did)); - - + c_loop_domain.begin(), + std::find(c_loop_domain.begin(), c_loop_domain.end(), c_logical_did)); + splitLike(consumer, sharded_axis, p_did_split); consumer->axis(sharded_axis)->parallelize(p_did->getParallelType()); - // Move this did_pos to the end in producer to avoid using TransformPropagator on it. + // Move this did_pos to the end in producer to avoid using + // TransformPropagator on it. producer->reorder({{idx, -1}}); } return num_reshape_shardings; } +void reorderAllAsLogicalMap(std::vector tvs) { + for (auto tv : tvs) { + tv->reorder(scheduler_utils::domainReorderAsLogicalMap(tv)); + } +} + } // namespace void PropagateShardingsPass::runPass(Fusion* fusion) { const std::vector& exprs = fusion->exprs(); - + for (Expr* expr : exprs) { // Note: Tvs without a mesh are assumed to have no manual sharding // annotation and are sharded like the first producer Tv. @@ -256,19 +290,20 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { const auto& reference_inputs = getOrderedReferenceInputs(expr); std::unordered_set output_parallel_types; - + // Propagate shardings from reference inputs in order. for (auto* ref_input : reference_inputs) { // Skip if the input has no device dimensions or is nullptr. if (ref_input == nullptr || num_device_dims(ref_input) == 0) { continue; } - + // This restricts the transform propagation to the DID axis. int64_t num_device_dims = reorderDIDToFront(ref_input); - for (auto idx: c10::irange(num_device_dims)) { - if (output_parallel_types.count(ref_input->axis(idx)->getParallelType())) { + for (auto idx : c10::irange(num_device_dims)) { + if (output_parallel_types.count( + ref_input->axis(idx)->getParallelType())) { // Do not propagate parallel types already seen on the output. ref_input->reorder({{idx, -1}}); num_device_dims--; @@ -279,7 +314,7 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { int64_t num_reshape_shardings = handleViewOp(view_op, num_device_dims); num_device_dims = num_device_dims - num_reshape_shardings; } - + // Propagate the DID loop split to the outputs without mesh. TransformPropagator propagator(ref_input, num_device_dims); PropagateShardingsSelector selector( @@ -288,10 +323,17 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { /*allow_p2c=*/true); MaxLogicalDomainInfoSpanningTree(ref_input, &selector) .traverse(&propagator); - + // Apply parallelization on the outputs without mesh. shardAllLike(ref_input, outputs_without_mesh); - for (auto idx: c10::irange(num_device_dims)) { + // Reorder the loop as logical domain since the transform propagator may + // have reordered the iterdomains in loop domain. For example: Consider + // linear op: in = [b, m, k] weight = [DIDx(d), n/d, k] After + // transformation, the loop domain of linear output is [DIDx(d), n/d, b, + // m, r{k}] Since we later set the allocation domain to be loop domain, we + // reorder the loop domain as logical domain. + reorderAllAsLogicalMap(outputs_without_mesh); + for (auto idx : c10::irange(num_device_dims)) { output_parallel_types.insert(ref_input->axis(idx)->getParallelType()); } } @@ -313,9 +355,10 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { continue; } - // All outputs of an expression are uniformly sharded so we pick the first one. - // TODO: Do we need to worry about the case where the outputs are not uniformly sharded? - // The relevant exprs are Welford and SDPA. + // All outputs of an expression are uniformly sharded so we pick the first + // one. + // TODO: Do we need to worry about the case where the outputs are not + // uniformly sharded? The relevant exprs are Welford and SDPA. TensorView* ref_output = *i_output; int64_t did_pos = reorderDIDToFront(ref_output); @@ -325,14 +368,20 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { inputs.begin(), inputs.end(), std::back_inserter(unsharded_inputs), - [](TensorView* tv) { return !tv->hasDeviceMesh() || num_device_dims(tv) == 0; }); + [](TensorView* tv) { + return !tv->hasDeviceMesh() || num_device_dims(tv) == 0; + }); + // Note: We do not have to manually shard for reshape here. + // TransformPropagator can handle reshapes when going from consumer to + // producer. TransformPropagator propagator(ref_output, did_pos); PropagateShardingsSelector selector( {unsharded_inputs.begin(), unsharded_inputs.end()}, /*allow_c2p=*/true, /*allow_p2c=*/false); - MaxLogicalDomainInfoSpanningTree(ref_output, &selector).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(ref_output, &selector) + .traverse(&propagator); shardAllLike(ref_output, unsharded_inputs, /*parallelize_inputs=*/true); } diff --git a/tests/cpp/test_multidevice_preseg_passes.cpp b/tests/cpp/test_multidevice_preseg_passes.cpp index fa9cea0f017..2eb97438190 100644 --- a/tests/cpp/test_multidevice_preseg_passes.cpp +++ b/tests/cpp/test_multidevice_preseg_passes.cpp @@ -20,21 +20,26 @@ namespace nvfuser { +constexpr int64_t b = 2, s = 3, h = 128, a = 8; +constexpr double dropout_p = 0.0; +constexpr bool is_causal = false; + using MultiDevicePresegPassesTest = MultiDeviceTest; TEST_F(MultiDevicePresegPassesTest, ResidualAdd) { // This is similar to the residual add after MHA dropout in the transformer. - // The output of linear following MHA is all-gathered and sharded on the sequence dim. - // This sharding can be propagated to the linear output through backpropagating the shardings - // from residual add. This information is not present during forward propagation. + // The output of linear following MHA is all-gathered and sharded on the + // sequence dim. This sharding can be propagated to the linear output through + // backpropagating the shardings from residual add. This information is not + // present during forward propagation. auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); const int d = communicator_->size(); const int64_t b = 2, s = 3, h = 8; - TensorView* tv0 = makeContigConcreteTensor({b, d*s, h}); - TensorView* tv1 = makeContigConcreteTensor({b, d*s, h}); + TensorView* tv0 = makeContigConcreteTensor({b, d * s, h}); + TensorView* tv1 = makeContigConcreteTensor({b, d * s, h}); TensorView* tv2 = add(tv0, tv1); auto mesh = DeviceMesh::createForNumDevices(d); @@ -47,20 +52,23 @@ TEST_F(MultiDevicePresegPassesTest, ResidualAdd) { fusion->addInput(tv1); fusion->addOutput(tv2); - preseg_passes::OptimizationPass::runPass(fusion.get()); + preseg_passes::OptimizationPass< + preseg_passes::PropagateShardingsPass>::runPass(fusion.get()); // Set the allocation domain explicitly until the preseg pass is fixed. for (auto* tv : {tv0, tv1, tv2}) { tv->setAllocationDomain(tv->getLoopDomain(), true); } NVF_CHECK(getShardedLogicalAxis(tv0, ParallelType::DIDx) == 1); - at::Tensor inp0 = at::randn({b, d*s, h}, tensor_options); - at::Tensor inp1 = at::randn({b, d*s, h}, tensor_options); + at::Tensor inp0 = at::randn({b, d * s, h}, tensor_options); + at::Tensor inp1 = at::randn({b, d * s, h}, tensor_options); at::Tensor sharded_inp0 = shardTensor(inp0, 1, mesh); at::Tensor sharded_inp1 = shardTensor(inp1, 1, mesh); - + FusionExecutorCache executor_cache(std::move(fusion)); - at::Tensor nvf_out = executor_cache.runFusionWithInputs({sharded_inp0, sharded_inp1})[0].as(); + at::Tensor nvf_out = + executor_cache.runFusionWithInputs({sharded_inp0, sharded_inp1})[0] + .as(); testValidate( executor_cache.fusion(), {nvf_out}, @@ -77,8 +85,8 @@ TEST_F(MultiDevicePresegPassesTest, MultipleTransformReshape) { const int d = communicator_->size(); const int64_t b = 2, s = 3, h = 8, e = 4; - TensorView* tv0 = makeContigConcreteTensor({d*b, s, h*e}); - TensorView* tv1 = reshape(tv0, {d*b, s, h*e}, {d*b*s*h, e}); + TensorView* tv0 = makeContigConcreteTensor({d * b, s, h * e}); + TensorView* tv1 = reshape(tv0, {d * b, s, h * e}, {d * b * s * h, e}); fusion->addInput(tv0); fusion->addOutput(tv1); @@ -87,52 +95,165 @@ TEST_F(MultiDevicePresegPassesTest, MultipleTransformReshape) { tv0->split(0, d, /*inner_split=*/false); tv0->axis(0)->parallelize(ParallelType::DIDx); - preseg_passes::OptimizationPass::runPass(fusion.get()); + preseg_passes::OptimizationPass< + preseg_passes::PropagateShardingsPass>::runPass(fusion.get()); for (auto* tv : {tv0, tv1}) { tv->setAllocationDomain(tv->getLoopDomain(), true); } NVF_CHECK(getShardedLogicalAxis(tv1, ParallelType::DIDx) == 0); - at::Tensor inp = at::randn({d*b, s, h*e}, tensor_options); + at::Tensor inp = at::randn({d * b, s, h * e}, tensor_options); at::Tensor sharded_inp = shardTensor(inp, 0, mesh); FusionExecutorCache executor_cache(std::move(fusion)); - at::Tensor nvf_out = executor_cache.runFusionWithInputs({sharded_inp})[0].as(); + at::Tensor nvf_out = + executor_cache.runFusionWithInputs({sharded_inp})[0].as(); testValidate( executor_cache.fusion(), {nvf_out}, {sharded_inp}, - {sharded_inp.view({b*s*h, e})}, + {sharded_inp.view({b * s * h, e})}, __LINE__, __FILE__); } -TEST_F(MultiDevicePresegPassesTest, TransformerFwd) { +TEST_F(MultiDevicePresegPassesTest, SliceReshapePermute) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + const int d = communicator_->size(); + const int64_t b = 2, s = 3, h = 128, a = 8; + + auto mesh = DeviceMesh::createForNumDevices(d); + + TensorView* tv0 = makeConcreteTensor({b, s, 3 * d * h}); + TensorView* tv1 = reshape(tv0, {b, s, 3 * d * h}, {b, s, d * a, 3 * h / a}); + TensorView* tv2 = slice(tv1, {0, 0, 0, 0}, {b, s, d * a, h / a}); + TensorView* tv3 = permute(tv2, {0, 2, 1, 3}); + + fusion->addInput(tv0); + fusion->addOutput(tv3); + + tv0->setDeviceMesh(mesh); + tv0->split(-1, d, /*inner_split=*/false); + tv0->axis(-2)->parallelize(ParallelType::DIDx); + + preseg_passes::OptimizationPass< + preseg_passes::PropagateShardingsPass>::runPass(fusion.get()); + + for (auto* tv : fusion->allTvs()) { + reorderDIDToFront(tv); + tv->setAllocationDomain(tv->getLoopDomain(), true); + } + + FusionExecutorCache executor_cache(std::move(fusion)); + at::Tensor inp = at::randn({b, s, 3 * d * h}, tensor_options); + at::Tensor sharded_inp = shardTensor(inp, -1, mesh); + at::Tensor nvf_out = + executor_cache.runFusionWithInputs({sharded_inp})[0].as(); + + at::Tensor reference_out = sharded_inp.view({b, s, a, 3 * h / a}) + .index( + {at::indexing::Slice(0), + at::indexing::Slice(0), + at::indexing::Slice(0), + at::indexing::Slice(0, h / a)}) + .transpose(1, 2); + + testValidate( + executor_cache.fusion(), + {nvf_out}, + {sharded_inp}, + {reference_out}, + __LINE__, + __FILE__); +} + +// TODO: Enable this test once the insert_reshardings preseg pass is fixed. +TEST_F(MultiDevicePresegPassesTest, DISABLED_MHALinear) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + const int d = communicator_->size(); + auto mesh = DeviceMesh::createForNumDevices(d); + const int64_t b = 2, s = 3, h = 128; //,a=8; + + TensorView* inp = makeConcreteTensor({b, d * s, h}, DataType::Half); + TensorView* weight = makeConcreteTensor({3 * d * h, h}, DataType::Half); + TensorView* out = linear(inp, weight); + + fusion->addInput(inp); + fusion->addInput(weight); + fusion->addOutput(out); + + inp->setDeviceMesh(mesh); + weight->setDeviceMesh(mesh); + inp->split(1, d, /*inner_split=*/false); + inp->axis(1)->parallelize(ParallelType::DIDx); + weight->split(0, d, /*inner_split=*/false); + weight->axis(0)->parallelize(ParallelType::DIDx); + + preseg_passes::OptimizationPass< + preseg_passes::PropagateShardingsPass>::runPass(fusion.get()); + for (auto* tv : fusion->allTvs()) { + reorderDIDToFront(tv); + tv->setAllocationDomain(tv->getLoopDomain(), true); + } + NVF_CHECK(getShardedLogicalAxis(out, ParallelType::DIDx) == 2); + at::Tensor inp_tensor = + at::randn({b, d * s, h}, tensor_options.dtype(at::kHalf)); + at::Tensor sharded_inp = shardTensor(inp_tensor, 1, mesh); + + at::Tensor weight_tensor = + at::randn({3 * d * h, h}, tensor_options.dtype(at::kHalf)); + at::Tensor sharded_weight = shardTensor(weight_tensor, 0, mesh); + + FusionExecutorCache executor_cache(std::move(fusion)); + at::Tensor nvf_out = + executor_cache.runFusionWithInputs({sharded_inp, sharded_weight})[0] + .as(); +} + +namespace { +at::Tensor reference_mha(at::Tensor inp, at::Tensor weight) { + at::Tensor linear0_out = at::linear(inp, weight); + auto qkv = + linear0_out.view({b, s, a, 3 * h / a}).transpose(1, 2).split(h / a, -1); + double scale = 1.0 / std::sqrt(h / a); + auto sdpa_out = at::_scaled_dot_product_flash_attention( + qkv[0], + qkv[1], + qkv[2], + /*dropout_p=*/dropout_p, + /*is_causal=*/is_causal, + /*return_debug_mask=*/false, + scale); + auto attn = std::get<0>(sdpa_out); + return attn.transpose(1, 2).reshape({b, s, h}); +} +} // namespace + +TEST_F(MultiDevicePresegPassesTest, MHAFwd) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); const int d = communicator_->size(); - const int64_t b = 2, s = 3, a = 8, h=128; - // const double kDropoutProb = 0.1; + const int64_t b = 2, s = 3, h = 128, a = 8; auto mesh = DeviceMesh::createForNumDevices(d); - TensorView* inp = makeConcreteTensor({b, d*s, h}, DataType::Half); - TensorView* mha_linear0_weight = makeConcreteTensor({3*d*h, h}, DataType::Half); - TensorView* mha_linear1_weight = makeConcreteTensor({h, d*h}, DataType::Half); + TensorView* inp = makeConcreteTensor({b, s, h}, DataType::Half); + TensorView* mha_w0 = makeConcreteTensor({3 * d * h, h}, DataType::Half); - // Layernorm - TensorView* ln_in = maybeCastOp(DataType::Float, inp); - TensorView* ln_out = layer_norm(ln_in, {h}, /*weight=*/nullptr, /*bias=*/nullptr, /*eps=*/IrBuilder::create(1e-5)).output; - TensorView* mha_in = maybeCastOp(inp->dtype(), ln_out); // MHA Linear0 - TensorView* mha_linear0_out = linear(mha_in, mha_linear0_weight); + TensorView* mha_linear0_out = linear(inp, mha_w0); // reshape -> slice -> permute - TensorView* qkv = reshape(mha_linear0_out, {b, s, 3*d*h}, {b, s, d*a, 3*h/a}); - TensorView* q = slice(qkv, {0, 0, 0, 0}, {b, s, d*a, h/a}); - TensorView* k = slice(qkv, {0, 0, 0, h/a}, {b, s, d*a, 2*h/a}); - TensorView* v = slice(qkv, {0, 0, 0, 2*h/a}, {b, s, d*a, 3*h/a}); - + TensorView* qkv = + reshape(mha_linear0_out, {b, s, 3 * d * h}, {b, s, d * a, 3 * h / a}); + TensorView* q = slice(qkv, {0, 0, 0, 0}, {b, s, d * a, h / a}); + TensorView* k = slice(qkv, {0, 0, 0, h / a}, {b, s, d * a, 2 * h / a}); + TensorView* v = slice(qkv, {0, 0, 0, 2 * h / a}, {b, s, d * a, 3 * h / a}); + TensorView* q_permuted = permute(q, {0, 2, 1, 3}); TensorView* k_permuted = permute(k, {0, 2, 1, 3}); TensorView* v_permuted = permute(v, {0, 2, 1, 3}); @@ -141,87 +262,54 @@ TEST_F(MultiDevicePresegPassesTest, TransformerFwd) { q_permuted, k_permuted, v_permuted, - /*dropout_p=*/IrBuilder::create(0.0), - /*is_causal=*/IrBuilder::create(false), + /*dropout_p=*/IrBuilder::create(dropout_p), + /*is_causal=*/IrBuilder::create(is_causal), /*scale=*/nullptr); TensorView* attn = sdpa_out.output; TensorView* attn_permute = permute(attn, {0, 2, 1, 3}); - TensorView* attn_reshaped = reshape(attn_permute, {b, s, d*a, h/a}, {b, s, d*h}); - - // MHA Linear1: The reduction dimension is sharded and requires communication. - TensorView* mha_linear1_out = linear(attn_reshaped, mha_linear1_weight); - // Val* prob = IrBuilder::create(1.0 - kDropoutProb); - // Val* scale = IrBuilder::create(1.0 / (1.0 - kDropoutProb)); - // TensorView* dropout_out = dropout(mha_linear1_out, prob, scale).output; - // TensorView* residual_add_out = add(dropout_out, inp); + TensorView* attn_reshaped = + reshape(attn_permute, {b, s, d * a, h / a}, {b, s, d * h}); fusion->addInput(inp); - fusion->addInput(mha_linear0_weight); - fusion->addInput(mha_linear1_weight); - - fusion->addOutput(ln_out); - - // Rfactor mha_linear1_out for communication. - mha_linear1_out->split(-1, d, /*inner_split=*/false); - TensorView* local_mha_linear1_out = mha_linear1_out->rFactor({-1}); + fusion->addInput(mha_w0); + fusion->addOutput(attn_reshaped); // Shard input tensors - for (auto* tv : {inp, mha_linear0_weight, mha_linear1_weight, mha_linear1_out, local_mha_linear1_out}) { + for (auto* tv : {inp, mha_w0}) { tv->setDeviceMesh(mesh); } - inp->split(1, d, /*inner_split=*/false); - inp->axis(1)->parallelize(ParallelType::DIDx); - - mha_linear0_weight->split(0, d, /*inner_split=*/false); - mha_linear0_weight->axis(0)->parallelize(ParallelType::DIDx); - - mha_linear1_weight->split(1, d, /*inner_split=*/false); - mha_linear1_weight->axis(1)->parallelize(ParallelType::DIDx); - // Parallelize MHA linear out: This will be done in insert_reshardings preseg pass. - local_mha_linear1_out->axis(-2)->parallelize(ParallelType::DIDx); + mha_w0->split(0, d, /*inner_split=*/false); + mha_w0->axis(0)->parallelize(ParallelType::DIDx); - preseg_passes::OptimizationPass::runPass(fusion.get()); + preseg_passes::OptimizationPass< + preseg_passes::PropagateShardingsPass>::runPass(fusion.get()); for (auto tv : fusion->allTvs()) { + reorderDIDToFront(tv); tv->setAllocationDomain(tv->getLoopDomain(), true); } FusionExecutorCache executor_cache(std::move(fusion)); - at::Tensor inp_tensor = at::randn({b, d*s, h}, tensor_options.dtype(at::kHalf)); - at::Tensor mha_linear0_weight_tensor = at::randn({3*d*h, h}, tensor_options.dtype(at::kHalf)); - at::Tensor mha_linear1_weight_tensor = at::randn({h, d*h}, tensor_options.dtype(at::kHalf)); + at::Tensor inp_tensor = at::randn({b, s, h}, tensor_options.dtype(at::kHalf)); - at::Tensor sharded_inp = shardTensor(inp_tensor, 1, mesh); - at::Tensor sharded_mha_linear0_weight = shardTensor(mha_linear0_weight_tensor, 0, mesh); - at::Tensor sharded_mha_linear1_weight = shardTensor(mha_linear1_weight_tensor, 1, mesh); + at::Tensor mha_w0_tensor = + at::randn({3 * d * h, h}, tensor_options.dtype(at::kHalf)); + at::Tensor sharded_mha_w0 = shardTensor(mha_w0_tensor, 0, mesh); - at::Tensor nvf_out = - executor_cache - .runFusionWithInputs({sharded_inp, sharded_mha_linear0_weight, sharded_mha_linear1_weight})[0] - .as(); - debug() << "nvf_out: " << nvf_out.sizes() << std::endl; - - // double scale = 1.0 / std::sqrt(e); - // auto reference_out = at::_scaled_dot_product_flash_attention( - // hq_tensor.view(out_shape).transpose(1, 2), - // hk_tensor.view(out_shape).transpose(1, 2), - // hv_tensor.view(out_shape).transpose(1, 2), - // /*dropout_p=*/0.0, - // /*is_causal=*/false, - // /*return_debug_mask=*/false, - // /*scale=*/scale); - // at::Tensor ref_attn = shardTensor( - // std::get<0>(reference_out).transpose(1, 2).view(in_shape), -1, mesh); - - // testValidate( - // executor_cache.fusion(), - // {nvf_out}, - // {sharded_hq, sharded_hk, sharded_hv}, - // {ref_attn}, - // __LINE__, - // __FILE__); -} + KernelArgumentHolder args = {inp_tensor, sharded_mha_w0}; + auto outputs = executor_cache.runFusionWithInputs(args); + at::Tensor nvf_out = outputs[0].as(); + + at::Tensor ref_out = reference_mha(inp_tensor, sharded_mha_w0); + testValidate( + executor_cache.fusion(), + {nvf_out}, + {inp_tensor, sharded_mha_w0}, + {ref_out}, + __LINE__, + __FILE__); +} -} // namespace nvfuser \ No newline at end of file +} // namespace nvfuser From 3b673abe22b201f1845df93620f4e6484daa8375 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Tue, 18 Mar 2025 15:25:59 -0700 Subject: [PATCH 20/70] fix rebase --- csrc/multidevice/utils.cpp | 5 +++- csrc/scheduler/utils.cpp | 1 + tests/cpp/test_multidevice_sharding.cpp | 33 ++++++++++--------------- 3 files changed, 18 insertions(+), 21 deletions(-) diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp index a2cfd3403d9..c9732c2bf16 100644 --- a/csrc/multidevice/utils.cpp +++ b/csrc/multidevice/utils.cpp @@ -571,7 +571,10 @@ bool isInnerResharding(Expr* expr) { return false; } -void shardAllLike(TensorView* ref, std::vector tvs, bool parallelize_inputs) { +void shardAllLike( + TensorView* ref, + std::vector tvs, + bool parallelize_inputs) { for (auto tv : tvs) { tv->setDeviceMesh(ref->getDeviceMesh()); } diff --git a/csrc/scheduler/utils.cpp b/csrc/scheduler/utils.cpp index 319694f510d..fc0a00d6702 100644 --- a/csrc/scheduler/utils.cpp +++ b/csrc/scheduler/utils.cpp @@ -2337,6 +2337,7 @@ void propagateReshapeTransforms(Fusion* fusion, const ComputeAtMap& ca_map) { tv->reorder(old2new); //! Propagate current transformations on from_tv to all graphs transformPropagateToAllFrom(tv, (int64_t)old2new.size()); + // Propgating the transforms will not replay the DIDx parallelization, so we // need to do it manually here. parallelizeAllLike( diff --git a/tests/cpp/test_multidevice_sharding.cpp b/tests/cpp/test_multidevice_sharding.cpp index 37b83520078..04a8524d204 100644 --- a/tests/cpp/test_multidevice_sharding.cpp +++ b/tests/cpp/test_multidevice_sharding.cpp @@ -12,11 +12,9 @@ #include #include #include -#include #include #include #include -#include "multidevice/utils.h" namespace nvfuser { @@ -781,20 +779,17 @@ TEST_F(MultiDeviceTest, TransformPropagatorSplitReshape) { tv0->axis(-3)->parallelize(ParallelType::DIDx); // in: loop domain: {b, s, DIDx{d}, h, e} - // // Propagate DID loop split to output - // TransformPropagator propagator_p2c(tv0); - // MaxLogicalDomainInfoSpanningTree(tv0).traverse(&propagator_p2c); - // // out: loop domain: {b, s, d, h, e} after transform propagation + // Propagate DID loop split to output + TransformPropagator propagator_p2c(tv0); + MaxLogicalDomainInfoSpanningTree(tv0).traverse(&propagator_p2c); + // out: loop domain: {b, s, d, h, e} after transform propagation - // // Parallelize output - // tv1->setDeviceMesh(mesh); - // scheduler_utils::parallelizeAllLike( - // tv0, - // /*pos=*/-1, - // /*selected_tv=*/{tv1}); - // // out: loop domain: {b, s, DIDx{d}, h, e} after parallelization - - preseg_passes::OptimizationPass::runPass(fusion.get()); + // Parallelize output + scheduler_utils::parallelizeAllLike( + tv0, + /*pos=*/-1, + /*selected_tv=*/{tv1}); + // out: loop domain: {b, s, DIDx{d}, h, e} after parallelization tv0->setAllocationDomain(tv0->getLoopDomain(), true); tv1->setAllocationDomain(tv1->getLoopDomain(), true); @@ -1158,8 +1153,7 @@ TEST_F(MultiDeviceTest, TransformerFwd) { tv->axis(-2)->parallelize(ParallelType::DIDx); reorderDIDToFront(tv); } - - preseg_passes::OptimizationPass::runPass(fusion.get()); + propagateShardings(fusion.get(), d); for (auto tv : fusion->allTvs()) { tv->setAllocationDomain(tv->getLoopDomain(), true); @@ -1174,7 +1168,7 @@ TEST_F(MultiDeviceTest, TransformerFwd) { at::Tensor sharded_hk = shardTensor(hk_tensor, -1, mesh); at::Tensor sharded_hv = shardTensor(hv_tensor, -1, mesh); - at::Tensor nvf_out = + auto nvf_out = executor_cache .runFusionWithInputs({sharded_hq, sharded_hk, sharded_hv})[0] .as(); @@ -1187,7 +1181,7 @@ TEST_F(MultiDeviceTest, TransformerFwd) { /*dropout_p=*/0.0, /*is_causal=*/false, /*return_debug_mask=*/false, - /*scale=*/scale); + scale); at::Tensor ref_attn = shardTensor( std::get<0>(reference_out).transpose(1, 2).view(in_shape), -1, mesh); @@ -1199,5 +1193,4 @@ TEST_F(MultiDeviceTest, TransformerFwd) { __LINE__, __FILE__); } - } // namespace nvfuser From 748949a96b119176154d8a14282ca022b411ba04 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Tue, 18 Mar 2025 15:27:34 -0700 Subject: [PATCH 21/70] fix rebase --- tests/python/test_multidevice.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/test_multidevice.py b/tests/python/test_multidevice.py index 8a7f7dbcef7..237aa9f6aa6 100644 --- a/tests/python/test_multidevice.py +++ b/tests/python/test_multidevice.py @@ -1112,7 +1112,7 @@ def test_transformer_forward(multidevice_test, benchmark): # Benchmark and profile. The profile can be collected and displayed using # `nsys`. See instructions in test_transformer_engine.py. - # benchmark.pedantic(benchmark_fn, rounds=5) + benchmark.pedantic(benchmark_fn, rounds=5) # All tensors are replicated to all devices at this moment; future PRs will try @@ -1691,4 +1691,4 @@ def test_transformer_backward(multidevice_test, benchmark): _assert_shape_dtype(layernorm0_weight_grad, [e], torch.bfloat16) _assert_shape_dtype(inp_grad, [b, s, e], torch.bfloat16) - # benchmark.pedantic(benchmark_fn, rounds=5) + benchmark.pedantic(benchmark_fn, rounds=5) From 39e03ee1396f1bc3c109d178470121146e10ce7e Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Tue, 18 Mar 2025 15:28:37 -0700 Subject: [PATCH 22/70] clean --- csrc/multidevice/utils.cpp | 2 +- csrc/preseg_passes/propagate_shardings.cpp | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp index c9732c2bf16..982a93a87b3 100644 --- a/csrc/multidevice/utils.cpp +++ b/csrc/multidevice/utils.cpp @@ -578,7 +578,7 @@ void shardAllLike( for (auto tv : tvs) { tv->setDeviceMesh(ref->getDeviceMesh()); } - // TODO: If the tv already has a particular device parallel type, skip that. + if (!tvs.empty()) { scheduler_utils::parallelizeAllLike( ref, diff --git a/csrc/preseg_passes/propagate_shardings.cpp b/csrc/preseg_passes/propagate_shardings.cpp index 985184a77f4..6b54f3ef2d1 100644 --- a/csrc/preseg_passes/propagate_shardings.cpp +++ b/csrc/preseg_passes/propagate_shardings.cpp @@ -8,7 +8,6 @@ #include #include -#include "type.h" #include #include From ca80c001fd0012db0237e97f94e92e7cefa0b46e Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Tue, 18 Mar 2025 15:43:22 -0700 Subject: [PATCH 23/70] lintrunner --- csrc/multidevice/utils.h | 5 ++++- csrc/scheduler/utils.cpp | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/csrc/multidevice/utils.h b/csrc/multidevice/utils.h index 823eb29093e..b2ed0dbbc85 100644 --- a/csrc/multidevice/utils.h +++ b/csrc/multidevice/utils.h @@ -62,7 +62,10 @@ bool haveDifferentShardings( bool isInnerResharding(Expr* expr); // Shards all tensors in tvs like reference -void shardAllLike(TensorView* ref, std::vector tvs, bool parallelize_inputs=false); +void shardAllLike( + TensorView* ref, + std::vector tvs, + bool parallelize_inputs = false); // Shards all TVs between from and to AND between TVs created inside a fusion // and to. This is required for (1) expressions like rng_uniform that create a diff --git a/csrc/scheduler/utils.cpp b/csrc/scheduler/utils.cpp index fc0a00d6702..b4fba60331d 100644 --- a/csrc/scheduler/utils.cpp +++ b/csrc/scheduler/utils.cpp @@ -2337,7 +2337,7 @@ void propagateReshapeTransforms(Fusion* fusion, const ComputeAtMap& ca_map) { tv->reorder(old2new); //! Propagate current transformations on from_tv to all graphs transformPropagateToAllFrom(tv, (int64_t)old2new.size()); - + // Propgating the transforms will not replay the DIDx parallelization, so we // need to do it manually here. parallelizeAllLike( From 19bca7e73a03e9b76c3134361c1a90f46c96cdb8 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Tue, 18 Mar 2025 17:12:44 -0700 Subject: [PATCH 24/70] reorder did to front --- csrc/preseg_passes/propagate_shardings.cpp | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/csrc/preseg_passes/propagate_shardings.cpp b/csrc/preseg_passes/propagate_shardings.cpp index 6b54f3ef2d1..74a3f86ddb2 100644 --- a/csrc/preseg_passes/propagate_shardings.cpp +++ b/csrc/preseg_passes/propagate_shardings.cpp @@ -293,7 +293,7 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { // Propagate shardings from reference inputs in order. for (auto* ref_input : reference_inputs) { // Skip if the input has no device dimensions or is nullptr. - if (ref_input == nullptr || num_device_dims(ref_input) == 0) { + if (ref_input == nullptr) { continue; } @@ -323,8 +323,6 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { MaxLogicalDomainInfoSpanningTree(ref_input, &selector) .traverse(&propagator); - // Apply parallelization on the outputs without mesh. - shardAllLike(ref_input, outputs_without_mesh); // Reorder the loop as logical domain since the transform propagator may // have reordered the iterdomains in loop domain. For example: Consider // linear op: in = [b, m, k] weight = [DIDx(d), n/d, k] After @@ -332,9 +330,14 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { // m, r{k}] Since we later set the allocation domain to be loop domain, we // reorder the loop domain as logical domain. reorderAllAsLogicalMap(outputs_without_mesh); + + // Apply parallelization on the outputs without mesh. + shardAllLike(ref_input, outputs_without_mesh); + for (auto idx : c10::irange(num_device_dims)) { output_parallel_types.insert(ref_input->axis(idx)->getParallelType()); } + } } @@ -384,6 +387,11 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { shardAllLike(ref_output, unsharded_inputs, /*parallelize_inputs=*/true); } + // Reorder DID to front for all TensorViews. + // This will likely be subsumed/replace by other presegmentation passes once they are fixed. + for (auto tv : fusion->allTvs()) { + reorderDIDToFront(tv); + } validateMeshes(fusion); } From 5eabe968d19effe3cf543b956e01c2fe1882921a Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Tue, 18 Mar 2025 18:04:22 -0700 Subject: [PATCH 25/70] check if ref input has device mesh --- csrc/preseg_passes/propagate_shardings.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/preseg_passes/propagate_shardings.cpp b/csrc/preseg_passes/propagate_shardings.cpp index 74a3f86ddb2..7daf4b1e472 100644 --- a/csrc/preseg_passes/propagate_shardings.cpp +++ b/csrc/preseg_passes/propagate_shardings.cpp @@ -292,8 +292,8 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { // Propagate shardings from reference inputs in order. for (auto* ref_input : reference_inputs) { - // Skip if the input has no device dimensions or is nullptr. - if (ref_input == nullptr) { + // Skip if the input has no device mesh or is nullptr. + if (ref_input == nullptr || !ref_input->hasDeviceMesh()) { continue; } From 99c35745ede22af3b38537cdeeef3f7b1162cc9d Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Thu, 20 Mar 2025 14:28:30 -0700 Subject: [PATCH 26/70] update preseg pass --- csrc/preseg_passes/propagate_shardings.cpp | 37 +++++++++-------- tests/cpp/test_multidevice_preseg_passes.cpp | 43 ++++++++++---------- 2 files changed, 42 insertions(+), 38 deletions(-) diff --git a/csrc/preseg_passes/propagate_shardings.cpp b/csrc/preseg_passes/propagate_shardings.cpp index 7daf4b1e472..cf6a8cb8fb3 100644 --- a/csrc/preseg_passes/propagate_shardings.cpp +++ b/csrc/preseg_passes/propagate_shardings.cpp @@ -348,6 +348,25 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { // example. for (auto i_expr = exprs.rbegin(); i_expr != exprs.rend(); i_expr++) { Expr* expr = *i_expr; + + const auto& inputs = ir_utils::filterByType(expr->inputs()); + std::vector unsharded_inputs; + // Find all inputs that are not fusion inputs and have no device mesh or + // no device dimensions. Fusion inputs should already have device mesh set. + // We should not modify the shardings for the fusion inputs. + std::copy_if( + inputs.begin(), + inputs.end(), + std::back_inserter(unsharded_inputs), + [](TensorView* tv) { + return !tv->isFusionInput() && (!tv->hasDeviceMesh() || + num_device_dims(tv) == 0); + }); + + if (unsharded_inputs.empty()) { + continue; + } + const auto& outputs = ir_utils::filterByType(expr->outputs()); auto i_output = std::find_if( outputs.begin(), @@ -364,16 +383,6 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { TensorView* ref_output = *i_output; int64_t did_pos = reorderDIDToFront(ref_output); - const auto& inputs = ir_utils::filterByType(expr->inputs()); - std::vector unsharded_inputs; - std::copy_if( - inputs.begin(), - inputs.end(), - std::back_inserter(unsharded_inputs), - [](TensorView* tv) { - return !tv->hasDeviceMesh() || num_device_dims(tv) == 0; - }); - // Note: We do not have to manually shard for reshape here. // TransformPropagator can handle reshapes when going from consumer to // producer. @@ -384,14 +393,10 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { /*allow_p2c=*/false); MaxLogicalDomainInfoSpanningTree(ref_output, &selector) .traverse(&propagator); - shardAllLike(ref_output, unsharded_inputs, /*parallelize_inputs=*/true); + reorderAllAsLogicalMap(unsharded_inputs); + shardAllLike(ref_output, unsharded_inputs); } - // Reorder DID to front for all TensorViews. - // This will likely be subsumed/replace by other presegmentation passes once they are fixed. - for (auto tv : fusion->allTvs()) { - reorderDIDToFront(tv); - } validateMeshes(fusion); } diff --git a/tests/cpp/test_multidevice_preseg_passes.cpp b/tests/cpp/test_multidevice_preseg_passes.cpp index 2eb97438190..6a8e493e907 100644 --- a/tests/cpp/test_multidevice_preseg_passes.cpp +++ b/tests/cpp/test_multidevice_preseg_passes.cpp @@ -24,6 +24,7 @@ constexpr int64_t b = 2, s = 3, h = 128, a = 8; constexpr double dropout_p = 0.0; constexpr bool is_causal = false; +using testing::ElementsAre; using MultiDevicePresegPassesTest = MultiDeviceTest; TEST_F(MultiDevicePresegPassesTest, ResidualAdd) { @@ -36,46 +37,43 @@ TEST_F(MultiDevicePresegPassesTest, ResidualAdd) { FusionGuard fg(fusion.get()); const int d = communicator_->size(); - const int64_t b = 2, s = 3, h = 8; - TensorView* tv0 = makeContigConcreteTensor({b, d * s, h}); - TensorView* tv1 = makeContigConcreteTensor({b, d * s, h}); + TensorView* tv0 = makeContigConcreteTensor({d, 4}); + TensorView* tv1 = uniform( + shape(tv0), + fusion->zeroVal(DataType::Float), + fusion->oneVal(DataType::Float), + DataType::Float); TensorView* tv2 = add(tv0, tv1); auto mesh = DeviceMesh::createForNumDevices(d); tv0->setDeviceMesh(mesh); - tv1->setDeviceMesh(mesh); - tv1->split(1, d, /*inner_split=*/false); - tv1->axis(1)->parallelize(ParallelType::DIDx); + tv0->split(0, d, /*inner_split=*/false); + tv0->axis(0)->parallelize(ParallelType::DIDx); fusion->addInput(tv0); - fusion->addInput(tv1); + fusion->addOutput(tv1); fusion->addOutput(tv2); preseg_passes::OptimizationPass< preseg_passes::PropagateShardingsPass>::runPass(fusion.get()); + NVF_CHECK(tv1->hasDeviceMesh()); + NVF_CHECK(getShardedLogicalAxis(tv1, ParallelType::DIDx) == getShardedLogicalAxis(tv0, ParallelType::DIDx), "Expected tv1 to be sharded like tv0 due to backpropagation of shardings."); // Set the allocation domain explicitly until the preseg pass is fixed. for (auto* tv : {tv0, tv1, tv2}) { + reorderDIDToFront(tv); tv->setAllocationDomain(tv->getLoopDomain(), true); } - NVF_CHECK(getShardedLogicalAxis(tv0, ParallelType::DIDx) == 1); - at::Tensor inp0 = at::randn({b, d * s, h}, tensor_options); - at::Tensor inp1 = at::randn({b, d * s, h}, tensor_options); - at::Tensor sharded_inp0 = shardTensor(inp0, 1, mesh); - at::Tensor sharded_inp1 = shardTensor(inp1, 1, mesh); + at::Tensor inp0 = at::randn({d, 4}, tensor_options); + at::Tensor sharded_inp0 = shardTensor(inp0, 0, mesh); FusionExecutorCache executor_cache(std::move(fusion)); - at::Tensor nvf_out = - executor_cache.runFusionWithInputs({sharded_inp0, sharded_inp1})[0] - .as(); - testValidate( - executor_cache.fusion(), - {nvf_out}, - {sharded_inp0, sharded_inp1}, - {sharded_inp0 + sharded_inp1}, - __LINE__, - __FILE__); + auto nvf_out = + executor_cache.runFusionWithInputs({sharded_inp0}); + for (auto& out : nvf_out) { + EXPECT_THAT(out.as().sizes(), ElementsAre(1, 4)); + } } TEST_F(MultiDevicePresegPassesTest, MultipleTransformReshape) { @@ -98,6 +96,7 @@ TEST_F(MultiDevicePresegPassesTest, MultipleTransformReshape) { preseg_passes::OptimizationPass< preseg_passes::PropagateShardingsPass>::runPass(fusion.get()); for (auto* tv : {tv0, tv1}) { + reorderDIDToFront(tv); tv->setAllocationDomain(tv->getLoopDomain(), true); } From 1a20a7c86f245dc553b790994de326457d8f4b3e Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Thu, 20 Mar 2025 16:45:04 -0700 Subject: [PATCH 27/70] reorder to original in the interim --- csrc/preseg_passes/propagate_shardings.cpp | 72 +++++++--- tests/cpp/test_multidevice_sharding.cpp | 158 --------------------- 2 files changed, 53 insertions(+), 177 deletions(-) diff --git a/csrc/preseg_passes/propagate_shardings.cpp b/csrc/preseg_passes/propagate_shardings.cpp index cf6a8cb8fb3..8466dba04bd 100644 --- a/csrc/preseg_passes/propagate_shardings.cpp +++ b/csrc/preseg_passes/propagate_shardings.cpp @@ -77,7 +77,7 @@ getReshapedIds( return std::make_pair(p_reshaped_ids, c_reshaped_ids); } -int64_t num_device_dims(TensorView* tv) { +int64_t numDeviceDims(TensorView* tv) { return std::count_if( tv->getLoopDomain().begin(), tv->getLoopDomain().end(), @@ -107,7 +107,7 @@ std::vector getOrderedReferenceInputs(Expr* expr) { sorted_inputs.begin(), sorted_inputs.end(), [&](TensorView* a, TensorView* b) { - return num_device_dims(a) > num_device_dims(b); + return numDeviceDims(a) > numDeviceDims(b); }); return sorted_inputs; @@ -160,7 +160,7 @@ void splitLike( // Returns the number of DID axis on reshaped ids that were propagated to the // consumer. -int64_t handleViewOp(ViewOp* view_op, int64_t num_device_dims) { +int64_t shardViewOp(ViewOp* view_op, std::unordered_map& new2old) { // This implementation asserts that only one sharding is applied on the // reshaped ids. Inner split is not supported. The cases are: // 1. Split reshape: [h] -> [a, h/a]. Sharding on h is applied to a in @@ -175,6 +175,10 @@ int64_t handleViewOp(ViewOp* view_op, int64_t num_device_dims) { // are all outer splits. For example: For the reshape [h] -> [a, h/a] where // the h is sharded twice: [h] -> [cp, h/cp] -> [cp, tp, h/(cp*tp)] + // A more general approach maybe to "undo" the reshape (reverse transforms + // from root to logical domain), followed by simplification of the consumer + // loop domain to move DID upwards. + TensorView* producer = view_op->in(); TensorView* consumer = view_op->out(); @@ -192,6 +196,7 @@ int64_t handleViewOp(ViewOp* view_op, int64_t num_device_dims) { // Track number of DID axis on reshaped ids that were propagated to the // consumer. These will not be included in TransformPropagator. int64_t num_reshape_shardings = 0; + int64_t num_device_dims = new2old.size(); for (auto idx : c10::irange(num_device_dims)) { IterDomain* p_did = p_loop_domain.at(idx); @@ -257,13 +262,17 @@ int64_t handleViewOp(ViewOp* view_op, int64_t num_device_dims) { c_loop_domain.begin(), std::find(c_loop_domain.begin(), c_loop_domain.end(), c_logical_did)); + // TODO: Check for divisibility of the consumer axis by the split factor. splitLike(consumer, sharded_axis, p_did_split); consumer->axis(sharded_axis)->parallelize(p_did->getParallelType()); - // Move this did_pos to the end in producer to avoid using + // Move this did_pos behind the non-propagated DID axis to avoid using // TransformPropagator on it. - producer->reorder({{idx, -1}}); + producer->reorder({{idx, num_device_dims - 1}}); + new2old[idx] = num_device_dims - 1; + num_device_dims--; } + return num_reshape_shardings; } @@ -273,6 +282,26 @@ void reorderAllAsLogicalMap(std::vector tvs) { } } +// Reorder the DID axis to the front only if it does not have a parallel type +// already seen on the output (existing_parallel_types). +// Returns a map from the new position to the old position to undo the reordering later. +std::unordered_map selectiveReorderDIDToFront(TensorView* tv, std::unordered_set existing_parallel_types) { + std::unordered_map old2new; + std::unordered_map new2old; + int64_t current_pos = 0; + + for (auto pos : c10::irange(tv->nDims())) { + if (tv->axis(pos)->isDeviceDim() && !existing_parallel_types.count(tv->axis(pos)->getParallelType())) { + old2new[pos] = current_pos; + new2old[current_pos] = pos; + current_pos++; + } + } + + tv->reorder(old2new); + return new2old; +} + } // namespace void PropagateShardingsPass::runPass(Fusion* fusion) { @@ -297,20 +326,17 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { continue; } - // This restricts the transform propagation to the DID axis. - int64_t num_device_dims = reorderDIDToFront(ref_input); + // Reorder the DID axis to the front only if it does not have a parallel type + // already seen on the output. + std::unordered_map new2old = selectiveReorderDIDToFront(ref_input, output_parallel_types); - for (auto idx : c10::irange(num_device_dims)) { - if (output_parallel_types.count( - ref_input->axis(idx)->getParallelType())) { - // Do not propagate parallel types already seen on the output. - ref_input->reorder({{idx, -1}}); - num_device_dims--; - } - } + // This restricts the transform propagation to the DID axis. + int64_t num_device_dims = new2old.size(); if (ViewOp* view_op = dynamic_cast(expr)) { - int64_t num_reshape_shardings = handleViewOp(view_op, num_device_dims); + // Propagation of reshape will return how many DID axis were propagated. + // They are reordered behind non-propagated DID axis and the new2old map is updated. + int64_t num_reshape_shardings = shardViewOp(view_op, new2old); num_device_dims = num_device_dims - num_reshape_shardings; } @@ -337,7 +363,10 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { for (auto idx : c10::irange(num_device_dims)) { output_parallel_types.insert(ref_input->axis(idx)->getParallelType()); } - + // Moving the DID to the end can break tests using logical domain split for linear/matmul. + // Undo the reordering. This is only needed temporarily while we fix the other preseg passes. + // TODO: Remove this once the other preseg passes are fixed and reorder to front. + ref_input->reorder(new2old); } } @@ -360,7 +389,7 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { std::back_inserter(unsharded_inputs), [](TensorView* tv) { return !tv->isFusionInput() && (!tv->hasDeviceMesh() || - num_device_dims(tv) == 0); + numDeviceDims(tv) == 0); }); if (unsharded_inputs.empty()) { @@ -381,7 +410,8 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { // TODO: Do we need to worry about the case where the outputs are not // uniformly sharded? The relevant exprs are Welford and SDPA. TensorView* ref_output = *i_output; - int64_t did_pos = reorderDIDToFront(ref_output); + std::unordered_map new2old = selectiveReorderDIDToFront(ref_output, {}); + int64_t did_pos = new2old.size(); // Note: We do not have to manually shard for reshape here. // TransformPropagator can handle reshapes when going from consumer to @@ -395,6 +425,10 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { .traverse(&propagator); reorderAllAsLogicalMap(unsharded_inputs); shardAllLike(ref_output, unsharded_inputs); + + // Temporarily undo the reordering. This is only needed temporarily while we fix the other preseg passes. + // TODO: Remove this once the other preseg passes are fixed and reorder to front. + ref_output->reorder(new2old); } validateMeshes(fusion); diff --git a/tests/cpp/test_multidevice_sharding.cpp b/tests/cpp/test_multidevice_sharding.cpp index 04a8524d204..bd45c57116a 100644 --- a/tests/cpp/test_multidevice_sharding.cpp +++ b/tests/cpp/test_multidevice_sharding.cpp @@ -1035,162 +1035,4 @@ TEST_F(MultiDeviceTest, LoopShardedMergeReshapeIds) { __FILE__); } -namespace { -// This is a simplified version of what we will eventually do in the -// pre-segmentation pass -void propagateShardings(Fusion* fusion, int64_t num_devices) { - for (Expr* expr : fusion->exprs()) { - if (expr->isA()) { - NVF_THROW("SliceOp is not currently supported"); - } - - if (expr->isA()) { - // TransformPropagator cannot be directly used. - // It raises an error for conflicting transformations from root domain to - // logical domain. Instead, we manually find the reshaped iterdomain and - // outer split DID. This might have to be extended further in the - // presegmentation pass. - // Note: For simplicity, this assumes that the sharding is on reshaped - // IDs. It is possible that the non-reshaped IDs are sharded, in which - // case we can use the TransformPropagator. - TensorView* reshaped_tv = expr->as()->out(); - auto transform_exprs = StmtSort::getExprsBetween( - {reshaped_tv->getMaybeRootDomain().begin(), - reshaped_tv->getMaybeRootDomain().end()}, - {reshaped_tv->getLogicalDomain().begin(), - reshaped_tv->getLogicalDomain().end()}); - NVF_CHECK(transform_exprs.size() == 1); - auto transform = transform_exprs[0]; - NVF_CHECK(transform->isA() || transform->isA()); - - // Get the reshaped ID (outer ID for split reshape). - // This is the ID that will be parallelized. - IterDomain* reshaped_id = transform->isA() - ? transform->as()->outer() - : transform->as()->out(); - - auto reshaped_it = std::find( - reshaped_tv->getLoopDomain().begin(), - reshaped_tv->getLoopDomain().end(), - reshaped_id); - int64_t reshaped_axis = - std::distance(reshaped_tv->getLoopDomain().begin(), reshaped_it); - - // Apply sharding to the reshaped tensor - reshaped_tv->split(reshaped_axis, num_devices, false); - reshaped_tv->axis(reshaped_axis)->parallelize(ParallelType::DIDx); - reorderDIDToFront(reshaped_tv); - continue; - } - - // For other ops, propagate sharding from input to outputs - auto input_tv = expr->input(0)->as(); - std::vector output_tvs; - for (auto output : expr->outputs()) { - output_tvs.push_back(output->as()); - } - - TransformPropagator propagator(input_tv); - - // Note: We will finally propagate from each input iteratively. - SetSelector selector( - std::unordered_set(output_tvs.begin(), output_tvs.end())); - MaxLogicalDomainInfoSpanningTree(input_tv, &selector).traverse(&propagator); - scheduler_utils::parallelizeAllLike( - input_tv, - /*pos=*/-1, - /*selected_tv=*/output_tvs); - } -} - -} // namespace - -TEST_F(MultiDeviceTest, TransformerFwd) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - const int d = communicator_->size(); - const int64_t b = 2, s = 3, h = 8, e = 16; - auto mesh = DeviceMesh::createForNumDevices(d); - - std::vector in_shape = {b, s, d * h * e}; - std::vector out_shape = {b, s, d * h, e}; - - // The transformer block produces hq/hk/hv after slicing the MHA linear - // output. - TensorView* hq = makeConcreteTensor(in_shape, DataType::Half); - TensorView* hk = makeConcreteTensor(in_shape, DataType::Half); - TensorView* hv = makeConcreteTensor(in_shape, DataType::Half); - - TensorView* q = reshape(hq, in_shape, out_shape); - TensorView* q_permuted = permute(q, {0, 2, 1, 3}); - TensorView* k = reshape(hk, in_shape, out_shape); - TensorView* k_permuted = permute(k, {0, 2, 1, 3}); - TensorView* v = reshape(hv, in_shape, out_shape); - TensorView* v_permuted = permute(v, {0, 2, 1, 3}); - - SdpfaFwdResult sdpa_out = sdpfa_fwd( - q_permuted, - k_permuted, - v_permuted, - /*dropout_p=*/IrBuilder::create(0.0), - /*is_causal=*/IrBuilder::create(false), - /*scale=*/nullptr); - - TensorView* attn = sdpa_out.output; - TensorView* attn_permute = permute(attn, {0, 2, 1, 3}); - TensorView* out = reshape(attn_permute, out_shape, in_shape); - - fusion->addInput(hq); - fusion->addInput(hk); - fusion->addInput(hv); - fusion->addOutput(out); - - // Shard input tensors - for (auto* tv : {hq, hk, hv}) { - tv->setDeviceMesh(mesh); - tv->split(-1, d, /*inner_split=*/false); - tv->axis(-2)->parallelize(ParallelType::DIDx); - reorderDIDToFront(tv); - } - propagateShardings(fusion.get(), d); - - for (auto tv : fusion->allTvs()) { - tv->setAllocationDomain(tv->getLoopDomain(), true); - } - - FusionExecutorCache executor_cache(std::move(fusion)); - at::Tensor hq_tensor = at::randn({in_shape}, tensor_options.dtype(at::kHalf)); - at::Tensor hk_tensor = at::randn({in_shape}, tensor_options.dtype(at::kHalf)); - at::Tensor hv_tensor = at::randn({in_shape}, tensor_options.dtype(at::kHalf)); - - at::Tensor sharded_hq = shardTensor(hq_tensor, -1, mesh); - at::Tensor sharded_hk = shardTensor(hk_tensor, -1, mesh); - at::Tensor sharded_hv = shardTensor(hv_tensor, -1, mesh); - - auto nvf_out = - executor_cache - .runFusionWithInputs({sharded_hq, sharded_hk, sharded_hv})[0] - .as(); - - double scale = 1.0 / std::sqrt(e); - auto reference_out = at::_scaled_dot_product_flash_attention( - hq_tensor.view(out_shape).transpose(1, 2), - hk_tensor.view(out_shape).transpose(1, 2), - hv_tensor.view(out_shape).transpose(1, 2), - /*dropout_p=*/0.0, - /*is_causal=*/false, - /*return_debug_mask=*/false, - scale); - at::Tensor ref_attn = shardTensor( - std::get<0>(reference_out).transpose(1, 2).view(in_shape), -1, mesh); - - testValidate( - executor_cache.fusion(), - {nvf_out}, - {sharded_hq, sharded_hk, sharded_hv}, - {ref_attn}, - __LINE__, - __FILE__); -} } // namespace nvfuser From de91943d25d3714e959fbebdc6f390fe68d9b979 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Fri, 21 Mar 2025 18:04:41 -0700 Subject: [PATCH 28/70] allocation domain util fn --- csrc/preseg_passes/propagate_shardings.cpp | 42 ++++-- csrc/scheduler/utils.cpp | 137 ++++++++++++------- csrc/scheduler/utils.h | 2 + tests/cpp/test_multidevice_preseg_passes.cpp | 34 +++++ tests/cpp/test_multidevice_sharding.cpp | 1 - 5 files changed, 152 insertions(+), 64 deletions(-) diff --git a/csrc/preseg_passes/propagate_shardings.cpp b/csrc/preseg_passes/propagate_shardings.cpp index 8466dba04bd..32621b7801a 100644 --- a/csrc/preseg_passes/propagate_shardings.cpp +++ b/csrc/preseg_passes/propagate_shardings.cpp @@ -276,9 +276,13 @@ int64_t shardViewOp(ViewOp* view_op, std::unordered_map& new2o return num_reshape_shardings; } -void reorderAllAsLogicalMap(std::vector tvs) { +void reorderLoopDomainAsAllocationMap(std::vector tvs) { for (auto tv : tvs) { - tv->reorder(scheduler_utils::domainReorderAsLogicalMap(tv)); + auto reorder_map = scheduler_utils::domainReorderAsAllocationMap(tv); + if (reorder_map.empty()) { + continue; + } + tv->reorder(reorder_map); } } @@ -302,6 +306,13 @@ std::unordered_map selectiveReorderDIDToFront(TensorView* tv, return new2old; } +void propagateAllocationDomain(std::vector tvs) { + // TODO: Propagate/fully allocate ParallelType::Stream based on inputs. + for (auto tv : tvs) { + tv->setAllocationDomain(tv->getLoopDomain(), true); + } +} + } // namespace void PropagateShardingsPass::runPass(Fusion* fusion) { @@ -349,25 +360,30 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { MaxLogicalDomainInfoSpanningTree(ref_input, &selector) .traverse(&propagator); - // Reorder the loop as logical domain since the transform propagator may - // have reordered the iterdomains in loop domain. For example: Consider - // linear op: in = [b, m, k] weight = [DIDx(d), n/d, k] After - // transformation, the loop domain of linear output is [DIDx(d), n/d, b, - // m, r{k}] Since we later set the allocation domain to be loop domain, we - // reorder the loop domain as logical domain. - reorderAllAsLogicalMap(outputs_without_mesh); - // Apply parallelization on the outputs without mesh. shardAllLike(ref_input, outputs_without_mesh); for (auto idx : c10::irange(num_device_dims)) { output_parallel_types.insert(ref_input->axis(idx)->getParallelType()); } - // Moving the DID to the end can break tests using logical domain split for linear/matmul. - // Undo the reordering. This is only needed temporarily while we fix the other preseg passes. - // TODO: Remove this once the other preseg passes are fixed and reorder to front. + // Undo the reordering of the DID axis so it is in the correct order again. ref_input->reorder(new2old); } + + // Reorder the loop as logical domain since the transform propagator may + // have reordered the iterdomains in loop domain. For example: Consider + // linear op: in = [b, m, k] weight = [DIDx(d), n/d, k] After + // transformation, the loop domain of linear output is [DIDx(d), n/d, b, + // m, r{k}]. + // Due to the restriction that the allocation domain is the same as the loop domain, + // we reorder it as allocation domain in the interim. Ideally, this should follow logical domain + // and DIDx axis at the front. + reorderLoopDomainAsAllocationMap(outputs_without_mesh); + // TODO: Do we reorder to front here or reorder_sharded_axis? + + // Currently sets it as loop domain. In general, it will be different from loop domain + // for the case of ParallelType::Stream when fully allocated. + propagateAllocationDomain(expr->outputs()); } // Back-propagate device meshes. This makes sure all TensorViews have a mesh diff --git a/csrc/scheduler/utils.cpp b/csrc/scheduler/utils.cpp index b4fba60331d..597ef9aa454 100644 --- a/csrc/scheduler/utils.cpp +++ b/csrc/scheduler/utils.cpp @@ -2124,67 +2124,95 @@ bool breakIsDisjoint(std::vector group_ids, int64_t pos) { return true; } -std::unordered_map domainReorderAsLogicalMap(TensorView* tv) { +namespace { + +void applySplitTransform(Split* split, std::vector& ids) { + auto find_it = + std::find(ids.begin(), ids.end(), split->in()); + if (find_it == ids.end()) { + // Transformations before rfactor, ignore those. + return; + } + auto pos = std::distance(ids.begin(), find_it); + ids[pos] = split->inner(); + ids.insert(ids.begin() + pos, split->outer()); +} + +void applyMergeTransform(Merge* merge, std::vector& ids) { + auto find_it_0 = + std::find(ids.begin(), ids.end(), merge->outer()); + auto find_it_1 = + std::find(ids.begin(), ids.end(), merge->inner()); + if (find_it_0 == ids.end() && + find_it_1 == ids.end()) { + // Transformations before rfactor, ignore those. + return; + } + NVF_ERROR( + find_it_0 != ids.end() && find_it_1 != ids.end(), + "Error in transformations of ", + tv->toString(), + "\nTransformations before rfactor should not mix with transformations after rfactor."); + auto pos0 = std::distance(ids.begin(), find_it_0); + auto pos1 = std::distance(ids.begin(), find_it_1); + if (pos0 > pos1) { + std::swap(pos0, pos1); + } + // Should be impossible. + NVF_ERROR( + pos0 != pos1, + "Didn't expect merge inputs to be the same iteration domain:\n", + merge->toString()); + + ids.erase(ids.begin() + pos0); + ids[--pos1] = merge->out(); +} + +void applyResizeTransform(Resize* resize, std::vector& ids) { + auto find_it = + std::find(ids.begin(), ids.end(), resize->in()); + if (find_it == ids.end()) { + // Transformations before rfactor, ignore those. + return; + } + *find_it = resize->out(); +} + +std::unordered_map createReorderMap(std::vector& orig_domain, std::vector& new_domain) { + std::unordered_map old2new; + for (auto idx : c10::irange((int64_t)orig_domain.size())) { + auto orig_id = orig_domain.at(idx); + auto find_it = std::find(new_domain.begin(), new_domain.end(), orig_id); + NVF_ERROR( + find_it != new_domain.end(), + "Reordering map creation failed, uninitialized iterdomain,", + " likely something is wrong with the transformations between the logical and loop domain."); + int64_t new_pos = (int64_t)std::distance(new_domain.begin(), find_it); + old2new[idx] = new_pos; + } + return old2new; +} + +// simply update this vector of id's as progressing through the transformation +// expressions. We'll always insert the result of split in the location of the +// input, and insert the merge result in the position of the inner dimension. +void reorderDomain(TensorView* tv, std::vector& ids) { FusionGuard fg(tv->fusion()); auto transform_exprs = StmtSort::getExprsTo( {tv->getLoopDomain().begin(), tv->getLoopDomain().end()}); - // simply update this vector of id's as progressing through the transformation - // expressions. We'll always insert the result of split in the location of the - // input, and insert the merge result in the position of the inner dimension. - - auto reordered_ids = tv->getLogicalDomain(); for (const auto* expr : transform_exprs) { if (const Split* split = dynamic_cast(expr)) { - auto find_it = - std::find(reordered_ids.begin(), reordered_ids.end(), split->in()); - if (find_it == reordered_ids.end()) { - // Transformations before rfactor, ignore those. - continue; - } - auto pos = std::distance(reordered_ids.begin(), find_it); - reordered_ids[pos] = split->inner(); - reordered_ids.insert(reordered_ids.begin() + pos, split->outer()); + applySplitTransform(split, ids); } else if (const Merge* merge = dynamic_cast(expr)) { - auto find_it_0 = - std::find(reordered_ids.begin(), reordered_ids.end(), merge->outer()); - auto find_it_1 = - std::find(reordered_ids.begin(), reordered_ids.end(), merge->inner()); - if (find_it_0 == reordered_ids.end() && - find_it_1 == reordered_ids.end()) { - // Transformations before rfactor, ignore those. - continue; - } - NVF_ERROR( - find_it_0 != reordered_ids.end() && find_it_1 != reordered_ids.end(), - "Error in transformations of ", - tv->toString(), - "\nTransformations before rfactor should not mix with transformations after rfactor."); - auto pos0 = std::distance(reordered_ids.begin(), find_it_0); - auto pos1 = std::distance(reordered_ids.begin(), find_it_1); - if (pos0 > pos1) { - std::swap(pos0, pos1); - } - // Should be impossible. - NVF_ERROR( - pos0 != pos1, - "Didn't expect merge inputs to be the same iteration domain:\n", - merge->toString()); - - reordered_ids.erase(reordered_ids.begin() + pos0); - reordered_ids[--pos1] = merge->out(); + applyMergeTransform(merge, ids); } else if (const Resize* resize = dynamic_cast(expr)) { - auto find_it = - std::find(reordered_ids.begin(), reordered_ids.end(), resize->in()); - if (find_it == reordered_ids.end()) { - // Transformations before rfactor, ignore those. - continue; - } - *find_it = resize->out(); + applyResizeTransform(resize, ids); } else { NVF_ERROR(expr != nullptr); NVF_THROW("Unexpected expression: ", expr->toString()); } } +} std::unordered_map old2new; for (auto id_i : arange((int64_t)tv->getLoopDomain().size())) { @@ -2199,7 +2227,16 @@ std::unordered_map domainReorderAsLogicalMap(TensorView* tv) { int64_t old_pos = id_i; old2new[old_pos] = new_pos; } - return old2new; + auto reordered_ids = tv->getAllocationDomain(); + reorderDomain(tv, reordered_ids); + return createReorderMap(tv->getLoopDomain(), reordered_ids); +} + +// Returns a map reordering the loop domain of the tensor view as the logical domain +std::unordered_map domainReorderAsLogicalMap(TensorView* tv) { + auto reordered_ids = tv->getLoopDomain(); + reorderDomain(tv, reordered_ids); + return createReorderMap(tv->getLoopDomain(), reordered_ids); } std::unordered_map maybeReorderAsAllocationMap( diff --git a/csrc/scheduler/utils.h b/csrc/scheduler/utils.h index 2c1b41c8346..c73a4e7f330 100644 --- a/csrc/scheduler/utils.h +++ b/csrc/scheduler/utils.h @@ -657,6 +657,8 @@ bool breakIsDisjoint(std::vector group_ids, int64_t pos); // This is somewhat similar to orderTiledConcreteIdAsRoot std::unordered_map domainReorderAsLogicalMap(TensorView* tv); +std::unordered_map domainReorderAsAllocationMap(TensorView* tv); + // Generates an old to new map to reorder tv's loop domain as its allocation // order. This only handles the simple case where allocation is a permutation of // loop domain, otherwise, the function returns an empty container. diff --git a/tests/cpp/test_multidevice_preseg_passes.cpp b/tests/cpp/test_multidevice_preseg_passes.cpp index 6a8e493e907..93af6bbcf27 100644 --- a/tests/cpp/test_multidevice_preseg_passes.cpp +++ b/tests/cpp/test_multidevice_preseg_passes.cpp @@ -311,4 +311,38 @@ TEST_F(MultiDevicePresegPassesTest, MHAFwd) { __FILE__); } +TEST_F(MultiDevicePresegPassesTest, ReplayAllocation) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + const int d = communicator_->size(); + auto mesh = DeviceMesh::createForNumDevices(d); + + TensorView* tv0 = makeConcreteTensor({d*b, s}); + TensorView* tv1 = makeConcreteTensor({d*b, s}); + TensorView* tv2 = add(tv0, tv1); + + fusion->addInput(tv0); + fusion->addInput(tv1); + fusion->addOutput(tv2); + + for (auto* tv: {tv0, tv1, tv2}) { + tv->setAllocationDomain({tv->axis(1), tv->axis(0)}, true); + tv->setDeviceMesh(mesh); + tv->split(0, d, /*inner_split=*/false); + tv->axis(0)->parallelize(ParallelType::DIDx); + } + + preseg_passes::OptimizationPass< + preseg_passes::PropagateShardingsPass>::runPass(fusion.get()); + debug() << tv2->domain()->toString(0, false) << std::endl; + auto transforms = DependencyCheck::getAllExprsBetween({tv2->getLogicalDomain().begin(), tv2->getLogicalDomain().end()}, {tv2->getAllocationDomain().begin(), tv2->getAllocationDomain().end()}); + debug() << "transforms: " << transforms.size() << std::endl; + tv2->setAllocationDomain(tv2->getLoopDomain(), true); + auto transforms_updated = DependencyCheck::getAllExprsBetween({tv2->getLogicalDomain().begin(), tv2->getLogicalDomain().end()}, {tv2->getAllocationDomain().begin(), tv2->getAllocationDomain().end()}); + debug() << "transforms_updated: " << transforms_updated.size() << std::endl; + for (auto* expr: transforms_updated) { + debug() << expr->toString() << std::endl; + } +} + } // namespace nvfuser diff --git a/tests/cpp/test_multidevice_sharding.cpp b/tests/cpp/test_multidevice_sharding.cpp index bd45c57116a..288afbc5fd4 100644 --- a/tests/cpp/test_multidevice_sharding.cpp +++ b/tests/cpp/test_multidevice_sharding.cpp @@ -774,7 +774,6 @@ TEST_F(MultiDeviceTest, TransformPropagatorSplitReshape) { // Loop split and parallelize input tv0->setDeviceMesh(mesh); - tv1->setDeviceMesh(mesh); tv0->split(-2, d, /*inner_split=*/false); tv0->axis(-3)->parallelize(ParallelType::DIDx); // in: loop domain: {b, s, DIDx{d}, h, e} From d56a7424bdaae7c54324f7bc20d6ff1e9c563995 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Fri, 21 Mar 2025 20:12:41 -0700 Subject: [PATCH 29/70] rm allocation domain reorder --- csrc/preseg_passes/propagate_shardings.cpp | 20 ++++---------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/csrc/preseg_passes/propagate_shardings.cpp b/csrc/preseg_passes/propagate_shardings.cpp index 32621b7801a..3e7ac9d9421 100644 --- a/csrc/preseg_passes/propagate_shardings.cpp +++ b/csrc/preseg_passes/propagate_shardings.cpp @@ -276,12 +276,9 @@ int64_t shardViewOp(ViewOp* view_op, std::unordered_map& new2o return num_reshape_shardings; } -void reorderLoopDomainAsAllocationMap(std::vector tvs) { +void reorderLoopDomainAsLogicalMap(std::vector tvs) { for (auto tv : tvs) { - auto reorder_map = scheduler_utils::domainReorderAsAllocationMap(tv); - if (reorder_map.empty()) { - continue; - } + auto reorder_map = scheduler_utils::domainReorderAsLogicalMap(tv); tv->reorder(reorder_map); } } @@ -306,13 +303,6 @@ std::unordered_map selectiveReorderDIDToFront(TensorView* tv, return new2old; } -void propagateAllocationDomain(std::vector tvs) { - // TODO: Propagate/fully allocate ParallelType::Stream based on inputs. - for (auto tv : tvs) { - tv->setAllocationDomain(tv->getLoopDomain(), true); - } -} - } // namespace void PropagateShardingsPass::runPass(Fusion* fusion) { @@ -378,12 +368,10 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { // Due to the restriction that the allocation domain is the same as the loop domain, // we reorder it as allocation domain in the interim. Ideally, this should follow logical domain // and DIDx axis at the front. - reorderLoopDomainAsAllocationMap(outputs_without_mesh); + reorderLoopDomainAsLogicalMap(outputs_without_mesh); // TODO: Do we reorder to front here or reorder_sharded_axis? - // Currently sets it as loop domain. In general, it will be different from loop domain - // for the case of ParallelType::Stream when fully allocated. - propagateAllocationDomain(expr->outputs()); + // TODO: Propagate AllocationDomain. } // Back-propagate device meshes. This makes sure all TensorViews have a mesh From 7be563d9a0c9e510eb8b030f9d51daafaee1af36 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Fri, 4 Apr 2025 17:30:47 -0700 Subject: [PATCH 30/70] rebase, reorder as alloc --- csrc/multidevice/utils.cpp | 4 +- csrc/preseg_passes/propagate_shardings.cpp | 55 +++++--- csrc/scheduler/utils.cpp | 87 +++++------- csrc/scheduler/utils.h | 9 +- tests/cpp/test_multidevice_preseg_passes.cpp | 45 ------ tests/cpp/test_multidevice_sharding.cpp | 141 ------------------- 6 files changed, 82 insertions(+), 259 deletions(-) diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp index 982a93a87b3..845dda31011 100644 --- a/csrc/multidevice/utils.cpp +++ b/csrc/multidevice/utils.cpp @@ -331,6 +331,8 @@ std::pair computeLoopIndex( return id_to_index.at(id); } +} // namespace + std::vector getInputsInTargetDomain( IterDomain* loop_id, const std::vector& target_domain) { @@ -347,8 +349,6 @@ std::vector getInputsInTargetDomain( return inputs_as_iter_domains; } -} // namespace - bool haveDifferentShardings( const TensorView* producer, const TensorView* consumer) { diff --git a/csrc/preseg_passes/propagate_shardings.cpp b/csrc/preseg_passes/propagate_shardings.cpp index 3e7ac9d9421..381fad463dd 100644 --- a/csrc/preseg_passes/propagate_shardings.cpp +++ b/csrc/preseg_passes/propagate_shardings.cpp @@ -276,9 +276,18 @@ int64_t shardViewOp(ViewOp* view_op, std::unordered_map& new2o return num_reshape_shardings; } -void reorderLoopDomainAsLogicalMap(std::vector tvs) { +void reorderLoopAsAllocation(std::vector tvs) { + // Use maybeAllocationDomain to transform + // Transform using exprs between logical and loop and get the map for (auto tv : tvs) { - auto reorder_map = scheduler_utils::domainReorderAsLogicalMap(tv); + auto alloc_dom = tv->getMaybeAllocationDomain(); + std::vector transform_exprs = DependencyCheck::getAllExprsBetween( + {alloc_dom.begin(), alloc_dom.end()}, + {tv->getLoopDomain().begin(), tv->getLoopDomain().end()}); + auto reorder_map = scheduler_utils::createReorderMapUnderTransforms( + /*ids_to_reorder=*/tv->getLoopDomain(), + /*ids_to_transform=*/alloc_dom, + /*transform_exprs=*/transform_exprs); tv->reorder(reorder_map); } } @@ -303,8 +312,23 @@ std::unordered_map selectiveReorderDIDToFront(TensorView* tv, return new2old; } +// Updates the set of parallel types seen on the output. +void updateOutputParallelTypes(TensorView* tv, std::unordered_set& output_parallel_types) { + for (auto id: tv->getLoopDomain()) { + if (id->isDeviceDim()) { + output_parallel_types.insert(id->getParallelType()); + } + } +} + } // namespace + +// This presegmentation pass propagates shardings from fusion inputs to downstream tensorviews. +// 1. Forward propagating DID loop splits and parallelization from inputs to outputs that don't have a mesh using TransformPropagator +// 2. Reshape is handled manually since the DID loop split transforms conflict with the reshape root-to-logical transforms if using TransformPropagator +// 3. Back-propagating device meshes to ensure all TensorViews have consistent meshes. This also splits and parallelizes unsharded inputs based on outputs. See `MultiDevicePresegPassesTest.ResidualAdd` for an example. +// 4. Reorders the loop domain as the allocation order. Ideally, loop domain should follow logical domain and allocation domain should follow any stride order specified/inferred. However, we currently require loop domain to be the same as allocation domain. void PropagateShardingsPass::runPass(Fusion* fusion) { const std::vector& exprs = fusion->exprs(); @@ -353,25 +377,20 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { // Apply parallelization on the outputs without mesh. shardAllLike(ref_input, outputs_without_mesh); - for (auto idx : c10::irange(num_device_dims)) { - output_parallel_types.insert(ref_input->axis(idx)->getParallelType()); - } + updateOutputParallelTypes(ref_input, output_parallel_types); + // Undo the reordering of the DID axis so it is in the correct order again. ref_input->reorder(new2old); } - // Reorder the loop as logical domain since the transform propagator may + // Reorder the loop domain since the transform propagator may // have reordered the iterdomains in loop domain. For example: Consider // linear op: in = [b, m, k] weight = [DIDx(d), n/d, k] After // transformation, the loop domain of linear output is [DIDx(d), n/d, b, - // m, r{k}]. - // Due to the restriction that the allocation domain is the same as the loop domain, - // we reorder it as allocation domain in the interim. Ideally, this should follow logical domain - // and DIDx axis at the front. - reorderLoopDomainAsLogicalMap(outputs_without_mesh); - // TODO: Do we reorder to front here or reorder_sharded_axis? - - // TODO: Propagate AllocationDomain. + // m, r{k}]. Since, we set allocation to be the same as loop, we reorder it as allocation domain in the interim. + // Ideally, this should follow logical domain and DIDx axis at the front. + // The allocation domain should follow any stride order specified/inferred. + reorderLoopAsAllocation(outputs_without_mesh); } // Back-propagate device meshes. This makes sure all TensorViews have a mesh @@ -410,9 +429,7 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { } // All outputs of an expression are uniformly sharded so we pick the first - // one. - // TODO: Do we need to worry about the case where the outputs are not - // uniformly sharded? The relevant exprs are Welford and SDPA. + // one. Multi-output expressions are Welford and SDPA. TensorView* ref_output = *i_output; std::unordered_map new2old = selectiveReorderDIDToFront(ref_output, {}); int64_t did_pos = new2old.size(); @@ -427,12 +444,10 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { /*allow_p2c=*/false); MaxLogicalDomainInfoSpanningTree(ref_output, &selector) .traverse(&propagator); - reorderAllAsLogicalMap(unsharded_inputs); shardAllLike(ref_output, unsharded_inputs); - // Temporarily undo the reordering. This is only needed temporarily while we fix the other preseg passes. - // TODO: Remove this once the other preseg passes are fixed and reorder to front. ref_output->reorder(new2old); + reorderLoopAsAllocation(unsharded_inputs); } validateMeshes(fusion); diff --git a/csrc/scheduler/utils.cpp b/csrc/scheduler/utils.cpp index 597ef9aa454..70b5cc60849 100644 --- a/csrc/scheduler/utils.cpp +++ b/csrc/scheduler/utils.cpp @@ -2126,7 +2126,7 @@ bool breakIsDisjoint(std::vector group_ids, int64_t pos) { namespace { -void applySplitTransform(Split* split, std::vector& ids) { +void applySplitTransform(const Split* split, std::vector& ids) { auto find_it = std::find(ids.begin(), ids.end(), split->in()); if (find_it == ids.end()) { @@ -2138,7 +2138,7 @@ void applySplitTransform(Split* split, std::vector& ids) { ids.insert(ids.begin() + pos, split->outer()); } -void applyMergeTransform(Merge* merge, std::vector& ids) { +void applyMergeTransform(const Merge* merge, std::vector& ids) { auto find_it_0 = std::find(ids.begin(), ids.end(), merge->outer()); auto find_it_1 = @@ -2151,24 +2151,24 @@ void applyMergeTransform(Merge* merge, std::vector& ids) { NVF_ERROR( find_it_0 != ids.end() && find_it_1 != ids.end(), "Error in transformations of ", - tv->toString(), + ids, "\nTransformations before rfactor should not mix with transformations after rfactor."); auto pos0 = std::distance(ids.begin(), find_it_0); auto pos1 = std::distance(ids.begin(), find_it_1); - if (pos0 > pos1) { - std::swap(pos0, pos1); - } - // Should be impossible. - NVF_ERROR( - pos0 != pos1, - "Didn't expect merge inputs to be the same iteration domain:\n", - merge->toString()); + if (pos0 > pos1) { + std::swap(pos0, pos1); + } + // Should be impossible. + NVF_ERROR( + pos0 != pos1, + "Didn't expect merge inputs to be the same iteration domain:\n", + merge->toString()); ids.erase(ids.begin() + pos0); ids[--pos1] = merge->out(); } -void applyResizeTransform(Resize* resize, std::vector& ids) { +void applyResizeTransform(const Resize* resize, std::vector& ids) { auto find_it = std::find(ids.begin(), ids.end(), resize->in()); if (find_it == ids.end()) { @@ -2178,65 +2178,52 @@ void applyResizeTransform(Resize* resize, std::vector& ids) { *find_it = resize->out(); } -std::unordered_map createReorderMap(std::vector& orig_domain, std::vector& new_domain) { - std::unordered_map old2new; - for (auto idx : c10::irange((int64_t)orig_domain.size())) { - auto orig_id = orig_domain.at(idx); - auto find_it = std::find(new_domain.begin(), new_domain.end(), orig_id); - NVF_ERROR( - find_it != new_domain.end(), - "Reordering map creation failed, uninitialized iterdomain,", - " likely something is wrong with the transformations between the logical and loop domain."); - int64_t new_pos = (int64_t)std::distance(new_domain.begin(), find_it); - old2new[idx] = new_pos; - } - return old2new; -} +} // namespace -// simply update this vector of id's as progressing through the transformation +// Update the vector of ids_to_transform as progressing through the transformation // expressions. We'll always insert the result of split in the location of the // input, and insert the merge result in the position of the inner dimension. -void reorderDomain(TensorView* tv, std::vector& ids) { - FusionGuard fg(tv->fusion()); - auto transform_exprs = StmtSort::getExprsTo( - {tv->getLoopDomain().begin(), tv->getLoopDomain().end()}); +// After transformations, ids_to_reorder should be a permutation of ids_to_transform. +// Returns a reorder map from ids_to_reorder to ids_to_transform. +std::unordered_map createReorderMapUnderTransforms( + const std::vector& ids_to_reorder, + std::vector& ids_to_transform, + const std::vector& transform_exprs) { for (const auto* expr : transform_exprs) { if (const Split* split = dynamic_cast(expr)) { - applySplitTransform(split, ids); + applySplitTransform(split, ids_to_transform); } else if (const Merge* merge = dynamic_cast(expr)) { - applyMergeTransform(merge, ids); + applyMergeTransform(merge, ids_to_transform); } else if (const Resize* resize = dynamic_cast(expr)) { - applyResizeTransform(resize, ids); + applyResizeTransform(resize, ids_to_transform); } else { NVF_ERROR(expr != nullptr); NVF_THROW("Unexpected expression: ", expr->toString()); } } -} std::unordered_map old2new; - for (auto id_i : arange((int64_t)tv->getLoopDomain().size())) { - auto loop_id = tv->axis(id_i); - auto find_it = - std::find(reordered_ids.begin(), reordered_ids.end(), loop_id); + for (auto idx : c10::irange((int64_t)ids_to_reorder.size())) { + auto orig_id = ids_to_reorder.at(idx); + auto find_it = std::find(ids_to_transform.begin(), ids_to_transform.end(), orig_id); NVF_ERROR( - find_it != reordered_ids.end(), - "Reordering map creation failed, uninitialized iterdomain,", + find_it != ids_to_transform.end(), + "Reordering map creation failed, uninitialized iterdomain, ", + orig_id->toString(), " likely something is wrong with the transformations between the logical and loop domain."); - int64_t new_pos = (int64_t)std::distance(reordered_ids.begin(), find_it); - int64_t old_pos = id_i; - old2new[old_pos] = new_pos; + int64_t new_pos = (int64_t)std::distance(ids_to_transform.begin(), find_it); + old2new[idx] = new_pos; } - auto reordered_ids = tv->getAllocationDomain(); - reorderDomain(tv, reordered_ids); - return createReorderMap(tv->getLoopDomain(), reordered_ids); + return old2new; } // Returns a map reordering the loop domain of the tensor view as the logical domain std::unordered_map domainReorderAsLogicalMap(TensorView* tv) { - auto reordered_ids = tv->getLoopDomain(); - reorderDomain(tv, reordered_ids); - return createReorderMap(tv->getLoopDomain(), reordered_ids); + FusionGuard fg(tv->fusion()); + auto transform_exprs = StmtSort::getExprsTo( + {tv->getLoopDomain().begin(), tv->getLoopDomain().end()}); + auto ids_to_transform = tv->getLogicalDomain(); + return createReorderMapUnderTransforms(tv->getLoopDomain(), ids_to_transform, transform_exprs); } std::unordered_map maybeReorderAsAllocationMap( diff --git a/csrc/scheduler/utils.h b/csrc/scheduler/utils.h index c73a4e7f330..e882324e034 100644 --- a/csrc/scheduler/utils.h +++ b/csrc/scheduler/utils.h @@ -649,6 +649,14 @@ DisjointSets disjointLogicalSets(Fusion* fusion); // [1, 0, 0] pos 1 would return true bool breakIsDisjoint(std::vector group_ids, int64_t pos); +// Transform the ids_to_transform as progressing through the transform_exprs +// and return a reorder map from ids_to_reorder to the transformed ids. +// This is used to reorder the loop domain as the logical or the allocation order. +std::unordered_map createReorderMapUnderTransforms( + const std::vector& ids_to_reorder, + std::vector& ids_to_transform, + const std::vector& transform_exprs); + // Generates an old to new map to reorder tv's domain as the logical order. // Priority is given to inner most dimensions for example: // logical [i0, i1, i2] @@ -657,7 +665,6 @@ bool breakIsDisjoint(std::vector group_ids, int64_t pos); // This is somewhat similar to orderTiledConcreteIdAsRoot std::unordered_map domainReorderAsLogicalMap(TensorView* tv); -std::unordered_map domainReorderAsAllocationMap(TensorView* tv); // Generates an old to new map to reorder tv's loop domain as its allocation // order. This only handles the simple case where allocation is a permutation of diff --git a/tests/cpp/test_multidevice_preseg_passes.cpp b/tests/cpp/test_multidevice_preseg_passes.cpp index 93af6bbcf27..4d2bf4df000 100644 --- a/tests/cpp/test_multidevice_preseg_passes.cpp +++ b/tests/cpp/test_multidevice_preseg_passes.cpp @@ -167,51 +167,6 @@ TEST_F(MultiDevicePresegPassesTest, SliceReshapePermute) { __FILE__); } -// TODO: Enable this test once the insert_reshardings preseg pass is fixed. -TEST_F(MultiDevicePresegPassesTest, DISABLED_MHALinear) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - const int d = communicator_->size(); - auto mesh = DeviceMesh::createForNumDevices(d); - const int64_t b = 2, s = 3, h = 128; //,a=8; - - TensorView* inp = makeConcreteTensor({b, d * s, h}, DataType::Half); - TensorView* weight = makeConcreteTensor({3 * d * h, h}, DataType::Half); - TensorView* out = linear(inp, weight); - - fusion->addInput(inp); - fusion->addInput(weight); - fusion->addOutput(out); - - inp->setDeviceMesh(mesh); - weight->setDeviceMesh(mesh); - inp->split(1, d, /*inner_split=*/false); - inp->axis(1)->parallelize(ParallelType::DIDx); - weight->split(0, d, /*inner_split=*/false); - weight->axis(0)->parallelize(ParallelType::DIDx); - - preseg_passes::OptimizationPass< - preseg_passes::PropagateShardingsPass>::runPass(fusion.get()); - for (auto* tv : fusion->allTvs()) { - reorderDIDToFront(tv); - tv->setAllocationDomain(tv->getLoopDomain(), true); - } - NVF_CHECK(getShardedLogicalAxis(out, ParallelType::DIDx) == 2); - at::Tensor inp_tensor = - at::randn({b, d * s, h}, tensor_options.dtype(at::kHalf)); - at::Tensor sharded_inp = shardTensor(inp_tensor, 1, mesh); - - at::Tensor weight_tensor = - at::randn({3 * d * h, h}, tensor_options.dtype(at::kHalf)); - at::Tensor sharded_weight = shardTensor(weight_tensor, 0, mesh); - - FusionExecutorCache executor_cache(std::move(fusion)); - at::Tensor nvf_out = - executor_cache.runFusionWithInputs({sharded_inp, sharded_weight})[0] - .as(); -} - namespace { at::Tensor reference_mha(at::Tensor inp, at::Tensor weight) { at::Tensor linear0_out = at::linear(inp, weight); diff --git a/tests/cpp/test_multidevice_sharding.cpp b/tests/cpp/test_multidevice_sharding.cpp index 288afbc5fd4..bdda0192b39 100644 --- a/tests/cpp/test_multidevice_sharding.cpp +++ b/tests/cpp/test_multidevice_sharding.cpp @@ -852,147 +852,6 @@ TEST_F(MultiDeviceTest, LoopShardedSplitReshapeIds) { __FILE__); } -TEST_F(MultiDeviceTest, LoopShardedSplitReshapeIds) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - const int d = communicator_->size(); - const int64_t b = 2, s = 3, h = 8, e = 4; - - TensorView* tv0 = makeContigConcreteTensor({b, s, d * h * e}); - TensorView* tv1 = reshape(tv0, {b, s, d * h * e}, {b, s, d * h, e}); - - fusion->addInput(tv0); - fusion->addOutput(tv1); - - auto mesh = DeviceMesh::createForNumDevices(d); - - tv0->setDeviceMesh(mesh); - tv0->split(-1, d, /*inner_split=*/false); - tv0->axis(-2)->parallelize(ParallelType::DIDx); - - tv1->setDeviceMesh(mesh); - tv1->split(-2, d, /*inner_split=*/false); - tv1->axis(-3)->parallelize(ParallelType::DIDx); - - FusionExecutorCache executor_cache(std::move(fusion)); - at::Tensor in_tensor = at::randn(in_shape, tensor_options); - at::Tensor sharded_in = shardTensor(in_tensor, -1, mesh); - - at::Tensor out_tensor = executor_cache.runFusionWithInputs({sharded_in})[0]; - testValidate( - executor_cache.fusion(), - {out_tensor}, - {sharded_in}, - {sharded_in.view({b, s, h, e})}, - __LINE__, - __FILE__); - -} - -namespace { -// This is a simplified version of what we will eventually do in the -// pre-segmentation pass -void propagateShardings(Fusion* fusion, int64_t num_devices) { - for (Expr* expr : fusion->exprs()) { - if (expr->isA()) { - NVF_THROW("SliceOp is not currently supported"); - } - - if (expr->isA()) { - // TransformPropagator cannot be directly used. - // It raises an error for conflicting transformations from root domain to - // logical domain. Instead, we manually find the reshaped iterdomain and - // outer split DID. This might have to be extended further in the - // presegmentation pass. - // Note: For simplicity, this assumes that the sharding is on reshaped - // IDs. It is possible that the non-reshaped IDs are sharded, in which - // case we can use the TransformPropagator. - TensorView* reshaped_tv = expr->as()->out(); - auto transform_exprs = StmtSort::getExprsBetween( - {reshaped_tv->getMaybeRootDomain().begin(), - reshaped_tv->getMaybeRootDomain().end()}, - {reshaped_tv->getLogicalDomain().begin(), - reshaped_tv->getLogicalDomain().end()}); - NVF_CHECK(transform_exprs.size() == 1); - auto transform = transform_exprs[0]; - NVF_CHECK(transform->isA() || transform->isA()); - - // Get the reshaped ID (outer ID for split reshape). - // This is the ID that will be parallelized. - IterDomain* reshaped_id = transform->isA() - ? transform->as()->outer() - : transform->as()->out(); - - auto reshaped_it = std::find( - reshaped_tv->getLoopDomain().begin(), - reshaped_tv->getLoopDomain().end(), - reshaped_id); - int64_t reshaped_axis = - std::distance(reshaped_tv->getLoopDomain().begin(), reshaped_it); - - // Apply sharding to the reshaped tensor - reshaped_tv->split(reshaped_axis, num_devices, false); - reshaped_tv->axis(reshaped_axis)->parallelize(ParallelType::DIDx); - reorderDIDToFront(reshaped_tv); - continue; - } - - // For other ops, propagate sharding from input to outputs - auto input_tv = expr->input(0)->as(); - std::vector output_tvs; - for (auto output : expr->outputs()) { - output_tvs.push_back(output->as()); - } - - TransformPropagator propagator(input_tv); - - // Note: We will finally propagate from each input iteratively. - SetSelector selector( - std::unordered_set(output_tvs.begin(), output_tvs.end())); - MaxLogicalDomainInfoSpanningTree(input_tv, &selector).traverse(&propagator); - shardAllLike(input_tv, output_tvs); - } -} - -} // namespace - -TEST_F(MultiDeviceTest, TransformerFwd) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - const int d = communicator_->size(); - const int64_t m = 5, n = 7; - - TensorView* in = makeContigConcreteTensor({d * m, n}); - TensorView* out = reshape(in, {d * m, n}, {d * m * n}); - // TensorView* add_out = add(out, IrBuilder::create(1.0)); - - fusion->addInput(in); - fusion->addOutput(out); - - auto mesh = DeviceMesh::createForNumDevices(d); - for (auto* tv : {in, out}) { - tv->setDeviceMesh(mesh); - tv->split(0, d, /*inner_split=*/false); - tv->axis(0)->parallelize(ParallelType::DIDx); - tv->setAllocationDomain(tv->getLoopDomain(), true); - } - - FusionExecutorCache executor_cache(std::move(fusion)); - at::Tensor inp = at::randn({b, s, d * h * e}, tensor_options); - at::Tensor sharded_inp = shardTensor(inp, tv0); - - at::Tensor nvf_out = executor_cache.runFusionWithInputs({sharded_inp})[0]; - testValidate( - executor_cache.fusion(), - {nvf_out}, - {sharded_inp}, - {sharded_inp.view({b, s, h, e})}, - __LINE__, - __FILE__); -} - TEST_F(MultiDeviceTest, LoopShardedMergeReshapeIds) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); From fc80957d1fc4398a216f5bf2df7adbcc0ba66c9f Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Mon, 7 Apr 2025 13:36:29 -0700 Subject: [PATCH 31/70] set device mesh on fusion inputs --- csrc/preseg_passes/propagate_shardings.cpp | 88 +++++++++++++------- tests/cpp/test_multidevice_preseg_passes.cpp | 34 -------- 2 files changed, 57 insertions(+), 65 deletions(-) diff --git a/csrc/preseg_passes/propagate_shardings.cpp b/csrc/preseg_passes/propagate_shardings.cpp index 381fad463dd..b1fbdb79c63 100644 --- a/csrc/preseg_passes/propagate_shardings.cpp +++ b/csrc/preseg_passes/propagate_shardings.cpp @@ -84,6 +84,35 @@ int64_t numDeviceDims(TensorView* tv) { std::mem_fn(&IterDomain::isDeviceDim)); } + +// Sort the given tvs by the number of device dimensions in descending order. +// Break ties by the total number of dimensions. +// Only includes TensorViews that have a device mesh. +template +std::vector sortTvsByDeviceDims(const Range& tvs) { + // Filter out TVs without a device mesh + std::vector tvs_with_mesh; + std::copy_if( + tvs.begin(), + tvs.end(), + std::back_inserter(tvs_with_mesh), + std::mem_fn(&TensorView::hasDeviceMesh)); + + // Then sort the filtered TVs + std::sort(tvs_with_mesh.begin(), tvs_with_mesh.end(), + [](auto a, auto b) { + int64_t a_device_dims = numDeviceDims(a); + int64_t b_device_dims = numDeviceDims(b); + if (a_device_dims != b_device_dims) { + return a_device_dims > b_device_dims; + } + // Break ties by the total number of dimensions + return a->nDims() > b->nDims(); + }); + + return tvs_with_mesh; +} + // Order the inputs of the expression based on their priority. // For linear op, we use weights and bias before input. // For matmul op, we use weights before input. @@ -102,13 +131,7 @@ std::vector getOrderedReferenceInputs(Expr* expr) { } // Sort inputs by number of device dimensions in descending order - std::vector sorted_inputs(inputs.begin(), inputs.end()); - std::sort( - sorted_inputs.begin(), - sorted_inputs.end(), - [&](TensorView* a, TensorView* b) { - return numDeviceDims(a) > numDeviceDims(b); - }); + std::vector sorted_inputs = sortTvsByDeviceDims(inputs); return sorted_inputs; } @@ -397,40 +420,43 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { // if any of them has one. This is needed in addition to the forward // propagation for ops that don't take any TensorView operands, e.g., // `uniform` used in dropout. See MultiDeviceTest.BackpropMeshes for an - // example. + // example. For non-fusion inputs, we also propagate shardings from outputs to inputs. + // See MultiDevicePresegPassesTest.ResidualAdd for an example. for (auto i_expr = exprs.rbegin(); i_expr != exprs.rend(); i_expr++) { Expr* expr = *i_expr; - const auto& inputs = ir_utils::filterByType(expr->inputs()); - std::vector unsharded_inputs; - // Find all inputs that are not fusion inputs and have no device mesh or - // no device dimensions. Fusion inputs should already have device mesh set. - // We should not modify the shardings for the fusion inputs. - std::copy_if( - inputs.begin(), - inputs.end(), - std::back_inserter(unsharded_inputs), - [](TensorView* tv) { - return !tv->isFusionInput() && (!tv->hasDeviceMesh() || - numDeviceDims(tv) == 0); - }); + const auto& outputs = ir_utils::filterByType(expr->outputs()); + std::vector sorted_outputs = sortTvsByDeviceDims(outputs); + // All outputs of an expression (Welford, SDPA) should be uniformly sharded. + // We pick the most parallel output as the reference. + // This is to avoid picking seed/offset tvs in SDPA. - if (unsharded_inputs.empty()) { + if (sorted_outputs.empty()) { continue; } - const auto& outputs = ir_utils::filterByType(expr->outputs()); - auto i_output = std::find_if( - outputs.begin(), - outputs.end(), - std::mem_fn(&TensorView::hasDeviceMesh)); - if (i_output == outputs.end()) { + TensorView* ref_output = sorted_outputs.front(); + + // For fusion inputs, only check if they have a device mesh. We do not modify their sharding. + // For non-fusion inputs, check both device mesh and device dims. + const auto& inputs = ir_utils::filterByType(expr->inputs()); + std::vector unsharded_inputs; + for (auto* tv : inputs) { + if (tv->isFusionInput()) { + if (!tv->hasDeviceMesh()) { + tv->setDeviceMesh(ref_output->getDeviceMesh()); + } + continue; + } + if (!tv->hasDeviceMesh() || numDeviceDims(tv) == 0) { + unsharded_inputs.push_back(tv); + } + } + + if (unsharded_inputs.empty()) { continue; } - // All outputs of an expression are uniformly sharded so we pick the first - // one. Multi-output expressions are Welford and SDPA. - TensorView* ref_output = *i_output; std::unordered_map new2old = selectiveReorderDIDToFront(ref_output, {}); int64_t did_pos = new2old.size(); diff --git a/tests/cpp/test_multidevice_preseg_passes.cpp b/tests/cpp/test_multidevice_preseg_passes.cpp index 4d2bf4df000..e634717467c 100644 --- a/tests/cpp/test_multidevice_preseg_passes.cpp +++ b/tests/cpp/test_multidevice_preseg_passes.cpp @@ -265,39 +265,5 @@ TEST_F(MultiDevicePresegPassesTest, MHAFwd) { __LINE__, __FILE__); } - -TEST_F(MultiDevicePresegPassesTest, ReplayAllocation) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - const int d = communicator_->size(); - auto mesh = DeviceMesh::createForNumDevices(d); - - TensorView* tv0 = makeConcreteTensor({d*b, s}); - TensorView* tv1 = makeConcreteTensor({d*b, s}); - TensorView* tv2 = add(tv0, tv1); - - fusion->addInput(tv0); - fusion->addInput(tv1); - fusion->addOutput(tv2); - - for (auto* tv: {tv0, tv1, tv2}) { - tv->setAllocationDomain({tv->axis(1), tv->axis(0)}, true); - tv->setDeviceMesh(mesh); - tv->split(0, d, /*inner_split=*/false); - tv->axis(0)->parallelize(ParallelType::DIDx); - } - - preseg_passes::OptimizationPass< - preseg_passes::PropagateShardingsPass>::runPass(fusion.get()); - debug() << tv2->domain()->toString(0, false) << std::endl; - auto transforms = DependencyCheck::getAllExprsBetween({tv2->getLogicalDomain().begin(), tv2->getLogicalDomain().end()}, {tv2->getAllocationDomain().begin(), tv2->getAllocationDomain().end()}); - debug() << "transforms: " << transforms.size() << std::endl; - tv2->setAllocationDomain(tv2->getLoopDomain(), true); - auto transforms_updated = DependencyCheck::getAllExprsBetween({tv2->getLogicalDomain().begin(), tv2->getLogicalDomain().end()}, {tv2->getAllocationDomain().begin(), tv2->getAllocationDomain().end()}); - debug() << "transforms_updated: " << transforms_updated.size() << std::endl; - for (auto* expr: transforms_updated) { - debug() << expr->toString() << std::endl; - } -} } // namespace nvfuser From 443aef4c951b4e32190ff99175abc32285593d4b Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Mon, 7 Apr 2025 14:21:07 -0700 Subject: [PATCH 32/70] early return if ref_inputs do not have mesh --- csrc/multidevice/utils.cpp | 16 +++--------- csrc/multidevice/utils.h | 7 ++---- csrc/preseg_passes/propagate_shardings.cpp | 29 +++++++++++++++------- 3 files changed, 25 insertions(+), 27 deletions(-) diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp index 845dda31011..593dd8793cc 100644 --- a/csrc/multidevice/utils.cpp +++ b/csrc/multidevice/utils.cpp @@ -571,22 +571,13 @@ bool isInnerResharding(Expr* expr) { return false; } -void shardAllLike( - TensorView* ref, - std::vector tvs, - bool parallelize_inputs) { +void shardAllLike(TensorView* ref, std::vector tvs) { for (auto tv : tvs) { tv->setDeviceMesh(ref->getDeviceMesh()); } - if (!tvs.empty()) { scheduler_utils::parallelizeAllLike( - ref, - /*pos=*/-1, - /*selected_tvs=*/tvs, - /*selected_parallel_types=*/{ParallelType::DIDx, ParallelType::Serial}, - /*propagate_padding=*/false, - /*parallelize_inputs=*/parallelize_inputs); + ref, tvs, {ParallelType::DIDx, ParallelType::Serial}); } // parallelAllLke, tries to DID-parallelize @@ -721,7 +712,7 @@ std::set involvedDevices(Expr* expr) { return ret; } -int64_t reorderDIDToFront(TensorView* tv) { +void reorderDIDToFront(TensorView* tv) { // old position to new position std::unordered_map order_map; int64_t current_pos = 0; @@ -734,7 +725,6 @@ int64_t reorderDIDToFront(TensorView* tv) { } tv->reorder(order_map); - return current_pos; } std::unordered_set getTvsWithDifferentSharding( diff --git a/csrc/multidevice/utils.h b/csrc/multidevice/utils.h index b2ed0dbbc85..db560097230 100644 --- a/csrc/multidevice/utils.h +++ b/csrc/multidevice/utils.h @@ -62,10 +62,7 @@ bool haveDifferentShardings( bool isInnerResharding(Expr* expr); // Shards all tensors in tvs like reference -void shardAllLike( - TensorView* ref, - std::vector tvs, - bool parallelize_inputs = false); +void shardAllLike(TensorView* ref, std::vector tvs); // Shards all TVs between from and to AND between TVs created inside a fusion // and to. This is required for (1) expressions like rng_uniform that create a @@ -121,7 +118,7 @@ at::Tensor shardTensor( DeviceIdxType device_id); // Reorders a TensorView so that the DID parallelized axis are in front. -int64_t reorderDIDToFront(TensorView*); +void reorderDIDToFront(TensorView*); // Given a TensorView and the shape of a sharded tensor of which certain // dimensions are partially allocated, returns the global shape that'll be used diff --git a/csrc/preseg_passes/propagate_shardings.cpp b/csrc/preseg_passes/propagate_shardings.cpp index b1fbdb79c63..9b4a6c2ec0d 100644 --- a/csrc/preseg_passes/propagate_shardings.cpp +++ b/csrc/preseg_passes/propagate_shardings.cpp @@ -88,15 +88,22 @@ int64_t numDeviceDims(TensorView* tv) { // Sort the given tvs by the number of device dimensions in descending order. // Break ties by the total number of dimensions. // Only includes TensorViews that have a device mesh. + template -std::vector sortTvsByDeviceDims(const Range& tvs) { - // Filter out TVs without a device mesh +std::vector filterTvsWithMesh(const Range& tvs) { std::vector tvs_with_mesh; std::copy_if( tvs.begin(), tvs.end(), std::back_inserter(tvs_with_mesh), - std::mem_fn(&TensorView::hasDeviceMesh)); + [](TensorView* tv) { return tv != nullptr && tv->hasDeviceMesh(); }); + return tvs_with_mesh; +} + +template +std::vector sortTvsByDeviceDims(const Range& tvs) { + // Filter out TVs without a device mesh + std::vector tvs_with_mesh = filterTvsWithMesh(tvs); // Then sort the filtered TVs std::sort(tvs_with_mesh.begin(), tvs_with_mesh.end(), @@ -122,12 +129,12 @@ std::vector getOrderedReferenceInputs(Expr* expr) { const auto& inputs = ir_utils::filterByType(expr->inputs()); if (LinearOp* linear_op = dynamic_cast(expr)) { // Use weights and bias before input. - return {linear_op->inB(), linear_op->bias(), linear_op->inA()}; + return filterTvsWithMesh(std::vector({linear_op->inB(), linear_op->bias(), linear_op->inA()})); } if (MatmulOp* matmul_op = dynamic_cast(expr)) { // Use weights before input. - return {matmul_op->inB(), matmul_op->inA()}; + return filterTvsWithMesh(std::vector({matmul_op->inB(), matmul_op->inA()})); } // Sort inputs by number of device dimensions in descending order @@ -301,12 +308,13 @@ int64_t shardViewOp(ViewOp* view_op, std::unordered_map& new2o void reorderLoopAsAllocation(std::vector tvs) { // Use maybeAllocationDomain to transform - // Transform using exprs between logical and loop and get the map + // Transform using exprs between logical and loop and get the map. for (auto tv : tvs) { auto alloc_dom = tv->getMaybeAllocationDomain(); std::vector transform_exprs = DependencyCheck::getAllExprsBetween( {alloc_dom.begin(), alloc_dom.end()}, {tv->getLoopDomain().begin(), tv->getLoopDomain().end()}); + NVF_ERROR(std::all_of(transform_exprs.begin(), transform_exprs.end(), [](Expr* expr) { return expr->isA(); }), "Expected all transform exprs to be a split between logical and loop domain during sharding propagation."); auto reorder_map = scheduler_utils::createReorderMapUnderTransforms( /*ids_to_reorder=*/tv->getLoopDomain(), /*ids_to_transform=*/alloc_dom, @@ -365,14 +373,16 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { const auto& reference_inputs = getOrderedReferenceInputs(expr); + if (reference_inputs.empty()) { + continue; + } + std::unordered_set output_parallel_types; // Propagate shardings from reference inputs in order. for (auto* ref_input : reference_inputs) { // Skip if the input has no device mesh or is nullptr. - if (ref_input == nullptr || !ref_input->hasDeviceMesh()) { - continue; - } + NVF_ERROR(ref_input != nullptr && ref_input->hasDeviceMesh(), "Reference input ", ref_input, " has no device mesh."); // Reorder the DID axis to the front only if it does not have a parallel type // already seen on the output. @@ -436,6 +446,7 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { } TensorView* ref_output = sorted_outputs.front(); + NVF_ERROR(ref_output != nullptr && ref_output->hasDeviceMesh(), "Reference output ", ref_output, " has no device mesh."); // For fusion inputs, only check if they have a device mesh. We do not modify their sharding. // For non-fusion inputs, check both device mesh and device dims. From 9b86c5b02a7037b7c93b9e2c7440623bdffab4b5 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Mon, 7 Apr 2025 15:55:49 -0700 Subject: [PATCH 33/70] lintrunner --- csrc/preseg_passes/propagate_shardings.cpp | 126 +++++++++++++-------- 1 file changed, 80 insertions(+), 46 deletions(-) diff --git a/csrc/preseg_passes/propagate_shardings.cpp b/csrc/preseg_passes/propagate_shardings.cpp index 9b4a6c2ec0d..fe4caf50ebb 100644 --- a/csrc/preseg_passes/propagate_shardings.cpp +++ b/csrc/preseg_passes/propagate_shardings.cpp @@ -84,7 +84,6 @@ int64_t numDeviceDims(TensorView* tv) { std::mem_fn(&IterDomain::isDeviceDim)); } - // Sort the given tvs by the number of device dimensions in descending order. // Break ties by the total number of dimensions. // Only includes TensorViews that have a device mesh. @@ -101,22 +100,21 @@ std::vector filterTvsWithMesh(const Range& tvs) { } template -std::vector sortTvsByDeviceDims(const Range& tvs) { +std::vector sortTvsByDeviceDims(const Range& tvs) { // Filter out TVs without a device mesh std::vector tvs_with_mesh = filterTvsWithMesh(tvs); - + // Then sort the filtered TVs - std::sort(tvs_with_mesh.begin(), tvs_with_mesh.end(), - [](auto a, auto b) { - int64_t a_device_dims = numDeviceDims(a); - int64_t b_device_dims = numDeviceDims(b); - if (a_device_dims != b_device_dims) { - return a_device_dims > b_device_dims; - } - // Break ties by the total number of dimensions - return a->nDims() > b->nDims(); - }); - + std::sort(tvs_with_mesh.begin(), tvs_with_mesh.end(), [](auto a, auto b) { + int64_t a_device_dims = numDeviceDims(a); + int64_t b_device_dims = numDeviceDims(b); + if (a_device_dims != b_device_dims) { + return a_device_dims >= b_device_dims; + } + // Break ties by the total number of dimensions + return a->nDims() >= b->nDims(); + }); + return tvs_with_mesh; } @@ -129,12 +127,14 @@ std::vector getOrderedReferenceInputs(Expr* expr) { const auto& inputs = ir_utils::filterByType(expr->inputs()); if (LinearOp* linear_op = dynamic_cast(expr)) { // Use weights and bias before input. - return filterTvsWithMesh(std::vector({linear_op->inB(), linear_op->bias(), linear_op->inA()})); + return filterTvsWithMesh(std::vector( + {linear_op->inB(), linear_op->bias(), linear_op->inA()})); } if (MatmulOp* matmul_op = dynamic_cast(expr)) { // Use weights before input. - return filterTvsWithMesh(std::vector({matmul_op->inB(), matmul_op->inA()})); + return filterTvsWithMesh( + std::vector({matmul_op->inB(), matmul_op->inA()})); } // Sort inputs by number of device dimensions in descending order @@ -190,7 +190,9 @@ void splitLike( // Returns the number of DID axis on reshaped ids that were propagated to the // consumer. -int64_t shardViewOp(ViewOp* view_op, std::unordered_map& new2old) { +int64_t shardViewOp( + ViewOp* view_op, + std::unordered_map& new2old) { // This implementation asserts that only one sharding is applied on the // reshaped ids. Inner split is not supported. The cases are: // 1. Split reshape: [h] -> [a, h/a]. Sharding on h is applied to a in @@ -314,7 +316,12 @@ void reorderLoopAsAllocation(std::vector tvs) { std::vector transform_exprs = DependencyCheck::getAllExprsBetween( {alloc_dom.begin(), alloc_dom.end()}, {tv->getLoopDomain().begin(), tv->getLoopDomain().end()}); - NVF_ERROR(std::all_of(transform_exprs.begin(), transform_exprs.end(), [](Expr* expr) { return expr->isA(); }), "Expected all transform exprs to be a split between logical and loop domain during sharding propagation."); + NVF_ERROR( + std::all_of( + transform_exprs.begin(), + transform_exprs.end(), + [](Expr* expr) { return expr->isA(); }), + "Expected all transform exprs to be a split between logical and loop domain during sharding propagation."); auto reorder_map = scheduler_utils::createReorderMapUnderTransforms( /*ids_to_reorder=*/tv->getLoopDomain(), /*ids_to_transform=*/alloc_dom, @@ -325,14 +332,18 @@ void reorderLoopAsAllocation(std::vector tvs) { // Reorder the DID axis to the front only if it does not have a parallel type // already seen on the output (existing_parallel_types). -// Returns a map from the new position to the old position to undo the reordering later. -std::unordered_map selectiveReorderDIDToFront(TensorView* tv, std::unordered_set existing_parallel_types) { +// Returns a map from the new position to the old position to undo the +// reordering later. +std::unordered_map selectiveReorderDIDToFront( + TensorView* tv, + std::unordered_set existing_parallel_types) { std::unordered_map old2new; std::unordered_map new2old; int64_t current_pos = 0; for (auto pos : c10::irange(tv->nDims())) { - if (tv->axis(pos)->isDeviceDim() && !existing_parallel_types.count(tv->axis(pos)->getParallelType())) { + if (tv->axis(pos)->isDeviceDim() && + !existing_parallel_types.count(tv->axis(pos)->getParallelType())) { old2new[pos] = current_pos; new2old[current_pos] = pos; current_pos++; @@ -344,8 +355,10 @@ std::unordered_map selectiveReorderDIDToFront(TensorView* tv, } // Updates the set of parallel types seen on the output. -void updateOutputParallelTypes(TensorView* tv, std::unordered_set& output_parallel_types) { - for (auto id: tv->getLoopDomain()) { +void updateOutputParallelTypes( + TensorView* tv, + std::unordered_set& output_parallel_types) { + for (auto id : tv->getLoopDomain()) { if (id->isDeviceDim()) { output_parallel_types.insert(id->getParallelType()); } @@ -354,12 +367,19 @@ void updateOutputParallelTypes(TensorView* tv, std::unordered_set& } // namespace - -// This presegmentation pass propagates shardings from fusion inputs to downstream tensorviews. -// 1. Forward propagating DID loop splits and parallelization from inputs to outputs that don't have a mesh using TransformPropagator -// 2. Reshape is handled manually since the DID loop split transforms conflict with the reshape root-to-logical transforms if using TransformPropagator -// 3. Back-propagating device meshes to ensure all TensorViews have consistent meshes. This also splits and parallelizes unsharded inputs based on outputs. See `MultiDevicePresegPassesTest.ResidualAdd` for an example. -// 4. Reorders the loop domain as the allocation order. Ideally, loop domain should follow logical domain and allocation domain should follow any stride order specified/inferred. However, we currently require loop domain to be the same as allocation domain. +// This presegmentation pass propagates shardings from fusion inputs to +// downstream tensorviews. +// 1. Forward propagating DID loop splits and parallelization from inputs to +// outputs that don't have a mesh using TransformPropagator +// 2. Reshape is handled manually since the DID loop split transforms conflict +// with the reshape root-to-logical transforms if using TransformPropagator +// 3. Back-propagating device meshes to ensure all TensorViews have consistent +// meshes. This also splits and parallelizes unsharded inputs based on outputs. +// See `MultiDevicePresegPassesTest.ResidualAdd` for an example. +// 4. Reorders the loop domain as the allocation order. Ideally, loop domain +// should follow logical domain and allocation domain should follow any stride +// order specified/inferred. However, we currently require loop domain to be the +// same as allocation domain. void PropagateShardingsPass::runPass(Fusion* fusion) { const std::vector& exprs = fusion->exprs(); @@ -382,18 +402,24 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { // Propagate shardings from reference inputs in order. for (auto* ref_input : reference_inputs) { // Skip if the input has no device mesh or is nullptr. - NVF_ERROR(ref_input != nullptr && ref_input->hasDeviceMesh(), "Reference input ", ref_input, " has no device mesh."); + NVF_ERROR( + ref_input != nullptr && ref_input->hasDeviceMesh(), + "Reference input ", + ref_input, + " has no device mesh."); - // Reorder the DID axis to the front only if it does not have a parallel type - // already seen on the output. - std::unordered_map new2old = selectiveReorderDIDToFront(ref_input, output_parallel_types); + // Reorder the DID axis to the front only if it does not have a parallel + // type already seen on the output. + std::unordered_map new2old = + selectiveReorderDIDToFront(ref_input, output_parallel_types); // This restricts the transform propagation to the DID axis. int64_t num_device_dims = new2old.size(); if (ViewOp* view_op = dynamic_cast(expr)) { // Propagation of reshape will return how many DID axis were propagated. - // They are reordered behind non-propagated DID axis and the new2old map is updated. + // They are reordered behind non-propagated DID axis and the new2old map + // is updated. int64_t num_reshape_shardings = shardViewOp(view_op, new2old); num_device_dims = num_device_dims - num_reshape_shardings; } @@ -412,7 +438,8 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { updateOutputParallelTypes(ref_input, output_parallel_types); - // Undo the reordering of the DID axis so it is in the correct order again. + // Undo the reordering of the DID axis so it is in the correct order + // again. ref_input->reorder(new2old); } @@ -420,9 +447,10 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { // have reordered the iterdomains in loop domain. For example: Consider // linear op: in = [b, m, k] weight = [DIDx(d), n/d, k] After // transformation, the loop domain of linear output is [DIDx(d), n/d, b, - // m, r{k}]. Since, we set allocation to be the same as loop, we reorder it as allocation domain in the interim. - // Ideally, this should follow logical domain and DIDx axis at the front. - // The allocation domain should follow any stride order specified/inferred. + // m, r{k}]. Since, we set allocation to be the same as loop, we reorder it + // as allocation domain in the interim. Ideally, this should follow logical + // domain and DIDx axis at the front. The allocation domain should follow + // any stride order specified/inferred. reorderLoopAsAllocation(outputs_without_mesh); } @@ -430,15 +458,15 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { // if any of them has one. This is needed in addition to the forward // propagation for ops that don't take any TensorView operands, e.g., // `uniform` used in dropout. See MultiDeviceTest.BackpropMeshes for an - // example. For non-fusion inputs, we also propagate shardings from outputs to inputs. - // See MultiDevicePresegPassesTest.ResidualAdd for an example. + // example. For non-fusion inputs, we also propagate shardings from outputs to + // inputs. See MultiDevicePresegPassesTest.ResidualAdd for an example. for (auto i_expr = exprs.rbegin(); i_expr != exprs.rend(); i_expr++) { Expr* expr = *i_expr; const auto& outputs = ir_utils::filterByType(expr->outputs()); std::vector sorted_outputs = sortTvsByDeviceDims(outputs); // All outputs of an expression (Welford, SDPA) should be uniformly sharded. - // We pick the most parallel output as the reference. + // We pick the most parallel output as the reference. // This is to avoid picking seed/offset tvs in SDPA. if (sorted_outputs.empty()) { @@ -446,10 +474,15 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { } TensorView* ref_output = sorted_outputs.front(); - NVF_ERROR(ref_output != nullptr && ref_output->hasDeviceMesh(), "Reference output ", ref_output, " has no device mesh."); - - // For fusion inputs, only check if they have a device mesh. We do not modify their sharding. - // For non-fusion inputs, check both device mesh and device dims. + NVF_ERROR( + ref_output != nullptr && ref_output->hasDeviceMesh(), + "Reference output ", + ref_output, + " has no device mesh."); + + // For fusion inputs, only check if they have a device mesh. We do not + // modify their sharding. For non-fusion inputs, check both device mesh and + // device dims. const auto& inputs = ir_utils::filterByType(expr->inputs()); std::vector unsharded_inputs; for (auto* tv : inputs) { @@ -468,7 +501,8 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { continue; } - std::unordered_map new2old = selectiveReorderDIDToFront(ref_output, {}); + std::unordered_map new2old = + selectiveReorderDIDToFront(ref_output, {}); int64_t did_pos = new2old.size(); // Note: We do not have to manually shard for reshape here. From 83f7133b51307b4d00fc1e08930988a5d07b9908 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Thu, 10 Apr 2025 17:05:15 -0700 Subject: [PATCH 34/70] remove view op specific changes --- csrc/preseg_passes/propagate_shardings.cpp | 260 ++++--------------- tests/cpp/test_multidevice_preseg_passes.cpp | 51 ++-- 2 files changed, 64 insertions(+), 247 deletions(-) diff --git a/csrc/preseg_passes/propagate_shardings.cpp b/csrc/preseg_passes/propagate_shardings.cpp index fe4caf50ebb..9f66943a479 100644 --- a/csrc/preseg_passes/propagate_shardings.cpp +++ b/csrc/preseg_passes/propagate_shardings.cpp @@ -49,34 +49,6 @@ void validateMeshes(Fusion* fusion) { " not."); } -std::pair, std::unordered_set> -getReshapedIds( - ViewOp* view_op, - const std::unordered_map& c2p) { - std::unordered_set p_reshaped_ids; // Reshaped logical IDs - std::unordered_set c_reshaped_ids; // Reshaped root IDs - - TensorView* consumer = view_op->out(); - std::vector c_root_domain = consumer->getMaybeRootDomain(); - - for (auto id : consumer->getLogicalDomain()) { - if (id->isRFactorProduct() && id->definition() && - !id->definition()->isA()) { - auto root_ids = getInputsInTargetDomain(id, c_root_domain); - for (auto root_id : root_ids) { - c_reshaped_ids.insert(root_id); - } - } - } - - for (auto id : c_reshaped_ids) { - if (auto p_id = c2p.find(id); p_id != c2p.end()) { - p_reshaped_ids.insert(p_id->second); - } - } - return std::make_pair(p_reshaped_ids, c_reshaped_ids); -} - int64_t numDeviceDims(TensorView* tv) { return std::count_if( tv->getLoopDomain().begin(), @@ -154,6 +126,7 @@ std::vector getOutputsWithoutMesh(Expr* expr) { return outputs_without_mesh; } +// Custom selector to specify direction of transform propagation. class PropagateShardingsSelector : public SetSelector { private: bool allow_c2p_; @@ -177,137 +150,6 @@ class PropagateShardingsSelector : public SetSelector { } }; -void splitLike( - TensorView* tv, - int64_t axis, - Split* ref_split, - bool allow_inner_split = false) { - auto split_factor = ref_split->factor(); - auto inner_split = ref_split->innerSplit(); - NVF_ERROR(!inner_split || allow_inner_split, "Inner split is not supported."); - tv->split(axis, split_factor, /*inner_split=*/inner_split); -} - -// Returns the number of DID axis on reshaped ids that were propagated to the -// consumer. -int64_t shardViewOp( - ViewOp* view_op, - std::unordered_map& new2old) { - // This implementation asserts that only one sharding is applied on the - // reshaped ids. Inner split is not supported. The cases are: - // 1. Split reshape: [h] -> [a, h/a]. Sharding on h is applied to a in - // consumer. - // 2. Merge reshape: [a, h/a] -> [h]. Sharding on a is applied to h in - // consumer. - // 3. Multiple splits or merge reshapes: [x, y, z] -> [xyz]. Sharding on x and - // xyz. Similarly for the corresponding split reshape. - // 4. Independent splits or merge reshapes: [w, x, y, z] -> [wx, yz]. Sharding - // is on w and y. In the consumer, it is applied to wx and yz. An improvement - // is to support mult-levels of sharding (not a real case in practice) if they - // are all outer splits. For example: For the reshape [h] -> [a, h/a] where - // the h is sharded twice: [h] -> [cp, h/cp] -> [cp, tp, h/(cp*tp)] - - // A more general approach maybe to "undo" the reshape (reverse transforms - // from root to logical domain), followed by simplification of the consumer - // loop domain to move DID upwards. - - TensorView* producer = view_op->in(); - TensorView* consumer = view_op->out(); - - const std::unordered_map& c2p = - PairwiseLogicalDomainMap(producer, consumer).mapConsumerToProducer(); - const std::unordered_map& p2c = - PairwiseLogicalDomainMap(producer, consumer).mapProducerToConsumer(); - auto [p_logical_reshaped_ids, c_root_reshaped_ids] = - getReshapedIds(view_op, c2p); - - auto p_loop_domain = producer->getLoopDomain(); - auto c_loop_domain = consumer->getLoopDomain(); - auto c_logical_domain = consumer->getLogicalDomain(); - - // Track number of DID axis on reshaped ids that were propagated to the - // consumer. These will not be included in TransformPropagator. - int64_t num_reshape_shardings = 0; - int64_t num_device_dims = new2old.size(); - - for (auto idx : c10::irange(num_device_dims)) { - IterDomain* p_did = p_loop_domain.at(idx); - NVF_ERROR(p_did->isDeviceDim()); - - auto p_transforms = DependencyCheck::getAllExprsBetween( - {p_logical_reshaped_ids.begin(), p_logical_reshaped_ids.end()}, - {p_loop_domain.at(idx)}); - - if (p_transforms.empty()) { - // This did axis is not on reshaped ids. We will use the - // TransformPropagator. - continue; - } - - if (p_transforms.size() > 1) { - // This reshape has been transformed. - // We will attempt to use TransformPropagator for this did axis. - continue; - } - - NVF_ERROR( - p_transforms.front()->isA(), - "Expected a split transform producing the did axis."); - NVF_ERROR( - TensorDomain::sameAs(c_logical_domain, c_loop_domain), - "Sharding a previously transformed reshape is not supported."); - - num_reshape_shardings++; - - // Find the producer logical id that is sharded. - // We expect the outermost reshaped id to be sharded and follow the - // outermost path traversing the transforms - auto* p_did_split = p_did->definition()->as(); - IterDomain* p_logical_did = p_did_split->in(); - - // Find the mapping of the corresponding producer logical id in consumer - // root. - IterDomain* c_root_did = p2c.at(p_logical_did); - - // Get the reshape transforms corresponding to this root id. - // We use the c_root_did to only find the reshape IDs related to this did. - auto reshape_transforms = DependencyCheck::getAllExprsBetween( - {c_root_did}, - {consumer->getLogicalDomain().begin(), - consumer->getLogicalDomain().end()}); - - // Obtain the logical axis sharded in the consumer. - IterDomain* c_logical_did = c_root_did; - for (auto transform : reshape_transforms) { - if (transform->isA()) { - c_logical_did = transform->as()->outer(); - } - if (transform->isA()) { - NVF_ERROR( - c_logical_did == transform->as()->outer(), - "Expected the sharding to be on the outer reshaped id."); - c_logical_did = transform->as()->out(); - } - } - - int64_t sharded_axis = std::distance( - c_loop_domain.begin(), - std::find(c_loop_domain.begin(), c_loop_domain.end(), c_logical_did)); - - // TODO: Check for divisibility of the consumer axis by the split factor. - splitLike(consumer, sharded_axis, p_did_split); - consumer->axis(sharded_axis)->parallelize(p_did->getParallelType()); - - // Move this did_pos behind the non-propagated DID axis to avoid using - // TransformPropagator on it. - producer->reorder({{idx, num_device_dims - 1}}); - new2old[idx] = num_device_dims - 1; - num_device_dims--; - } - - return num_reshape_shardings; -} - void reorderLoopAsAllocation(std::vector tvs) { // Use maybeAllocationDomain to transform // Transform using exprs between logical and loop and get the map. @@ -354,15 +196,27 @@ std::unordered_map selectiveReorderDIDToFront( return new2old; } -// Updates the set of parallel types seen on the output. -void updateOutputParallelTypes( - TensorView* tv, - std::unordered_set& output_parallel_types) { - for (auto id : tv->getLoopDomain()) { - if (id->isDeviceDim()) { - output_parallel_types.insert(id->getParallelType()); +// Returns the set of parallel types seen on the loop domain of the given tvs. +std::unordered_set getTvParallelTypes(std::vector tvs) { + std::unordered_set parallel_types; + for (auto tv : tvs) { + for (auto id : tv->getLoopDomain()) { + if (id->isDeviceDim()) { + parallel_types.insert(id->getParallelType()); + } } } + return parallel_types; +} + +void propagateDIDTransform(TensorView* ref, std::vector tvs, int64_t did_pos, bool allow_c2p, bool allow_p2c) { + TransformPropagator propagator(ref, did_pos); + PropagateShardingsSelector selector( + {tvs.begin(), tvs.end()}, + allow_c2p, + allow_p2c); + MaxLogicalDomainInfoSpanningTree(ref, &selector) + .traverse(&propagator); } } // namespace @@ -396,9 +250,6 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { if (reference_inputs.empty()) { continue; } - - std::unordered_set output_parallel_types; - // Propagate shardings from reference inputs in order. for (auto* ref_input : reference_inputs) { // Skip if the input has no device mesh or is nullptr. @@ -410,35 +261,20 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { // Reorder the DID axis to the front only if it does not have a parallel // type already seen on the output. + std::unordered_set existing_parallel_types = getTvParallelTypes(outputs_without_mesh); std::unordered_map new2old = - selectiveReorderDIDToFront(ref_input, output_parallel_types); + selectiveReorderDIDToFront(ref_input, existing_parallel_types); // This restricts the transform propagation to the DID axis. int64_t num_device_dims = new2old.size(); - if (ViewOp* view_op = dynamic_cast(expr)) { - // Propagation of reshape will return how many DID axis were propagated. - // They are reordered behind non-propagated DID axis and the new2old map - // is updated. - int64_t num_reshape_shardings = shardViewOp(view_op, new2old); - num_device_dims = num_device_dims - num_reshape_shardings; - } - // Propagate the DID loop split to the outputs without mesh. - TransformPropagator propagator(ref_input, num_device_dims); - PropagateShardingsSelector selector( - {outputs_without_mesh.begin(), outputs_without_mesh.end()}, - /*allow_c2p=*/false, - /*allow_p2c=*/true); - MaxLogicalDomainInfoSpanningTree(ref_input, &selector) - .traverse(&propagator); + propagateDIDTransform(/*ref=*/ref_input, /*tvs=*/outputs_without_mesh, /*did_pos=*/num_device_dims, /*allow_c2p=*/false, /*allow_p2c=*/true); // Apply parallelization on the outputs without mesh. shardAllLike(ref_input, outputs_without_mesh); - updateOutputParallelTypes(ref_input, output_parallel_types); - - // Undo the reordering of the DID axis so it is in the correct order + // Undo the reordering of the DID axis in ref_input so it is in the correct order // again. ref_input->reorder(new2old); } @@ -470,6 +306,7 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { // This is to avoid picking seed/offset tvs in SDPA. if (sorted_outputs.empty()) { + // No output with a device mesh. continue; } @@ -481,10 +318,10 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { " has no device mesh."); // For fusion inputs, only check if they have a device mesh. We do not - // modify their sharding. For non-fusion inputs, check both device mesh and - // device dims. + // modify their sharding. For non-fusion inputs, we try to propagate shardings + // from the reference output for parallel types that are not already present. const auto& inputs = ir_utils::filterByType(expr->inputs()); - std::vector unsharded_inputs; + std::vector inputs_to_shard; for (auto* tv : inputs) { if (tv->isFusionInput()) { if (!tv->hasDeviceMesh()) { @@ -492,33 +329,32 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { } continue; } - if (!tv->hasDeviceMesh() || numDeviceDims(tv) == 0) { - unsharded_inputs.push_back(tv); - } + inputs_to_shard.push_back(tv); } - if (unsharded_inputs.empty()) { + if (inputs_to_shard.empty()) { continue; } - std::unordered_map new2old = - selectiveReorderDIDToFront(ref_output, {}); - int64_t did_pos = new2old.size(); - - // Note: We do not have to manually shard for reshape here. - // TransformPropagator can handle reshapes when going from consumer to - // producer. - TransformPropagator propagator(ref_output, did_pos); - PropagateShardingsSelector selector( - {unsharded_inputs.begin(), unsharded_inputs.end()}, - /*allow_c2p=*/true, + // Each input can have different shardings, so attempt to propagate independently. + for (auto tv : inputs_to_shard) { + std::unordered_set existing_parallel_types = getTvParallelTypes({tv}); + std::unordered_map new2old = selectiveReorderDIDToFront(ref_output, existing_parallel_types); + int64_t did_pos = new2old.size(); + + // Note: We do not have to manually shard for reshape here. + // TransformPropagator can handle reshapes when going from consumer to + // producer. + propagateDIDTransform( + /*ref=*/ref_output, + /*tvs=*/{tv}, + /*did_pos=*/did_pos, + /*allow_c2p=*/true, /*allow_p2c=*/false); - MaxLogicalDomainInfoSpanningTree(ref_output, &selector) - .traverse(&propagator); - shardAllLike(ref_output, unsharded_inputs); - - ref_output->reorder(new2old); - reorderLoopAsAllocation(unsharded_inputs); + ref_output->reorder(new2old); + } + shardAllLike(ref_output, inputs_to_shard); + reorderLoopAsAllocation(inputs_to_shard); } validateMeshes(fusion); diff --git a/tests/cpp/test_multidevice_preseg_passes.cpp b/tests/cpp/test_multidevice_preseg_passes.cpp index e634717467c..40286dd1aa5 100644 --- a/tests/cpp/test_multidevice_preseg_passes.cpp +++ b/tests/cpp/test_multidevice_preseg_passes.cpp @@ -76,7 +76,7 @@ TEST_F(MultiDevicePresegPassesTest, ResidualAdd) { } } -TEST_F(MultiDevicePresegPassesTest, MultipleTransformReshape) { +TEST_F(MultiDevicePresegPassesTest, DISABLED_MultipleTransformReshape) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -115,7 +115,7 @@ TEST_F(MultiDevicePresegPassesTest, MultipleTransformReshape) { __FILE__); } -TEST_F(MultiDevicePresegPassesTest, SliceReshapePermute) { +TEST_F(MultiDevicePresegPassesTest, DISABLED_SliceReshapePermute) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -168,10 +168,9 @@ TEST_F(MultiDevicePresegPassesTest, SliceReshapePermute) { } namespace { -at::Tensor reference_mha(at::Tensor inp, at::Tensor weight) { - at::Tensor linear0_out = at::linear(inp, weight); +at::Tensor reference_mha(at::Tensor inp) { auto qkv = - linear0_out.view({b, s, a, 3 * h / a}).transpose(1, 2).split(h / a, -1); + inp.view({b, s, a, 3 * h / a}).transpose(1, 2).split(h / a, -1); double scale = 1.0 / std::sqrt(h / a); auto sdpa_out = at::_scaled_dot_product_flash_attention( qkv[0], @@ -182,7 +181,7 @@ at::Tensor reference_mha(at::Tensor inp, at::Tensor weight) { /*return_debug_mask=*/false, scale); auto attn = std::get<0>(sdpa_out); - return attn.transpose(1, 2).reshape({b, s, h}); + return attn.transpose(1, 2); } } // namespace @@ -195,15 +194,7 @@ TEST_F(MultiDevicePresegPassesTest, MHAFwd) { auto mesh = DeviceMesh::createForNumDevices(d); - TensorView* inp = makeConcreteTensor({b, s, h}, DataType::Half); - TensorView* mha_w0 = makeConcreteTensor({3 * d * h, h}, DataType::Half); - - // MHA Linear0 - TensorView* mha_linear0_out = linear(inp, mha_w0); - - // reshape -> slice -> permute - TensorView* qkv = - reshape(mha_linear0_out, {b, s, 3 * d * h}, {b, s, d * a, 3 * h / a}); + TensorView* qkv = makeConcreteTensor({b, s, d * a, 3 * h / a}, DataType::Half); TensorView* q = slice(qkv, {0, 0, 0, 0}, {b, s, d * a, h / a}); TensorView* k = slice(qkv, {0, 0, 0, h / a}, {b, s, d * a, 2 * h / a}); TensorView* v = slice(qkv, {0, 0, 0, 2 * h / a}, {b, s, d * a, 3 * h / a}); @@ -222,20 +213,13 @@ TEST_F(MultiDevicePresegPassesTest, MHAFwd) { TensorView* attn = sdpa_out.output; TensorView* attn_permute = permute(attn, {0, 2, 1, 3}); - TensorView* attn_reshaped = - reshape(attn_permute, {b, s, d * a, h / a}, {b, s, d * h}); - fusion->addInput(inp); - fusion->addInput(mha_w0); - fusion->addOutput(attn_reshaped); + fusion->addInput(qkv); + fusion->addOutput(attn_permute); - // Shard input tensors - for (auto* tv : {inp, mha_w0}) { - tv->setDeviceMesh(mesh); - } - - mha_w0->split(0, d, /*inner_split=*/false); - mha_w0->axis(0)->parallelize(ParallelType::DIDx); + qkv->setDeviceMesh(mesh); + qkv->outer_split(2, d); + qkv->axis(2)->parallelize(ParallelType::DIDx); preseg_passes::OptimizationPass< preseg_passes::PropagateShardingsPass>::runPass(fusion.get()); @@ -245,22 +229,19 @@ TEST_F(MultiDevicePresegPassesTest, MHAFwd) { } FusionExecutorCache executor_cache(std::move(fusion)); - at::Tensor inp_tensor = at::randn({b, s, h}, tensor_options.dtype(at::kHalf)); - - at::Tensor mha_w0_tensor = - at::randn({3 * d * h, h}, tensor_options.dtype(at::kHalf)); - at::Tensor sharded_mha_w0 = shardTensor(mha_w0_tensor, 0, mesh); + at::Tensor unsharded_inp_tensor = at::randn({b, s, d * a, 3 * h / a}, tensor_options.dtype(at::kHalf)); + at::Tensor inp_tensor = shardTensor(unsharded_inp_tensor, 2, mesh); - KernelArgumentHolder args = {inp_tensor, sharded_mha_w0}; + KernelArgumentHolder args = {inp_tensor}; auto outputs = executor_cache.runFusionWithInputs(args); at::Tensor nvf_out = outputs[0].as(); - at::Tensor ref_out = reference_mha(inp_tensor, sharded_mha_w0); + at::Tensor ref_out = reference_mha(inp_tensor); testValidate( executor_cache.fusion(), {nvf_out}, - {inp_tensor, sharded_mha_w0}, + {inp_tensor}, {ref_out}, __LINE__, __FILE__); From 7919ef2c6cdca635a62ede0342a6a71f338aad80 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Thu, 10 Apr 2025 17:12:04 -0700 Subject: [PATCH 35/70] undo changes in lower.cpp --- csrc/host_ir/lower.cpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/csrc/host_ir/lower.cpp b/csrc/host_ir/lower.cpp index 51bde80e74b..32febda37a0 100644 --- a/csrc/host_ir/lower.cpp +++ b/csrc/host_ir/lower.cpp @@ -435,9 +435,11 @@ bool HostIrLower::canLower(Expr* expr, bool ignore_inner_resharding) { // stream-parallelized on axis 0. auto* a = linear->inA()->as(); auto* b = linear->inB()->as(); + auto* bias = + (linear->has_bias() ? linear->bias()->as() : nullptr); auto* out = linear->out()->as(); - return !isSharded(b) && - !(linear->has_bias() && isSharded(linear->bias())) && !isSharded(out) && + return !isSharded(b) && !(linear->has_bias() && isSharded(bias)) && + !isSharded(out) && a->axis(0)->getParallelType() == ParallelType::Serial && getShardedLogicalAxis(a, ParallelType::DIDx) == 1 && out->axis(0)->getParallelType() == ParallelType::Stream; @@ -463,7 +465,7 @@ std::vector HostIrLower::lowerToCollectiveBasedPipelinedGemmComm( auto* linear = expr->as(); tva = linear->inA()->as(); tvb = linear->inB()->as(); - tv_bias = linear->bias()->as(); + tv_bias = (linear->has_bias() ? linear->bias()->as() : nullptr); tv_out = linear->out()->as(); NVF_ERROR( !(linear->has_bias() && isSharded(tv_bias)), @@ -637,7 +639,6 @@ std::unique_ptr HostIrLower::lower( }; for (auto group : workspace.group_run_order) { - std::vector host_exprs; NVF_ERROR(!group->exprs().empty(), "invalid segmentation"); if (involvedDevices(group->exprs().at(0)).count(my_device_index) == 0) { continue; From 7a34ea46f7ad83aa0a0922e1d6a95be5d23ecd31 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Thu, 10 Apr 2025 17:21:40 -0700 Subject: [PATCH 36/70] undo changes in scheduler_utils merged in another PR --- csrc/scheduler/utils.cpp | 109 ++++++++++++++++++--------------------- csrc/scheduler/utils.h | 17 +++--- 2 files changed, 59 insertions(+), 67 deletions(-) diff --git a/csrc/scheduler/utils.cpp b/csrc/scheduler/utils.cpp index 70b5cc60849..2e566f5bb17 100644 --- a/csrc/scheduler/utils.cpp +++ b/csrc/scheduler/utils.cpp @@ -2126,33 +2126,34 @@ bool breakIsDisjoint(std::vector group_ids, int64_t pos) { namespace { -void applySplitTransform(const Split* split, std::vector& ids) { - auto find_it = - std::find(ids.begin(), ids.end(), split->in()); - if (find_it == ids.end()) { - // Transformations before rfactor, ignore those. - return; - } +void applySplitTransform(Split* split, std::vector& ids) { + auto find_it = std::find(ids.begin(), ids.end(), split->in()); + NVF_ERROR( + find_it != ids.end(), + "Split input ", + split->in()->toString(), + " not found in given ids: ", + ids); auto pos = std::distance(ids.begin(), find_it); ids[pos] = split->inner(); ids.insert(ids.begin() + pos, split->outer()); } -void applyMergeTransform(const Merge* merge, std::vector& ids) { - auto find_it_0 = - std::find(ids.begin(), ids.end(), merge->outer()); - auto find_it_1 = - std::find(ids.begin(), ids.end(), merge->inner()); - if (find_it_0 == ids.end() && - find_it_1 == ids.end()) { - // Transformations before rfactor, ignore those. - return; - } +void applyMergeTransform(Merge* merge, std::vector& ids) { + auto find_it_0 = std::find(ids.begin(), ids.end(), merge->outer()); + auto find_it_1 = std::find(ids.begin(), ids.end(), merge->inner()); NVF_ERROR( - find_it_0 != ids.end() && find_it_1 != ids.end(), - "Error in transformations of ", - ids, - "\nTransformations before rfactor should not mix with transformations after rfactor."); + find_it_0 != ids.end(), + "Merge outer ", + merge->outer()->toString(), + " not found in given ids: ", + ids); + NVF_ERROR( + find_it_1 != ids.end(), + "Merge inner ", + merge->inner()->toString(), + " not found in given ids: ", + ids); auto pos0 = std::distance(ids.begin(), find_it_0); auto pos1 = std::distance(ids.begin(), find_it_1); if (pos0 > pos1) { @@ -2168,62 +2169,54 @@ void applyMergeTransform(const Merge* merge, std::vector& ids) { ids[--pos1] = merge->out(); } -void applyResizeTransform(const Resize* resize, std::vector& ids) { - auto find_it = - std::find(ids.begin(), ids.end(), resize->in()); - if (find_it == ids.end()) { - // Transformations before rfactor, ignore those. - return; - } +void applyResizeTransform(Resize* resize, std::vector& ids) { + auto find_it = std::find(ids.begin(), ids.end(), resize->in()); + NVF_ERROR( + find_it != ids.end(), + "Resize input ", + resize->in()->toString(), + " not found in given ids: ", + ids); *find_it = resize->out(); } } // namespace -// Update the vector of ids_to_transform as progressing through the transformation -// expressions. We'll always insert the result of split in the location of the -// input, and insert the merge result in the position of the inner dimension. -// After transformations, ids_to_reorder should be a permutation of ids_to_transform. -// Returns a reorder map from ids_to_reorder to ids_to_transform. -std::unordered_map createReorderMapUnderTransforms( - const std::vector& ids_to_reorder, +void applyTransforms( std::vector& ids_to_transform, const std::vector& transform_exprs) { - for (const auto* expr : transform_exprs) { - if (const Split* split = dynamic_cast(expr)) { + for (auto* expr : transform_exprs) { + if (Split* split = dynamic_cast(expr)) { applySplitTransform(split, ids_to_transform); - } else if (const Merge* merge = dynamic_cast(expr)) { + } else if (Merge* merge = dynamic_cast(expr)) { applyMergeTransform(merge, ids_to_transform); - } else if (const Resize* resize = dynamic_cast(expr)) { + } else if (Resize* resize = dynamic_cast(expr)) { applyResizeTransform(resize, ids_to_transform); } else { NVF_ERROR(expr != nullptr); NVF_THROW("Unexpected expression: ", expr->toString()); } } - - std::unordered_map old2new; - for (auto idx : c10::irange((int64_t)ids_to_reorder.size())) { - auto orig_id = ids_to_reorder.at(idx); - auto find_it = std::find(ids_to_transform.begin(), ids_to_transform.end(), orig_id); - NVF_ERROR( - find_it != ids_to_transform.end(), - "Reordering map creation failed, uninitialized iterdomain, ", - orig_id->toString(), - " likely something is wrong with the transformations between the logical and loop domain."); - int64_t new_pos = (int64_t)std::distance(ids_to_transform.begin(), find_it); - old2new[idx] = new_pos; - } - return old2new; } -// Returns a map reordering the loop domain of the tensor view as the logical domain -std::unordered_map domainReorderAsLogicalMap(TensorView* tv) { +// Returns a permutation reordering the loop domain of the tensor view as the +// logical domain +std::vector domainReorderAsLogicalMap(TensorView* tv) { FusionGuard fg(tv->fusion()); - auto transform_exprs = StmtSort::getExprsTo( + auto transform_exprs = DependencyCheck::getAllExprsBetween( + {tv->getLogicalDomain().begin(), tv->getLogicalDomain().end()}, {tv->getLoopDomain().begin(), tv->getLoopDomain().end()}); - auto ids_to_transform = tv->getLogicalDomain(); - return createReorderMapUnderTransforms(tv->getLoopDomain(), ids_to_transform, transform_exprs); + std::vector ids_to_transform = tv->getLogicalDomain(); + applyTransforms(ids_to_transform, transform_exprs); + std::optional> permutation = + ir_utils::computePermutation(ids_to_transform, tv->getLoopDomain()); + NVF_ERROR( + permutation.has_value(), + "Failed to find a valid permutation for reordering", + tv->getLoopDomain(), + " as ", + ids_to_transform); + return *permutation; } std::unordered_map maybeReorderAsAllocationMap( diff --git a/csrc/scheduler/utils.h b/csrc/scheduler/utils.h index e882324e034..67385a61648 100644 --- a/csrc/scheduler/utils.h +++ b/csrc/scheduler/utils.h @@ -649,22 +649,21 @@ DisjointSets disjointLogicalSets(Fusion* fusion); // [1, 0, 0] pos 1 would return true bool breakIsDisjoint(std::vector group_ids, int64_t pos); -// Transform the ids_to_transform as progressing through the transform_exprs -// and return a reorder map from ids_to_reorder to the transformed ids. -// This is used to reorder the loop domain as the logical or the allocation order. -std::unordered_map createReorderMapUnderTransforms( - const std::vector& ids_to_reorder, +// Update the vector of ids_to_transform as progressing through the +// `transform_exprs`. We'll always insert the result of split in the +// location of the input, and insert the merge result in the position of the +// inner dimension. +void applyTransforms( std::vector& ids_to_transform, const std::vector& transform_exprs); -// Generates an old to new map to reorder tv's domain as the logical order. +// Generates a permutation to reorder tv's domain as the logical order. // Priority is given to inner most dimensions for example: // logical [i0, i1, i2] // domain [i0*i2, i1] -// will produce the map {{0, 1}, {1, 0}} +// will produce the permutation {1, 0} // This is somewhat similar to orderTiledConcreteIdAsRoot -std::unordered_map domainReorderAsLogicalMap(TensorView* tv); - +std::vector domainReorderAsLogicalMap(TensorView* tv); // Generates an old to new map to reorder tv's loop domain as its allocation // order. This only handles the simple case where allocation is a permutation of From 6e6b1d02ae31db4a07e77afd9d301b64560ee261 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Thu, 10 Apr 2025 17:41:22 -0700 Subject: [PATCH 37/70] fix reorderLoopAsAllocation --- csrc/preseg_passes/propagate_shardings.cpp | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/csrc/preseg_passes/propagate_shardings.cpp b/csrc/preseg_passes/propagate_shardings.cpp index 9f66943a479..a1810c9b790 100644 --- a/csrc/preseg_passes/propagate_shardings.cpp +++ b/csrc/preseg_passes/propagate_shardings.cpp @@ -151,10 +151,12 @@ class PropagateShardingsSelector : public SetSelector { }; void reorderLoopAsAllocation(std::vector tvs) { - // Use maybeAllocationDomain to transform - // Transform using exprs between logical and loop and get the map. + // Transform the maybe allocation domain to the loop domain. + // using exprs between logical and loop and get the permutation required to + // reorder the loop domain in the same relative order as the allocation domain. for (auto tv : tvs) { auto alloc_dom = tv->getMaybeAllocationDomain(); + // Allocation domain should be a permutation of logical domain at this point. std::vector transform_exprs = DependencyCheck::getAllExprsBetween( {alloc_dom.begin(), alloc_dom.end()}, {tv->getLoopDomain().begin(), tv->getLoopDomain().end()}); @@ -164,11 +166,15 @@ void reorderLoopAsAllocation(std::vector tvs) { transform_exprs.end(), [](Expr* expr) { return expr->isA(); }), "Expected all transform exprs to be a split between logical and loop domain during sharding propagation."); - auto reorder_map = scheduler_utils::createReorderMapUnderTransforms( - /*ids_to_reorder=*/tv->getLoopDomain(), - /*ids_to_transform=*/alloc_dom, - /*transform_exprs=*/transform_exprs); - tv->reorder(reorder_map); + scheduler_utils::applyTransforms(alloc_dom, transform_exprs); + std::optional> permutation = ir_utils::computePermutation(alloc_dom, tv->getLoopDomain()); + NVF_ERROR( + permutation.has_value(), + "Failed to find a valid permutation for reordering", + tv->getLoopDomain(), + " as ", + alloc_dom); + tv->reorder(permutation.value()); } } From cdc8d5d9d84369682a7db37a55ac6798805ce0a4 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Thu, 10 Apr 2025 18:28:24 -0700 Subject: [PATCH 38/70] set allocation --- csrc/preseg_passes/propagate_shardings.cpp | 102 +++++++++------------ 1 file changed, 43 insertions(+), 59 deletions(-) diff --git a/csrc/preseg_passes/propagate_shardings.cpp b/csrc/preseg_passes/propagate_shardings.cpp index a1810c9b790..9a0fd5cdea9 100644 --- a/csrc/preseg_passes/propagate_shardings.cpp +++ b/csrc/preseg_passes/propagate_shardings.cpp @@ -19,34 +19,22 @@ namespace nvfuser::preseg_passes { namespace { -void validateMeshes(Fusion* fusion) { +// Validates meshes and returns true if any TensorView has a device mesh. +bool validateMeshes(Fusion* fusion) { // Validate that meshes are assigned to all TensorViews or none. - TensorView* tv_with_mesh = nullptr; - TensorView* tv_without_mesh = nullptr; - for (TensorView* tv : fusion->allTvs()) { - auto update_if_null = [](TensorView*& lhs, TensorView* rhs) { - if (lhs == nullptr) { - lhs = rhs; - } - }; - + bool tv_with_mesh_found = false; + bool tv_without_mesh_found = false; + + for (auto tv : fusion->allTvs()) { if (tv->isCpuScalar()) { continue; } - - if (tv->hasDeviceMesh()) { - update_if_null(tv_with_mesh, tv); - } else { - update_if_null(tv_without_mesh, tv); - } + tv->hasDeviceMesh() ? tv_with_mesh_found = true : tv_without_mesh_found = true; } NVF_CHECK( - tv_with_mesh == nullptr || tv_without_mesh == nullptr, - "Found ", - tv_with_mesh, - " assigned a mesh and ", - tv_without_mesh, - " not."); + !(tv_with_mesh_found && tv_without_mesh_found), + "Cannot have some TensorViews with device mesh and some without."); + return tv_with_mesh_found; } int64_t numDeviceDims(TensorView* tv) { @@ -279,21 +267,7 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { // Apply parallelization on the outputs without mesh. shardAllLike(ref_input, outputs_without_mesh); - - // Undo the reordering of the DID axis in ref_input so it is in the correct order - // again. - ref_input->reorder(new2old); } - - // Reorder the loop domain since the transform propagator may - // have reordered the iterdomains in loop domain. For example: Consider - // linear op: in = [b, m, k] weight = [DIDx(d), n/d, k] After - // transformation, the loop domain of linear output is [DIDx(d), n/d, b, - // m, r{k}]. Since, we set allocation to be the same as loop, we reorder it - // as allocation domain in the interim. Ideally, this should follow logical - // domain and DIDx axis at the front. The allocation domain should follow - // any stride order specified/inferred. - reorderLoopAsAllocation(outputs_without_mesh); } // Back-propagate device meshes. This makes sure all TensorViews have a mesh @@ -327,7 +301,7 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { // modify their sharding. For non-fusion inputs, we try to propagate shardings // from the reference output for parallel types that are not already present. const auto& inputs = ir_utils::filterByType(expr->inputs()); - std::vector inputs_to_shard; + std::vector unsharded_inputs; for (auto* tv : inputs) { if (tv->isFusionInput()) { if (!tv->hasDeviceMesh()) { @@ -335,35 +309,45 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { } continue; } - inputs_to_shard.push_back(tv); + if (!tv->hasDeviceMesh() || numDeviceDims(tv) == 0) { + unsharded_inputs.push_back(tv); + } } - if (inputs_to_shard.empty()) { + if (unsharded_inputs.empty()) { continue; } - // Each input can have different shardings, so attempt to propagate independently. - for (auto tv : inputs_to_shard) { - std::unordered_set existing_parallel_types = getTvParallelTypes({tv}); - std::unordered_map new2old = selectiveReorderDIDToFront(ref_output, existing_parallel_types); - int64_t did_pos = new2old.size(); - - // Note: We do not have to manually shard for reshape here. - // TransformPropagator can handle reshapes when going from consumer to - // producer. - propagateDIDTransform( - /*ref=*/ref_output, - /*tvs=*/{tv}, - /*did_pos=*/did_pos, - /*allow_c2p=*/true, - /*allow_p2c=*/false); - ref_output->reorder(new2old); - } - shardAllLike(ref_output, inputs_to_shard); - reorderLoopAsAllocation(inputs_to_shard); + std::unordered_map new2old = selectiveReorderDIDToFront(ref_output, {}); + int64_t did_pos = new2old.size(); + + // Note: We do not have to manually shard for reshape here. + // TransformPropagator can handle reshapes when going from consumer to + // producer. + propagateDIDTransform( + /*ref=*/ref_output, + /*tvs=*/unsharded_inputs, + /*did_pos=*/did_pos, + /*allow_c2p=*/true, + /*allow_p2c=*/false); + shardAllLike(ref_output, unsharded_inputs); } - validateMeshes(fusion); + bool has_mesh = validateMeshes(fusion); + if (has_mesh) { + // Reorder the loop domain since the transform propagator may + // have reordered the iterdomains in loop domain. For example: Consider + // linear op: in = [b, m, k] weight = [DIDx(d), n/d, k] After + // transformation, the loop domain of linear output is [DIDx(d), n/d, b, + // m, r{k}]. Since, we set allocation to be the same as loop, we reorder it + // as allocation domain in the interim. Ideally, this should follow logical + // domain and DIDx axis at the front. The allocation domain should follow + // any stride order specified/inferred. + reorderLoopAsAllocation(fusion->allTvs()); + for (auto tv : fusion->allTvs()) { + tv->setAllocationDomain(tv->getLoopDomain(), true); + } + } } } // namespace nvfuser::preseg_passes From 07fbcb77c6c162e4318f4c65b31400703bbefcc7 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Fri, 11 Apr 2025 10:49:44 -0700 Subject: [PATCH 39/70] shard each input individually in backprop --- csrc/multidevice/utils.cpp | 12 ++++++-- csrc/multidevice/utils.h | 2 +- csrc/preseg_passes/pre_segmenter.cpp | 13 ++++---- csrc/preseg_passes/propagate_shardings.cpp | 36 ++++++++++++---------- 4 files changed, 36 insertions(+), 27 deletions(-) diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp index 593dd8793cc..3d5e25625e6 100644 --- a/csrc/multidevice/utils.cpp +++ b/csrc/multidevice/utils.cpp @@ -571,13 +571,19 @@ bool isInnerResharding(Expr* expr) { return false; } -void shardAllLike(TensorView* ref, std::vector tvs) { +void shardAllLike(TensorView* ref, std::vector tvs, std::unordered_set existing_parallel_types) { for (auto tv : tvs) { tv->setDeviceMesh(ref->getDeviceMesh()); } if (!tvs.empty()) { - scheduler_utils::parallelizeAllLike( - ref, tvs, {ParallelType::DIDx, ParallelType::Serial}); + std::unordered_set parallel_types; + parallel_types.insert(ParallelType::Serial); + for (auto pt : kParallelTypeDIDs) { + if (!existing_parallel_types.count(pt)) { + parallel_types.insert(pt); + } + } + scheduler_utils::parallelizeAllLike(ref, tvs, parallel_types); } // parallelAllLke, tries to DID-parallelize diff --git a/csrc/multidevice/utils.h b/csrc/multidevice/utils.h index db560097230..6ed9ba8e49d 100644 --- a/csrc/multidevice/utils.h +++ b/csrc/multidevice/utils.h @@ -62,7 +62,7 @@ bool haveDifferentShardings( bool isInnerResharding(Expr* expr); // Shards all tensors in tvs like reference -void shardAllLike(TensorView* ref, std::vector tvs); +void shardAllLike(TensorView* ref, std::vector tvs, std::unordered_set existing_parallel_types={}); // Shards all TVs between from and to AND between TVs created inside a fusion // and to. This is required for (1) expressions like rng_uniform that create a diff --git a/csrc/preseg_passes/pre_segmenter.cpp b/csrc/preseg_passes/pre_segmenter.cpp index 042f03191f7..1449a9bfc8a 100644 --- a/csrc/preseg_passes/pre_segmenter.cpp +++ b/csrc/preseg_passes/pre_segmenter.cpp @@ -39,12 +39,6 @@ namespace nvfuser::preseg_passes { debug() << "========================================" << std::endl; } - // For resharding across GPUs. - OptimizationPass::runPass(fusion); - OptimizationPass::runPass(fusion); - OptimizationPass::runPass(fusion); - OptimizationPass::runPass(fusion); - // Replace TensorViews with zero extent. Outputs and inputs may still be empty OptimizationPass::runPass(fusion); // This pass should be placed before ConsecutiveCastPass as more @@ -81,6 +75,13 @@ namespace nvfuser::preseg_passes { OptimizationPass::runPass(fusion); OptimizationPass::runPass(fusion); OptimizationPass::runPass(fusion); + + // For resharding across GPUs. + OptimizationPass::runPass(fusion); + OptimizationPass::runPass(fusion); + OptimizationPass::runPass(fusion); + OptimizationPass::runPass(fusion); + OptimizationPass::runPass(fusion); OptimizationPass::runPass(fusion); OptimizationPass::runPass(fusion); diff --git a/csrc/preseg_passes/propagate_shardings.cpp b/csrc/preseg_passes/propagate_shardings.cpp index 9a0fd5cdea9..281c17f08ea 100644 --- a/csrc/preseg_passes/propagate_shardings.cpp +++ b/csrc/preseg_passes/propagate_shardings.cpp @@ -266,7 +266,7 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { propagateDIDTransform(/*ref=*/ref_input, /*tvs=*/outputs_without_mesh, /*did_pos=*/num_device_dims, /*allow_c2p=*/false, /*allow_p2c=*/true); // Apply parallelization on the outputs without mesh. - shardAllLike(ref_input, outputs_without_mesh); + shardAllLike(ref_input, outputs_without_mesh, existing_parallel_types); } } @@ -301,7 +301,7 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { // modify their sharding. For non-fusion inputs, we try to propagate shardings // from the reference output for parallel types that are not already present. const auto& inputs = ir_utils::filterByType(expr->inputs()); - std::vector unsharded_inputs; + std::vector sharding_candidates; for (auto* tv : inputs) { if (tv->isFusionInput()) { if (!tv->hasDeviceMesh()) { @@ -310,27 +310,29 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { continue; } if (!tv->hasDeviceMesh() || numDeviceDims(tv) == 0) { - unsharded_inputs.push_back(tv); + sharding_candidates.push_back(tv); } } - if (unsharded_inputs.empty()) { + if (sharding_candidates.empty()) { continue; } - std::unordered_map new2old = selectiveReorderDIDToFront(ref_output, {}); - int64_t did_pos = new2old.size(); - - // Note: We do not have to manually shard for reshape here. - // TransformPropagator can handle reshapes when going from consumer to - // producer. - propagateDIDTransform( - /*ref=*/ref_output, - /*tvs=*/unsharded_inputs, - /*did_pos=*/did_pos, - /*allow_c2p=*/true, - /*allow_p2c=*/false); - shardAllLike(ref_output, unsharded_inputs); + for (auto tv : sharding_candidates) { + std::unordered_set existing_parallel_types = getTvParallelTypes({tv}); + std::unordered_map new2old = selectiveReorderDIDToFront(ref_output, existing_parallel_types); + int64_t did_pos = new2old.size(); + // Note: We do not have to manually shard for reshape here. + // TransformPropagator can handle reshapes when going from consumer to + // producer. + propagateDIDTransform( + /*ref=*/ref_output, + /*tvs=*/{tv}, + /*did_pos=*/did_pos, + /*allow_c2p=*/true, + /*allow_p2c=*/false); + shardAllLike(ref_output, {tv}, existing_parallel_types); + } } bool has_mesh = validateMeshes(fusion); From 7956dacb554ba9bf502b36090a6bea0c9126ab99 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Fri, 11 Apr 2025 11:04:23 -0700 Subject: [PATCH 40/70] cleanup --- csrc/multidevice/utils.cpp | 16 ++----- csrc/multidevice/utils.h | 6 +-- csrc/preseg_passes/propagate_shardings.cpp | 56 ++++++++++------------ 3 files changed, 30 insertions(+), 48 deletions(-) diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp index 6341f516858..8da9d2afb1b 100644 --- a/csrc/multidevice/utils.cpp +++ b/csrc/multidevice/utils.cpp @@ -344,8 +344,6 @@ std::pair computeLoopIndex( return id_to_index.at(id); } -} // namespace - std::vector getInputsInTargetDomain( IterDomain* loop_id, const std::vector& target_domain) { @@ -362,6 +360,8 @@ std::vector getInputsInTargetDomain( return inputs_as_iter_domains; } +} // namespace + bool haveDifferentShardings( const TensorView* producer, const TensorView* consumer) { @@ -595,19 +595,13 @@ bool isInnerResharding(Expr* expr) { return false; } -void shardAllLike(TensorView* ref, std::vector tvs, std::unordered_set existing_parallel_types) { +void shardAllLike(TensorView* ref, std::vector tvs) { for (auto tv : tvs) { tv->setDeviceMesh(ref->getDeviceMesh()); } if (!tvs.empty()) { - std::unordered_set parallel_types; - parallel_types.insert(ParallelType::Serial); - for (auto pt : kParallelTypeDIDs) { - if (!existing_parallel_types.count(pt)) { - parallel_types.insert(pt); - } - } - scheduler_utils::parallelizeAllLike(ref, tvs, parallel_types); + scheduler_utils::parallelizeAllLike( + ref, tvs, {ParallelType::DIDx, ParallelType::Serial}); } } diff --git a/csrc/multidevice/utils.h b/csrc/multidevice/utils.h index 6ed9ba8e49d..34c510ccb2e 100644 --- a/csrc/multidevice/utils.h +++ b/csrc/multidevice/utils.h @@ -39,10 +39,6 @@ bool isSharded(const TensorView*); // Returns number of device dimensions in a TensorView's loop domain. int64_t numDeviceDims(const TensorView*); -std::vector getInputsInTargetDomain( - IterDomain* loop_id, - const std::vector& target_domain); - // Returns the subset of tvs which elements have the different multi-device // sharding as ref std::unordered_set getTvsWithDifferentSharding( @@ -62,7 +58,7 @@ bool haveDifferentShardings( bool isInnerResharding(Expr* expr); // Shards all tensors in tvs like reference -void shardAllLike(TensorView* ref, std::vector tvs, std::unordered_set existing_parallel_types={}); +void shardAllLike(TensorView* ref, std::vector tvs); // Shards all TVs between from and to AND between TVs created inside a fusion // and to. This is required for (1) expressions like rng_uniform that create a diff --git a/csrc/preseg_passes/propagate_shardings.cpp b/csrc/preseg_passes/propagate_shardings.cpp index 281c17f08ea..85878c835ff 100644 --- a/csrc/preseg_passes/propagate_shardings.cpp +++ b/csrc/preseg_passes/propagate_shardings.cpp @@ -19,7 +19,7 @@ namespace nvfuser::preseg_passes { namespace { -// Validates meshes and returns true if any TensorView has a device mesh. +// Validates meshes (i.e. all TensorViews have a device mesh or none) and returns true if any TensorView has a device mesh. bool validateMeshes(Fusion* fusion) { // Validate that meshes are assigned to all TensorViews or none. bool tv_with_mesh_found = false; @@ -37,17 +37,6 @@ bool validateMeshes(Fusion* fusion) { return tv_with_mesh_found; } -int64_t numDeviceDims(TensorView* tv) { - return std::count_if( - tv->getLoopDomain().begin(), - tv->getLoopDomain().end(), - std::mem_fn(&IterDomain::isDeviceDim)); -} - -// Sort the given tvs by the number of device dimensions in descending order. -// Break ties by the total number of dimensions. -// Only includes TensorViews that have a device mesh. - template std::vector filterTvsWithMesh(const Range& tvs) { std::vector tvs_with_mesh; @@ -59,13 +48,23 @@ std::vector filterTvsWithMesh(const Range& tvs) { return tvs_with_mesh; } +// Sort the given tvs by the number of device dimensions in descending order. +// Break ties by the total number of dimensions. +// Only includes TensorViews that have a device mesh. template std::vector sortTvsByDeviceDims(const Range& tvs) { // Filter out TVs without a device mesh std::vector tvs_with_mesh = filterTvsWithMesh(tvs); + auto numDeviceDims = [](TensorView* tv) -> int64_t { + return std::count_if( + tv->getLoopDomain().begin(), + tv->getLoopDomain().end(), + std::mem_fn(&IterDomain::isDeviceDim)); + }; + // Then sort the filtered TVs - std::sort(tvs_with_mesh.begin(), tvs_with_mesh.end(), [](auto a, auto b) { + std::sort(tvs_with_mesh.begin(), tvs_with_mesh.end(), [&numDeviceDims](auto a, auto b) { int64_t a_device_dims = numDeviceDims(a); int64_t b_device_dims = numDeviceDims(b); if (a_device_dims != b_device_dims) { @@ -168,26 +167,23 @@ void reorderLoopAsAllocation(std::vector tvs) { // Reorder the DID axis to the front only if it does not have a parallel type // already seen on the output (existing_parallel_types). -// Returns a map from the new position to the old position to undo the -// reordering later. -std::unordered_map selectiveReorderDIDToFront( +// Returns the number of device dimensions that were reordered to the front. +int64_t selectiveReorderDIDToFront( TensorView* tv, std::unordered_set existing_parallel_types) { std::unordered_map old2new; - std::unordered_map new2old; int64_t current_pos = 0; for (auto pos : c10::irange(tv->nDims())) { if (tv->axis(pos)->isDeviceDim() && !existing_parallel_types.count(tv->axis(pos)->getParallelType())) { old2new[pos] = current_pos; - new2old[current_pos] = pos; current_pos++; } } tv->reorder(old2new); - return new2old; + return current_pos; } // Returns the set of parallel types seen on the loop domain of the given tvs. @@ -254,16 +250,15 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { " has no device mesh."); // Reorder the DID axis to the front only if it does not have a parallel - // type already seen on the output. + // type already seen on the outputs. std::unordered_set existing_parallel_types = getTvParallelTypes(outputs_without_mesh); - std::unordered_map new2old = + + // This restricts the transform propagation to only the relevant DID axis. + int64_t did_pos = selectiveReorderDIDToFront(ref_input, existing_parallel_types); - // This restricts the transform propagation to the DID axis. - int64_t num_device_dims = new2old.size(); - // Propagate the DID loop split to the outputs without mesh. - propagateDIDTransform(/*ref=*/ref_input, /*tvs=*/outputs_without_mesh, /*did_pos=*/num_device_dims, /*allow_c2p=*/false, /*allow_p2c=*/true); + propagateDIDTransform(/*ref=*/ref_input, /*tvs=*/outputs_without_mesh, /*did_pos=*/did_pos, /*allow_c2p=*/false, /*allow_p2c=*/true); // Apply parallelization on the outputs without mesh. shardAllLike(ref_input, outputs_without_mesh, existing_parallel_types); @@ -280,11 +275,11 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { Expr* expr = *i_expr; const auto& outputs = ir_utils::filterByType(expr->outputs()); - std::vector sorted_outputs = sortTvsByDeviceDims(outputs); // All outputs of an expression (Welford, SDPA) should be uniformly sharded. // We pick the most parallel output as the reference. // This is to avoid picking seed/offset tvs in SDPA. - + std::vector sorted_outputs = sortTvsByDeviceDims(outputs); + if (sorted_outputs.empty()) { // No output with a device mesh. continue; @@ -309,9 +304,7 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { } continue; } - if (!tv->hasDeviceMesh() || numDeviceDims(tv) == 0) { - sharding_candidates.push_back(tv); - } + sharding_candidates.push_back(tv); } if (sharding_candidates.empty()) { @@ -320,8 +313,7 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { for (auto tv : sharding_candidates) { std::unordered_set existing_parallel_types = getTvParallelTypes({tv}); - std::unordered_map new2old = selectiveReorderDIDToFront(ref_output, existing_parallel_types); - int64_t did_pos = new2old.size(); + int64_t did_pos = selectiveReorderDIDToFront(ref_output, existing_parallel_types); // Note: We do not have to manually shard for reshape here. // TransformPropagator can handle reshapes when going from consumer to // producer. From c271701bac1abfbd16c246d0da6765c8e5e4cbcb Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Fri, 11 Apr 2025 11:37:09 -0700 Subject: [PATCH 41/70] rm reshape tests --- tests/cpp/test_multidevice_preseg_passes.cpp | 113 ------------------- 1 file changed, 113 deletions(-) diff --git a/tests/cpp/test_multidevice_preseg_passes.cpp b/tests/cpp/test_multidevice_preseg_passes.cpp index 40286dd1aa5..4f66782673d 100644 --- a/tests/cpp/test_multidevice_preseg_passes.cpp +++ b/tests/cpp/test_multidevice_preseg_passes.cpp @@ -59,112 +59,6 @@ TEST_F(MultiDevicePresegPassesTest, ResidualAdd) { preseg_passes::PropagateShardingsPass>::runPass(fusion.get()); NVF_CHECK(tv1->hasDeviceMesh()); NVF_CHECK(getShardedLogicalAxis(tv1, ParallelType::DIDx) == getShardedLogicalAxis(tv0, ParallelType::DIDx), "Expected tv1 to be sharded like tv0 due to backpropagation of shardings."); - // Set the allocation domain explicitly until the preseg pass is fixed. - for (auto* tv : {tv0, tv1, tv2}) { - reorderDIDToFront(tv); - tv->setAllocationDomain(tv->getLoopDomain(), true); - } - - at::Tensor inp0 = at::randn({d, 4}, tensor_options); - at::Tensor sharded_inp0 = shardTensor(inp0, 0, mesh); - - FusionExecutorCache executor_cache(std::move(fusion)); - auto nvf_out = - executor_cache.runFusionWithInputs({sharded_inp0}); - for (auto& out : nvf_out) { - EXPECT_THAT(out.as().sizes(), ElementsAre(1, 4)); - } -} - -TEST_F(MultiDevicePresegPassesTest, DISABLED_MultipleTransformReshape) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - const int d = communicator_->size(); - const int64_t b = 2, s = 3, h = 8, e = 4; - - TensorView* tv0 = makeContigConcreteTensor({d * b, s, h * e}); - TensorView* tv1 = reshape(tv0, {d * b, s, h * e}, {d * b * s * h, e}); - fusion->addInput(tv0); - fusion->addOutput(tv1); - - auto mesh = DeviceMesh::createForNumDevices(d); - tv0->setDeviceMesh(mesh); - tv0->split(0, d, /*inner_split=*/false); - tv0->axis(0)->parallelize(ParallelType::DIDx); - - preseg_passes::OptimizationPass< - preseg_passes::PropagateShardingsPass>::runPass(fusion.get()); - for (auto* tv : {tv0, tv1}) { - reorderDIDToFront(tv); - tv->setAllocationDomain(tv->getLoopDomain(), true); - } - - NVF_CHECK(getShardedLogicalAxis(tv1, ParallelType::DIDx) == 0); - at::Tensor inp = at::randn({d * b, s, h * e}, tensor_options); - at::Tensor sharded_inp = shardTensor(inp, 0, mesh); - FusionExecutorCache executor_cache(std::move(fusion)); - at::Tensor nvf_out = - executor_cache.runFusionWithInputs({sharded_inp})[0].as(); - testValidate( - executor_cache.fusion(), - {nvf_out}, - {sharded_inp}, - {sharded_inp.view({b * s * h, e})}, - __LINE__, - __FILE__); -} - -TEST_F(MultiDevicePresegPassesTest, DISABLED_SliceReshapePermute) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - const int d = communicator_->size(); - const int64_t b = 2, s = 3, h = 128, a = 8; - - auto mesh = DeviceMesh::createForNumDevices(d); - - TensorView* tv0 = makeConcreteTensor({b, s, 3 * d * h}); - TensorView* tv1 = reshape(tv0, {b, s, 3 * d * h}, {b, s, d * a, 3 * h / a}); - TensorView* tv2 = slice(tv1, {0, 0, 0, 0}, {b, s, d * a, h / a}); - TensorView* tv3 = permute(tv2, {0, 2, 1, 3}); - - fusion->addInput(tv0); - fusion->addOutput(tv3); - - tv0->setDeviceMesh(mesh); - tv0->split(-1, d, /*inner_split=*/false); - tv0->axis(-2)->parallelize(ParallelType::DIDx); - - preseg_passes::OptimizationPass< - preseg_passes::PropagateShardingsPass>::runPass(fusion.get()); - - for (auto* tv : fusion->allTvs()) { - reorderDIDToFront(tv); - tv->setAllocationDomain(tv->getLoopDomain(), true); - } - - FusionExecutorCache executor_cache(std::move(fusion)); - at::Tensor inp = at::randn({b, s, 3 * d * h}, tensor_options); - at::Tensor sharded_inp = shardTensor(inp, -1, mesh); - at::Tensor nvf_out = - executor_cache.runFusionWithInputs({sharded_inp})[0].as(); - - at::Tensor reference_out = sharded_inp.view({b, s, a, 3 * h / a}) - .index( - {at::indexing::Slice(0), - at::indexing::Slice(0), - at::indexing::Slice(0), - at::indexing::Slice(0, h / a)}) - .transpose(1, 2); - - testValidate( - executor_cache.fusion(), - {nvf_out}, - {sharded_inp}, - {reference_out}, - __LINE__, - __FILE__); } namespace { @@ -221,13 +115,6 @@ TEST_F(MultiDevicePresegPassesTest, MHAFwd) { qkv->outer_split(2, d); qkv->axis(2)->parallelize(ParallelType::DIDx); - preseg_passes::OptimizationPass< - preseg_passes::PropagateShardingsPass>::runPass(fusion.get()); - for (auto tv : fusion->allTvs()) { - reorderDIDToFront(tv); - tv->setAllocationDomain(tv->getLoopDomain(), true); - } - FusionExecutorCache executor_cache(std::move(fusion)); at::Tensor unsharded_inp_tensor = at::randn({b, s, d * a, 3 * h / a}, tensor_options.dtype(at::kHalf)); at::Tensor inp_tensor = shardTensor(unsharded_inp_tensor, 2, mesh); From c61f5480cf4f3f74bbabeb65ad10d9c95b2b10ff Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Fri, 11 Apr 2025 11:41:02 -0700 Subject: [PATCH 42/70] extraneous change --- tests/cpp/test_multidevice_sharding.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/cpp/test_multidevice_sharding.cpp b/tests/cpp/test_multidevice_sharding.cpp index bdda0192b39..5a26d5d6622 100644 --- a/tests/cpp/test_multidevice_sharding.cpp +++ b/tests/cpp/test_multidevice_sharding.cpp @@ -774,6 +774,7 @@ TEST_F(MultiDeviceTest, TransformPropagatorSplitReshape) { // Loop split and parallelize input tv0->setDeviceMesh(mesh); + tv1->setDeviceMesh(mesh); tv0->split(-2, d, /*inner_split=*/false); tv0->axis(-3)->parallelize(ParallelType::DIDx); // in: loop domain: {b, s, DIDx{d}, h, e} From 61dbdc6fcd2326413caa8c170cf1655155e476c3 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Fri, 11 Apr 2025 12:15:08 -0700 Subject: [PATCH 43/70] shardAllLike --- csrc/multidevice/utils.cpp | 13 +++++++++---- csrc/multidevice/utils.h | 4 ++-- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp index 8da9d2afb1b..8732766ac95 100644 --- a/csrc/multidevice/utils.cpp +++ b/csrc/multidevice/utils.cpp @@ -595,15 +595,20 @@ bool isInnerResharding(Expr* expr) { return false; } -void shardAllLike(TensorView* ref, std::vector tvs) { +void shardAllLike(TensorView* ref, std::vector tvs, std::unordered_set excluded_parallel_types) { for (auto tv : tvs) { tv->setDeviceMesh(ref->getDeviceMesh()); } if (!tvs.empty()) { - scheduler_utils::parallelizeAllLike( - ref, tvs, {ParallelType::DIDx, ParallelType::Serial}); + std::unordered_set parallel_types; + parallel_types.insert(ParallelType::Serial); + for (auto pt : kParallelTypeDIDs) { + if (!excluded_parallel_types.count(pt)) { + parallel_types.insert(pt); + } + } + scheduler_utils::parallelizeAllLike(ref, tvs, parallel_types); } -} void shardBetween( const std::vector& from, diff --git a/csrc/multidevice/utils.h b/csrc/multidevice/utils.h index 34c510ccb2e..3ae5aebefac 100644 --- a/csrc/multidevice/utils.h +++ b/csrc/multidevice/utils.h @@ -57,8 +57,8 @@ bool haveDifferentShardings( // Returns whether a resharding expr reshards an inner axis bool isInnerResharding(Expr* expr); -// Shards all tensors in tvs like reference -void shardAllLike(TensorView* ref, std::vector tvs); +// Shards all tensors in tvs like reference except for the parallel types in excluded_parallel_types +void shardAllLike(TensorView* ref, std::vector tvs, std::unordered_set excluded_parallel_types={}); // Shards all TVs between from and to AND between TVs created inside a fusion // and to. This is required for (1) expressions like rng_uniform that create a From 24331ad4893c100192371a3339483ed81475374b Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Fri, 11 Apr 2025 15:49:37 -0700 Subject: [PATCH 44/70] fix build error, test without serial --- csrc/multidevice/utils.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp index 8732766ac95..0d33e0f35da 100644 --- a/csrc/multidevice/utils.cpp +++ b/csrc/multidevice/utils.cpp @@ -601,7 +601,7 @@ void shardAllLike(TensorView* ref, std::vector tvs, std::unordered_ } if (!tvs.empty()) { std::unordered_set parallel_types; - parallel_types.insert(ParallelType::Serial); + // parallel_types.insert(ParallelType::Serial); for (auto pt : kParallelTypeDIDs) { if (!excluded_parallel_types.count(pt)) { parallel_types.insert(pt); @@ -609,6 +609,7 @@ void shardAllLike(TensorView* ref, std::vector tvs, std::unordered_ } scheduler_utils::parallelizeAllLike(ref, tvs, parallel_types); } +} void shardBetween( const std::vector& from, From 3f146b6678dd6bbb036854fd2dfae85ba21436b3 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Fri, 11 Apr 2025 16:02:41 -0700 Subject: [PATCH 45/70] test without changing preseg order --- csrc/preseg_passes/pre_segmenter.cpp | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/csrc/preseg_passes/pre_segmenter.cpp b/csrc/preseg_passes/pre_segmenter.cpp index 1449a9bfc8a..042f03191f7 100644 --- a/csrc/preseg_passes/pre_segmenter.cpp +++ b/csrc/preseg_passes/pre_segmenter.cpp @@ -39,6 +39,12 @@ namespace nvfuser::preseg_passes { debug() << "========================================" << std::endl; } + // For resharding across GPUs. + OptimizationPass::runPass(fusion); + OptimizationPass::runPass(fusion); + OptimizationPass::runPass(fusion); + OptimizationPass::runPass(fusion); + // Replace TensorViews with zero extent. Outputs and inputs may still be empty OptimizationPass::runPass(fusion); // This pass should be placed before ConsecutiveCastPass as more @@ -75,13 +81,6 @@ namespace nvfuser::preseg_passes { OptimizationPass::runPass(fusion); OptimizationPass::runPass(fusion); OptimizationPass::runPass(fusion); - - // For resharding across GPUs. - OptimizationPass::runPass(fusion); - OptimizationPass::runPass(fusion); - OptimizationPass::runPass(fusion); - OptimizationPass::runPass(fusion); - OptimizationPass::runPass(fusion); OptimizationPass::runPass(fusion); OptimizationPass::runPass(fusion); From 916a693033c318169b5e17c82a3a8bbccf8dbee8 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Fri, 11 Apr 2025 18:27:19 -0700 Subject: [PATCH 46/70] include parallel type serial --- csrc/multidevice/utils.cpp | 2 +- csrc/preseg_passes/propagate_shardings.cpp | 10 +++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp index 0d33e0f35da..ce446f4cc3a 100644 --- a/csrc/multidevice/utils.cpp +++ b/csrc/multidevice/utils.cpp @@ -601,7 +601,7 @@ void shardAllLike(TensorView* ref, std::vector tvs, std::unordered_ } if (!tvs.empty()) { std::unordered_set parallel_types; - // parallel_types.insert(ParallelType::Serial); + parallel_types.insert(ParallelType::Serial); for (auto pt : kParallelTypeDIDs) { if (!excluded_parallel_types.count(pt)) { parallel_types.insert(pt); diff --git a/csrc/preseg_passes/propagate_shardings.cpp b/csrc/preseg_passes/propagate_shardings.cpp index 85878c835ff..e2923ca6c15 100644 --- a/csrc/preseg_passes/propagate_shardings.cpp +++ b/csrc/preseg_passes/propagate_shardings.cpp @@ -337,10 +337,14 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { // as allocation domain in the interim. Ideally, this should follow logical // domain and DIDx axis at the front. The allocation domain should follow // any stride order specified/inferred. + // debug() << "reorderLoopAsAllocation\n"; + // debug() << "before: " << fusion->toString() << "\n"; reorderLoopAsAllocation(fusion->allTvs()); - for (auto tv : fusion->allTvs()) { - tv->setAllocationDomain(tv->getLoopDomain(), true); - } + + // for (auto tv : fusion->allTvs()) { + // tv->setAllocationDomain(tv->getLoopDomain(), true); + // } + // debug() << "after: " << fusion->toString() << "\n"; } } From c0e96cfddffc79f1d8eebfe9e3550b485b241fb9 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Fri, 11 Apr 2025 18:31:35 -0700 Subject: [PATCH 47/70] undo debug changes --- csrc/preseg_passes/propagate_shardings.cpp | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/csrc/preseg_passes/propagate_shardings.cpp b/csrc/preseg_passes/propagate_shardings.cpp index e2923ca6c15..7132177517d 100644 --- a/csrc/preseg_passes/propagate_shardings.cpp +++ b/csrc/preseg_passes/propagate_shardings.cpp @@ -337,14 +337,11 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { // as allocation domain in the interim. Ideally, this should follow logical // domain and DIDx axis at the front. The allocation domain should follow // any stride order specified/inferred. - // debug() << "reorderLoopAsAllocation\n"; - // debug() << "before: " << fusion->toString() << "\n"; reorderLoopAsAllocation(fusion->allTvs()); - // for (auto tv : fusion->allTvs()) { - // tv->setAllocationDomain(tv->getLoopDomain(), true); - // } - // debug() << "after: " << fusion->toString() << "\n"; + for (auto tv : fusion->allTvs()) { + tv->setAllocationDomain(tv->getLoopDomain(), true); + } } } From 632f74390bd2e17f49305270c88c371141fe5b55 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Fri, 11 Apr 2025 20:00:08 -0700 Subject: [PATCH 48/70] propagate only for nonsharded inputs --- csrc/preseg_passes/propagate_shardings.cpp | 25 ++++++++++---------- tests/cpp/test_multidevice_preseg_passes.cpp | 9 +------ 2 files changed, 14 insertions(+), 20 deletions(-) diff --git a/csrc/preseg_passes/propagate_shardings.cpp b/csrc/preseg_passes/propagate_shardings.cpp index 7132177517d..c44d79ea7d5 100644 --- a/csrc/preseg_passes/propagate_shardings.cpp +++ b/csrc/preseg_passes/propagate_shardings.cpp @@ -48,6 +48,13 @@ std::vector filterTvsWithMesh(const Range& tvs) { return tvs_with_mesh; } +int64_t numDeviceDims(TensorView* tv) { + return std::count_if( + tv->getLoopDomain().begin(), + tv->getLoopDomain().end(), + std::mem_fn(&IterDomain::isDeviceDim)); +}; + // Sort the given tvs by the number of device dimensions in descending order. // Break ties by the total number of dimensions. // Only includes TensorViews that have a device mesh. @@ -56,15 +63,8 @@ std::vector sortTvsByDeviceDims(const Range& tvs) { // Filter out TVs without a device mesh std::vector tvs_with_mesh = filterTvsWithMesh(tvs); - auto numDeviceDims = [](TensorView* tv) -> int64_t { - return std::count_if( - tv->getLoopDomain().begin(), - tv->getLoopDomain().end(), - std::mem_fn(&IterDomain::isDeviceDim)); - }; - // Then sort the filtered TVs - std::sort(tvs_with_mesh.begin(), tvs_with_mesh.end(), [&numDeviceDims](auto a, auto b) { + std::sort(tvs_with_mesh.begin(), tvs_with_mesh.end(), [](auto a, auto b) { int64_t a_device_dims = numDeviceDims(a); int64_t b_device_dims = numDeviceDims(b); if (a_device_dims != b_device_dims) { @@ -261,7 +261,7 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { propagateDIDTransform(/*ref=*/ref_input, /*tvs=*/outputs_without_mesh, /*did_pos=*/did_pos, /*allow_c2p=*/false, /*allow_p2c=*/true); // Apply parallelization on the outputs without mesh. - shardAllLike(ref_input, outputs_without_mesh, existing_parallel_types); + shardAllLike(ref_input, outputs_without_mesh); } } @@ -304,7 +304,9 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { } continue; } - sharding_candidates.push_back(tv); + if (!tv->hasDeviceMesh() || numDeviceDims(tv) == 0) { + sharding_candidates.push_back(tv); + } } if (sharding_candidates.empty()) { @@ -323,7 +325,7 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { /*did_pos=*/did_pos, /*allow_c2p=*/true, /*allow_p2c=*/false); - shardAllLike(ref_output, {tv}, existing_parallel_types); + shardAllLike(ref_output, {tv}); } } @@ -338,7 +340,6 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { // domain and DIDx axis at the front. The allocation domain should follow // any stride order specified/inferred. reorderLoopAsAllocation(fusion->allTvs()); - for (auto tv : fusion->allTvs()) { tv->setAllocationDomain(tv->getLoopDomain(), true); } diff --git a/tests/cpp/test_multidevice_preseg_passes.cpp b/tests/cpp/test_multidevice_preseg_passes.cpp index 4f66782673d..7e2f1a001fd 100644 --- a/tests/cpp/test_multidevice_preseg_passes.cpp +++ b/tests/cpp/test_multidevice_preseg_passes.cpp @@ -124,14 +124,7 @@ TEST_F(MultiDevicePresegPassesTest, MHAFwd) { at::Tensor nvf_out = outputs[0].as(); at::Tensor ref_out = reference_mha(inp_tensor); - - testValidate( - executor_cache.fusion(), - {nvf_out}, - {inp_tensor}, - {ref_out}, - __LINE__, - __FILE__); + EXPECT_TRUE(at::allclose(nvf_out, ref_out)); } } // namespace nvfuser From 2bb46243b9c2a557bda0165992e2c2fe861fd9b0 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Fri, 11 Apr 2025 20:18:48 -0700 Subject: [PATCH 49/70] revert to only propagating for unsharded inputs --- csrc/preseg_passes/propagate_shardings.cpp | 25 ++++++++++------------ 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/csrc/preseg_passes/propagate_shardings.cpp b/csrc/preseg_passes/propagate_shardings.cpp index c44d79ea7d5..32bd89cd533 100644 --- a/csrc/preseg_passes/propagate_shardings.cpp +++ b/csrc/preseg_passes/propagate_shardings.cpp @@ -313,20 +313,17 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { continue; } - for (auto tv : sharding_candidates) { - std::unordered_set existing_parallel_types = getTvParallelTypes({tv}); - int64_t did_pos = selectiveReorderDIDToFront(ref_output, existing_parallel_types); - // Note: We do not have to manually shard for reshape here. - // TransformPropagator can handle reshapes when going from consumer to - // producer. - propagateDIDTransform( - /*ref=*/ref_output, - /*tvs=*/{tv}, - /*did_pos=*/did_pos, - /*allow_c2p=*/true, - /*allow_p2c=*/false); - shardAllLike(ref_output, {tv}); - } + int64_t did_pos = selectiveReorderDIDToFront(ref_output, {}); + // Note: We do not have to manually shard for reshape here. + // TransformPropagator can handle reshapes when going from consumer to + // producer. + propagateDIDTransform( + /*ref=*/ref_output, + /*tvs=*/sharding_candidates, + /*did_pos=*/did_pos, + /*allow_c2p=*/true, + /*allow_p2c=*/false); + shardAllLike(ref_output, sharding_candidates); } bool has_mesh = validateMeshes(fusion); From ad931e463f7d5379e0ed2c898786d09278f4b8a3 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Mon, 14 Apr 2025 13:46:07 -0700 Subject: [PATCH 50/70] derive contiguity through transforms --- csrc/preseg_passes/propagate_shardings.cpp | 86 ++++++++++++++------ tests/cpp/test_multidevice_preseg_passes.cpp | 57 ++++++++++++- 2 files changed, 113 insertions(+), 30 deletions(-) diff --git a/csrc/preseg_passes/propagate_shardings.cpp b/csrc/preseg_passes/propagate_shardings.cpp index 32bd89cd533..5d18b24ef5d 100644 --- a/csrc/preseg_passes/propagate_shardings.cpp +++ b/csrc/preseg_passes/propagate_shardings.cpp @@ -137,32 +137,66 @@ class PropagateShardingsSelector : public SetSelector { } }; -void reorderLoopAsAllocation(std::vector tvs) { - // Transform the maybe allocation domain to the loop domain. - // using exprs between logical and loop and get the permutation required to - // reorder the loop domain in the same relative order as the allocation domain. - for (auto tv : tvs) { - auto alloc_dom = tv->getMaybeAllocationDomain(); - // Allocation domain should be a permutation of logical domain at this point. - std::vector transform_exprs = DependencyCheck::getAllExprsBetween( - {alloc_dom.begin(), alloc_dom.end()}, - {tv->getLoopDomain().begin(), tv->getLoopDomain().end()}); - NVF_ERROR( - std::all_of( - transform_exprs.begin(), - transform_exprs.end(), - [](Expr* expr) { return expr->isA(); }), - "Expected all transform exprs to be a split between logical and loop domain during sharding propagation."); - scheduler_utils::applyTransforms(alloc_dom, transform_exprs); - std::optional> permutation = ir_utils::computePermutation(alloc_dom, tv->getLoopDomain()); +// Transform the maybe allocation domain to the loop domain. +// using exprs between logical and loop and get the permutation required to +// reorder the loop domain in the same relative order as the allocation domain. +// Returns the contiguity of the transformed allocation domain. +std::vector> reorderLoopAsAllocation(TensorView* tv) { + auto alloc_dom = tv->getMaybeAllocationDomain(); + auto contiguity = tv->getContiguity(); + + auto splitContiguity = [](std::optional contiguity) -> std::pair, std::optional>{ + if (!contiguity.has_value()) { + return std::make_pair(std::nullopt, std::nullopt); + } + if (contiguity.value()) { + return std::make_pair(true, true); + } + return std::make_pair(true, false); + }; + + // Allocation domain should be a permutation of logical domain at this point. + std::vector transform_exprs = DependencyCheck::getAllExprsBetween( + {alloc_dom.begin(), alloc_dom.end()}, + {tv->getLoopDomain().begin(), tv->getLoopDomain().end()}); + + NVF_ERROR( + std::all_of( + transform_exprs.begin(), + transform_exprs.end(), + [](Expr* expr) { return expr->isA(); }), + "Expected all transform exprs to be a split between logical and loop domain during sharding propagation."); + + for (auto* expr: transform_exprs) { + Split* split = dynamic_cast(expr); + auto find_it = std::find(alloc_dom.begin(), alloc_dom.end(), split->in()); NVF_ERROR( - permutation.has_value(), - "Failed to find a valid permutation for reordering", - tv->getLoopDomain(), - " as ", - alloc_dom); - tv->reorder(permutation.value()); + find_it != alloc_dom.end(), + "Split input ", + split->in()->toString(), + " not found in given ids: ", + alloc_dom); + + auto pos = std::distance(alloc_dom.begin(), find_it); + auto [outer_contiguity, inner_contiguity] = splitContiguity(contiguity.at(pos)); + + alloc_dom[pos] = split->inner(); + alloc_dom.insert(alloc_dom.begin() + pos, split->outer()); + + contiguity[pos] = inner_contiguity; + contiguity.insert(contiguity.begin() + pos, outer_contiguity); } + + std::optional> permutation = ir_utils::computePermutation(alloc_dom, tv->getLoopDomain()); + NVF_ERROR( + permutation.has_value(), + "Failed to find a valid permutation for reordering", + tv->getLoopDomain(), + " as ", + alloc_dom); + tv->reorder(permutation.value()); + + return contiguity; } // Reorder the DID axis to the front only if it does not have a parallel type @@ -336,9 +370,9 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { // as allocation domain in the interim. Ideally, this should follow logical // domain and DIDx axis at the front. The allocation domain should follow // any stride order specified/inferred. - reorderLoopAsAllocation(fusion->allTvs()); for (auto tv : fusion->allTvs()) { - tv->setAllocationDomain(tv->getLoopDomain(), true); + auto contiguity = reorderLoopAsAllocation(tv); + tv->setAllocationDomain(tv->getLoopDomain(), contiguity); } } } diff --git a/tests/cpp/test_multidevice_preseg_passes.cpp b/tests/cpp/test_multidevice_preseg_passes.cpp index 7e2f1a001fd..54540fc5acb 100644 --- a/tests/cpp/test_multidevice_preseg_passes.cpp +++ b/tests/cpp/test_multidevice_preseg_passes.cpp @@ -20,7 +20,7 @@ namespace nvfuser { -constexpr int64_t b = 2, s = 3, h = 128, a = 8; +constexpr int64_t b = 2, s = 3, h = 64, a = 8; constexpr double dropout_p = 0.0; constexpr bool is_causal = false; @@ -62,9 +62,15 @@ TEST_F(MultiDevicePresegPassesTest, ResidualAdd) { } namespace { +at::Tensor reference_mlp(at::Tensor inp, at::Tensor w0, at::Tensor w1) { + auto linear0 = at::linear(inp, w0); + auto gelu = at::gelu(linear0, "tanh"); + auto linear1 = at::linear(gelu, w1); + return linear1; +} + at::Tensor reference_mha(at::Tensor inp) { - auto qkv = - inp.view({b, s, a, 3 * h / a}).transpose(1, 2).split(h / a, -1); + auto qkv = inp.transpose(1, 2).split(h / a, -1); double scale = 1.0 / std::sqrt(h / a); auto sdpa_out = at::_scaled_dot_product_flash_attention( qkv[0], @@ -79,6 +85,49 @@ at::Tensor reference_mha(at::Tensor inp) { } } // namespace +TEST_F(MultiDevicePresegPassesTest, MLP) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + const int d = communicator_->size(); + auto mesh = DeviceMesh::createForNumDevices(d); + + TensorView* inp = makeContigConcreteTensor({b, s, h}); + TensorView* w0 = makeContigConcreteTensor({4*d*h, h}); + TensorView* w1 = makeContigConcreteTensor({h, 4*d*h}); + + TensorView* linear0 = linear(inp, w0); + TensorView* gelu = tanh_gelu(linear0); + TensorView* linear1 = linear(gelu, w1); + + std::vector fusion_inputs {inp, w0, w1}; + for (auto tv: fusion_inputs){ + fusion->addInput(tv); + tv->setDeviceMesh(mesh); + } + fusion->addOutput(linear1); + + w0->outer_split(0, d); + w0->axis(0)->parallelize(ParallelType::DIDx); + w1->outer_split(1, d); + w1->axis(1)->parallelize(ParallelType::DIDx); + + FusionExecutorCache executor_cache(std::move(fusion)); + at::Tensor inp_tensor = at::randn({b, s, h}, tensor_options.dtype(at::kFloat)); + at::Tensor w0_tensor = at::randn({4*d*h, h}, tensor_options.dtype(at::kFloat)); + at::Tensor w1_tensor = at::randn({h, 4*d*h}, tensor_options.dtype(at::kFloat)); + + at::Tensor w0_sharded = shardTensor(w0_tensor, 0, mesh); + at::Tensor w1_sharded = shardTensor(w1_tensor, 1, mesh); + + KernelArgumentHolder args = {inp_tensor, w0_sharded, w1_sharded}; + auto outputs = executor_cache.runFusionWithInputs(args); + at::Tensor nvf_out = outputs[0].as(); + + at::Tensor ref_out = reference_mlp(inp_tensor, w0_tensor, w1_tensor); + EXPECT_TRUE(at::allclose(nvf_out, ref_out)); +} + TEST_F(MultiDevicePresegPassesTest, MHAFwd) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -88,7 +137,7 @@ TEST_F(MultiDevicePresegPassesTest, MHAFwd) { auto mesh = DeviceMesh::createForNumDevices(d); - TensorView* qkv = makeConcreteTensor({b, s, d * a, 3 * h / a}, DataType::Half); + TensorView* qkv = makeContigConcreteTensor({b, s, d * a, 3 * h / a}, DataType::Half); TensorView* q = slice(qkv, {0, 0, 0, 0}, {b, s, d * a, h / a}); TensorView* k = slice(qkv, {0, 0, 0, h / a}, {b, s, d * a, 2 * h / a}); TensorView* v = slice(qkv, {0, 0, 0, 2 * h / a}, {b, s, d * a, 3 * h / a}); From 3757b3cd30a5be80aa7412e4bcfc2ecbc4532524 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Mon, 14 Apr 2025 17:06:48 -0700 Subject: [PATCH 51/70] fix test --- tests/cpp/test_multidevice_preseg_passes.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/cpp/test_multidevice_preseg_passes.cpp b/tests/cpp/test_multidevice_preseg_passes.cpp index 54540fc5acb..f3df9c31528 100644 --- a/tests/cpp/test_multidevice_preseg_passes.cpp +++ b/tests/cpp/test_multidevice_preseg_passes.cpp @@ -20,7 +20,7 @@ namespace nvfuser { -constexpr int64_t b = 2, s = 3, h = 64, a = 8; +constexpr int64_t b = 2, s = 3, h = 16, a = 2; constexpr double dropout_p = 0.0; constexpr bool is_causal = false; @@ -133,7 +133,6 @@ TEST_F(MultiDevicePresegPassesTest, MHAFwd) { FusionGuard fg(fusion.get()); const int d = communicator_->size(); - const int64_t b = 2, s = 3, h = 128, a = 8; auto mesh = DeviceMesh::createForNumDevices(d); From bd945e3325d8c8bea84c11f37859fbddf9d53fd8 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Mon, 14 Apr 2025 21:35:01 -0700 Subject: [PATCH 52/70] set allocation domain last --- .../make_resharding_contiguous.cpp | 135 +++++++++++++++--- csrc/preseg_passes/mark_aliases_prepare.cpp | 3 + csrc/preseg_passes/pre_segmenter.cpp | 10 +- csrc/preseg_passes/propagate_shardings.cpp | 15 -- 4 files changed, 127 insertions(+), 36 deletions(-) diff --git a/csrc/preseg_passes/make_resharding_contiguous.cpp b/csrc/preseg_passes/make_resharding_contiguous.cpp index 04fbe0d7173..5f6f14c078b 100644 --- a/csrc/preseg_passes/make_resharding_contiguous.cpp +++ b/csrc/preseg_passes/make_resharding_contiguous.cpp @@ -12,36 +12,135 @@ #include #include #include +#include namespace nvfuser::preseg_passes { namespace { -void setShardedAllocationDomain(TensorView* tv) { - if (!tv->hasAllocation()) { - tv->setAllocationDomain(tv->getLoopDomain(), true); + + +// Validates meshes (i.e. all TensorViews have a device mesh or none) and returns true if any TensorView has a device mesh. +bool validateMeshes(Fusion* fusion) { + // Validate that meshes are assigned to all TensorViews or none. + bool tv_with_mesh_found = false; + bool tv_without_mesh_found = false; + + for (auto tv : fusion->allTvs()) { + if (tv->isCpuScalar()) { + continue; + } + tv->hasDeviceMesh() ? tv_with_mesh_found = true : tv_without_mesh_found = true; + } + NVF_CHECK( + !(tv_with_mesh_found && tv_without_mesh_found), + "Cannot have some TensorViews with device mesh and some without."); + return tv_with_mesh_found; +} + +// Transform the maybe allocation domain to the loop domain. +// using exprs between logical and loop and get the permutation required to +// reorder the loop domain in the same relative order as the allocation domain. +// Returns the contiguity of the transformed allocation domain. +std::vector> reorderLoopAsAllocation(TensorView* tv) { + auto alloc_dom = tv->getMaybeAllocationDomain(); + auto contiguity = tv->getContiguity(); + + auto splitContiguity = [](std::optional contiguity) -> std::pair, std::optional>{ + if (!contiguity.has_value()) { + return std::make_pair(std::nullopt, std::nullopt); + } + if (contiguity.value()) { + return std::make_pair(true, true); + } + return std::make_pair(true, false); + }; + + // Allocation domain should be a permutation of logical domain at this point. + std::vector transform_exprs = DependencyCheck::getAllExprsBetween( + {alloc_dom.begin(), alloc_dom.end()}, + {tv->getLoopDomain().begin(), tv->getLoopDomain().end()}); + + NVF_ERROR( + std::all_of( + transform_exprs.begin(), + transform_exprs.end(), + [](Expr* expr) { return expr->isA(); }), + "Expected all transform exprs to be a split between logical and loop domain during sharding propagation."); + + for (auto* expr: transform_exprs) { + Split* split = dynamic_cast(expr); + auto find_it = std::find(alloc_dom.begin(), alloc_dom.end(), split->in()); + NVF_ERROR( + find_it != alloc_dom.end(), + "Split input ", + split->in()->toString(), + " not found in given ids: ", + alloc_dom); + + auto pos = std::distance(alloc_dom.begin(), find_it); + auto [outer_contiguity, inner_contiguity] = splitContiguity(contiguity.at(pos)); + + alloc_dom[pos] = split->inner(); + alloc_dom.insert(alloc_dom.begin() + pos, split->outer()); + + contiguity[pos] = inner_contiguity; + contiguity.insert(contiguity.begin() + pos, outer_contiguity); + } + + std::optional> permutation = ir_utils::computePermutation(alloc_dom, tv->getLoopDomain()); + NVF_ERROR( + permutation.has_value(), + "Failed to find a valid permutation for reordering", + tv->getLoopDomain(), + " as ", + alloc_dom); + tv->reorder(permutation.value()); + + return contiguity; +} + +bool isTvContiguous(TensorView* tv) { + return std::all_of( + tv->getContiguity().begin(), + tv->getContiguity().end(), + [](std::optional c) { + return c.value_or(true); + } + ); +} + +template +void setShardedAllocationDomain(Range tvs) { + for (auto tv: tvs) { + auto contiguity = reorderLoopAsAllocation(tv); + tv->setAllocationDomain(tv->getLoopDomain(), contiguity); } } + } // namespace void MakeReshardingContiguousPass::runPass(Fusion* fusion) { + bool has_mesh = validateMeshes(fusion); + if (!has_mesh) { + return; + } + for (Expr* expr : fusion->exprs()) { - if (!isResharding(expr)) { - continue; - } - for (auto* tv : ir_utils::filterByType(expr->inputs())) { - for (auto c : tv->getContiguity()) { - if (c.has_value()) { - NVF_CHECK( - c.value(), - "Resharding expression input must be contiguous: ", - expr); + auto inputs = ir_utils::filterByType(expr->inputs()); + auto outputs = ir_utils::filterByType(expr->outputs()); + + if (isResharding(expr)) { + NVF_CHECK(std::all_of( + inputs.begin(), + inputs.end(), + [](TensorView* tv) { + return isTvContiguous(tv); } - } - setShardedAllocationDomain(tv); - } - for (auto tv : ir_utils::filterByType(expr->outputs())) { - setShardedAllocationDomain(tv); + ), "Resharding expression inputs must be contiguous: ", expr); } + + setShardedAllocationDomain(inputs); + setShardedAllocationDomain(outputs); } } diff --git a/csrc/preseg_passes/mark_aliases_prepare.cpp b/csrc/preseg_passes/mark_aliases_prepare.cpp index 227d1e3784e..b7136e15854 100644 --- a/csrc/preseg_passes/mark_aliases_prepare.cpp +++ b/csrc/preseg_passes/mark_aliases_prepare.cpp @@ -13,6 +13,7 @@ #include #include #include +#include namespace nvfuser::preseg_passes { @@ -110,12 +111,14 @@ void insertSegmentSetAfter( copy->setAllocationDomain( replayed_domain->allocation(), replayed_domain->contiguity()); } + copy->setLoopDomain(replayed_domain->loop()); std::for_each(first_user, last_user, [&](const Use& use) { ir_utils::replaceValInExprInputs(use.user, use_of, copy); }); if (use_of->isFusionOutput()) { use_of->fusion()->replaceOutput(use_of, copy); } + shardAllLike(use_of, {copy}); } } // namespace diff --git a/csrc/preseg_passes/pre_segmenter.cpp b/csrc/preseg_passes/pre_segmenter.cpp index 042f03191f7..2f8106ea28b 100644 --- a/csrc/preseg_passes/pre_segmenter.cpp +++ b/csrc/preseg_passes/pre_segmenter.cpp @@ -41,9 +41,8 @@ namespace nvfuser::preseg_passes { // For resharding across GPUs. OptimizationPass::runPass(fusion); - OptimizationPass::runPass(fusion); - OptimizationPass::runPass(fusion); - OptimizationPass::runPass(fusion); + + // Replace TensorViews with zero extent. Outputs and inputs may still be empty OptimizationPass::runPass(fusion); @@ -81,6 +80,11 @@ namespace nvfuser::preseg_passes { OptimizationPass::runPass(fusion); OptimizationPass::runPass(fusion); OptimizationPass::runPass(fusion); + + OptimizationPass::runPass(fusion); + OptimizationPass::runPass(fusion); + OptimizationPass::runPass(fusion); + OptimizationPass::runPass(fusion); OptimizationPass::runPass(fusion); OptimizationPass::runPass(fusion); diff --git a/csrc/preseg_passes/propagate_shardings.cpp b/csrc/preseg_passes/propagate_shardings.cpp index 5d18b24ef5d..7b4ccd5ec75 100644 --- a/csrc/preseg_passes/propagate_shardings.cpp +++ b/csrc/preseg_passes/propagate_shardings.cpp @@ -360,21 +360,6 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { shardAllLike(ref_output, sharding_candidates); } - bool has_mesh = validateMeshes(fusion); - if (has_mesh) { - // Reorder the loop domain since the transform propagator may - // have reordered the iterdomains in loop domain. For example: Consider - // linear op: in = [b, m, k] weight = [DIDx(d), n/d, k] After - // transformation, the loop domain of linear output is [DIDx(d), n/d, b, - // m, r{k}]. Since, we set allocation to be the same as loop, we reorder it - // as allocation domain in the interim. Ideally, this should follow logical - // domain and DIDx axis at the front. The allocation domain should follow - // any stride order specified/inferred. - for (auto tv : fusion->allTvs()) { - auto contiguity = reorderLoopAsAllocation(tv); - tv->setAllocationDomain(tv->getLoopDomain(), contiguity); - } - } } } // namespace nvfuser::preseg_passes From a40cf3d2ae3f9b63841cef4bafc8676958843b4c Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Tue, 15 Apr 2025 14:26:52 -0700 Subject: [PATCH 53/70] undo shardAllLike changes --- csrc/multidevice/utils.cpp | 50 +++++++++++++++++++++++++++++++------- csrc/multidevice/utils.h | 4 +-- 2 files changed, 43 insertions(+), 11 deletions(-) diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp index ce446f4cc3a..475d852028b 100644 --- a/csrc/multidevice/utils.cpp +++ b/csrc/multidevice/utils.cpp @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -380,6 +381,43 @@ bool haveDifferentShardings( return true; } + // Special handling of SelectOp for a quick fix + // TODO: work on a proper implementation + if (consumer->definition()->isA()) { + auto* select_op = consumer->definition()->as(); + NVF_ERROR( + select_op->input(0) == producer, "SelectOp input 0 is not producer"); + // If we select into the sharded axis, the op is resharding because the + // axis doesn't exist in the consumer and so becomes "replicated". + // + // tv0 = makeContigTensor(2); // [DIDx(4), 8] on mesh {0,1,2,3} + // tv1 = select(tv0, /*axis=*/0, /*index=*/1); // [8] on mesh {0,1,2,3} + // + // The long term better solution would actually to "select" into the + // DeviceMesh, e.g., + // + // tv0 = makeContigTensor(2); // [DIDx(4), 8] on mesh {0,1,2,3} + // tv1 = select(tv0, /*axis=*/0, /*index=*/1); // [8] on mesh {1} + // But for achieving this with symbolic "index" we need to make DeviceMesh + // symbolic. + if (select_op->getIndexedID()->isDeviceDim()) { + return true; + } + // If the sharded axis is not selected into, then we still need to check + // that other axis do not get resharded. + const std::unordered_map& c2p = + PairwiseLogicalDomainMap(producer, consumer) + .mapBroadcast(false) + .mapConsumerToProducer(); + return !std::all_of( + consumer->getLoopDomain().begin(), + consumer->getLoopDomain().end(), + [&c2p](IterDomain* c_id) { + auto p_id = c2p.at(c_id); + return c_id->isDeviceDim() == p_id->isDeviceDim(); + }); + } + // The rest of this function tries to do the following: for each pair of // logical-domain-mapped IterDomains (i.e. those mapped by // PairwiseLogicalDomainMap), check if they are sharded consistently. If not, @@ -595,19 +633,13 @@ bool isInnerResharding(Expr* expr) { return false; } -void shardAllLike(TensorView* ref, std::vector tvs, std::unordered_set excluded_parallel_types) { +void shardAllLike(TensorView* ref, std::vector tvs) { for (auto tv : tvs) { tv->setDeviceMesh(ref->getDeviceMesh()); } if (!tvs.empty()) { - std::unordered_set parallel_types; - parallel_types.insert(ParallelType::Serial); - for (auto pt : kParallelTypeDIDs) { - if (!excluded_parallel_types.count(pt)) { - parallel_types.insert(pt); - } - } - scheduler_utils::parallelizeAllLike(ref, tvs, parallel_types); + scheduler_utils::parallelizeAllLike( + ref, tvs, {ParallelType::DIDx, ParallelType::Serial}); } } diff --git a/csrc/multidevice/utils.h b/csrc/multidevice/utils.h index 3ae5aebefac..34c510ccb2e 100644 --- a/csrc/multidevice/utils.h +++ b/csrc/multidevice/utils.h @@ -57,8 +57,8 @@ bool haveDifferentShardings( // Returns whether a resharding expr reshards an inner axis bool isInnerResharding(Expr* expr); -// Shards all tensors in tvs like reference except for the parallel types in excluded_parallel_types -void shardAllLike(TensorView* ref, std::vector tvs, std::unordered_set excluded_parallel_types={}); +// Shards all tensors in tvs like reference +void shardAllLike(TensorView* ref, std::vector tvs); // Shards all TVs between from and to AND between TVs created inside a fusion // and to. This is required for (1) expressions like rng_uniform that create a From 3b7e77ff4e82ef91ba52c31f34c7689b53fc8bb5 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Tue, 15 Apr 2025 14:28:07 -0700 Subject: [PATCH 54/70] move allocation domain to makeReshardingContiguous --- .../make_resharding_contiguous.cpp | 1 - .../make_resharding_contiguous.h | 15 +++- csrc/preseg_passes/mark_aliases_prepare.cpp | 2 - csrc/preseg_passes/pre_segmenter.cpp | 9 +-- csrc/preseg_passes/propagate_shardings.cpp | 80 ------------------- 5 files changed, 14 insertions(+), 93 deletions(-) diff --git a/csrc/preseg_passes/make_resharding_contiguous.cpp b/csrc/preseg_passes/make_resharding_contiguous.cpp index 5f6f14c078b..44efd5ae799 100644 --- a/csrc/preseg_passes/make_resharding_contiguous.cpp +++ b/csrc/preseg_passes/make_resharding_contiguous.cpp @@ -18,7 +18,6 @@ namespace nvfuser::preseg_passes { namespace { - // Validates meshes (i.e. all TensorViews have a device mesh or none) and returns true if any TensorView has a device mesh. bool validateMeshes(Fusion* fusion) { // Validate that meshes are assigned to all TensorViews or none. diff --git a/csrc/preseg_passes/make_resharding_contiguous.h b/csrc/preseg_passes/make_resharding_contiguous.h index 60ded24f76d..d9c7d002db0 100644 --- a/csrc/preseg_passes/make_resharding_contiguous.h +++ b/csrc/preseg_passes/make_resharding_contiguous.h @@ -15,11 +15,18 @@ namespace nvfuser::preseg_passes { -// Resharding expressions are mapped to collective libraries which expect +// This pass: +// 1. Validates that all TensorViews have a device mesh or none. +// 2. Resharding expressions are mapped to collective libraries which expect // contiguous tensors and output contiguous buffers. This pass checks that -// inputs are contiguous and sets the allocation domain of inputs and outputs of -// all resharding expressions. This pass should run after all passes that add or -// update resharding expressions. +// inputs are contiguous. +// 3. Sets the allocation domain of all fusion tvs if they have a device mesh. +// The allocation domain is obtained by transforming the `maybeAllocationDomain` using +// the transforms to loop domain. This ensures that the allocation domain has DID loop splits. +// All iterdomains derived from a given logical iterdomain are placed together. +// See `reorderLoopAsAllocation` for more details. +// Eventually, this pass should run after `markAliasesPrepare` and `AllocationDomainPass` +// after they are fixed. class MakeReshardingContiguousPass : public OptimizationPass { friend class OptimizationPass; diff --git a/csrc/preseg_passes/mark_aliases_prepare.cpp b/csrc/preseg_passes/mark_aliases_prepare.cpp index b7136e15854..b2176d16695 100644 --- a/csrc/preseg_passes/mark_aliases_prepare.cpp +++ b/csrc/preseg_passes/mark_aliases_prepare.cpp @@ -111,14 +111,12 @@ void insertSegmentSetAfter( copy->setAllocationDomain( replayed_domain->allocation(), replayed_domain->contiguity()); } - copy->setLoopDomain(replayed_domain->loop()); std::for_each(first_user, last_user, [&](const Use& use) { ir_utils::replaceValInExprInputs(use.user, use_of, copy); }); if (use_of->isFusionOutput()) { use_of->fusion()->replaceOutput(use_of, copy); } - shardAllLike(use_of, {copy}); } } // namespace diff --git a/csrc/preseg_passes/pre_segmenter.cpp b/csrc/preseg_passes/pre_segmenter.cpp index 2f8106ea28b..2cb32b80c82 100644 --- a/csrc/preseg_passes/pre_segmenter.cpp +++ b/csrc/preseg_passes/pre_segmenter.cpp @@ -41,8 +41,9 @@ namespace nvfuser::preseg_passes { // For resharding across GPUs. OptimizationPass::runPass(fusion); - - + OptimizationPass::runPass(fusion); + OptimizationPass::runPass(fusion); + OptimizationPass::runPass(fusion); // Replace TensorViews with zero extent. Outputs and inputs may still be empty OptimizationPass::runPass(fusion); @@ -81,10 +82,6 @@ namespace nvfuser::preseg_passes { OptimizationPass::runPass(fusion); OptimizationPass::runPass(fusion); - OptimizationPass::runPass(fusion); - OptimizationPass::runPass(fusion); - OptimizationPass::runPass(fusion); - OptimizationPass::runPass(fusion); OptimizationPass::runPass(fusion); OptimizationPass::runPass(fusion); diff --git a/csrc/preseg_passes/propagate_shardings.cpp b/csrc/preseg_passes/propagate_shardings.cpp index 7b4ccd5ec75..a670a82fd05 100644 --- a/csrc/preseg_passes/propagate_shardings.cpp +++ b/csrc/preseg_passes/propagate_shardings.cpp @@ -19,23 +19,6 @@ namespace nvfuser::preseg_passes { namespace { -// Validates meshes (i.e. all TensorViews have a device mesh or none) and returns true if any TensorView has a device mesh. -bool validateMeshes(Fusion* fusion) { - // Validate that meshes are assigned to all TensorViews or none. - bool tv_with_mesh_found = false; - bool tv_without_mesh_found = false; - - for (auto tv : fusion->allTvs()) { - if (tv->isCpuScalar()) { - continue; - } - tv->hasDeviceMesh() ? tv_with_mesh_found = true : tv_without_mesh_found = true; - } - NVF_CHECK( - !(tv_with_mesh_found && tv_without_mesh_found), - "Cannot have some TensorViews with device mesh and some without."); - return tv_with_mesh_found; -} template std::vector filterTvsWithMesh(const Range& tvs) { @@ -137,68 +120,6 @@ class PropagateShardingsSelector : public SetSelector { } }; -// Transform the maybe allocation domain to the loop domain. -// using exprs between logical and loop and get the permutation required to -// reorder the loop domain in the same relative order as the allocation domain. -// Returns the contiguity of the transformed allocation domain. -std::vector> reorderLoopAsAllocation(TensorView* tv) { - auto alloc_dom = tv->getMaybeAllocationDomain(); - auto contiguity = tv->getContiguity(); - - auto splitContiguity = [](std::optional contiguity) -> std::pair, std::optional>{ - if (!contiguity.has_value()) { - return std::make_pair(std::nullopt, std::nullopt); - } - if (contiguity.value()) { - return std::make_pair(true, true); - } - return std::make_pair(true, false); - }; - - // Allocation domain should be a permutation of logical domain at this point. - std::vector transform_exprs = DependencyCheck::getAllExprsBetween( - {alloc_dom.begin(), alloc_dom.end()}, - {tv->getLoopDomain().begin(), tv->getLoopDomain().end()}); - - NVF_ERROR( - std::all_of( - transform_exprs.begin(), - transform_exprs.end(), - [](Expr* expr) { return expr->isA(); }), - "Expected all transform exprs to be a split between logical and loop domain during sharding propagation."); - - for (auto* expr: transform_exprs) { - Split* split = dynamic_cast(expr); - auto find_it = std::find(alloc_dom.begin(), alloc_dom.end(), split->in()); - NVF_ERROR( - find_it != alloc_dom.end(), - "Split input ", - split->in()->toString(), - " not found in given ids: ", - alloc_dom); - - auto pos = std::distance(alloc_dom.begin(), find_it); - auto [outer_contiguity, inner_contiguity] = splitContiguity(contiguity.at(pos)); - - alloc_dom[pos] = split->inner(); - alloc_dom.insert(alloc_dom.begin() + pos, split->outer()); - - contiguity[pos] = inner_contiguity; - contiguity.insert(contiguity.begin() + pos, outer_contiguity); - } - - std::optional> permutation = ir_utils::computePermutation(alloc_dom, tv->getLoopDomain()); - NVF_ERROR( - permutation.has_value(), - "Failed to find a valid permutation for reordering", - tv->getLoopDomain(), - " as ", - alloc_dom); - tv->reorder(permutation.value()); - - return contiguity; -} - // Reorder the DID axis to the front only if it does not have a parallel type // already seen on the output (existing_parallel_types). // Returns the number of device dimensions that were reordered to the front. @@ -359,7 +280,6 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { /*allow_p2c=*/false); shardAllLike(ref_output, sharding_candidates); } - } } // namespace nvfuser::preseg_passes From 044cbc8144f48f83df2a2549b1ae531adaf722eb Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Tue, 15 Apr 2025 18:10:17 -0700 Subject: [PATCH 55/70] undo changes --- csrc/preseg_passes/pre_segmenter.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/csrc/preseg_passes/pre_segmenter.cpp b/csrc/preseg_passes/pre_segmenter.cpp index 2cb32b80c82..042f03191f7 100644 --- a/csrc/preseg_passes/pre_segmenter.cpp +++ b/csrc/preseg_passes/pre_segmenter.cpp @@ -43,7 +43,7 @@ namespace nvfuser::preseg_passes { OptimizationPass::runPass(fusion); OptimizationPass::runPass(fusion); OptimizationPass::runPass(fusion); - OptimizationPass::runPass(fusion); + OptimizationPass::runPass(fusion); // Replace TensorViews with zero extent. Outputs and inputs may still be empty OptimizationPass::runPass(fusion); @@ -81,7 +81,6 @@ namespace nvfuser::preseg_passes { OptimizationPass::runPass(fusion); OptimizationPass::runPass(fusion); OptimizationPass::runPass(fusion); - OptimizationPass::runPass(fusion); OptimizationPass::runPass(fusion); OptimizationPass::runPass(fusion); From 29de7ff463a5a1ee6cb1af21d72fb735c657a3ee Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Tue, 15 Apr 2025 18:14:58 -0700 Subject: [PATCH 56/70] lintrunner --- .../make_resharding_contiguous.cpp | 62 ++++++++++--------- .../make_resharding_contiguous.h | 12 ++-- csrc/preseg_passes/mark_aliases_prepare.cpp | 1 - csrc/preseg_passes/propagate_shardings.cpp | 46 ++++++++------ tests/cpp/test_multidevice_preseg_passes.cpp | 30 +++++---- 5 files changed, 85 insertions(+), 66 deletions(-) diff --git a/csrc/preseg_passes/make_resharding_contiguous.cpp b/csrc/preseg_passes/make_resharding_contiguous.cpp index 44efd5ae799..b9345417ad6 100644 --- a/csrc/preseg_passes/make_resharding_contiguous.cpp +++ b/csrc/preseg_passes/make_resharding_contiguous.cpp @@ -18,17 +18,19 @@ namespace nvfuser::preseg_passes { namespace { -// Validates meshes (i.e. all TensorViews have a device mesh or none) and returns true if any TensorView has a device mesh. +// Validates meshes (i.e. all TensorViews have a device mesh or none) and +// returns true if any TensorView has a device mesh. bool validateMeshes(Fusion* fusion) { // Validate that meshes are assigned to all TensorViews or none. bool tv_with_mesh_found = false; bool tv_without_mesh_found = false; - + for (auto tv : fusion->allTvs()) { if (tv->isCpuScalar()) { continue; } - tv->hasDeviceMesh() ? tv_with_mesh_found = true : tv_without_mesh_found = true; + tv->hasDeviceMesh() ? tv_with_mesh_found = true + : tv_without_mesh_found = true; } NVF_CHECK( !(tv_with_mesh_found && tv_without_mesh_found), @@ -44,7 +46,8 @@ std::vector> reorderLoopAsAllocation(TensorView* tv) { auto alloc_dom = tv->getMaybeAllocationDomain(); auto contiguity = tv->getContiguity(); - auto splitContiguity = [](std::optional contiguity) -> std::pair, std::optional>{ + auto splitContiguity = [](std::optional contiguity) + -> std::pair, std::optional> { if (!contiguity.has_value()) { return std::make_pair(std::nullopt, std::nullopt); } @@ -53,20 +56,20 @@ std::vector> reorderLoopAsAllocation(TensorView* tv) { } return std::make_pair(true, false); }; - + // Allocation domain should be a permutation of logical domain at this point. std::vector transform_exprs = DependencyCheck::getAllExprsBetween( {alloc_dom.begin(), alloc_dom.end()}, {tv->getLoopDomain().begin(), tv->getLoopDomain().end()}); - + NVF_ERROR( std::all_of( transform_exprs.begin(), transform_exprs.end(), [](Expr* expr) { return expr->isA(); }), "Expected all transform exprs to be a split between logical and loop domain during sharding propagation."); - - for (auto* expr: transform_exprs) { + + for (auto* expr : transform_exprs) { Split* split = dynamic_cast(expr); auto find_it = std::find(alloc_dom.begin(), alloc_dom.end(), split->in()); NVF_ERROR( @@ -77,8 +80,9 @@ std::vector> reorderLoopAsAllocation(TensorView* tv) { alloc_dom); auto pos = std::distance(alloc_dom.begin(), find_it); - auto [outer_contiguity, inner_contiguity] = splitContiguity(contiguity.at(pos)); - + auto [outer_contiguity, inner_contiguity] = + splitContiguity(contiguity.at(pos)); + alloc_dom[pos] = split->inner(); alloc_dom.insert(alloc_dom.begin() + pos, split->outer()); @@ -86,13 +90,14 @@ std::vector> reorderLoopAsAllocation(TensorView* tv) { contiguity.insert(contiguity.begin() + pos, outer_contiguity); } - std::optional> permutation = ir_utils::computePermutation(alloc_dom, tv->getLoopDomain()); + std::optional> permutation = + ir_utils::computePermutation(alloc_dom, tv->getLoopDomain()); NVF_ERROR( - permutation.has_value(), - "Failed to find a valid permutation for reordering", - tv->getLoopDomain(), - " as ", - alloc_dom); + permutation.has_value(), + "Failed to find a valid permutation for reordering", + tv->getLoopDomain(), + " as ", + alloc_dom); tv->reorder(permutation.value()); return contiguity; @@ -100,17 +105,14 @@ std::vector> reorderLoopAsAllocation(TensorView* tv) { bool isTvContiguous(TensorView* tv) { return std::all_of( - tv->getContiguity().begin(), - tv->getContiguity().end(), - [](std::optional c) { - return c.value_or(true); - } - ); + tv->getContiguity().begin(), + tv->getContiguity().end(), + [](std::optional c) { return c.value_or(true); }); } template void setShardedAllocationDomain(Range tvs) { - for (auto tv: tvs) { + for (auto tv : tvs) { auto contiguity = reorderLoopAsAllocation(tv); tv->setAllocationDomain(tv->getLoopDomain(), contiguity); } @@ -129,13 +131,13 @@ void MakeReshardingContiguousPass::runPass(Fusion* fusion) { auto outputs = ir_utils::filterByType(expr->outputs()); if (isResharding(expr)) { - NVF_CHECK(std::all_of( - inputs.begin(), - inputs.end(), - [](TensorView* tv) { - return isTvContiguous(tv); - } - ), "Resharding expression inputs must be contiguous: ", expr); + NVF_CHECK( + std::all_of( + inputs.begin(), + inputs.end(), + [](TensorView* tv) { return isTvContiguous(tv); }), + "Resharding expression inputs must be contiguous: ", + expr); } setShardedAllocationDomain(inputs); diff --git a/csrc/preseg_passes/make_resharding_contiguous.h b/csrc/preseg_passes/make_resharding_contiguous.h index d9c7d002db0..b8ee5475d8a 100644 --- a/csrc/preseg_passes/make_resharding_contiguous.h +++ b/csrc/preseg_passes/make_resharding_contiguous.h @@ -21,12 +21,12 @@ namespace nvfuser::preseg_passes { // contiguous tensors and output contiguous buffers. This pass checks that // inputs are contiguous. // 3. Sets the allocation domain of all fusion tvs if they have a device mesh. -// The allocation domain is obtained by transforming the `maybeAllocationDomain` using -// the transforms to loop domain. This ensures that the allocation domain has DID loop splits. -// All iterdomains derived from a given logical iterdomain are placed together. -// See `reorderLoopAsAllocation` for more details. -// Eventually, this pass should run after `markAliasesPrepare` and `AllocationDomainPass` -// after they are fixed. +// The allocation domain is obtained by transforming the `maybeAllocationDomain` +// using the transforms to loop domain. This ensures that the allocation domain +// has DID loop splits. All iterdomains derived from a given logical iterdomain +// are placed together. See `reorderLoopAsAllocation` for more details. +// Eventually, this pass should run after `markAliasesPrepare` and +// `AllocationDomainPass` after they are fixed. class MakeReshardingContiguousPass : public OptimizationPass { friend class OptimizationPass; diff --git a/csrc/preseg_passes/mark_aliases_prepare.cpp b/csrc/preseg_passes/mark_aliases_prepare.cpp index b2176d16695..227d1e3784e 100644 --- a/csrc/preseg_passes/mark_aliases_prepare.cpp +++ b/csrc/preseg_passes/mark_aliases_prepare.cpp @@ -13,7 +13,6 @@ #include #include #include -#include namespace nvfuser::preseg_passes { diff --git a/csrc/preseg_passes/propagate_shardings.cpp b/csrc/preseg_passes/propagate_shardings.cpp index a670a82fd05..f2f56d0b813 100644 --- a/csrc/preseg_passes/propagate_shardings.cpp +++ b/csrc/preseg_passes/propagate_shardings.cpp @@ -142,7 +142,8 @@ int64_t selectiveReorderDIDToFront( } // Returns the set of parallel types seen on the loop domain of the given tvs. -std::unordered_set getTvParallelTypes(std::vector tvs) { +std::unordered_set getTvParallelTypes( + std::vector tvs) { std::unordered_set parallel_types; for (auto tv : tvs) { for (auto id : tv->getLoopDomain()) { @@ -154,14 +155,16 @@ std::unordered_set getTvParallelTypes(std::vector tvs return parallel_types; } -void propagateDIDTransform(TensorView* ref, std::vector tvs, int64_t did_pos, bool allow_c2p, bool allow_p2c) { +void propagateDIDTransform( + TensorView* ref, + std::vector tvs, + int64_t did_pos, + bool allow_c2p, + bool allow_p2c) { TransformPropagator propagator(ref, did_pos); PropagateShardingsSelector selector( - {tvs.begin(), tvs.end()}, - allow_c2p, - allow_p2c); - MaxLogicalDomainInfoSpanningTree(ref, &selector) - .traverse(&propagator); + {tvs.begin(), tvs.end()}, allow_c2p, allow_p2c); + MaxLogicalDomainInfoSpanningTree(ref, &selector).traverse(&propagator); } } // namespace @@ -206,14 +209,20 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { // Reorder the DID axis to the front only if it does not have a parallel // type already seen on the outputs. - std::unordered_set existing_parallel_types = getTvParallelTypes(outputs_without_mesh); - + std::unordered_set existing_parallel_types = + getTvParallelTypes(outputs_without_mesh); + // This restricts the transform propagation to only the relevant DID axis. int64_t did_pos = selectiveReorderDIDToFront(ref_input, existing_parallel_types); // Propagate the DID loop split to the outputs without mesh. - propagateDIDTransform(/*ref=*/ref_input, /*tvs=*/outputs_without_mesh, /*did_pos=*/did_pos, /*allow_c2p=*/false, /*allow_p2c=*/true); + propagateDIDTransform( + /*ref=*/ref_input, + /*tvs=*/outputs_without_mesh, + /*did_pos=*/did_pos, + /*allow_c2p=*/false, + /*allow_p2c=*/true); // Apply parallelization on the outputs without mesh. shardAllLike(ref_input, outputs_without_mesh); @@ -234,7 +243,7 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { // We pick the most parallel output as the reference. // This is to avoid picking seed/offset tvs in SDPA. std::vector sorted_outputs = sortTvsByDeviceDims(outputs); - + if (sorted_outputs.empty()) { // No output with a device mesh. continue; @@ -248,8 +257,9 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { " has no device mesh."); // For fusion inputs, only check if they have a device mesh. We do not - // modify their sharding. For non-fusion inputs, we try to propagate shardings - // from the reference output for parallel types that are not already present. + // modify their sharding. For non-fusion inputs, we try to propagate + // shardings from the reference output for parallel types that are not + // already present. const auto& inputs = ir_utils::filterByType(expr->inputs()); std::vector sharding_candidates; for (auto* tv : inputs) { @@ -273,11 +283,11 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { // TransformPropagator can handle reshapes when going from consumer to // producer. propagateDIDTransform( - /*ref=*/ref_output, - /*tvs=*/sharding_candidates, - /*did_pos=*/did_pos, - /*allow_c2p=*/true, - /*allow_p2c=*/false); + /*ref=*/ref_output, + /*tvs=*/sharding_candidates, + /*did_pos=*/did_pos, + /*allow_c2p=*/true, + /*allow_p2c=*/false); shardAllLike(ref_output, sharding_candidates); } } diff --git a/tests/cpp/test_multidevice_preseg_passes.cpp b/tests/cpp/test_multidevice_preseg_passes.cpp index f3df9c31528..c9494fc64fe 100644 --- a/tests/cpp/test_multidevice_preseg_passes.cpp +++ b/tests/cpp/test_multidevice_preseg_passes.cpp @@ -58,7 +58,10 @@ TEST_F(MultiDevicePresegPassesTest, ResidualAdd) { preseg_passes::OptimizationPass< preseg_passes::PropagateShardingsPass>::runPass(fusion.get()); NVF_CHECK(tv1->hasDeviceMesh()); - NVF_CHECK(getShardedLogicalAxis(tv1, ParallelType::DIDx) == getShardedLogicalAxis(tv0, ParallelType::DIDx), "Expected tv1 to be sharded like tv0 due to backpropagation of shardings."); + NVF_CHECK( + getShardedLogicalAxis(tv1, ParallelType::DIDx) == + getShardedLogicalAxis(tv0, ParallelType::DIDx), + "Expected tv1 to be sharded like tv0 due to backpropagation of shardings."); } namespace { @@ -93,15 +96,15 @@ TEST_F(MultiDevicePresegPassesTest, MLP) { auto mesh = DeviceMesh::createForNumDevices(d); TensorView* inp = makeContigConcreteTensor({b, s, h}); - TensorView* w0 = makeContigConcreteTensor({4*d*h, h}); - TensorView* w1 = makeContigConcreteTensor({h, 4*d*h}); + TensorView* w0 = makeContigConcreteTensor({4 * d * h, h}); + TensorView* w1 = makeContigConcreteTensor({h, 4 * d * h}); TensorView* linear0 = linear(inp, w0); TensorView* gelu = tanh_gelu(linear0); TensorView* linear1 = linear(gelu, w1); - std::vector fusion_inputs {inp, w0, w1}; - for (auto tv: fusion_inputs){ + std::vector fusion_inputs{inp, w0, w1}; + for (auto tv : fusion_inputs) { fusion->addInput(tv); tv->setDeviceMesh(mesh); } @@ -113,9 +116,12 @@ TEST_F(MultiDevicePresegPassesTest, MLP) { w1->axis(1)->parallelize(ParallelType::DIDx); FusionExecutorCache executor_cache(std::move(fusion)); - at::Tensor inp_tensor = at::randn({b, s, h}, tensor_options.dtype(at::kFloat)); - at::Tensor w0_tensor = at::randn({4*d*h, h}, tensor_options.dtype(at::kFloat)); - at::Tensor w1_tensor = at::randn({h, 4*d*h}, tensor_options.dtype(at::kFloat)); + at::Tensor inp_tensor = + at::randn({b, s, h}, tensor_options.dtype(at::kFloat)); + at::Tensor w0_tensor = + at::randn({4 * d * h, h}, tensor_options.dtype(at::kFloat)); + at::Tensor w1_tensor = + at::randn({h, 4 * d * h}, tensor_options.dtype(at::kFloat)); at::Tensor w0_sharded = shardTensor(w0_tensor, 0, mesh); at::Tensor w1_sharded = shardTensor(w1_tensor, 1, mesh); @@ -136,7 +142,8 @@ TEST_F(MultiDevicePresegPassesTest, MHAFwd) { auto mesh = DeviceMesh::createForNumDevices(d); - TensorView* qkv = makeContigConcreteTensor({b, s, d * a, 3 * h / a}, DataType::Half); + TensorView* qkv = + makeContigConcreteTensor({b, s, d * a, 3 * h / a}, DataType::Half); TensorView* q = slice(qkv, {0, 0, 0, 0}, {b, s, d * a, h / a}); TensorView* k = slice(qkv, {0, 0, 0, h / a}, {b, s, d * a, 2 * h / a}); TensorView* v = slice(qkv, {0, 0, 0, 2 * h / a}, {b, s, d * a, 3 * h / a}); @@ -164,7 +171,8 @@ TEST_F(MultiDevicePresegPassesTest, MHAFwd) { qkv->axis(2)->parallelize(ParallelType::DIDx); FusionExecutorCache executor_cache(std::move(fusion)); - at::Tensor unsharded_inp_tensor = at::randn({b, s, d * a, 3 * h / a}, tensor_options.dtype(at::kHalf)); + at::Tensor unsharded_inp_tensor = + at::randn({b, s, d * a, 3 * h / a}, tensor_options.dtype(at::kHalf)); at::Tensor inp_tensor = shardTensor(unsharded_inp_tensor, 2, mesh); KernelArgumentHolder args = {inp_tensor}; @@ -174,5 +182,5 @@ TEST_F(MultiDevicePresegPassesTest, MHAFwd) { at::Tensor ref_out = reference_mha(inp_tensor); EXPECT_TRUE(at::allclose(nvf_out, ref_out)); } - + } // namespace nvfuser From af1c71d6a9722110d3b2065c08499a75ea8d399b Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Tue, 15 Apr 2025 18:38:44 -0700 Subject: [PATCH 57/70] modify tests --- tests/cpp/test_multidevice_matmul.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/cpp/test_multidevice_matmul.cpp b/tests/cpp/test_multidevice_matmul.cpp index b8aba89aa86..6ccaf5707fd 100644 --- a/tests/cpp/test_multidevice_matmul.cpp +++ b/tests/cpp/test_multidevice_matmul.cpp @@ -180,7 +180,7 @@ TEST_F(DistributedMatmulTest, Matmul_LayoutTN_NoComms) { executor_cache.getMostRecentKernelRuntime(); EXPECT_THAT( kernel_runtime->fusionSegments()->groups(), - Contains(HeuristicIs(SchedulerType::ExprEval)).Times(2)); + Contains(HeuristicIs(SchedulerType::ExprEval)).Times(1)); } TEST_F(DistributedMatmulTest, Matmul_LayoutTN_Allgather) { @@ -238,7 +238,7 @@ TEST_F(DistributedMatmulTest, Matmul_LayoutTN_Allgather) { executor_cache.getMostRecentKernelRuntime(); EXPECT_THAT( kernel_runtime->fusionSegments()->groups(), - Contains(HeuristicIs(SchedulerType::ExprEval)).Times(2)); + Contains(HeuristicIs(SchedulerType::ExprEval)).Times(1)); } TEST_F(DistributedMatmulTest, Matmul_LayoutNT_AllReduce) { @@ -289,7 +289,7 @@ TEST_F(DistributedMatmulTest, Matmul_LayoutNT_AllReduce) { executor_cache.getMostRecentKernelRuntime(); EXPECT_THAT( kernel_runtime->fusionSegments()->groups(), - Contains(HeuristicIs(SchedulerType::ExprEval)).Times(2)); + Contains(HeuristicIs(SchedulerType::ExprEval)).Times(1)); } TEST_F(DistributedMatmulTest, Matmul_LayoutNT_ReduceScatter) { From 61520553282c817373118c1b7064b57ae95c9ce9 Mon Sep 17 00:00:00 2001 From: Priya Mishra <52657555+Priya2698@users.noreply.github.com> Date: Wed, 16 Apr 2025 17:58:51 -0700 Subject: [PATCH 58/70] Update csrc/preseg_passes/make_resharding_contiguous.cpp Co-authored-by: Jingyue Wu --- csrc/preseg_passes/make_resharding_contiguous.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/preseg_passes/make_resharding_contiguous.cpp b/csrc/preseg_passes/make_resharding_contiguous.cpp index b9345417ad6..98e3eada9f8 100644 --- a/csrc/preseg_passes/make_resharding_contiguous.cpp +++ b/csrc/preseg_passes/make_resharding_contiguous.cpp @@ -107,7 +107,7 @@ bool isTvContiguous(TensorView* tv) { return std::all_of( tv->getContiguity().begin(), tv->getContiguity().end(), - [](std::optional c) { return c.value_or(true); }); + [](const std::optional& c) { return c.value_or(true); }); } template From 7d54f588e469f3183d9dd6f314e020e0c4ba6a8d Mon Sep 17 00:00:00 2001 From: Priya Mishra <52657555+Priya2698@users.noreply.github.com> Date: Wed, 16 Apr 2025 18:01:00 -0700 Subject: [PATCH 59/70] Update csrc/preseg_passes/make_resharding_contiguous.cpp Co-authored-by: Jingyue Wu --- csrc/preseg_passes/make_resharding_contiguous.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/preseg_passes/make_resharding_contiguous.cpp b/csrc/preseg_passes/make_resharding_contiguous.cpp index 98e3eada9f8..232aac15418 100644 --- a/csrc/preseg_passes/make_resharding_contiguous.cpp +++ b/csrc/preseg_passes/make_resharding_contiguous.cpp @@ -112,7 +112,7 @@ bool isTvContiguous(TensorView* tv) { template void setShardedAllocationDomain(Range tvs) { - for (auto tv : tvs) { + for (TensorView* tv : tvs) { auto contiguity = reorderLoopAsAllocation(tv); tv->setAllocationDomain(tv->getLoopDomain(), contiguity); } From 9f944846fa0b2a23f332721486c4839fd492ab17 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Fri, 18 Apr 2025 16:00:43 -0700 Subject: [PATCH 60/70] restore tests --- tests/cpp/test_multidevice_matmul.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/cpp/test_multidevice_matmul.cpp b/tests/cpp/test_multidevice_matmul.cpp index 6ccaf5707fd..b8aba89aa86 100644 --- a/tests/cpp/test_multidevice_matmul.cpp +++ b/tests/cpp/test_multidevice_matmul.cpp @@ -180,7 +180,7 @@ TEST_F(DistributedMatmulTest, Matmul_LayoutTN_NoComms) { executor_cache.getMostRecentKernelRuntime(); EXPECT_THAT( kernel_runtime->fusionSegments()->groups(), - Contains(HeuristicIs(SchedulerType::ExprEval)).Times(1)); + Contains(HeuristicIs(SchedulerType::ExprEval)).Times(2)); } TEST_F(DistributedMatmulTest, Matmul_LayoutTN_Allgather) { @@ -238,7 +238,7 @@ TEST_F(DistributedMatmulTest, Matmul_LayoutTN_Allgather) { executor_cache.getMostRecentKernelRuntime(); EXPECT_THAT( kernel_runtime->fusionSegments()->groups(), - Contains(HeuristicIs(SchedulerType::ExprEval)).Times(1)); + Contains(HeuristicIs(SchedulerType::ExprEval)).Times(2)); } TEST_F(DistributedMatmulTest, Matmul_LayoutNT_AllReduce) { @@ -289,7 +289,7 @@ TEST_F(DistributedMatmulTest, Matmul_LayoutNT_AllReduce) { executor_cache.getMostRecentKernelRuntime(); EXPECT_THAT( kernel_runtime->fusionSegments()->groups(), - Contains(HeuristicIs(SchedulerType::ExprEval)).Times(1)); + Contains(HeuristicIs(SchedulerType::ExprEval)).Times(2)); } TEST_F(DistributedMatmulTest, Matmul_LayoutNT_ReduceScatter) { From 7c539a2ba0a1704a246e31b56525d7d1a4a55ac5 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Fri, 18 Apr 2025 18:32:22 -0700 Subject: [PATCH 61/70] move tests --- CMakeLists.txt | 1 - tests/cpp/test_multidevice_preseg_passes.cpp | 186 ------------------- tests/cpp/test_multidevice_transformer.cpp | 137 ++++++++++++++ tests/cpp/test_sharding.cpp | 35 ++++ 4 files changed, 172 insertions(+), 187 deletions(-) delete mode 100644 tests/cpp/test_multidevice_preseg_passes.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index d630466a24b..b9865da34a7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -688,7 +688,6 @@ if(BUILD_TEST) ${NVFUSER_ROOT}/tests/cpp/test_multidevice_matmul.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_pipeline.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_sharding.cpp - ${NVFUSER_ROOT}/tests/cpp/test_multidevice_preseg_passes.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_transformer.cpp ) add_test_without_main(test_multidevice "${MULTIDEVICE_TEST_SRCS}" "") diff --git a/tests/cpp/test_multidevice_preseg_passes.cpp b/tests/cpp/test_multidevice_preseg_passes.cpp deleted file mode 100644 index c9494fc64fe..00000000000 --- a/tests/cpp/test_multidevice_preseg_passes.cpp +++ /dev/null @@ -1,186 +0,0 @@ -// clang-format off -/* - * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. - * All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - */ -// clang-format on -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include "multidevice/utils.h" - -namespace nvfuser { - -constexpr int64_t b = 2, s = 3, h = 16, a = 2; -constexpr double dropout_p = 0.0; -constexpr bool is_causal = false; - -using testing::ElementsAre; -using MultiDevicePresegPassesTest = MultiDeviceTest; - -TEST_F(MultiDevicePresegPassesTest, ResidualAdd) { - // This is similar to the residual add after MHA dropout in the transformer. - // The output of linear following MHA is all-gathered and sharded on the - // sequence dim. This sharding can be propagated to the linear output through - // backpropagating the shardings from residual add. This information is not - // present during forward propagation. - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - const int d = communicator_->size(); - - TensorView* tv0 = makeContigConcreteTensor({d, 4}); - TensorView* tv1 = uniform( - shape(tv0), - fusion->zeroVal(DataType::Float), - fusion->oneVal(DataType::Float), - DataType::Float); - TensorView* tv2 = add(tv0, tv1); - - auto mesh = DeviceMesh::createForNumDevices(d); - tv0->setDeviceMesh(mesh); - tv0->split(0, d, /*inner_split=*/false); - tv0->axis(0)->parallelize(ParallelType::DIDx); - - fusion->addInput(tv0); - fusion->addOutput(tv1); - fusion->addOutput(tv2); - - preseg_passes::OptimizationPass< - preseg_passes::PropagateShardingsPass>::runPass(fusion.get()); - NVF_CHECK(tv1->hasDeviceMesh()); - NVF_CHECK( - getShardedLogicalAxis(tv1, ParallelType::DIDx) == - getShardedLogicalAxis(tv0, ParallelType::DIDx), - "Expected tv1 to be sharded like tv0 due to backpropagation of shardings."); -} - -namespace { -at::Tensor reference_mlp(at::Tensor inp, at::Tensor w0, at::Tensor w1) { - auto linear0 = at::linear(inp, w0); - auto gelu = at::gelu(linear0, "tanh"); - auto linear1 = at::linear(gelu, w1); - return linear1; -} - -at::Tensor reference_mha(at::Tensor inp) { - auto qkv = inp.transpose(1, 2).split(h / a, -1); - double scale = 1.0 / std::sqrt(h / a); - auto sdpa_out = at::_scaled_dot_product_flash_attention( - qkv[0], - qkv[1], - qkv[2], - /*dropout_p=*/dropout_p, - /*is_causal=*/is_causal, - /*return_debug_mask=*/false, - scale); - auto attn = std::get<0>(sdpa_out); - return attn.transpose(1, 2); -} -} // namespace - -TEST_F(MultiDevicePresegPassesTest, MLP) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - const int d = communicator_->size(); - auto mesh = DeviceMesh::createForNumDevices(d); - - TensorView* inp = makeContigConcreteTensor({b, s, h}); - TensorView* w0 = makeContigConcreteTensor({4 * d * h, h}); - TensorView* w1 = makeContigConcreteTensor({h, 4 * d * h}); - - TensorView* linear0 = linear(inp, w0); - TensorView* gelu = tanh_gelu(linear0); - TensorView* linear1 = linear(gelu, w1); - - std::vector fusion_inputs{inp, w0, w1}; - for (auto tv : fusion_inputs) { - fusion->addInput(tv); - tv->setDeviceMesh(mesh); - } - fusion->addOutput(linear1); - - w0->outer_split(0, d); - w0->axis(0)->parallelize(ParallelType::DIDx); - w1->outer_split(1, d); - w1->axis(1)->parallelize(ParallelType::DIDx); - - FusionExecutorCache executor_cache(std::move(fusion)); - at::Tensor inp_tensor = - at::randn({b, s, h}, tensor_options.dtype(at::kFloat)); - at::Tensor w0_tensor = - at::randn({4 * d * h, h}, tensor_options.dtype(at::kFloat)); - at::Tensor w1_tensor = - at::randn({h, 4 * d * h}, tensor_options.dtype(at::kFloat)); - - at::Tensor w0_sharded = shardTensor(w0_tensor, 0, mesh); - at::Tensor w1_sharded = shardTensor(w1_tensor, 1, mesh); - - KernelArgumentHolder args = {inp_tensor, w0_sharded, w1_sharded}; - auto outputs = executor_cache.runFusionWithInputs(args); - at::Tensor nvf_out = outputs[0].as(); - - at::Tensor ref_out = reference_mlp(inp_tensor, w0_tensor, w1_tensor); - EXPECT_TRUE(at::allclose(nvf_out, ref_out)); -} - -TEST_F(MultiDevicePresegPassesTest, MHAFwd) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - const int d = communicator_->size(); - - auto mesh = DeviceMesh::createForNumDevices(d); - - TensorView* qkv = - makeContigConcreteTensor({b, s, d * a, 3 * h / a}, DataType::Half); - TensorView* q = slice(qkv, {0, 0, 0, 0}, {b, s, d * a, h / a}); - TensorView* k = slice(qkv, {0, 0, 0, h / a}, {b, s, d * a, 2 * h / a}); - TensorView* v = slice(qkv, {0, 0, 0, 2 * h / a}, {b, s, d * a, 3 * h / a}); - - TensorView* q_permuted = permute(q, {0, 2, 1, 3}); - TensorView* k_permuted = permute(k, {0, 2, 1, 3}); - TensorView* v_permuted = permute(v, {0, 2, 1, 3}); - - SdpfaFwdResult sdpa_out = sdpfa_fwd( - q_permuted, - k_permuted, - v_permuted, - /*dropout_p=*/IrBuilder::create(dropout_p), - /*is_causal=*/IrBuilder::create(is_causal), - /*scale=*/nullptr); - - TensorView* attn = sdpa_out.output; - TensorView* attn_permute = permute(attn, {0, 2, 1, 3}); - - fusion->addInput(qkv); - fusion->addOutput(attn_permute); - - qkv->setDeviceMesh(mesh); - qkv->outer_split(2, d); - qkv->axis(2)->parallelize(ParallelType::DIDx); - - FusionExecutorCache executor_cache(std::move(fusion)); - at::Tensor unsharded_inp_tensor = - at::randn({b, s, d * a, 3 * h / a}, tensor_options.dtype(at::kHalf)); - at::Tensor inp_tensor = shardTensor(unsharded_inp_tensor, 2, mesh); - - KernelArgumentHolder args = {inp_tensor}; - auto outputs = executor_cache.runFusionWithInputs(args); - at::Tensor nvf_out = outputs[0].as(); - - at::Tensor ref_out = reference_mha(inp_tensor); - EXPECT_TRUE(at::allclose(nvf_out, ref_out)); -} - -} // namespace nvfuser diff --git a/tests/cpp/test_multidevice_transformer.cpp b/tests/cpp/test_multidevice_transformer.cpp index 55f23bdc5a1..f122e4da73f 100644 --- a/tests/cpp/test_multidevice_transformer.cpp +++ b/tests/cpp/test_multidevice_transformer.cpp @@ -1016,6 +1016,143 @@ TEST_P(DistributedTransformerTest, Backward) { 0.02}); } +namespace { +at::Tensor reference_loop_split_mlp( + at::Tensor inp, + at::Tensor w0, + at::Tensor w1) { + auto linear0 = at::linear(inp, w0); + auto gelu = at::gelu(linear0, "tanh"); + auto linear1 = at::linear(gelu, w1); + return linear1; +} + +at::Tensor reference_loop_split_mha(at::Tensor inp) { + auto qkv = inp.transpose(1, 2).split(E / H, -1); + double scale = 1.0 / std::sqrt(E / H); + auto sdpa_out = at::_scaled_dot_product_flash_attention( + qkv[0], + qkv[1], + qkv[2], + /*dropout_p=*/kDropoutProb, + /*is_causal=*/true, + /*return_debug_mask=*/false, + scale); + auto attn = std::get<0>(sdpa_out); + return attn.transpose(1, 2); +} +} // namespace + +TEST_F(DistributedTransformerTest, LoopSplitMLP) { + if ((4 * E) % D != 0) { + GTEST_SKIP() << "Requires number of devices=" << D + << " evenly divide 4*E=" << 4 * E; + } + auto dtype = DataType::Float; + // auto dtype = GetParam(); + at::ScalarType at_dtype = data_type_to_aten(dtype); + + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + const int d = communicator_->size(); + auto mesh = DeviceMesh::createForNumDevices(d); + + TensorView* inp = makeContigConcreteTensor({B, S, E}, dtype); + TensorView* w0 = makeContigConcreteTensor({4 * E, E}, dtype); + TensorView* w1 = makeContigConcreteTensor({E, 4 * E}, dtype); + + TensorView* linear0 = linear(inp, w0); + TensorView* linear0_float = castOp(DataType::Float, linear0); + TensorView* gelu = tanh_gelu(linear0_float); + TensorView* gelu_dtype = castOp(dtype, gelu); + TensorView* linear1 = linear(gelu_dtype, w1); + + std::vector fusion_inputs{inp, w0, w1}; + for (auto tv : fusion_inputs) { + fusion->addInput(tv); + tv->setDeviceMesh(mesh); + } + fusion->addOutput(linear1); + + w0->outer_split(0, d); + w0->axis(0)->parallelize(ParallelType::DIDx); + w1->outer_split(1, d); + w1->axis(1)->parallelize(ParallelType::DIDx); + + FusionExecutorCache executor_cache(std::move(fusion)); + at::Tensor inp_tensor = at::randn({B, S, E}, tensor_options.dtype(at_dtype)); + at::Tensor w0_tensor = at::randn({4 * E, E}, tensor_options.dtype(at_dtype)); + at::Tensor w1_tensor = at::randn({E, 4 * E}, tensor_options.dtype(at_dtype)); + + at::Tensor w0_sharded = shardTensor(w0_tensor, 0, mesh); + at::Tensor w1_sharded = shardTensor(w1_tensor, 1, mesh); + + KernelArgumentHolder args = {inp_tensor, w0_sharded, w1_sharded}; + auto outputs = executor_cache.runFusionWithInputs(args); + at::Tensor nvf_out = outputs[0].as(); + + at::Tensor ref_out = + reference_loop_split_mlp(inp_tensor, w0_tensor, w1_tensor); + validate({ref_out}, {nvf_out}, {0.02}); +} + +TEST_F(DistributedTransformerTest, LoopSplitMHAFwd) { + if (H % D != 0) { + GTEST_SKIP() << "Requires number of devices=" << D + << " evenly divide H=" << H; + } + + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + // auto dtype = GetParam(); + auto dtype = DataType::Half; + at::ScalarType at_dtype = data_type_to_aten(dtype); + + const int d = communicator_->size(); + + auto mesh = DeviceMesh::createForNumDevices(d); + + TensorView* qkv = makeContigConcreteTensor({B, S, H, 3 * E / H}, dtype); + TensorView* q = slice(qkv, {0, 0, 0, 0}, {B, S, H, E / H}); + TensorView* k = slice(qkv, {0, 0, 0, E / H}, {B, S, H, 2 * E / H}); + TensorView* v = slice(qkv, {0, 0, 0, 2 * E / H}, {B, S, H, 3 * E / H}); + + TensorView* q_permuted = permute(q, {0, 2, 1, 3}); + TensorView* k_permuted = permute(k, {0, 2, 1, 3}); + TensorView* v_permuted = permute(v, {0, 2, 1, 3}); + + SdpfaFwdResult sdpa_out = sdpfa_fwd( + q_permuted, + k_permuted, + v_permuted, + /*dropout_p=*/IrBuilder::create(kDropoutProb), + /*is_causal=*/IrBuilder::create(true), + /*scale=*/nullptr); + + TensorView* attn = sdpa_out.output; + TensorView* attn_permute = permute(attn, {0, 2, 1, 3}); + + fusion->addInput(qkv); + fusion->addOutput(attn_permute); + + qkv->setDeviceMesh(mesh); + qkv->outer_split(2, d); + qkv->axis(2)->parallelize(ParallelType::DIDx); + + FusionExecutorCache executor_cache(std::move(fusion)); + at::Tensor unsharded_inp_tensor = + at::randn({B, S, H, 3 * E / H}, tensor_options.dtype(at_dtype)); + at::Tensor inp_tensor = shardTensor(unsharded_inp_tensor, 2, mesh); + + KernelArgumentHolder args = {inp_tensor}; + auto outputs = executor_cache.runFusionWithInputs(args); + at::Tensor nvf_out = outputs[0].as(); + at::Tensor ref_out = reference_loop_split_mha(inp_tensor); + validate({ref_out}, {nvf_out}, {0.02}); +} + INSTANTIATE_TEST_SUITE_P( , DistributedTransformerTest, diff --git a/tests/cpp/test_sharding.cpp b/tests/cpp/test_sharding.cpp index 1ce1d96d8d0..ffbbabb7402 100644 --- a/tests/cpp/test_sharding.cpp +++ b/tests/cpp/test_sharding.cpp @@ -234,6 +234,41 @@ TEST_F(ShardingTest, MultiDimDeviceMesh) { EXPECT_EQ(mesh3d.getSlice(18, ParallelType::DIDx), slice_didx); } +TEST_F(ShardingTest, ResidualAdd) { + // This is similar to the residual add after MHA dropout in the transformer. + // The output of linear following MHA is all-gathered and sharded on the + // sequence dim. This sharding can be propagated to the linear output through + // backpropagating the shardings from residual add. This information is not + // present during forward propagation. + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + DeviceMesh mesh({0, 1}); + + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = uniform( + shape(tv0), + fusion->zeroVal(DataType::Float), + fusion->oneVal(DataType::Float), + DataType::Float); + TensorView* tv2 = add(tv0, tv1); + + tv0->setDeviceMesh(mesh); + tv0->outer_split(0, mesh.size()); + tv0->axis(0)->parallelize(ParallelType::DIDx); + + fusion->addInput(tv0); + fusion->addOutput(tv1); + fusion->addOutput(tv2); + + preseg_passes::OptimizationPass< + preseg_passes::PropagateShardingsPass>::runPass(fusion.get()); + NVF_CHECK(tv1->hasDeviceMesh()); + NVF_CHECK( + getShardedLogicalAxis(tv1, ParallelType::DIDx) == + getShardedLogicalAxis(tv0, ParallelType::DIDx), + "Expected tv1 to be sharded like tv0 due to backpropagation of shardings."); +} + INSTANTIATE_TEST_SUITE_P( , ShardingTest, From 11d3d0423a70feb1267bf2510888d6b8b23a0d7a Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Tue, 22 Apr 2025 13:42:00 -0700 Subject: [PATCH 62/70] clean tests --- tests/cpp/test_multidevice_transformer.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/cpp/test_multidevice_transformer.cpp b/tests/cpp/test_multidevice_transformer.cpp index f122e4da73f..025726698be 100644 --- a/tests/cpp/test_multidevice_transformer.cpp +++ b/tests/cpp/test_multidevice_transformer.cpp @@ -1043,13 +1043,15 @@ at::Tensor reference_loop_split_mha(at::Tensor inp) { } } // namespace +// TODO: Allow testing for float16 and bfloat16 for loop split mlp and mha +// This currently fails because privatizeUpcast clones cast operations, +// which fails segmentation since the transforms are not replicated. TEST_F(DistributedTransformerTest, LoopSplitMLP) { if ((4 * E) % D != 0) { GTEST_SKIP() << "Requires number of devices=" << D << " evenly divide 4*E=" << 4 * E; } auto dtype = DataType::Float; - // auto dtype = GetParam(); at::ScalarType at_dtype = data_type_to_aten(dtype); auto fusion = std::make_unique(); @@ -1106,7 +1108,6 @@ TEST_F(DistributedTransformerTest, LoopSplitMHAFwd) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); - // auto dtype = GetParam(); auto dtype = DataType::Half; at::ScalarType at_dtype = data_type_to_aten(dtype); From 00d93fcb764c2b6afb4974e79997d0c5f2c7d4bd Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Tue, 22 Apr 2025 13:53:51 -0700 Subject: [PATCH 63/70] comment; --- .../make_resharding_contiguous.cpp | 40 ++++++++++--------- 1 file changed, 22 insertions(+), 18 deletions(-) diff --git a/csrc/preseg_passes/make_resharding_contiguous.cpp b/csrc/preseg_passes/make_resharding_contiguous.cpp index 232aac15418..0bc15e7c691 100644 --- a/csrc/preseg_passes/make_resharding_contiguous.cpp +++ b/csrc/preseg_passes/make_resharding_contiguous.cpp @@ -38,11 +38,21 @@ bool validateMeshes(Fusion* fusion) { return tv_with_mesh_found; } -// Transform the maybe allocation domain to the loop domain. -// using exprs between logical and loop and get the permutation required to -// reorder the loop domain in the same relative order as the allocation domain. -// Returns the contiguity of the transformed allocation domain. -std::vector> reorderLoopAsAllocation(TensorView* tv) { +// Reorders the loop domain in the same relative order as the allocation domain. +// Specifically: +// 1. It uses the exprs between logical and loop domain to split the allocation +// domain +// 2. It reorders the loop domain to match the split allocation domain. +// 3. It computes the contiguity of the transformed allocation domain through +// the split exprs. +// 4. Sets the allocation domain to be the same as the loop domain with the +// computed contiguity. This preserves both the sharding and any stride order. +// Note: Ideally, the loop domain can follow the logical domain and the +// allocation domain can follow the stride order specified/inferred. However, we +// currently require loop domain to be the same as allocation domain. This +// behavior will be modified in the future with allocation and loop domain being +// propagated independently. +void setLoopAndAllocationDomain(TensorView* tv) { auto alloc_dom = tv->getMaybeAllocationDomain(); auto contiguity = tv->getContiguity(); @@ -99,8 +109,7 @@ std::vector> reorderLoopAsAllocation(TensorView* tv) { " as ", alloc_dom); tv->reorder(permutation.value()); - - return contiguity; + tv->setAllocationDomain(tv->getLoopDomain(), contiguity); } bool isTvContiguous(TensorView* tv) { @@ -110,14 +119,6 @@ bool isTvContiguous(TensorView* tv) { [](const std::optional& c) { return c.value_or(true); }); } -template -void setShardedAllocationDomain(Range tvs) { - for (TensorView* tv : tvs) { - auto contiguity = reorderLoopAsAllocation(tv); - tv->setAllocationDomain(tv->getLoopDomain(), contiguity); - } -} - } // namespace void MakeReshardingContiguousPass::runPass(Fusion* fusion) { @@ -139,9 +140,12 @@ void MakeReshardingContiguousPass::runPass(Fusion* fusion) { "Resharding expression inputs must be contiguous: ", expr); } - - setShardedAllocationDomain(inputs); - setShardedAllocationDomain(outputs); + for (auto tv : inputs) { + setLoopAndAllocationDomain(tv); + } + for (auto tv : outputs) { + setLoopAndAllocationDomain(tv); + } } } From bbb835e6df2a5e0b5abfac0fb86cbe9732871d60 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Tue, 22 Apr 2025 13:55:06 -0700 Subject: [PATCH 64/70] comment --- csrc/preseg_passes/make_resharding_contiguous.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/preseg_passes/make_resharding_contiguous.h b/csrc/preseg_passes/make_resharding_contiguous.h index b8ee5475d8a..8a719683004 100644 --- a/csrc/preseg_passes/make_resharding_contiguous.h +++ b/csrc/preseg_passes/make_resharding_contiguous.h @@ -24,7 +24,7 @@ namespace nvfuser::preseg_passes { // The allocation domain is obtained by transforming the `maybeAllocationDomain` // using the transforms to loop domain. This ensures that the allocation domain // has DID loop splits. All iterdomains derived from a given logical iterdomain -// are placed together. See `reorderLoopAsAllocation` for more details. +// are placed together. See `setLoopAndAllocationDomain` for more details. // Eventually, this pass should run after `markAliasesPrepare` and // `AllocationDomainPass` after they are fixed. class MakeReshardingContiguousPass From 3192018c36924fdd59ec68feb41cbf3216432548 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Tue, 22 Apr 2025 16:55:51 -0700 Subject: [PATCH 65/70] propagate only selected types --- csrc/preseg_passes/propagate_shardings.cpp | 46 +++++++++++++--------- 1 file changed, 27 insertions(+), 19 deletions(-) diff --git a/csrc/preseg_passes/propagate_shardings.cpp b/csrc/preseg_passes/propagate_shardings.cpp index f2f56d0b813..54bb3a0de6f 100644 --- a/csrc/preseg_passes/propagate_shardings.cpp +++ b/csrc/preseg_passes/propagate_shardings.cpp @@ -47,15 +47,16 @@ std::vector sortTvsByDeviceDims(const Range& tvs) { std::vector tvs_with_mesh = filterTvsWithMesh(tvs); // Then sort the filtered TVs - std::sort(tvs_with_mesh.begin(), tvs_with_mesh.end(), [](auto a, auto b) { - int64_t a_device_dims = numDeviceDims(a); - int64_t b_device_dims = numDeviceDims(b); - if (a_device_dims != b_device_dims) { - return a_device_dims >= b_device_dims; - } - // Break ties by the total number of dimensions - return a->nDims() >= b->nDims(); - }); + std::stable_sort( + tvs_with_mesh.begin(), tvs_with_mesh.end(), [](auto a, auto b) { + int64_t a_device_dims = numDeviceDims(a); + int64_t b_device_dims = numDeviceDims(b); + if (a_device_dims != b_device_dims) { + return a_device_dims > b_device_dims; + } + // Break ties by the total number of dimensions + return a->nDims() > b->nDims(); + }); return tvs_with_mesh; } @@ -125,13 +126,13 @@ class PropagateShardingsSelector : public SetSelector { // Returns the number of device dimensions that were reordered to the front. int64_t selectiveReorderDIDToFront( TensorView* tv, - std::unordered_set existing_parallel_types) { + std::unordered_set selected_parallel_types) { std::unordered_map old2new; int64_t current_pos = 0; for (auto pos : c10::irange(tv->nDims())) { if (tv->axis(pos)->isDeviceDim() && - !existing_parallel_types.count(tv->axis(pos)->getParallelType())) { + selected_parallel_types.count(tv->axis(pos)->getParallelType())) { old2new[pos] = current_pos; current_pos++; } @@ -142,17 +143,24 @@ int64_t selectiveReorderDIDToFront( } // Returns the set of parallel types seen on the loop domain of the given tvs. -std::unordered_set getTvParallelTypes( +std::unordered_set getParallelTypesToPropagate( std::vector tvs) { - std::unordered_set parallel_types; + // Get the set of parallel types seen on the loop domain of the given tvs. + std::unordered_set existing_parallel_types; for (auto tv : tvs) { for (auto id : tv->getLoopDomain()) { if (id->isDeviceDim()) { - parallel_types.insert(id->getParallelType()); + existing_parallel_types.insert(id->getParallelType()); } } } - return parallel_types; + std::unordered_set selected_parallel_types; + for (ParallelType pt : kParallelTypeDIDs) { + if (!existing_parallel_types.count(pt)) { + selected_parallel_types.insert(pt); + } + } + return selected_parallel_types; } void propagateDIDTransform( @@ -209,12 +217,12 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { // Reorder the DID axis to the front only if it does not have a parallel // type already seen on the outputs. - std::unordered_set existing_parallel_types = - getTvParallelTypes(outputs_without_mesh); + std::unordered_set selected_parallel_types = + getParallelTypesToPropagate(outputs_without_mesh); // This restricts the transform propagation to only the relevant DID axis. int64_t did_pos = - selectiveReorderDIDToFront(ref_input, existing_parallel_types); + selectiveReorderDIDToFront(ref_input, selected_parallel_types); // Propagate the DID loop split to the outputs without mesh. propagateDIDTransform( @@ -225,7 +233,7 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { /*allow_p2c=*/true); // Apply parallelization on the outputs without mesh. - shardAllLike(ref_input, outputs_without_mesh); + shardAllLike(ref_input, outputs_without_mesh, selected_parallel_types); } } From afb580ca27ea31eca96ae7b9879035e25c85b605 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Wed, 23 Apr 2025 13:29:58 -0700 Subject: [PATCH 66/70] check contiguity for both inp and output --- .../make_resharding_contiguous.cpp | 19 ++++++++++--------- csrc/preseg_passes/propagate_shardings.cpp | 15 +++++++++++---- 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/csrc/preseg_passes/make_resharding_contiguous.cpp b/csrc/preseg_passes/make_resharding_contiguous.cpp index 0bc15e7c691..359f562011d 100644 --- a/csrc/preseg_passes/make_resharding_contiguous.cpp +++ b/csrc/preseg_passes/make_resharding_contiguous.cpp @@ -131,21 +131,22 @@ void MakeReshardingContiguousPass::runPass(Fusion* fusion) { auto inputs = ir_utils::filterByType(expr->inputs()); auto outputs = ir_utils::filterByType(expr->outputs()); - if (isResharding(expr)) { - NVF_CHECK( - std::all_of( - inputs.begin(), - inputs.end(), - [](TensorView* tv) { return isTvContiguous(tv); }), - "Resharding expression inputs must be contiguous: ", - expr); - } for (auto tv : inputs) { setLoopAndAllocationDomain(tv); } for (auto tv : outputs) { setLoopAndAllocationDomain(tv); } + + if (isResharding(expr)) { + auto check_contiguity = [&](const auto& tvs) { + return std::all_of(tvs.begin(), tvs.end(), isTvContiguous); + }; + NVF_CHECK( + check_contiguity(inputs) && check_contiguity(outputs), + "Resharding expression must have contiguous inputs and outputs: ", + expr); + } } } diff --git a/csrc/preseg_passes/propagate_shardings.cpp b/csrc/preseg_passes/propagate_shardings.cpp index 54bb3a0de6f..9ed3f83243a 100644 --- a/csrc/preseg_passes/propagate_shardings.cpp +++ b/csrc/preseg_passes/propagate_shardings.cpp @@ -36,7 +36,7 @@ int64_t numDeviceDims(TensorView* tv) { tv->getLoopDomain().begin(), tv->getLoopDomain().end(), std::mem_fn(&IterDomain::isDeviceDim)); -}; +} // Sort the given tvs by the number of device dimensions in descending order. // Break ties by the total number of dimensions. @@ -121,9 +121,9 @@ class PropagateShardingsSelector : public SetSelector { } }; -// Reorder the DID axis to the front only if it does not have a parallel type -// already seen on the output (existing_parallel_types). +// Reorder the DID axis with the given parallel types to the front. // Returns the number of device dimensions that were reordered to the front. +// This allows us to limit propagation to only the relevant DID axis. int64_t selectiveReorderDIDToFront( TensorView* tv, std::unordered_set selected_parallel_types) { @@ -216,7 +216,14 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { " has no device mesh."); // Reorder the DID axis to the front only if it does not have a parallel - // type already seen on the outputs. + // type already seen on the outputs. This avoids propagating the same + // parallel type on multiple axis of the output when using multiple + // reference inputs. Consider out [M, N] = linear (inp [M, K], weight (N, + // K)) with inp sharded on M ([DIDx(d), M/d, K]) and weight sharded on N + // ([DIDy(d), N/d, K]). We propagate from weights first, so the output + // will be [M, DIDx(d), N/d]. When we propagate from inp next, we should + // not propagate DIDx parallel type to the output. Otherwise, the output + // will have multiple DIDx shardings which is invalid. std::unordered_set selected_parallel_types = getParallelTypesToPropagate(outputs_without_mesh); From 3b480144229198942f73e43950bf369b6d40b8dc Mon Sep 17 00:00:00 2001 From: Priya Mishra <52657555+Priya2698@users.noreply.github.com> Date: Wed, 23 Apr 2025 14:54:47 -0700 Subject: [PATCH 67/70] Update csrc/preseg_passes/propagate_shardings.cpp Co-authored-by: Jingyue Wu --- csrc/preseg_passes/propagate_shardings.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/preseg_passes/propagate_shardings.cpp b/csrc/preseg_passes/propagate_shardings.cpp index 9ed3f83243a..d8b86f9f38f 100644 --- a/csrc/preseg_passes/propagate_shardings.cpp +++ b/csrc/preseg_passes/propagate_shardings.cpp @@ -250,7 +250,7 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { // `uniform` used in dropout. See MultiDeviceTest.BackpropMeshes for an // example. For non-fusion inputs, we also propagate shardings from outputs to // inputs. See MultiDevicePresegPassesTest.ResidualAdd for an example. - for (auto i_expr = exprs.rbegin(); i_expr != exprs.rend(); i_expr++) { + for (Expr* expr : exprs | std::views::reverse) { Expr* expr = *i_expr; const auto& outputs = ir_utils::filterByType(expr->outputs()); From 354f8dade1c8cb4537874175fd1a376251cf3a3f Mon Sep 17 00:00:00 2001 From: Priya Mishra <52657555+Priya2698@users.noreply.github.com> Date: Wed, 23 Apr 2025 14:54:56 -0700 Subject: [PATCH 68/70] Update csrc/preseg_passes/propagate_shardings.cpp Co-authored-by: Jingyue Wu --- csrc/preseg_passes/propagate_shardings.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/preseg_passes/propagate_shardings.cpp b/csrc/preseg_passes/propagate_shardings.cpp index d8b86f9f38f..e867dac754a 100644 --- a/csrc/preseg_passes/propagate_shardings.cpp +++ b/csrc/preseg_passes/propagate_shardings.cpp @@ -126,7 +126,7 @@ class PropagateShardingsSelector : public SetSelector { // This allows us to limit propagation to only the relevant DID axis. int64_t selectiveReorderDIDToFront( TensorView* tv, - std::unordered_set selected_parallel_types) { + const std::unordered_set& selected_parallel_types) { std::unordered_map old2new; int64_t current_pos = 0; From 52573e7c4a7f4b53a1240c061a182a6ad63487a4 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Wed, 23 Apr 2025 15:12:48 -0700 Subject: [PATCH 69/70] review comments --- csrc/preseg_passes/propagate_shardings.cpp | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/csrc/preseg_passes/propagate_shardings.cpp b/csrc/preseg_passes/propagate_shardings.cpp index e867dac754a..b3f7344a8e7 100644 --- a/csrc/preseg_passes/propagate_shardings.cpp +++ b/csrc/preseg_passes/propagate_shardings.cpp @@ -130,9 +130,9 @@ int64_t selectiveReorderDIDToFront( std::unordered_map old2new; int64_t current_pos = 0; - for (auto pos : c10::irange(tv->nDims())) { - if (tv->axis(pos)->isDeviceDim() && - selected_parallel_types.count(tv->axis(pos)->getParallelType())) { + for (auto&& [pos, id] : enumerate(tv->getLoopDomain())) { + if (id->isDeviceDim() && + selected_parallel_types.count(id->getParallelType())) { old2new[pos] = current_pos; current_pos++; } @@ -251,8 +251,6 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { // example. For non-fusion inputs, we also propagate shardings from outputs to // inputs. See MultiDevicePresegPassesTest.ResidualAdd for an example. for (Expr* expr : exprs | std::views::reverse) { - Expr* expr = *i_expr; - const auto& outputs = ir_utils::filterByType(expr->outputs()); // All outputs of an expression (Welford, SDPA) should be uniformly sharded. // We pick the most parallel output as the reference. From 5eea3abe0a1740036191cd331ee9889b6999596b Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Wed, 23 Apr 2025 16:18:10 -0700 Subject: [PATCH 70/70] Simplify some tests since sharding propagation is in place --- tests/python/multidevice/test_matmul.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/tests/python/multidevice/test_matmul.py b/tests/python/multidevice/test_matmul.py index 76ce1939edf..1c877b4acb5 100644 --- a/tests/python/multidevice/test_matmul.py +++ b/tests/python/multidevice/test_matmul.py @@ -81,7 +81,7 @@ def definition(self): self.add_output(self.out) def multidevice_schedule(self): - for t in [self.inp, self.weight, self.bias, self.out]: + for t in [self.inp, self.weight, self.bias]: self.sched._set_device_mesh(t, mesh) # Shard N for weight (N, K) and bias (N) @@ -90,12 +90,6 @@ def multidevice_schedule(self): self.sched.parallelize(t, 0, nvfuser.ParallelType.mesh_x) self.sched.set_allocation_as_loop(t) - # Output of linear: {.., i{M}, i{N}, r{K}} - # Shard N -> axis(-2) - self.sched.split(self.out, -2, d, False) - self.sched.parallelize(self.out, -3, nvfuser.ParallelType.mesh_x) - self.sched.set_allocation_as_loop(self.out) - torch.cuda.set_device(multidevice_test.local_rank) b, s = 2, 1024 @@ -135,7 +129,7 @@ def definition(self): self.add_output(self.out) def multidevice_schedule(self): - for t in [self.inp, self.weight, self.out]: + for t in [self.inp, self.weight]: self.sched._set_device_mesh(t, mesh) self.sched.split(t, -1, d, False) self.sched.parallelize(t, -2, nvfuser.ParallelType.mesh_x)