-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[Fix] Enhance floormod simplification rules for better expression matching #17765
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
|
||
| TVM_TRY_REWRITE_IF(floormod(x * c1, x * c2), x * floormod(c1, c2), c2.Eval()->value != 0); | ||
| TVM_TRY_REWRITE_IF(matches_one_of(floormod(x * c1, x * c2), floormod(c1 * x, c2 * x)), | ||
| floormod(c1, c2), c2.Eval()->value != 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am a bit surprised to see if we need c * x case since canonical simplify will move most cases to x * c, would be good to understand why original flow fails in this case
|
Thanks for the PR, would be good to learn about your particular usecase(e.g. how did you find out about the case). How did you construct the expression/model, what if the code is changed explicitly to x * c pattern? |
|
Sorry for lately reply. from tvm import tir
from tvm.arith import Analyzer
from tvm.tir.op import floormod
# Define symbolic variable
past_decoder_sequence_length = tir.Var("past_decoder_sequence_length", "int64")
# Create expressions with common factor
expr1 = (past_decoder_sequence_length + 1) * tir.IntImm("int64", 64)
divisor1 = (past_decoder_sequence_length + 1) * tir.IntImm("int64", 32)
# Create Analyzer
analyzer = Analyzer()
# The Logic from CanProveDivisible().
print(analyzer.can_prove_equal(expr1, divisor1) or analyzer.can_prove(floormod(expr1, divisor1) == 0))
# Expected: True, but actual: False
# Main reason is the following simplification.
print(analyzer.rewrite_simplify(floormod(expr1, divisor1)))
# Expected output: 0
# Actual output:
# T.int64(64) * (past_decoder_sequence_length + T.int64(1)) %
# (T.int64(32) * (past_decoder_sequence_length + T.int64(1)))Detailed rewrite_simplify process(up to 2 iterations): Caused by tvm/src/arith/rewrite_simplify.cc Line 449 in f4704f2
it will cause the following case: I'm not sure if this is the correct simplification , but it seems to be a bug for x * c pattern of floormod. |
|
Thanks @Ghosts381937 . A few things to note, we indeed should simplify to make mul coefficient in rhs by default which is a convention in the code base. You made a right observation about the rule Based on your observation, #18031 should fix the issue as well as the testcase you raised. Thank you for being careful and digging into the issue, arith module is something that we need to carefully maintain so this really helps |
|
superseded by #18031 |
|
Thank you for help me fix the issue!
|
Description
Update the floormod simplification rule to correctly handle expressions of the form floormod(c1*x, c2*x) by simplifying them to floormod(c1, c2). This enhancement enables better optimization of expressions that contain common factors, which frequently appear in transformer model computations.
Test Case