Propagate modified reshape extents at concretization#630
Conversation
This allows us to just do one traversal for the whole segment to replace vals instead of one traversal per input TV.
|
Note that an alternative approach would be to reset the extents of the concretized reshaped TV to the original expressions. I went with this approach since it more closely resembles the extents we would see using static reshapes. |
| TORCH_INTERNAL_ASSERT( | ||
| !replacement_extents.empty() && | ||
| getMaybeRFactorDomain().size() == replacement_extents.size()); |
There was a problem hiding this comment.
I omitted the check that replacement_extents is not empty since it did not seem necessary.
| /*traverse_members*/ true, | ||
| /*traverse_attributes*/ true, |
There was a problem hiding this comment.
Since we propagate scalars, we should do this anywhere they might occur; namely in members and attributes. This was not necessary when we were only replacing IterTypes since IterDomains are traversed regardless of these settings.
There was a problem hiding this comment.
Didn't we discuss duplicated replacements if members were also traversed? Would it be a concern?
| return toString(indent_size); | ||
| } | ||
|
|
||
| void TensorView::convertRfactorToRootDomain() { |
There was a problem hiding this comment.
Is this change completely unrelated with the propagation of concretized vals?
There was a problem hiding this comment.
Yes. At first I thought it was going to be necessary. I realized it was not necessary for this PR, but it is a little faster and reduces the TensorView interface so I figured I would keep it. I can split it into another PR if you prefer.
|
!build |
naoyam
left a comment
There was a problem hiding this comment.
The PR looks good to me, but I have one question on the rfactor-root replacement
| /*traverse_members*/ true, | ||
| /*traverse_attributes*/ true, |
There was a problem hiding this comment.
Didn't we discuss duplicated replacements if members were also traversed? Would it be a concern?
| for (const auto& id : rfactor) { | ||
| if (id->isRFactorProduct()) { | ||
| // Create new symbolic extents for rfactor iterDomains | ||
| auto domain_extent = (!tv_is_concrete) |
There was a problem hiding this comment.
I think this function is fine as it's mostly just copied from tensor_view.cpp. However, I don't know why we do use all symbolic extents if tv_is_concrete. It seems this means that if there's any IterDomain with a symbolic extent, all IterDomains would have symbolic extents.
This is not about this PR itself, but does it make sense?
There was a problem hiding this comment.
Didn't we discuss duplicated replacements if members were also traversed? Would it be a concern?
This is an important point since it could lead to us concretizing one aspect of a Val but losing a previously concretized aspect (for example losing the concretized extent when we later concretize IterType). But I think we are OK if we are careful to use maybeMutated when creating replacements. In one case we do already do multiple mutations; we call OptOutMutator::mutate(id) at the beginning of our mutate(IterDomain* id) override.
There was a problem hiding this comment.
It seems this means that if there's any IterDomain with a symbolic extent, all IterDomains would have symbolic extents.
Yeah that is strange. All IterDomains in the rfactor domain are checked to determine concreteness, but then the condition is only applied to rfactor products. The purpose of this function is to make the root a "standalone" domain so that we can bind input shapes to it, so it seems like all extents regardless of if they were already rfactor products should be either constant ints or pure symbolic.
We always need to create a new IterDomain for rfactor products in order to cut the connection to the root domain. It seems like a simpler condition would be
auto domain_extent = id->extent()->isConstScalar()
? id->extent()
: IrBuilder::create<Val>(DataType::Int);And in light of the above, we could move this outside of the id->isRFactorProduct() check, so that we would replace IDs that are not rfactor products if they have non-constant derived extents.
There was a problem hiding this comment.
Please create a separate issue for this. Let's merge this PR as is.
There was a problem hiding this comment.
Sounds good. Merging without that change. I will do that in another PR.
This PR propagates extent scalars after concretizing reshape. Previously the following fusion
would be concretized like this:
Now this concretizes as
This helps ensure that no scalars get lost during segmentation, which could previously occur if the reshape output became a segmentation edge (see #629).
I also took this opportunity to remove
TensorView::convertRfactorToRootDomain(). It was only used at segmentation and did a traversal for each segment input TV. Now, we call it once and that replaces all extents for all input TVs using a single traversal.Fixes #629 and fixes #418.