From 4e915670c4c5892e778a1e8d558c4c961b2866a4 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 17 May 2023 14:58:44 -0400 Subject: [PATCH 1/9] Initial try at fixing output domain extent --- csrc/dynamic_transform.cpp | 21 ------------ csrc/ops/utils.cpp | 25 +++++++++++--- csrc/root_domain_map.cpp | 9 +++++ test/test_dynamic_transform.cpp | 61 +++++++++++++++++++++++++++++++++ 4 files changed, 90 insertions(+), 26 deletions(-) diff --git a/csrc/dynamic_transform.cpp b/csrc/dynamic_transform.cpp index 9c6d9c3b914..cd518657f23 100644 --- a/csrc/dynamic_transform.cpp +++ b/csrc/dynamic_transform.cpp @@ -143,27 +143,6 @@ void DynamicTransformInfoBuilder::handle(TensorView* tv) { "Cannot evaluate the extent of a resized IterDomain: ", id->toString()); - auto in_id = op->in()->as(); - auto in_extent_val = expr_eval_->evaluate(in_id->extent()); - TORCH_INTERNAL_ASSERT( - in_extent_val.has_value(), - "Cannot evaluate the extent of input to an IterDomain resize: ", - in_id->toString()); - - auto left = op->leftExpand()->as(); - auto left_val = expr_eval_->evaluate(left); - TORCH_INTERNAL_ASSERT( - left_val.has_value(), - "Cannot evaluate the left expansion of an IterDomain resize: ", - left->toString()); - - auto right = op->rightExpand()->as(); - auto right_val = expr_eval_->evaluate(right); - TORCH_INTERNAL_ASSERT( - right_val.has_value(), - "Cannot evaluate the right expansion of an IterDomain resize: ", - right->toString()); - auto out_itertype = out_extent_val->as() == 1 ? IterType::Broadcast : IterType::Iteration; diff --git a/csrc/ops/utils.cpp b/csrc/ops/utils.cpp index 8d2c8d3fd2f..af5db71b855 100644 --- a/csrc/ops/utils.cpp +++ b/csrc/ops/utils.cpp @@ -186,13 +186,9 @@ IterType promoteIterType(IterType type1, IterType type2) { "Unexpected IterType: ", type2); - // If either is Iteration, the output type is also Iteration. If - // none of them is Iteration and either of them is Symbolic, the - // output is also Symbolic. + // If either is Iteration, the output type is also Iteration. if (type1 == IterType::Iteration || type2 == IterType::Iteration) { return IterType::Iteration; - } else if (type1 == IterType::Symbolic || type2 == IterType::Symbolic) { - return IterType::Symbolic; } else { return IterType::Broadcast; } @@ -236,6 +232,25 @@ std::vector newOutputDomain( " dimensions but expected ", out_domain.size()); for (const auto i : c10::irange(dom.size())) { + auto iter_type = dom[i]->getIterType(); + if (iter_types[i].has_value()) { + if (iter_types[i].value() == IterType::Symbolic) { + // If the best guess so far is that the output is Symbolic, then all + // the inputs must have been symbolic. If the current ID is Iteration, + // then we should prefer its extent instead of the Symbolic value. + if (iter_type == IterType::Iteration) { + extent_vals[i] = dom[i]->extent(); + } + } else if (iter_type == IterType::Symbolic) { + // If there is an input with a Symbolic ID and no other inputs, we + // can re-use its extent expression. Otherwise, a symbolic input ID + // could be either a Broadcast or an Iteration IterDomain, so its + // extent expression is not necessarily telling us the output extent. + // However, if there are any Iteration domains, we _can_ use their + // extent expressions, since they will resolve any broadcasts. + continue; + } + } if (dom[i]->isBroadcast()) { if (dom[i]->hasExpandedExtent()) { expanded_extent_vals[i] = diff --git a/csrc/root_domain_map.cpp b/csrc/root_domain_map.cpp index abb25ab449d..0feb2ee1703 100644 --- a/csrc/root_domain_map.cpp +++ b/csrc/root_domain_map.cpp @@ -137,6 +137,7 @@ std::unordered_map PairwiseRootDomainMap::map( // domains of torch_gather) // 3. Squeeze and unsqueeze // 4. Broadcast and non broadcast + // 5. Symbolic IDs // Condition 1: when the producer ID is the dim of a select-like op if (producer_id == indexed_producer_id) { @@ -187,6 +188,14 @@ std::unordered_map PairwiseRootDomainMap::map( continue; } + // Condition 5 + if (producer_id->getIterType() == IterType::Symbolic || + consumer_id->getIterType() == IterType::Symbolic) { + itc++; + itp++; + continue; + } + IterDomain* map_key_id = producer_id; IterDomain* map_value_id = consumer_id; if (!producer_to_consumer) { diff --git a/test/test_dynamic_transform.cpp b/test/test_dynamic_transform.cpp index 50566371655..d7ede791cf9 100644 --- a/test/test_dynamic_transform.cpp +++ b/test/test_dynamic_transform.cpp @@ -954,4 +954,65 @@ TEST_F(NVFuserTest, DynamicPadShmoo_CUDA) { reductionDynamicPadAddFusion(invocations); } +// Test dynamic pad followed by broadcast resolution +TEST_F(NVFuserTest, DynamicPadBroadcast_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + TensorView* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + TensorView* tv1 = makeSymbolicTensor(2); + fusion.addInput(tv1); + + // 2d axis order here is YX + auto ypad = IrBuilder::create(); + fusion.addInput(ypad); + auto xpad = IrBuilder::create(); + fusion.addInput(xpad); + + // two-way resizes to cut square tv down to broadcastable size in each axis + auto tv0_pad = pad(tv0, {fusion.zeroVal(), xpad, fusion.zeroVal(), ypad}); + + // This will potentially resolve the y or x broadcast + auto p = mul(tv0_pad, tv1); + fusion.addOutput(p); + fusion.addOutput(tv0_pad); + + fusion.printMath(); + + FusionExecutorCache fusion_executor_cache(std::move(fusion_ptr)); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor at_x = at::randn({5, 5}, options); + at::Tensor at_y = at::randn({5, 5}, options); + + // trivial resize + std::vector aten_inputs({at_x, at_y, 0, 0}); + std::vector outputs; + + /* + aten_inputs[2] = 0; + aten_inputs[3] = 0; + outputs = fusion_executor_cache.runFusionWithInputs(aten_inputs); + testValidate(fusion_executor_cache.fusion(), outputs, aten_inputs, {at_x * + at_y}, __LINE__, __FILE__); + */ + + // shrink first axis + aten_inputs[2] = -4; + aten_inputs[3] = 0; + outputs = fusion_executor_cache.runFusionWithInputs(aten_inputs); + std::cout << outputs << std::endl; + std::cout << at_x.slice(0, 0, 1) * at_y << std::endl; + std::cout << at_x.slice(0, 0, 1) << std::endl; + testValidate( + fusion_executor_cache.fusion(), + outputs, + aten_inputs, + {at_x.slice(0, 0, 1) * at_y, at_x.slice(0, 0, 1)}, + __LINE__, + __FILE__); +} + } // namespace nvfuser From 7ae44a4947eeea34960b93af0b4b17e95dadcf11 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 17 May 2023 19:30:20 -0400 Subject: [PATCH 2/9] Change how output iter_types are computed in newOutputDomain --- csrc/kernel_cache.cpp | 1 - csrc/ops/utils.cpp | 67 +++++++++++++++++++++++++++++++------------ 2 files changed, 49 insertions(+), 19 deletions(-) diff --git a/csrc/kernel_cache.cpp b/csrc/kernel_cache.cpp index 47eaa22b5fd..3c189b58212 100644 --- a/csrc/kernel_cache.cpp +++ b/csrc/kernel_cache.cpp @@ -439,7 +439,6 @@ FusionKernelRuntime* FusionExecutorCache::getKernelRuntimeFor( // these around during a subsequent fusion copy would lead to an attempt // to clone them, ending in a segfault. Instead, we reset the object // here, effectively as if it now describes a non-dynamic Fusion. - // cloned_conc_info.clear(); fusion->stopManaging(conc_info_index); } kernel_runtimes.emplace_back(std::make_unique( diff --git a/csrc/ops/utils.cpp b/csrc/ops/utils.cpp index af5db71b855..03e997ba59f 100644 --- a/csrc/ops/utils.cpp +++ b/csrc/ops/utils.cpp @@ -219,9 +219,10 @@ std::vector newOutputDomain( std::vector start_offsets(out_domain.size(), 0); std::vector stop_offsets(out_domain.size(), 0); std::vector extent_vals(out_domain.size(), nullptr); + std::vector mismatched_symbolic_extents(out_domain.size(), false); std::vector expanded_extent_vals(out_domain.size(), nullptr); - std::vector> iter_types( - out_domain.size(), c10::nullopt); + std::vector> iter_types( + out_domain.size(), std::nullopt); for (auto tv : tvs) { auto dom = TensorDomain::noReductions(tv->getMaybeRFactorDomain()); @@ -233,22 +234,46 @@ std::vector newOutputDomain( out_domain.size()); for (const auto i : c10::irange(dom.size())) { auto iter_type = dom[i]->getIterType(); - if (iter_types[i].has_value()) { - if (iter_types[i].value() == IterType::Symbolic) { - // If the best guess so far is that the output is Symbolic, then all - // the inputs must have been symbolic. If the current ID is Iteration, - // then we should prefer its extent instead of the Symbolic value. - if (iter_type == IterType::Iteration) { - extent_vals[i] = dom[i]->extent(); - } - } else if (iter_type == IterType::Symbolic) { - // If there is an input with a Symbolic ID and no other inputs, we - // can re-use its extent expression. Otherwise, a symbolic input ID - // could be either a Broadcast or an Iteration IterDomain, so its - // extent expression is not necessarily telling us the output extent. - // However, if there are any Iteration domains, we _can_ use their - // extent expressions, since they will resolve any broadcasts. - continue; + // If there is any Iteration domain, we should use the first one's + // extent. + // + // If all inputs are Symbolic or Broadcast, then we can use the + // symbolic extent if all the symbolic extents agree. + // + // Otherwise, we don't know the output extent and iter_type should be + // Symbolic if there are any Symbolic inputs else Broadcast. + if (iter_type == IterType::Iteration && iter_types[i].has_value() && + iter_types[i].value() == IterType::Symbolic) { + // Current is Iteration, previous was Symbolic. Erase symbolic since + // we'll go with the Iteration extent. + extent_vals[i] = nullptr; + } + if (iter_type == IterType::Symbolic && iter_types[i].has_value()) { + switch (iter_types[i].value()) { + case IterType::Iteration: + // Previously found Iteration domain, so ignore all Symbolic domains + continue; + case IterType::Symbolic: + if (extent_vals[i]->sameAs(dom[i]->extent())) { + // Found another matching symbolic domain + continue; + } else { + // Mismatched symbolic input extents => Don't know output extent + // TODO: set mismatch_symbolic for this axis + mismatched_symbolic_extents[i] = true; + } + break; + case IterType::Broadcast: + // Previously found only Broadcast. Output should be symbolic and + // default to extent of dom[i] + iter_types[i] = std::nullopt; + extent_vals[i] = nullptr; + break; + default: + TORCH_CHECK( + false, + "Encountered unexpected IterType when creating new output domain: ", + iter_types[i].value()); } } if (dom[i]->isBroadcast()) { @@ -283,6 +308,12 @@ std::vector newOutputDomain( } } for (const auto dim_i : c10::irange(out_domain.size())) { + if (iter_types[dim_i] == IterType::Symbolic && + mismatched_symbolic_extents[dim_i]) { + // if we have a symbolic output but the input symbolic extents did not + // match, create a new extent + extent_vals[dim_i] = nullptr; + } if (extent_vals[dim_i] != nullptr) { TORCH_INTERNAL_ASSERT( iter_types[dim_i].has_value(), From 19aca3e75e08daa8968b24f2334bf5377d1fd0f0 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 17 May 2023 19:30:51 -0400 Subject: [PATCH 3/9] Skip double mutating while concretizing root->rfactor --- csrc/dynamic_transform.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/csrc/dynamic_transform.cpp b/csrc/dynamic_transform.cpp index cd518657f23..d3f791e5ff6 100644 --- a/csrc/dynamic_transform.cpp +++ b/csrc/dynamic_transform.cpp @@ -400,6 +400,10 @@ void DynamicTransformConcretizer::mutate(TensorView* tv) { // Update the IterType of each output for (auto out_id : ir_utils::filterByType(expr->outputs())) { + if (mutations_.find(out_id) != mutations_.end()) { + // Skip outputs that have already been registered for mutation + continue; + } auto concretized_out_id = IterDomainBuilder(out_id).iter_type(iter_type).build(); registerMutation(out_id, concretized_out_id); From b581c1327f8aea0d0ccbd31ab5b97817d4d88af1 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 23 May 2023 11:45:56 -0400 Subject: [PATCH 4/9] Clean up logic in newOutputDomain --- csrc/ops/utils.cpp | 93 +++++++++++++++++++++++++++------------------- 1 file changed, 54 insertions(+), 39 deletions(-) diff --git a/csrc/ops/utils.cpp b/csrc/ops/utils.cpp index 03e997ba59f..37dc433959b 100644 --- a/csrc/ops/utils.cpp +++ b/csrc/ops/utils.cpp @@ -232,48 +232,61 @@ std::vector newOutputDomain( dom.size(), " dimensions but expected ", out_domain.size()); + // If there is any Iteration domain, we should use the first one's + // extent. + // + // If all inputs are Symbolic or Broadcast, then we can use the + // symbolic extent if all the symbolic extents agree. + // + // Otherwise, we don't know the output extent and iter_type should be + // Symbolic if there are any Symbolic inputs else Broadcast. for (const auto i : c10::irange(dom.size())) { auto iter_type = dom[i]->getIterType(); - // If there is any Iteration domain, we should use the first one's - // extent. - // - // If all inputs are Symbolic or Broadcast, then we can use the - // symbolic extent if all the symbolic extents agree. - // - // Otherwise, we don't know the output extent and iter_type should be - // Symbolic if there are any Symbolic inputs else Broadcast. - if (iter_type == IterType::Iteration && iter_types[i].has_value() && - iter_types[i].value() == IterType::Symbolic) { - // Current is Iteration, previous was Symbolic. Erase symbolic since - // we'll go with the Iteration extent. - extent_vals[i] = nullptr; - } - if (iter_type == IterType::Symbolic && iter_types[i].has_value()) { - switch (iter_types[i].value()) { - case IterType::Iteration: - // Previously found Iteration domain, so ignore all Symbolic domains - continue; - case IterType::Symbolic: - if (extent_vals[i]->sameAs(dom[i]->extent())) { - // Found another matching symbolic domain + if (iter_types[i].has_value()) { + // Clang-tidy complains about unchecked access to optional value here + // NOLINTNEXTLINE(bugprone-unchecked-optional-access,-warnings-as-errors) + auto prev_iter_type = iter_types[i].value(); + if (iter_type == IterType::Iteration && + prev_iter_type == IterType::Symbolic) { + // Prefer the Iteration extent, since Symbolic could be broadcast + extent_vals[i] = nullptr; + } else if (iter_type == IterType::Symbolic) { + switch (prev_iter_type) { + case IterType::Iteration: + // Previously found Iteration domain, so ignore all Symbolic + // domains continue; - } else { - // Mismatched symbolic input extents => Don't know output extent - // TODO: set mismatch_symbolic for this axis - mismatched_symbolic_extents[i] = true; - } - break; - case IterType::Broadcast: - // Previously found only Broadcast. Output should be symbolic and - // default to extent of dom[i] - iter_types[i] = std::nullopt; - extent_vals[i] = nullptr; - break; - default: - TORCH_CHECK( - false, - "Encountered unexpected IterType when creating new output domain: ", - iter_types[i].value()); + case IterType::Symbolic: + if (extent_vals[i]->sameAs(dom[i]->extent())) { + // matching symbolic extent + continue; + } else { + // Mismatched symbolic input extents. Any one of the symbolic + // inputs could be a Broadcast or Iteration domain. Until + // concretization, we will not know which one holds the true + // extent (or whether they all are Broadcast, so that the output + // is also Broadcast). We record that these symbolic extents + // mismatched so that we can introduce a new symbolic extent + // later. + mismatched_symbolic_extents[i] = true; + } + break; + case IterType::Broadcast: + // Previously found only broadcast, so this will either also + // broadcast or resolve those broadcasts. If the expanded + // extent of any of the broadcasts is not 1, then it will need to + // match that of the dom[i]. In either case, prefer dom[i]'s + // extent, so clear iter_types[i] and extent_vals[i] so that the + // rest of this iteration will mark output as Symbolic. + iter_types[i] = std::nullopt; + extent_vals[i] = nullptr; + break; + default: + TORCH_CHECK( + false, + "Encountered unexpected IterType when creating new output domain: ", + prev_iter_type); + } } } if (dom[i]->isBroadcast()) { @@ -286,6 +299,7 @@ std::vector newOutputDomain( extent_vals[i] = promoteSize(extent_vals[i], dom[i]->extent()); if (iter_types[i].has_value()) { iter_types[i] = + // NOLINTNEXTLINE(bugprone-unchecked-optional-access,-warnings-as-errors) promoteIterType(iter_types[i].value(), dom[i]->getIterType()); } else { iter_types[i] = dom[i]->getIterType(); @@ -322,6 +336,7 @@ std::vector newOutputDomain( IterDomainBuilder( IrBuilder::create(start_offsets[dim_i]), extent_vals[dim_i]) .stop_offset(IrBuilder::create(stop_offsets[dim_i])) + // NOLINTNEXTLINE(bugprone-unchecked-optional-access,-warnings-as-errors) .iter_type(iter_types[dim_i].value()) .build(); } else { From e304aed8a85f504dd61c7b1465916b023968b062 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 23 May 2023 11:53:30 -0400 Subject: [PATCH 5/9] Clean up clang-tidy warnings on vector of optionals --- csrc/ops/utils.cpp | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/csrc/ops/utils.cpp b/csrc/ops/utils.cpp index 3f1a0201deb..2da449063c1 100644 --- a/csrc/ops/utils.cpp +++ b/csrc/ops/utils.cpp @@ -242,16 +242,15 @@ std::vector newOutputDomain( // Symbolic if there are any Symbolic inputs else Broadcast. for (const auto i : c10::irange(dom.size())) { auto iter_type = dom[i]->getIterType(); - if (iter_types[i].has_value()) { + auto prev_iter_type = iter_types[i]; + if (prev_iter_type.has_value()) { // Clang-tidy complains about unchecked access to optional value here - // NOLINTNEXTLINE(bugprone-unchecked-optional-access,-warnings-as-errors) - auto prev_iter_type = iter_types[i].value(); if (iter_type == IterType::Iteration && - prev_iter_type == IterType::Symbolic) { + prev_iter_type.value() == IterType::Symbolic) { // Prefer the Iteration extent, since Symbolic could be broadcast extent_vals[i] = nullptr; } else if (iter_type == IterType::Symbolic) { - switch (prev_iter_type) { + switch (prev_iter_type.value()) { case IterType::Iteration: // Previously found Iteration domain, so ignore all Symbolic // domains @@ -285,7 +284,7 @@ std::vector newOutputDomain( TORCH_CHECK( false, "Encountered unexpected IterType when creating new output domain: ", - prev_iter_type); + prev_iter_type.value()); } } } @@ -297,10 +296,9 @@ std::vector newOutputDomain( continue; } extent_vals[i] = promoteSize(extent_vals[i], dom[i]->extent()); - if (iter_types[i].has_value()) { + if (prev_iter_type.has_value()) { iter_types[i] = - // NOLINTNEXTLINE(bugprone-unchecked-optional-access,-warnings-as-errors) - promoteIterType(iter_types[i].value(), dom[i]->getIterType()); + promoteIterType(prev_iter_type.value(), dom[i]->getIterType()); } else { iter_types[i] = dom[i]->getIterType(); } @@ -322,22 +320,21 @@ std::vector newOutputDomain( } } for (const auto dim_i : c10::irange(out_domain.size())) { - if (iter_types[dim_i] == IterType::Symbolic && - mismatched_symbolic_extents[dim_i]) { + auto iter_type = iter_types[dim_i]; + if (iter_type == IterType::Symbolic && mismatched_symbolic_extents[dim_i]) { // if we have a symbolic output but the input symbolic extents did not // match, create a new extent extent_vals[dim_i] = nullptr; } if (extent_vals[dim_i] != nullptr) { TORCH_INTERNAL_ASSERT( - iter_types[dim_i].has_value(), + iter_type.has_value(), "Could not deduce iter type for new tensor view."); out_domain[dim_i] = IterDomainBuilder( IrBuilder::create(start_offsets[dim_i]), extent_vals[dim_i]) .stop_offset(IrBuilder::create(stop_offsets[dim_i])) - // NOLINTNEXTLINE(bugprone-unchecked-optional-access,-warnings-as-errors) - .iter_type(iter_types[dim_i].value()) + .iter_type(iter_type.value()) .build(); } else { out_domain[dim_i] = IterDomainBuilder( From bd0d7c3a65612df9477aad59275027b288edca6b Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 23 May 2023 12:14:55 -0400 Subject: [PATCH 6/9] Undo change to promoteIterType --- csrc/ops/utils.cpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/csrc/ops/utils.cpp b/csrc/ops/utils.cpp index 2da449063c1..c89485009b2 100644 --- a/csrc/ops/utils.cpp +++ b/csrc/ops/utils.cpp @@ -186,9 +186,13 @@ IterType promoteIterType(IterType type1, IterType type2) { "Unexpected IterType: ", type2); - // If either is Iteration, the output type is also Iteration. + // If either is Iteration, the output type is also Iteration. If + // none of them is Iteration and either of them is Symbolic, the + // output is also Symbolic. if (type1 == IterType::Iteration || type2 == IterType::Iteration) { return IterType::Iteration; + } else if (type1 == IterType::Symbolic || type2 == IterType::Symbolic) { + return IterType::Symbolic; } else { return IterType::Broadcast; } From 761c9c0a353cb96241e128959dbb96da057627f5 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 25 May 2023 16:52:39 -0400 Subject: [PATCH 7/9] Handle {Cat,Pad,Slice}Op in ConcretizedBroadcastDomains With this, now failing at [E thread_pool.cpp:109] Exception in thread pool task: it != bcast_map_.end() INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/root_domain_map.cpp":630, please report a bug to PyTorch. Not found: iS4{i0}rf in T2[ iS4{i0}rf iS6{i2}rf ] (Rfactor: [ bS10{( i0 + i5 )}rf iS11{( i2 + i6 )}rf ]) --- .../analysis/trivial_broadcast.cpp | 53 +++++++++++++++++++ .../device_lower/analysis/trivial_broadcast.h | 7 +++ csrc/ir/internal_nodes.h | 4 ++ csrc/root_domain_map.cpp | 4 +- 4 files changed, 67 insertions(+), 1 deletion(-) diff --git a/csrc/device_lower/analysis/trivial_broadcast.cpp b/csrc/device_lower/analysis/trivial_broadcast.cpp index 6993699632b..27f148e10a8 100644 --- a/csrc/device_lower/analysis/trivial_broadcast.cpp +++ b/csrc/device_lower/analysis/trivial_broadcast.cpp @@ -64,6 +64,59 @@ void ConcretizedBroadcastDomains::handle(BroadcastOp* bop) { } } +void ConcretizedBroadcastDomains::handle(CatOp* op) { + auto id = + op->out()->as()->getMaybeRFactorDomain().at(op->concatenatedDim()); + if (id->isBroadcast()) { + broadcast_origin_map_.emplace(id, std::unordered_set({id})); + } +} + +void ConcretizedBroadcastDomains::handle(PadOp* op) { + std::cout << "handle(PadOp* " << op->toString() << ")" << std::endl; + for (auto i : op->getPaddedAxes()) { + std::cout << "padded axis " << i << std::endl; + // Instead of the root domain of the output, as with BroadcastOp, we set the + // origin as the RFactor domain, since PadOp inserts Resize ops between root + // and rfactor + auto id = op->out()->as()->getMaybeRFactorDomain().at(i); + std::cout << "id = " << id->toString() << std::endl; + if (id->isBroadcast()) { + std::cout << "broadcast_origin_map_.emplace("; + std::cout << id->toString() << ", std::unordered_set({"; + std::cout << "id"/*id->toString()*/ << "}));" << std::endl;; + broadcast_origin_map_.emplace(id, std::unordered_set({id})); + } + } +} + +void ConcretizedBroadcastDomains::handle(SliceOp* op) { + auto consumer_root = op->out()->as()->getMaybeRFactorDomain(); + auto producer_rfactor = TensorDomain::noReductions( + op->in()->as()->getMaybeRFactorDomain()); + TORCH_INTERNAL_ASSERT( + consumer_root.size() == producer_rfactor.size(), + "Consumer root size ", + consumer_root.size(), + " does not match producer rfactor size ", + producer_rfactor.size()); + for (auto i : c10::irange(consumer_root.size())) { + auto cid = consumer_root.at(i); + auto pid = producer_rfactor.at(i); + if (cid->isBroadcast()) { + // Map to producer ID if it was already broadcast. Otherwise to consumer + // ID + if (pid->isBroadcast()) { + broadcast_origin_map_.emplace( + pid, std::unordered_set({cid, pid})); + } else { + broadcast_origin_map_.emplace( + cid, std::unordered_set({cid})); + } + } + } +} + void ConcretizedBroadcastDomains::handle(Expr* expr) { IterVisitor::handle(expr); diff --git a/csrc/device_lower/analysis/trivial_broadcast.h b/csrc/device_lower/analysis/trivial_broadcast.h index 841b23c501f..ab94641d40a 100644 --- a/csrc/device_lower/analysis/trivial_broadcast.h +++ b/csrc/device_lower/analysis/trivial_broadcast.h @@ -47,6 +47,13 @@ class TORCH_CUDA_CU_API ConcretizedBroadcastDomains : private IterVisitor { void handle(BroadcastOp* bop) final; + // After concretization, ops with Resized IterDomains in their outputs may set + // the broadcast flag, even though they are not BroadcastOps themselves. In + // these cases, we set the output as the origin. + void handle(CatOp* op) final; + void handle(PadOp* op) final; + void handle(SliceOp* op) final; + void handle(Expr* expr) final; void markAsConcretized( diff --git a/csrc/ir/internal_nodes.h b/csrc/ir/internal_nodes.h index 4e99d5f58f2..f838b185d8f 100644 --- a/csrc/ir/internal_nodes.h +++ b/csrc/ir/internal_nodes.h @@ -1860,6 +1860,10 @@ class TORCH_CUDA_CU_API CatOp : public Expr { return attribute(0)->as>()->value; } + Val* out() const { + return output(0); + } + //! The index val that determines which input tensor should be used //! to fill the particular output position of this expression. Only //! valid after indexing diff --git a/csrc/root_domain_map.cpp b/csrc/root_domain_map.cpp index 94644550ce0..3951cf41af9 100644 --- a/csrc/root_domain_map.cpp +++ b/csrc/root_domain_map.cpp @@ -870,7 +870,9 @@ void ComputeAtRootDomainMapBuilder::setMaybeMapped( } if (consumer_id->isBroadcast()) { - TORCH_INTERNAL_ASSERT(producer_id->isBroadcast()); + // Note that consumer may be broadcast even though producer is not if it is + // the output of a Resize op. + // Get bcast_map_ entry for consumer_id const auto consumer_bcast_domains = root_map_.getConcretizedKeys(consumer_td, consumer_id); From 9cbfa7ebd0518f3ccd909c90d918bd96856934b1 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Fri, 2 Jun 2023 11:12:25 -0400 Subject: [PATCH 8/9] Add mapSymbolic option to PairwiseRootDomainMap Using this, we can enable mapping of Symbolic IterDomains when doing p2c propagation in concretization, and still keep it disabled for the purposes of exact mapping. This fixes some errors but I am now seeing errors like > [E thread_pool.cpp:110] Exception in thread pool task: is_overriden || index_map.find(alloc_dom[i]) != index_map.end() INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/index_compute.cpp":1634, --- csrc/dynamic_transform.cpp | 4 +++- csrc/root_domain_map.cpp | 5 +++-- csrc/root_domain_map.h | 7 +++++++ 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/csrc/dynamic_transform.cpp b/csrc/dynamic_transform.cpp index b7bc221392d..f56691c2f14 100644 --- a/csrc/dynamic_transform.cpp +++ b/csrc/dynamic_transform.cpp @@ -542,7 +542,8 @@ void DynamicTransformConcretizer::mutate(TensorView* tv) { // Update the IterType of each output for (auto out_id : ir_utils::filterByType(expr->outputs())) { - if (!out_id->isSymbolic() || mutations_.find(out_id) != mutations_.end()) { + if (!out_id->isSymbolic() || + mutations_.find(out_id) != mutations_.end()) { // Skip symbolic outputs and outputs that have already been registered // for mutation continue; @@ -646,6 +647,7 @@ bool DynamicTransformConcretizer::propagateFromProducerToConsumer( for (auto producer : ir_utils::filterByType(def->inputs())) { PairwiseRootDomainMap root_map(producer, consumer); + root_map.mapSymbolic(true); auto c2p = root_map.mapConsumerToProducer( consumer->domain(), producer->domain()); diff --git a/csrc/root_domain_map.cpp b/csrc/root_domain_map.cpp index 3951cf41af9..aad690614be 100644 --- a/csrc/root_domain_map.cpp +++ b/csrc/root_domain_map.cpp @@ -183,8 +183,9 @@ std::unordered_map PairwiseRootDomainMap::map( } // Condition 5 - if (producer_id->getIterType() == IterType::Symbolic || - consumer_id->getIterType() == IterType::Symbolic) { + if (!map_symbolic_ && + (producer_id->getIterType() == IterType::Symbolic || + consumer_id->getIterType() == IterType::Symbolic)) { itc++; itp++; continue; diff --git a/csrc/root_domain_map.h b/csrc/root_domain_map.h index a3e7a30e977..ff0728b7e33 100644 --- a/csrc/root_domain_map.h +++ b/csrc/root_domain_map.h @@ -100,6 +100,11 @@ class TORCH_CUDA_CU_API PairwiseRootDomainMap : public RootDomainMap { return *this; } + PairwiseRootDomainMap& mapSymbolic(bool b) { + map_symbolic_ = b; + return *this; + } + PairwiseRootDomainMap& mapDifferentExtents(bool b) { map_different_extents_ = b; return *this; @@ -136,6 +141,8 @@ class TORCH_CUDA_CU_API PairwiseRootDomainMap : public RootDomainMap { //! Map broadcast and non-broadcast domains. Note that this is on by //! default bool map_broadcast_ = true; + //! Map symbolic domains with one another. + bool map_symbolic_ = false; //! Map domains that may have different extents, e.g., torch_gather bool map_different_extents_ = false; //! Map domains that are indirectly accessed, e.g., index_select From 89be4e45c933781bf761835d20d9025dc701755b Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Fri, 2 Jun 2023 11:19:53 -0400 Subject: [PATCH 9/9] Remove debug print stmts --- csrc/device_lower/analysis/trivial_broadcast.cpp | 6 ------ 1 file changed, 6 deletions(-) diff --git a/csrc/device_lower/analysis/trivial_broadcast.cpp b/csrc/device_lower/analysis/trivial_broadcast.cpp index 27f148e10a8..677ac93e834 100644 --- a/csrc/device_lower/analysis/trivial_broadcast.cpp +++ b/csrc/device_lower/analysis/trivial_broadcast.cpp @@ -73,18 +73,12 @@ void ConcretizedBroadcastDomains::handle(CatOp* op) { } void ConcretizedBroadcastDomains::handle(PadOp* op) { - std::cout << "handle(PadOp* " << op->toString() << ")" << std::endl; for (auto i : op->getPaddedAxes()) { - std::cout << "padded axis " << i << std::endl; // Instead of the root domain of the output, as with BroadcastOp, we set the // origin as the RFactor domain, since PadOp inserts Resize ops between root // and rfactor auto id = op->out()->as()->getMaybeRFactorDomain().at(i); - std::cout << "id = " << id->toString() << std::endl; if (id->isBroadcast()) { - std::cout << "broadcast_origin_map_.emplace("; - std::cout << id->toString() << ", std::unordered_set({"; - std::cout << "id"/*id->toString()*/ << "}));" << std::endl;; broadcast_origin_map_.emplace(id, std::unordered_set({id})); } }