-
Notifications
You must be signed in to change notification settings - Fork 79
Refactor concretization traversal #576
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
a58ce3e
Refactor to traverse in forward direction only
jacobhinkle 88f5d5e
Comment out PadShmoo {{3, 5}, {0, -4}, true} case
jacobhinkle bdec706
Add comment about topo ordering in iter_visitor.h
jacobhinkle e3f5a35
More carefully handle extents in mutate(IterDomain*)
jacobhinkle 4de65f9
Traverse in three passes: Most Vals, Exprs, TV/TDs
jacobhinkle 8dab796
Fix logic and clean up comments
jacobhinkle d278b02
Merge branch 'main' into concretization_topo_order
jacobhinkle 9b78b17
Move contiguity update to OptOutMutator::mutate(TensorDomain*)
jacobhinkle File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -357,16 +357,19 @@ class DynamicTransformConcretizer : public OptOutMutator { | |
|
|
||
| using OptOutMutator::mutate; | ||
|
|
||
| void mutate(TensorView* tv) final; | ||
|
|
||
| void mutate(TensorDomain* td) final; | ||
|
|
||
| //! Concretizes the root domain of a symbolic consumer tensor from | ||
| //! its producer domains. Returns true if any root ID is concretized. | ||
| bool propagateFromProducerToConsumer(TensorView* consumer); | ||
| void mutate(IterDomain* id) final; | ||
|
|
||
| private: | ||
| const DynamicTransformConcretizationInfo* info_; | ||
|
|
||
| //! This map is used during concretization to identify, for a given IterDomain | ||
| //! the set of all IterDomains which are "aligned" with it in some TensorView | ||
| //! expression. This enables us to write mutate(IterDomain*) and propagate | ||
| //! information from producer IterDomains to consumers, which is otherwise not | ||
| //! represented in the graph since we do not connect IterDomains between | ||
| //! TensorViews with expressions. | ||
| std::unordered_map<IterDomain*, std::unordered_set<IterDomain*>> | ||
| id_producers_; | ||
| }; | ||
|
|
||
| void DynamicTransformConcretizer::concretize() { | ||
|
|
@@ -376,13 +379,76 @@ void DynamicTransformConcretizer::concretize() { | |
| // Set output IterTypes for dynamic resize ops | ||
| concretizeResize(); | ||
|
|
||
| // Finally, propagate concretized domains | ||
| auto all_stmts = StmtSort::getStmts(info_->fusion(), true); | ||
| for (auto stmt : all_stmts) { | ||
| if (stmt->isA<Val>()) { | ||
| mutate(stmt); | ||
| // This fixes the set of statements we will process over the course of | ||
| // multiple passes. | ||
| // Since we will traverse these Statements after some have been removed, we | ||
| // will not be able to safely check the types of each Statement in later | ||
| // loops. To avoid segfaults, we first split all_statements into subsets for | ||
| // each traversal. | ||
| std::vector<Val*> non_tds_tvs; | ||
| std::vector<Expr*> all_exprs; | ||
| std::vector<Val*> tvs_and_tds; | ||
| for (auto stmt : StmtSort::getStmts(info_->fusion(), true)) { | ||
| if (stmt->isExpr()) { | ||
| all_exprs.push_back(stmt->asExpr()); | ||
| } else { | ||
| auto val = stmt->asVal(); | ||
| if (val->isA<TensorView>() || val->isA<TensorDomain>()) { | ||
| tvs_and_tds.push_back(val); | ||
| } else { | ||
| non_tds_tvs.push_back(val); | ||
| } | ||
|
Comment on lines
+388
to
+400
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What was previously a single loop over |
||
| } | ||
| } | ||
|
|
||
| // When propagating IterTypes across expressions, we need to know the producer | ||
| // IterDomains corresponding to a consumer ID. This mapping helps facilitate | ||
| // this and is used later by mutate(IterDomain*) in the first pass below. | ||
| for (auto expr : all_exprs) { | ||
| for (auto consumer : ir_utils::filterByType<TensorView>(expr->outputs())) { | ||
| for (auto producer : ir_utils::filterByType<TensorView>(expr->inputs())) { | ||
| PairwiseRootDomainMap root_map(producer, consumer); | ||
| for (auto [cid, pid] : root_map.mapConsumerToProducer( | ||
| consumer->domain(), producer->domain())) { | ||
| // Initialize set of producer IDs, if we haven't already | ||
| auto& producers = id_producers_.emplace(cid, 0).first->second; | ||
| producers.insert(pid); | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // In this first pass, we only mutate Vals that are not TensorDomains or | ||
| // TensorViews. | ||
| // | ||
| // This pass does not modify the Fusion. | ||
| for (auto val : non_tds_tvs) { | ||
| mutate(val); | ||
| } | ||
|
|
||
| // In the second pass, we only mutate Exprs. For each expr, if any of its | ||
| // inputs, outputs, or attributes are registered for mutation, a new expr will | ||
| // be created and the original expr will be removed. This is the mechanism | ||
| // OptOutMutator provides for setting the `definition()` of replaced Vals. | ||
| // | ||
| // This pass may add and remove Exprs, so elements of all_exprs are invalid | ||
| // after this pass. | ||
| for (auto expr : all_exprs) { | ||
| mutate(expr); | ||
| } | ||
|
|
||
| // In the third pass, we mutate the TensorDomains and TensorViews, without | ||
| // touching any other Vals or Exprs. The only change made to the Fusion at | ||
| // this stage is that TensorViews have their domain() replaced if any of their | ||
| // IterDomains are registered for mutation. This must happen last, as Expr | ||
| // mutation is required in order to properly connect root and rfactor domains, | ||
| // which is checked when creating new TensorDomains. | ||
| // | ||
| // This pass modifies the Fusion by creating new TensorDomains and swapping | ||
| // them into TensorViews. | ||
| for (auto val : tvs_and_tds) { | ||
| mutate(val); | ||
| } | ||
| } | ||
|
|
||
| void DynamicTransformConcretizer::concretizeReshape() { | ||
|
|
@@ -441,228 +507,72 @@ void DynamicTransformConcretizer::checkConcretizedUses( | |
| } | ||
| } | ||
|
|
||
| // Concretizes inherited symbolic domains. Note that when this is | ||
| // called, it is assumed that all dynamic ops themselves are | ||
| // concretized. Since symbolic IDs may be propagated down to | ||
| // consumers, those domains need to be concretized accordingly. | ||
| void DynamicTransformConcretizer::mutate(TensorView* tv) { | ||
| if (!tv->domain()->hasSymbolicAxis()) { | ||
| return; | ||
| } | ||
|
|
||
| // First, try to concretize the root domain as there may be symbolic | ||
| // axes inherited from the producers | ||
| propagateFromProducerToConsumer(tv); | ||
|
|
||
| // If no root domain is altered by producer, we don't need to propagate back | ||
| // up to rfactor. We could return early, but instead we go ahead and check the | ||
| // root to rfactor transforms to be sure we have concretized any intermediate | ||
| // IterDomains. | ||
|
|
||
| // At this point, there should be no expr beyond rfactor root | ||
| TORCH_INTERNAL_ASSERT( | ||
| tv->getLeafDomain() == tv->getMaybeRFactorDomain(), | ||
| "Invalid tensor: ", | ||
| tv->toString()); | ||
|
|
||
| // If it has an rfactor root domain, the IterTypes of the rfactor | ||
| // IDs may need to be updated as well. Traverse the rfactor exprs | ||
| // and mutate the IterTypes of output IDs if symbolic. | ||
| if (tv->hasRFactor()) { | ||
| // Note that it is assumed that theres's no further expression | ||
| // beyond the rfactor domain as asserted above | ||
| auto all_id_exprs = StmtSort::getExprsBetween( | ||
| tv->fusion(), | ||
| {tv->getRootDomain().begin(), tv->getRootDomain().end()}, | ||
| {tv->getMaybeRFactorDomain().begin(), | ||
| tv->getMaybeRFactorDomain().end()}); | ||
| for (auto expr : all_id_exprs) { | ||
| // Assume outputs of IterDomain exprs are always IterDomains. If | ||
| // the assumption is invalidated, the logic here would need to | ||
| // be updated. Assert the assumption to immediately detect such | ||
| // a case if happened. | ||
| for (auto out_val : expr->outputs()) { | ||
| TORCH_INTERNAL_ASSERT( | ||
| out_val->isA<IterDomain>(), | ||
| "Unexpected output: ", | ||
| out_val->toString(), | ||
| ". IterDomain was expected."); | ||
| } | ||
| void DynamicTransformConcretizer::mutate(IterDomain* id) { | ||
| // This will register id for mutation if start, stop, or extent are registered | ||
| // for mutation | ||
| OptOutMutator::mutate(id); | ||
|
|
||
| // Use this to prototype new concretizations, since it will have replaced | ||
| // extent (see above) | ||
| auto mut_id = maybeMutated(id)->as<IterDomain>(); | ||
|
|
||
| IterDomain* concretized_id = nullptr; | ||
|
|
||
| // NOTE: We do not return early if all outputs are concrete as there may | ||
| // still be concrete inputs. For example, a Symbolic IterDomain might be | ||
| // padded with constant pad widths (1, 1), in which case although we do | ||
| // not know the exact extent of the output, we know it is at least as | ||
| // large as the sum of the pad widths, 2. In such cases, the output | ||
| // IterDomain is concrete at definition, since if the extent is >1 we know | ||
| // the IterType is Iteration. In these cases, we must continue to | ||
| // concretize intermediate expressions between the root and R-factor | ||
| // domain. See test DynamicTransform5_CUDA which demonstrates this | ||
| // behavior. | ||
| // NOTE: We also do not assume that if one output ID is symbolic, that | ||
| // they all must be. See test FusionSliceForNanoGPT3_CUDA for an example | ||
| // that does a static split by a factor of 16 of a symbolic input domain. | ||
| // The static split in that case results in a concrete IterDomain with | ||
| // extent 16 along with a symbolic one (extent ceilDiv(n / 16)). | ||
|
|
||
| // Determine the output IterType | ||
| IterType iter_type = IterType::Symbolic; | ||
| for (auto inp_id : ir_utils::filterByType<IterDomain>(expr->inputs())) { | ||
| if (mut_id->isSymbolic()) { | ||
| if (auto def = id->definition()) { | ||
| IterType iter_type = mut_id->getIterType(); | ||
| // Determine concrete IterType based on promotion of inputs to def | ||
| for (auto inp_id : ir_utils::filterByType<IterDomain>(def->inputs())) { | ||
| auto updated_id = maybeMutated(inp_id)->as<IterDomain>(); | ||
| iter_type = ops::promoteIterType(iter_type, updated_id->getIterType()); | ||
| } | ||
| TORCH_INTERNAL_ASSERT( | ||
| iter_type != IterType::Symbolic, | ||
| "Failed to concretize an output IterType for expression: ", | ||
| expr->toString()); | ||
|
|
||
| // Update the IterType of each output | ||
| for (auto out_id : ir_utils::filterByType<IterDomain>(expr->outputs())) { | ||
| if (!out_id->isSymbolic()) { | ||
| continue; | ||
| def->toString()); | ||
| concretized_id = IterDomainBuilder(mut_id).iter_type(iter_type).build(); | ||
| } else { | ||
| // IterDomains without definitions might be root domains for the output of | ||
| // a TensorView expression. If so, we should propagate their | ||
| // concretization in the producer to consumer direction. | ||
|
|
||
| auto producers_it = id_producers_.find(id); | ||
| if (producers_it != id_producers_.end()) { | ||
| // id was a consumer root ID in some TV expression | ||
|
|
||
| std::optional<IterType> id_type; | ||
| for (auto producer_id : producers_it->second) { | ||
| producer_id = maybeMutated(producer_id)->as<IterDomain>(); | ||
| if (id_type.has_value()) { | ||
| id_type = | ||
| ops::promoteIterType(*id_type, producer_id->getIterType()); | ||
| } else { | ||
| id_type = producer_id->getIterType(); | ||
| } | ||
| } | ||
| auto concretized_out_id = | ||
| IterDomainBuilder(out_id).iter_type(iter_type).build(); | ||
| registerConcretization(out_id, concretized_out_id); | ||
| } | ||
|
|
||
| // The expr itself needs to be mutated as well in case the outputs are | ||
| // mutated, which can be done by the mutate method | ||
| OptOutMutator::mutate(expr); | ||
| } | ||
| } | ||
|
|
||
| // Root and rfactor domains are updated. First mutate the | ||
| // TensorDomain and then TensorView | ||
| mutate(tv->domain()); | ||
| OptOutMutator::mutate(tv); | ||
| } | ||
|
|
||
| // Almost an exact copy of OptOutMutator::mutate(TensorDomain*), but | ||
| // the contiguity vector may need to be updated as well as symbolic | ||
| // domains may be mutated to broadcast domains, which means contiguity | ||
| // may need to be changed to nullopt | ||
| void DynamicTransformConcretizer::mutate(TensorDomain* td) { | ||
| bool mutated = false; | ||
|
|
||
| auto updateIdVec = [&](const std::vector<IterDomain*>& ids) { | ||
| std::vector<IterDomain*> updated_ids; | ||
| for (auto id : ids) { | ||
| auto updated_id = maybeMutated(id)->as<IterDomain>(); | ||
| updated_ids.push_back(updated_id); | ||
| if (!updated_id->sameAs(id)) { | ||
| mutated = true; | ||
| } | ||
| } | ||
| return updated_ids; | ||
| }; | ||
|
|
||
| std::vector<IterDomain*> root_dom = updateIdVec(td->root()); | ||
| std::vector<IterDomain*> rfactor_dom = td->hasRFactor() | ||
| ? updateIdVec(td->maybeRFactor()) | ||
| : std::vector<IterDomain*>(); | ||
| std::vector<IterDomain*> domain = updateIdVec(td->leaf()); | ||
|
|
||
| if (!mutated) { | ||
| return; | ||
| } | ||
|
|
||
| // Update the contiguity vector. Drop the contig val if mutated to broadcast | ||
| auto contig = td->contiguity(); | ||
|
|
||
| for (const auto i : c10::irange(td->maybeRFactor().size())) { | ||
| auto original_id = td->maybeRFactor().at(i); | ||
| if (original_id->getIterType() != IterType::Symbolic) { | ||
| continue; | ||
| } | ||
|
|
||
| TORCH_INTERNAL_ASSERT( | ||
| contig.at(i), | ||
| "Unexpected to have a non-contig symbolic domain: ", | ||
| original_id->toString()); | ||
|
|
||
| auto updated_id = td->hasRFactor() ? rfactor_dom.at(i) : root_dom.at(i); | ||
|
|
||
| // If the concretized ID is a broadcast domain, drop the contig val | ||
| if (updated_id->isBroadcast()) { | ||
| contig.at(i) = std::nullopt; | ||
| } | ||
| } | ||
|
|
||
| Val* mutated_val = IrBuilder::create<TensorDomain>( | ||
| td->container(), root_dom, rfactor_dom, domain, contig); | ||
| registerConcretization(td, mutated_val); | ||
| } | ||
|
|
||
| bool DynamicTransformConcretizer::propagateFromProducerToConsumer( | ||
| TensorView* consumer) { | ||
| if (consumer->definition() == nullptr || | ||
| !consumer->domain()->hasSymbolicAxis()) { | ||
| return false; | ||
| } | ||
|
|
||
| const auto& root_domain = consumer->getRootDomain(); | ||
|
|
||
| auto def = consumer->definition(); | ||
|
|
||
| bool is_concretized = false; | ||
|
|
||
| for (const auto i : c10::irange(root_domain.size())) { | ||
| auto root_id = root_domain.at(i); | ||
| if (root_id->getIterType() != IterType::Symbolic) { | ||
| continue; | ||
| } | ||
|
|
||
| // Figure out the right IterType of this consumer root ID from its | ||
| // corresponding producer IDs | ||
|
|
||
| std::optional<IterType> id_type; | ||
|
|
||
| for (auto producer : ir_utils::filterByType<TensorView>(def->inputs())) { | ||
| PairwiseRootDomainMap root_map(producer, consumer); | ||
| auto c2p = root_map.mapConsumerToProducer( | ||
| consumer->domain(), producer->domain()); | ||
| TORCH_INTERNAL_ASSERT( | ||
| id_type.has_value(), | ||
| "Did not find id_type for consumer root domain ", | ||
| id->toString(), | ||
| ". Perhaps consumer def has no inputs."); | ||
|
|
||
| TORCH_INTERNAL_ASSERT( | ||
| c2p.find(root_id) != c2p.end(), | ||
| "No input ID found to map with output ID: ", | ||
| root_id->toString()); | ||
| TORCH_INTERNAL_ASSERT( | ||
| id_type.value() != IterType::Symbolic, | ||
| "Failed to concretize ", | ||
| id->toString()); | ||
|
|
||
| auto input_id = c2p.at(root_id); | ||
| TORCH_INTERNAL_ASSERT( | ||
| input_id->getIterType() != IterType::Symbolic, | ||
| "Producer ID not concretized: ", | ||
| input_id->toString()); | ||
| if (id_type.value() != id->getIterType()) | ||
|
|
||
| if (id_type.has_value()) { | ||
| id_type = ops::promoteIterType(*id_type, input_id->getIterType()); | ||
| } else { | ||
| id_type = input_id->getIterType(); | ||
| concretized_id = | ||
| IterDomainBuilder(mut_id).iter_type(id_type.value()).build(); | ||
| } | ||
| } | ||
|
|
||
| TORCH_INTERNAL_ASSERT( | ||
| id_type.has_value(), | ||
| "Did not find id_type for consumer root domain ", | ||
| root_id->toString(), | ||
| ". Perhaps consumer def has no inputs. Consumer definition = ", | ||
| def->toString()); | ||
|
|
||
| TORCH_INTERNAL_ASSERT( | ||
| id_type != IterType::Symbolic, | ||
| "Failed to concretize ", | ||
| root_id->toString(), | ||
| " of ", | ||
| consumer->toString()); | ||
|
|
||
| auto concretized_id = | ||
| IterDomainBuilder(root_id).iter_type(*id_type).build(); | ||
|
|
||
| registerConcretization(root_id, concretized_id); | ||
| is_concretized = true; | ||
| } | ||
|
|
||
| return is_concretized; | ||
| if (concretized_id) { | ||
| registerConcretization(id, concretized_id); | ||
| } | ||
| } | ||
|
|
||
| DynamicTransformInitialInfo DynamicTransform::getInitialInfo(Fusion* fusion) { | ||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of only mutating
Vals, we now mutateExprs as well, which replaces theExprin place if any inputs or outputs have changed. Note that outputs ofExprs are mutated after their definition has been mutated, so we should be careful updating aValthat has a definition. But of course we should be careful in that case in the existing code too.