diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 82b1089ac197..7ea8c02bed85 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -51,6 +51,7 @@ from .op import tvm_stack_alloca, tvm_stack_make_shape, tvm_stack_make_array from .op import tvm_tuple, tvm_struct_get, tvm_struct_set from .op import address_of, lookup_param, assume, undef +from .op import tvm_thread_allreduce, type_annotation, tvm_access_ptr, tvm_throw_last_error from .op import infinity, reinterpret from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp, clz from .op import sin, sinh, asin, asinh @@ -62,6 +63,7 @@ from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod, ceildiv from .op import comm_reducer, min, max, sum from .op import q_multiply_shift +from .op import TVMBackendAllocWorkspace, TVMBackendFreeWorkspace from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 3e8dc529357d..19ce4f4bc10b 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -16,6 +16,7 @@ # under the License. # pylint: disable=redefined-builtin, invalid-name """Operators used in TIR expression.""" +import warnings from typing import Any, Optional import tvm._ffi from tvm.ir.base import Span @@ -262,10 +263,22 @@ def call_llvm_intrin(dtype, name, *args, span=None): # pylint: disable=import-outside-toplevel from tvm.target import codegen - llvm_id = codegen.llvm_lookup_intrinsic_id(name) - assert llvm_id != 0, "%s is not an LLVM intrinsic" % name + from .expr import IntImm + + if isinstance(name, str): + llvm_id = codegen.llvm_lookup_intrinsic_id(name) + elif isinstance(name, IntImm): + llvm_id = name.value + else: + llvm_id = name + if llvm_id == 0: + warnings.warn(f"Unknown llvm intrinsic function {name}, falling back to 0") return call_intrin( - dtype, Op.get("tir.call_llvm_intrin"), tvm.tir.const(llvm_id, "uint32"), *args, span=span + dtype, + Op.get("tir.call_llvm_intrin"), + tvm.tir.const(llvm_id, "uint32"), + *args, + span=span, ) @@ -294,8 +307,16 @@ def call_llvm_pure_intrin(dtype, name, *args, span=None): # pylint: disable=import-outside-toplevel from tvm.target import codegen - llvm_id = codegen.llvm_lookup_intrinsic_id(name) - assert llvm_id != 0, "%s is not an LLVM intrinsic" % name + from .expr import IntImm + + if isinstance(name, str): + llvm_id = codegen.llvm_lookup_intrinsic_id(name) + elif isinstance(name, IntImm): + llvm_id = name.value + else: + llvm_id = name + if llvm_id == 0: + warnings.warn(f"Unknown llvm intrinsic function {name}, falling back to 0") return call_intrin( dtype, Op.get("tir.call_llvm_pure_intrin"), @@ -504,6 +525,76 @@ def lookup_param(param_name, span=None): return call_intrin("handle", "tir.lookup_param", param_name, span=span) +def tvm_thread_allreduce(*freduce_args): + """ + Parameters + ---------- + freduce_args : Expr + The args. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("handle", "tir.tvm_thread_allreduce", *freduce_args) + + +def type_annotation(dtype): + """Create a type annotation expression + + Parameters + ---------- + dtype : Expr + The data type. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin(dtype, "tir.type_annotation") + + +def tvm_access_ptr(ptype, data, offset, extent, rw_mask): + """Get head access address with memory access pattern info + + Parameters + ---------- + ptype : Expr + The data type of pointer. + + data : DType* + The data of pointer. + + offset : int + The offset of pointer. + + extent : int + The extent of pointer. + + rw_mask : int + The read write mask. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("handle", "tir.tvm_access_ptr", ptype, data, offset, extent, rw_mask) + + +def tvm_throw_last_error(): + """Throw TVMGetLastError() + + Returns + ------- + ret : PrimExpr + The return expression + """ + return call_intrin("handle", "tir.tvm_throw_last_error") + + def ret(val): """Create a tir return expression @@ -1857,6 +1948,64 @@ def reducer(expr, axis, where=None, init=None, *args): return reducer +def TVMBackendAllocWorkspace(device_type, device_id, nbytes, dtype_code_hint, dtype_bits_hint): + """Backend function to allocate temporal workspace + + Parameters + ---------- + device_type : int + The device type which the space will be allocated. + + device_id : int + The device id which the space will be allocated. + + nbytes : int + The size of the space requested. + + dtype_code_hint : int + The type code of the array elements. Only used in certain backends such as OpenGL. + + dtype_bits_hint : int + The type bits of the array elements. Only used in certain backends such as OpenGL. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin( + "handle", + "tir.TVMBackendAllocWorkspace", + device_type, + device_id, + nbytes, + dtype_code_hint, + dtype_bits_hint, + ) + + +def TVMBackendFreeWorkspace(device_type, device_id, ptr): + """Backend function to free temporal workspace. + + Parameters + ---------- + device_type : int + The device type which the space will be allocated. + + device_id : int + The device id which the space will be allocated. + + ptr : Var + The result allocated space pointer. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("int32", "tir.TVMBackendFreeWorkspace", device_type, device_id, ptr) + + # pylint: disable=unnecessary-lambda sum = comm_reducer(lambda x, y: x + y, lambda t: const(0, dtype=t), name="sum") min = comm_reducer(lambda x, y: _ffi_api._OpMin(x, y, None), max_value, name="min") # type: ignore diff --git a/tests/python/unittest/test_tir_op_types.py b/tests/python/unittest/test_tir_op_types.py index e36eecf2ecc2..ffee3b3b57c9 100644 --- a/tests/python/unittest/test_tir_op_types.py +++ b/tests/python/unittest/test_tir_op_types.py @@ -79,6 +79,42 @@ def test_tir_op_call_likely(): assert expr.op.name == "tir.likely" +def test_tir_op_tvm_thread_allreduce(): + x = tir.Var("x", "int32") + buffer = tir.decl_buffer((128), "float32") + y = tir.Var("y", "handle") + z = tir.Var("z", "int32") + expr = tir.tvm_thread_allreduce(x, buffer[0], True, y, z) + assert expr.op.name == "tir.tvm_thread_allreduce" + + +def test_tir_op_type_annotation(): + expr = tir.type_annotation("int32") + assert expr.op.name == "tir.type_annotation" + + +def test_tir_op_tvm_access_ptr(): + buffer = tir.decl_buffer((128), "float32") + expr = tir.tvm_access_ptr("float32", buffer.data, 0, 1, 2) + assert expr.op.name == "tir.tvm_access_ptr" + + +def test_tir_op_tvm_throw_last_error(): + expr = tir.tvm_throw_last_error() + assert expr.op.name == "tir.tvm_throw_last_error" + + +def test_tir_op_TVMBackendAllocWorkspace(): + expr = tir.TVMBackendAllocWorkspace(0, 1, 2, 3, 4) + assert expr.op.name == "tir.TVMBackendAllocWorkspace" + + +def test_tir_op_TVMBackendFreeWorkspace(): + buffer = tir.decl_buffer((128), "float32") + expr = tir.TVMBackendFreeWorkspace(0, 1, buffer.data) + assert expr.op.name == "tir.TVMBackendFreeWorkspace" + + if __name__ == "__main__": test_tir_op_tvm_tuple() test_tir_op_tvm_struct_get() @@ -90,3 +126,9 @@ def test_tir_op_call_likely(): test_tir_op_call_assume() test_tir_op_call_undef() test_tir_op_call_likely() + test_tir_op_tvm_thread_allreduce() + test_tir_op_type_annotation() + test_tir_op_tvm_access_ptr() + test_tir_op_tvm_throw_last_error() + test_tir_op_TVMBackendAllocWorkspace() + test_tir_op_TVMBackendFreeWorkspace()