From 3848d3cdb5e16b4e32f7166ee396e66973bbcd92 Mon Sep 17 00:00:00 2001 From: Siyuan Feng <25500082+Hzfengsy@users.noreply.github.com> Date: Fri, 11 Jul 2025 19:38:19 +0800 Subject: [PATCH] [TIR] Add `T.thread_return()` for early thread exit in CUDA kernels This commit implements T.thread_return() functionality that allows threads to exit early from CUDA kernels. The feature is useful for cases where threads need to conditionally return based on thread indices or other conditions. Key changes: - Add thread_return builtin in TIR - Implement CUDA codegen for thread_return - Add Python bindings for T.thread_return() - Update TIR IR builder to support thread_return - Add tests demonstrating thread_return usage Example usage: ```python @T.prim_func def main(A: T.Buffer((16, 16), "float32"), B: T.Buffer((16, 16), "float32")): for i in T.thread_binding(16, thread="blockIdx.x"): for j in T.thread_binding(32, thread="threadIdx.x"): if j >= 16: T.thread_return() # Early exit for threads with j >= 16 B[i, j] = A[i, j] ``` and generate code is: ```cuda extern "C" __global__ void __launch_bounds__(32) main_kernel(float* __restrict__ A, float* __restrict__ B) { if (16 <= ((int)threadIdx.x)) { return; } B[((((int)blockIdx.x) * 16) + ((int)threadIdx.x))] = A[((((int)blockIdx.x) * 16) + ((int)threadIdx.x))]; } ``` --- include/tvm/tir/builtin.h | 4 ++++ include/tvm/tir/op.h | 8 ++++++++ python/tvm/script/ir_builder/tir/ir.py | 2 ++ python/tvm/tir/op.py | 17 +++++++++++++++++ src/target/source/codegen_cuda.cc | 2 ++ src/tir/op/builtin.cc | 4 ++++ src/tir/op/op.cc | 6 ++++++ .../python/codegen/test_target_codegen_cuda.py | 17 +++++++++++++++++ 8 files changed, 60 insertions(+) diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index c057422a0266..6b31324fa596 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -45,6 +45,10 @@ namespace builtin { * \brief Return value. */ TVM_DLL const Op& ret(); +/*! + * \brief Return from a GPU thread. + */ +TVM_DLL const Op& thread_return(); /*! * \brief Reinterpret the value using the target type. */ diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index 99139f83b297..3dda3f7c63c5 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -91,6 +91,14 @@ TVM_DLL runtime::DataType GetRuntimeDataType(const Type& type); */ TVM_DLL PrimExpr ret(PrimExpr value, Span span = Span()); +/*! + * \brief Return from a thread. + * + * \param span The location of this operation in the source. + * \return The return expression. + */ +TVM_DLL PrimExpr thread_return(Span span = Span()); + /*! * Query the maximum possible value of dtype. * \param dtype The data type. diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 5864de2cac77..c6549ad104c3 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1927,6 +1927,7 @@ def wrapped(*args, **kwargs): sqrt = _op_wrapper(_tir_op.sqrt) tan = _op_wrapper(_tir_op.tan) tanh = _op_wrapper(_tir_op.tanh) +thread_return = _op_wrapper(_tir_op.thread_return) trunc = _op_wrapper(_tir_op.trunc) truncdiv = _op_wrapper(_tir_op.truncdiv) truncmod = _op_wrapper(_tir_op.truncmod) @@ -2205,6 +2206,7 @@ def wrapped(*args, **kwargs): "sqrt", "tan", "tanh", + "thread_return", "trunc", "truncdiv", "truncmod", diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 155c7e10de60..54c70ede7a9b 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -1882,6 +1882,23 @@ def ret(val, span=None): return _ffi_api.ret(val, span) +def thread_return(span=None): + """Return from a GPU thread. + + Parameters + ---------- + span : Optional[Span] + The location of this operator in the source code. + + Returns + ------- + ret : PrimExpr + The return expression + """ + + return _ffi_api.thread_return(span) + + def any(*args, span=None): """Create a new experssion of the union of all conditions in the arguments diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 35e3d3cb8da7..951415c3b353 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -1334,6 +1334,8 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { LOG(FATAL) << "Invalid number of lanes for float4_e2m1fn reinterpret: " << lanes; } EndScope(ssa_scope); + } else if (op->op.same_as(builtin::thread_return())) { + os << "return"; } else { CodeGenC::VisitExpr_(op, os); } diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index 70614dfeebd7..12c7c8d33c7f 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -48,6 +48,10 @@ TIR_DEFINE_BUILTIN_FUNC(ret) .set_attr("TCallEffectKind", Integer(CallEffectKind::kControlJump)) .set_num_inputs(1); +TIR_DEFINE_BUILTIN_FUNC(thread_return) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kControlJump)) + .set_num_inputs(0); + TIR_DEFINE_BUILTIN_FUNC(likely) .set_num_inputs(1) .set_attr("TCallEffectKind", Integer(CallEffectKind::kExprAnnotation)) diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index 41e090c58542..b7cf5e3f8a2f 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -247,6 +247,12 @@ PrimExpr ret(PrimExpr value, Span span) { TVM_FFI_REGISTER_GLOBAL("tir.ret").set_body_typed(ret); +PrimExpr thread_return(Span span) { + return tir::Call(DataType::Void(), tir::builtin::thread_return(), {}, span); +} + +TVM_FFI_REGISTER_GLOBAL("tir.thread_return").set_body_typed(thread_return); + // maximum and min limits PrimExpr max_value(const DataType& dtype, Span span) { using namespace tir; diff --git a/tests/python/codegen/test_target_codegen_cuda.py b/tests/python/codegen/test_target_codegen_cuda.py index 2d00618eb06f..28dfb6b9d4cb 100644 --- a/tests/python/codegen/test_target_codegen_cuda.py +++ b/tests/python/codegen/test_target_codegen_cuda.py @@ -839,5 +839,22 @@ def main( tvm.testing.assert_allclose(c_tvm.numpy(), a_np + b_np) +@tvm.testing.requires_cuda +def test_thread_return(): + @I.ir_module + class Module: + @T.prim_func + def main(A: T.Buffer((16, 16), "float32"), B: T.Buffer((16, 16), "float32")): + for bx in T.thread_binding(32, "blockIdx.x"): + for tx in T.thread_binding(32, "threadIdx.x"): + if bx >= 16 or tx >= 16: + T.thread_return() + B[bx, tx] = A[bx, tx] + + lib = tvm.compile(Module, target="cuda") + cuda_code = lib.mod.imported_modules[0].get_source() + assert "return;" in cuda_code + + if __name__ == "__main__": tvm.testing.main()