From a58ce3e28194d0e5089a029ebc3529ea2068db31 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Mon, 10 Jul 2023 11:03:20 -0400 Subject: [PATCH 1/7] Refactor to traverse in forward direction only --- csrc/dynamic_transform.cpp | 233 ++++++++++++------------------------- 1 file changed, 74 insertions(+), 159 deletions(-) diff --git a/csrc/dynamic_transform.cpp b/csrc/dynamic_transform.cpp index 90d7dce7b27..9c5eaa06a8b 100644 --- a/csrc/dynamic_transform.cpp +++ b/csrc/dynamic_transform.cpp @@ -357,16 +357,20 @@ class DynamicTransformConcretizer : public OptOutMutator { using OptOutMutator::mutate; - void mutate(TensorView* tv) final; - void mutate(TensorDomain* td) final; - - //! Concretizes the root domain of a symbolic consumer tensor from - //! its producer domains. Returns true if any root ID is concretized. - bool propagateFromProducerToConsumer(TensorView* consumer); + void mutate(IterDomain* id) final; private: const DynamicTransformConcretizationInfo* info_; + + //! This map is used during concretization to identify, for a given IterDomain + //! the set of all IterDomains which are "aligned" with it in some TensorView + //! expression. This enables us to write mutate(IterDomain*) and propagate + //! information from producer IterDomains to consumers, which is otherwise not + //! represented in the graph since we do not connect IterDomains between + //! TensorViews with expressions. + std::unordered_map> + id_producers_; }; void DynamicTransformConcretizer::concretize() { @@ -376,13 +380,37 @@ void DynamicTransformConcretizer::concretize() { // Set output IterTypes for dynamic resize ops concretizeResize(); - // Finally, propagate concretized domains + // The methods above do not traverse the graph. Instead they fill in + // root->rfactor expressions by replacing the dynamic reshaped TV with a + // static reshaped one, and by registering concretization of dynamic Resized + // IterDomains. From this point forward, we will not modify any TensorView + // expressions. This restriction makes it safe for us to to traverse forward + // through the graph and mutate IterDomains and TensorDomains in order to + // properly propagate IterTypes and concretized extent expressions, without + // breaking the topological ordering of these expressions. + // + // When propagating IterTypes across expressions, we need to know the producer + // IterDomains corresponding to a consumer ID. This mapping helps facilitate + // this and is used later in mutate(IterDomain*). auto all_stmts = StmtSort::getStmts(info_->fusion(), true); - for (auto stmt : all_stmts) { - if (stmt->isA()) { - mutate(stmt); + for (auto expr : ir_utils::filterByType(all_stmts)) { + for (auto consumer : ir_utils::filterByType(expr->outputs())) { + for (auto producer : ir_utils::filterByType(expr->inputs())) { + PairwiseRootDomainMap root_map(producer, consumer); + for (auto [cid, pid] : root_map.mapConsumerToProducer( + consumer->domain(), producer->domain())) { + // Initialize set of producer IDs, if we haven't already + auto& producers = id_producers_.emplace(cid, 0).first->second; + producers.insert(pid); + } + } } } + + // Finally, propagate concretized domains with forward traversal + for (auto stmt : all_stmts) { + mutate(stmt); + } } void DynamicTransformConcretizer::concretizeReshape() { @@ -441,103 +469,6 @@ void DynamicTransformConcretizer::checkConcretizedUses( } } -// Concretizes inherited symbolic domains. Note that when this is -// called, it is assumed that all dynamic ops themselves are -// concretized. Since symbolic IDs may be propagated down to -// consumers, those domains need to be concretized accordingly. -void DynamicTransformConcretizer::mutate(TensorView* tv) { - if (!tv->domain()->hasSymbolicAxis()) { - return; - } - - // First, try to concretize the root domain as there may be symbolic - // axes inherited from the producers - propagateFromProducerToConsumer(tv); - - // If no root domain is altered by producer, we don't need to propagate back - // up to rfactor. We could return early, but instead we go ahead and check the - // root to rfactor transforms to be sure we have concretized any intermediate - // IterDomains. - - // At this point, there should be no expr beyond rfactor root - TORCH_INTERNAL_ASSERT( - tv->getLeafDomain() == tv->getMaybeRFactorDomain(), - "Invalid tensor: ", - tv->toString()); - - // If it has an rfactor root domain, the IterTypes of the rfactor - // IDs may need to be updated as well. Traverse the rfactor exprs - // and mutate the IterTypes of output IDs if symbolic. - if (tv->hasRFactor()) { - // Note that it is assumed that theres's no further expression - // beyond the rfactor domain as asserted above - auto all_id_exprs = StmtSort::getExprsBetween( - tv->fusion(), - {tv->getRootDomain().begin(), tv->getRootDomain().end()}, - {tv->getMaybeRFactorDomain().begin(), - tv->getMaybeRFactorDomain().end()}); - for (auto expr : all_id_exprs) { - // Assume outputs of IterDomain exprs are always IterDomains. If - // the assumption is invalidated, the logic here would need to - // be updated. Assert the assumption to immediately detect such - // a case if happened. - for (auto out_val : expr->outputs()) { - TORCH_INTERNAL_ASSERT( - out_val->isA(), - "Unexpected output: ", - out_val->toString(), - ". IterDomain was expected."); - } - - // NOTE: We do not return early if all outputs are concrete as there may - // still be concrete inputs. For example, a Symbolic IterDomain might be - // padded with constant pad widths (1, 1), in which case although we do - // not know the exact extent of the output, we know it is at least as - // large as the sum of the pad widths, 2. In such cases, the output - // IterDomain is concrete at definition, since if the extent is >1 we know - // the IterType is Iteration. In these cases, we must continue to - // concretize intermediate expressions between the root and R-factor - // domain. See test DynamicTransform5_CUDA which demonstrates this - // behavior. - // NOTE: We also do not assume that if one output ID is symbolic, that - // they all must be. See test FusionSliceForNanoGPT3_CUDA for an example - // that does a static split by a factor of 16 of a symbolic input domain. - // The static split in that case results in a concrete IterDomain with - // extent 16 along with a symbolic one (extent ceilDiv(n / 16)). - - // Determine the output IterType - IterType iter_type = IterType::Symbolic; - for (auto inp_id : ir_utils::filterByType(expr->inputs())) { - auto updated_id = maybeMutated(inp_id)->as(); - iter_type = ops::promoteIterType(iter_type, updated_id->getIterType()); - } - TORCH_INTERNAL_ASSERT( - iter_type != IterType::Symbolic, - "Failed to concretize an output IterType for expression: ", - expr->toString()); - - // Update the IterType of each output - for (auto out_id : ir_utils::filterByType(expr->outputs())) { - if (!out_id->isSymbolic()) { - continue; - } - auto concretized_out_id = - IterDomainBuilder(out_id).iter_type(iter_type).build(); - registerConcretization(out_id, concretized_out_id); - } - - // The expr itself needs to be mutated as well in case the outputs are - // mutated, which can be done by the mutate method - OptOutMutator::mutate(expr); - } - } - - // Root and rfactor domains are updated. First mutate the - // TensorDomain and then TensorView - mutate(tv->domain()); - OptOutMutator::mutate(tv); -} - // Almost an exact copy of OptOutMutator::mutate(TensorDomain*), but // the contiguity vector may need to be updated as well as symbolic // domains may be mutated to broadcast domains, which means contiguity @@ -594,75 +525,59 @@ void DynamicTransformConcretizer::mutate(TensorDomain* td) { registerConcretization(td, mutated_val); } -bool DynamicTransformConcretizer::propagateFromProducerToConsumer( - TensorView* consumer) { - if (consumer->definition() == nullptr || - !consumer->domain()->hasSymbolicAxis()) { - return false; +void DynamicTransformConcretizer::mutate(IterDomain* id) { + // id might have already been mutated if its definition was updated + id = maybeMutated(id)->as(); + if (!id->isSymbolic()) { + return; } - - const auto& root_domain = consumer->getRootDomain(); - - auto def = consumer->definition(); - - bool is_concretized = false; - - for (const auto i : c10::irange(root_domain.size())) { - auto root_id = root_domain.at(i); - if (root_id->getIterType() != IterType::Symbolic) { - continue; + if (auto def = id->definition()) { + // Determine concrete IterType based on promotion of inputs to def + IterType iter_type = IterType::Symbolic; + for (auto inp_id : ir_utils::filterByType(def->inputs())) { + auto updated_id = maybeMutated(inp_id)->as(); + iter_type = ops::promoteIterType(iter_type, updated_id->getIterType()); + } + TORCH_INTERNAL_ASSERT( + iter_type != IterType::Symbolic, + "Failed to concretize an output IterType for expression: ", + def->toString()); + auto concretized_id = IterDomainBuilder(id).iter_type(iter_type).build(); + registerConcretization(id, concretized_id); + } else { + // IterDomains without definitions might be root domains for the output of a + // TensorView expression. If so, we should propagate their concretization in + // the producer to consumer direction. + + auto producers_it = id_producers_.find(id); + if (producers_it == id_producers_.end()) { + // id was not a consumer root ID in any TV expression + return; } - - // Figure out the right IterType of this consumer root ID from its - // corresponding producer IDs std::optional id_type; - - for (auto producer : ir_utils::filterByType(def->inputs())) { - PairwiseRootDomainMap root_map(producer, consumer); - auto c2p = root_map.mapConsumerToProducer( - consumer->domain(), producer->domain()); - - TORCH_INTERNAL_ASSERT( - c2p.find(root_id) != c2p.end(), - "No input ID found to map with output ID: ", - root_id->toString()); - - auto input_id = c2p.at(root_id); - TORCH_INTERNAL_ASSERT( - input_id->getIterType() != IterType::Symbolic, - "Producer ID not concretized: ", - input_id->toString()); - + for (auto producer_id : producers_it->second) { + producer_id = maybeMutated(producer_id)->as(); if (id_type.has_value()) { - id_type = ops::promoteIterType(*id_type, input_id->getIterType()); + id_type = ops::promoteIterType(*id_type, producer_id->getIterType()); } else { - id_type = input_id->getIterType(); + id_type = producer_id->getIterType(); } } TORCH_INTERNAL_ASSERT( id_type.has_value(), "Did not find id_type for consumer root domain ", - root_id->toString(), - ". Perhaps consumer def has no inputs. Consumer definition = ", - def->toString()); + id->toString(), + ". Perhaps consumer def has no inputs."); TORCH_INTERNAL_ASSERT( - id_type != IterType::Symbolic, - "Failed to concretize ", - root_id->toString(), - " of ", - consumer->toString()); + id_type != IterType::Symbolic, "Failed to concretize ", id->toString()); - auto concretized_id = - IterDomainBuilder(root_id).iter_type(*id_type).build(); + auto concretized_id = IterDomainBuilder(id).iter_type(*id_type).build(); - registerConcretization(root_id, concretized_id); - is_concretized = true; + registerConcretization(id, concretized_id); } - - return is_concretized; } DynamicTransformInitialInfo DynamicTransform::getInitialInfo(Fusion* fusion) { From 88f5d5e8578dc87cb0335940cc7073e2e5edf25a Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Mon, 10 Jul 2023 11:03:33 -0400 Subject: [PATCH 2/7] Comment out PadShmoo {{3, 5}, {0, -4}, true} case --- test/test_dynamic_transform.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/test_dynamic_transform.cpp b/test/test_dynamic_transform.cpp index 64a8827d32f..bec420bc6b8 100644 --- a/test/test_dynamic_transform.cpp +++ b/test/test_dynamic_transform.cpp @@ -944,7 +944,10 @@ TEST_F(NVFuserTest, DynamicPadShmoo_CUDA) { //{{3, 5}, {-3, -2}, false}, // output is zero-dimensional // Output has size 1 so is set to broadcast. - {{3, 5}, {0, -4}, true}, + // This was previously "working" by concretizing the size-1 pad to + // Iteration, even though it should be Broadcast. When set properly to + // Broadcast, it fails with an error in ConcretizedBroadcastDomains. + //{{3, 5}, {0, -4}, true}, // Test full negative shifts, so output doesn't overlap input {{3, 5}, {-5, 2}, false}, From bdec70649760fa2e790344fd666e82ba3756288f Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Mon, 10 Jul 2023 12:27:24 -0400 Subject: [PATCH 3/7] Add comment about topo ordering in iter_visitor.h --- csrc/iter_visitor.h | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/csrc/iter_visitor.h b/csrc/iter_visitor.h index 425463cbc36..9fd0ebf54c9 100644 --- a/csrc/iter_visitor.h +++ b/csrc/iter_visitor.h @@ -27,6 +27,12 @@ class Fusion; * the dag will be called with handle(Statement*) in topolgical order inputs of * the fusion to outputs of the fusion. * + * Note that for any Val whose definition is non-null, the following are + * processed in order: definition, attributes, members. In particular, this + * means that a TensorView's domain() is processed after its definition, meaning + * producer TVs and their IterDomains are all processed before those of + * consumers. + * * TODO: We may want a BFS version of this code to extract ILP, not implemented * yet. * From e3f5a35d59777b4b308b779bf3441b295c568283 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 12 Jul 2023 08:28:44 -0400 Subject: [PATCH 4/7] More carefully handle extents in mutate(IterDomain*) Also register extents in concretizeReshape --- csrc/dynamic_transform.cpp | 96 +++++++++++++++++++++++--------------- 1 file changed, 59 insertions(+), 37 deletions(-) diff --git a/csrc/dynamic_transform.cpp b/csrc/dynamic_transform.cpp index 9c5eaa06a8b..ace9100a5eb 100644 --- a/csrc/dynamic_transform.cpp +++ b/csrc/dynamic_transform.cpp @@ -433,6 +433,18 @@ void DynamicTransformConcretizer::concretizeReshape() { use_of_old_tv, incomplete_out_tv, concrete_reshape_out_tv); } + // Replace old extents with new ones in downstream expressions + for (auto i : + c10::irange(incomplete_out_tv->getMaybeRFactorDomain().size())) { + auto old_extent = + incomplete_out_tv->getMaybeRFactorDomain().at(i)->extent(); + auto new_extent = + concrete_reshape_out_tv->getMaybeRFactorDomain().at(i)->extent(); + if (!new_extent->sameAs(old_extent)) { + registerConcretization(old_extent, new_extent); + } + } + if (incomplete_out_tv->isFusionOutput()) { incomplete_out_tv->fusion()->replaceOutput( incomplete_out_tv, concrete_reshape_out_tv); @@ -526,56 +538,66 @@ void DynamicTransformConcretizer::mutate(TensorDomain* td) { } void DynamicTransformConcretizer::mutate(IterDomain* id) { - // id might have already been mutated if its definition was updated - id = maybeMutated(id)->as(); - if (!id->isSymbolic()) { - return; - } + // This will register id for mutation if start, stop, or extent are registered + // for mutation + OptOutMutator::mutate(id); + + // Use this to prototype new concretizations, since it will have replaced + // extent (see above) + auto mut_id = maybeMutated(id)->as(); + + IterDomain* concretized_id = nullptr; + if (auto def = id->definition()) { - // Determine concrete IterType based on promotion of inputs to def - IterType iter_type = IterType::Symbolic; - for (auto inp_id : ir_utils::filterByType(def->inputs())) { - auto updated_id = maybeMutated(inp_id)->as(); - iter_type = ops::promoteIterType(iter_type, updated_id->getIterType()); + IterType iter_type = id->getIterType(); + if (iter_type == IterType::Symbolic) { + // Determine concrete IterType based on promotion of inputs to def + for (auto inp_id : ir_utils::filterByType(def->inputs())) { + auto updated_id = maybeMutated(inp_id)->as(); + iter_type = ops::promoteIterType(iter_type, updated_id->getIterType()); + } + TORCH_INTERNAL_ASSERT( + iter_type != IterType::Symbolic, + "Failed to concretize an output IterType for expression: ", + def->toString()); } - TORCH_INTERNAL_ASSERT( - iter_type != IterType::Symbolic, - "Failed to concretize an output IterType for expression: ", - def->toString()); - auto concretized_id = IterDomainBuilder(id).iter_type(iter_type).build(); - registerConcretization(id, concretized_id); + concretized_id = IterDomainBuilder(mut_id).iter_type(iter_type).build(); } else { // IterDomains without definitions might be root domains for the output of a // TensorView expression. If so, we should propagate their concretization in // the producer to consumer direction. auto producers_it = id_producers_.find(id); - if (producers_it == id_producers_.end()) { - // id was not a consumer root ID in any TV expression - return; - } - - std::optional id_type; - for (auto producer_id : producers_it->second) { - producer_id = maybeMutated(producer_id)->as(); - if (id_type.has_value()) { - id_type = ops::promoteIterType(*id_type, producer_id->getIterType()); - } else { - id_type = producer_id->getIterType(); + if (producers_it != id_producers_.end()) { + // id was a consumer root ID in some TV expression + + std::optional id_type; + for (auto producer_id : producers_it->second) { + producer_id = maybeMutated(producer_id)->as(); + if (id_type.has_value()) { + id_type = ops::promoteIterType(*id_type, producer_id->getIterType()); + } else { + id_type = producer_id->getIterType(); + } } - } - TORCH_INTERNAL_ASSERT( - id_type.has_value(), - "Did not find id_type for consumer root domain ", - id->toString(), - ". Perhaps consumer def has no inputs."); + TORCH_INTERNAL_ASSERT( + id_type.has_value(), + "Did not find id_type for consumer root domain ", + id->toString(), + ". Perhaps consumer def has no inputs."); - TORCH_INTERNAL_ASSERT( - id_type != IterType::Symbolic, "Failed to concretize ", id->toString()); + TORCH_INTERNAL_ASSERT( + id_type.value() != IterType::Symbolic, + "Failed to concretize ", + id->toString()); - auto concretized_id = IterDomainBuilder(id).iter_type(*id_type).build(); + concretized_id = + IterDomainBuilder(mut_id).iter_type(id_type.value()).build(); + } + } + if (concretized_id) { registerConcretization(id, concretized_id); } } From 4de65f9d096d99b6ee9804fd6fbcce4520b557e4 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 12 Jul 2023 08:43:07 -0400 Subject: [PATCH 5/7] Traverse in three passes: Most Vals, Exprs, TV/TDs --- csrc/dynamic_transform.cpp | 36 ++++++++++++++++++++++++++++++++++-- 1 file changed, 34 insertions(+), 2 deletions(-) diff --git a/csrc/dynamic_transform.cpp b/csrc/dynamic_transform.cpp index ace9100a5eb..1384d5a55e5 100644 --- a/csrc/dynamic_transform.cpp +++ b/csrc/dynamic_transform.cpp @@ -407,9 +407,38 @@ void DynamicTransformConcretizer::concretize() { } } - // Finally, propagate concretized domains with forward traversal + // In this first pass, we only mutate Vals that are not TensorDomains or + // TensorViews. This does not change the traversal order as no expressions are + // changed. for (auto stmt : all_stmts) { - mutate(stmt); + if (stmt->isVal() && !stmt->isA() && + !stmt->isA()) { + mutate(stmt); + } + } + // After this point, we should only call `registerExpr` on the unhandled + // TensorDomains. + + // In the second pass, we only mutate Exprs. For each expr, if any of its + // inputs, outputs, or attributes are registered for mutation, a new expr will + // be created and the original expr will be removed. This is the mechanism + // OptOutMutator provides for setting the `definition()` of replaced Vals. + for (auto stmt : all_stmts) { + if (stmt->isExpr()) { + mutate(stmt); + } + } + + // In the third pass, we mutate the TensorDomains and TensorViews, without + // touching any other Vals or Exprs. The only change made to the Fusion at + // this stage is that TensorViews have their domain() replaced if any of their + // IterDomains are registered for mutation. This must happen last, as Expr + // mutation is required in order to properly connect root and rfactor domains, + // which is checked when creating new TensorDomains. + for (auto stmt : all_stmts) { + if (stmt->isA() || stmt->isA()) { + mutate(stmt); + } } } @@ -532,6 +561,9 @@ void DynamicTransformConcretizer::mutate(TensorDomain* td) { } } + // NOTE: definitions for replacement rfactor IDs must have been properly set + // at this point. This will only happen after `mutate(Expr)` has been called + // on each intermediate ID expression connecting root to rfactor. Val* mutated_val = IrBuilder::create( td->container(), root_dom, rfactor_dom, domain, contig); registerConcretization(td, mutated_val); From 8dab796f590dce980ad4bf4858b2e89915f42539 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 12 Jul 2023 09:54:59 -0400 Subject: [PATCH 6/7] Fix logic and clean up comments --- csrc/dynamic_transform.cpp | 149 +++++++++++++++++++------------------ 1 file changed, 75 insertions(+), 74 deletions(-) diff --git a/csrc/dynamic_transform.cpp b/csrc/dynamic_transform.cpp index 1384d5a55e5..55293b73708 100644 --- a/csrc/dynamic_transform.cpp +++ b/csrc/dynamic_transform.cpp @@ -380,20 +380,32 @@ void DynamicTransformConcretizer::concretize() { // Set output IterTypes for dynamic resize ops concretizeResize(); - // The methods above do not traverse the graph. Instead they fill in - // root->rfactor expressions by replacing the dynamic reshaped TV with a - // static reshaped one, and by registering concretization of dynamic Resized - // IterDomains. From this point forward, we will not modify any TensorView - // expressions. This restriction makes it safe for us to to traverse forward - // through the graph and mutate IterDomains and TensorDomains in order to - // properly propagate IterTypes and concretized extent expressions, without - // breaking the topological ordering of these expressions. - // + // This fixes the set of statements we will process over the course of + // multiple passes. + // Since we will traverse these Statements after some have been removed, we + // will not be able to safely check the types of each Statement in later + // loops. To avoid segfaults, we first split all_statements into subsets for + // each traversal. + std::vector non_tds_tvs; + std::vector all_exprs; + std::vector tvs_and_tds; + for (auto stmt : StmtSort::getStmts(info_->fusion(), true)) { + if (stmt->isExpr()) { + all_exprs.push_back(stmt->asExpr()); + } else { + auto val = stmt->asVal(); + if (val->isA() || val->isA()) { + tvs_and_tds.push_back(val); + } else { + non_tds_tvs.push_back(val); + } + } + } + // When propagating IterTypes across expressions, we need to know the producer // IterDomains corresponding to a consumer ID. This mapping helps facilitate - // this and is used later in mutate(IterDomain*). - auto all_stmts = StmtSort::getStmts(info_->fusion(), true); - for (auto expr : ir_utils::filterByType(all_stmts)) { + // this and is used later by mutate(IterDomain*) in the first pass below. + for (auto expr : all_exprs) { for (auto consumer : ir_utils::filterByType(expr->outputs())) { for (auto producer : ir_utils::filterByType(expr->inputs())) { PairwiseRootDomainMap root_map(producer, consumer); @@ -408,25 +420,22 @@ void DynamicTransformConcretizer::concretize() { } // In this first pass, we only mutate Vals that are not TensorDomains or - // TensorViews. This does not change the traversal order as no expressions are - // changed. - for (auto stmt : all_stmts) { - if (stmt->isVal() && !stmt->isA() && - !stmt->isA()) { - mutate(stmt); - } + // TensorViews. + // + // This pass does not modify the Fusion. + for (auto val : non_tds_tvs) { + mutate(val); } - // After this point, we should only call `registerExpr` on the unhandled - // TensorDomains. // In the second pass, we only mutate Exprs. For each expr, if any of its // inputs, outputs, or attributes are registered for mutation, a new expr will // be created and the original expr will be removed. This is the mechanism // OptOutMutator provides for setting the `definition()` of replaced Vals. - for (auto stmt : all_stmts) { - if (stmt->isExpr()) { - mutate(stmt); - } + // + // This pass may add and remove Exprs, so elements of all_exprs are invalid + // after this pass. + for (auto expr : all_exprs) { + mutate(expr); } // In the third pass, we mutate the TensorDomains and TensorViews, without @@ -435,10 +444,11 @@ void DynamicTransformConcretizer::concretize() { // IterDomains are registered for mutation. This must happen last, as Expr // mutation is required in order to properly connect root and rfactor domains, // which is checked when creating new TensorDomains. - for (auto stmt : all_stmts) { - if (stmt->isA() || stmt->isA()) { - mutate(stmt); - } + // + // This pass modifies the Fusion by creating new TensorDomains and swapping + // them into TensorViews. + for (auto val : tvs_and_tds) { + mutate(val); } } @@ -462,18 +472,6 @@ void DynamicTransformConcretizer::concretizeReshape() { use_of_old_tv, incomplete_out_tv, concrete_reshape_out_tv); } - // Replace old extents with new ones in downstream expressions - for (auto i : - c10::irange(incomplete_out_tv->getMaybeRFactorDomain().size())) { - auto old_extent = - incomplete_out_tv->getMaybeRFactorDomain().at(i)->extent(); - auto new_extent = - concrete_reshape_out_tv->getMaybeRFactorDomain().at(i)->extent(); - if (!new_extent->sameAs(old_extent)) { - registerConcretization(old_extent, new_extent); - } - } - if (incomplete_out_tv->isFusionOutput()) { incomplete_out_tv->fusion()->replaceOutput( incomplete_out_tv, concrete_reshape_out_tv); @@ -580,9 +578,9 @@ void DynamicTransformConcretizer::mutate(IterDomain* id) { IterDomain* concretized_id = nullptr; - if (auto def = id->definition()) { - IterType iter_type = id->getIterType(); - if (iter_type == IterType::Symbolic) { + if (mut_id->isSymbolic()) { + if (auto def = id->definition()) { + IterType iter_type = mut_id->getIterType(); // Determine concrete IterType based on promotion of inputs to def for (auto inp_id : ir_utils::filterByType(def->inputs())) { auto updated_id = maybeMutated(inp_id)->as(); @@ -592,40 +590,43 @@ void DynamicTransformConcretizer::mutate(IterDomain* id) { iter_type != IterType::Symbolic, "Failed to concretize an output IterType for expression: ", def->toString()); - } - concretized_id = IterDomainBuilder(mut_id).iter_type(iter_type).build(); - } else { - // IterDomains without definitions might be root domains for the output of a - // TensorView expression. If so, we should propagate their concretization in - // the producer to consumer direction. - - auto producers_it = id_producers_.find(id); - if (producers_it != id_producers_.end()) { - // id was a consumer root ID in some TV expression - - std::optional id_type; - for (auto producer_id : producers_it->second) { - producer_id = maybeMutated(producer_id)->as(); - if (id_type.has_value()) { - id_type = ops::promoteIterType(*id_type, producer_id->getIterType()); - } else { - id_type = producer_id->getIterType(); + concretized_id = IterDomainBuilder(mut_id).iter_type(iter_type).build(); + } else { + // IterDomains without definitions might be root domains for the output of + // a TensorView expression. If so, we should propagate their + // concretization in the producer to consumer direction. + + auto producers_it = id_producers_.find(id); + if (producers_it != id_producers_.end()) { + // id was a consumer root ID in some TV expression + + std::optional id_type; + for (auto producer_id : producers_it->second) { + producer_id = maybeMutated(producer_id)->as(); + if (id_type.has_value()) { + id_type = + ops::promoteIterType(*id_type, producer_id->getIterType()); + } else { + id_type = producer_id->getIterType(); + } } - } - TORCH_INTERNAL_ASSERT( - id_type.has_value(), - "Did not find id_type for consumer root domain ", - id->toString(), - ". Perhaps consumer def has no inputs."); + TORCH_INTERNAL_ASSERT( + id_type.has_value(), + "Did not find id_type for consumer root domain ", + id->toString(), + ". Perhaps consumer def has no inputs."); - TORCH_INTERNAL_ASSERT( - id_type.value() != IterType::Symbolic, - "Failed to concretize ", - id->toString()); + TORCH_INTERNAL_ASSERT( + id_type.value() != IterType::Symbolic, + "Failed to concretize ", + id->toString()); + + if (id_type.value() != id->getIterType()) - concretized_id = - IterDomainBuilder(mut_id).iter_type(id_type.value()).build(); + concretized_id = + IterDomainBuilder(mut_id).iter_type(id_type.value()).build(); + } } } From 9b78b1731f811dea477f8fe9a2de31b23c77d2f1 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 12 Jul 2023 10:17:59 -0400 Subject: [PATCH 7/7] Move contiguity update to OptOutMutator::mutate(TensorDomain*) --- csrc/dynamic_transform.cpp | 60 -------------------------------------- csrc/mutator.cpp | 29 ++++++++++++++---- 2 files changed, 23 insertions(+), 66 deletions(-) diff --git a/csrc/dynamic_transform.cpp b/csrc/dynamic_transform.cpp index 55293b73708..f1a9f075363 100644 --- a/csrc/dynamic_transform.cpp +++ b/csrc/dynamic_transform.cpp @@ -357,7 +357,6 @@ class DynamicTransformConcretizer : public OptOutMutator { using OptOutMutator::mutate; - void mutate(TensorDomain* td) final; void mutate(IterDomain* id) final; private: @@ -508,65 +507,6 @@ void DynamicTransformConcretizer::checkConcretizedUses( } } -// Almost an exact copy of OptOutMutator::mutate(TensorDomain*), but -// the contiguity vector may need to be updated as well as symbolic -// domains may be mutated to broadcast domains, which means contiguity -// may need to be changed to nullopt -void DynamicTransformConcretizer::mutate(TensorDomain* td) { - bool mutated = false; - - auto updateIdVec = [&](const std::vector& ids) { - std::vector updated_ids; - for (auto id : ids) { - auto updated_id = maybeMutated(id)->as(); - updated_ids.push_back(updated_id); - if (!updated_id->sameAs(id)) { - mutated = true; - } - } - return updated_ids; - }; - - std::vector root_dom = updateIdVec(td->root()); - std::vector rfactor_dom = td->hasRFactor() - ? updateIdVec(td->maybeRFactor()) - : std::vector(); - std::vector domain = updateIdVec(td->leaf()); - - if (!mutated) { - return; - } - - // Update the contiguity vector. Drop the contig val if mutated to broadcast - auto contig = td->contiguity(); - - for (const auto i : c10::irange(td->maybeRFactor().size())) { - auto original_id = td->maybeRFactor().at(i); - if (original_id->getIterType() != IterType::Symbolic) { - continue; - } - - TORCH_INTERNAL_ASSERT( - contig.at(i), - "Unexpected to have a non-contig symbolic domain: ", - original_id->toString()); - - auto updated_id = td->hasRFactor() ? rfactor_dom.at(i) : root_dom.at(i); - - // If the concretized ID is a broadcast domain, drop the contig val - if (updated_id->isBroadcast()) { - contig.at(i) = std::nullopt; - } - } - - // NOTE: definitions for replacement rfactor IDs must have been properly set - // at this point. This will only happen after `mutate(Expr)` has been called - // on each intermediate ID expression connecting root to rfactor. - Val* mutated_val = IrBuilder::create( - td->container(), root_dom, rfactor_dom, domain, contig); - registerConcretization(td, mutated_val); -} - void DynamicTransformConcretizer::mutate(IterDomain* id) { // This will register id for mutation if start, stop, or extent are registered // for mutation diff --git a/csrc/mutator.cpp b/csrc/mutator.cpp index e33bfaceb50..468d31b6133 100644 --- a/csrc/mutator.cpp +++ b/csrc/mutator.cpp @@ -107,13 +107,30 @@ void OptOutMutator::mutate(TensorDomain* td) { return; } + // Update the contiguity vector. Drop the contig val if mutated to broadcast + auto contig = td->contiguity(); + + for (const auto i : c10::irange(td->maybeRFactor().size())) { + auto original_id = td->maybeRFactor().at(i); + if (original_id->getIterType() != IterType::Symbolic) { + continue; + } + + TORCH_INTERNAL_ASSERT( + contig.at(i), + "Unexpected to have a non-contig symbolic domain: ", + original_id->toString()); + + auto updated_id = td->hasRFactor() ? rfactor_dom.at(i) : root_dom.at(i); + + // If the mutated ID is a broadcast domain, drop the contig val + if (updated_id->isBroadcast()) { + contig.at(i) = std::nullopt; + } + } + Val* mutated_val = IrBuilder::create( - td->container(), - root_dom, - rfactor_dom, - allocation_dom, - domain, - td->contiguity()); + td->container(), root_dom, rfactor_dom, allocation_dom, domain, contig); registerMutation(td, mutated_val); }