diff --git a/python/tvm/script/tir/__init__.pyi b/python/tvm/script/tir/__init__.pyi new file mode 100644 index 000000000000..fba026d414f6 --- /dev/null +++ b/python/tvm/script/tir/__init__.pyi @@ -0,0 +1,359 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=redefined-builtin +from typing import ( + Any, + Callable, + ContextManager, + Dict, + Iterable, + Optional, + Tuple, + Union, + Sequence, + List, + Mapping, + overload, +) +from numbers import Number +import builtins + +from tvm.tir.function import PrimFunc +from tvm.tir import Range +from tvm.runtime import Object +from .node import BufferSlice + +""" +redefine types +""" + +class PrimExpr: + def __init__(self: PrimExpr) -> None: ... + @overload + def __add__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ... + @overload + def __add__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ... + @overload + def __sub__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ... + @overload + def __sub__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ... + @overload + def __mul__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ... + @overload + def __mul__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ... + @overload + def __div__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ... + @overload + def __div__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ... + def __radd__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ... + def __rsub__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ... + def __rmul__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ... + def __rdiv__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ... + +class Var(PrimExpr): ... +class IterVar(Var): ... + +class Buffer: + @overload + def __getitem__(self: Buffer, pos: Sequence[Union[PrimExpr, int]]) -> PrimExpr: ... + @overload + def __getitem__(self: Buffer, pos: Union[PrimExpr, int]) -> PrimExpr: ... + @overload + def __setitem__(self: Buffer, pos: Sequence[Union[PrimExpr, int]], value: PrimExpr) -> None: ... + @overload + def __setitem__(self: Buffer, pos: Union[PrimExpr, int], value: PrimExpr) -> None: ... + @property + def data(self: Buffer) -> Ptr: ... + +""" +Variables and constants +""" + +def bool(imm: Union[PrimExpr, builtins.bool, builtins.int]) -> PrimExpr: ... +def int8(imm: Union[PrimExpr, int]) -> PrimExpr: ... +def int16(imm: Union[PrimExpr, int]) -> PrimExpr: ... +def int32(imm: Union[PrimExpr, int]) -> PrimExpr: ... +def int64(imm: Union[PrimExpr, int]) -> PrimExpr: ... +def uint8(imm: Union[PrimExpr, int]) -> PrimExpr: ... +def uint16(imm: Union[PrimExpr, int]) -> PrimExpr: ... +def uint32(imm: Union[PrimExpr, int]) -> PrimExpr: ... +def uint64(imm: Union[PrimExpr, int]) -> PrimExpr: ... +def float8(imm: Union[PrimExpr, int]) -> PrimExpr: ... +def float16(imm: Union[PrimExpr, int]) -> PrimExpr: ... +def float32(imm: Union[PrimExpr, int]) -> PrimExpr: ... +def float64(imm: Union[PrimExpr, int]) -> PrimExpr: ... + +""" +Intrinsic +""" + +def min_value(dtype: str) -> PrimExpr: ... +def max_value(dtype: str) -> PrimExpr: ... +def floordiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ... +def floormod(x: PrimExpr, y: PrimExpr) -> PrimExpr: ... +def truncmod(x: PrimExpr, y: PrimExpr) -> PrimExpr: ... +def truncdiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ... +def abs(x: PrimExpr) -> PrimExpr: ... +def load( + dtype: str, var: Var, index: PrimExpr, predicate: Union[PrimExpr, builtins.bool] = None +) -> PrimExpr: ... +def cast(value: PrimExpr, dtype: str) -> PrimExpr: ... +def ramp(base: PrimExpr, stride: Any, lanes: int) -> PrimExpr: ... +def broadcast(value: PrimExpr, lanes: int) -> PrimExpr: ... +def iter_var(var: Union[Var, str], dom: Range, iter_type: int, thread_tag: str) -> IterVar: ... +def max(a: PrimExpr, b: PrimExpr) -> PrimExpr: ... +def min(a: PrimExpr, b: PrimExpr) -> PrimExpr: ... +def Select(cond: PrimExpr, if_body: PrimExpr, else_body: PrimExpr) -> PrimExpr: ... +def if_then_else(cond: PrimExpr, t: PrimExpr, f: PrimExpr, dtype: str) -> PrimExpr: ... +def evaluate(value: PrimExpr) -> None: ... +def reinterpret(value: PrimExpr, dtype: str) -> PrimExpr: ... +def store( + var: Var, index: PrimExpr, value: PrimExpr, predicate: Union[PrimExpr, builtins.bool] = True +) -> None: ... +def comm_reducer(lambda_io: Tuple[List[PrimExpr]], identities: List[PrimExpr]) -> PrimExpr: ... + +""" +Unary operator +""" + +def exp2(x: PrimExpr) -> PrimExpr: ... +def exp10(x: PrimExpr) -> PrimExpr: ... +def erf(x: PrimExpr) -> PrimExpr: ... +def tanh(x: PrimExpr) -> PrimExpr: ... +def sigmoid(x: PrimExpr) -> PrimExpr: ... +def log(x: PrimExpr) -> PrimExpr: ... +def log2(x: PrimExpr) -> PrimExpr: ... +def log10(x: PrimExpr) -> PrimExpr: ... +def log1p(x: PrimExpr) -> PrimExpr: ... +def tan(x: PrimExpr) -> PrimExpr: ... +def cos(x: PrimExpr) -> PrimExpr: ... +def cosh(x: PrimExpr) -> PrimExpr: ... +def acos(x: PrimExpr) -> PrimExpr: ... +def acosh(x: PrimExpr) -> PrimExpr: ... +def sin(x: PrimExpr) -> PrimExpr: ... +def sinh(x: PrimExpr) -> PrimExpr: ... +def asin(x: PrimExpr) -> PrimExpr: ... +def asinh(x: PrimExpr) -> PrimExpr: ... +def atan(x: PrimExpr) -> PrimExpr: ... +def atanh(x: PrimExpr) -> PrimExpr: ... +def atan2(x: PrimExpr) -> PrimExpr: ... +def sqrt(x: PrimExpr) -> PrimExpr: ... +def rsqrt(x: PrimExpr) -> PrimExpr: ... + +""" +special_stmt - Buffers +""" + +def match_buffer( + param: Union[Var, BufferSlice], + shape: Sequence[Union[PrimExpr, int]], + dtype: str = "float32", + data: Var = None, + strides: Optional[Sequence[int]] = None, + elem_offset: Optional[int] = None, + scope: str = "global", + align: int = -1, + offset_factor: int = 0, + buffer_type: str = "default", +) -> Buffer: ... +def buffer_decl( + shape: Sequence[Union[PrimExpr, int]], + dtype: str = "float32", + data: Var = None, + strides: Optional[Sequence[int]] = None, + elem_offset: Optional[int] = None, + scope: str = "global", + align: int = -1, + offset_factor: int = 0, + buffer_type: str = "default", +) -> Buffer: ... +def alloc_buffer( + shape: Sequence[Union[PrimExpr, int]], + dtype: str = "float32", + data: Var = None, + strides: Optional[Sequence[int]] = None, + elem_offset: Optional[int] = None, + scope: str = "global", + align: int = -1, + offset_factor: int = 0, + buffer_type: str = "default", +) -> Buffer: ... + +""" +special_stmt - Reads/Writes +""" + +def reads(read_regions: Union[BufferSlice, List[BufferSlice]]) -> None: ... +def writes(write_region: Union[BufferSlice, List[BufferSlice]]) -> None: ... +def block_attr(attrs: Mapping[str, Object]) -> None: ... + +""" +special_stmt - Axis +""" + +class axis: + @overload + @staticmethod + def spatial(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ... + @overload + @staticmethod + def spatial( + dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr + ) -> IterVar: ... + @overload + @staticmethod + def S(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ... + @overload + @staticmethod + def S(dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr) -> IterVar: ... + @overload + @staticmethod + def reduce(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ... + @overload + @staticmethod + def reduce( + dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr + ) -> IterVar: ... + @overload + @staticmethod + def R(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ... + @overload + @staticmethod + def R(dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr) -> IterVar: ... + @overload + @staticmethod + def scan(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ... + @overload + @staticmethod + def scan( + dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr + ) -> IterVar: ... + @overload + @staticmethod + def opaque(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ... + @overload + @staticmethod + def opaque( + dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr + ) -> IterVar: ... + @staticmethod + def remap(iter_types: str, loop_vars: List[Var]) -> List[IterVar]: ... + +def get_axis(begin: PrimExpr, end: PrimExpr, iter_type: int) -> IterVar: ... + +""" +special_stmt - Annotations +""" + +def buffer_var(dtype: str, storage_scope: str) -> Var: ... +def func_attr(attrs: Mapping[str, Object]) -> None: ... +def prim_func(input_func: Callable) -> PrimFunc: ... + +""" +special_stmt - Threads and Bindings +""" + +def env_thread(env_name: str) -> IterVar: ... +def bind(iter_var: IterVar, expr: PrimExpr) -> None: ... + +""" +Scope handler +""" + +class block(ContextManager): + def __init__(self, name_hint: str = "") -> None: ... + def __enter__(self) -> Sequence[IterVar]: ... + +class init(ContextManager): + def __init__(self) -> None: ... + +class let(ContextManager): + def __init__(self, var: Var, value: PrimExpr) -> None: ... + +def where(cond: PrimExpr) -> None: ... +def allocate( + extents: List[PrimExpr], + dtype: str, + scope: str, + condition: Union[PrimExpr, builtins.bool] = True, + annotations: Optional[Mapping[str, Object]] = None, +) -> Var: ... +def launch_thread(env_var: Var, extent: Union[int, PrimExpr]) -> Var: ... +def realize( + buffer_slice: BufferSlice, scope: str, condition: Union[PrimExpr, builtins.bool] = True +) -> None: ... +def attr(node: PrimExpr, attr_key: str, value: PrimExpr) -> None: ... +def Assert(condition: Union[PrimExpr, builtins.bool], message: str) -> PrimExpr: ... + +""" +Scope handler - Loops +""" + +def serial( + begin: Union[PrimExpr, int], + end: Union[PrimExpr, int], + annotations: Optional[Mapping[str, Object]] = None, +) -> Iterable[IterVar]: ... +def parallel( + begin: Union[PrimExpr, int], + end: Union[PrimExpr, int], + annotations: Optional[Mapping[str, Object]] = None, +) -> Iterable[IterVar]: ... +def vectorized( + begin: Union[PrimExpr, int], + end: Union[PrimExpr, int], + annotations: Optional[Mapping[str, Object]] = None, +) -> Iterable[IterVar]: ... +def unroll( + begin: Union[PrimExpr, int], + end: Union[PrimExpr, int], + annotations: Optional[Mapping[str, Object]] = None, +) -> Iterable[IterVar]: ... +def thread_binding( + begin: Union[PrimExpr, int], + end: Union[PrimExpr, int], + thread: str, + annotations: Optional[Mapping[str, Object]] = None, +) -> Iterable[IterVar]: ... +def for_range( + begin: Union[PrimExpr, int], + end: Union[PrimExpr, int] = None, + annotations: Optional[Mapping[str, Object]] = None, +) -> Iterable[IterVar]: ... +def grid(*extents: Union[PrimExpr, int]) -> Iterable[Tuple[IterVar]]: ... + +""" +ty - redefine types +""" + +class boolean: ... + +class handle: + @overload + def __getitem__(self: handle, pos: Sequence[Union[int, PrimExpr, slice]]) -> Buffer: ... + @overload + def __getitem__(self: handle, pos: Union[int, PrimExpr, slice]) -> Buffer: ... + @overload + def __setitem__( + self: handle, pos: Sequence[Union[int, PrimExpr, slice]], value: Buffer + ) -> None: ... + @overload + def __setitem__(self: handle, pos: Union[int, PrimExpr, slice], value: Buffer) -> None: ... + @property + def data(self: handle) -> Ptr: ... + +class Ptr: ... diff --git a/python/tvm/script/tir/ty.py b/python/tvm/script/tir/ty.py index 6a4f7bc00cb6..9140310d4733 100644 --- a/python/tvm/script/tir/ty.py +++ b/python/tvm/script/tir/ty.py @@ -30,8 +30,13 @@ def evaluate(self): """Return an actual ir.Type Object that this Generic class wraps""" raise TypeError("Cannot get tvm.Type from a generic type") + # This function is added here to avoid a pylint error + # for T.int/float below not being callable + def __call__(self): + raise NotImplementedError() -class ConcreteType(TypeGeneric): # pylint: disable=too-few-public-methods + +class ConcreteType(TypeGeneric): # pylint: disable=too-few-public-methods, abstract-method """TVM script typing class for uniform Type objects""" def __init__(self, vtype): @@ -41,7 +46,7 @@ def evaluate(self): return tvm.ir.PrimType(self.type) -class GenericPtrType(TypeGeneric): +class GenericPtrType(TypeGeneric): # pylint: disable=abstract-method """TVM script typing class generator for PtrType [] operator is overloaded, accepts a ConcreteType and returns a ConcreteType wrapping PtrType @@ -51,7 +56,7 @@ def __getitem__(self, vtype): return ConcreteType(tvm.ir.PointerType(vtype.evaluate())) -class GenericTupleType(TypeGeneric): +class GenericTupleType(TypeGeneric): # pylint: disable=abstract-method """TVM script typing class generator for TupleType [] operator is overloaded, accepts a list of ConcreteType and returns a ConcreteType diff --git a/python/tvm/tir/schedule/testing.py b/python/tvm/tir/schedule/testing.py index 66ede31f4103..04cbffcd4d87 100644 --- a/python/tvm/tir/schedule/testing.py +++ b/python/tvm/tir/schedule/testing.py @@ -17,18 +17,17 @@ """Testing utilities for the TensorIR schedule API""" from typing import Union -from tvm import tir from tvm.ir import IRModule, structural_equal from tvm.tir import PrimFunc -from tvm.tir.schedule import Trace +from tvm.tir.schedule import Trace, Schedule def verify_trace_roundtrip( - sch: tir.Schedule, + sch: Schedule, mod: Union[PrimFunc, IRModule], *, debug_mask: Union[str, int] = "all", -) -> tir.Schedule: +) -> Schedule: """Serialize a traced schedule to JSON, then replay the JSON trace by applying to a fresh new schedule, verifying the reproducibility of scheduling. @@ -51,7 +50,7 @@ def verify_trace_roundtrip( assert trace is not None json_obj = trace.as_json() # Step 2. Apply the JSON trace to a new schedule, then check if it reproduces the scheduling - new_sch = tir.Schedule(mod=mod, debug_mask=debug_mask) + new_sch = Schedule(mod=mod, debug_mask=debug_mask) Trace.apply_json_to_schedule(json_obj=json_obj, sch=new_sch) assert structural_equal(new_sch.mod, sch.mod) # Step 3. Check the consistency of the text format between the old and new traces diff --git a/tests/lint/pylint.sh b/tests/lint/pylint.sh index a96c2672c01f..e47d576ced2f 100755 --- a/tests/lint/pylint.sh +++ b/tests/lint/pylint.sh @@ -19,3 +19,4 @@ python3 -m pylint python/tvm --rcfile=$(dirname "$0")/pylintrc python3 -m pylint vta/python/vta --rcfile=$(dirname "$0")/pylintrc +python3 -m pylint tests/python/unittest/test_tvmscript_type.py --rcfile=$(dirname "$0")/pylintrc diff --git a/tests/python/unittest/test_tvmscript_type.py b/tests/python/unittest/test_tvmscript_type.py new file mode 100644 index 000000000000..44ea04b5ed36 --- /dev/null +++ b/tests/python/unittest/test_tvmscript_type.py @@ -0,0 +1,86 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring,invalid-name,pointless-string-statement +from tvm.script import tir as T + +""" +This prim func include necessary buffer types that need to be checked +e.g. reads/writes, match_buffer/alloc_buffer, serial/block etc. +""" + + +@T.prim_func +def element_wise_storage_align(a: T.handle, c: T.handle) -> None: + C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) + A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) + # body + with T.block("root"): + T.reads([]) + T.writes([]) + B = T.alloc_buffer([128, 128], elem_offset=0, align=128, offset_factor=1) + for i0 in T.serial(0, 128): + for ax1 in T.serial(0, 128): + with T.block("B"): + vi = T.axis.S(128, i0) + vj = T.axis.S(128, ax1) + T.reads([A[vi, vj]]) + T.writes([B[vi, vj]]) + T.block_attr({"buffer_dim_align": [[0, 0, 128, 127]]}) + B[vi, vj] = A[vi, vj] * T.float32(2) + for i1 in T.serial(0, 128): + with T.block("C"): + vi_1, vj_1 = T.axis.remap("SS", [i0, i1]) + T.reads([B[vi_1, vj_1]]) + T.writes([C[vi_1, vj_1]]) + C[vi_1, vj_1] = B[vi_1, vj_1] + T.float32(1) + + +""" +This prim func include necessary thread types that need to be checked +e.g. env_thread, launch_thread, thread_binding etc. +""" + + +@T.prim_func +def element_wise_env_thread_x(a: T.handle, b: T.handle, c: T.handle) -> None: + j1_0 = T.env_thread("threadIdx.x") + j0_0 = T.env_thread("threadIdx.x") + i = T.env_thread("blockIdx.x") + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) + T.launch_thread(i, 128) + T.launch_thread(j0_0, 4) + T.launch_thread(j1_0, 4) + + for blockIdx_x in T.thread_binding(0, 128, "blockIdx.x"): + for threadIdx_x in T.thread_binding(0, 4, "threadIdx.x"): + for j0_1 in T.serial(0, 32): + with T.block(""): + B[blockIdx_x, threadIdx_x * 32 + j0_1] = ( + A[blockIdx_x, threadIdx_x * 32 + j0_1] * 2.0 + ) + for j1_1 in T.serial(0, 32): + with T.block(""): + C[blockIdx_x, threadIdx_x * 32 + j1_1] = ( + B[blockIdx_x, threadIdx_x * 32 + j1_1] + 1.0 + ) + + +# Not running any test as we only want to type-check here +if __name__ == "__main__": + pass diff --git a/tests/scripts/task_mypy.sh b/tests/scripts/task_mypy.sh index aba4663d5931..0e20fc22cfb2 100755 --- a/tests/scripts/task_mypy.sh +++ b/tests/scripts/task_mypy.sh @@ -19,6 +19,7 @@ set -e set -u set -o pipefail +source tests/scripts/setup-pytest-env.sh echo "Checking MyPy Type defs in the TensorIR schedule package." mypy --check-untyped-defs python/tvm/tir/schedule @@ -32,6 +33,9 @@ mypy --check-untyped-defs python/tvm/tir/analysis/ echo "Checking MyPy Type defs in the transform package." mypy --check-untyped-defs python/tvm/tir/transform/ +echo "Checking MyPy Type defs in the TIR package with unittest" +MYPYPATH=$TVM_PATH/python mypy --check-untyped-defs tests/python/unittest/test_tvmscript_type.py + #TODO(@mikepapadim): This is failing atm # echo "Checking MyPy Type defs in the tvm.relay.backend.contrib.ethosu package." # mypy --check-untyped-defs python/tvm/relay/backend/contrib/ethosu/