From a8b33f92953074aae3d8fb71b2996de3d046fa42 Mon Sep 17 00:00:00 2001 From: SigureMo Date: Mon, 8 Sep 2025 15:51:25 +0800 Subject: [PATCH] [FFI] Record stream and device type for Paddle context --- ffi/python/tvm_ffi/cython/function.pxi | 41 ++++++++++++++++++++++---- 1 file changed, 36 insertions(+), 5 deletions(-) diff --git a/ffi/python/tvm_ffi/cython/function.pxi b/ffi/python/tvm_ffi/cython/function.pxi index 0161ec4292ab..b790dcd4e497 100644 --- a/ffi/python/tvm_ffi/cython/function.pxi +++ b/ffi/python/tvm_ffi/cython/function.pxi @@ -19,14 +19,30 @@ import os from numbers import Real, Integral -if os.environ.get("TVM_FFI_BUILD_DOCS", "0") == "0": +def is_building_docs(): + return os.environ.get("TVM_FFI_BUILD_DOCS", "0") == "1" + +def try_import_torch(): + if is_building_docs(): + return None try: - # optionally import torch and setup torch related utils import torch + return torch + except ImportError: + return None + +def try_import_paddle(): + if is_building_docs(): + return None + try: + import paddle + return paddle except ImportError: - torch = None -else: - torch = None + return None + +# optionally import specific framework and setup framework related utils +torch = try_import_torch() +paddle = try_import_paddle() cdef inline object make_ret_small_str(TVMFFIAny result): @@ -122,6 +138,21 @@ cdef inline int make_args(tuple py_args, TVMFFIAny* out, list temp_args, temp_ptr = torch._C._cuda_getCurrentRawStream(temp_dltensor.device.device_id) ctx_stream[0] = temp_ptr temp_args.append(arg) + elif paddle is not None and isinstance(arg, paddle.Tensor): + is_cuda = arg.is_cuda + arg = from_dlpack(paddle.utils.dlpack.to_dlpack(arg), + required_alignment=__dlpack_auto_import_required_alignment__) + out[i].type_index = kTVMFFITensor + out[i].v_ptr = (arg).chandle + temp_dltensor = TVMFFITensorGetDLTensorPtr((arg).chandle) + # record the stream and device for paddle context + if is_cuda and ctx_dev_type != NULL and ctx_dev_type[0] == -1: + ctx_dev_type[0] = temp_dltensor.device.device_type + ctx_dev_id[0] = temp_dltensor.device.device_id + # PaddlePaddle provides a torch compatible API to get the current stream + temp_ptr = paddle._C._cuda_getCurrentRawStream(temp_dltensor.device.device_id) + ctx_stream[0] = temp_ptr + temp_args.append(arg) elif hasattr(arg, "__dlpack__"): arg = from_dlpack(arg, required_alignment=__dlpack_auto_import_required_alignment__) out[i].type_index = kTVMFFITensor