Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 78 additions & 2 deletions python/tvm/script/tir/intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@
"""TVM Script Parser Intrinsic Classes"""
# pylint: disable=redefined-builtin, relative-beyond-top-level
import builtins
from typing import List, Any
from typing import Any, List

import tvm.tir
from tvm.tir import FloatImm
from ..registry import register

from ...target import codegen
from ..registry import register
from ..utils import get_param_list, tvm_span_from_synr


Expand Down Expand Up @@ -229,3 +230,78 @@ def comm_reducer(lambda_io, identities, span):
def llvm_lookup_intrinsic_id(name, span):
# pylint: disable=unused-argument
return codegen.llvm_lookup_intrinsic_id(name)


@register
def FloorMod(x, y, span): # pylint: disable=invalid-name
return tvm.tir.FloorMod(x, y, span)


@register
def FloorDiv(x, y, span): # pylint: disable=invalid-name
return tvm.tir.FloorDiv(x, y, span)


@register
def Mul(x, y, span): # pylint: disable=invalid-name
return tvm.tir.Mul(x, y, span)


@register
def Div(x, y, span): # pylint: disable=invalid-name
return tvm.tir.Div(x, y, span)


@register
def Add(x, y, span): # pylint: disable=invalid-name
return tvm.tir.Add(x, y, span)


@register
def Sub(x, y, span): # pylint: disable=invalid-name
return tvm.tir.Sub(x, y, span)


@register
def LT(x, y, span): # pylint: disable=invalid-name
return tvm.tir.LT(x, y, span)


@register
def LE(x, y, span): # pylint: disable=invalid-name
return tvm.tir.LE(x, y, span)


@register
def GT(x, y, span): # pylint: disable=invalid-name
return tvm.tir.GT(x, y, span)


@register
def GE(x, y, span): # pylint: disable=invalid-name
return tvm.tir.GE(x, y, span)


@register
def EQ(x, y, span): # pylint: disable=invalid-name
return tvm.tir.EQ(x, y, span)


@register
def NE(x, y, span): # pylint: disable=invalid-name
return tvm.tir.NE(x, y, span)


@register
def And(x, y, span): # pylint: disable=invalid-name
return tvm.tir.And(x, y, span)


@register
def Or(x, y, span): # pylint: disable=invalid-name
return tvm.tir.Or(x, y, span)


@register
def Cast(dtype, value, span): # pylint: disable=invalid-name
return tvm.tir.Cast(dtype, value, span)
97 changes: 56 additions & 41 deletions src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -788,7 +788,7 @@ Doc TVMScriptPrinter::VisitExpr_(const StringImmNode* op, ExprPrecedence* out_pr
Doc TVMScriptPrinter::VisitExpr_(const CastNode* op, ExprPrecedence* out_precedence) {
*out_precedence = ExprPrecedence::kIdentity;
Doc doc;
doc << tir_prefix_ << ".cast(" << Print(op->value) << ", " << PrintDType(op->dtype) << ")";
doc << tir_prefix_ << ".Cast(" << PrintDType(op->dtype) << ", " << Print(op->value) << ")";
return doc;
}

Expand All @@ -798,46 +798,61 @@ Doc TVMScriptPrinter::VisitExpr_(const VarNode* op, ExprPrecedence* out_preceden
return meta_.InMeta(var) ? meta_.GetMetaNode(var) : AllocVar(GetRef<Var>(op));
}

#define TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(OpName, OpString, OpPrecedence) \
Doc TVMScriptPrinter::VisitExpr_(const OpName* op, ExprPrecedence* out_precedence) { \
Doc doc; \
ExprPrecedence lhs_precedence = ExprPrecedence::kUnknown; \
ExprPrecedence rhs_precedence = ExprPrecedence::kUnknown; \
/* Get children expr out_precedence */ \
Doc lhs_doc = VisitExpr(op->a, &lhs_precedence); \
Doc rhs_doc = VisitExpr(op->b, &rhs_precedence); \
ICHECK(lhs_precedence != ExprPrecedence::kUnknown); \
ICHECK(rhs_precedence != ExprPrecedence::kUnknown); \
/* Update out_precedence of current node. */ \
*out_precedence = OpPrecedence; \
if (lhs_precedence > OpPrecedence) { \
doc << "(" << lhs_doc << ")"; \
} else { \
doc << lhs_doc; \
} \
doc << OpString; \
if (rhs_precedence >= OpPrecedence) { \
doc << "(" << rhs_doc << ")"; \
} else { \
doc << rhs_doc; \
} \
return doc; \
}

TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(MulNode, " * ", ExprPrecedence::kMultiplicationDivision)
TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(DivNode, " / ", ExprPrecedence::kMultiplicationDivision)
TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(FloorDivNode, " // ", ExprPrecedence::kMultiplicationDivision)
TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(FloorModNode, " % ", ExprPrecedence::kMultiplicationDivision)
TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(AddNode, " + ", ExprPrecedence::kAdditionSubtraction)
TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(SubNode, " - ", ExprPrecedence::kAdditionSubtraction)
TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(LTNode, " < ", ExprPrecedence::kRelational)
TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(LENode, " <= ", ExprPrecedence::kRelational)
TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(GTNode, " > ", ExprPrecedence::kRelational)
TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(GENode, " >= ", ExprPrecedence::kRelational)
TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(EQNode, " == ", ExprPrecedence::kEquality)
TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(NENode, " != ", ExprPrecedence::kEquality)
TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(AndNode, " and ", ExprPrecedence::kAnd)
TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(OrNode, " or ", ExprPrecedence::kOr)
bool WillPrintConstScalar(const PrimExpr& expr) {
if (const auto* imm = expr.as<IntImmNode>()) {
DataType dtype = imm->dtype;
return dtype == DataType::Int(32) || dtype == DataType::Bool();
}
return false;
}

#define TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(OpName, OpString, OpClass, OpPrecedence) \
Doc TVMScriptPrinter::VisitExpr_(const OpName* op, ExprPrecedence* out_precedence) { \
Doc doc; \
if (WillPrintConstScalar(op->a) && WillPrintConstScalar(op->b)) { \
*out_precedence = ExprPrecedence::kIdentity; \
doc << tir_prefix_ << "." << OpClass << "(" << Print(op->a) << ", " << Print(op->b) << ")"; \
return doc; \
} \
ExprPrecedence lhs_precedence = ExprPrecedence::kUnknown; \
ExprPrecedence rhs_precedence = ExprPrecedence::kUnknown; \
/* Get children expr out_precedence */ \
Doc lhs_doc = VisitExpr(op->a, &lhs_precedence); \
Doc rhs_doc = VisitExpr(op->b, &rhs_precedence); \
ICHECK(lhs_precedence != ExprPrecedence::kUnknown); \
ICHECK(rhs_precedence != ExprPrecedence::kUnknown); \
/* Update out_precedence of current node. */ \
*out_precedence = OpPrecedence; \
if (lhs_precedence > OpPrecedence) { \
doc << "(" << lhs_doc << ")"; \
} else { \
doc << lhs_doc; \
} \
doc << OpString; \
if (rhs_precedence >= OpPrecedence) { \
doc << "(" << rhs_doc << ")"; \
} else { \
doc << rhs_doc; \
} \
return doc; \
}

TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(MulNode, " * ", "Mul", ExprPrecedence::kMultiplicationDivision)
TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(DivNode, " / ", "Div", ExprPrecedence::kMultiplicationDivision)
TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(FloorDivNode, " // ", "FloorDiv",
ExprPrecedence::kMultiplicationDivision)
TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(FloorModNode, " % ", "FloorMod",
ExprPrecedence::kMultiplicationDivision)
TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(AddNode, " + ", "Add", ExprPrecedence::kAdditionSubtraction)
TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(SubNode, " - ", "Sub", ExprPrecedence::kAdditionSubtraction)
TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(LTNode, " < ", "LT", ExprPrecedence::kRelational)
TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(LENode, " <= ", "LE", ExprPrecedence::kRelational)
TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(GTNode, " > ", "GT", ExprPrecedence::kRelational)
TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(GENode, " >= ", "GE", ExprPrecedence::kRelational)
TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(EQNode, " == ", "EQ", ExprPrecedence::kEquality)
TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(NENode, " != ", "NE", ExprPrecedence::kEquality)
TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(AndNode, " and ", "And", ExprPrecedence::kAnd)
TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(OrNode, " or ", "Or", ExprPrecedence::kOr)

Doc TVMScriptPrinter::VisitExpr_(const ModNode* op, ExprPrecedence* out_precedence) {
*out_precedence = ExprPrecedence::kIdentity;
Expand Down
17 changes: 8 additions & 9 deletions tests/python/contrib/test_hexagon/test_async_dma_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,10 @@
""" Test different strategies for loading data into vtcm before running HVX workloads. """

import numpy as np
import tvm
import pytest

from tvm.script import tir as T
import tvm
from numpy.random import default_rng
from tvm.script import tir as T

VRMPY_SIZE_B = 128
VRMPY_SIZE_INT32 = 32
Expand Down Expand Up @@ -126,9 +125,9 @@ def get_single_dma_schedule(size_a, size_w):
@T.prim_func
def operator(a_input: T.handle, b_input: T.handle, c_output: T.handle) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
a_buffer = T.match_buffer(a_input, a_shape, dtype="uint8", mem_scope="global")
w_buffer = T.match_buffer(b_input, w_shape, dtype="uint8", mem_scope="global")
c_buffer = T.match_buffer(c_output, out_shape, dtype="int32", mem_scope="global")
a_buffer = T.match_buffer(a_input, a_shape, dtype="uint8", scope="global")
w_buffer = T.match_buffer(b_input, w_shape, dtype="uint8", scope="global")
c_buffer = T.match_buffer(c_output, out_shape, dtype="int32", scope="global")
a_global_vtcm = T.alloc_buffer(a_shape, dtype="uint8", mem_scope="global.vtcm")
w_global_vtcm = T.alloc_buffer(w_shape, dtype="uint8", mem_scope="global.vtcm")
c_global_vtcm = T.alloc_buffer(out_shape, dtype="int32", mem_scope="global.vtcm")
Expand All @@ -153,7 +152,7 @@ def operator(a_input: T.handle, b_input: T.handle, c_output: T.handle) -> None:
0,
dtype="handle",
),
T.cast(a_bytes, dtype="int"),
T.Cast("int", a_bytes),
dtype="int32",
)
)
Expand All @@ -178,7 +177,7 @@ def operator(a_input: T.handle, b_input: T.handle, c_output: T.handle) -> None:
0,
dtype="handle",
),
T.cast(w_bytes, dtype="int"),
T.Cast("int", w_bytes),
dtype="int32",
)
)
Expand Down Expand Up @@ -222,7 +221,7 @@ def operator(a_input: T.handle, b_input: T.handle, c_output: T.handle) -> None:
0,
dtype="handle",
),
T.cast(a_bytes, dtype="int"),
T.Cast("int", a_bytes),
dtype="int32",
)
)
Expand Down
49 changes: 16 additions & 33 deletions tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@
""" Test different strategies for loading data into vtcm before running HVX workloads. """

import numpy as np
from numpy.random import default_rng

import tvm
from numpy.random import default_rng
from tvm.script import tir as T

from .infrastructure import get_hexagon_target
Expand Down Expand Up @@ -109,17 +108,17 @@ def operator(a: T.handle, b: T.handle, c: T.handle) -> None:
[T.cast(operations, "int32") * 128],
dtype="uint8",
align=128,
mem_scope="global.vtcm",
scope="global.vtcm",
)
b_buffer = T.match_buffer(
b,
[T.cast(operations, "int32") * 128],
dtype="uint8",
align=128,
mem_scope="global.vtcm",
scope="global.vtcm",
)
c_buffer = T.match_buffer(
c, [T.cast(operations, "int32") * 32], dtype="int32", align=128, mem_scope="global.vtcm"
c, [T.cast(operations, "int32") * 32], dtype="int32", align=128, scope="global.vtcm"
)
for n in T.grid(operations):
with T.block("c_buffer"):
Expand Down Expand Up @@ -149,21 +148,13 @@ def operator(
a: T.handle, b: T.handle, c: T.handle, a_v: T.handle, b_v: T.handle, c_v: T.handle
) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
a_buffer = T.match_buffer(
a, [operations, 128], dtype="uint8", align=128, mem_scope="global"
)
b_buffer = T.match_buffer(
b, [operations, 128], dtype="uint8", align=128, mem_scope="global"
)
c_buffer = T.match_buffer(c, [operations, 32], dtype="int32", align=128, mem_scope="global")
a_global_vtcm = T.match_buffer(
a_v, [size], dtype="uint8", align=128, mem_scope="global.vtcm"
)
b_global_vtcm = T.match_buffer(
b_v, [size], dtype="uint8", align=128, mem_scope="global.vtcm"
)
a_buffer = T.match_buffer(a, [operations, 128], dtype="uint8", align=128, scope="global")
b_buffer = T.match_buffer(b, [operations, 128], dtype="uint8", align=128, scope="global")
c_buffer = T.match_buffer(c, [operations, 32], dtype="int32", align=128, scope="global")
a_global_vtcm = T.match_buffer(a_v, [size], dtype="uint8", align=128, scope="global.vtcm")
b_global_vtcm = T.match_buffer(b_v, [size], dtype="uint8", align=128, scope="global.vtcm")
c_global_vtcm = T.match_buffer(
c_v, [out_size], dtype="int32", align=128, mem_scope="global.vtcm"
c_v, [out_size], dtype="int32", align=128, scope="global.vtcm"
)
for n, i in T.grid(operations, 128):
with T.block("a_buffer_global.vtcm"):
Expand Down Expand Up @@ -212,21 +203,13 @@ def operator(
c_v: T.handle,
) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
a_buffer = T.match_buffer(
a, [operations, 128], dtype="uint8", align=128, mem_scope="global"
)
b_buffer = T.match_buffer(
b, [operations, 128], dtype="uint8", align=128, mem_scope="global"
)
c_buffer = T.match_buffer(c, [operations, 32], dtype="int32", align=128, mem_scope="global")
a_global_vtcm = T.match_buffer(
a_v, [size], dtype="uint8", align=128, mem_scope="global.vtcm"
)
b_global_vtcm = T.match_buffer(
b_v, [size], dtype="uint8", align=128, mem_scope="global.vtcm"
)
a_buffer = T.match_buffer(a, [operations, 128], dtype="uint8", align=128, scope="global")
b_buffer = T.match_buffer(b, [operations, 128], dtype="uint8", align=128, scope="global")
c_buffer = T.match_buffer(c, [operations, 32], dtype="int32", align=128, scope="global")
a_global_vtcm = T.match_buffer(a_v, [size], dtype="uint8", align=128, scope="global.vtcm")
b_global_vtcm = T.match_buffer(b_v, [size], dtype="uint8", align=128, scope="global.vtcm")
c_global_vtcm = T.match_buffer(
c_v, [out_size], dtype="int32", align=128, mem_scope="global.vtcm"
c_v, [out_size], dtype="int32", align=128, scope="global.vtcm"
)
T.evaluate(
T.tvm_call_packed(
Expand Down
12 changes: 6 additions & 6 deletions tests/python/unittest/test_aot_legalize_packed_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-function-docstring,missing-module-docstring
import pytest
import tvm
from tvm.script import tir as T
from tvm import tir
import tvm.testing
import pytest
from tvm import tir
from tvm.script import tir as T


@tvm.script.ir_module
Expand Down Expand Up @@ -85,7 +85,7 @@ def tir_packed_call() -> None:
T.tvm_stack_make_shape(1, dtype="handle"),
T.reinterpret(T.uint64(0), dtype="handle"),
T.uint32(1),
T.cast(0, dtype="float32"),
T.Cast("float32", 0),
0,
dtype="handle",
),
Expand All @@ -94,7 +94,7 @@ def tir_packed_call() -> None:
T.tvm_stack_make_shape(1, dtype="handle"),
T.reinterpret(T.uint64(0), dtype="handle"),
T.uint32(1),
T.cast(0, dtype="float32"),
T.Cast("float32", 0),
0,
dtype="handle",
),
Expand All @@ -103,7 +103,7 @@ def tir_packed_call() -> None:
T.tvm_stack_make_shape(1, dtype="handle"),
T.reinterpret(T.uint64(0), dtype="handle"),
T.uint32(1),
T.cast(0, dtype="float32"),
T.Cast("float32", 0),
0,
dtype="handle",
),
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_meta_schedule_space_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,7 +856,7 @@ def nrm_1(A: T.Buffer[(1, 256, 256), "float32"], D: T.Buffer[1, "float32"]) -> N
for i0_1 in T.thread_binding(128, thread="threadIdx.x"):
with T.block("D"):
b = T.axis.spatial(1, i0_1)
T.where(0 * 128 + i0_1 < 1)
T.where(T.Mul(0, 128) + i0_1 < 1)
T.reads(C_shared[b])
T.writes(D[b])
D[b] = T.sqrt(C_shared[b], dtype="float32")
Expand Down
Loading