Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,21 @@ TVM_DLL const Op& ret();
*/
TVM_DLL const Op& reinterpret();

/*!
* \brief Zero extend the value using the target type.
*/
TVM_DLL const Op& zextend();

/*!
* \brief Sign extend the value using the target type.
*/
TVM_DLL const Op& sextend();

/*!
* \brief Truncate the value using the target type.
*/
TVM_DLL const Op& truncate();

/*!
* \brief Marks a condition is likely going to happen.
*/
Expand Down Expand Up @@ -769,9 +784,10 @@ TVM_DLL const Op& vectorlow();
TVM_DLL const Op& vectorcombine();

/*!
* \brief atomic add instruction, corresponding e.g. to atomicAdd in CUDA
* \brief Atomic add instruction.
*/
TVM_DLL const Op& atomic_add();

/*!
* \brief Create an Nd memory allocation with storage scope
*/
Expand Down
6 changes: 6 additions & 0 deletions python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1869,6 +1869,9 @@ def wrapped(*args, **kwargs):


reinterpret = _dtype_forward(_tir_op.reinterpret)
sextend = _dtype_forward(_tir_op.sextend)
zextend = _dtype_forward(_tir_op.zextend)
truncate = _dtype_forward(_tir_op.truncate)
call_extern = _dtype_forward(_tir_op.call_extern)
call_intrin = _dtype_forward(_tir_op.call_intrin)
call_llvm_intrin = _dtype_forward(_tir_op.call_llvm_intrin)
Expand Down Expand Up @@ -2072,6 +2075,9 @@ def wrapped(*args, **kwargs):
"q_multiply_shift_per_axis",
"ret",
"reinterpret",
"sextend",
"zextend",
"truncate",
"round",
"rsqrt",
"shift_left",
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
create_barriers,
)
from .op import vectorlow, vectorhigh, vectorcombine
from .op import infinity, reinterpret
from .op import infinity, reinterpret, zextend, sextend, truncate
from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp, clz
from .op import sin, sinh, asin, asinh
from .op import cos, cosh, acos, acosh
Expand All @@ -88,6 +88,7 @@
from .op import q_multiply_shift, q_multiply_shift_per_axis, shift_left, shift_right
from .op import TVMBackendAllocWorkspace, TVMBackendFreeWorkspace
from .op import start_profile_intrinsic, end_profile_intrinsic
from .op import atomic_add
from .generic import add, subtract, multiply

from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError
Expand Down
90 changes: 89 additions & 1 deletion python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1609,6 +1609,28 @@ def vectorcombine(dtype, vec1, vec2):
return call_intrin(dtype, "tir.vectorcombine", vec1, vec2)


def atomic_add(dtype, vec0, vec1):
"""Atomic add instruction.

Parameters
----------
vec0 : list
The input vector.

Parameters
----------
vec1 : list
The input vector.

Returns
-------
call : PrimExpr
The call expression.
"""
assert vec0.dtype == vec1.dtype == dtype
return call_intrin(dtype, "tir.atomic_add", vec0, vec1)


def ret(val):
"""Create a tir return expression

Expand Down Expand Up @@ -1775,7 +1797,7 @@ def infinity(dtype: str, span: Optional[Span] = None) -> Any:


def reinterpret(dtype, value) -> Any:
"""infinity value of dtype
"""Reinterpret of the value

Parameters
----------
Expand All @@ -1796,6 +1818,72 @@ def reinterpret(dtype, value) -> Any:
return call_intrin(dtype, "tir.reinterpret", value)


def zextend(dtype, value) -> Any:
"""Zero extend the value

Parameters
----------
dtype : str
The target data type.

value : PrimExpr
The input value.

span : Optional[Span]
The location of this operator in the source code.

Returns
-------
value : tvm.Expr
The zero extended value of dtype.
"""
return call_intrin(dtype, "tir.zextend", value)


def sextend(dtype, value) -> Any:
"""Sign extend the value

Parameters
----------
dtype : str
The target data type.

value : PrimExpr
The input value.

span : Optional[Span]
The location of this operator in the source code.

Returns
-------
value : tvm.Expr
The sign extended value of dtype.
"""
return call_intrin(dtype, "tir.sextend", value)


def truncate(dtype, value) -> Any:
"""Truncate the value

Parameters
----------
dtype : str
The target data type.

value : PrimExpr
The input value.

span : Optional[Span]
The location of this operator in the source code.

Returns
-------
value : tvm.Expr
The truncated value of dtype.
"""
return call_intrin(dtype, "tir.truncate", value)


def exp(x):
"""Take exponential of input x.

Expand Down
14 changes: 12 additions & 2 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1441,6 +1441,15 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
} else if (op->op.same_as(builtin::reinterpret())) {
llvm::Type* target = DTypeToLLVMType(op->dtype);
return builder_->CreateBitCast(MakeValue(op->args[0]), target);
} else if (op->op.same_as(builtin::zextend())) {
llvm::Type* target = DTypeToLLVMType(op->dtype);
return builder_->CreateZExt(MakeValue(op->args[0]), target);
} else if (op->op.same_as(builtin::sextend())) {
llvm::Type* target = DTypeToLLVMType(op->dtype);
return builder_->CreateSExt(MakeValue(op->args[0]), target);
} else if (op->op.same_as(builtin::truncate())) {
llvm::Type* target = DTypeToLLVMType(op->dtype);
return builder_->CreateTrunc(MakeValue(op->args[0]), target);
} else if (op->op.same_as(builtin::isnan())) {
// TODO(hgt312): set fast math flag
llvm::Value* a = MakeValue(op->args[0]);
Expand All @@ -1467,8 +1476,9 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
}
return builder_->CreateShuffleVector(v0, v1, indices);
} else if (op->op.same_as(builtin::atomic_add())) {
// TODO(masahi): Support atomic for CPU backend
LOG(FATAL) << "CPU backend does not support atomic add yet.";
llvm::Value* v0 = MakeValue(op->args[0]);
llvm::Value* v1 = MakeValue(op->args[1]);
return builder_->CreateAdd(v0, v1);
} else if (op->op.same_as(builtin::start_profile_intrinsic()) ||
op->op.same_as(builtin::end_profile_intrinsic())) {
LOG(INFO) << "Ignoring profile_intrinsic ... " << op->op;
Expand Down
18 changes: 18 additions & 0 deletions src/tir/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,24 @@ TIR_DEFINE_BUILTIN_FUNC(reinterpret)
Integer(ScriptDtypePrintLocation::kFirst))
.set_num_inputs(1);

TIR_DEFINE_BUILTIN_FUNC(zextend)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure))
.set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
Integer(ScriptDtypePrintLocation::kFirst))
.set_num_inputs(1);

TIR_DEFINE_BUILTIN_FUNC(sextend)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure))
.set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
Integer(ScriptDtypePrintLocation::kFirst))
.set_num_inputs(1);

TIR_DEFINE_BUILTIN_FUNC(truncate)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure))
.set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
Integer(ScriptDtypePrintLocation::kFirst))
.set_num_inputs(1);

TIR_DEFINE_BUILTIN_FUNC(ret)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kControlJump))
.set_num_inputs(1);
Expand Down
30 changes: 30 additions & 0 deletions tests/python/unittest/test_tir_op_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,27 @@ def test_tir_op_reinterpret():
assert expr.op.name == "tir.reinterpret"


def test_tir_op_zextend():
buffer = tir.decl_buffer((4, 4), "uint8", offset_factor=1)
vec = buffer.vload([0, 0], dtype="uint8x16")
expr = tir.zextend("uint16x8", vec)
assert expr.op.name == "tir.zextend"


def test_tir_op_sextend():
buffer = tir.decl_buffer((4, 4), "uint8", offset_factor=1)
vec = buffer.vload([0, 0], dtype="uint8x16")
expr = tir.sextend("int16x8", vec)
assert expr.op.name == "tir.sextend"


def test_tir_op_truncate():
buffer = tir.decl_buffer((4, 4), "uint16", offset_factor=1)
vec = buffer.vload([0, 0], dtype="uint16x16")
expr = tir.truncate("uint8x32", vec)
assert expr.op.name == "tir.truncate"


def test_tir_op_isnullptr():
x = tir.Var("x", dtype="int32")
expr = tir.isnullptr(x)
Expand Down Expand Up @@ -302,6 +323,15 @@ def test_tir_op_vectorcombine():
assert expr.op.name == "tir.vectorcombine"


def test_tir_op_atomic_add():
buffer0 = tir.decl_buffer((2, 2), "uint32", offset_factor=1)
buffer1 = tir.decl_buffer((2, 2), "uint32", offset_factor=1)
vec0 = buffer0.vload([0, 0], dtype="uint32x4")
vec1 = buffer1.vload([0, 0], dtype="uint32x4")
expr = tir.atomic_add("uint32x4", vec0, vec1)
assert expr.op.name == "tir.atomic_add"


def test_tir_op_shift_left():
x = tir.Var("x", dtype="int32")
y = tir.Var("x", dtype="int32")
Expand Down