diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index afe9388b9a14..efea9f1aea94 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1277,6 +1277,9 @@ def buffer_store( """ from tvm.arith import Analyzer # pylint: disable=import-outside-toplevel + if not isinstance(indices, (list, tuple, ir.Array)): + indices = [indices] + expr_indices = [] for index in indices: if isinstance(index, slice): diff --git a/python/tvm/script/parser/core/parser.py b/python/tvm/script/parser/core/parser.py index 7032d194be2f..7b7dd066c3fa 100644 --- a/python/tvm/script/parser/core/parser.py +++ b/python/tvm/script/parser/core/parser.py @@ -460,6 +460,8 @@ def _duplicate_lhs_check(self, target: doc.expr) -> Union[bool, Set[str]]: return vars elif isinstance(target, doc.Name): return {target.id} + elif isinstance(target, doc.Starred): + return self._duplicate_lhs_check(target.value) else: self.report_error(target, "Invalid type in assign statement") raise NotImplementedError diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py index 60bdb7f7924b..33b42b343628 100644 --- a/python/tvm/script/parser/tir/parser.py +++ b/python/tvm/script/parser/tir/parser.py @@ -91,7 +91,7 @@ def bind_for_value(self: Parser, node: doc.expr, var_name: str, value: Any) -> A res : Any The bound value. """ - if isinstance(value, (list, tuple)): + if isinstance(value, (list, tuple, tvm.ir.Array)): for i, v in enumerate(value): bind_for_value(self, node, f"{var_name}_{i}", v) return value @@ -255,7 +255,7 @@ def visit_assign(self: Parser, node: doc.Assign) -> None: for index in lhs.slice.elts: indices.append(self.eval_expr(index)) else: - indices = [self.eval_expr(lhs.slice)] + indices = self.eval_expr(lhs.slice) T.buffer_store(self.eval_expr(lhs.value), rhs, indices) else: self.eval_assign(target=lhs, source=rhs, bind_value=bind_assign_value) diff --git a/tests/python/unittest/test_tvmscript_error_report.py b/tests/python/unittest/test_tvmscript_error_report.py index 2713669bd3c3..f902ebb41183 100644 --- a/tests/python/unittest/test_tvmscript_error_report.py +++ b/tests/python/unittest/test_tvmscript_error_report.py @@ -236,11 +236,11 @@ def invalid_loop_var() -> None: def test_inconsistent_grid(): - def inconsistent_grid() -> None: - for i in T.grid(16, 16): # error - T.evaluate(1.0) + def inconsistent_grid(A: T.Buffer(16)) -> None: + for i in T.grid(16, 16): # valid, i is a tuple (iter0, iter1) + T.evaluate(A[i]) # error - check_error(inconsistent_grid, 2) + check_error(inconsistent_grid, 3) def test_invalid_match_buffer_region(): diff --git a/tests/python/unittest/test_tvmscript_parser_tir.py b/tests/python/unittest/test_tvmscript_parser_tir.py index c04d16008af0..210c173141c5 100644 --- a/tests/python/unittest/test_tvmscript_parser_tir.py +++ b/tests/python/unittest/test_tvmscript_parser_tir.py @@ -230,5 +230,67 @@ def non_starred(a: T.handle) -> None: tvm.ir.assert_structural_equal(starred, non_starred) +def test_tir_starred_shape_expression(): + dims = (128, 128) + + @T.prim_func(private=True) + def starred(a: T.handle) -> None: + A = T.match_buffer(a, [128, *dims], "int32") + for i, j, k in T.grid(*A.shape): + A[i, j, k] = T.int32(1) + + @T.prim_func(private=True) + def non_starred(a: T.handle) -> None: + A = T.match_buffer(a, [128, 128, 128], "int32") + for i, j, k in T.grid(128, 128, 128): + A[i, j, k] = T.int32(1) + + tvm.ir.assert_structural_equal(starred, non_starred) + + +def test_tir_dynamic_for_loop(): + dims = (128, 128) + + @T.prim_func(private=True) + def starred(a: T.handle) -> None: + A = T.match_buffer(a, [128, *dims], "int32") + for iters in T.grid(*A.shape): + A[iters] = T.int32(1) + + @T.prim_func(private=True) + def non_starred(a: T.handle) -> None: + A = T.match_buffer(a, [128, 128, 128], "int32") + for i, j, k in T.grid(128, 128, 128): + A[i, j, k] = T.int32(1) + + tvm.ir.assert_structural_equal(starred, non_starred) + + +def test_tir_starred_for_loop(): + dims = (128, 128) + + @T.prim_func(private=True) + def starred(a: T.handle, b: T.handle): + A = T.match_buffer(a, [*dims, 128], "int32") + B = T.match_buffer(a, dims, "int32") + for *spatial, reduction in T.grid(*A.shape): + with T.block("reduce"): + with T.init(): + B[spatial] = T.int32(0) + B[spatial] = B[spatial] + A[(*spatial, reduction)] + + @T.prim_func(private=True) + def non_starred(a: T.handle, b: T.handle): + A = T.match_buffer(a, [128, 128, 128], "int32") + B = T.match_buffer(a, [128, 128], "int32") + for i, j, k in T.grid(128, 128, 128): + with T.block("reduce"): + with T.init(): + B[i, j] = T.int32(0) + B[i, j] = B[i, j] + A[i, j, k] + + tvm.ir.assert_structural_equal(starred, non_starred) + + if __name__ == "__main__": tvm.testing.main()