Handle constant -1 values in dynamic reshape at definition#590
Handle constant -1 values in dynamic reshape at definition#590jacobhinkle merged 4 commits intomainfrom
Conversation
| } else { | ||
| TORCH_INTERNAL_ASSERT( | ||
| extent_int > 0, "Invalid output domain extent: ", extent_int); | ||
| } |
There was a problem hiding this comment.
These checks are handled by reshape() and analyzeView() and are not needed here.
| other_new_numel = mul(other_new_numel, new_sizes.at(j)); | ||
| } | ||
| new_size = div(numel, other_new_numel); | ||
| new_size = simplifyExpr(new_size); |
There was a problem hiding this comment.
Note that SimplifyingIrBuilder evaluates constant subexpressions but does not simplify (x * y) / y, so I used simplifyExpr here.
| for (const auto j : c10::irange(inp_dom.size())) { | ||
| numel = mul(numel, inp_dom.at(j)->extent()); | ||
| } |
There was a problem hiding this comment.
Should we put this in a new method Scalar TensorView::numel() const instead?
There was a problem hiding this comment.
Maybe in loewr_utils? I'd prefer minimalistic designs for the IR nodes themselves.
|
!build |
| // Dynamic reshape sizes that are not constant at definition must be explicit: | ||
| // no -1 allowed | ||
| EXPECT_THROW( | ||
| fusion_executor_cache.runFusionWithInputs({at_x, 2, 4, -1}), | ||
| std::exception); |
There was a problem hiding this comment.
Input scalars of -1 are not supported for dynamic reshape currently. Supporting this is not impossible, but would increase the complexity quite a bit. At first glance, it seems like we could just replace any Val that evaluates to -1 in a reshape output with the proper shape. However, that is not safe since we currently replace all uses of concretized scalars and we would run the risk of erroneously replacing intended values of -1 with the positive integer. Instead, it's safer to just disallow this explicitly and require -1 to be a constant at definition.
I did not commit the test change here before merging #590. This updates the test to reflect the new behavior, which is that reshape sizes of -1 must be constant at definition time: for dynamic reshapes with input scalars they must not be -1.
I did not commit the test change here before merging #590. This updates the test to reflect the new behavior, which is that reshape sizes of -1 must be constant at definition time: for dynamic reshapes with input scalars they must not be -1.
This replaces scalar new sizes that evaluate to
-1at definition time in thereshape()op with a simplified expression for the new output extent. Previous to this change, the output shape would be set to theVal*that was provided, which evaluates to-1; after which it is difficult to reliably change even at concretization, since that scalar might have legitimate other uses.Before this change:
After this change:
In this PR we also introduce a check that only allows values of -1 when that value is known at compile time; for example passing an input scalar of -1 or an expression that evaluates to -1 to a dynamic reshape will result in an error at concretization.
Fixes #249