diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index 65012c6c0f0f..1930f88b4c48 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -50,6 +50,21 @@ TVM_DLL const Op& ret(); */ TVM_DLL const Op& reinterpret(); +/*! + * \brief Zero extend the value using the target type. + */ +TVM_DLL const Op& zextend(); + +/*! + * \brief Sign extend the value using the target type. + */ +TVM_DLL const Op& sextend(); + +/*! + * \brief Truncate the value using the target type. + */ +TVM_DLL const Op& truncate(); + /*! * \brief Marks a condition is likely going to happen. */ @@ -769,9 +784,10 @@ TVM_DLL const Op& vectorlow(); TVM_DLL const Op& vectorcombine(); /*! - * \brief atomic add instruction, corresponding e.g. to atomicAdd in CUDA + * \brief Atomic add instruction. */ TVM_DLL const Op& atomic_add(); + /*! * \brief Create an Nd memory allocation with storage scope */ diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 5471288878f5..d638520b76af 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1869,6 +1869,9 @@ def wrapped(*args, **kwargs): reinterpret = _dtype_forward(_tir_op.reinterpret) +sextend = _dtype_forward(_tir_op.sextend) +zextend = _dtype_forward(_tir_op.zextend) +truncate = _dtype_forward(_tir_op.truncate) call_extern = _dtype_forward(_tir_op.call_extern) call_intrin = _dtype_forward(_tir_op.call_intrin) call_llvm_intrin = _dtype_forward(_tir_op.call_llvm_intrin) @@ -2072,6 +2075,9 @@ def wrapped(*args, **kwargs): "q_multiply_shift_per_axis", "ret", "reinterpret", + "sextend", + "zextend", + "truncate", "round", "rsqrt", "shift_left", diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index f0500290b888..55d2c760f090 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -74,7 +74,7 @@ create_barriers, ) from .op import vectorlow, vectorhigh, vectorcombine -from .op import infinity, reinterpret +from .op import infinity, reinterpret, zextend, sextend, truncate from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp, clz from .op import sin, sinh, asin, asinh from .op import cos, cosh, acos, acosh @@ -88,6 +88,7 @@ from .op import q_multiply_shift, q_multiply_shift_per_axis, shift_left, shift_right from .op import TVMBackendAllocWorkspace, TVMBackendFreeWorkspace from .op import start_profile_intrinsic, end_profile_intrinsic +from .op import atomic_add from .generic import add, subtract, multiply from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 905d14296d98..c8b07e7a1e7e 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -1609,6 +1609,28 @@ def vectorcombine(dtype, vec1, vec2): return call_intrin(dtype, "tir.vectorcombine", vec1, vec2) +def atomic_add(dtype, vec0, vec1): + """Atomic add instruction. + + Parameters + ---------- + vec0 : list + The input vector. + + Parameters + ---------- + vec1 : list + The input vector. + + Returns + ------- + call : PrimExpr + The call expression. + """ + assert vec0.dtype == vec1.dtype == dtype + return call_intrin(dtype, "tir.atomic_add", vec0, vec1) + + def ret(val): """Create a tir return expression @@ -1775,7 +1797,7 @@ def infinity(dtype: str, span: Optional[Span] = None) -> Any: def reinterpret(dtype, value) -> Any: - """infinity value of dtype + """Reinterpret of the value Parameters ---------- @@ -1796,6 +1818,72 @@ def reinterpret(dtype, value) -> Any: return call_intrin(dtype, "tir.reinterpret", value) +def zextend(dtype, value) -> Any: + """Zero extend the value + + Parameters + ---------- + dtype : str + The target data type. + + value : PrimExpr + The input value. + + span : Optional[Span] + The location of this operator in the source code. + + Returns + ------- + value : tvm.Expr + The zero extended value of dtype. + """ + return call_intrin(dtype, "tir.zextend", value) + + +def sextend(dtype, value) -> Any: + """Sign extend the value + + Parameters + ---------- + dtype : str + The target data type. + + value : PrimExpr + The input value. + + span : Optional[Span] + The location of this operator in the source code. + + Returns + ------- + value : tvm.Expr + The sign extended value of dtype. + """ + return call_intrin(dtype, "tir.sextend", value) + + +def truncate(dtype, value) -> Any: + """Truncate the value + + Parameters + ---------- + dtype : str + The target data type. + + value : PrimExpr + The input value. + + span : Optional[Span] + The location of this operator in the source code. + + Returns + ------- + value : tvm.Expr + The truncated value of dtype. + """ + return call_intrin(dtype, "tir.truncate", value) + + def exp(x): """Take exponential of input x. diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 3d4d3def2411..455ea0461b63 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -1441,6 +1441,15 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { } else if (op->op.same_as(builtin::reinterpret())) { llvm::Type* target = DTypeToLLVMType(op->dtype); return builder_->CreateBitCast(MakeValue(op->args[0]), target); + } else if (op->op.same_as(builtin::zextend())) { + llvm::Type* target = DTypeToLLVMType(op->dtype); + return builder_->CreateZExt(MakeValue(op->args[0]), target); + } else if (op->op.same_as(builtin::sextend())) { + llvm::Type* target = DTypeToLLVMType(op->dtype); + return builder_->CreateSExt(MakeValue(op->args[0]), target); + } else if (op->op.same_as(builtin::truncate())) { + llvm::Type* target = DTypeToLLVMType(op->dtype); + return builder_->CreateTrunc(MakeValue(op->args[0]), target); } else if (op->op.same_as(builtin::isnan())) { // TODO(hgt312): set fast math flag llvm::Value* a = MakeValue(op->args[0]); @@ -1467,8 +1476,9 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { } return builder_->CreateShuffleVector(v0, v1, indices); } else if (op->op.same_as(builtin::atomic_add())) { - // TODO(masahi): Support atomic for CPU backend - LOG(FATAL) << "CPU backend does not support atomic add yet."; + llvm::Value* v0 = MakeValue(op->args[0]); + llvm::Value* v1 = MakeValue(op->args[1]); + return builder_->CreateAdd(v0, v1); } else if (op->op.same_as(builtin::start_profile_intrinsic()) || op->op.same_as(builtin::end_profile_intrinsic())) { LOG(INFO) << "Ignoring profile_intrinsic ... " << op->op; diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index 1b80959b5705..dc4f03c78121 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -44,6 +44,24 @@ TIR_DEFINE_BUILTIN_FUNC(reinterpret) Integer(ScriptDtypePrintLocation::kFirst)) .set_num_inputs(1); +TIR_DEFINE_BUILTIN_FUNC(zextend) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)) + .set_attr("TScriptDtypePrintLocation", + Integer(ScriptDtypePrintLocation::kFirst)) + .set_num_inputs(1); + +TIR_DEFINE_BUILTIN_FUNC(sextend) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)) + .set_attr("TScriptDtypePrintLocation", + Integer(ScriptDtypePrintLocation::kFirst)) + .set_num_inputs(1); + +TIR_DEFINE_BUILTIN_FUNC(truncate) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)) + .set_attr("TScriptDtypePrintLocation", + Integer(ScriptDtypePrintLocation::kFirst)) + .set_num_inputs(1); + TIR_DEFINE_BUILTIN_FUNC(ret) .set_attr("TCallEffectKind", Integer(CallEffectKind::kControlJump)) .set_num_inputs(1); diff --git a/tests/python/unittest/test_tir_op_types.py b/tests/python/unittest/test_tir_op_types.py index 7398ee781b9e..d5638f59bf5c 100644 --- a/tests/python/unittest/test_tir_op_types.py +++ b/tests/python/unittest/test_tir_op_types.py @@ -57,6 +57,27 @@ def test_tir_op_reinterpret(): assert expr.op.name == "tir.reinterpret" +def test_tir_op_zextend(): + buffer = tir.decl_buffer((4, 4), "uint8", offset_factor=1) + vec = buffer.vload([0, 0], dtype="uint8x16") + expr = tir.zextend("uint16x8", vec) + assert expr.op.name == "tir.zextend" + + +def test_tir_op_sextend(): + buffer = tir.decl_buffer((4, 4), "uint8", offset_factor=1) + vec = buffer.vload([0, 0], dtype="uint8x16") + expr = tir.sextend("int16x8", vec) + assert expr.op.name == "tir.sextend" + + +def test_tir_op_truncate(): + buffer = tir.decl_buffer((4, 4), "uint16", offset_factor=1) + vec = buffer.vload([0, 0], dtype="uint16x16") + expr = tir.truncate("uint8x32", vec) + assert expr.op.name == "tir.truncate" + + def test_tir_op_isnullptr(): x = tir.Var("x", dtype="int32") expr = tir.isnullptr(x) @@ -302,6 +323,15 @@ def test_tir_op_vectorcombine(): assert expr.op.name == "tir.vectorcombine" +def test_tir_op_atomic_add(): + buffer0 = tir.decl_buffer((2, 2), "uint32", offset_factor=1) + buffer1 = tir.decl_buffer((2, 2), "uint32", offset_factor=1) + vec0 = buffer0.vload([0, 0], dtype="uint32x4") + vec1 = buffer1.vload([0, 0], dtype="uint32x4") + expr = tir.atomic_add("uint32x4", vec0, vec1) + assert expr.op.name == "tir.atomic_add" + + def test_tir_op_shift_left(): x = tir.Var("x", dtype="int32") y = tir.Var("x", dtype="int32")