From 3a1ccf8bacddd18b61bf7acec5c1d9e1f2c3a87f Mon Sep 17 00:00:00 2001 From: shingjan Date: Tue, 2 Nov 2021 11:33:18 -0700 Subject: [PATCH 01/32] add init --- python/tvm/script/tir/__init__.pyi | 196 ++++++ python/tvm/tir/__init__.pyi | 612 +++++++++++++++++++ tests/python/unittest/test_tvmscript_type.py | 46 ++ 3 files changed, 854 insertions(+) create mode 100644 python/tvm/script/tir/__init__.pyi create mode 100644 python/tvm/tir/__init__.pyi create mode 100644 tests/python/unittest/test_tvmscript_type.py diff --git a/python/tvm/script/tir/__init__.pyi b/python/tvm/script/tir/__init__.pyi new file mode 100644 index 000000000000..8358f9c2b2fb --- /dev/null +++ b/python/tvm/script/tir/__init__.pyi @@ -0,0 +1,196 @@ +from typing import ( + Any, + Callable, + ContextManager, + Dict, + Iterable, + Object, + Optional, + Tuple, + Union, + Sequence, + List, + Mapping, +) +from tvm.ir import Span +from tvm.tir.function import PrimFunc +from tvm.tir import PrimExpr, Buffer, IterVar, Var, Ptr +from .node import BufferSlice + +""" +Variables and constants +""" + +def bool(imm: int, span: Span) -> PrimExpr: ... +def int8(imm: int, span: Span) -> PrimExpr: ... +def int16(imm: int, span: Span) -> PrimExpr: ... +def int32(imm: int, span: Span) -> PrimExpr: ... +def int64(imm: int, span: Span) -> PrimExpr: ... +def uint8(imm: int, span: Span) -> PrimExpr: ... +def uint16(imm: int, span: Span) -> PrimExpr: ... +def uint32(imm: int, span: Span) -> PrimExpr: ... +def uint64(imm: int, span: Span) -> PrimExpr: ... +def float8(imm: int, span: Span) -> PrimExpr: ... +def float16(imm: int, span: Span) -> PrimExpr: ... +def float32(imm: int, span: Span) -> PrimExpr: ... +def float64(imm: int, span: Span) -> PrimExpr: ... + +""" +Intrinsic +""" + +def min_value(dtype, span: Span): ... +def max_value(dtype, span: Span): ... +def floordiv(x: PrimExpr, y: PrimExpr, span: Span): ... +def floormod(x: PrimExpr, y: PrimExpr, span: Span): ... +def abs(x, span: Span): ... +def load(dtype, var, index, predicate=None, span: Span = None): ... +def cast(value, dtype, span: Span): ... +def ramp(base, stride, lanes, span: Span): ... +def broadcast(value, lanes, span: Span): ... +def iter_var(var, dom, iter_type, thread_tag, span: Span): ... +def max(a, b, span: Span): ... +def min(a, b, span: Span): ... +def get_axis(begin, end, iter_type, span: Span): ... +def range(begin, end, span: Span): ... +def reduce_axis(begin, end, span: Span): ... +def scan_axis(begin, end, span: Span): ... +def opaque_axis(begin, end, span: Span): ... +def Select(cond, if_body, else_body, span: Span): ... +def evaluate(value, span: Span): ... +def store(var, index, value, predicate=True, span: Span = None): ... +def comm_reducer(lambda_io, identities, span: Span): ... + +""" +Unary operator +""" + +def exp2(x: PrimExpr) -> PrimExpr: ... +def exp10(x: PrimExpr) -> PrimExpr: ... +def erf(x: PrimExpr) -> PrimExpr: ... +def tanh(x: PrimExpr) -> PrimExpr: ... +def sigmoid(x: PrimExpr) -> PrimExpr: ... +def log(x: PrimExpr) -> PrimExpr: ... +def log2(x: PrimExpr) -> PrimExpr: ... +def log10(x: PrimExpr) -> PrimExpr: ... +def log1p(x: PrimExpr) -> PrimExpr: ... +def tan(x: PrimExpr) -> PrimExpr: ... +def cos(x: PrimExpr) -> PrimExpr: ... +def cosh(x: PrimExpr) -> PrimExpr: ... +def acos(x: PrimExpr) -> PrimExpr: ... +def acosh(x: PrimExpr) -> PrimExpr: ... +def sin(x: PrimExpr) -> PrimExpr: ... +def sinh(x: PrimExpr) -> PrimExpr: ... +def asin(x: PrimExpr) -> PrimExpr: ... +def asinh(x: PrimExpr) -> PrimExpr: ... +def atan(x: PrimExpr) -> PrimExpr: ... +def atanh(x: PrimExpr) -> PrimExpr: ... +def atan2(x: PrimExpr) -> PrimExpr: ... +def sqrt(x: PrimExpr) -> PrimExpr: ... +def rsqrt(x: PrimExpr) -> PrimExpr: ... + +""" +Loops +""" + +def serial(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> Iterable[IterVar]: ... +def parallel(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> Iterable[IterVar]: ... +def vectorize(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> Iterable[IterVar]: ... +def unroll(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> Iterable[IterVar]: ... +def grid(*extents: Union[PrimExpr, int]) -> Iterable[Tuple[IterVar]]: ... +def range(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> Iterable[IterVar]: ... +def thread_binding( + begin: Union[PrimExpr, int], end: Union[PrimExpr, int], thread: str +) -> Iterable[IterVar]: ... + +""" +Axis +""" + +def reduce_axis(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ... +def range(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ... +def scan_axis(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ... +def opaque_axis(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ... + +""" +Buffers +""" + +def match_buffer( + param: Union[Var, BufferSlice], + shape: Sequence[Union[PrimExpr, int]], + dtype: str = "float32", + data=None, + strides: Optional[Sequence[int]] = None, + elem_offset: Optional[int] = None, + scope: str = "global", + align: int = -1, + offset_factor: int = 0, + buffer_type: str = "default", + span: Span = None, +) -> Buffer: ... +def buffer_decl( + shape: Sequence[Union[PrimExpr, int]], + dtype: str = "float32", + data=None, + strides: Optional[Sequence[int]] = None, + elem_offset: Optional[int] = None, + scope: str = "global", + align: int = -1, + offset_factor: int = 0, + buffer_type: str = "default", + span: Span = None, +) -> Buffer: ... +def alloc_buffer( + shape: Sequence[Union[PrimExpr, int]], + dtype: str = "float32", + data=None, + strides: Optional[Sequence[int]] = None, + elem_offset: Optional[int] = None, + scope: str = "global", + align: int = -1, + offset_factor: int = 0, + buffer_type: str = "default", + span: Span = None, +) -> Buffer: ... + +""" +Reads/Writes +""" + +def reads(read_regions: Union[BufferSlice, List[BufferSlice]], span: Span = None) -> None: ... +def writes(write_region: Union[BufferSlice, List[BufferSlice]], span: Span = None) -> None: ... +def block_attr(attrs: Mapping[str, Object], span: Span = None) -> None: ... + +""" +Scope handler +""" + +class block(ContextManager): + def __init__(self, axes: Sequence[Union[int, PrimExpr, slice]], name: str = "") -> None: ... + def __enter__(self) -> Sequence[IterVar]: ... + +class init(ContextManager): + def __init__(self) -> None: ... + +class let(ContextManager): + def __init__(self, var: Var, value: PrimExpr) -> None: ... + +def where(cond: PrimExpr) -> None: ... +def realize(x: Buffer, scope: str, condition: bool = True) -> None: ... + +""" +Threads and Bindings +""" + +def env_thread(thread: str) -> IterVar: ... +def bind(iter_var: IterVar, expr: PrimExpr) -> None: ... + +""" +Annotations +""" + +def func_attr(attrs: Dict) -> None: ... +def block_attr(attrs: Dict) -> None: ... +def attr(node: PrimExpr, attr_key: str, value: PrimExpr) -> None: ... +def prim_func(input_func: Callable) -> PrimFunc: ... \ No newline at end of file diff --git a/python/tvm/tir/__init__.pyi b/python/tvm/tir/__init__.pyi new file mode 100644 index 000000000000..c7bccf5d9bec --- /dev/null +++ b/python/tvm/tir/__init__.pyi @@ -0,0 +1,612 @@ +from typing import ( + Any, + Callable, + ContextManager, + Dict, + Iterable, + List, + Optional, + Tuple, + Union, + Mapping, + Sequence, + overload, +) +from enum import IntEnum +from tvm import ir +from tvm.ir import BaseFunc, Range, Span +from tvm.runtime import ( + DataType, + DataTypeCode, + Object, + ObjectGeneric, +) + +""" +Redefine types +""" + +class PrimExpr: + def __init__(self: PrimExpr) -> None: ... + @overload + def __add__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ... + @overload + def __sub__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ... + @overload + def __mul__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ... + @overload + def __div__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ... + @overload + def __add__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ... + @overload + def __radd__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ... + @overload + def __sub__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ... + @overload + def __rsub__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ... + @overload + def __mul__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ... + @overload + def __rmul__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ... + @overload + def __div__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ... + @overload + def __rdiv__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ... + +""" +buffer +""" + +class Buffer(Object): + READ: int + WRITE: int + def access_ptr( + self, access_mask, ptr_type: str = ..., content_lanes: int = ..., offset: int = ... + ): ... + def vload(self, begin, dtype: Any | None = ...): ... + def vstore(self, begin, value): ... + def scope(self): ... + +def decl_buffer( + shape, + dtype: Any | None = ..., + name: str = ..., + data: Any | None = ..., + strides: Any | None = ..., + elem_offset: Any | None = ..., + scope: str = ..., + data_alignment: int = ..., + offset_factor: int = ..., + buffer_type: str = ..., + span: Any | None = ..., +): ... + +class DataProducer(Object): ... + +""" +data layout +""" + +class Layout(Object): + def __len__(self): ... + def __contains__(self, axis): ... + def __getitem__(self, index): ... + def index_of(self, axis): ... + def factor_of(self, axis): ... + +class BijectiveLayout(Object): + def forward_index(self, index): ... + def backward_index(self, index): ... + def forward_shape(self, shape): ... + def backward_shape(self, shape): ... + +def layout(layout_str: str) -> Layout: ... +def bijective_layout( + src_layout: Union[str, Layout], dst_layout: Union[str, Layout] +) -> BijectiveLayout: ... + +""" +expr +""" +from typing import Any as _Any + +def div_ambiguity_error(): ... + +class ExprOp: + def __add__(self, other): ... + def __radd__(self, other): ... + def __sub__(self, other): ... + def __rsub__(self, other): ... + def __mul__(self, other): ... + def __rmul__(self, other): ... + def __div__(self, other): ... + def __rdiv__(self, other): ... + def __truediv__(self, other): ... + def __rtruediv__(self, other): ... + def __floordiv__(self, other): ... + def __rfloordiv__(self, other): ... + def __mod__(self, other): ... + def __rmod__(self, other): ... + def __neg__(self): ... + def __lshift__(self, other): ... + def __rlshift__(self, other): ... + def __rshift__(self, other): ... + def __rrshift__(self, other): ... + def __and__(self, other): ... + def __rand__(self, other): ... + def __or__(self, other): ... + def __ror__(self, other): ... + def __xor__(self, other): ... + def __rxor__(self, other): ... + def __invert__(self): ... + def __lt__(self, other): ... + def __le__(self, other): ... + def __eq__(self, other): ... + def __ne__(self, other): ... + def __gt__(self, other): ... + def __ge__(self, other): ... + def __nonzero__(self) -> None: ... + def __bool__(self): ... + def equal(self, other, span: _Any | None = ...): ... + def astype(self, dtype: str, span: Optional[Span] = ...): ... + +class EqualOp(ObjectGeneric, ExprOp): + same_as: _Any + a: _Any + b: _Any + span: _Any + def __init__(self, a, b, span: _Any | None = ...) -> None: ... + def __nonzero__(self): ... + def __bool__(self): ... + def asobject(self): ... + +class NotEqualOp(ObjectGeneric, ExprOp): + same_as: _Any + a: _Any + b: _Any + span: _Any + def __init__(self, a, b, span: _Any | None = ...) -> None: ... + def __nonzero__(self): ... + def __bool__(self): ... + def asobject(self): ... + +class IntImmEnum(ObjectGeneric): + value: _Any + span: _Any + def __init__(self, value, span: _Any | None = ...) -> None: ... + def asobject(self): ... + +class PrimExprWithOp(ExprOp, PrimExpr): + __hash__: _Any + +class ConstExpr(PrimExprWithOp): ... +class BinaryOpExpr(PrimExprWithOp): ... +class CmpExpr(PrimExprWithOp): ... +class LogicalExpr(PrimExprWithOp): ... + +class Var(PrimExprWithOp): + def __init__( + self, name: str, dtype: Union[str, ir.Type], span: Optional[Span] = ... + ) -> None: ... + +class SizeVar(Var): + def __init__(self, name, dtype, span: _Any | None = ...) -> None: ... + +class IterVar(Object, ExprOp): + DataPar: int + ThreadIndex: int + CommReduce: int + Ordered: int + DimInfo: int + Unrolled: int + Vectorized: int + Parallelized: int + Tensorized: int + def __init__( + self, dom, var, iter_type, thread_tag: str = ..., span: _Any | None = ... + ) -> None: ... + +class CommReducer(Object): + def __init__(self, lhs, rhs, result, identity_element, span: _Any | None = ...) -> None: ... + +class Reduce(PrimExprWithOp): + def __init__( + self, + combiner, + src, + rdom, + condition, + value_index, + init: _Any | None = ..., + span: _Any | None = ..., + ) -> None: ... + +class FloatImm(ConstExpr): + def __init__(self, dtype, value, span: _Any | None = ...) -> None: ... + def __float__(self): ... + +class IntImm(ConstExpr): + def __init__(self, dtype, value, span: _Any | None = ...) -> None: ... + def __hash__(self): ... + def __int__(self): ... + def __nonzero__(self): ... + def __eq__(self, other): ... + def __ne__(self, other): ... + def __bool__(self): ... + +class StringImm(ConstExpr): + def __init__(self, value, span: _Any | None = ...) -> None: ... + def __eq__(self, other): ... + def __ne__(self, other): ... + def __hash__(self): ... + +class Cast(PrimExprWithOp): + def __init__(self, dtype, value, span: _Any | None = ...) -> None: ... + +class Add(BinaryOpExpr): + def __init__(self, a, b, span: _Any | None = ...) -> None: ... + +class Sub(BinaryOpExpr): + def __init__(self, a, b, span: _Any | None = ...) -> None: ... + +class Mul(BinaryOpExpr): + def __init__(self, a, b, span: _Any | None = ...) -> None: ... + +class Div(BinaryOpExpr): + def __init__(self, a, b, span: _Any | None = ...) -> None: ... + +class Mod(BinaryOpExpr): + def __init__(self, a, b, span: _Any | None = ...) -> None: ... + +class FloorDiv(BinaryOpExpr): + def __init__(self, a, b, span: _Any | None = ...) -> None: ... + +class FloorMod(BinaryOpExpr): + def __init__(self, a, b, span: _Any | None = ...) -> None: ... + +class Min(BinaryOpExpr): + def __init__(self, a, b, span: _Any | None = ...) -> None: ... + +class Max(BinaryOpExpr): + def __init__(self, a, b, span: _Any | None = ...) -> None: ... + +class EQ(CmpExpr): + def __init__(self, a, b, span: _Any | None = ...) -> None: ... + +class NE(CmpExpr): + def __init__(self, a, b, span: _Any | None = ...) -> None: ... + +class LT(CmpExpr): + def __init__(self, a, b, span: _Any | None = ...) -> None: ... + +class LE(CmpExpr): + def __init__(self, a, b, span: _Any | None = ...) -> None: ... + +class GT(CmpExpr): + def __init__(self, a, b, span: _Any | None = ...) -> None: ... + +class GE(CmpExpr): + def __init__(self, a, b, span: _Any | None = ...) -> None: ... + +class And(LogicalExpr): + def __init__(self, a, b, span: _Any | None = ...) -> None: ... + +class Or(LogicalExpr): + def __init__(self, a, b, span: _Any | None = ...) -> None: ... + +class Not(LogicalExpr): + def __init__(self, a, span: _Any | None = ...) -> None: ... + +class Select(PrimExprWithOp): + def __init__(self, condition, true_value, false_value, span: _Any | None = ...) -> None: ... + +class Load(PrimExprWithOp): + def __init__( + self, dtype, buffer_var, index, predicate: _Any | None = ..., span: _Any | None = ... + ) -> None: ... + +class BufferLoad(PrimExprWithOp): + def __init__(self, buffer, indices, span: _Any | None = ...) -> None: ... + +class ProducerLoad(PrimExprWithOp): + def __init__(self, producer, indices, span: _Any | None = ...) -> None: ... + +class Ramp(PrimExprWithOp): + def __init__(self, base, stride, lanes, span: _Any | None = ...) -> None: ... + +class Broadcast(PrimExprWithOp): + def __init__(self, value, lanes, span: _Any | None = ...) -> None: ... + +class Shuffle(PrimExprWithOp): + def __init__(self, vectors, indices, span: _Any | None = ...) -> None: ... + +class CallEffectKind: + ExprAnnotation: _Any + Pure: _Any + ReadState: _Any + UpdateState: _Any + Opaque: _Any + +class Call(PrimExprWithOp): + def __init__(self, dtype, op, args, span: _Any | None = ...) -> None: ... + +class Let(PrimExprWithOp): + def __init__(self, var, value, body, span: _Any | None = ...) -> None: ... + +class Any(PrimExprWithOp): + def __init__(self, span: _Any | None = ...) -> None: ... + +""" +function +""" + +class PrimFunc(BaseFunc): + def __init__( + self, + params, + body, + ret_type: Any | None = ..., + buffer_map: Any | None = ..., + attrs: Any | None = ..., + span: Any | None = ..., + ) -> None: ... + def with_body(self, new_body, span: Any | None = ...): ... + def specialize(self, param_map: Mapping[Var, Union[PrimExpr, Buffer]]): ... + def script(self, tir_prefix: str = ..., show_meta: bool = ...) -> str: ... + +""" +generic +""" + +def add(lhs, rhs, span: Any | None = ...): ... +def subtract(lhs, rhs, span: Any | None = ...): ... +def multiply(lhs, rhs, span: Any | None = ...): ... +def divide(lhs, rhs, span: Any | None = ...): ... +def floordiv(lhs, rhs, span: Any | None = ...): ... +def cast(src, dtype, span: Any | None = ...): ... + +""" +ir_builder +""" + +class WithScope: + def __init__(self, enter_value, exit_cb) -> None: ... + def __enter__(self): ... + def __exit__(self, ptype, value, trace) -> None: ... + +class BufferVar(ObjectGeneric): + def __init__(self, builder, buffer_var, shape, content_type) -> None: ... + def asobject(self): ... + @property + def dtype(self): ... + def __getitem__(self, index): ... + def __setitem__(self, index, value) -> None: ... + +class IRBuilder: + nidx: int + def __init__(self) -> None: ... + def emit(self, stmt) -> None: ... + def scope_attr(self, node, attr_key, value): ... + def for_range(self, begin, end, name: str = ..., dtype: str = ..., kind: str = ...): ... + def while_loop(self, condition): ... + def if_scope(self, cond): ... + def else_scope(self): ... + def new_scope(self): ... + def let(self, var_name, value): ... + def allocate(self, dtype, shape, name: str = ..., scope: str = ...): ... + def pointer(self, content_type, name: str = ..., scope: str = ...): ... + def buffer_ptr(self, buf, shape: Any | None = ...): ... + def likely(self, expr): ... + def get(self): ... + +def create(): ... + +""" +op +""" + +def call_packed(*args, span: Any | None = ...): ... +def call_intrin(dtype, func_name, *args, span: Any | None = ...): ... +def call_pure_extern(dtype, func_name, *args, span: Any | None = ...): ... +def call_extern(dtype, func_name, *args, span: Any | None = ...): ... +def call_llvm_intrin(dtype, name, *args, span: Any | None = ...): ... +def call_llvm_pure_intrin(dtype, name, *args, span: Any | None = ...): ... +def ret(val): ... +def any(*args, span: Any | None = ...): ... +def all(*args, span: Any | None = ...): ... +def trace(args, trace_action: str = ...): ... +def min_value(dtype, span: Any | None = ...): ... +def max_value(dtype: str, span: Optional[Span] = ...) -> Any: ... +def exp(x): ... +def exp2(x): ... +def exp10(x): ... +def erf(x): ... +def tanh(x): ... +def sigmoid(x): ... +def log(x): ... +def log2(x): ... +def log10(x): ... +def log1p(x): ... +def tan(x): ... +def cos(x): ... +def cosh(x): ... +def acos(x): ... +def acosh(x): ... +def sin(x): ... +def sinh(x): ... +def asin(x): ... +def asinh(x): ... +def atan(x): ... +def atanh(x): ... +def atan2(x1, x2): ... +def sqrt(x): ... +def rsqrt(x): ... +def clz(x): ... +def floor(x: PrimExprWithOp, span: Any | None = ...): ... +def ceil(x, span: Any | None = ...): ... +def trunc(x, span: Any | None = ...): ... +def abs(x, span: Any | None = ...): ... +def round(x, span: Any | None = ...): ... +def nearbyint(x, span: Any | None = ...): ... +def nextafter(x1, x2): ... +def hypot(x1, x2): ... +def copysign(x1, x2): ... +def ldexp(x1, x2): ... +def isnan(x, span: Any | None = ...): ... +def isfinite(x, span: Any | None = ...): ... +def isinf(x, span: Any | None = ...): ... +def power(x, y, span: Any | None = ...): ... +def popcount(x): ... +def q_multiply_shift(x, y, q, s): ... +def fmod(x, y): ... +def if_then_else(cond, t, f, span: Any | None = ...): ... +def div(a, b, span: Any | None = ...): ... +def indexdiv(a, b, span: Any | None = ...): ... +def indexmod(a, b, span: Any | None = ...): ... +def truncdiv(a, b, span: Any | None = ...): ... +def truncmod(a, b, span: Any | None = ...): ... +def floordiv(a, b, span: Any | None = ...): ... +def floormod(a, b, span: Any | None = ...): ... +def comm_reducer(fcombine, fidentity, name: str = ...): ... + +""" +stmt_functor +""" + +def ir_transform(stmt, preorder, postorder, only_enable: Any | None = ...): ... +def post_order_visit(stmt, fvisit): ... +def substitute(node, vmap): ... + +""" +stmt +""" + +class Stmt(Object): ... + +class LetStmt(Stmt): + def __init__(self, var, value, body, span: Any | None = ...) -> None: ... + +class AssertStmt(Stmt): + def __init__(self, condition, message, body, span: Any | None = ...) -> None: ... + +class ForKind(IntEnum): + SERIAL: int + PARALLEL: int + VECTORIZED: int + UNROLLED: int + THREAD_BINDING: int + +class For(Stmt): + def __init__( + self, + loop_var, + min_val, + extent, + kind, + body, + thread_binding: Any | None = ..., + annotations: Any | None = ..., + span: Any | None = ..., + ) -> None: ... + +class While(Stmt): + def __init__(self, condition, body, span: Any | None = ...) -> None: ... + +class Store(Stmt): + def __init__( + self, buffer_var, value, index, predicate: Any | None = ..., span: Any | None = ... + ) -> None: ... + +class BufferStore(Stmt): + def __init__(self, buffer, value, indices, span: Any | None = ...) -> None: ... + +class BufferRealize(Stmt): + def __init__(self, buffer, bounds, condition, body, span: Any | None = ...) -> None: ... + +class ProducerStore(Stmt): + def __init__(self, producer, value, indices, span: Any | None = ...) -> None: ... + +class Allocate(Stmt): + def __init__( + self, + buffer_var, + dtype, + extents, + condition, + body, + annotations: Any | None = ..., + span: Any | None = ..., + ) -> None: ... + +class AttrStmt(Stmt): + def __init__(self, node, attr_key, value, body, span: Any | None = ...) -> None: ... + +class ProducerRealize(Stmt): + def __init__( + self, producer, bounds, condition, body, storage_scope: str = ..., span: Any | None = ... + ) -> None: ... + +class SeqStmt(Stmt): + def __init__(self, seq, span: Any | None = ...) -> None: ... + def __getitem__(self, i): ... + def __len__(self): ... + +class IfThenElse(Stmt): + def __init__(self, condition, then_case, else_case, span: Any | None = ...) -> None: ... + +class Evaluate(Stmt): + def __init__(self, value, span: Any | None = ...) -> None: ... + +class Prefetch(Stmt): + def __init__(self, buffer, bounds, span: Any | None = ...) -> None: ... + +class BufferRegion(Object): + buffer: Buffer + region: List[Range] + def __init__(self, buffer: Buffer, region: List[Range]) -> None: ... + +class MatchBufferRegion(Object): + buffer: Buffer + source: BufferRegion + def __init__(self, buffer: Buffer, source: BufferRegion) -> None: ... + +class Block(Stmt): + iter_vars: List[IterVar] + reads: List[BufferRegion] + writes: List[BufferRegion] + name_hint: str + body: Stmt + init: Optional[Stmt] + alloc_buffers: Optional[List[Buffer]] + match_buffers: Optional[List[MatchBufferRegion]] + annotations: Optional[Mapping[str, Object]] + span: Optional[Span] + def __init__( + self, + iter_vars: List[IterVar], + reads: List[BufferRegion], + writes: List[BufferRegion], + name_hint: str, + body: Stmt, + init: Optional[Stmt] = ..., + alloc_buffers: Optional[List[Buffer]] = ..., + match_buffers: Optional[List[MatchBufferRegion]] = ..., + annotations: Optional[Mapping[str, Object]] = ..., + span: Optional[Span] = ..., + ) -> None: ... + +class BlockRealize(Stmt): + iter_values: List[PrimExpr] + predicate: PrimExpr + block: Block + span: Optional[Span] + def __init__( + self, + iter_values: List[PrimExpr], + predicate: Union[PrimExpr, bool], + block: Block, + span: Optional[Span] = ..., + ) -> None: ... + +def stmt_seq(*args): ... +def stmt_list(stmt): ... \ No newline at end of file diff --git a/tests/python/unittest/test_tvmscript_type.py b/tests/python/unittest/test_tvmscript_type.py new file mode 100644 index 000000000000..67de0b550ad7 --- /dev/null +++ b/tests/python/unittest/test_tvmscript_type.py @@ -0,0 +1,46 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import sys +import pytest +import tvm +from tvm import tir +from tvm.script import tir as T + + +@pytest.mark.mypy_testing +def test_mypy_use_reveal_type() -> None: + @tvm.script.ir_module + class Module: + @T.prim_func + def func(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128], dtype="float32") + B = T.match_buffer(b, [128, 128], dtype="float32") + C = T.match_buffer(c, [128, 128], dtype="float32") + reveal_type(A) # R: + for i, j, k in T.grid(128, 128, T.reduce_axis(0, 128)): + with T.block("C"): + C[i, j] = T.if_then_else( + i == 0 and j == 0 and k == 0, + 0.0, + C[i, j] + A[i, k] * B[k, j], + dtype="float32", + ) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) From 54eceb5abf79a08f6fb9ab8a563289446132e197 Mon Sep 17 00:00:00 2001 From: shingjan Date: Tue, 2 Nov 2021 13:36:59 -0700 Subject: [PATCH 02/32] get rid of span --- python/tvm/script/tir/__init__.pyi | 80 ++++++++++---------- tests/python/unittest/test_tvmscript_type.py | 33 ++++---- 2 files changed, 54 insertions(+), 59 deletions(-) diff --git a/python/tvm/script/tir/__init__.pyi b/python/tvm/script/tir/__init__.pyi index 8358f9c2b2fb..534c6befd96f 100644 --- a/python/tvm/script/tir/__init__.pyi +++ b/python/tvm/script/tir/__init__.pyi @@ -12,54 +12,53 @@ from typing import ( List, Mapping, ) -from tvm.ir import Span from tvm.tir.function import PrimFunc -from tvm.tir import PrimExpr, Buffer, IterVar, Var, Ptr +from tvm.tir import PrimExpr, Buffer, IterVar, Var from .node import BufferSlice """ Variables and constants """ -def bool(imm: int, span: Span) -> PrimExpr: ... -def int8(imm: int, span: Span) -> PrimExpr: ... -def int16(imm: int, span: Span) -> PrimExpr: ... -def int32(imm: int, span: Span) -> PrimExpr: ... -def int64(imm: int, span: Span) -> PrimExpr: ... -def uint8(imm: int, span: Span) -> PrimExpr: ... -def uint16(imm: int, span: Span) -> PrimExpr: ... -def uint32(imm: int, span: Span) -> PrimExpr: ... -def uint64(imm: int, span: Span) -> PrimExpr: ... -def float8(imm: int, span: Span) -> PrimExpr: ... -def float16(imm: int, span: Span) -> PrimExpr: ... -def float32(imm: int, span: Span) -> PrimExpr: ... -def float64(imm: int, span: Span) -> PrimExpr: ... +def bool(imm: int) -> PrimExpr: ... +def int8(imm: int) -> PrimExpr: ... +def int16(imm: int) -> PrimExpr: ... +def int32(imm: int) -> PrimExpr: ... +def int64(imm: int) -> PrimExpr: ... +def uint8(imm: int) -> PrimExpr: ... +def uint16(imm: int) -> PrimExpr: ... +def uint32(imm: int) -> PrimExpr: ... +def uint64(imm: int) -> PrimExpr: ... +def float8(imm: int) -> PrimExpr: ... +def float16(imm: int) -> PrimExpr: ... +def float32(imm: int) -> PrimExpr: ... +def float64(imm: int) -> PrimExpr: ... """ Intrinsic """ -def min_value(dtype, span: Span): ... -def max_value(dtype, span: Span): ... -def floordiv(x: PrimExpr, y: PrimExpr, span: Span): ... -def floormod(x: PrimExpr, y: PrimExpr, span: Span): ... -def abs(x, span: Span): ... -def load(dtype, var, index, predicate=None, span: Span = None): ... -def cast(value, dtype, span: Span): ... -def ramp(base, stride, lanes, span: Span): ... -def broadcast(value, lanes, span: Span): ... -def iter_var(var, dom, iter_type, thread_tag, span: Span): ... -def max(a, b, span: Span): ... -def min(a, b, span: Span): ... -def get_axis(begin, end, iter_type, span: Span): ... -def range(begin, end, span: Span): ... -def reduce_axis(begin, end, span: Span): ... -def scan_axis(begin, end, span: Span): ... -def opaque_axis(begin, end, span: Span): ... -def Select(cond, if_body, else_body, span: Span): ... -def evaluate(value, span: Span): ... -def store(var, index, value, predicate=True, span: Span = None): ... -def comm_reducer(lambda_io, identities, span: Span): ... +def min_value(dtype): ... +def max_value(dtype): ... +def floordiv(x: PrimExpr, y: PrimExpr): ... +def floormod(x: PrimExpr, y: PrimExpr): ... +def abs(x): ... +def load(dtype, var, index, predicate=None): ... +def cast(value, dtype): ... +def ramp(base, stride, lanes): ... +def broadcast(value, lanes): ... +def iter_var(var, dom, iter_type, thread_tag): ... +def max(a, b): ... +def min(a, b): ... +def get_axis(begin, end, iter_type): ... +def range(begin, end): ... +def reduce_axis(begin, end): ... +def scan_axis(begin, end): ... +def opaque_axis(begin, end): ... +def Select(cond, if_body, else_body): ... +def evaluate(value): ... +def store(var, index, value, predicate=True): ... +def comm_reducer(lambda_io, identities): ... """ Unary operator @@ -127,7 +126,6 @@ def match_buffer( align: int = -1, offset_factor: int = 0, buffer_type: str = "default", - span: Span = None, ) -> Buffer: ... def buffer_decl( shape: Sequence[Union[PrimExpr, int]], @@ -139,7 +137,6 @@ def buffer_decl( align: int = -1, offset_factor: int = 0, buffer_type: str = "default", - span: Span = None, ) -> Buffer: ... def alloc_buffer( shape: Sequence[Union[PrimExpr, int]], @@ -151,16 +148,15 @@ def alloc_buffer( align: int = -1, offset_factor: int = 0, buffer_type: str = "default", - span: Span = None, ) -> Buffer: ... """ Reads/Writes """ -def reads(read_regions: Union[BufferSlice, List[BufferSlice]], span: Span = None) -> None: ... -def writes(write_region: Union[BufferSlice, List[BufferSlice]], span: Span = None) -> None: ... -def block_attr(attrs: Mapping[str, Object], span: Span = None) -> None: ... +def reads(read_regions: Union[BufferSlice, List[BufferSlice]]) -> None: ... +def writes(write_region: Union[BufferSlice, List[BufferSlice]]) -> None: ... +def block_attr(attrs: Mapping[str, Object]) -> None: ... """ Scope handler diff --git a/tests/python/unittest/test_tvmscript_type.py b/tests/python/unittest/test_tvmscript_type.py index 67de0b550ad7..437674cb5c77 100644 --- a/tests/python/unittest/test_tvmscript_type.py +++ b/tests/python/unittest/test_tvmscript_type.py @@ -23,23 +23,22 @@ @pytest.mark.mypy_testing -def test_mypy_use_reveal_type() -> None: - @tvm.script.ir_module - class Module: - @T.prim_func - def func(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, [128, 128], dtype="float32") - B = T.match_buffer(b, [128, 128], dtype="float32") - C = T.match_buffer(c, [128, 128], dtype="float32") - reveal_type(A) # R: - for i, j, k in T.grid(128, 128, T.reduce_axis(0, 128)): - with T.block("C"): - C[i, j] = T.if_then_else( - i == 0 and j == 0 and k == 0, - 0.0, - C[i, j] + A[i, k] * B[k, j], - dtype="float32", - ) +@tvm.script.ir_module +class Module: + @T.prim_func + def func(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128], dtype="float32") + B = T.match_buffer(b, [128, 128], dtype="float32") + C = T.match_buffer(c, [128, 128], dtype="float32") + reveal_type(A) # R: + for i, j, k in T.grid(128, 128, T.reduce_axis(0, 128)): + with T.block("C"): + C[i, j] = T.if_then_else( + i == 0 and j == 0 and k == 0, + 0.0, + C[i, j] + A[i, k] * B[k, j], + dtype="float32", + ) if __name__ == "__main__": From 8e7309fa12fcb07c2c33696ad792702c63ee096c Mon Sep 17 00:00:00 2001 From: shingjan Date: Tue, 2 Nov 2021 13:47:50 -0700 Subject: [PATCH 03/32] afs header --- python/tvm/script/tir/__init__.pyi | 17 +++++++++++++++++ python/tvm/tir/__init__.pyi | 19 ++++++++++++++++++- 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/python/tvm/script/tir/__init__.pyi b/python/tvm/script/tir/__init__.pyi index 534c6befd96f..ea034f9d6e95 100644 --- a/python/tvm/script/tir/__init__.pyi +++ b/python/tvm/script/tir/__init__.pyi @@ -1,3 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=redefined-builtin from typing import ( Any, Callable, diff --git a/python/tvm/tir/__init__.pyi b/python/tvm/tir/__init__.pyi index c7bccf5d9bec..df7cf100d5f9 100644 --- a/python/tvm/tir/__init__.pyi +++ b/python/tvm/tir/__init__.pyi @@ -1,3 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=redefined-builtin from typing import ( Any, Callable, @@ -609,4 +626,4 @@ class BlockRealize(Stmt): ) -> None: ... def stmt_seq(*args): ... -def stmt_list(stmt): ... \ No newline at end of file +def stmt_list(stmt): ... From 47615e1de88b8ba32277610552bded18117bed89 Mon Sep 17 00:00:00 2001 From: shingjan Date: Tue, 2 Nov 2021 14:11:36 -0700 Subject: [PATCH 04/32] update scope_handler --- python/tvm/script/tir/__init__.pyi | 60 ++++++++++++++++++++++-------- 1 file changed, 45 insertions(+), 15 deletions(-) diff --git a/python/tvm/script/tir/__init__.pyi b/python/tvm/script/tir/__init__.pyi index ea034f9d6e95..d21db4631511 100644 --- a/python/tvm/script/tir/__init__.pyi +++ b/python/tvm/script/tir/__init__.pyi @@ -105,20 +105,6 @@ def atan2(x: PrimExpr) -> PrimExpr: ... def sqrt(x: PrimExpr) -> PrimExpr: ... def rsqrt(x: PrimExpr) -> PrimExpr: ... -""" -Loops -""" - -def serial(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> Iterable[IterVar]: ... -def parallel(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> Iterable[IterVar]: ... -def vectorize(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> Iterable[IterVar]: ... -def unroll(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> Iterable[IterVar]: ... -def grid(*extents: Union[PrimExpr, int]) -> Iterable[Tuple[IterVar]]: ... -def range(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> Iterable[IterVar]: ... -def thread_binding( - begin: Union[PrimExpr, int], end: Union[PrimExpr, int], thread: str -) -> Iterable[IterVar]: ... - """ Axis """ @@ -190,7 +176,51 @@ class let(ContextManager): def __init__(self, var: Var, value: PrimExpr) -> None: ... def where(cond: PrimExpr) -> None: ... -def realize(x: Buffer, scope: str, condition: bool = True) -> None: ... +def allocate(extents, dtype, scope: str, condition: bool = True, annotations=None) -> None: ... +def launch_thread(env_var, extent): ... +def realize(buffer_slice: BufferSlice, scope: str, condition: bool = True) -> None: ... +def attr(attr_node, attr_key, value) -> None: ... +def Assert(condition, message): ... +def let(var, value): ... +def block(name_hint: str = ""): ... +def init(): ... + +""" +Scope handler - Loops +""" + +def serial( + begin: PrimExpr, + end: PrimExpr, + annotations: Optional[Mapping[str, Object]] = None, +) -> None: ... +def parallel( + begin: PrimExpr, + end: PrimExpr, + annotations: Optional[Mapping[str, Object]] = None, +) -> None: ... +def vectorized( + begin: PrimExpr, + end: PrimExpr, + annotations: Optional[Mapping[str, Object]] = None, +) -> None: ... +def unroll( + begin: PrimExpr, + end: PrimExpr, + annotations: Optional[Mapping[str, Object]] = None, +) -> None: ... +def thread_binding( + begin: PrimExpr, + end: PrimExpr, + thread: str, + annotations: Optional[Mapping[str, Object]] = None, +) -> None: ... +def for_range( + begin: PrimExpr, + end: PrimExpr = None, + annotations: Optional[Mapping[str, Object]] = None, +) -> None: ... +def grid(*extents: List[PrimExpr]) -> None: ... """ Threads and Bindings From 981f287bdce9faa6bdbd99653f4a8b371610e8e8 Mon Sep 17 00:00:00 2001 From: shingjan Date: Tue, 2 Nov 2021 14:12:40 -0700 Subject: [PATCH 05/32] rm tir/__init__.pyi --- python/tvm/tir/__init__.pyi | 629 ------------------------------------ 1 file changed, 629 deletions(-) delete mode 100644 python/tvm/tir/__init__.pyi diff --git a/python/tvm/tir/__init__.pyi b/python/tvm/tir/__init__.pyi deleted file mode 100644 index df7cf100d5f9..000000000000 --- a/python/tvm/tir/__init__.pyi +++ /dev/null @@ -1,629 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=redefined-builtin -from typing import ( - Any, - Callable, - ContextManager, - Dict, - Iterable, - List, - Optional, - Tuple, - Union, - Mapping, - Sequence, - overload, -) -from enum import IntEnum -from tvm import ir -from tvm.ir import BaseFunc, Range, Span -from tvm.runtime import ( - DataType, - DataTypeCode, - Object, - ObjectGeneric, -) - -""" -Redefine types -""" - -class PrimExpr: - def __init__(self: PrimExpr) -> None: ... - @overload - def __add__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ... - @overload - def __sub__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ... - @overload - def __mul__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ... - @overload - def __div__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ... - @overload - def __add__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ... - @overload - def __radd__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ... - @overload - def __sub__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ... - @overload - def __rsub__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ... - @overload - def __mul__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ... - @overload - def __rmul__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ... - @overload - def __div__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ... - @overload - def __rdiv__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ... - -""" -buffer -""" - -class Buffer(Object): - READ: int - WRITE: int - def access_ptr( - self, access_mask, ptr_type: str = ..., content_lanes: int = ..., offset: int = ... - ): ... - def vload(self, begin, dtype: Any | None = ...): ... - def vstore(self, begin, value): ... - def scope(self): ... - -def decl_buffer( - shape, - dtype: Any | None = ..., - name: str = ..., - data: Any | None = ..., - strides: Any | None = ..., - elem_offset: Any | None = ..., - scope: str = ..., - data_alignment: int = ..., - offset_factor: int = ..., - buffer_type: str = ..., - span: Any | None = ..., -): ... - -class DataProducer(Object): ... - -""" -data layout -""" - -class Layout(Object): - def __len__(self): ... - def __contains__(self, axis): ... - def __getitem__(self, index): ... - def index_of(self, axis): ... - def factor_of(self, axis): ... - -class BijectiveLayout(Object): - def forward_index(self, index): ... - def backward_index(self, index): ... - def forward_shape(self, shape): ... - def backward_shape(self, shape): ... - -def layout(layout_str: str) -> Layout: ... -def bijective_layout( - src_layout: Union[str, Layout], dst_layout: Union[str, Layout] -) -> BijectiveLayout: ... - -""" -expr -""" -from typing import Any as _Any - -def div_ambiguity_error(): ... - -class ExprOp: - def __add__(self, other): ... - def __radd__(self, other): ... - def __sub__(self, other): ... - def __rsub__(self, other): ... - def __mul__(self, other): ... - def __rmul__(self, other): ... - def __div__(self, other): ... - def __rdiv__(self, other): ... - def __truediv__(self, other): ... - def __rtruediv__(self, other): ... - def __floordiv__(self, other): ... - def __rfloordiv__(self, other): ... - def __mod__(self, other): ... - def __rmod__(self, other): ... - def __neg__(self): ... - def __lshift__(self, other): ... - def __rlshift__(self, other): ... - def __rshift__(self, other): ... - def __rrshift__(self, other): ... - def __and__(self, other): ... - def __rand__(self, other): ... - def __or__(self, other): ... - def __ror__(self, other): ... - def __xor__(self, other): ... - def __rxor__(self, other): ... - def __invert__(self): ... - def __lt__(self, other): ... - def __le__(self, other): ... - def __eq__(self, other): ... - def __ne__(self, other): ... - def __gt__(self, other): ... - def __ge__(self, other): ... - def __nonzero__(self) -> None: ... - def __bool__(self): ... - def equal(self, other, span: _Any | None = ...): ... - def astype(self, dtype: str, span: Optional[Span] = ...): ... - -class EqualOp(ObjectGeneric, ExprOp): - same_as: _Any - a: _Any - b: _Any - span: _Any - def __init__(self, a, b, span: _Any | None = ...) -> None: ... - def __nonzero__(self): ... - def __bool__(self): ... - def asobject(self): ... - -class NotEqualOp(ObjectGeneric, ExprOp): - same_as: _Any - a: _Any - b: _Any - span: _Any - def __init__(self, a, b, span: _Any | None = ...) -> None: ... - def __nonzero__(self): ... - def __bool__(self): ... - def asobject(self): ... - -class IntImmEnum(ObjectGeneric): - value: _Any - span: _Any - def __init__(self, value, span: _Any | None = ...) -> None: ... - def asobject(self): ... - -class PrimExprWithOp(ExprOp, PrimExpr): - __hash__: _Any - -class ConstExpr(PrimExprWithOp): ... -class BinaryOpExpr(PrimExprWithOp): ... -class CmpExpr(PrimExprWithOp): ... -class LogicalExpr(PrimExprWithOp): ... - -class Var(PrimExprWithOp): - def __init__( - self, name: str, dtype: Union[str, ir.Type], span: Optional[Span] = ... - ) -> None: ... - -class SizeVar(Var): - def __init__(self, name, dtype, span: _Any | None = ...) -> None: ... - -class IterVar(Object, ExprOp): - DataPar: int - ThreadIndex: int - CommReduce: int - Ordered: int - DimInfo: int - Unrolled: int - Vectorized: int - Parallelized: int - Tensorized: int - def __init__( - self, dom, var, iter_type, thread_tag: str = ..., span: _Any | None = ... - ) -> None: ... - -class CommReducer(Object): - def __init__(self, lhs, rhs, result, identity_element, span: _Any | None = ...) -> None: ... - -class Reduce(PrimExprWithOp): - def __init__( - self, - combiner, - src, - rdom, - condition, - value_index, - init: _Any | None = ..., - span: _Any | None = ..., - ) -> None: ... - -class FloatImm(ConstExpr): - def __init__(self, dtype, value, span: _Any | None = ...) -> None: ... - def __float__(self): ... - -class IntImm(ConstExpr): - def __init__(self, dtype, value, span: _Any | None = ...) -> None: ... - def __hash__(self): ... - def __int__(self): ... - def __nonzero__(self): ... - def __eq__(self, other): ... - def __ne__(self, other): ... - def __bool__(self): ... - -class StringImm(ConstExpr): - def __init__(self, value, span: _Any | None = ...) -> None: ... - def __eq__(self, other): ... - def __ne__(self, other): ... - def __hash__(self): ... - -class Cast(PrimExprWithOp): - def __init__(self, dtype, value, span: _Any | None = ...) -> None: ... - -class Add(BinaryOpExpr): - def __init__(self, a, b, span: _Any | None = ...) -> None: ... - -class Sub(BinaryOpExpr): - def __init__(self, a, b, span: _Any | None = ...) -> None: ... - -class Mul(BinaryOpExpr): - def __init__(self, a, b, span: _Any | None = ...) -> None: ... - -class Div(BinaryOpExpr): - def __init__(self, a, b, span: _Any | None = ...) -> None: ... - -class Mod(BinaryOpExpr): - def __init__(self, a, b, span: _Any | None = ...) -> None: ... - -class FloorDiv(BinaryOpExpr): - def __init__(self, a, b, span: _Any | None = ...) -> None: ... - -class FloorMod(BinaryOpExpr): - def __init__(self, a, b, span: _Any | None = ...) -> None: ... - -class Min(BinaryOpExpr): - def __init__(self, a, b, span: _Any | None = ...) -> None: ... - -class Max(BinaryOpExpr): - def __init__(self, a, b, span: _Any | None = ...) -> None: ... - -class EQ(CmpExpr): - def __init__(self, a, b, span: _Any | None = ...) -> None: ... - -class NE(CmpExpr): - def __init__(self, a, b, span: _Any | None = ...) -> None: ... - -class LT(CmpExpr): - def __init__(self, a, b, span: _Any | None = ...) -> None: ... - -class LE(CmpExpr): - def __init__(self, a, b, span: _Any | None = ...) -> None: ... - -class GT(CmpExpr): - def __init__(self, a, b, span: _Any | None = ...) -> None: ... - -class GE(CmpExpr): - def __init__(self, a, b, span: _Any | None = ...) -> None: ... - -class And(LogicalExpr): - def __init__(self, a, b, span: _Any | None = ...) -> None: ... - -class Or(LogicalExpr): - def __init__(self, a, b, span: _Any | None = ...) -> None: ... - -class Not(LogicalExpr): - def __init__(self, a, span: _Any | None = ...) -> None: ... - -class Select(PrimExprWithOp): - def __init__(self, condition, true_value, false_value, span: _Any | None = ...) -> None: ... - -class Load(PrimExprWithOp): - def __init__( - self, dtype, buffer_var, index, predicate: _Any | None = ..., span: _Any | None = ... - ) -> None: ... - -class BufferLoad(PrimExprWithOp): - def __init__(self, buffer, indices, span: _Any | None = ...) -> None: ... - -class ProducerLoad(PrimExprWithOp): - def __init__(self, producer, indices, span: _Any | None = ...) -> None: ... - -class Ramp(PrimExprWithOp): - def __init__(self, base, stride, lanes, span: _Any | None = ...) -> None: ... - -class Broadcast(PrimExprWithOp): - def __init__(self, value, lanes, span: _Any | None = ...) -> None: ... - -class Shuffle(PrimExprWithOp): - def __init__(self, vectors, indices, span: _Any | None = ...) -> None: ... - -class CallEffectKind: - ExprAnnotation: _Any - Pure: _Any - ReadState: _Any - UpdateState: _Any - Opaque: _Any - -class Call(PrimExprWithOp): - def __init__(self, dtype, op, args, span: _Any | None = ...) -> None: ... - -class Let(PrimExprWithOp): - def __init__(self, var, value, body, span: _Any | None = ...) -> None: ... - -class Any(PrimExprWithOp): - def __init__(self, span: _Any | None = ...) -> None: ... - -""" -function -""" - -class PrimFunc(BaseFunc): - def __init__( - self, - params, - body, - ret_type: Any | None = ..., - buffer_map: Any | None = ..., - attrs: Any | None = ..., - span: Any | None = ..., - ) -> None: ... - def with_body(self, new_body, span: Any | None = ...): ... - def specialize(self, param_map: Mapping[Var, Union[PrimExpr, Buffer]]): ... - def script(self, tir_prefix: str = ..., show_meta: bool = ...) -> str: ... - -""" -generic -""" - -def add(lhs, rhs, span: Any | None = ...): ... -def subtract(lhs, rhs, span: Any | None = ...): ... -def multiply(lhs, rhs, span: Any | None = ...): ... -def divide(lhs, rhs, span: Any | None = ...): ... -def floordiv(lhs, rhs, span: Any | None = ...): ... -def cast(src, dtype, span: Any | None = ...): ... - -""" -ir_builder -""" - -class WithScope: - def __init__(self, enter_value, exit_cb) -> None: ... - def __enter__(self): ... - def __exit__(self, ptype, value, trace) -> None: ... - -class BufferVar(ObjectGeneric): - def __init__(self, builder, buffer_var, shape, content_type) -> None: ... - def asobject(self): ... - @property - def dtype(self): ... - def __getitem__(self, index): ... - def __setitem__(self, index, value) -> None: ... - -class IRBuilder: - nidx: int - def __init__(self) -> None: ... - def emit(self, stmt) -> None: ... - def scope_attr(self, node, attr_key, value): ... - def for_range(self, begin, end, name: str = ..., dtype: str = ..., kind: str = ...): ... - def while_loop(self, condition): ... - def if_scope(self, cond): ... - def else_scope(self): ... - def new_scope(self): ... - def let(self, var_name, value): ... - def allocate(self, dtype, shape, name: str = ..., scope: str = ...): ... - def pointer(self, content_type, name: str = ..., scope: str = ...): ... - def buffer_ptr(self, buf, shape: Any | None = ...): ... - def likely(self, expr): ... - def get(self): ... - -def create(): ... - -""" -op -""" - -def call_packed(*args, span: Any | None = ...): ... -def call_intrin(dtype, func_name, *args, span: Any | None = ...): ... -def call_pure_extern(dtype, func_name, *args, span: Any | None = ...): ... -def call_extern(dtype, func_name, *args, span: Any | None = ...): ... -def call_llvm_intrin(dtype, name, *args, span: Any | None = ...): ... -def call_llvm_pure_intrin(dtype, name, *args, span: Any | None = ...): ... -def ret(val): ... -def any(*args, span: Any | None = ...): ... -def all(*args, span: Any | None = ...): ... -def trace(args, trace_action: str = ...): ... -def min_value(dtype, span: Any | None = ...): ... -def max_value(dtype: str, span: Optional[Span] = ...) -> Any: ... -def exp(x): ... -def exp2(x): ... -def exp10(x): ... -def erf(x): ... -def tanh(x): ... -def sigmoid(x): ... -def log(x): ... -def log2(x): ... -def log10(x): ... -def log1p(x): ... -def tan(x): ... -def cos(x): ... -def cosh(x): ... -def acos(x): ... -def acosh(x): ... -def sin(x): ... -def sinh(x): ... -def asin(x): ... -def asinh(x): ... -def atan(x): ... -def atanh(x): ... -def atan2(x1, x2): ... -def sqrt(x): ... -def rsqrt(x): ... -def clz(x): ... -def floor(x: PrimExprWithOp, span: Any | None = ...): ... -def ceil(x, span: Any | None = ...): ... -def trunc(x, span: Any | None = ...): ... -def abs(x, span: Any | None = ...): ... -def round(x, span: Any | None = ...): ... -def nearbyint(x, span: Any | None = ...): ... -def nextafter(x1, x2): ... -def hypot(x1, x2): ... -def copysign(x1, x2): ... -def ldexp(x1, x2): ... -def isnan(x, span: Any | None = ...): ... -def isfinite(x, span: Any | None = ...): ... -def isinf(x, span: Any | None = ...): ... -def power(x, y, span: Any | None = ...): ... -def popcount(x): ... -def q_multiply_shift(x, y, q, s): ... -def fmod(x, y): ... -def if_then_else(cond, t, f, span: Any | None = ...): ... -def div(a, b, span: Any | None = ...): ... -def indexdiv(a, b, span: Any | None = ...): ... -def indexmod(a, b, span: Any | None = ...): ... -def truncdiv(a, b, span: Any | None = ...): ... -def truncmod(a, b, span: Any | None = ...): ... -def floordiv(a, b, span: Any | None = ...): ... -def floormod(a, b, span: Any | None = ...): ... -def comm_reducer(fcombine, fidentity, name: str = ...): ... - -""" -stmt_functor -""" - -def ir_transform(stmt, preorder, postorder, only_enable: Any | None = ...): ... -def post_order_visit(stmt, fvisit): ... -def substitute(node, vmap): ... - -""" -stmt -""" - -class Stmt(Object): ... - -class LetStmt(Stmt): - def __init__(self, var, value, body, span: Any | None = ...) -> None: ... - -class AssertStmt(Stmt): - def __init__(self, condition, message, body, span: Any | None = ...) -> None: ... - -class ForKind(IntEnum): - SERIAL: int - PARALLEL: int - VECTORIZED: int - UNROLLED: int - THREAD_BINDING: int - -class For(Stmt): - def __init__( - self, - loop_var, - min_val, - extent, - kind, - body, - thread_binding: Any | None = ..., - annotations: Any | None = ..., - span: Any | None = ..., - ) -> None: ... - -class While(Stmt): - def __init__(self, condition, body, span: Any | None = ...) -> None: ... - -class Store(Stmt): - def __init__( - self, buffer_var, value, index, predicate: Any | None = ..., span: Any | None = ... - ) -> None: ... - -class BufferStore(Stmt): - def __init__(self, buffer, value, indices, span: Any | None = ...) -> None: ... - -class BufferRealize(Stmt): - def __init__(self, buffer, bounds, condition, body, span: Any | None = ...) -> None: ... - -class ProducerStore(Stmt): - def __init__(self, producer, value, indices, span: Any | None = ...) -> None: ... - -class Allocate(Stmt): - def __init__( - self, - buffer_var, - dtype, - extents, - condition, - body, - annotations: Any | None = ..., - span: Any | None = ..., - ) -> None: ... - -class AttrStmt(Stmt): - def __init__(self, node, attr_key, value, body, span: Any | None = ...) -> None: ... - -class ProducerRealize(Stmt): - def __init__( - self, producer, bounds, condition, body, storage_scope: str = ..., span: Any | None = ... - ) -> None: ... - -class SeqStmt(Stmt): - def __init__(self, seq, span: Any | None = ...) -> None: ... - def __getitem__(self, i): ... - def __len__(self): ... - -class IfThenElse(Stmt): - def __init__(self, condition, then_case, else_case, span: Any | None = ...) -> None: ... - -class Evaluate(Stmt): - def __init__(self, value, span: Any | None = ...) -> None: ... - -class Prefetch(Stmt): - def __init__(self, buffer, bounds, span: Any | None = ...) -> None: ... - -class BufferRegion(Object): - buffer: Buffer - region: List[Range] - def __init__(self, buffer: Buffer, region: List[Range]) -> None: ... - -class MatchBufferRegion(Object): - buffer: Buffer - source: BufferRegion - def __init__(self, buffer: Buffer, source: BufferRegion) -> None: ... - -class Block(Stmt): - iter_vars: List[IterVar] - reads: List[BufferRegion] - writes: List[BufferRegion] - name_hint: str - body: Stmt - init: Optional[Stmt] - alloc_buffers: Optional[List[Buffer]] - match_buffers: Optional[List[MatchBufferRegion]] - annotations: Optional[Mapping[str, Object]] - span: Optional[Span] - def __init__( - self, - iter_vars: List[IterVar], - reads: List[BufferRegion], - writes: List[BufferRegion], - name_hint: str, - body: Stmt, - init: Optional[Stmt] = ..., - alloc_buffers: Optional[List[Buffer]] = ..., - match_buffers: Optional[List[MatchBufferRegion]] = ..., - annotations: Optional[Mapping[str, Object]] = ..., - span: Optional[Span] = ..., - ) -> None: ... - -class BlockRealize(Stmt): - iter_values: List[PrimExpr] - predicate: PrimExpr - block: Block - span: Optional[Span] - def __init__( - self, - iter_values: List[PrimExpr], - predicate: Union[PrimExpr, bool], - block: Block, - span: Optional[Span] = ..., - ) -> None: ... - -def stmt_seq(*args): ... -def stmt_list(stmt): ... From 63e2ec2a67ab2236d4a9e8c8346d78adeb1eb7f0 Mon Sep 17 00:00:00 2001 From: shingjan Date: Tue, 2 Nov 2021 15:48:54 -0700 Subject: [PATCH 06/32] fix linting --- python/tvm/script/tir/__init__.pyi | 2 +- python/tvm/tir/schedule/testing.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/python/tvm/script/tir/__init__.pyi b/python/tvm/script/tir/__init__.pyi index d21db4631511..2913e3a4873c 100644 --- a/python/tvm/script/tir/__init__.pyi +++ b/python/tvm/script/tir/__init__.pyi @@ -236,4 +236,4 @@ Annotations def func_attr(attrs: Dict) -> None: ... def block_attr(attrs: Dict) -> None: ... def attr(node: PrimExpr, attr_key: str, value: PrimExpr) -> None: ... -def prim_func(input_func: Callable) -> PrimFunc: ... \ No newline at end of file +def prim_func(input_func: Callable) -> PrimFunc: ... diff --git a/python/tvm/tir/schedule/testing.py b/python/tvm/tir/schedule/testing.py index 66ede31f4103..586b0c17b312 100644 --- a/python/tvm/tir/schedule/testing.py +++ b/python/tvm/tir/schedule/testing.py @@ -20,15 +20,15 @@ from tvm import tir from tvm.ir import IRModule, structural_equal from tvm.tir import PrimFunc -from tvm.tir.schedule import Trace +from tvm.tir.schedule import Trace, Schedule def verify_trace_roundtrip( - sch: tir.Schedule, + sch: Schedule, mod: Union[PrimFunc, IRModule], *, debug_mask: Union[str, int] = "all", -) -> tir.Schedule: +) -> Schedule: """Serialize a traced schedule to JSON, then replay the JSON trace by applying to a fresh new schedule, verifying the reproducibility of scheduling. @@ -51,7 +51,7 @@ def verify_trace_roundtrip( assert trace is not None json_obj = trace.as_json() # Step 2. Apply the JSON trace to a new schedule, then check if it reproduces the scheduling - new_sch = tir.Schedule(mod=mod, debug_mask=debug_mask) + new_sch = Schedule(mod=mod, debug_mask=debug_mask) Trace.apply_json_to_schedule(json_obj=json_obj, sch=new_sch) assert structural_equal(new_sch.mod, sch.mod) # Step 3. Check the consistency of the text format between the old and new traces From a44071d0cf53328cff21ba289cd77f6761e675d9 Mon Sep 17 00:00:00 2001 From: shingjan Date: Tue, 2 Nov 2021 15:51:50 -0700 Subject: [PATCH 07/32] fix lint --- python/tvm/tir/schedule/testing.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/tir/schedule/testing.py b/python/tvm/tir/schedule/testing.py index 586b0c17b312..04cbffcd4d87 100644 --- a/python/tvm/tir/schedule/testing.py +++ b/python/tvm/tir/schedule/testing.py @@ -17,7 +17,6 @@ """Testing utilities for the TensorIR schedule API""" from typing import Union -from tvm import tir from tvm.ir import IRModule, structural_equal from tvm.tir import PrimFunc from tvm.tir.schedule import Trace, Schedule From 57a63463feaecaf6f346fd17e814983381e98535 Mon Sep 17 00:00:00 2001 From: shingjan Date: Tue, 2 Nov 2021 16:53:31 -0700 Subject: [PATCH 08/32] new test case --- tests/python/unittest/test_tvmscript_type.py | 42 ++++++++++++++------ 1 file changed, 29 insertions(+), 13 deletions(-) diff --git a/tests/python/unittest/test_tvmscript_type.py b/tests/python/unittest/test_tvmscript_type.py index 437674cb5c77..b97252e93294 100644 --- a/tests/python/unittest/test_tvmscript_type.py +++ b/tests/python/unittest/test_tvmscript_type.py @@ -21,24 +21,40 @@ from tvm import tir from tvm.script import tir as T +""" +This module tests the type of +T.prim_func, T.handle, T.match_buffer, T.block +T.reads, T.writes, T.alloc_buffer, T.serial +T.block_attr, T.float32 +""" + @pytest.mark.mypy_testing @tvm.script.ir_module class Module: @T.prim_func - def func(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, [128, 128], dtype="float32") - B = T.match_buffer(b, [128, 128], dtype="float32") - C = T.match_buffer(c, [128, 128], dtype="float32") - reveal_type(A) # R: - for i, j, k in T.grid(128, 128, T.reduce_axis(0, 128)): - with T.block("C"): - C[i, j] = T.if_then_else( - i == 0 and j == 0 and k == 0, - 0.0, - C[i, j] + A[i, k] * B[k, j], - dtype="float32", - ) + def element_wise_storage_align(a: T.handle, c: T.handle) -> None: + C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) + A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) + # body + with T.block("root"): + T.reads([]) + T.writes([]) + B = T.alloc_buffer([128, 128], elem_offset=0, align=128, offset_factor=1) + for i0 in T.serial(0, 128): + for ax1 in T.serial(0, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i0, ax1]) + T.reads([A[vi, vj]]) + T.writes([B[vi, vj]]) + T.block_attr({"buffer_dim_align": [[0, 0, 128, 127]]}) + B[vi, vj] = A[vi, vj] * T.float32(2) + for i1 in T.serial(0, 128): + with T.block("C"): + vi_1, vj_1 = T.axis.remap("SS", [i0, i1]) + T.reads([B[vi_1, vj_1]]) + T.writes([C[vi_1, vj_1]]) + C[vi_1, vj_1] = B[vi_1, vj_1] + T.float32(1) if __name__ == "__main__": From d14fd696c0213a91adf5bbed56c532d1ab5419fa Mon Sep 17 00:00:00 2001 From: shingjan Date: Wed, 3 Nov 2021 12:25:38 -0700 Subject: [PATCH 09/32] add axis module --- mypy.ini | 8 ++ python/tvm/script/tir/__init__.pyi | 116 +++++++++++++++++------------ python/tvm/script/tir/axis.py | 17 +++++ python/tvm/script/tir/axis.pyi | 27 +++++++ 4 files changed, 121 insertions(+), 47 deletions(-) create mode 100644 python/tvm/script/tir/axis.py create mode 100644 python/tvm/script/tir/axis.pyi diff --git a/mypy.ini b/mypy.ini index 02564a85469e..50a4c5e31820 100644 --- a/mypy.ini +++ b/mypy.ini @@ -23,6 +23,14 @@ follow_imports = skip ignore_errors = False strict_optional = False +# +# Note: not all tests under .tests/ are typed +# Therefore include test files that should be +# checked by mypy here +# +files = + tests/python/unittest/test_tvmscript_type.py + [mypy-python.tvm.auto_scheduler.*] ignore_errors = True diff --git a/python/tvm/script/tir/__init__.pyi b/python/tvm/script/tir/__init__.pyi index 2913e3a4873c..fb93674ce27b 100644 --- a/python/tvm/script/tir/__init__.pyi +++ b/python/tvm/script/tir/__init__.pyi @@ -21,7 +21,6 @@ from typing import ( ContextManager, Dict, Iterable, - Object, Optional, Tuple, Union, @@ -29,9 +28,24 @@ from typing import ( List, Mapping, ) + from tvm.tir.function import PrimFunc -from tvm.tir import PrimExpr, Buffer, IterVar, Var +from tvm.tir import PrimExpr, Range, IterVar, Var +from tvm.runtime import Object from .node import BufferSlice +from . import axis +from .ty import ConcreteType + +""" +redefine types +""" + +class Buffer(Var): + def __getitem__(self: Buffer, pos: Tuple[Union[int, PrimExpr]]) -> Buffer: ... + @property + def data(self: Buffer) -> Ptr: ... + +class Ptr: ... """ Variables and constants @@ -68,10 +82,6 @@ def iter_var(var, dom, iter_type, thread_tag): ... def max(a, b): ... def min(a, b): ... def get_axis(begin, end, iter_type): ... -def range(begin, end): ... -def reduce_axis(begin, end): ... -def scan_axis(begin, end): ... -def opaque_axis(begin, end): ... def Select(cond, if_body, else_body): ... def evaluate(value): ... def store(var, index, value, predicate=True): ... @@ -115,7 +125,7 @@ def scan_axis(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar def opaque_axis(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ... """ -Buffers +special_stmt - Buffers """ def match_buffer( @@ -154,19 +164,44 @@ def alloc_buffer( ) -> Buffer: ... """ -Reads/Writes +special_stmt - Reads/Writes """ def reads(read_regions: Union[BufferSlice, List[BufferSlice]]) -> None: ... def writes(write_region: Union[BufferSlice, List[BufferSlice]]) -> None: ... def block_attr(attrs: Mapping[str, Object]) -> None: ... +""" +special_stmt - Axis +""" + +def axis_spatial(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ... +def axis_reduce(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ... +def axis_scan(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ... +def axis_opaque(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ... +def axis_remap(iter_types: str, loop_vars: List[Var]) -> IterVar: ... + +""" +special_stmt - Annotations +""" + +def buffer_var(dtype, storage_scope) -> IterVar: ... +def func_attr(attrs: Dict) -> None: ... +def prim_func(input_func: Callable) -> PrimFunc: ... + +""" +special_stmt - Threads and Bindings +""" + +def env_thread(env_name: str) -> IterVar: ... +def bind(iter_var: IterVar, expr: PrimExpr) -> None: ... + """ Scope handler """ class block(ContextManager): - def __init__(self, axes: Sequence[Union[int, PrimExpr, slice]], name: str = "") -> None: ... + def __init__(self, name_hint: str = "") -> None: ... def __enter__(self) -> Sequence[IterVar]: ... class init(ContextManager): @@ -176,64 +211,51 @@ class let(ContextManager): def __init__(self, var: Var, value: PrimExpr) -> None: ... def where(cond: PrimExpr) -> None: ... -def allocate(extents, dtype, scope: str, condition: bool = True, annotations=None) -> None: ... +def allocate(extents, dtype, scope: str, condition=True, annotations=None) -> None: ... def launch_thread(env_var, extent): ... -def realize(buffer_slice: BufferSlice, scope: str, condition: bool = True) -> None: ... -def attr(attr_node, attr_key, value) -> None: ... +def realize(buffer_slice: BufferSlice, scope: str, condition=True) -> None: ... +def attr(node: PrimExpr, attr_key: str, value: PrimExpr) -> None: ... def Assert(condition, message): ... -def let(var, value): ... -def block(name_hint: str = ""): ... -def init(): ... """ Scope handler - Loops """ def serial( - begin: PrimExpr, - end: PrimExpr, + begin: Union[PrimExpr, int], + end: Union[PrimExpr, int], annotations: Optional[Mapping[str, Object]] = None, -) -> None: ... +) -> Iterable[IterVar]: ... def parallel( - begin: PrimExpr, - end: PrimExpr, + begin: Union[PrimExpr, int], + end: Union[PrimExpr, int], annotations: Optional[Mapping[str, Object]] = None, -) -> None: ... +) -> Iterable[IterVar]: ... def vectorized( - begin: PrimExpr, - end: PrimExpr, + begin: Union[PrimExpr, int], + end: Union[PrimExpr, int], annotations: Optional[Mapping[str, Object]] = None, -) -> None: ... +) -> Iterable[IterVar]: ... def unroll( - begin: PrimExpr, - end: PrimExpr, + begin: Union[PrimExpr, int], + end: Union[PrimExpr, int], annotations: Optional[Mapping[str, Object]] = None, -) -> None: ... +) -> Iterable[IterVar]: ... def thread_binding( - begin: PrimExpr, - end: PrimExpr, + begin: Union[PrimExpr, int], + end: Union[PrimExpr, int], thread: str, annotations: Optional[Mapping[str, Object]] = None, -) -> None: ... +) -> Iterable[IterVar]: ... def for_range( - begin: PrimExpr, - end: PrimExpr = None, + begin: Union[PrimExpr, int], + end: Union[PrimExpr, int] = None, annotations: Optional[Mapping[str, Object]] = None, -) -> None: ... -def grid(*extents: List[PrimExpr]) -> None: ... +) -> Iterable[IterVar]: ... +def grid(*extents: List[Union[PrimExpr, int]]) -> Iterable[Tuple[IterVar]]: ... """ -Threads and Bindings +ty """ - -def env_thread(thread: str) -> IterVar: ... -def bind(iter_var: IterVar, expr: PrimExpr) -> None: ... - -""" -Annotations -""" - -def func_attr(attrs: Dict) -> None: ... -def block_attr(attrs: Dict) -> None: ... -def attr(node: PrimExpr, attr_key: str, value: PrimExpr) -> None: ... -def prim_func(input_func: Callable) -> PrimFunc: ... +boolean = ConcreteType("bool") +handle = ConcreteType("handle") \ No newline at end of file diff --git a/python/tvm/script/tir/axis.py b/python/tvm/script/tir/axis.py new file mode 100644 index 000000000000..becec8d68517 --- /dev/null +++ b/python/tvm/script/tir/axis.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=redefined-builtin \ No newline at end of file diff --git a/python/tvm/script/tir/axis.pyi b/python/tvm/script/tir/axis.pyi new file mode 100644 index 000000000000..0d6a5239406a --- /dev/null +++ b/python/tvm/script/tir/axis.pyi @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=redefined-builtin +from typing import Tuple, Union, List +from tvm.tir import PrimExpr, IterVar, Var + +def spatial(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ... +def S(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], iter_value) -> IterVar: ... +def reduce(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ... +def R(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ... +def scan(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ... +def opaque(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ... +def remap(iter_types: str, loop_vars: List[Var]) -> IterVar: ... \ No newline at end of file From f52c252b0490880317ae890c370859824f1fc0b9 Mon Sep 17 00:00:00 2001 From: shingjan Date: Wed, 3 Nov 2021 13:40:37 -0700 Subject: [PATCH 10/32] address comments --- python/tvm/script/tir/__init__.pyi | 57 ++++++++++---------- python/tvm/script/tir/axis.py | 17 ------ python/tvm/script/tir/axis.pyi | 27 ---------- tests/python/unittest/test_tvmscript_type.py | 1 - 4 files changed, 29 insertions(+), 73 deletions(-) delete mode 100644 python/tvm/script/tir/axis.py delete mode 100644 python/tvm/script/tir/axis.pyi diff --git a/python/tvm/script/tir/__init__.pyi b/python/tvm/script/tir/__init__.pyi index fb93674ce27b..2f3523e2af49 100644 --- a/python/tvm/script/tir/__init__.pyi +++ b/python/tvm/script/tir/__init__.pyi @@ -28,12 +28,12 @@ from typing import ( List, Mapping, ) +from numbers import Number from tvm.tir.function import PrimFunc from tvm.tir import PrimExpr, Range, IterVar, Var from tvm.runtime import Object from .node import BufferSlice -from . import axis from .ty import ConcreteType """ @@ -51,19 +51,19 @@ class Ptr: ... Variables and constants """ -def bool(imm: int) -> PrimExpr: ... -def int8(imm: int) -> PrimExpr: ... -def int16(imm: int) -> PrimExpr: ... -def int32(imm: int) -> PrimExpr: ... -def int64(imm: int) -> PrimExpr: ... -def uint8(imm: int) -> PrimExpr: ... -def uint16(imm: int) -> PrimExpr: ... -def uint32(imm: int) -> PrimExpr: ... -def uint64(imm: int) -> PrimExpr: ... -def float8(imm: int) -> PrimExpr: ... -def float16(imm: int) -> PrimExpr: ... -def float32(imm: int) -> PrimExpr: ... -def float64(imm: int) -> PrimExpr: ... +def bool(imm) -> PrimExpr: ... +def int8(imm: Union[PrimExpr, Number]) -> PrimExpr: ... +def int16(imm: Union[PrimExpr, Number]) -> PrimExpr: ... +def int32(imm: Union[PrimExpr, Number]) -> PrimExpr: ... +def int64(imm: Union[PrimExpr, Number]) -> PrimExpr: ... +def uint8(imm: Union[PrimExpr, Number]) -> PrimExpr: ... +def uint16(imm: Union[PrimExpr, Number]) -> PrimExpr: ... +def uint32(imm: Union[PrimExpr, Number]) -> PrimExpr: ... +def uint64(imm: Union[PrimExpr, Number]) -> PrimExpr: ... +def float8(imm: Union[PrimExpr, Number]) -> PrimExpr: ... +def float16(imm: Union[PrimExpr, Number]) -> PrimExpr: ... +def float32(imm: Union[PrimExpr, Number]) -> PrimExpr: ... +def float64(imm: Union[PrimExpr, Number]) -> PrimExpr: ... """ Intrinsic @@ -115,15 +115,6 @@ def atan2(x: PrimExpr) -> PrimExpr: ... def sqrt(x: PrimExpr) -> PrimExpr: ... def rsqrt(x: PrimExpr) -> PrimExpr: ... -""" -Axis -""" - -def reduce_axis(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ... -def range(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ... -def scan_axis(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ... -def opaque_axis(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ... - """ special_stmt - Buffers """ @@ -175,11 +166,21 @@ def block_attr(attrs: Mapping[str, Object]) -> None: ... special_stmt - Axis """ -def axis_spatial(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ... -def axis_reduce(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ... -def axis_scan(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ... -def axis_opaque(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ... -def axis_remap(iter_types: str, loop_vars: List[Var]) -> IterVar: ... +class axis: + @staticmethod + def spatial(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ... + @staticmethod + def S(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ... + @staticmethod + def reduce(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ... + @staticmethod + def R(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ... + @staticmethod + def scan(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ... + @staticmethod + def opaque(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ... + @staticmethod + def remap(iter_types: str, loop_vars: List[Var]) -> IterVar: ... """ special_stmt - Annotations diff --git a/python/tvm/script/tir/axis.py b/python/tvm/script/tir/axis.py deleted file mode 100644 index becec8d68517..000000000000 --- a/python/tvm/script/tir/axis.py +++ /dev/null @@ -1,17 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=redefined-builtin \ No newline at end of file diff --git a/python/tvm/script/tir/axis.pyi b/python/tvm/script/tir/axis.pyi deleted file mode 100644 index 0d6a5239406a..000000000000 --- a/python/tvm/script/tir/axis.pyi +++ /dev/null @@ -1,27 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=redefined-builtin -from typing import Tuple, Union, List -from tvm.tir import PrimExpr, IterVar, Var - -def spatial(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ... -def S(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], iter_value) -> IterVar: ... -def reduce(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ... -def R(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ... -def scan(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ... -def opaque(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ... -def remap(iter_types: str, loop_vars: List[Var]) -> IterVar: ... \ No newline at end of file diff --git a/tests/python/unittest/test_tvmscript_type.py b/tests/python/unittest/test_tvmscript_type.py index b97252e93294..97c9a5a78180 100644 --- a/tests/python/unittest/test_tvmscript_type.py +++ b/tests/python/unittest/test_tvmscript_type.py @@ -29,7 +29,6 @@ """ -@pytest.mark.mypy_testing @tvm.script.ir_module class Module: @T.prim_func From 2424d0ff987c3e06b7768f862a3bb0c86769a7f2 Mon Sep 17 00:00:00 2001 From: shingjan Date: Wed, 3 Nov 2021 13:45:25 -0700 Subject: [PATCH 11/32] redefine ty types --- python/tvm/script/tir/__init__.pyi | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/python/tvm/script/tir/__init__.pyi b/python/tvm/script/tir/__init__.pyi index 2f3523e2af49..022bd5a83fc5 100644 --- a/python/tvm/script/tir/__init__.pyi +++ b/python/tvm/script/tir/__init__.pyi @@ -45,8 +45,6 @@ class Buffer(Var): @property def data(self: Buffer) -> Ptr: ... -class Ptr: ... - """ Variables and constants """ @@ -256,7 +254,18 @@ def for_range( def grid(*extents: List[Union[PrimExpr, int]]) -> Iterable[Tuple[IterVar]]: ... """ -ty +ty - redefine types """ -boolean = ConcreteType("bool") -handle = ConcreteType("handle") \ No newline at end of file + +class boolean: ... +class handle: ... +class int8: ... +class int16: ... +class int32: ... +class int64: ... +class float16: ... +class float32: ... +class float64: ... +class int16: ... +class Ptr: ... +class Tuple: ... \ No newline at end of file From 82e5ed4e4d6e4f1d22ee8fae89257758eca215aa Mon Sep 17 00:00:00 2001 From: shingjan Date: Wed, 3 Nov 2021 14:44:12 -0700 Subject: [PATCH 12/32] lint --- python/tvm/script/tir/__init__.pyi | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/tvm/script/tir/__init__.pyi b/python/tvm/script/tir/__init__.pyi index 022bd5a83fc5..c1cfe8e6df79 100644 --- a/python/tvm/script/tir/__init__.pyi +++ b/python/tvm/script/tir/__init__.pyi @@ -34,7 +34,6 @@ from tvm.tir.function import PrimFunc from tvm.tir import PrimExpr, Range, IterVar, Var from tvm.runtime import Object from .node import BufferSlice -from .ty import ConcreteType """ redefine types @@ -268,4 +267,4 @@ class float32: ... class float64: ... class int16: ... class Ptr: ... -class Tuple: ... \ No newline at end of file +class Tuple: ... From bdb40c484ee774b6a885630ec060a963a8f6e378 Mon Sep 17 00:00:00 2001 From: shingjan Date: Wed, 3 Nov 2021 15:55:48 -0700 Subject: [PATCH 13/32] address comments --- mypy.ini | 8 -------- tests/scripts/task_mypy.sh | 3 +++ 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/mypy.ini b/mypy.ini index 50a4c5e31820..02564a85469e 100644 --- a/mypy.ini +++ b/mypy.ini @@ -23,14 +23,6 @@ follow_imports = skip ignore_errors = False strict_optional = False -# -# Note: not all tests under .tests/ are typed -# Therefore include test files that should be -# checked by mypy here -# -files = - tests/python/unittest/test_tvmscript_type.py - [mypy-python.tvm.auto_scheduler.*] ignore_errors = True diff --git a/tests/scripts/task_mypy.sh b/tests/scripts/task_mypy.sh index aba4663d5931..081b4c305293 100755 --- a/tests/scripts/task_mypy.sh +++ b/tests/scripts/task_mypy.sh @@ -32,6 +32,9 @@ mypy --check-untyped-defs python/tvm/tir/analysis/ echo "Checking MyPy Type defs in the transform package." mypy --check-untyped-defs python/tvm/tir/transform/ +echo "Checking MyPy Type defs in the TIR package with unittest" +mypy --check-untyped-defs tests/python/unittest/test_tvmscript_type.py + #TODO(@mikepapadim): This is failing atm # echo "Checking MyPy Type defs in the tvm.relay.backend.contrib.ethosu package." # mypy --check-untyped-defs python/tvm/relay/backend/contrib/ethosu/ From cc4f319b7fafb6b47f853da65d5d906280f79c27 Mon Sep 17 00:00:00 2001 From: shingjan Date: Wed, 3 Nov 2021 23:53:38 -0700 Subject: [PATCH 14/32] address comments --- python/tvm/script/tir/__init__.pyi | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/python/tvm/script/tir/__init__.pyi b/python/tvm/script/tir/__init__.pyi index c1cfe8e6df79..6eca4fe60046 100644 --- a/python/tvm/script/tir/__init__.pyi +++ b/python/tvm/script/tir/__init__.pyi @@ -29,6 +29,7 @@ from typing import ( Mapping, ) from numbers import Number +import builtins from tvm.tir.function import PrimFunc from tvm.tir import PrimExpr, Range, IterVar, Var @@ -48,7 +49,7 @@ class Buffer(Var): Variables and constants """ -def bool(imm) -> PrimExpr: ... +def bool(imm: Union[PrimExpr, builtins.bool, Number]) -> PrimExpr: ... def int8(imm: Union[PrimExpr, Number]) -> PrimExpr: ... def int16(imm: Union[PrimExpr, Number]) -> PrimExpr: ... def int32(imm: Union[PrimExpr, Number]) -> PrimExpr: ... @@ -209,9 +210,11 @@ class let(ContextManager): def __init__(self, var: Var, value: PrimExpr) -> None: ... def where(cond: PrimExpr) -> None: ... -def allocate(extents, dtype, scope: str, condition=True, annotations=None) -> None: ... +def allocate( + extents, dtype, scope: str, condition: builtins.bool = True, annotations=None +) -> None: ... def launch_thread(env_var, extent): ... -def realize(buffer_slice: BufferSlice, scope: str, condition=True) -> None: ... +def realize(buffer_slice: BufferSlice, scope: str, condition: builtins.bool = True) -> None: ... def attr(node: PrimExpr, attr_key: str, value: PrimExpr) -> None: ... def Assert(condition, message): ... From 55f548a1a19a9188f751803e442a2fca3b3d3fdf Mon Sep 17 00:00:00 2001 From: shingjan Date: Thu, 4 Nov 2021 01:15:36 -0700 Subject: [PATCH 15/32] fix ci --- python/tvm/script/tir/__init__.pyi | 12 ++---------- tests/python/unittest/test_tvmscript_type.py | 1 - tests/scripts/task_mypy.sh | 3 ++- 3 files changed, 4 insertions(+), 12 deletions(-) diff --git a/python/tvm/script/tir/__init__.pyi b/python/tvm/script/tir/__init__.pyi index 6eca4fe60046..59a8ffa73fc0 100644 --- a/python/tvm/script/tir/__init__.pyi +++ b/python/tvm/script/tir/__init__.pyi @@ -27,6 +27,7 @@ from typing import ( Sequence, List, Mapping, + overload, ) from numbers import Number import builtins @@ -178,7 +179,7 @@ class axis: @staticmethod def opaque(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ... @staticmethod - def remap(iter_types: str, loop_vars: List[Var]) -> IterVar: ... + def remap(iter_types: str, loop_vars: List[Var]) -> List[IterVar]: ... """ special_stmt - Annotations @@ -261,13 +262,4 @@ ty - redefine types class boolean: ... class handle: ... -class int8: ... -class int16: ... -class int32: ... -class int64: ... -class float16: ... -class float32: ... -class float64: ... -class int16: ... class Ptr: ... -class Tuple: ... diff --git a/tests/python/unittest/test_tvmscript_type.py b/tests/python/unittest/test_tvmscript_type.py index 97c9a5a78180..861f469dac62 100644 --- a/tests/python/unittest/test_tvmscript_type.py +++ b/tests/python/unittest/test_tvmscript_type.py @@ -18,7 +18,6 @@ import sys import pytest import tvm -from tvm import tir from tvm.script import tir as T """ diff --git a/tests/scripts/task_mypy.sh b/tests/scripts/task_mypy.sh index 081b4c305293..0e20fc22cfb2 100755 --- a/tests/scripts/task_mypy.sh +++ b/tests/scripts/task_mypy.sh @@ -19,6 +19,7 @@ set -e set -u set -o pipefail +source tests/scripts/setup-pytest-env.sh echo "Checking MyPy Type defs in the TensorIR schedule package." mypy --check-untyped-defs python/tvm/tir/schedule @@ -33,7 +34,7 @@ echo "Checking MyPy Type defs in the transform package." mypy --check-untyped-defs python/tvm/tir/transform/ echo "Checking MyPy Type defs in the TIR package with unittest" -mypy --check-untyped-defs tests/python/unittest/test_tvmscript_type.py +MYPYPATH=$TVM_PATH/python mypy --check-untyped-defs tests/python/unittest/test_tvmscript_type.py #TODO(@mikepapadim): This is failing atm # echo "Checking MyPy Type defs in the tvm.relay.backend.contrib.ethosu package." From 3ea3edd6ee34268dcfb8ddf65e819a7af970b344 Mon Sep 17 00:00:00 2001 From: shingjan Date: Thu, 4 Nov 2021 01:29:21 -0700 Subject: [PATCH 16/32] add test cases --- python/tvm/script/tir/__init__.pyi | 2 +- tests/python/unittest/test_tvmscript_type.py | 84 +++++++++++++------- 2 files changed, 58 insertions(+), 28 deletions(-) diff --git a/python/tvm/script/tir/__init__.pyi b/python/tvm/script/tir/__init__.pyi index 59a8ffa73fc0..309333d42964 100644 --- a/python/tvm/script/tir/__init__.pyi +++ b/python/tvm/script/tir/__init__.pyi @@ -214,7 +214,7 @@ def where(cond: PrimExpr) -> None: ... def allocate( extents, dtype, scope: str, condition: builtins.bool = True, annotations=None ) -> None: ... -def launch_thread(env_var, extent): ... +def launch_thread(env_var, extent) -> None: ... def realize(buffer_slice: BufferSlice, scope: str, condition: builtins.bool = True) -> None: ... def attr(node: PrimExpr, attr_key: str, value: PrimExpr) -> None: ... def Assert(condition, message): ... diff --git a/tests/python/unittest/test_tvmscript_type.py b/tests/python/unittest/test_tvmscript_type.py index 861f469dac62..7ff632444c68 100644 --- a/tests/python/unittest/test_tvmscript_type.py +++ b/tests/python/unittest/test_tvmscript_type.py @@ -21,38 +21,68 @@ from tvm.script import tir as T """ -This module tests the type of +This prim_func tests the type of T.prim_func, T.handle, T.match_buffer, T.block T.reads, T.writes, T.alloc_buffer, T.serial -T.block_attr, T.float32 +T.block_attr, T.float32, T.axis.remap """ -@tvm.script.ir_module -class Module: - @T.prim_func - def element_wise_storage_align(a: T.handle, c: T.handle) -> None: - C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) - A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) - # body - with T.block("root"): - T.reads([]) - T.writes([]) - B = T.alloc_buffer([128, 128], elem_offset=0, align=128, offset_factor=1) - for i0 in T.serial(0, 128): - for ax1 in T.serial(0, 128): - with T.block("B"): - vi, vj = T.axis.remap("SS", [i0, ax1]) - T.reads([A[vi, vj]]) - T.writes([B[vi, vj]]) - T.block_attr({"buffer_dim_align": [[0, 0, 128, 127]]}) - B[vi, vj] = A[vi, vj] * T.float32(2) - for i1 in T.serial(0, 128): - with T.block("C"): - vi_1, vj_1 = T.axis.remap("SS", [i0, i1]) - T.reads([B[vi_1, vj_1]]) - T.writes([C[vi_1, vj_1]]) - C[vi_1, vj_1] = B[vi_1, vj_1] + T.float32(1) +@T.prim_func +def element_wise_storage_align(a: T.handle, c: T.handle) -> None: + C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) + A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) + # body + with T.block("root"): + T.reads([]) + T.writes([]) + B = T.alloc_buffer([128, 128], elem_offset=0, align=128, offset_factor=1) + for i0 in T.serial(0, 128): + for ax1 in T.serial(0, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i0, ax1]) + T.reads([A[vi, vj]]) + T.writes([B[vi, vj]]) + T.block_attr({"buffer_dim_align": [[0, 0, 128, 127]]}) + B[vi, vj] = A[vi, vj] * T.float32(2) + for i1 in T.serial(0, 128): + with T.block("C"): + vi_1, vj_1 = T.axis.remap("SS", [i0, i1]) + T.reads([B[vi_1, vj_1]]) + T.writes([C[vi_1, vj_1]]) + C[vi_1, vj_1] = B[vi_1, vj_1] + T.float32(1) + + +""" +This prim_func tests the type of +T.env_thread, T.launch_thread, T.thread_binding +""" + + +@T.prim_func +def element_wise_env_thread_x(a: T.handle, b: T.handle, c: T.handle) -> None: + j1_0 = T.env_thread("threadIdx.x") + j0_0 = T.env_thread("threadIdx.x") + i = T.env_thread("blockIdx.x") + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) + T.launch_thread(i, 128) + T.launch_thread(j0_0, 4) + T.launch_thread(j1_0, 4) + + for blockIdx_x in T.thread_binding(0, 128, "blockIdx.x"): + for threadIdx_x in T.thread_binding(0, 4, "threadIdx.x"): + for j0_1 in T.serial(0, 32): + with T.block(""): + B[blockIdx_x, threadIdx_x * 32 + j0_1] = ( + A[blockIdx_x, threadIdx_x * 32 + j0_1] * 2.0 + ) + for j1_1 in T.serial(0, 32): + with T.block(""): + C[blockIdx_x, threadIdx_x * 32 + j1_1] = ( + B[blockIdx_x, threadIdx_x * 32 + j1_1] + 1.0 + ) if __name__ == "__main__": From ddf594bf53198df26efbf622db38f1dd965c39ee Mon Sep 17 00:00:00 2001 From: shingjan Date: Thu, 4 Nov 2021 10:47:26 -0700 Subject: [PATCH 17/32] fix CI --- python/tvm/script/tir/__init__.pyi | 4 +++- tests/python/unittest/test_tvmscript_type.py | 5 ++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/python/tvm/script/tir/__init__.pyi b/python/tvm/script/tir/__init__.pyi index 309333d42964..a374d1764ebe 100644 --- a/python/tvm/script/tir/__init__.pyi +++ b/python/tvm/script/tir/__init__.pyi @@ -42,7 +42,9 @@ redefine types """ class Buffer(Var): - def __getitem__(self: Buffer, pos: Tuple[Union[int, PrimExpr]]) -> Buffer: ... + def __getitem__( + self: Buffer, pos: Tuple[Union[int, PrimExpr], Union[int, PrimExpr]] + ) -> Buffer: ... @property def data(self: Buffer) -> Ptr: ... diff --git a/tests/python/unittest/test_tvmscript_type.py b/tests/python/unittest/test_tvmscript_type.py index 7ff632444c68..3c5702ee7249 100644 --- a/tests/python/unittest/test_tvmscript_type.py +++ b/tests/python/unittest/test_tvmscript_type.py @@ -16,8 +16,6 @@ # under the License. # pylint: disable=missing-function-docstring,missing-module-docstring import sys -import pytest -import tvm from tvm.script import tir as T """ @@ -85,5 +83,6 @@ def element_wise_env_thread_x(a: T.handle, b: T.handle, c: T.handle) -> None: ) +# Not running any test as we only want to type-check here if __name__ == "__main__": - sys.exit(pytest.main([__file__] + sys.argv[1:])) + pass From 5b8ff23db753b1350e7fb653e403dafbd1db8711 Mon Sep 17 00:00:00 2001 From: shingjan Date: Thu, 4 Nov 2021 14:12:56 -0700 Subject: [PATCH 18/32] address comments --- python/tvm/script/tir/__init__.pyi | 61 ++++++++++++-------- tests/python/unittest/test_tvmscript_type.py | 1 - 2 files changed, 36 insertions(+), 26 deletions(-) diff --git a/python/tvm/script/tir/__init__.pyi b/python/tvm/script/tir/__init__.pyi index a374d1764ebe..65c1e4c2a20f 100644 --- a/python/tvm/script/tir/__init__.pyi +++ b/python/tvm/script/tir/__init__.pyi @@ -42,9 +42,7 @@ redefine types """ class Buffer(Var): - def __getitem__( - self: Buffer, pos: Tuple[Union[int, PrimExpr], Union[int, PrimExpr]] - ) -> Buffer: ... + def __getitem__(self: Buffer, *pos: Union[int, PrimExpr, slice, Any]) -> Buffer: ... @property def data(self: Buffer) -> Ptr: ... @@ -70,23 +68,25 @@ def float64(imm: Union[PrimExpr, Number]) -> PrimExpr: ... Intrinsic """ -def min_value(dtype): ... -def max_value(dtype): ... +def min_value(dtype: str): ... +def max_value(dtype: str): ... def floordiv(x: PrimExpr, y: PrimExpr): ... def floormod(x: PrimExpr, y: PrimExpr): ... -def abs(x): ... -def load(dtype, var, index, predicate=None): ... -def cast(value, dtype): ... -def ramp(base, stride, lanes): ... -def broadcast(value, lanes): ... -def iter_var(var, dom, iter_type, thread_tag): ... -def max(a, b): ... -def min(a, b): ... -def get_axis(begin, end, iter_type): ... -def Select(cond, if_body, else_body): ... -def evaluate(value): ... -def store(var, index, value, predicate=True): ... -def comm_reducer(lambda_io, identities): ... +def abs(x: PrimExpr): ... +def load(dtype: str, var: Var, index: PrimExpr, predicate: PrimExpr = None): ... +def cast(value: PrimExpr, dtype: str): ... +def ramp(base: PrimExpr, stride: Any, lanes: int): ... +def broadcast(value: PrimExpr, lanes: int): ... +def iter_var(var: Union[Var, str], dom: Range, iter_type: int, thread_tag: str): ... +def max(a: PrimExpr, b: PrimExpr): ... +def min(a: PrimExpr, b: PrimExpr): ... +def get_axis(begin: PrimExpr, end: PrimExpr, iter_type: int): ... +def Select(cond: PrimExpr, if_body: PrimExpr, else_body: PrimExpr): ... +def if_then_else(cond: PrimExpr, t: PrimExpr, f: PrimExpr, dtype: str) -> PrimExpr: ... +def evaluate(value: PrimExpr): ... +def reinterpret(value: PrimExpr, dtype: str): ... +def store(var: Var, index: PrimExpr, value: PrimExpr, predicate: PrimExpr = True): ... +def comm_reducer(lambda_io: Tuple[List[PrimExpr]], identities: List[PrimExpr]): ... """ Unary operator @@ -187,7 +187,7 @@ class axis: special_stmt - Annotations """ -def buffer_var(dtype, storage_scope) -> IterVar: ... +def buffer_var(dtype: str, storage_scope: str) -> Var: ... def func_attr(attrs: Dict) -> None: ... def prim_func(input_func: Callable) -> PrimFunc: ... @@ -214,12 +214,18 @@ class let(ContextManager): def where(cond: PrimExpr) -> None: ... def allocate( - extents, dtype, scope: str, condition: builtins.bool = True, annotations=None + extents: List[PrimExpr], + dtype: str, + scope: str, + condition: Union[PrimExpr, builtins.bool] = True, + annotations: Optional[Mapping[str, Object]] = None, +) -> Var: ... +def launch_thread(env_var: Var, extent: PrimExpr) -> Var: ... +def realize( + buffer_slice: BufferSlice, scope: str, condition: Union[PrimExpr, builtins.bool] = True ) -> None: ... -def launch_thread(env_var, extent) -> None: ... -def realize(buffer_slice: BufferSlice, scope: str, condition: builtins.bool = True) -> None: ... def attr(node: PrimExpr, attr_key: str, value: PrimExpr) -> None: ... -def Assert(condition, message): ... +def Assert(condition: Union[PrimExpr, builtins.bool], message: str): ... """ Scope handler - Loops @@ -256,12 +262,17 @@ def for_range( end: Union[PrimExpr, int] = None, annotations: Optional[Mapping[str, Object]] = None, ) -> Iterable[IterVar]: ... -def grid(*extents: List[Union[PrimExpr, int]]) -> Iterable[Tuple[IterVar]]: ... +def grid(*extents: Union[PrimExpr, int]) -> Iterable[Tuple[IterVar]]: ... """ ty - redefine types """ class boolean: ... -class handle: ... + +class handle: + def __getitem__(self: handle, pos: Tuple[Union[Number, PrimExpr, slice]]) -> Buffer: ... + @property + def data(self: handle) -> Ptr: ... + class Ptr: ... diff --git a/tests/python/unittest/test_tvmscript_type.py b/tests/python/unittest/test_tvmscript_type.py index 3c5702ee7249..3fdd33dcbe15 100644 --- a/tests/python/unittest/test_tvmscript_type.py +++ b/tests/python/unittest/test_tvmscript_type.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-function-docstring,missing-module-docstring -import sys from tvm.script import tir as T """ From 1a9ed897ff52c90923fbd539f5cf4d24ee4c9c75 Mon Sep 17 00:00:00 2001 From: shingjan Date: Thu, 4 Nov 2021 14:15:29 -0700 Subject: [PATCH 19/32] add types --- python/tvm/script/tir/__init__.pyi | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/tvm/script/tir/__init__.pyi b/python/tvm/script/tir/__init__.pyi index 65c1e4c2a20f..26a0cf56a6af 100644 --- a/python/tvm/script/tir/__init__.pyi +++ b/python/tvm/script/tir/__init__.pyi @@ -124,7 +124,7 @@ def match_buffer( param: Union[Var, BufferSlice], shape: Sequence[Union[PrimExpr, int]], dtype: str = "float32", - data=None, + data: Var = None, strides: Optional[Sequence[int]] = None, elem_offset: Optional[int] = None, scope: str = "global", @@ -135,7 +135,7 @@ def match_buffer( def buffer_decl( shape: Sequence[Union[PrimExpr, int]], dtype: str = "float32", - data=None, + data: Var = None, strides: Optional[Sequence[int]] = None, elem_offset: Optional[int] = None, scope: str = "global", @@ -146,7 +146,7 @@ def buffer_decl( def alloc_buffer( shape: Sequence[Union[PrimExpr, int]], dtype: str = "float32", - data=None, + data: Var = None, strides: Optional[Sequence[int]] = None, elem_offset: Optional[int] = None, scope: str = "global", From e89ae32d1766b5fe78d82f4c970623f9d5243c32 Mon Sep 17 00:00:00 2001 From: shingjan Date: Thu, 4 Nov 2021 17:02:28 -0700 Subject: [PATCH 20/32] mypy --strict --- python/tvm/script/tir/__init__.pyi | 47 +++++++++++--------- python/tvm/script/tir/ty.py | 3 ++ tests/python/unittest/test_tvmscript_type.py | 2 +- 3 files changed, 31 insertions(+), 21 deletions(-) diff --git a/python/tvm/script/tir/__init__.pyi b/python/tvm/script/tir/__init__.pyi index 26a0cf56a6af..f318ee47a483 100644 --- a/python/tvm/script/tir/__init__.pyi +++ b/python/tvm/script/tir/__init__.pyi @@ -68,25 +68,26 @@ def float64(imm: Union[PrimExpr, Number]) -> PrimExpr: ... Intrinsic """ -def min_value(dtype: str): ... -def max_value(dtype: str): ... -def floordiv(x: PrimExpr, y: PrimExpr): ... -def floormod(x: PrimExpr, y: PrimExpr): ... -def abs(x: PrimExpr): ... -def load(dtype: str, var: Var, index: PrimExpr, predicate: PrimExpr = None): ... -def cast(value: PrimExpr, dtype: str): ... -def ramp(base: PrimExpr, stride: Any, lanes: int): ... -def broadcast(value: PrimExpr, lanes: int): ... -def iter_var(var: Union[Var, str], dom: Range, iter_type: int, thread_tag: str): ... -def max(a: PrimExpr, b: PrimExpr): ... -def min(a: PrimExpr, b: PrimExpr): ... -def get_axis(begin: PrimExpr, end: PrimExpr, iter_type: int): ... -def Select(cond: PrimExpr, if_body: PrimExpr, else_body: PrimExpr): ... +def min_value(dtype: str) -> PrimExpr: ... +def max_value(dtype: str) -> PrimExpr: ... +def floordiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ... +def floormod(x: PrimExpr, y: PrimExpr) -> PrimExpr: ... +def truncmod(x: PrimExpr, y: PrimExpr) -> PrimExpr: ... +def truncdiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ... +def abs(x: PrimExpr) -> PrimExpr: ... +def load(dtype: str, var: Var, index: PrimExpr, predicate: PrimExpr = None) -> PrimExpr: ... +def cast(value: PrimExpr, dtype: str) -> PrimExpr: ... +def ramp(base: PrimExpr, stride: Any, lanes: int) -> PrimExpr: ... +def broadcast(value: PrimExpr, lanes: int) -> PrimExpr: ... +def iter_var(var: Union[Var, str], dom: Range, iter_type: int, thread_tag: str) -> IterVar: ... +def max(a: PrimExpr, b: PrimExpr) -> PrimExpr: ... +def min(a: PrimExpr, b: PrimExpr) -> PrimExpr: ... +def Select(cond: PrimExpr, if_body: PrimExpr, else_body: PrimExpr) -> PrimExpr: ... def if_then_else(cond: PrimExpr, t: PrimExpr, f: PrimExpr, dtype: str) -> PrimExpr: ... -def evaluate(value: PrimExpr): ... -def reinterpret(value: PrimExpr, dtype: str): ... -def store(var: Var, index: PrimExpr, value: PrimExpr, predicate: PrimExpr = True): ... -def comm_reducer(lambda_io: Tuple[List[PrimExpr]], identities: List[PrimExpr]): ... +def evaluate(value: PrimExpr) -> PrimExpr: ... +def reinterpret(value: PrimExpr, dtype: str) -> PrimExpr: ... +def store(var: Var, index: PrimExpr, value: PrimExpr, predicate: PrimExpr = True) -> None: ... +def comm_reducer(lambda_io: Tuple[List[PrimExpr]], identities: List[PrimExpr]) -> PrimExpr: ... """ Unary operator @@ -183,12 +184,14 @@ class axis: @staticmethod def remap(iter_types: str, loop_vars: List[Var]) -> List[IterVar]: ... +def get_axis(begin: PrimExpr, end: PrimExpr, iter_type: int) -> IterVar: ... + """ special_stmt - Annotations """ def buffer_var(dtype: str, storage_scope: str) -> Var: ... -def func_attr(attrs: Dict) -> None: ... +def func_attr(attrs: Mapping[str, Object]) -> None: ... def prim_func(input_func: Callable) -> PrimFunc: ... """ @@ -225,7 +228,7 @@ def realize( buffer_slice: BufferSlice, scope: str, condition: Union[PrimExpr, builtins.bool] = True ) -> None: ... def attr(node: PrimExpr, attr_key: str, value: PrimExpr) -> None: ... -def Assert(condition: Union[PrimExpr, builtins.bool], message: str): ... +def Assert(condition: Union[PrimExpr, builtins.bool], message: str) -> PrimExpr: ... """ Scope handler - Loops @@ -275,4 +278,8 @@ class handle: @property def data(self: handle) -> Ptr: ... +# class float32: +# def __new__(self, imm: Union[PrimExpr, Number]) -> PrimExpr: ... +# def __init__(self, imm: Union[PrimExpr, Number]) -> None: ... + class Ptr: ... diff --git a/python/tvm/script/tir/ty.py b/python/tvm/script/tir/ty.py index 6a4f7bc00cb6..d72e971beae2 100644 --- a/python/tvm/script/tir/ty.py +++ b/python/tvm/script/tir/ty.py @@ -30,6 +30,9 @@ def evaluate(self): """Return an actual ir.Type Object that this Generic class wraps""" raise TypeError("Cannot get tvm.Type from a generic type") + def __call__(self): + raise NotImplementedError + class ConcreteType(TypeGeneric): # pylint: disable=too-few-public-methods """TVM script typing class for uniform Type objects""" diff --git a/tests/python/unittest/test_tvmscript_type.py b/tests/python/unittest/test_tvmscript_type.py index 3fdd33dcbe15..951599c7df3e 100644 --- a/tests/python/unittest/test_tvmscript_type.py +++ b/tests/python/unittest/test_tvmscript_type.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=missing-function-docstring,missing-module-docstring +# pylint: disable=missing-function-docstring,missing-module-docstring,invalid-name,pointless-string-statement from tvm.script import tir as T """ From b6ed67d2afb09c461514f2b31fc4404ac051fdec Mon Sep 17 00:00:00 2001 From: shingjan Date: Thu, 4 Nov 2021 17:08:51 -0700 Subject: [PATCH 21/32] comments --- python/tvm/script/tir/__init__.pyi | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/script/tir/__init__.pyi b/python/tvm/script/tir/__init__.pyi index f318ee47a483..253135cf43be 100644 --- a/python/tvm/script/tir/__init__.pyi +++ b/python/tvm/script/tir/__init__.pyi @@ -42,7 +42,7 @@ redefine types """ class Buffer(Var): - def __getitem__(self: Buffer, *pos: Union[int, PrimExpr, slice, Any]) -> Buffer: ... + def __getitem__(self: Buffer, *pos: Union[int, PrimExpr, slice, Any]) -> PrimExpr: ... @property def data(self: Buffer) -> Ptr: ... From 77e7839fae7f85438babb37527ebfec847e9b4b7 Mon Sep 17 00:00:00 2001 From: shingjan Date: Thu, 4 Nov 2021 17:13:07 -0700 Subject: [PATCH 22/32] update test comments --- tests/python/unittest/test_tvmscript_type.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/tests/python/unittest/test_tvmscript_type.py b/tests/python/unittest/test_tvmscript_type.py index 951599c7df3e..5c8fe2a9cff1 100644 --- a/tests/python/unittest/test_tvmscript_type.py +++ b/tests/python/unittest/test_tvmscript_type.py @@ -17,16 +17,13 @@ # pylint: disable=missing-function-docstring,missing-module-docstring,invalid-name,pointless-string-statement from tvm.script import tir as T -""" -This prim_func tests the type of -T.prim_func, T.handle, T.match_buffer, T.block -T.reads, T.writes, T.alloc_buffer, T.serial -T.block_attr, T.float32, T.axis.remap -""" - @T.prim_func def element_wise_storage_align(a: T.handle, c: T.handle) -> None: + """ + This prim func include necessary buffer types that need to be checked + e.g. reads/writes, match_buffer/alloc_buffer, serial/block etc. + """ C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) # body @@ -58,6 +55,10 @@ def element_wise_storage_align(a: T.handle, c: T.handle) -> None: @T.prim_func def element_wise_env_thread_x(a: T.handle, b: T.handle, c: T.handle) -> None: + """ + This prim func include necessary thread types that need to be checked + e.g. env_thread, launch_thread, thread_binding etc. + """ j1_0 = T.env_thread("threadIdx.x") j0_0 = T.env_thread("threadIdx.x") i = T.env_thread("blockIdx.x") From bbcd16b168ea4c6badd2c5d33602c66fc1b74311 Mon Sep 17 00:00:00 2001 From: shingjan Date: Thu, 4 Nov 2021 17:29:24 -0700 Subject: [PATCH 23/32] linting fix --- python/tvm/script/tir/ty.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/tvm/script/tir/ty.py b/python/tvm/script/tir/ty.py index d72e971beae2..e3af53e4a5e0 100644 --- a/python/tvm/script/tir/ty.py +++ b/python/tvm/script/tir/ty.py @@ -34,7 +34,7 @@ def __call__(self): raise NotImplementedError -class ConcreteType(TypeGeneric): # pylint: disable=too-few-public-methods +class ConcreteType(TypeGeneric): # pylint: disable=too-few-public-methods, abstract-method """TVM script typing class for uniform Type objects""" def __init__(self, vtype): @@ -44,7 +44,7 @@ def evaluate(self): return tvm.ir.PrimType(self.type) -class GenericPtrType(TypeGeneric): +class GenericPtrType(TypeGeneric): # pylint: disable=abstract-method """TVM script typing class generator for PtrType [] operator is overloaded, accepts a ConcreteType and returns a ConcreteType wrapping PtrType @@ -54,7 +54,7 @@ def __getitem__(self, vtype): return ConcreteType(tvm.ir.PointerType(vtype.evaluate())) -class GenericTupleType(TypeGeneric): +class GenericTupleType(TypeGeneric): # pylint: disable=abstract-method """TVM script typing class generator for TupleType [] operator is overloaded, accepts a list of ConcreteType and returns a ConcreteType From 9fda96e6e972d15a9a7fad73eacee2656884df3c Mon Sep 17 00:00:00 2001 From: shingjan Date: Fri, 5 Nov 2021 14:03:35 -0700 Subject: [PATCH 24/32] address comments --- python/tvm/script/tir/__init__.pyi | 12 ++++++------ python/tvm/script/tir/ty.py | 4 +++- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/python/tvm/script/tir/__init__.pyi b/python/tvm/script/tir/__init__.pyi index 253135cf43be..254de1d7db85 100644 --- a/python/tvm/script/tir/__init__.pyi +++ b/python/tvm/script/tir/__init__.pyi @@ -41,8 +41,9 @@ from .node import BufferSlice redefine types """ -class Buffer(Var): - def __getitem__(self: Buffer, *pos: Union[int, PrimExpr, slice, Any]) -> PrimExpr: ... +class Buffer: + def __getitem__(self: Buffer, pos: Union[PrimExpr, int, slice]) -> PrimExpr: ... + def __setitem__(self: Buffer, pos: Union[PrimExpr, int, slice], value: PrimExpr) -> None: ... @property def data(self: Buffer) -> Ptr: ... @@ -275,11 +276,10 @@ class boolean: ... class handle: def __getitem__(self: handle, pos: Tuple[Union[Number, PrimExpr, slice]]) -> Buffer: ... + def __setitem__( + self: handle, pos: Tuple[Union[Number, PrimExpr, slice]], value: Buffer + ) -> Buffer: ... @property def data(self: handle) -> Ptr: ... -# class float32: -# def __new__(self, imm: Union[PrimExpr, Number]) -> PrimExpr: ... -# def __init__(self, imm: Union[PrimExpr, Number]) -> None: ... - class Ptr: ... diff --git a/python/tvm/script/tir/ty.py b/python/tvm/script/tir/ty.py index e3af53e4a5e0..9140310d4733 100644 --- a/python/tvm/script/tir/ty.py +++ b/python/tvm/script/tir/ty.py @@ -30,8 +30,10 @@ def evaluate(self): """Return an actual ir.Type Object that this Generic class wraps""" raise TypeError("Cannot get tvm.Type from a generic type") + # This function is added here to avoid a pylint error + # for T.int/float below not being callable def __call__(self): - raise NotImplementedError + raise NotImplementedError() class ConcreteType(TypeGeneric): # pylint: disable=too-few-public-methods, abstract-method From e24d2219f0e45b263d6743fd3df6dd7ef07c08fe Mon Sep 17 00:00:00 2001 From: shingjan Date: Fri, 5 Nov 2021 14:11:46 -0700 Subject: [PATCH 25/32] add pylint for tir type check --- tests/lint/pylint.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/lint/pylint.sh b/tests/lint/pylint.sh index a96c2672c01f..e47d576ced2f 100755 --- a/tests/lint/pylint.sh +++ b/tests/lint/pylint.sh @@ -19,3 +19,4 @@ python3 -m pylint python/tvm --rcfile=$(dirname "$0")/pylintrc python3 -m pylint vta/python/vta --rcfile=$(dirname "$0")/pylintrc +python3 -m pylint tests/python/unittest/test_tvmscript_type.py --rcfile=$(dirname "$0")/pylintrc From d7503a5965d6af9497c8fcb44ef3515e5fdbac31 Mon Sep 17 00:00:00 2001 From: shingjan Date: Fri, 5 Nov 2021 14:12:25 -0700 Subject: [PATCH 26/32] address comments --- tests/python/unittest/test_tvmscript_type.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tests/python/unittest/test_tvmscript_type.py b/tests/python/unittest/test_tvmscript_type.py index 5c8fe2a9cff1..80c80dc13fcb 100644 --- a/tests/python/unittest/test_tvmscript_type.py +++ b/tests/python/unittest/test_tvmscript_type.py @@ -47,12 +47,6 @@ def element_wise_storage_align(a: T.handle, c: T.handle) -> None: C[vi_1, vj_1] = B[vi_1, vj_1] + T.float32(1) -""" -This prim_func tests the type of -T.env_thread, T.launch_thread, T.thread_binding -""" - - @T.prim_func def element_wise_env_thread_x(a: T.handle, b: T.handle, c: T.handle) -> None: """ From 8f989ca85072418fe30212764bd5db83eb016722 Mon Sep 17 00:00:00 2001 From: shingjan Date: Fri, 5 Nov 2021 15:23:28 -0700 Subject: [PATCH 27/32] move doc string --- tests/python/unittest/test_tvmscript_type.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/tests/python/unittest/test_tvmscript_type.py b/tests/python/unittest/test_tvmscript_type.py index 80c80dc13fcb..2c5e2e697947 100644 --- a/tests/python/unittest/test_tvmscript_type.py +++ b/tests/python/unittest/test_tvmscript_type.py @@ -17,13 +17,14 @@ # pylint: disable=missing-function-docstring,missing-module-docstring,invalid-name,pointless-string-statement from tvm.script import tir as T +""" +This prim func include necessary buffer types that need to be checked +e.g. reads/writes, match_buffer/alloc_buffer, serial/block etc. +""" + @T.prim_func def element_wise_storage_align(a: T.handle, c: T.handle) -> None: - """ - This prim func include necessary buffer types that need to be checked - e.g. reads/writes, match_buffer/alloc_buffer, serial/block etc. - """ C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) # body @@ -47,12 +48,14 @@ def element_wise_storage_align(a: T.handle, c: T.handle) -> None: C[vi_1, vj_1] = B[vi_1, vj_1] + T.float32(1) +""" +This prim func include necessary thread types that need to be checked +e.g. env_thread, launch_thread, thread_binding etc. +""" + + @T.prim_func def element_wise_env_thread_x(a: T.handle, b: T.handle, c: T.handle) -> None: - """ - This prim func include necessary thread types that need to be checked - e.g. env_thread, launch_thread, thread_binding etc. - """ j1_0 = T.env_thread("threadIdx.x") j0_0 = T.env_thread("threadIdx.x") i = T.env_thread("blockIdx.x") From 4331f04d2e4abb07dc9404686e0fe9ea31a45aa4 Mon Sep 17 00:00:00 2001 From: shingjan Date: Fri, 5 Nov 2021 15:46:33 -0700 Subject: [PATCH 28/32] comments --- python/tvm/script/tir/__init__.pyi | 2 +- tests/python/unittest/test_tvmscript_type.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/python/tvm/script/tir/__init__.pyi b/python/tvm/script/tir/__init__.pyi index 254de1d7db85..57a96cb0e928 100644 --- a/python/tvm/script/tir/__init__.pyi +++ b/python/tvm/script/tir/__init__.pyi @@ -85,7 +85,7 @@ def max(a: PrimExpr, b: PrimExpr) -> PrimExpr: ... def min(a: PrimExpr, b: PrimExpr) -> PrimExpr: ... def Select(cond: PrimExpr, if_body: PrimExpr, else_body: PrimExpr) -> PrimExpr: ... def if_then_else(cond: PrimExpr, t: PrimExpr, f: PrimExpr, dtype: str) -> PrimExpr: ... -def evaluate(value: PrimExpr) -> PrimExpr: ... +def evaluate(value: PrimExpr) -> None: ... def reinterpret(value: PrimExpr, dtype: str) -> PrimExpr: ... def store(var: Var, index: PrimExpr, value: PrimExpr, predicate: PrimExpr = True) -> None: ... def comm_reducer(lambda_io: Tuple[List[PrimExpr]], identities: List[PrimExpr]) -> PrimExpr: ... diff --git a/tests/python/unittest/test_tvmscript_type.py b/tests/python/unittest/test_tvmscript_type.py index 2c5e2e697947..44ea04b5ed36 100644 --- a/tests/python/unittest/test_tvmscript_type.py +++ b/tests/python/unittest/test_tvmscript_type.py @@ -35,7 +35,8 @@ def element_wise_storage_align(a: T.handle, c: T.handle) -> None: for i0 in T.serial(0, 128): for ax1 in T.serial(0, 128): with T.block("B"): - vi, vj = T.axis.remap("SS", [i0, ax1]) + vi = T.axis.S(128, i0) + vj = T.axis.S(128, ax1) T.reads([A[vi, vj]]) T.writes([B[vi, vj]]) T.block_attr({"buffer_dim_align": [[0, 0, 128, 127]]}) From 872371859b33930fde4f58dd6e483e91e920f680 Mon Sep 17 00:00:00 2001 From: shingjan Date: Mon, 8 Nov 2021 11:37:19 -0800 Subject: [PATCH 29/32] getter setter --- python/tvm/script/tir/__init__.pyi | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/python/tvm/script/tir/__init__.pyi b/python/tvm/script/tir/__init__.pyi index 57a96cb0e928..d2fe5b8ec86c 100644 --- a/python/tvm/script/tir/__init__.pyi +++ b/python/tvm/script/tir/__init__.pyi @@ -42,8 +42,14 @@ redefine types """ class Buffer: - def __getitem__(self: Buffer, pos: Union[PrimExpr, int, slice]) -> PrimExpr: ... - def __setitem__(self: Buffer, pos: Union[PrimExpr, int, slice], value: PrimExpr) -> None: ... + @overload + def __getitem__(self: Buffer, pos: List[Union[PrimExpr, int]]) -> PrimExpr: ... + @overload + def __getitem__(self: Buffer, pos: Union[PrimExpr, int]) -> PrimExpr: ... + @overload + def __setitem__(self: Buffer, pos: List[Union[PrimExpr, int]], value: PrimExpr) -> None: ... + @overload + def __setitem__(self: Buffer, pos: Union[PrimExpr, int], value: PrimExpr) -> None: ... @property def data(self: Buffer) -> Ptr: ... From 05aa7dd2e3ffdfaf853f05c7ed0ecba01ce6c3be Mon Sep 17 00:00:00 2001 From: shingjan Date: Mon, 8 Nov 2021 12:22:43 -0800 Subject: [PATCH 30/32] add PrimExpr, IterVar and Var --- python/tvm/script/tir/__init__.pyi | 120 ++++++++++++++++++++++------- 1 file changed, 93 insertions(+), 27 deletions(-) diff --git a/python/tvm/script/tir/__init__.pyi b/python/tvm/script/tir/__init__.pyi index d2fe5b8ec86c..f97cc7484de3 100644 --- a/python/tvm/script/tir/__init__.pyi +++ b/python/tvm/script/tir/__init__.pyi @@ -33,7 +33,7 @@ from numbers import Number import builtins from tvm.tir.function import PrimFunc -from tvm.tir import PrimExpr, Range, IterVar, Var +from tvm.tir import Range from tvm.runtime import Object from .node import BufferSlice @@ -41,13 +41,43 @@ from .node import BufferSlice redefine types """ +class PrimExpr: + def __init__(self: PrimExpr) -> None: ... + @overload + def __add__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ... + @overload + def __add__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ... + @overload + def __sub__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ... + @overload + def __sub__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ... + @overload + def __mul__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ... + @overload + def __mul__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ... + @overload + def __div__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ... + @overload + def __div__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ... + def __radd__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ... + def __rsub__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ... + def __rmul__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ... + def __rdiv__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ... + +class Var(PrimExpr): ... +class IterVar(Var): ... + class Buffer: @overload - def __getitem__(self: Buffer, pos: List[Union[PrimExpr, int]]) -> PrimExpr: ... + def __getitem__( + self: Buffer, pos: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]] + ) -> PrimExpr: ... @overload def __getitem__(self: Buffer, pos: Union[PrimExpr, int]) -> PrimExpr: ... @overload - def __setitem__(self: Buffer, pos: List[Union[PrimExpr, int]], value: PrimExpr) -> None: ... + def __setitem__( + self: Buffer, pos: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr + ) -> None: ... @overload def __setitem__(self: Buffer, pos: Union[PrimExpr, int], value: PrimExpr) -> None: ... @property @@ -57,19 +87,19 @@ class Buffer: Variables and constants """ -def bool(imm: Union[PrimExpr, builtins.bool, Number]) -> PrimExpr: ... -def int8(imm: Union[PrimExpr, Number]) -> PrimExpr: ... -def int16(imm: Union[PrimExpr, Number]) -> PrimExpr: ... -def int32(imm: Union[PrimExpr, Number]) -> PrimExpr: ... -def int64(imm: Union[PrimExpr, Number]) -> PrimExpr: ... -def uint8(imm: Union[PrimExpr, Number]) -> PrimExpr: ... -def uint16(imm: Union[PrimExpr, Number]) -> PrimExpr: ... -def uint32(imm: Union[PrimExpr, Number]) -> PrimExpr: ... -def uint64(imm: Union[PrimExpr, Number]) -> PrimExpr: ... -def float8(imm: Union[PrimExpr, Number]) -> PrimExpr: ... -def float16(imm: Union[PrimExpr, Number]) -> PrimExpr: ... -def float32(imm: Union[PrimExpr, Number]) -> PrimExpr: ... -def float64(imm: Union[PrimExpr, Number]) -> PrimExpr: ... +def bool(imm: Union[PrimExpr, builtins.bool, builtins.int]) -> PrimExpr: ... +def int8(imm: Union[PrimExpr, int]) -> PrimExpr: ... +def int16(imm: Union[PrimExpr, int]) -> PrimExpr: ... +def int32(imm: Union[PrimExpr, int]) -> PrimExpr: ... +def int64(imm: Union[PrimExpr, int]) -> PrimExpr: ... +def uint8(imm: Union[PrimExpr, int]) -> PrimExpr: ... +def uint16(imm: Union[PrimExpr, int]) -> PrimExpr: ... +def uint32(imm: Union[PrimExpr, int]) -> PrimExpr: ... +def uint64(imm: Union[PrimExpr, int]) -> PrimExpr: ... +def float8(imm: Union[PrimExpr, int]) -> PrimExpr: ... +def float16(imm: Union[PrimExpr, int]) -> PrimExpr: ... +def float32(imm: Union[PrimExpr, int]) -> PrimExpr: ... +def float64(imm: Union[PrimExpr, int]) -> PrimExpr: ... """ Intrinsic @@ -82,7 +112,9 @@ def floormod(x: PrimExpr, y: PrimExpr) -> PrimExpr: ... def truncmod(x: PrimExpr, y: PrimExpr) -> PrimExpr: ... def truncdiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ... def abs(x: PrimExpr) -> PrimExpr: ... -def load(dtype: str, var: Var, index: PrimExpr, predicate: PrimExpr = None) -> PrimExpr: ... +def load( + dtype: str, var: Var, index: PrimExpr, predicate: Union[PrimExpr, builtins.bool] = None +) -> PrimExpr: ... def cast(value: PrimExpr, dtype: str) -> PrimExpr: ... def ramp(base: PrimExpr, stride: Any, lanes: int) -> PrimExpr: ... def broadcast(value: PrimExpr, lanes: int) -> PrimExpr: ... @@ -93,7 +125,9 @@ def Select(cond: PrimExpr, if_body: PrimExpr, else_body: PrimExpr) -> PrimExpr: def if_then_else(cond: PrimExpr, t: PrimExpr, f: PrimExpr, dtype: str) -> PrimExpr: ... def evaluate(value: PrimExpr) -> None: ... def reinterpret(value: PrimExpr, dtype: str) -> PrimExpr: ... -def store(var: Var, index: PrimExpr, value: PrimExpr, predicate: PrimExpr = True) -> None: ... +def store( + var: Var, index: PrimExpr, value: PrimExpr, predicate: Union[PrimExpr, builtins.bool] = True +) -> None: ... def comm_reducer(lambda_io: Tuple[List[PrimExpr]], identities: List[PrimExpr]) -> PrimExpr: ... """ @@ -176,18 +210,50 @@ special_stmt - Axis """ class axis: + @overload + @staticmethod + def spatial(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ... + @overload @staticmethod - def spatial(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ... + def spatial( + dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr + ) -> IterVar: ... + @overload + @staticmethod + def S(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ... + @overload + @staticmethod + def S(dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr) -> IterVar: ... + @overload @staticmethod - def S(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ... + def reduce(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ... + @overload @staticmethod - def reduce(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ... + def reduce( + dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr + ) -> IterVar: ... + @overload @staticmethod - def R(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ... + def R(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ... + @overload @staticmethod - def scan(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ... + def R(dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr) -> IterVar: ... + @overload + @staticmethod + def scan(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ... + @overload + @staticmethod + def scan( + dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr + ) -> IterVar: ... + @overload + @staticmethod + def opaque(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ... + @overload @staticmethod - def opaque(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ... + def opaque( + dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr + ) -> IterVar: ... @staticmethod def remap(iter_types: str, loop_vars: List[Var]) -> List[IterVar]: ... @@ -230,7 +296,7 @@ def allocate( condition: Union[PrimExpr, builtins.bool] = True, annotations: Optional[Mapping[str, Object]] = None, ) -> Var: ... -def launch_thread(env_var: Var, extent: PrimExpr) -> Var: ... +def launch_thread(env_var: Var, extent: Union[int, PrimExpr]) -> Var: ... def realize( buffer_slice: BufferSlice, scope: str, condition: Union[PrimExpr, builtins.bool] = True ) -> None: ... @@ -281,9 +347,9 @@ ty - redefine types class boolean: ... class handle: - def __getitem__(self: handle, pos: Tuple[Union[Number, PrimExpr, slice]]) -> Buffer: ... + def __getitem__(self: handle, pos: Tuple[Union[int, PrimExpr, slice]]) -> Buffer: ... def __setitem__( - self: handle, pos: Tuple[Union[Number, PrimExpr, slice]], value: Buffer + self: handle, pos: Tuple[Union[int, PrimExpr, slice]], value: Buffer ) -> Buffer: ... @property def data(self: handle) -> Ptr: ... From 543caecd891c6eaaffb869ea07cdd53eb7bad5aa Mon Sep 17 00:00:00 2001 From: shingjan Date: Mon, 8 Nov 2021 12:54:01 -0800 Subject: [PATCH 31/32] add sequence --- python/tvm/script/tir/__init__.pyi | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/python/tvm/script/tir/__init__.pyi b/python/tvm/script/tir/__init__.pyi index f97cc7484de3..8370b8282592 100644 --- a/python/tvm/script/tir/__init__.pyi +++ b/python/tvm/script/tir/__init__.pyi @@ -69,15 +69,11 @@ class IterVar(Var): ... class Buffer: @overload - def __getitem__( - self: Buffer, pos: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]] - ) -> PrimExpr: ... + def __getitem__(self: Buffer, pos: Sequence[Union[PrimExpr, int]]) -> PrimExpr: ... @overload def __getitem__(self: Buffer, pos: Union[PrimExpr, int]) -> PrimExpr: ... @overload - def __setitem__( - self: Buffer, pos: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr - ) -> None: ... + def __setitem__(self: Buffer, pos: Sequence[Union[PrimExpr, int]], value: PrimExpr) -> None: ... @overload def __setitem__(self: Buffer, pos: Union[PrimExpr, int], value: PrimExpr) -> None: ... @property From 8fbbc1b5cb85bf5a4baeab32ecb4f1300444d188 Mon Sep 17 00:00:00 2001 From: shingjan Date: Mon, 8 Nov 2021 13:33:38 -0800 Subject: [PATCH 32/32] change for handle --- python/tvm/script/tir/__init__.pyi | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/python/tvm/script/tir/__init__.pyi b/python/tvm/script/tir/__init__.pyi index 8370b8282592..fba026d414f6 100644 --- a/python/tvm/script/tir/__init__.pyi +++ b/python/tvm/script/tir/__init__.pyi @@ -343,10 +343,16 @@ ty - redefine types class boolean: ... class handle: - def __getitem__(self: handle, pos: Tuple[Union[int, PrimExpr, slice]]) -> Buffer: ... + @overload + def __getitem__(self: handle, pos: Sequence[Union[int, PrimExpr, slice]]) -> Buffer: ... + @overload + def __getitem__(self: handle, pos: Union[int, PrimExpr, slice]) -> Buffer: ... + @overload def __setitem__( - self: handle, pos: Tuple[Union[int, PrimExpr, slice]], value: Buffer - ) -> Buffer: ... + self: handle, pos: Sequence[Union[int, PrimExpr, slice]], value: Buffer + ) -> None: ... + @overload + def __setitem__(self: handle, pos: Union[int, PrimExpr, slice], value: Buffer) -> None: ... @property def data(self: handle) -> Ptr: ...