From 72fa713d81e95381b7274323a682733ad3344282 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 16 Mar 2023 11:57:31 -0500 Subject: [PATCH 1/3] [Bugfix][TVMScript] Handle LetStmt for `var1 = var2` expressions Usually, when using TVMScript to represent a `PrimFunc` variable definition `var_name = expr` defines `LetStmt` with a variable named `var_name` bound to the expression `expr`. However, prior to this commit, if `expr` is a `tir::Var`, the TVMScript parser would instead silently omit the `LetStmt`, and rename all instances of that variable to `var_name`. The root cause was in the `VarTable.exist` check, which erroneously returned False in all cases. This was due to a `value is v` check, which checked if the value was the same as the stack of maybe-shadowing values that share the same name. Replacing the 'value is v` check with a `value in v` check resolves this issue. This bug dates to the initial implementation of the new TVMScript parser in https://github.com/apache/tvm/pull/12496. --- python/tvm/script/parser/core/parser.py | 4 +-- .../unittest/test_tvmscript_syntax_sugar.py | 25 +++++++++++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/python/tvm/script/parser/core/parser.py b/python/tvm/script/parser/core/parser.py index 7c699c42aecb..f440e8a9f5c1 100644 --- a/python/tvm/script/parser/core/parser.py +++ b/python/tvm/script/parser/core/parser.py @@ -186,8 +186,8 @@ def exist(self, value: Any) -> bool: res : bool The existence of the value. """ - for v in self.name2value.values(): - if v is value: + for value_stack in self.name2value.values(): + if value in value_stack: return True return False diff --git a/tests/python/unittest/test_tvmscript_syntax_sugar.py b/tests/python/unittest/test_tvmscript_syntax_sugar.py index e4ba1f7950ab..85a1cc0a5fb8 100644 --- a/tests/python/unittest/test_tvmscript_syntax_sugar.py +++ b/tests/python/unittest/test_tvmscript_syntax_sugar.py @@ -399,5 +399,30 @@ def implicit(A: T.Buffer(1, "int32")): assert_structural_equal(implicit, explicit) +def test_preserve_trivial_let_binding(): + @T.prim_func + def explicit(i: T.int32): + j = T.int32() + T.LetStmt(i, var=j) + T.evaluate(j) + + @T.prim_func + def implicit(i: T.int32): + j = i + T.evaluate(j) + + assert_structural_equal(implicit, explicit) + + +def test_preserve_parameter_name(): + @T.prim_func + def func(i: T.int32): + j = i + T.evaluate(j) + + param_name = func.params[0].name + assert param_name == "i" + + if __name__ == "__main__": tvm.testing.main() From 3aefdbf915bc3a63c8bc21aaf3aaff083ef8c638 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 24 Mar 2023 09:13:21 -0500 Subject: [PATCH 2/3] Avoid implicit `PrimExpr.__bool__` from `if value in value_stack` --- python/tvm/script/parser/core/parser.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/python/tvm/script/parser/core/parser.py b/python/tvm/script/parser/core/parser.py index f440e8a9f5c1..fdccabcd235d 100644 --- a/python/tvm/script/parser/core/parser.py +++ b/python/tvm/script/parser/core/parser.py @@ -186,10 +186,11 @@ def exist(self, value: Any) -> bool: res : bool The existence of the value. """ - for value_stack in self.name2value.values(): - if value in value_stack: - return True - return False + return any( + value.same_as(known_value) + for known_value_stack in self.name2value.values() + for known_value in known_value_stack + ) def _dispatch_wrapper(func: dispatch.ParseMethod) -> dispatch.ParseMethod: From c9681a3886d567b856f73308092213a95a16e228 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 24 Mar 2023 09:14:21 -0500 Subject: [PATCH 3/3] Use T.meta_var where variable renaming is required. --- python/tvm/tir/tensor_intrin/cuda.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/tir/tensor_intrin/cuda.py b/python/tvm/tir/tensor_intrin/cuda.py index da194f885d1c..3bc16f234fba 100644 --- a/python/tvm/tir/tensor_intrin/cuda.py +++ b/python/tvm/tir/tensor_intrin/cuda.py @@ -245,7 +245,7 @@ def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None: for i, j, k in T.grid(M_DIM, N_DIM, k_dim): with T.block("C"): i, j, k = T.axis.remap("SSR", [i, j, k]) - b_row_ind, b_col_ind = maybe_swap(k, j) + b_row_ind, b_col_ind = T.meta_var(maybe_swap(k, j)) thread_id_C, local_id_C = T.meta_var(index_map_C(i, j)) thread_id_A, local_id_A = T.meta_var(index_map_A(i, k)) @@ -719,7 +719,7 @@ def wmma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None: for i, j, k in T.grid(m_dim, n_dim, k_dim): with T.block(""): vii, vjj, vkk = T.axis.remap("SSR", [i, j, k]) - B_index_0, B_index_1 = maybe_swap(vkk, vjj) + B_index_0, B_index_1 = T.meta_var(maybe_swap(vkk, vjj)) C[vii, vjj] = C[vii, vjj] + maybe_cast(A[vii, vkk]) * maybe_cast( B[B_index_0, B_index_1] )