-
Notifications
You must be signed in to change notification settings - Fork 79
Propagate modified reshape extents at concretization #630
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
3d73387
a207bae
a70be81
98066d2
deeabf6
0d3850b
92ba13b
5996e7a
d05d1e7
fb9d2e3
536b59e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1437,6 +1437,62 @@ std::string toString(const SegmentedFusion* segmented_fusion) { | |
| return ss.str(); | ||
| } | ||
|
|
||
| //! Sets the rfactor as root and erases rfactor of all inputs in fusion. Any | ||
| //! non-constant expressions in those extents are replaced by new scalars with | ||
| //! no definition. These mutations are performed throughout the Fusion so that | ||
| //! downstream expressions dependent on the original inputs' rfactor extents can | ||
| //! be computed properly. | ||
| void convertInputRfactorsToRoots(Fusion* fusion) { | ||
| FusionGuard fg(fusion); | ||
|
|
||
| // Holds all Val replacements across all inputs | ||
| std::unordered_map<Val*, Val*> replacement_map; | ||
|
|
||
| for (auto tv : ir_utils::filterByType<TensorView>(fusion->inputs())) { | ||
| // Create a new root domain and replacement TensorDomain. | ||
| // Given an rfactor domain, create a new IterDomain. | ||
| // Otherwise, clone the previous IterDomain | ||
| std::vector<IterDomain*> new_root_domain; | ||
| auto rfactor = tv->getMaybeRFactorDomain(); | ||
| new_root_domain.reserve(rfactor.size()); | ||
|
|
||
| // Does the domain (root / rfactor) contain all concrete sized extents? | ||
| bool tv_is_concrete = true; | ||
| for (auto id : rfactor) { | ||
| if (!id->extent()->isConstScalar()) { | ||
| tv_is_concrete = false; | ||
| break; | ||
| } | ||
| } | ||
|
|
||
| for (const auto& id : rfactor) { | ||
| if (id->isRFactorProduct()) { | ||
| // Create new symbolic extents for rfactor iterDomains | ||
| auto domain_extent = (!tv_is_concrete) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this function is fine as it's mostly just copied from tensor_view.cpp. However, I don't know why we do use all symbolic extents if This is not about this PR itself, but does it make sense?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This is an important point since it could lead to us concretizing one aspect of a Val but losing a previously concretized aspect (for example losing the concretized extent when we later concretize IterType). But I think we are OK if we are careful to use
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yeah that is strange. All IterDomains in the rfactor domain are checked to determine concreteness, but then the condition is only applied to rfactor products. The purpose of this function is to make the root a "standalone" domain so that we can bind input shapes to it, so it seems like all extents regardless of if they were already rfactor products should be either constant ints or pure symbolic. We always need to create a new auto domain_extent = id->extent()->isConstScalar()
? id->extent()
: IrBuilder::create<Val>(DataType::Int);And in light of the above, we could move this outside of the
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please create a separate issue for this. Let's merge this PR as is.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good. Merging without that change. I will do that in another PR. |
||
| ? IrBuilder::create<Val>(DataType::Int) | ||
| : id->extent(); | ||
| replacement_map.emplace(id->extent(), domain_extent); | ||
| new_root_domain.push_back(IterDomainBuilder(id) | ||
| .extent(domain_extent) | ||
| .resetSchedulingParams() | ||
| .build()); | ||
| } else { | ||
| new_root_domain.push_back(id->cloneWithoutRFactor()); | ||
| } | ||
| } | ||
|
|
||
| TORCH_INTERNAL_ASSERT( | ||
| new_root_domain.size() == tv->domain()->contiguity().size()); | ||
| auto new_td = IrBuilder::create<TensorDomain>( | ||
| new_root_domain, tv->domain()->contiguity()); | ||
| replacement_map.emplace(tv->domain(), new_td); | ||
| } | ||
|
|
||
| // This will replace the values in the mapping replacement_map throughout the | ||
| // Fusion | ||
| ir_utils::replaceValue(fusion, replacement_map); | ||
| } | ||
|
|
||
| std::unique_ptr<Fusion> SegmentedFusion::makeFusion(SegmentedGroup* sg) { | ||
| std::unique_ptr<Fusion> fusion_segment = std::make_unique<Fusion>(); | ||
|
|
||
|
|
@@ -1469,9 +1525,9 @@ std::unique_ptr<Fusion> SegmentedFusion::makeFusion(SegmentedGroup* sg) { | |
| fusion_segment->addOutput(complete_to_segment_map.clone(out)); | ||
| } | ||
|
|
||
| for (auto tv : view_tvs) { | ||
| tv->convertRfactorToRootDomain(); | ||
| } | ||
| // Replace all vals that are rfactor extents in fusion_segment->inputs() with | ||
| // new Vals so that they can be bound to the segment inputs. | ||
| convertInputRfactorsToRoots(fusion_segment.get()); | ||
|
|
||
| return fusion_segment; | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -240,69 +240,6 @@ std::string TensorView::toInlineString(int indent_size) const { | |
| return toString(indent_size); | ||
| } | ||
|
|
||
| void TensorView::convertRfactorToRootDomain() { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this change completely unrelated with the propagation of concretized vals?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. At first I thought it was going to be necessary. I realized it was not necessary for this PR, but it is a little faster and reduces the TensorView interface so I figured I would keep it. I can split it into another PR if you prefer. |
||
| // For a given TensorView, does its domain (root / rfactor) contain any | ||
| // concrete sized extents? | ||
| auto is_concrete_tensor = [](TensorView* tv) { | ||
| for (auto id : tv->getMaybeRFactorDomain()) { | ||
| if (!id->extent()->isConstScalar()) { | ||
| return false; | ||
| } | ||
| } | ||
| return true; | ||
| }; | ||
|
|
||
| // Create a new root domain and replacement TensorDomain. | ||
| // Given an rfactor domain, create a new IterDomain. | ||
| // Otherwise, clone the previous IterDomain | ||
| auto createReplacementDomain = | ||
| [this](const std::vector<Val*>& replacement_extents) { | ||
| TORCH_INTERNAL_ASSERT( | ||
| !replacement_extents.empty() && | ||
| getMaybeRFactorDomain().size() == replacement_extents.size()); | ||
|
Comment on lines
-260
to
-262
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I omitted the check that |
||
| size_t idx = 0; | ||
| std::vector<IterDomain*> new_root_domain( | ||
| getMaybeRFactorDomain().size()); | ||
| for (const auto& id : getMaybeRFactorDomain()) { | ||
| if (replacement_extents[idx] != nullptr) { | ||
| new_root_domain[idx] = IterDomainBuilder(id) | ||
| .extent(replacement_extents[idx]) | ||
| .resetSchedulingParams() | ||
| .build(); | ||
| ++idx; | ||
| } else { | ||
| TORCH_INTERNAL_ASSERT(!id->isRFactorProduct()); | ||
| new_root_domain[idx++] = id->cloneWithoutRFactor(); | ||
| } | ||
| } | ||
|
|
||
| TORCH_INTERNAL_ASSERT( | ||
| new_root_domain.size() == domain()->contiguity().size()); | ||
| setDomain(IrBuilder::create<TensorDomain>( | ||
| container(), new_root_domain, domain()->contiguity())); | ||
| }; | ||
|
|
||
| std::vector<Val*> rfactor_extents; | ||
| std::unordered_map<Val*, Val*> replacement_map; | ||
| const auto kThisIsConcreteTensor = is_concrete_tensor(this); | ||
| for (const auto& id : getMaybeRFactorDomain()) { | ||
| if (id->isRFactorProduct()) { | ||
| // Create new symbolic extents for rfactor iterDomains | ||
| auto domain_extent = (!kThisIsConcreteTensor) | ||
| ? IrBuilder::create<Val>(container(), DataType::Int) | ||
| : id->extent(); | ||
| rfactor_extents.push_back(domain_extent); | ||
| replacement_map.emplace(id->extent(), domain_extent); | ||
| } else { | ||
| rfactor_extents.push_back(nullptr); | ||
| } | ||
| } | ||
| createReplacementDomain(rfactor_extents); | ||
|
|
||
| // Propagate new extent throughout fusion using ValReplacementMutator | ||
| ir_utils::replaceValue(fusion(), replacement_map); | ||
| } | ||
|
|
||
| TensorView::TensorView(const TensorView* src, IrCloner* ir_cloner) | ||
| : Val(src, ir_cloner), | ||
| domain_(ir_cloner->clone(src->domain_)), | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we propagate scalars, we should do this anywhere they might occur; namely in members and attributes. This was not necessary when we were only replacing
IterTypes sinceIterDomains are traversed regardless of these settings.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Didn't we discuss duplicated replacements if members were also traversed? Would it be a concern?