From f8d1835bd46a8638fa043fbd393ceeade1fc1f61 Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Sun, 25 Sep 2022 18:41:43 -0700 Subject: [PATCH 1/2] [TVMScript] Import TIR methods into the IRBuilder This PR introduces remaining TIR methods into IRBuilder Co-authored-by: yongwww --- include/tvm/script/ir_builder/tir/ir.h | 8 + python/tvm/script/ir_builder/tir/ir.py | 396 +++++++++++++++++- src/script/ir_builder/tir/ir.cc | 11 + .../unittest/test_tvmscript_ir_builder_tir.py | 15 + 4 files changed, 428 insertions(+), 2 deletions(-) diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h index dd289b691502..7460099f9448 100644 --- a/include/tvm/script/ir_builder/tir/ir.h +++ b/include/tvm/script/ir_builder/tir/ir.h @@ -435,6 +435,14 @@ void Prefetch(Buffer buffer, Array bounds); */ void Evaluate(PrimExpr value); +/*! + * \brief The pointer declaration function. + * \param dtype The data type of the pointer. + * \param storage_scope The storage scope of the pointer. + * \return The pointer. + */ +PrimExpr Ptr(runtime::DataType dtype, String storage_scope = "global"); + #define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName, DType) \ inline PrimExpr FuncName(Optional expr = NullOpt) { \ DataType dtype = DType; \ diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 625e1291ff20..3dba15c63d9b 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -17,24 +17,35 @@ # pylint: disable=missing-docstring """IRBuilder for TIR""" +import inspect +import functools from numbers import Integral -from typing import Any, Dict, List, Optional, Union, Tuple +from typing import Any, Callable, Dict, List, Optional, Union, Tuple import numpy as np # type: ignore from tvm.ir import Range, Type from tvm.runtime import convert, ndarray +from tvm.target.codegen import llvm_lookup_intrinsic_id from tvm.tir import ( Buffer, BufferLoad, BufferRegion, + Cast, + CommReducer, IntImm, IterVar, Let, PrimExpr, + Select, + Shuffle, StringImm, + type_annotation, Var, ) +from tvm.tir import Broadcast as broadcast from tvm.tir import Ramp as ramp +from tvm.tir import op as _tir_op +from tvm.tir.generic import cast from . import _ffi_api, frame @@ -1501,7 +1512,7 @@ def void(expr: Optional[PrimExpr] = None) -> PrimExpr: return _ffi_api.Void(expr) # type: ignore[attr-defined] # pylint: disable=no-member -def var(dtype, name="") -> Var: +def var(dtype: str, name: str = "") -> Var: """Construct a new tir.Var. Parameters @@ -1520,6 +1531,268 @@ def var(dtype, name="") -> Var: return Var(name, dtype) # pylint: disable=no-member +def ptr(dtype: str, storage_scope: str = "global") -> Var: + """The pointer declaration function. + + Parameters + ---------- + dtype : str + The data type of the pointer. + + storage_scope : str + The storage scope of the pointer. + + Returns + ------- + res : Var + The pointer. + """ + return _ffi_api.Ptr(dtype, storage_scope) # type: ignore[attr-defined] # pylint: disable=no-member + + +def min(a: PrimExpr, b: PrimExpr) -> PrimExpr: # pylint: disable=redefined-builtin + """Compute the minimum value of two expressions. + + Parameters + ---------- + a : PrimExpr + The left hand operand + + b : PrimExpr + The right hand operand + + Returns + ------- + res : PrimExpr + The result expression. + """ + return _ffi_api.min(a, b) # type: ignore[attr-defined] # pylint: disable=no-member + + +def max(a: PrimExpr, b: PrimExpr) -> PrimExpr: # pylint: disable=redefined-builtin + """Compute the maximum value of two expressions. + + Parameters + ---------- + a : PrimExpr + The left hand operand + + b : PrimExpr + The right hand operand + + Returns + ------- + res : PrimExpr + The result expression. + """ + return _ffi_api.max(a, b) # type: ignore[attr-defined] # pylint: disable=no-member + + +def iter_var(v: Union[Var, str], dom: Range, iter_type: int, thread_tag: str) -> IterVar: + """The iteration variable. + + Parameters + ---------- + var : Union[Var, str] + The internal variable that is used for iteration. + + dom : Range + The domain of the iteration. + + iter_type : int + The iteration type. + + thread_tag : str + The thread type tag. + + Returns + ------- + res : IterVar + The iteration variable. + """ + iter_type = getattr(IterVar, iter_type) + return IterVar(dom, v, iter_type, thread_tag) + + +def comm_reducer(combiner: Callable, identity: List[PrimExpr]) -> CommReducer: + """ + Create a CommReducer from lambda inputs/outputs and the identities + + Parameters + ---------- + combiner : Callable + A binary function which takes two PrimExpr as input to return a PrimExpr. + + identity : List[PrimExpr] + A list of types of output PrimExpr. + + Returns + ------- + res : CommReducer + The CommReducer. + """ + params = inspect.signature(combiner).parameters + num_args = len(params) + args = [] + for name, i in zip(params.keys(), identity + identity): + if isinstance(i, int): + args.append(Var(name, "int32")) + else: + args.append(Var(name, i.dtype)) + res = combiner(*args) + if not isinstance(res, tuple): + res = (res,) + return CommReducer(args[: num_args // 2], args[num_args // 2 :], res, identity) + + +def _op_wrapper(func): + @functools.wraps(func) + def wrapped(*args, **kwargs): + if "dtype" in kwargs: + kwargs.pop("dtype") + return func(*args, **kwargs) + + return wrapped + + +def _dtype_forward(func): + @functools.wraps(func) + def wrapped(*args, **kwargs): + if "dtype" in kwargs: + args = (kwargs.pop("dtype"),) + args + return func(*args, **kwargs) + + return wrapped + + +# pylint: disable=invalid-name + +buffer_var = ptr +abs = _op_wrapper(_tir_op.abs) # pylint: disable=redefined-builtin +fabs = abs +acos = _op_wrapper(_tir_op.acos) +acosh = _op_wrapper(_tir_op.acosh) +address_of = _op_wrapper(_tir_op.address_of) +asin = _op_wrapper(_tir_op.asin) +asinh = _op_wrapper(_tir_op.asinh) +atan = _op_wrapper(_tir_op.atan) +atan2 = _op_wrapper(_tir_op.atan2) +atanh = _op_wrapper(_tir_op.atanh) +ceil = _op_wrapper(_tir_op.ceil) +clz = _op_wrapper(_tir_op.clz) +copysign = _op_wrapper(_tir_op.copysign) +cos = _op_wrapper(_tir_op.cos) +cosh = _op_wrapper(_tir_op.cosh) +erf = _op_wrapper(_tir_op.erf) +exp = _op_wrapper(_tir_op.exp) +exp2 = _op_wrapper(_tir_op.exp2) +exp10 = _op_wrapper(_tir_op.exp10) +floor = _op_wrapper(_tir_op.floor) +ceildiv = _op_wrapper(_tir_op.ceildiv) +floordiv = _op_wrapper(_tir_op.floordiv) +floormod = _op_wrapper(_tir_op.floormod) +fmod = _op_wrapper(_tir_op.fmod) +hypot = _op_wrapper(_tir_op.hypot) +if_then_else = _op_wrapper(_tir_op.if_then_else) +infinity = _op_wrapper(_tir_op.infinity) +isfinite = _op_wrapper(_tir_op.isfinite) +isinf = _op_wrapper(_tir_op.isinf) +isnan = _op_wrapper(_tir_op.isnan) +isnullptr = _op_wrapper(_tir_op.isnullptr) +ldexp = _op_wrapper(_tir_op.ldexp) +likely = _op_wrapper(_tir_op.likely) +log = _op_wrapper(_tir_op.log) +log1p = _op_wrapper(_tir_op.log1p) +log2 = _op_wrapper(_tir_op.log2) +log10 = _op_wrapper(_tir_op.log10) +lookup_param = _op_wrapper(_tir_op.lookup_param) +max_value = _op_wrapper(_tir_op.max_value) +min_value = _op_wrapper(_tir_op.min_value) +nearbyint = _op_wrapper(_tir_op.nearbyint) +nextafter = _op_wrapper(_tir_op.nextafter) +popcount = _op_wrapper(_tir_op.popcount) +power = _op_wrapper(_tir_op.power) +q_multiply_shift = _op_wrapper(_tir_op.q_multiply_shift) +ret = _op_wrapper(_tir_op.ret) +reinterpret = _dtype_forward(_tir_op.reinterpret) +round = _op_wrapper(_tir_op.round) # pylint: disable=redefined-builtin +rsqrt = _op_wrapper(_tir_op.rsqrt) +shift_left = _op_wrapper(_tir_op.shift_left) +shift_right = _op_wrapper(_tir_op.shift_right) +sigmoid = _op_wrapper(_tir_op.sigmoid) +sin = _op_wrapper(_tir_op.sin) +sinh = _op_wrapper(_tir_op.sinh) +sqrt = _op_wrapper(_tir_op.sqrt) +tan = _op_wrapper(_tir_op.tan) +tanh = _op_wrapper(_tir_op.tanh) +trunc = _op_wrapper(_tir_op.trunc) +truncdiv = _op_wrapper(_tir_op.truncdiv) +truncmod = _op_wrapper(_tir_op.truncmod) +tvm_access_ptr = _op_wrapper(_tir_op.tvm_access_ptr) +tvm_throw_last_error = _op_wrapper(_tir_op.tvm_throw_last_error) +tvm_stack_alloca = _op_wrapper(_tir_op.tvm_stack_alloca) +tvm_stack_make_shape = _op_wrapper(_tir_op.tvm_stack_make_shape) +tvm_stack_make_array = _op_wrapper(_tir_op.tvm_stack_make_array) +call_packed = _op_wrapper(_tir_op.call_packed) +call_cpacked = _op_wrapper(_tir_op.call_cpacked) +call_packed_lowered = _op_wrapper(_tir_op.call_packed_lowered) +call_cpacked_lowered = _op_wrapper(_tir_op.call_cpacked_lowered) +call_extern = _dtype_forward(_tir_op.call_extern) +call_intrin = _dtype_forward(_tir_op.call_intrin) +call_llvm_intrin = _dtype_forward(_tir_op.call_llvm_intrin) +call_llvm_pure_intrin = _dtype_forward(_tir_op.call_llvm_pure_intrin) +call_pure_extern = _dtype_forward(_tir_op.call_pure_extern) +tvm_access_ptr = _op_wrapper(_tir_op.tvm_access_ptr) +tvm_tuple = _op_wrapper(_tir_op.tvm_tuple) +tvm_struct_set = _op_wrapper(_tir_op.tvm_struct_set) +tvm_struct_get = _tir_op.tvm_struct_get +tvm_thread_allreduce = _op_wrapper(_tir_op.tvm_thread_allreduce) +tvm_load_matrix_sync = _op_wrapper(_tir_op.tvm_load_matrix_sync) +tvm_mma_sync = _op_wrapper(_tir_op.tvm_mma_sync) +tvm_bmma_sync = _op_wrapper(_tir_op.tvm_bmma_sync) +tvm_fill_fragment = _op_wrapper(_tir_op.tvm_fill_fragment) +tvm_store_matrix_sync = _op_wrapper(_tir_op.tvm_store_matrix_sync) +ptx_mma = _dtype_forward(_tir_op.ptx_mma) +ptx_mma_sp = _dtype_forward(_tir_op.ptx_mma_sp) +ptx_ldmatrix = _dtype_forward(_tir_op.ptx_ldmatrix) +ptx_cp_async = _dtype_forward(_tir_op.ptx_cp_async) +ptx_wait_group = _op_wrapper(_tir_op.ptx_wait_group) +ptx_commit_group = _op_wrapper(_tir_op.ptx_commit_group) +mma_store = _dtype_forward(_tir_op.mma_store) +mma_fill = _dtype_forward(_tir_op.mma_fill) +vectorlow = _dtype_forward(_tir_op.vectorlow) +vectorhigh = _dtype_forward(_tir_op.vectorhigh) +vectorcombine = _dtype_forward(_tir_op.vectorcombine) +assume = _op_wrapper(_tir_op.assume) +undef = _op_wrapper(_tir_op.undef) +tvm_call_packed = call_packed +tvm_call_cpacked = call_cpacked +tvm_call_packed_lowered = call_packed_lowered +tvm_call_cpacked_lowered = call_cpacked_lowered +TVMBackendAllocWorkspace = _op_wrapper(_tir_op.TVMBackendAllocWorkspace) +TVMBackendFreeWorkspace = _op_wrapper(_tir_op.TVMBackendFreeWorkspace) + + +class inline: + """Inline function for meta-programming. + + Parameters + ---------- + value: Any + The value to be inlined. + """ + + def __init__(self, value: Any) -> None: + self.value = value + + def __iter__(self): + def f(): + for i in self.value: + yield inline(i) + + return f() + + # pylint: enable=invalid-name @@ -1581,4 +1854,123 @@ def var(dtype, name="") -> Var: "handle", "void", "var", + "ptr", + "min", + "max", + "iter_var", + "comm_reducer", + "buffer_var", + "abs", + "fabs", + "acos", + "acosh", + "address_of", + "asin", + "asinh", + "atan", + "atan2", + "atanh", + "ceil", + "clz", + "copysign", + "cos", + "cosh", + "erf", + "exp", + "exp2", + "exp10", + "floor", + "ceildiv", + "floordiv", + "floormod", + "fmod", + "hypot", + "if_then_else", + "infinity", + "isfinite", + "isinf", + "isnan", + "isnullptr", + "ldexp", + "likely", + "log", + "log1p", + "log2", + "log10", + "lookup_param", + "max_value", + "min_value", + "nearbyint", + "nextafter", + "popcount", + "power", + "q_multiply_shift", + "ret", + "reinterpret", + "round", + "rsqrt", + "shift_left", + "shift_right", + "sigmoid", + "sin", + "sinh", + "sqrt", + "tan", + "tanh", + "trunc", + "truncdiv", + "truncmod", + "tvm_access_ptr", + "tvm_throw_last_error", + "tvm_stack_alloca", + "tvm_stack_make_shape", + "tvm_stack_make_array", + "call_packed", + "call_cpacked", + "call_packed_lowered", + "call_cpacked_lowered", + "call_extern", + "call_intrin", + "call_llvm_intrin", + "call_llvm_pure_intrin", + "call_pure_extern", + "tvm_access_ptr", + "tvm_tuple", + "tvm_struct_set", + "tvm_struct_get", + "tvm_thread_allreduce", + "tvm_load_matrix_sync", + "tvm_mma_sync", + "tvm_bmma_sync", + "tvm_fill_fragment", + "tvm_store_matrix_sync", + "ptx_mma", + "ptx_mma_sp", + "ptx_ldmatrix", + "ptx_cp_async", + "ptx_wait_group", + "ptx_commit_group", + "mma_store", + "mma_fill", + "vectorlow", + "vectorhigh", + "vectorcombine", + "assume", + "undef", + "tvm_call_packed", + "tvm_call_cpacked", + "tvm_call_packed_lowered", + "tvm_call_cpacked_lowered", + "TVMBackendAllocWorkspace", + "TVMBackendFreeWorkspace", + "inline", + "llvm_lookup_intrinsic_id", + "Cast", + "Let", + "Select", + "Shuffle", + "type_annotation", + "broadcast", + "ramp", + "cast", ] diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index 28c3d69861fa..6be6e2619fea 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -534,6 +534,10 @@ DeclBufferFrame DeclBuffer(Array shape, DataType dtype, String buffer_ void Evaluate(PrimExpr value) { AddToParent(tvm::tir::Evaluate(value)); } +PrimExpr Ptr(runtime::DataType dtype, String storage_scope) { + return tvm::tir::Var("", tvm::PointerType(PrimType(dtype), storage_scope)); +} + using tvm::script::ir_builder::details::Namer; TVM_STATIC_IR_FUNCTOR(Namer, vtable) @@ -632,6 +636,8 @@ TVM_REGISTER_GLOBAL("script.ir_builder.tir.BufferStore").set_body_typed(BufferSt TVM_REGISTER_GLOBAL("script.ir_builder.tir.Prefetch").set_body_typed(Prefetch); TVM_REGISTER_GLOBAL("script.ir_builder.tir.Evaluate").set_body_typed(Evaluate); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Ptr").set_body_typed(Ptr); + TVM_REGISTER_GLOBAL("script.ir_builder.tir.Int8").set_body_typed(Int8); TVM_REGISTER_GLOBAL("script.ir_builder.tir.Int16").set_body_typed(Int16); TVM_REGISTER_GLOBAL("script.ir_builder.tir.Int32").set_body_typed(Int32); @@ -650,6 +656,11 @@ TVM_REGISTER_GLOBAL("script.ir_builder.tir.Int32x16").set_body_typed(Int32x16); TVM_REGISTER_GLOBAL("script.ir_builder.tir.Boolean").set_body_typed(Boolean); TVM_REGISTER_GLOBAL("script.ir_builder.tir.Handle").set_body_typed(Handle); TVM_REGISTER_GLOBAL("script.ir_builder.tir.Void").set_body_typed(Void); + +TVM_REGISTER_GLOBAL("script.ir_builder.tir.min") + .set_body_typed([](PrimExpr a, PrimExpr b) -> PrimExpr { return tvm::min(a, b); }); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.max") + .set_body_typed([](PrimExpr a, PrimExpr b) -> PrimExpr { return tvm::max(a, b); }); } // namespace tir } // namespace ir_builder } // namespace script diff --git a/tests/python/unittest/test_tvmscript_ir_builder_tir.py b/tests/python/unittest/test_tvmscript_ir_builder_tir.py index 40e13a2fbe2f..dbc9b594fb87 100644 --- a/tests/python/unittest/test_tvmscript_ir_builder_tir.py +++ b/tests/python/unittest/test_tvmscript_ir_builder_tir.py @@ -476,5 +476,20 @@ def test_ir_builder_tir_decl_buffer(): assert_structural_equal(ir_actual, ir_expected, map_free_vars=True) +def test_ir_builder_tir_inline(): + with IRBuilder() as ib: + m, n = T.inline(1), T.inline(2) + a, b = T.inline([3, 4]) + T.evaluate(m.value + n.value + a.value + b.value) + # the evaluate generated by IRBuilder + eval_actual = ib.get() + + # the expected evaluate + eval_expected = tir.Evaluate(10) + + # Check if the generated ir is expected + assert_structural_equal(eval_actual, eval_expected, map_free_vars=True) + + if __name__ == "__main__": tvm.testing.main() From 033078e9ae9744b321cc12ff446932e7f121b3c5 Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Sun, 25 Sep 2022 18:49:49 -0700 Subject: [PATCH 2/2] fix type annotation --- python/tvm/script/ir_builder/tir/ir.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 3dba15c63d9b..4ec1511f2907 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1588,7 +1588,7 @@ def max(a: PrimExpr, b: PrimExpr) -> PrimExpr: # pylint: disable=redefined-buil return _ffi_api.max(a, b) # type: ignore[attr-defined] # pylint: disable=no-member -def iter_var(v: Union[Var, str], dom: Range, iter_type: int, thread_tag: str) -> IterVar: +def iter_var(v: Union[Var, str], dom: Range, iter_type: str, thread_tag: str) -> IterVar: """The iteration variable. Parameters @@ -1599,7 +1599,7 @@ def iter_var(v: Union[Var, str], dom: Range, iter_type: int, thread_tag: str) -> dom : Range The domain of the iteration. - iter_type : int + iter_type : str The iteration type. thread_tag : str