Skip trivial Resize, handle resize of Broadcast at concretization.#800
Skip trivial Resize, handle resize of Broadcast at concretization.#800jacobhinkle merged 5 commits intomainfrom
Conversation
Avoids trivial resize
|
!build |
csrc/ir/nodes.cpp
Outdated
| if (left == 0 && right == 0) { | ||
| // Trivial Resize | ||
| return in; |
There was a problem hiding this comment.
It seems odd to check the trivial case here as this part is to figure out the iteration type. Why not just do this check early? Also, maybe we should check the optional iteration type is not given or the same as the input iteration type.
There was a problem hiding this comment.
Good point. I moved the check to the beginning of the function and added an assertion that the specified IterType does not clash, if it is given.
csrc/dynamic_transform.cpp
Outdated
| // dimension, then we should not retain Symbolic. To work around this, | ||
| // we always overwrite Symbolic with the first concrete IterType we | ||
| // encounter. | ||
| iter_type = iter_type == IterType::Symbolic |
There was a problem hiding this comment.
Here, updated_id->getIterType() should be always non symbolic, right?
There was a problem hiding this comment.
This feels more intuitive to me.
for (const auto i: c10::irange(...)) {
auto updated_id = ...;
TORCH_INTERNAL_ASSERT(updated_id->getIterType() != IterType::Symbolic);
if (i == 0) {
iter_type = updated_id->getIterType();
} else {
iter_type = ops::promoteIterType(iter_type, updated_id->getIterType());
}
}
There was a problem hiding this comment.
Here,
updated_id->getIterType()should be always non symbolic, right?
I think it could still be symbolic, since updated_id could have been mutated because its extent was mutated, without setting the IterType yet.
There was a problem hiding this comment.
Could you give me an example?
There was a problem hiding this comment.
Actually I think you might be right that this won't happen.
We currently replace derived extents when we concretize reshape using registerMutation. So if a downstream TensorView has the old extent then the default mutate will create a new Symbolic ID with the new extent here https://github.com/NVIDIA/Fuser/blob/main/csrc/dynamic_transform.cpp#L570 .
For example:
T1[ ?S2{div(i0, s0)}, ?S3{s0} ] = reshape(T0[ div(i0, s0), s0 ]);
T2[ ?S4{div(i0, s0)}, ?S5{s0} ] = neg(T1);
Then depending on the inputs we might concretize this as
T1[ iS6{ceilDiv(i0, 3)}, iS7{3} ] = reshape(T0[ div(i0, s0), s0 ]);
T2[ iS10{ceilDiv(i0, 3)}, iS11{3} ] = neg(T1);
When mutate is called for T2, first div(i0, s0) will be recognized as mutated, so ?S4{div(i0, s0)} will be mutated as ?S10{ceilDiv(i0, 3)}, then later in the mutate we will set its IterType in propagateFromProducerToConsumer.
This happens for root domains of T2, but anything downstream of root up until rfactor should not be mutated yet, unless it was a Resize output, in which case the IterType should've been set before this traversal. I switched to the indexed loop as you suggested. I'll add this check for Symbolic iter type as well. Thanks!
naoyam
left a comment
There was a problem hiding this comment.
The changes make sense to me. Just left some suggestions. Feel free to merge.
naoyam
left a comment
There was a problem hiding this comment.
LGTM. Thanks for the update!
) This updates the root to rfactor propagation in IterType concretization of dynamic fusions. Previously, although we only overwrote Symbolic IterDomains in this step, we still asserted that we could infer an IterType for each I moved that check so that it is only applied when we need to make a change. Additionally, we previously propagated Broadcast-only IterDomains as Symbolic, since we combine with our previous estimate using promoteIterType. As mentioned in a comment, this means Broadcast gets propagated as Symbolic. Instead we now only fall back to promoteIterType when there are multiple input IterTypes to the IterDomain expression. Fixes #798
This updates the root to rfactor propagation in IterType concretization of dynamic fusions.
Previously, although we only overwrote Symbolic IterDomains in this step, we still asserted that we could infer an IterType for each I moved that check so that it is only applied when we need to make a change.
Additionally, we previously propagated Broadcast-only IterDomains as Symbolic, since we combine with our previous estimate using promoteIterType. As mentioned in a comment, this means Broadcast gets propagated as Symbolic. Instead we now only fall back to promoteIterType when there are multiple input IterTypes to the IterDomain expression.
Fixes #798