diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index c057422a0266..6b31324fa596 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -45,6 +45,10 @@ namespace builtin { * \brief Return value. */ TVM_DLL const Op& ret(); +/*! + * \brief Return from a GPU thread. + */ +TVM_DLL const Op& thread_return(); /*! * \brief Reinterpret the value using the target type. */ diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index 99139f83b297..3dda3f7c63c5 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -91,6 +91,14 @@ TVM_DLL runtime::DataType GetRuntimeDataType(const Type& type); */ TVM_DLL PrimExpr ret(PrimExpr value, Span span = Span()); +/*! + * \brief Return from a thread. + * + * \param span The location of this operation in the source. + * \return The return expression. + */ +TVM_DLL PrimExpr thread_return(Span span = Span()); + /*! * Query the maximum possible value of dtype. * \param dtype The data type. diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 5864de2cac77..c6549ad104c3 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1927,6 +1927,7 @@ def wrapped(*args, **kwargs): sqrt = _op_wrapper(_tir_op.sqrt) tan = _op_wrapper(_tir_op.tan) tanh = _op_wrapper(_tir_op.tanh) +thread_return = _op_wrapper(_tir_op.thread_return) trunc = _op_wrapper(_tir_op.trunc) truncdiv = _op_wrapper(_tir_op.truncdiv) truncmod = _op_wrapper(_tir_op.truncmod) @@ -2205,6 +2206,7 @@ def wrapped(*args, **kwargs): "sqrt", "tan", "tanh", + "thread_return", "trunc", "truncdiv", "truncmod", diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 155c7e10de60..54c70ede7a9b 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -1882,6 +1882,23 @@ def ret(val, span=None): return _ffi_api.ret(val, span) +def thread_return(span=None): + """Return from a GPU thread. + + Parameters + ---------- + span : Optional[Span] + The location of this operator in the source code. + + Returns + ------- + ret : PrimExpr + The return expression + """ + + return _ffi_api.thread_return(span) + + def any(*args, span=None): """Create a new experssion of the union of all conditions in the arguments diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 35e3d3cb8da7..951415c3b353 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -1334,6 +1334,8 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { LOG(FATAL) << "Invalid number of lanes for float4_e2m1fn reinterpret: " << lanes; } EndScope(ssa_scope); + } else if (op->op.same_as(builtin::thread_return())) { + os << "return"; } else { CodeGenC::VisitExpr_(op, os); } diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index 70614dfeebd7..12c7c8d33c7f 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -48,6 +48,10 @@ TIR_DEFINE_BUILTIN_FUNC(ret) .set_attr("TCallEffectKind", Integer(CallEffectKind::kControlJump)) .set_num_inputs(1); +TIR_DEFINE_BUILTIN_FUNC(thread_return) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kControlJump)) + .set_num_inputs(0); + TIR_DEFINE_BUILTIN_FUNC(likely) .set_num_inputs(1) .set_attr("TCallEffectKind", Integer(CallEffectKind::kExprAnnotation)) diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index 41e090c58542..b7cf5e3f8a2f 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -247,6 +247,12 @@ PrimExpr ret(PrimExpr value, Span span) { TVM_FFI_REGISTER_GLOBAL("tir.ret").set_body_typed(ret); +PrimExpr thread_return(Span span) { + return tir::Call(DataType::Void(), tir::builtin::thread_return(), {}, span); +} + +TVM_FFI_REGISTER_GLOBAL("tir.thread_return").set_body_typed(thread_return); + // maximum and min limits PrimExpr max_value(const DataType& dtype, Span span) { using namespace tir; diff --git a/tests/python/codegen/test_target_codegen_cuda.py b/tests/python/codegen/test_target_codegen_cuda.py index 2d00618eb06f..28dfb6b9d4cb 100644 --- a/tests/python/codegen/test_target_codegen_cuda.py +++ b/tests/python/codegen/test_target_codegen_cuda.py @@ -839,5 +839,22 @@ def main( tvm.testing.assert_allclose(c_tvm.numpy(), a_np + b_np) +@tvm.testing.requires_cuda +def test_thread_return(): + @I.ir_module + class Module: + @T.prim_func + def main(A: T.Buffer((16, 16), "float32"), B: T.Buffer((16, 16), "float32")): + for bx in T.thread_binding(32, "blockIdx.x"): + for tx in T.thread_binding(32, "threadIdx.x"): + if bx >= 16 or tx >= 16: + T.thread_return() + B[bx, tx] = A[bx, tx] + + lib = tvm.compile(Module, target="cuda") + cuda_code = lib.mod.imported_modules[0].get_source() + assert "return;" in cuda_code + + if __name__ == "__main__": tvm.testing.main()