diff --git a/python/tvm/script/parser/core/evaluator.py b/python/tvm/script/parser/core/evaluator.py index e2b67341dcb5..939b7e82ce61 100644 --- a/python/tvm/script/parser/core/evaluator.py +++ b/python/tvm/script/parser/core/evaluator.py @@ -221,6 +221,17 @@ def _visit(self, node: doc.AST) -> Any: return node if isinstance(node, doc.Lambda): return self._eval_lambda(node) + if isinstance(node, doc.Starred): + value = self._visit(node.value) + return doc.Starred( + value=value, + ctx=node.ctx, + lineno=node.lineno, + col_offset=node.col_offset, + end_lineno=node.end_lineno, + end_col_offset=node.end_col_offset, + ) + fields = {} for field in node.__class__._FIELDS: # pylint: disable=protected-access attr = getattr(node, field) diff --git a/tests/python/unittest/test_tvmscript_parser_tir.py b/tests/python/unittest/test_tvmscript_parser_tir.py index 36df55610868..c04d16008af0 100644 --- a/tests/python/unittest/test_tvmscript_parser_tir.py +++ b/tests/python/unittest/test_tvmscript_parser_tir.py @@ -212,5 +212,23 @@ def expected_non_hygienic(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32" tvm.ir.assert_structural_equal(use_non_hygienic, expected_non_hygienic) +def test_tir_starred_expression(): + dims = (128, 128) + + @T.prim_func(private=True) + def starred(a: T.handle) -> None: + A = T.match_buffer(a, [128, *dims], "int32") + for i, j, k in T.grid(128, *dims): + A[i, j, k] = T.int32(1) + + @T.prim_func(private=True) + def non_starred(a: T.handle) -> None: + A = T.match_buffer(a, [128, 128, 128], "int32") + for i, j, k in T.grid(128, 128, 128): + A[i, j, k] = T.int32(1) + + tvm.ir.assert_structural_equal(starred, non_starred) + + if __name__ == "__main__": tvm.testing.main()