Skip to content

Handle constant -1 values in dynamic reshape at definition#590

Merged
jacobhinkle merged 4 commits intomainfrom
fix_issue249
Jul 14, 2023
Merged

Handle constant -1 values in dynamic reshape at definition#590
jacobhinkle merged 4 commits intomainfrom
fix_issue249

Conversation

@jacobhinkle
Copy link
Collaborator

@jacobhinkle jacobhinkle commented Jul 14, 2023

This replaces scalar new sizes that evaluate to -1 at definition time in the reshape() op with a simplified expression for the new output extent. Previous to this change, the output shape would be set to the Val* that was provided, which evaluates to -1; after which it is difficult to reliably change even at concretization, since that scalar might have legitimate other uses.

Before this change:

auto tv0 = makeSymbolicTensor(4); // [ iS0{i0}, iS1{i2}, iS2{i3}, iS3{i4} ]
auto tv1 = reshape({tv0->axis(0)->extent(), tv0->axis(2)->extent(), IrBuilder::create<Scalar>(-1)});
// [ ?S4{i0}, ?S5{i3}, ?S6{-1} ]

After this change:

auto tv0 = makeSymbolicTensor(4); // [ iS0{i0}, iS1{i2}, iS2{i3}, iS3{i4} ]
auto tv1 = reshape({tv0->axis(0)->extent(), tv0->axis(2)->extent(), IrBuilder::create<Scalar>(-1)});
// [ iS4{i0}, iS5{i3}, ?S6{( i2 * i4 )} ]

In this PR we also introduce a check that only allows values of -1 when that value is known at compile time; for example passing an input scalar of -1 or an expression that evaluates to -1 to a dynamic reshape will result in an error at concretization.

Fixes #249

} else {
TORCH_INTERNAL_ASSERT(
extent_int > 0, "Invalid output domain extent: ", extent_int);
}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These checks are handled by reshape() and analyzeView() and are not needed here.

other_new_numel = mul(other_new_numel, new_sizes.at(j));
}
new_size = div(numel, other_new_numel);
new_size = simplifyExpr(new_size);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that SimplifyingIrBuilder evaluates constant subexpressions but does not simplify (x * y) / y, so I used simplifyExpr here.

Comment on lines +140 to +142
for (const auto j : c10::irange(inp_dom.size())) {
numel = mul(numel, inp_dom.at(j)->extent());
}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we put this in a new method Scalar TensorView::numel() const instead?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe in loewr_utils? I'd prefer minimalistic designs for the IR nodes themselves.

@jacobhinkle jacobhinkle requested a review from naoyam July 14, 2023 15:51
@jacobhinkle jacobhinkle marked this pull request as ready for review July 14, 2023 15:52
@jacobhinkle
Copy link
Collaborator Author

!build

@jacobhinkle jacobhinkle changed the title Handle constInt -1 values in reshape at definition Handle constant -1 values in reshape at definition Jul 14, 2023
@jacobhinkle jacobhinkle changed the title Handle constant -1 values in reshape at definition Handle constant -1 values in dynamic reshape at definition Jul 14, 2023
Comment on lines +1136 to +1140
// Dynamic reshape sizes that are not constant at definition must be explicit:
// no -1 allowed
EXPECT_THROW(
fusion_executor_cache.runFusionWithInputs({at_x, 2, 4, -1}),
std::exception);
Copy link
Collaborator Author

@jacobhinkle jacobhinkle Jul 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Input scalars of -1 are not supported for dynamic reshape currently. Supporting this is not impossible, but would increase the complexity quite a bit. At first glance, it seems like we could just replace any Val that evaluates to -1 in a reshape output with the proper shape. However, that is not safe since we currently replace all uses of concretized scalars and we would run the risk of erroneously replacing intended values of -1 with the positive integer. Instead, it's safer to just disallow this explicitly and require -1 to be a constant at definition.

Copy link
Collaborator

@naoyam naoyam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@jacobhinkle jacobhinkle merged commit ef13b21 into main Jul 14, 2023
@jacobhinkle jacobhinkle deleted the fix_issue249 branch July 14, 2023 20:56
jacobhinkle added a commit that referenced this pull request Jul 17, 2023
I did not commit the test change here before merging #590. This updates
the test to reflect the new behavior, which is that reshape sizes of -1
must be constant at definition time: for dynamic reshapes with input
scalars they must not be -1.
jacobhinkle added a commit that referenced this pull request Jul 17, 2023
I did not commit the test change here before merging #590. This updates
the test to reflect the new behavior, which is that reshape sizes of -1
must be constant at definition time: for dynamic reshapes with input
scalars they must not be -1.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support symbolic -1 output dimensions in reshape

2 participants