Skip to content

Conversation

@Ghosts381937
Copy link
Contributor

Description

Update the floormod simplification rule to correctly handle expressions of the form floormod(c1*x, c2*x) by simplifying them to floormod(c1, c2). This enhancement enables better optimization of expressions that contain common factors, which frequently appear in transformer model computations.

Test Case

from tvm import tir
from tvm.arith import Analyzer
from tvm.tir.op import floormod

# Define symbolic variable
past_decoder_sequence_length = tir.Var("past_decoder_sequence_length", "int64")

# Create expressions with common factor
expr = tir.IntImm("int64", 64) * (past_decoder_sequence_length + tir.IntImm("int64", 1))
divisor = tir.IntImm("int64", 31) * (past_decoder_sequence_length + tir.IntImm("int64", 1))

# Create Analyzer
analyzer = Analyzer()

# Before: returns unsimplified expression
# After: correctly simplifies to 2
print(analyzer.simplify(floormod(expr, divisor)))


TVM_TRY_REWRITE_IF(floormod(x * c1, x * c2), x * floormod(c1, c2), c2.Eval()->value != 0);
TVM_TRY_REWRITE_IF(matches_one_of(floormod(x * c1, x * c2), floormod(c1 * x, c2 * x)),
floormod(c1, c2), c2.Eval()->value != 0);
Copy link
Member

@tqchen tqchen Mar 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a bit surprised to see if we need c * x case since canonical simplify will move most cases to x * c, would be good to understand why original flow fails in this case

@tqchen
Copy link
Member

tqchen commented Mar 19, 2025

Thanks for the PR, would be good to learn about your particular usecase(e.g. how did you find out about the case).

How did you construct the expression/model, what if the code is changed explicitly to x * c pattern?

@Ghosts381937
Copy link
Contributor Author

Ghosts381937 commented May 10, 2025

Sorry for lately reply.
After taking a deep dive into the code infrastructure, I have a few questions regarding the following code, and
I'm not sure if this is the correct simplification , but it seems to be a bug for x * c pattern of floormod.

from tvm import tir
from tvm.arith import Analyzer
from tvm.tir.op import floormod

# Define symbolic variable
past_decoder_sequence_length = tir.Var("past_decoder_sequence_length", "int64")

# Create expressions with common factor
expr1 = (past_decoder_sequence_length + 1) * tir.IntImm("int64", 64)
divisor1 = (past_decoder_sequence_length + 1) * tir.IntImm("int64", 32)
# Create Analyzer
analyzer = Analyzer()

# The Logic from CanProveDivisible().
print(analyzer.can_prove_equal(expr1, divisor1) or analyzer.can_prove(floormod(expr1, divisor1) == 0))
# Expected: True, but actual: False

# Main reason is the following simplification.
print(analyzer.rewrite_simplify(floormod(expr1, divisor1)))
# Expected output: 0
# Actual output: 
#     T.int64(64) * (past_decoder_sequence_length + T.int64(1)) % 
#     (T.int64(32) * (past_decoder_sequence_length + T.int64(1)))

Detailed rewrite_simplify process(up to 2 iterations):

#   iter 0: (past_decoder_sequence_length * T.int64(64) + T.int64(64)) % 
            (past_decoder_sequence_length * T.int64(32) + T.int64(32))
#   iter 1: T.int64(64) * (past_decoder_sequence_length + T.int64(1)) % 
            (T.int64(32) * (past_decoder_sequence_length + T.int64(1)))

Caused by

TVM_TRY_REWRITE(matches_one_of(x * y + x, y * x + x, x + y * x, x + x * y), x * (y + 1));

it will cause the following case:

#     iter 0 -> iter 1: 
#         x: T.int64(64)
#         y: past_decoder_sequence_length
#     => x * (y + 1) = T.int64(64) * (past_decoder_sequence_length + T.int64(1))
#     => c1 * (y + 1) = T.int64(64) * (past_decoder_sequence_length + T.int64(1))

I'm not sure if this is the correct simplification , but it seems to be a bug for x * c pattern of floormod.
Perhaps should we consider supporting the x * c + c => (x + 1) * c pattern like recently commit in the rewrite rules for add?

@Ghosts381937 Ghosts381937 requested a review from tqchen May 11, 2025 12:42
@Ghosts381937
Copy link
Contributor Author

@tqchen

@tqchen
Copy link
Member

tqchen commented Jun 2, 2025

Thanks @Ghosts381937 . A few things to note, we indeed should simplify to make mul coefficient in rhs by default which is a convention in the code base. You made a right observation about the rule TVM_TRY_REWRITE(matches_one_of(x * y + x, y * x + x, x + y * x, x + x * y), x * (y + 1)); violating this principle.

Based on your observation, #18031 should fix the issue as well as the testcase you raised. Thank you for being careful and digging into the issue, arith module is something that we need to carefully maintain so this really helps

@tqchen tqchen closed this Jun 2, 2025
@tqchen
Copy link
Member

tqchen commented Jun 2, 2025

superseded by #18031

@Ghosts381937
Copy link
Contributor Author

Thank you for help me fix the issue!

Thanks @Ghosts381937 . A few things to note, we indeed should simplify to make mul coefficient in rhs by default which is a convention in the code base. You made a right observation about the rule TVM_TRY_REWRITE(matches_one_of(x * y + x, y * x + x, x + y * x, x + x * y), x * (y + 1)); violating this principle.

Based on your observation, #18031 should fix the issue as well as the testcase you raised. Thank you for being careful and digging into the issue, arith module is something that we need to carefully maintain so this really helps

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants