From 2030d1cfd43553ab6dedf4ee1ef4cb79df221fe4 Mon Sep 17 00:00:00 2001 From: wrongtest Date: Thu, 15 Dec 2022 20:05:45 +0800 Subject: [PATCH 1/2] Fix print round-tripable multi thread env binding --- src/printer/tvmscript_printer.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index 7fb1129d274e..274b9542cc92 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -1045,7 +1045,6 @@ Doc TVMScriptPrinter::VisitStmt_(const AttrStmtNode* op) { << ")"; doc << Doc::NewLine() << PrintBody(op->body); } - TryDeallocVar(iter_var->var); return doc; } } From c39d7292e3765cdf9902fe7146f672965980bbf4 Mon Sep 17 00:00:00 2001 From: wrongtest Date: Thu, 15 Dec 2022 20:26:27 +0800 Subject: [PATCH 2/2] add unittest --- tests/python/unittest/test_tvmscript_roundtrip.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index 0ead66bd609f..c0174a0671c0 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -3537,6 +3537,19 @@ def func(A: T.Buffer[1, "bool"], i: T.bool, j: T.bool, k: T.bool): yield generator +def multi_env_threads(): + @T.prim_func + def func(A: T.Buffer[128, "float32"], C: T.Buffer[128, "float32"]): + B = T.alloc_buffer([128], dtype="float32") + for i in T.thread_binding(128, thread="threadIdx.x"): + B[i] = A[i] + 1.0 + for i in T.thread_binding(128, thread="threadIdx.x"): + C[i] = B[i] + 2.0 + + mod = tvm.tir.transform.LowerOpaqueBlock()(tvm.IRModule.from_expr(func)) + return mod["main"] + + ir_generator = tvm.testing.parameter( opt_gemm_normalize, opt_gemm_lower, @@ -3593,6 +3606,7 @@ def func(A: T.Buffer[1, "bool"], i: T.bool, j: T.bool, k: T.bool): elif_chain_without_else, elif_chain_with_else, *nested_boolean_expressions(), + multi_env_threads, )