Skip to content

Skip trivial Resize, handle resize of Broadcast at concretization.#800

Merged
jacobhinkle merged 5 commits intomainfrom
fix-issue798
Aug 29, 2023
Merged

Skip trivial Resize, handle resize of Broadcast at concretization.#800
jacobhinkle merged 5 commits intomainfrom
fix-issue798

Conversation

@jacobhinkle
Copy link
Collaborator

@jacobhinkle jacobhinkle commented Aug 28, 2023

This updates the root to rfactor propagation in IterType concretization of dynamic fusions.

Previously, although we only overwrote Symbolic IterDomains in this step, we still asserted that we could infer an IterType for each I moved that check so that it is only applied when we need to make a change.

Additionally, we previously propagated Broadcast-only IterDomains as Symbolic, since we combine with our previous estimate using promoteIterType. As mentioned in a comment, this means Broadcast gets propagated as Symbolic. Instead we now only fall back to promoteIterType when there are multiple input IterTypes to the IterDomain expression.

Fixes #798

@jacobhinkle
Copy link
Collaborator Author

!build

@jacobhinkle jacobhinkle marked this pull request as ready for review August 29, 2023 09:12
@jacobhinkle jacobhinkle requested a review from naoyam August 29, 2023 15:30
Comment on lines +2828 to +2830
if (left == 0 && right == 0) {
// Trivial Resize
return in;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems odd to check the trivial case here as this part is to figure out the iteration type. Why not just do this check early? Also, maybe we should check the optional iteration type is not given or the same as the input iteration type.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I moved the check to the beginning of the function and added an assertion that the specified IterType does not clash, if it is given.

// dimension, then we should not retain Symbolic. To work around this,
// we always overwrite Symbolic with the first concrete IterType we
// encounter.
iter_type = iter_type == IterType::Symbolic
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, updated_id->getIterType() should be always non symbolic, right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels more intuitive to me.

for (const auto i: c10::irange(...)) {
    auto updated_id = ...;
  TORCH_INTERNAL_ASSERT(updated_id->getIterType() != IterType::Symbolic);
  if (i == 0) {
    iter_type = updated_id->getIterType();
  } else {
    iter_type = ops::promoteIterType(iter_type, updated_id->getIterType());
  }
}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, updated_id->getIterType() should be always non symbolic, right?

I think it could still be symbolic, since updated_id could have been mutated because its extent was mutated, without setting the IterType yet.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you give me an example?

Copy link
Collaborator Author

@jacobhinkle jacobhinkle Aug 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I think you might be right that this won't happen.

We currently replace derived extents when we concretize reshape using registerMutation. So if a downstream TensorView has the old extent then the default mutate will create a new Symbolic ID with the new extent here https://github.com/NVIDIA/Fuser/blob/main/csrc/dynamic_transform.cpp#L570 .
For example:

T1[ ?S2{div(i0, s0)}, ?S3{s0} ] = reshape(T0[ div(i0, s0), s0 ]);
T2[ ?S4{div(i0, s0)}, ?S5{s0} ] = neg(T1);

Then depending on the inputs we might concretize this as

T1[ iS6{ceilDiv(i0, 3)}, iS7{3} ] = reshape(T0[ div(i0, s0), s0 ]);
T2[ iS10{ceilDiv(i0, 3)}, iS11{3} ] = neg(T1);

When mutate is called for T2, first div(i0, s0) will be recognized as mutated, so ?S4{div(i0, s0)} will be mutated as ?S10{ceilDiv(i0, 3)}, then later in the mutate we will set its IterType in propagateFromProducerToConsumer.

This happens for root domains of T2, but anything downstream of root up until rfactor should not be mutated yet, unless it was a Resize output, in which case the IterType should've been set before this traversal. I switched to the indexed loop as you suggested. I'll add this check for Symbolic iter type as well. Thanks!

Copy link
Collaborator

@naoyam naoyam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes make sense to me. Just left some suggestions. Feel free to merge.

Copy link
Collaborator

@naoyam naoyam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks for the update!

@jacobhinkle jacobhinkle merged commit e20a25c into main Aug 29, 2023
@jacobhinkle jacobhinkle deleted the fix-issue798 branch August 29, 2023 22:06
jacobhinkle added a commit that referenced this pull request Aug 30, 2023
)

This updates the root to rfactor propagation in IterType concretization
of dynamic fusions.

Previously, although we only overwrote Symbolic IterDomains in this
step, we still asserted that we could infer an IterType for each I moved
that check so that it is only applied when we need to make a change.

Additionally, we previously propagated Broadcast-only IterDomains as
Symbolic, since we combine with our previous estimate using
promoteIterType. As mentioned in a comment, this means Broadcast gets
propagated as Symbolic. Instead we now only fall back to promoteIterType
when there are multiple input IterTypes to the IterDomain expression.

Fixes #798
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Internal assert for pad: Failed to concretize an output IterType for expression

2 participants