Skip to content

Perform cancellation in SimplifyingIrBuilder::addExpr#2020

Open
jacobhinkle wants to merge 3 commits intomainfrom
addExpr_cancellation
Open

Perform cancellation in SimplifyingIrBuilder::addExpr#2020
jacobhinkle wants to merge 3 commits intomainfrom
addExpr_cancellation

Conversation

@jacobhinkle
Copy link
Collaborator

@jacobhinkle jacobhinkle commented Apr 2, 2024

This change modifies SimplifyingIrBuilder::addExpr to cancel negative terms with positive ones. For example, the following expressions will all be simplified to x + y or y + x:

  • x - (-y)
  • (x + z) - ((-y) + z)
  • (x + y) + (a + (b + (-(b + a)))

This is a more powerful alternative to #2017, but it does introduce more complexity in SimplifyingIrBuilder.

Note that if cancellation is performed, terms might be reordered from the order given. However, if no cancellation is performed then IrBuilder::addExpr(lhs, rhs) is returned, so no reordering will take place.

TODO: tests

@jacobhinkle jacobhinkle requested a review from naoyam April 2, 2024 14:23
@jacobhinkle
Copy link
Collaborator Author

!build --diff --diff-bench

@jacobhinkle jacobhinkle marked this pull request as ready for review April 2, 2024 15:18
std::vector<bool> cancelled_pos(pos_terms.size(), false);
std::vector<bool> cancelled_neg(neg_terms.size(), false);
bool performed_cancellation = false;
for (size_t j : c10::irange(neg_terms.size())) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Iterate over negative terms first since they are likely more rare. If no negative terms were found, we'll quickly see that no cancellation is possible and return the simple sum.

@naoyam
Copy link
Collaborator

naoyam commented Apr 2, 2024

I wonder this level of simplification should be done by the expr simplifier.

Some historical context: SimplifyingIrBuilder existed before the expr simplifier. We could replace SimplifyingIrBuilder entirely with the expr simplifier as the latter is more powerful. However, it's also much more costly than the simple and quick simplification done by SimplifyingIrBuilder. We can use SimplifyingIrBuilder without worrying about added overheads but that's not the case with the expr simplifier.

@naoyam
Copy link
Collaborator

naoyam commented Apr 2, 2024

I wonder this level of simplification should be done by the expr simplifier.

Some historical context: SimplifyingIrBuilder existed before the expr simplifier. We could replace SimplifyingIrBuilder entirely with the expr simplifier as the latter is more powerful. However, it's also much more costly than the simple and quick simplification done by SimplifyingIrBuilder. We can use SimplifyingIrBuilder without worrying about added overheads but that's not the case with the expr simplifier.

So, to me, SimplifyingIrBuilder is for quick low-hanging fruits only and thus can be used without worrying about overheads, whereas the expr simplifier is a heavy-weight process that's much more comprehensive, so it should be used selectively.

Any thoughts? @zasdfgbnm

@zasdfgbnm
Copy link
Collaborator

So, to me, SimplifyingIrBuilder is for quick low-hanging fruits only and thus can be used without worrying about overheads, whereas the expr simplifier is a heavy-weight process that's much more comprehensive, so it should be used selectively.

I agree with this statement. However, it is not clear to me what should be the boundary. What is considered a "low-hanging fruits" and what should be considered as "heavy-weight"? We need a clear definition here.

An alternative idea: I think the idea of "optimization fuel" that I learnt from @wujingyue's talk is really interesting. Should we just totally kill SimplifyingIrBuilder and add a double fuel argument to simplifyExpr guarding how much effort we should spend on simplification? simplifyExpr(..., fuel=0.0) means no simplification at all, simplifyExpr(..., fuel=inf) means try the best to simplify. We can use a small fuel value if we don't want to worry about overheads.

@naoyam
Copy link
Collaborator

naoyam commented Apr 2, 2024

I agree with this statement. However, it is not clear to me what should be the boundary. What is considered a "low-hanging fruits" and what should be considered as "heavy-weight"? We need a clear definition here.

Maybe, no recursive simplification?

An alternative idea: I think the idea of "optimization fuel" that I learnt from @wujingyue's talk is really interesting. Should we just totally kill SimplifyingIrBuilder and add a double fuel argument to simplifyExpr guarding how much effort we should spend on simplification? simplifyExpr(..., fuel=0.0) means no simplification at all, simplifyExpr(..., fuel=inf) means try the best to simplify. We can use a small fuel value if we don't want to worry about overheads.

Or super fast expr simplifier? 😉

@zasdfgbnm
Copy link
Collaborator

Maybe, no recursive simplification?

Sounds like a good boundary.

@jacobhinkle
Copy link
Collaborator Author

Yeah, I was thinking the boundary is something like "no recursive, does not require inequality or divisibility proofs", since those are the heavy-weight parts of simplifyExpr. Great discussion topic though.

@zasdfgbnm
Copy link
Collaborator

In order to use simplifyExpr and get some useful simplifications, I think the minimum price we have to pay is assoc_comm::flatten+eliminateTrivialComputation, which is slightly higher than this PR.

@jacobhinkle
Copy link
Collaborator Author

If simplifyExpr could cache its simplifications, at least for a given set of assumptions, then it could be used without assumptions in a practically non-recursive fashion and should be pretty fast. We could keep a thread-local Context that has no assumptions and use simplifyExpr(IrBuilder::addExpr(lhs, rhs)) in place of SimplifyingIrBuilder::addExpr(lhs, rhs). Then we could cache simplified expressions so that recurseDown stops when it recognizes an already-simplified expression. That's similar to what @zasdfgbnm did for inequality proofs in #1972 .

@zasdfgbnm
Copy link
Collaborator

Did a quick experiment in #2022, which uses simplifyExpr (enabling only eliminateTrivialComputation disabling all proves, with the cache mentioned by @jacobhinkle enabled) to replace SimplifyingIrBuilder, the time for GPUTTensorCoreTest.FusionAmpereMatmul_CUDA goes up from 4012 ms to 10686 ms if I just mindlessly replace all SimplifyingIrBuilder with simplifyExpr. If I am being smarter in the sense that if I know this expr will eventually be simplified, I use IrBuilder instead of SimplifyingIrBuilder, I am bring the test time back to 4482 ms, still a noticeable 10% slow down compared with the previous SimplifyingIrBuilder.

@jacobhinkle
Copy link
Collaborator Author

On my workstation, I see FusionAmpereMatmul_CUDA go from 3460 ms on main (3686a96) to 3618 ms (+4.5%) with this PR. But there's a lot of error in this measurement. There doesn't seem to be a consistently large perf hit due to this PR.

withPR2020.out                                                                                                             
63:[       OK ] GPUTTensorCoreTest.FusionAmpereMatmul_CUDA (3618 ms)                 
69:[       OK ] GPUTTensorCoreTest.FusionAmpereMatmulBFloat16_CUDA (3416 ms)         
71:[       OK ] GPUTTensorCoreTest.FusionAmpereMatmulPipelineGmem_CUDA (4554 ms)   
75:[       OK ] GPUTTensorCoreTest.FusionAmpereMatmulRegDoubleBuffer_CUDA (6883 ms) 
83:[       OK ] GPUTTensorCoreTest.FusionAmpereMatmulTNcpAsync_CUDA (1327 ms)    
89:[       OK ] GPUTTensorCoreTest.FusionAmpereMatmulTNSwizzled_CUDA (1643 ms)                       
91:[       OK ] GPUTTensorCoreTest.FusionAmpereMatmulLargeLoad_CUDA (3492 ms)                                              
95:[       OK ] GPUTTensorCoreTest.FusionAmpereMatmulTileCheck4warp_CUDA (47156 ms)                                        
97:[       OK ] GPUTTensorCoreTest.FusionAmpereMatmulTileCheck8warp_CUDA (36610 ms)
99:[       OK ] GPUTTensorCoreTest.FusionAmpereMatmulTileCheck6warp_CUDA (10190 ms)
101:[       OK ] GPUTTensorCoreTest.FusionAmpereMatmulLargeLoadLargeK_CUDA (3134 ms)
105:[       OK ] GPUTTensorCoreTest.FusionAmpereMatmulSmemEpilogue_CUDA (4495 ms)   
130:[  FAILED  ] GPUTTensorCoreTest.FusionAmpereMatmulSmemEpiloguePromotionRequiredA100_CUDA (620 ms)                      
132:[       OK ] GPUTTensorCoreTest.FusionAmpereMatmulSmemEpilogueCast_CUDA (4296 ms)                                      
134:[       OK ] GPUTTensorCoreTest.FusionAmpereMatmulSmemEpilogueRelu_CUDA (4354 ms)
136:[       OK ] GPUTTensorCoreTest.FusionAmpereMatmulSplitK_CUDA (5968 ms)        
138:[       OK ] GPUTTensorCoreTest.FusionAmpereMatmulSplitKBias_CUDA (6993 ms)    
140:[       OK ] GPUTTensorCoreTest.FusionAmpereMatmulBatchSplitK_CUDA (6754 ms)   
142:[       OK ] GPUTTensorCoreTest.FusionAmpereMatmulBatchSplitKBias_CUDA (8305 ms)
                                                                                                                           
main.out                                                                                                                   
63:[       OK ] GPUTTensorCoreTest.FusionAmpereMatmul_CUDA (3460 ms)                 
69:[       OK ] GPUTTensorCoreTest.FusionAmpereMatmulBFloat16_CUDA (3441 ms)         
71:[       OK ] GPUTTensorCoreTest.FusionAmpereMatmulPipelineGmem_CUDA (4581 ms)  
75:[       OK ] GPUTTensorCoreTest.FusionAmpereMatmulRegDoubleBuffer_CUDA (6974 ms) 
83:[       OK ] GPUTTensorCoreTest.FusionAmpereMatmulTNcpAsync_CUDA (1285 ms)    
89:[       OK ] GPUTTensorCoreTest.FusionAmpereMatmulTNSwizzled_CUDA (1618 ms)                       
91:[       OK ] GPUTTensorCoreTest.FusionAmpereMatmulLargeLoad_CUDA (3127 ms)                                                           
95:[       OK ] GPUTTensorCoreTest.FusionAmpereMatmulTileCheck4warp_CUDA (48088 ms)                                                     
97:[       OK ] GPUTTensorCoreTest.FusionAmpereMatmulTileCheck8warp_CUDA (35263 ms)                                                     
99:[       OK ] GPUTTensorCoreTest.FusionAmpereMatmulTileCheck6warp_CUDA (9878 ms)                                                      
101:[       OK ] GPUTTensorCoreTest.FusionAmpereMatmulLargeLoadLargeK_CUDA (3637 ms)                                                    
105:[       OK ] GPUTTensorCoreTest.FusionAmpereMatmulSmemEpilogue_CUDA (4866 ms)                                                       
130:[  FAILED  ] GPUTTensorCoreTest.FusionAmpereMatmulSmemEpiloguePromotionRequiredA100_CUDA (673 ms)                                                                     
132:[       OK ] GPUTTensorCoreTest.FusionAmpereMatmulSmemEpilogueCast_CUDA (4862 ms)                                                                                     
134:[       OK ] GPUTTensorCoreTest.FusionAmpereMatmulSmemEpilogueRelu_CUDA (4290 ms)                                                                                     
136:[       OK ] GPUTTensorCoreTest.FusionAmpereMatmulSplitK_CUDA (6633 ms)                                                                                               
138:[       OK ] GPUTTensorCoreTest.FusionAmpereMatmulSplitKBias_CUDA (7207 ms)                                                                                           
140:[       OK ] GPUTTensorCoreTest.FusionAmpereMatmulBatchSplitK_CUDA (6819 ms)                                                                                          
142:[       OK ] GPUTTensorCoreTest.FusionAmpereMatmulBatchSplitKBias_CUDA (8366 ms)

@jacobhinkle
Copy link
Collaborator Author

I'm going to merge main now that @naoyam has merged #2017 then check codediff to see if there is any value in adding this now.

@jacobhinkle
Copy link
Collaborator Author

!build --diff

@jacobhinkle
Copy link
Collaborator Author

jacobhinkle commented Apr 3, 2024

Diff used old commit on main. Retrying.

@jacobhinkle
Copy link
Collaborator Author

!build --diff

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants