diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py index e9b4286edad8..c34aae23453c 100644 --- a/python/tvm/script/parser.py +++ b/python/tvm/script/parser.py @@ -906,6 +906,13 @@ def transform_Call(self, node): ) if node.func_name.name in self._unaryop_maker: rhs = self.transform(node.params[0]) + if node.func_name.name == ast.BuiltinOp.USub and isinstance( + node.params[0], ast.Constant + ): + # '-literal' should be parsed together for proper literal type inference + if not isinstance(rhs, (tvm.tir.IntImm, tvm.tir.FloatImm)): + self.report_error("The literal is illegal after -", node.params[0].span) + return tvm.tir.const(-rhs.value) return self._unaryop_maker[node.func_name.name]( rhs, span=tvm_span_from_synr(node.span) ) diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index 17622789558d..1f5871b488e2 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -3381,6 +3381,15 @@ def func( return func +def minimal_i32_literal(): + @T.prim_func + def func() -> None: + T.evaluate(T.int32(-2147483648)) + T.evaluate(-T.int64(2147483648)) + + return func + + ir_generator = tvm.testing.parameter( opt_gemm_normalize, opt_gemm_lower, @@ -3423,6 +3432,7 @@ def func( decl_buffer, allocate_and_decl_buffer, float_infinity, + minimal_i32_literal, )