Skip to content

Concretize all dynamic outputs of downstream expressions#420

Closed
jacobhinkle wants to merge 28 commits intomainfrom
fix_issue418
Closed

Concretize all dynamic outputs of downstream expressions#420
jacobhinkle wants to merge 28 commits intomainfrom
fix_issue418

Conversation

@jacobhinkle
Copy link
Collaborator

@jacobhinkle jacobhinkle commented May 30, 2023

Fixes #418.

Previously, when concretizing a dynamic Fusion, we find all statements between inputs and outputs and mutate every Val. In the case of Exprs with multiple outputs, if some outputs are not used then they would not be concretized. This was the case in #418 in which the count output from Welford was not used, but it must be concretized. This PR:

  • changes the concretization traversal to mutate every root Val as well as the output of every intermediate expression between inputs and outputs.
  • propagates not only IterType but also extent expressions from producer to consumer during concretization
  • introduces the fusion_ir_dynamic debug dump option, which can be used to view the fusion before concretization.

@jacobhinkle
Copy link
Collaborator Author

jacobhinkle commented May 30, 2023

This seemed to have fixed the immediate issue, but I am now hitting another issue with this repro where some input extents do not seem to be bound properly. This may be an entirely separate issue..

Concretized fusion:

Inputs:
  T0_g[ iS0{i0}, iS1{i2}, iS2{i3}, iS3{i4} ], float
  i5, int64_t
Outputs:
  T7_g[ iS53{i0}, iS54{( i2 / i5 )}, bS37{1}, bS38{1}, bS39{1} ], float
  T6_g[ iS57{i0}, iS58{( i2 / i5 )}, bS32{1}, bS33{1}, bS34{1} ], float

%kernel_math {
T8_l[ iS40{i0}, iS45{4}rf, iS46{( ceilDiv(i2, 4) )}rf, iS42{i3}, iS43{i4} ] = view( T0_g[ iS0{i0}, iS1{i2}, iS2{i3}, iS3{i4} ] )
T2_l[ iS47{i0}, iS48{( i2 / i5 )}, rS15{i5}, rS16{i3}, rS17{i4} ](Avg),
T3_l[ iS49{i0}, iS50{( i2 / i5 )}, rS20{i5}, rS21{i3}, rS22{i4} ](Var),                                                          
T4_l[ iS51{i0}, iS52{( i2 / i5 )}, rS25{i5}, rS26{i3}, rS27{i4} ](Count)
 = Welford ( T8_l[ iS40{i0}, iS45{4}rf, iS46{( ceilDiv(i2, 4) )}rf, iS42{i3}, iS43{i4} ](Avg),
  allreduce = false )                                                                                                            
T7_g[ iS53{i0}, iS54{( i2 / i5 )}, bS37{1}, bS38{1}, bS39{1} ]
   = broadcast( T2_l[ iS47{i0}, iS48{( i2 / i5 )}, rS15{i5}, rS16{i3}, rS17{i4} ] )                                              
d16 = (double)(i5);                                                                                                              
d17 = double(1) * d16;                                                                                                           
d18 = (double)(i3);                                                                                                              
d19 = d17 * d18;                                                                                                                 
d20 = (double)(i4);                                                                                                              
d21 = d19 * d20;                                                                                                                 
d25 = reciprocal(d21);                                                                                                           
T5_l[ iS55{i0}, iS56{( i2 / i5 )} ]                                                                                                 
  = T3_l[ iS49{i0}, iS50{( i2 / i5 )}, rS20{i5}, rS21{i3}, rS22{i4} ]                                                              
    * d25;
T6_g[ iS57{i0}, iS58{( i2 / i5 )}, bS32{1}, bS33{1}, bS34{1} ]                                                                      
  = broadcast( T5_l[ iS55{i0}, iS56{( i2 / i5 )} ] )                                                                            
}                                                                                                                                 

winds up with this expression evaluator:

Evaluation context                                                                                                               
--------------------                                                                                                                                                                                                                                              Pre-computed Values                                                                                                              
Precomputed Values:                                                                                                              
i5 = 32                                                                                                                          
i4 = 28                                                                                                                          
i3 = 28                                                                                                                          
i76 = 32                                                                                                                         
i75 = 4                                                                                                                          
i0 = 256                                                                                                                         
--------------------                                                                                                                                  

Here, i2 / i5 evaluates to 4, and the concretized reshape uses Split(4) (not Split(i2 / i5)), so the output shape is hardcoded as 4. However, in downstream extents, the value i2 / i5 is used, so we do need to bind the value i2 = 128. I am currently looking into why this is not happening...

@jacobhinkle
Copy link
Collaborator Author

I understand the second problem better now. With the concretization fix, the above Fusion is segmented into one segment that only does the reshape, followed by another kernel doing all of the normalization:

group details:                                                                                                                   
g{(pointwise)                                                                                                                    
inputs:                                                                                                                          
T0_g[ iS0{i0}, iS1{i2}, iS2{i3}, iS3{i4} ] float                                                                                 
outputs:                                                                                                                         
T8_g[ iS40{i0}, iS45{4}rf, iS46{( ceilDiv(i2, 4) )}rf, iS42{i3}, iS43{i4} ] float                                                
                                                                                                                                 
                                                                                                                                 
T8_g[ iS40{i0}, iS45{4}rf, iS46{( ceilDiv(i2, 4) )}rf, iS42{i3}, iS43{i4} ] = view( T0_g[ iS0{i0}, iS1{i2}, iS2{i3}, iS3{i4} ] ) (15)                                                                                                                             
}                                                                                                                                
                                                                                                                                 
g{(persistent)                                                                                                                   
inputs:                                                                                                                          
i3 int64_t                                                                                                                       
i4 int64_t                                                                                                                       
i5 int64_t                                                                                                                       
T8_g[ iS40{i0}, iS45{4}rf, iS46{( ceilDiv(i2, 4) )}rf, iS42{i3}, iS43{i4} ] float                                                
outputs:                                                                                                                         
T6_g[ iS57{i0}, iS58{( i2 / i5 )}, bS32{1}, bS33{1}, bS34{1} ] float                                                             
T7_g[ iS53{i0}, iS54{( i2 / i5 )}, bS37{1}, bS38{1}, bS39{1} ] float 
...

The input to the second segment is T8_g, which is the output of the reshape from the first segment. As mentioned above, the dynamic reshape is rewritten as a static reshape having a Split(4), so instead of i2 / i5, the constant 4 is inserted. That means that we never bind i2 or i2 / i5 in the FusionExecutor for the second segment.

I think we could address this by providing all input extents/scalars to every FusionExecutor in the segmented fusion.

@jacobhinkle
Copy link
Collaborator Author

Pushed a change that propagates extent expressions P2C during concretization. This fixes cases like in #418 since the "proper" extents from the static reshape will get written downstream. Note that it may still be possible for those expressions to be used in another op (i.e. not an ID extent) using something like add(tv8, tv7->axis(1)->extent()) in which case this use of the old extent will not be replaced during concretization.

Note that there is a failing test: DynamicTransform5_CUDA, which is a dynamic reshape followed by a pad. This fails because of the earlier change in this PR: in concretize() we get a vector of Statements using StmtSort and then iterate over the vector and mutate each. For some reason this is "double mutating" the output of the reshape. I am looking for a better way to do this currently.

@jacobhinkle
Copy link
Collaborator Author

jacobhinkle commented Jun 1, 2023

The current state of this PR rewrites the concretized reshape so that its rfactor extents match those provided to the dynamic reshape op. This introduced some complication, and is failing when creating new TensorDomains in lowering, due to the validateDomainEquivalence check.

I just tried an alternative: rolling this branch back to d675161 and cherry-picking bf9aa1e. In that approach, instead of making the reshape rfactor match the dynamic values, we instead propagate the ceilDiv style expressions forward during concretization. With this change, plus fixing the bug with how we find Vals to mutate, all tests pass. I will push these changes soon.

@jacobhinkle
Copy link
Collaborator Author

!build

@jacobhinkle jacobhinkle marked this pull request as ready for review June 1, 2023 19:53
@jacobhinkle jacobhinkle requested a review from naoyam June 1, 2023 19:53
@jacobhinkle
Copy link
Collaborator Author

Currently, there are two test failures caused by the same subtle issue. When we create a new TensorDomain from inputs that have Symbolic IDs only in some position, we create a Symbolic ID in the new root domain at that position. The extent is that of the inputs if the expressions all match. In the case of cat, although we could prove they are equal, we do not currently try and prove that: we just check with sameAs. If the extents mismatch, we just create a new placeholder Int for the output extent. The reasoning is that for many operations we don't know which of the input expressions will match the output extent since some may be concretized to broadcast. In this PR we also propagate extent expressions when we concretize symbolic extents, which tends to replace these new expressions with something that can be evaluated properly.

However, the error comes in when we do an operation with some Symbolic ID like ?S11{i14} that creates a non-Symbolic IterDomain using that placeholder extent. For example if we do a reduction over that dimension, the output will be rS13{i14}, which will not be updated currently in concretization since it is non-Symbolic. The solution is to build a map of extent replacements mutate IDs whenever a replaced ID is present, regardless of whether it is Symbolic or not. I will push a fix for this soon. In the time being, this PR should work for GroupNorm with dynamic input sizes.

@jacobhinkle
Copy link
Collaborator Author

!build

Comment on lines +635 to +648
// NOTE: ops::newOutputTV would not necessarily be able to infer that the
// padded dimensions are all of the same size. However, we know that they are
// constructed such that that is the case, so we can use
auto out_domain = ops::newOutputDomain(resized_inputs);
// Override the concatenated dimension and insert an IterDomain with the true
// extent, if needed
if (!out_domain.at(cat_dim)->extent()->sameAs(concat_ext)) {
out_domain[cat_dim] =
IterDomainBuilder(out_domain.at(cat_dim)).extent(concat_ext).build();
}
auto out = IrBuilder::create<TensorView>(
IrBuilder::create<TensorDomain>(
out_domain, TensorDomain::getContiguityFilledWith(out_domain, true)),
dtype);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fixes a breakage that occurred when I made the other changes. It makes the output extent of cat look as it should (e.g. (i0 + i2) + i4) instead of creating a new symbolic extent as was done previously and which complicated concretization.

jacobhinkle added a commit that referenced this pull request Jun 22, 2023
This is in lieu of replacing all uses of symbolic extents, which cannot
be done reliably since they might appear as attributes or members of
objects which are untracked. See #420
@jacobhinkle jacobhinkle marked this pull request as draft July 10, 2023 19:15
jacobhinkle added a commit that referenced this pull request Jul 19, 2023
This change ensures that concretization visits all outputs to active
expressions. Even though unused outputs do not directly affect the
result of the Fusion, they still must not have Symbolic IterDomains
during scheduling. Note that this pre-empts the more complicated #420.

This fixes the immediate cause of #418. However, that test still fails
due to another issue where scalars get lost during segmentation. The
test included here only tests that concretization works properly;
scheduling and execution will be tested in another PR.
@jacobhinkle
Copy link
Collaborator Author

Closing in favor of #591

@zasdfgbnm zasdfgbnm deleted the fix_issue418 branch July 20, 2023 14:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

ops.reshape errors with !fusion->hasDynamicTransform()

1 participant