Skip to content

Propagate modified reshape extents at concretization#630

Merged
jacobhinkle merged 11 commits intomainfrom
seg_lost_scalars
Jul 21, 2023
Merged

Propagate modified reshape extents at concretization#630
jacobhinkle merged 11 commits intomainfrom
seg_lost_scalars

Conversation

@jacobhinkle
Copy link
Collaborator

@jacobhinkle jacobhinkle commented Jul 20, 2023

This PR propagates extent scalars after concretizing reshape. Previously the following fusion

  auto tv0 = makeSymbolicTensor(4);
  fusion->addInput(tv0);
  auto s0 = IrBuilder::create<Val>(DataType::Int);
  fusion->addInput(s0);

  auto sh = tensor_sizes(tv0);
  auto tv1 = reshape(tv0, {sh[0], div(sh[1], s0), s0, sh[2], sh[3]});
  // Reducing along axis 2 in tv1 is equivalent to a partial reduction across
  // axis 1 of tv0.
  auto vm = variance_mean(tv1, {2, 3, 4}, 0, true);
  fusion->addOutput(vm.mean);
  fusion->addOutput(vm.var);

would be concretized like this:

Inputs:
  T0_g[ iS0{i0}, iS1{i2}, iS2{i3}, iS3{i4} ], float
  i5, int64_t
Outputs:
  T7_g[ iS49{i0}, iS50{( i2 / i5 )}, bS37{1}, bS38{1}, bS39{1} ], float
  T6_g[ iS55{i0}, iS56{( i2 / i5 )}, bS32{1}, bS33{1}, bS34{1} ], float

%kernel_math {
T8_l[ iS40{i0}, iS45{4}rf, iS46{( ceilDiv(i2, 4) )}rf, iS42{i3}, iS43{i4} ] = view( T0_g[ iS0{i0}, iS1{i2}, iS2{i3}, iS3{i4} ] )
T2_l[ iS47{i0}, iS48{( i2 / i5 )}, rS15{i5}, rS16{i3}, rS17{i4} ](Avg),
T3_l[ iS51{i0}, iS52{( i2 / i5 )}, rS20{i5}, rS21{i3}, rS22{i4} ](Var),
T4_l[ iS57{i0}, iS58{( i2 / i5 )}, rS25{i5}, rS26{i3}, rS27{i4} ](Count)
 = Welford ( T8_l[ iS40{i0}, iS45{4}rf, iS46{( ceilDiv(i2, 4) )}rf, iS42{i3}, iS43{i4} ](Avg),
  allreduce = false )
T7_g[ iS49{i0}, iS50{( i2 / i5 )}, bS37{1}, bS38{1}, bS39{1} ]
   = broadcast( T2_l[ iS47{i0}, iS48{( i2 / i5 )}, rS15{i5}, rS16{i3}, rS17{i4} ] )
d17 = (double)(i5);
d19 = double(1) * d17;
d21 = (double)(i3);
d23 = d19 * d21;
d25 = (double)(i4);
d27 = d23 * d25;
d33 = reciprocal(d27);
T5_l[ iS53{i0}, iS54{( i2 / i5 )} ]
   = T3_l[ iS51{i0}, iS52{( i2 / i5 )}, rS20{i5}, rS21{i3}, rS22{i4} ]
   * d33;
T6_g[ iS55{i0}, iS56{( i2 / i5 )}, bS32{1}, bS33{1}, bS34{1} ]
   = broadcast( T5_l[ iS53{i0}, iS54{( i2 / i5 )} ] )
}

Now this concretizes as

Inputs:
  T0_g[ iS0{i0}, iS1{i2}, iS2{i3}, iS3{i4} ], float
  i5, int64_t
Outputs:
  T7_g[ iS52{i0}, iS53{4}, bS37{1}, bS38{1}, bS39{1} ], float
  T6_g[ iS62{i0}, iS63{4}, bS32{1}, bS33{1}, bS34{1} ], float

%kernel_math {
T8_l[ iS40{i0}, iS45{4}rf, iS46{( ceilDiv(i2, 4) )}rf, iS42{i3}, iS43{i4} ] = view( T0_g[ iS0{i0}, iS1{i2}, iS2{i3}, iS3{i4} ] )
T2_l[ iS49{i0}, iS50{4}, rS48{( ceilDiv(i2, 4) )}, rS16{i3}, rS17{i4} ](Avg),
T3_l[ iS56{i0}, iS57{4}, rS55{( ceilDiv(i2, 4) )}, rS21{i3}, rS22{i4} ](Var),
T4_l[ iS66{i0}, iS67{4}, rS65{( ceilDiv(i2, 4) )}, rS26{i3}, rS27{i4} ](Count)
 = Welford ( T8_l[ iS40{i0}, iS45{4}rf, iS46{( ceilDiv(i2, 4) )}rf, iS42{i3}, iS43{i4} ](Avg),
  allreduce = false )
T7_g[ iS52{i0}, iS53{4}, bS37{1}, bS38{1}, bS39{1} ]
   = broadcast( T2_l[ iS49{i0}, iS50{4}, rS48{( ceilDiv(i2, 4) )}, rS16{i3}, rS17{i4} ] )
d17 = (double)(i5);
d19 = double(1) * d17;
d21 = (double)(i3);
d23 = d19 * d21;
d25 = (double)(i4);
d27 = d23 * d25;
d33 = reciprocal(d27);
T5_l[ iS59{i0}, iS60{4} ]
   = T3_l[ iS56{i0}, iS57{4}, rS55{( ceilDiv(i2, 4) )}, rS21{i3}, rS22{i4} ]
   * d33;
T6_g[ iS62{i0}, iS63{4}, bS32{1}, bS33{1}, bS34{1} ]
   = broadcast( T5_l[ iS59{i0}, iS60{4} ] )
}

This helps ensure that no scalars get lost during segmentation, which could previously occur if the reshape output became a segmentation edge (see #629).

I also took this opportunity to remove TensorView::convertRfactorToRootDomain(). It was only used at segmentation and did a traversal for each segment input TV. Now, we call it once and that replaces all extents for all input TVs using a single traversal.

Fixes #629 and fixes #418.

This allows us to just do one traversal for the whole segment to replace
vals instead of one traversal per input TV.
@jacobhinkle jacobhinkle changed the title Seg lost scalars Propagate modified reshape extents at concretization Jul 20, 2023
@jacobhinkle
Copy link
Collaborator Author

Note that an alternative approach would be to reset the extents of the concretized reshaped TV to the original expressions. I went with this approach since it more closely resembles the extents we would see using static reshapes.

Comment on lines -260 to -262
TORCH_INTERNAL_ASSERT(
!replacement_extents.empty() &&
getMaybeRFactorDomain().size() == replacement_extents.size());
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I omitted the check that replacement_extents is not empty since it did not seem necessary.

@jacobhinkle jacobhinkle marked this pull request as ready for review July 20, 2023 14:50
@jacobhinkle jacobhinkle marked this pull request as draft July 20, 2023 15:06
Comment on lines +457 to +458
/*traverse_members*/ true,
/*traverse_attributes*/ true,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we propagate scalars, we should do this anywhere they might occur; namely in members and attributes. This was not necessary when we were only replacing IterTypes since IterDomains are traversed regardless of these settings.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't we discuss duplicated replacements if members were also traversed? Would it be a concern?

@jacobhinkle jacobhinkle marked this pull request as ready for review July 20, 2023 19:32
return toString(indent_size);
}

void TensorView::convertRfactorToRootDomain() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change completely unrelated with the propagation of concretized vals?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. At first I thought it was going to be necessary. I realized it was not necessary for this PR, but it is a little faster and reduces the TensorView interface so I figured I would keep it. I can split it into another PR if you prefer.

@jacobhinkle
Copy link
Collaborator Author

!build

Copy link
Collaborator

@naoyam naoyam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR looks good to me, but I have one question on the rfactor-root replacement

Comment on lines +457 to +458
/*traverse_members*/ true,
/*traverse_attributes*/ true,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't we discuss duplicated replacements if members were also traversed? Would it be a concern?

for (const auto& id : rfactor) {
if (id->isRFactorProduct()) {
// Create new symbolic extents for rfactor iterDomains
auto domain_extent = (!tv_is_concrete)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this function is fine as it's mostly just copied from tensor_view.cpp. However, I don't know why we do use all symbolic extents if tv_is_concrete. It seems this means that if there's any IterDomain with a symbolic extent, all IterDomains would have symbolic extents.

This is not about this PR itself, but does it make sense?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't we discuss duplicated replacements if members were also traversed? Would it be a concern?

This is an important point since it could lead to us concretizing one aspect of a Val but losing a previously concretized aspect (for example losing the concretized extent when we later concretize IterType). But I think we are OK if we are careful to use maybeMutated when creating replacements. In one case we do already do multiple mutations; we call OptOutMutator::mutate(id) at the beginning of our mutate(IterDomain* id) override.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems this means that if there's any IterDomain with a symbolic extent, all IterDomains would have symbolic extents.

Yeah that is strange. All IterDomains in the rfactor domain are checked to determine concreteness, but then the condition is only applied to rfactor products. The purpose of this function is to make the root a "standalone" domain so that we can bind input shapes to it, so it seems like all extents regardless of if they were already rfactor products should be either constant ints or pure symbolic.

We always need to create a new IterDomain for rfactor products in order to cut the connection to the root domain. It seems like a simpler condition would be

auto domain_extent = id->extent()->isConstScalar()
    ? id->extent()
    : IrBuilder::create<Val>(DataType::Int);

And in light of the above, we could move this outside of the id->isRFactorProduct() check, so that we would replace IDs that are not rfactor products if they have non-constant derived extents.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please create a separate issue for this. Let's merge this PR as is.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. Merging without that change. I will do that in another PR.

@jacobhinkle jacobhinkle merged commit d25e8c5 into main Jul 21, 2023
@jacobhinkle jacobhinkle deleted the seg_lost_scalars branch July 21, 2023 22:18
@jacobhinkle jacobhinkle mentioned this pull request Sep 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Extent root scalars can get lost during segmentation ops.reshape errors with !fusion->hasDynamicTransform()

2 participants