diff --git a/src/arithmetic/canonical.cc b/src/arithmetic/canonical.cc index 1a48779b79e3..c904b92d8ccb 100644 --- a/src/arithmetic/canonical.cc +++ b/src/arithmetic/canonical.cc @@ -191,6 +191,9 @@ class Canonical::Internal : public IRMutator { ret_entry_.max_level = stack_.back().max_level; stack_.pop_back(); CHECK(expr.defined()); + if (const IntImm* op = expr.as()) { + return Mutate_(op, expr); + } return expr; } // call produce to get a cache entry. diff --git a/tests/python/unittest/test_pass_simplify.py b/tests/python/unittest/test_pass_simplify.py index c6cf79d153b4..2cc8825e37f3 100644 --- a/tests/python/unittest/test_pass_simplify.py +++ b/tests/python/unittest/test_pass_simplify.py @@ -27,6 +27,16 @@ def test_basic(): assert str(ret.value) == "(m - 1)" +def test_canonical(): + x = tvm.var("x") + z = tvm.const(3) + ret = tvm.ir_pass.CanonicalSimplify(x / (z*z) - x / (z*z)) + assert(tvm.ir_pass.Equal(ret, 0)) + + ret = tvm.ir_pass.CanonicalSimplify(x / (z+z) - x / (z+z)) + assert(tvm.ir_pass.Equal(ret, 0)) + if __name__ == "__main__": test_basic() test_simplify() + test_canonical()