From 15b66280f522a7e3084a33c03776f25407c6ce11 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Wed, 9 Nov 2022 19:29:17 -0800 Subject: [PATCH] [TIR] Make syntax of AST nodes different than ops As part of effort of more formal TIR semantics, we want to more explicitly differentiate TIR AST nodes (defined in `tir/expr.h`) and TIR ops (defined in `tir/op.h`). A naming convention is that: - Lowercased methods, for example, `tvm.tir.mul`, means an TIR op, which will be eagerly constant-folded, i.e. `mul(1, 2)` returns `3` immediately rather than creating an AST node. - Capitalized callable, for example, `Mul`, means creating an AST node without constant folding. This PR makes this behavior more explictly by printing `T.Mul(a, b)` directly when `a` and `b` are both constants, rather than sugaring it into `mul(a. b)` or `a * b`, so that the difference between an op and an AST node is clarified. Co-authored-by: Yaxing Cai --- python/tvm/script/tir/intrin.py | 80 ++++++++++++++- src/printer/tvmscript_printer.cc | 97 +++++++++++-------- .../test_hexagon/test_async_dma_pipeline.py | 17 ++-- .../test_parallel_hvx_load_vtcm.py | 49 +++------- .../unittest/test_aot_legalize_packed_call.py | 12 +-- .../unittest/test_meta_schedule_space_cuda.py | 2 +- ..._tir_transform_inject_software_pipeline.py | 16 +-- ...est_tir_transform_inject_virtual_thread.py | 17 ++-- .../test_tir_transform_thread_sync.py | 2 +- .../unittest/test_tvmscript_roundtrip.py | 8 +- 10 files changed, 186 insertions(+), 114 deletions(-) diff --git a/python/tvm/script/tir/intrin.py b/python/tvm/script/tir/intrin.py index bd9aa1fdadfd..8e24f27325bd 100644 --- a/python/tvm/script/tir/intrin.py +++ b/python/tvm/script/tir/intrin.py @@ -17,12 +17,13 @@ """TVM Script Parser Intrinsic Classes""" # pylint: disable=redefined-builtin, relative-beyond-top-level import builtins -from typing import List, Any +from typing import Any, List import tvm.tir from tvm.tir import FloatImm -from ..registry import register + from ...target import codegen +from ..registry import register from ..utils import get_param_list, tvm_span_from_synr @@ -229,3 +230,78 @@ def comm_reducer(lambda_io, identities, span): def llvm_lookup_intrinsic_id(name, span): # pylint: disable=unused-argument return codegen.llvm_lookup_intrinsic_id(name) + + +@register +def FloorMod(x, y, span): # pylint: disable=invalid-name + return tvm.tir.FloorMod(x, y, span) + + +@register +def FloorDiv(x, y, span): # pylint: disable=invalid-name + return tvm.tir.FloorDiv(x, y, span) + + +@register +def Mul(x, y, span): # pylint: disable=invalid-name + return tvm.tir.Mul(x, y, span) + + +@register +def Div(x, y, span): # pylint: disable=invalid-name + return tvm.tir.Div(x, y, span) + + +@register +def Add(x, y, span): # pylint: disable=invalid-name + return tvm.tir.Add(x, y, span) + + +@register +def Sub(x, y, span): # pylint: disable=invalid-name + return tvm.tir.Sub(x, y, span) + + +@register +def LT(x, y, span): # pylint: disable=invalid-name + return tvm.tir.LT(x, y, span) + + +@register +def LE(x, y, span): # pylint: disable=invalid-name + return tvm.tir.LE(x, y, span) + + +@register +def GT(x, y, span): # pylint: disable=invalid-name + return tvm.tir.GT(x, y, span) + + +@register +def GE(x, y, span): # pylint: disable=invalid-name + return tvm.tir.GE(x, y, span) + + +@register +def EQ(x, y, span): # pylint: disable=invalid-name + return tvm.tir.EQ(x, y, span) + + +@register +def NE(x, y, span): # pylint: disable=invalid-name + return tvm.tir.NE(x, y, span) + + +@register +def And(x, y, span): # pylint: disable=invalid-name + return tvm.tir.And(x, y, span) + + +@register +def Or(x, y, span): # pylint: disable=invalid-name + return tvm.tir.Or(x, y, span) + + +@register +def Cast(dtype, value, span): # pylint: disable=invalid-name + return tvm.tir.Cast(dtype, value, span) diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index 64a576ef52f5..d7a3a406e352 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -788,7 +788,7 @@ Doc TVMScriptPrinter::VisitExpr_(const StringImmNode* op, ExprPrecedence* out_pr Doc TVMScriptPrinter::VisitExpr_(const CastNode* op, ExprPrecedence* out_precedence) { *out_precedence = ExprPrecedence::kIdentity; Doc doc; - doc << tir_prefix_ << ".cast(" << Print(op->value) << ", " << PrintDType(op->dtype) << ")"; + doc << tir_prefix_ << ".Cast(" << PrintDType(op->dtype) << ", " << Print(op->value) << ")"; return doc; } @@ -798,46 +798,61 @@ Doc TVMScriptPrinter::VisitExpr_(const VarNode* op, ExprPrecedence* out_preceden return meta_.InMeta(var) ? meta_.GetMetaNode(var) : AllocVar(GetRef(op)); } -#define TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(OpName, OpString, OpPrecedence) \ - Doc TVMScriptPrinter::VisitExpr_(const OpName* op, ExprPrecedence* out_precedence) { \ - Doc doc; \ - ExprPrecedence lhs_precedence = ExprPrecedence::kUnknown; \ - ExprPrecedence rhs_precedence = ExprPrecedence::kUnknown; \ - /* Get children expr out_precedence */ \ - Doc lhs_doc = VisitExpr(op->a, &lhs_precedence); \ - Doc rhs_doc = VisitExpr(op->b, &rhs_precedence); \ - ICHECK(lhs_precedence != ExprPrecedence::kUnknown); \ - ICHECK(rhs_precedence != ExprPrecedence::kUnknown); \ - /* Update out_precedence of current node. */ \ - *out_precedence = OpPrecedence; \ - if (lhs_precedence > OpPrecedence) { \ - doc << "(" << lhs_doc << ")"; \ - } else { \ - doc << lhs_doc; \ - } \ - doc << OpString; \ - if (rhs_precedence >= OpPrecedence) { \ - doc << "(" << rhs_doc << ")"; \ - } else { \ - doc << rhs_doc; \ - } \ - return doc; \ - } - -TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(MulNode, " * ", ExprPrecedence::kMultiplicationDivision) -TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(DivNode, " / ", ExprPrecedence::kMultiplicationDivision) -TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(FloorDivNode, " // ", ExprPrecedence::kMultiplicationDivision) -TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(FloorModNode, " % ", ExprPrecedence::kMultiplicationDivision) -TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(AddNode, " + ", ExprPrecedence::kAdditionSubtraction) -TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(SubNode, " - ", ExprPrecedence::kAdditionSubtraction) -TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(LTNode, " < ", ExprPrecedence::kRelational) -TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(LENode, " <= ", ExprPrecedence::kRelational) -TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(GTNode, " > ", ExprPrecedence::kRelational) -TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(GENode, " >= ", ExprPrecedence::kRelational) -TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(EQNode, " == ", ExprPrecedence::kEquality) -TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(NENode, " != ", ExprPrecedence::kEquality) -TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(AndNode, " and ", ExprPrecedence::kAnd) -TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(OrNode, " or ", ExprPrecedence::kOr) +bool WillPrintConstScalar(const PrimExpr& expr) { + if (const auto* imm = expr.as()) { + DataType dtype = imm->dtype; + return dtype == DataType::Int(32) || dtype == DataType::Bool(); + } + return false; +} + +#define TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(OpName, OpString, OpClass, OpPrecedence) \ + Doc TVMScriptPrinter::VisitExpr_(const OpName* op, ExprPrecedence* out_precedence) { \ + Doc doc; \ + if (WillPrintConstScalar(op->a) && WillPrintConstScalar(op->b)) { \ + *out_precedence = ExprPrecedence::kIdentity; \ + doc << tir_prefix_ << "." << OpClass << "(" << Print(op->a) << ", " << Print(op->b) << ")"; \ + return doc; \ + } \ + ExprPrecedence lhs_precedence = ExprPrecedence::kUnknown; \ + ExprPrecedence rhs_precedence = ExprPrecedence::kUnknown; \ + /* Get children expr out_precedence */ \ + Doc lhs_doc = VisitExpr(op->a, &lhs_precedence); \ + Doc rhs_doc = VisitExpr(op->b, &rhs_precedence); \ + ICHECK(lhs_precedence != ExprPrecedence::kUnknown); \ + ICHECK(rhs_precedence != ExprPrecedence::kUnknown); \ + /* Update out_precedence of current node. */ \ + *out_precedence = OpPrecedence; \ + if (lhs_precedence > OpPrecedence) { \ + doc << "(" << lhs_doc << ")"; \ + } else { \ + doc << lhs_doc; \ + } \ + doc << OpString; \ + if (rhs_precedence >= OpPrecedence) { \ + doc << "(" << rhs_doc << ")"; \ + } else { \ + doc << rhs_doc; \ + } \ + return doc; \ + } + +TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(MulNode, " * ", "Mul", ExprPrecedence::kMultiplicationDivision) +TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(DivNode, " / ", "Div", ExprPrecedence::kMultiplicationDivision) +TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(FloorDivNode, " // ", "FloorDiv", + ExprPrecedence::kMultiplicationDivision) +TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(FloorModNode, " % ", "FloorMod", + ExprPrecedence::kMultiplicationDivision) +TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(AddNode, " + ", "Add", ExprPrecedence::kAdditionSubtraction) +TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(SubNode, " - ", "Sub", ExprPrecedence::kAdditionSubtraction) +TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(LTNode, " < ", "LT", ExprPrecedence::kRelational) +TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(LENode, " <= ", "LE", ExprPrecedence::kRelational) +TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(GTNode, " > ", "GT", ExprPrecedence::kRelational) +TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(GENode, " >= ", "GE", ExprPrecedence::kRelational) +TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(EQNode, " == ", "EQ", ExprPrecedence::kEquality) +TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(NENode, " != ", "NE", ExprPrecedence::kEquality) +TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(AndNode, " and ", "And", ExprPrecedence::kAnd) +TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(OrNode, " or ", "Or", ExprPrecedence::kOr) Doc TVMScriptPrinter::VisitExpr_(const ModNode* op, ExprPrecedence* out_precedence) { *out_precedence = ExprPrecedence::kIdentity; diff --git a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py index a7a05c2aa3a7..9f8e639b5330 100644 --- a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py +++ b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py @@ -18,11 +18,10 @@ """ Test different strategies for loading data into vtcm before running HVX workloads. """ import numpy as np -import tvm import pytest - -from tvm.script import tir as T +import tvm from numpy.random import default_rng +from tvm.script import tir as T VRMPY_SIZE_B = 128 VRMPY_SIZE_INT32 = 32 @@ -126,9 +125,9 @@ def get_single_dma_schedule(size_a, size_w): @T.prim_func def operator(a_input: T.handle, b_input: T.handle, c_output: T.handle) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - a_buffer = T.match_buffer(a_input, a_shape, dtype="uint8", mem_scope="global") - w_buffer = T.match_buffer(b_input, w_shape, dtype="uint8", mem_scope="global") - c_buffer = T.match_buffer(c_output, out_shape, dtype="int32", mem_scope="global") + a_buffer = T.match_buffer(a_input, a_shape, dtype="uint8", scope="global") + w_buffer = T.match_buffer(b_input, w_shape, dtype="uint8", scope="global") + c_buffer = T.match_buffer(c_output, out_shape, dtype="int32", scope="global") a_global_vtcm = T.alloc_buffer(a_shape, dtype="uint8", mem_scope="global.vtcm") w_global_vtcm = T.alloc_buffer(w_shape, dtype="uint8", mem_scope="global.vtcm") c_global_vtcm = T.alloc_buffer(out_shape, dtype="int32", mem_scope="global.vtcm") @@ -153,7 +152,7 @@ def operator(a_input: T.handle, b_input: T.handle, c_output: T.handle) -> None: 0, dtype="handle", ), - T.cast(a_bytes, dtype="int"), + T.Cast("int", a_bytes), dtype="int32", ) ) @@ -178,7 +177,7 @@ def operator(a_input: T.handle, b_input: T.handle, c_output: T.handle) -> None: 0, dtype="handle", ), - T.cast(w_bytes, dtype="int"), + T.Cast("int", w_bytes), dtype="int32", ) ) @@ -222,7 +221,7 @@ def operator(a_input: T.handle, b_input: T.handle, c_output: T.handle) -> None: 0, dtype="handle", ), - T.cast(a_bytes, dtype="int"), + T.Cast("int", a_bytes), dtype="int32", ) ) diff --git a/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py b/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py index fb398f43977a..e6fc0a3c201c 100644 --- a/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py +++ b/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py @@ -18,9 +18,8 @@ """ Test different strategies for loading data into vtcm before running HVX workloads. """ import numpy as np -from numpy.random import default_rng - import tvm +from numpy.random import default_rng from tvm.script import tir as T from .infrastructure import get_hexagon_target @@ -109,17 +108,17 @@ def operator(a: T.handle, b: T.handle, c: T.handle) -> None: [T.cast(operations, "int32") * 128], dtype="uint8", align=128, - mem_scope="global.vtcm", + scope="global.vtcm", ) b_buffer = T.match_buffer( b, [T.cast(operations, "int32") * 128], dtype="uint8", align=128, - mem_scope="global.vtcm", + scope="global.vtcm", ) c_buffer = T.match_buffer( - c, [T.cast(operations, "int32") * 32], dtype="int32", align=128, mem_scope="global.vtcm" + c, [T.cast(operations, "int32") * 32], dtype="int32", align=128, scope="global.vtcm" ) for n in T.grid(operations): with T.block("c_buffer"): @@ -149,21 +148,13 @@ def operator( a: T.handle, b: T.handle, c: T.handle, a_v: T.handle, b_v: T.handle, c_v: T.handle ) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - a_buffer = T.match_buffer( - a, [operations, 128], dtype="uint8", align=128, mem_scope="global" - ) - b_buffer = T.match_buffer( - b, [operations, 128], dtype="uint8", align=128, mem_scope="global" - ) - c_buffer = T.match_buffer(c, [operations, 32], dtype="int32", align=128, mem_scope="global") - a_global_vtcm = T.match_buffer( - a_v, [size], dtype="uint8", align=128, mem_scope="global.vtcm" - ) - b_global_vtcm = T.match_buffer( - b_v, [size], dtype="uint8", align=128, mem_scope="global.vtcm" - ) + a_buffer = T.match_buffer(a, [operations, 128], dtype="uint8", align=128, scope="global") + b_buffer = T.match_buffer(b, [operations, 128], dtype="uint8", align=128, scope="global") + c_buffer = T.match_buffer(c, [operations, 32], dtype="int32", align=128, scope="global") + a_global_vtcm = T.match_buffer(a_v, [size], dtype="uint8", align=128, scope="global.vtcm") + b_global_vtcm = T.match_buffer(b_v, [size], dtype="uint8", align=128, scope="global.vtcm") c_global_vtcm = T.match_buffer( - c_v, [out_size], dtype="int32", align=128, mem_scope="global.vtcm" + c_v, [out_size], dtype="int32", align=128, scope="global.vtcm" ) for n, i in T.grid(operations, 128): with T.block("a_buffer_global.vtcm"): @@ -212,21 +203,13 @@ def operator( c_v: T.handle, ) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - a_buffer = T.match_buffer( - a, [operations, 128], dtype="uint8", align=128, mem_scope="global" - ) - b_buffer = T.match_buffer( - b, [operations, 128], dtype="uint8", align=128, mem_scope="global" - ) - c_buffer = T.match_buffer(c, [operations, 32], dtype="int32", align=128, mem_scope="global") - a_global_vtcm = T.match_buffer( - a_v, [size], dtype="uint8", align=128, mem_scope="global.vtcm" - ) - b_global_vtcm = T.match_buffer( - b_v, [size], dtype="uint8", align=128, mem_scope="global.vtcm" - ) + a_buffer = T.match_buffer(a, [operations, 128], dtype="uint8", align=128, scope="global") + b_buffer = T.match_buffer(b, [operations, 128], dtype="uint8", align=128, scope="global") + c_buffer = T.match_buffer(c, [operations, 32], dtype="int32", align=128, scope="global") + a_global_vtcm = T.match_buffer(a_v, [size], dtype="uint8", align=128, scope="global.vtcm") + b_global_vtcm = T.match_buffer(b_v, [size], dtype="uint8", align=128, scope="global.vtcm") c_global_vtcm = T.match_buffer( - c_v, [out_size], dtype="int32", align=128, mem_scope="global.vtcm" + c_v, [out_size], dtype="int32", align=128, scope="global.vtcm" ) T.evaluate( T.tvm_call_packed( diff --git a/tests/python/unittest/test_aot_legalize_packed_call.py b/tests/python/unittest/test_aot_legalize_packed_call.py index 9c597a55e5cc..cd0114d46428 100644 --- a/tests/python/unittest/test_aot_legalize_packed_call.py +++ b/tests/python/unittest/test_aot_legalize_packed_call.py @@ -15,11 +15,11 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-function-docstring,missing-module-docstring +import pytest import tvm -from tvm.script import tir as T -from tvm import tir import tvm.testing -import pytest +from tvm import tir +from tvm.script import tir as T @tvm.script.ir_module @@ -85,7 +85,7 @@ def tir_packed_call() -> None: T.tvm_stack_make_shape(1, dtype="handle"), T.reinterpret(T.uint64(0), dtype="handle"), T.uint32(1), - T.cast(0, dtype="float32"), + T.Cast("float32", 0), 0, dtype="handle", ), @@ -94,7 +94,7 @@ def tir_packed_call() -> None: T.tvm_stack_make_shape(1, dtype="handle"), T.reinterpret(T.uint64(0), dtype="handle"), T.uint32(1), - T.cast(0, dtype="float32"), + T.Cast("float32", 0), 0, dtype="handle", ), @@ -103,7 +103,7 @@ def tir_packed_call() -> None: T.tvm_stack_make_shape(1, dtype="handle"), T.reinterpret(T.uint64(0), dtype="handle"), T.uint32(1), - T.cast(0, dtype="float32"), + T.Cast("float32", 0), 0, dtype="handle", ), diff --git a/tests/python/unittest/test_meta_schedule_space_cuda.py b/tests/python/unittest/test_meta_schedule_space_cuda.py index 324d8a9ec4f8..0a518c840d11 100644 --- a/tests/python/unittest/test_meta_schedule_space_cuda.py +++ b/tests/python/unittest/test_meta_schedule_space_cuda.py @@ -856,7 +856,7 @@ def nrm_1(A: T.Buffer[(1, 256, 256), "float32"], D: T.Buffer[1, "float32"]) -> N for i0_1 in T.thread_binding(128, thread="threadIdx.x"): with T.block("D"): b = T.axis.spatial(1, i0_1) - T.where(0 * 128 + i0_1 < 1) + T.where(T.Mul(0, 128) + i0_1 < 1) T.reads(C_shared[b]) T.writes(D[b]) D[b] = T.sqrt(C_shared[b], dtype="float32") diff --git a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py index 2a4cabc541c6..c70525b05712 100644 --- a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py +++ b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py @@ -14,16 +14,16 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import pytest import sys -import numpy as np +import numpy as np +import pytest import tvm import tvm.testing import tvm.tir.tensor_intrin.cuda -from tvm import tir, te, TVMError -from tvm.script import tir as T +from tvm import TVMError, te, tir from tvm.meta_schedule.testing import te_workload +from tvm.script import tir as T from tvm.testing.tir import mma_schedule from tvm.tir.tensor_intrin.cuda import ( LDMATRIX_16x16_A_DYN_INTRIN, @@ -1116,7 +1116,7 @@ def test_simple_compute_async(): mod = tvm.tir.transform.InjectSoftwarePipeline()(sch.mod) @T.prim_func - def ref(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]) -> None: + def ref(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]): for tx in T.thread_binding(16, thread="threadIdx.x"): with T.block(): T.reads(A[tx, 0:16]) @@ -1127,7 +1127,7 @@ def ref(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]) -> N T.writes(B[0, tx, 0]) with T.attr(0, "async_commit_queue_scope", 0): with T.attr(0, "async_scope", 1): - B[0 % 2, tx, 0] = A[tx, 0] * T.float32(2) + B[T.FloorMod(0, 2), tx, 0] = A[tx, 0] * T.float32(2) with T.block(): T.reads(A[tx, 1:16], B[0:2, tx, 0]) T.writes(B[0:2, tx, 0], C[tx, 0:15]) @@ -1147,11 +1147,11 @@ def ref(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]) -> N with T.attr(0, "async_wait_inflight_count", 1): C[tx, i - 1 + 1] = B[(i - 1 + 1) % 2, tx, 0] + T.float32(1) with T.block(): - T.reads(B[15 % 2, tx, 0]) + T.reads(B[T.FloorMod(15, 2), tx, 0]) T.writes(C[tx, 15]) with T.attr(0, "async_wait_queue_scope", 0): with T.attr(0, "async_wait_inflight_count", 0): - C[tx, 15] = B[15 % 2, tx, 0] + T.float32(1) + C[tx, 15] = B[T.FloorMod(15, 2), tx, 0] + T.float32(1) tvm.ir.assert_structural_equal(mod["main"], ref, True) diff --git a/tests/python/unittest/test_tir_transform_inject_virtual_thread.py b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py index 548f3bc8d1d2..b4ea4e712d19 100644 --- a/tests/python/unittest/test_tir_transform_inject_virtual_thread.py +++ b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py @@ -16,7 +16,6 @@ # under the License. import tvm from tvm import te - from tvm.script import tir as T vthread_name = tvm.testing.parameter("vthread", "cthread") @@ -155,10 +154,10 @@ def expected_func(): B = T.buffer_decl([16], "int32", data=B_data, scope="shared") # The indices for B should each be a single Ramp node, and # should not be the sum of a Ramp and Broadcast node. - B[0 * 4 : 0 * 4 + 4] = T.broadcast(0, 4) - B[1 * 4 : 1 * 4 + 4] = T.broadcast(1, 4) - B[2 * 4 : 2 * 4 + 4] = T.broadcast(2, 4) - B[3 * 4 : 3 * 4 + 4] = T.broadcast(3, 4) + B[T.Mul(0, 4) : T.Mul(0, 4) + 4] = T.broadcast(0, 4) + B[T.Mul(1, 4) : T.Mul(1, 4) + 4] = T.broadcast(1, 4) + B[T.Mul(2, 4) : T.Mul(2, 4) + 4] = T.broadcast(2, 4) + B[T.Mul(3, 4) : T.Mul(3, 4) + 4] = T.broadcast(3, 4) before_mod = tvm.IRModule.from_expr(before_func) after_mod = tvm.tir.transform.InjectVirtualThread()(before_mod) @@ -182,10 +181,10 @@ def before_func(): def expected_func(): B_data = T.allocate([4], "int32x4", "shared") B = T.buffer_decl([4], "int32x4", data=B_data, scope="shared") - B[0 * 4 / 4] = T.broadcast(0, 4) - B[1 * 4 / 4] = T.broadcast(1, 4) - B[2 * 4 / 4] = T.broadcast(2, 4) - B[3 * 4 / 4] = T.broadcast(3, 4) + B[T.Mul(0, 4) / 4] = T.broadcast(0, 4) + B[T.Mul(1, 4) / 4] = T.broadcast(1, 4) + B[T.Mul(2, 4) / 4] = T.broadcast(2, 4) + B[T.Mul(3, 4) / 4] = T.broadcast(3, 4) before_mod = tvm.IRModule.from_expr(before_func) intermediate_mod = tvm.tir.transform.InjectVirtualThread()(before_mod) diff --git a/tests/python/unittest/test_tir_transform_thread_sync.py b/tests/python/unittest/test_tir_transform_thread_sync.py index 18607ca1a005..c80cd55ea27e 100644 --- a/tests/python/unittest/test_tir_transform_thread_sync.py +++ b/tests/python/unittest/test_tir_transform_thread_sync.py @@ -102,9 +102,9 @@ def func(p0: T.Buffer[2, "float32"], p1: T.Buffer[2, "float32"]) -> None: threadIdx_x = T.env_thread("threadIdx.x") blockIdx_x = T.env_thread("blockIdx.x") T.preflattened_buffer(p0, [1, 2, 1, 1], dtype="float32", data=p0.data) - T.launch_thread(blockIdx_x, 8) result_local = T.alloc_buffer([1], dtype="float32", scope="local") temp_shared = T.alloc_buffer([1], dtype="float32", scope="shared") + T.launch_thread(blockIdx_x, 8) T.launch_thread(threadIdx_x, 4) result_local[0] = T.float32(0) if threadIdx_x < 1: diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index dd6706762dc3..f22e61e1838d 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -90,9 +90,9 @@ class Module: def mmult(A: T.handle, B: T.handle, C: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "mmult", "tir.noalias": True}) - A_1 = T.match_buffer(A, [1024 * 1024], elem_offset=0, align=64, offset_factor=1) + A_1 = T.match_buffer(A, [16384], elem_offset=0, align=64, offset_factor=1) B_1 = T.match_buffer(B, [1024, 1024], elem_offset=0, align=64, offset_factor=1) - C_1 = T.match_buffer(C, [1024 * 1024], elem_offset=0, align=64, offset_factor=1) + C_1 = T.match_buffer(C, [16384], elem_offset=0, align=64, offset_factor=1) # body packedB_data = T.allocate([32768], "float32", "global") packedB = T.buffer_decl( @@ -3008,7 +3008,7 @@ def comm_reducer_single_reduce_group(): def comm_reducer_single_reduce_group(a: T.handle, b: T.handle) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) threadIdx_x = T.env_thread("threadIdx.x") - A = T.match_buffer(a, [128 * 128], dtype="float32") + A = T.match_buffer(a, [16384], dtype="float32") for i in T.serial(0, 128): T.launch_thread(threadIdx_x, 128) reduce_temp0_data = T.allocate([1], "float32", "local") @@ -3024,7 +3024,7 @@ def comm_reducer_multiple_reduce_groups(): def comm_reducer_multiple_reduce_groups(a: T.handle, b: T.handle) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) threadIdx_x = T.env_thread("threadIdx.x") - A = T.match_buffer(a, [128 * 128], dtype="float32") + A = T.match_buffer(a, [16384], dtype="float32") for i in T.serial(0, 128): T.launch_thread(threadIdx_x, 128) reduce_temp0_data = T.allocate([1], "float32", "local")