Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 36 additions & 5 deletions ffi/python/tvm_ffi/cython/function.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,30 @@ import os
from numbers import Real, Integral


if os.environ.get("TVM_FFI_BUILD_DOCS", "0") == "0":
def is_building_docs():
return os.environ.get("TVM_FFI_BUILD_DOCS", "0") == "1"

def try_import_torch():
if is_building_docs():
return None
try:
# optionally import torch and setup torch related utils
import torch
return torch
except ImportError:
return None

def try_import_paddle():
if is_building_docs():
return None
try:
import paddle
return paddle
except ImportError:
torch = None
else:
torch = None
return None

# optionally import specific framework and setup framework related utils
torch = try_import_torch()
paddle = try_import_paddle()


cdef inline object make_ret_small_str(TVMFFIAny result):
Expand Down Expand Up @@ -122,6 +138,21 @@ cdef inline int make_args(tuple py_args, TVMFFIAny* out, list temp_args,
temp_ptr = torch._C._cuda_getCurrentRawStream(temp_dltensor.device.device_id)
ctx_stream[0] = <TVMFFIStreamHandle>temp_ptr
temp_args.append(arg)
elif paddle is not None and isinstance(arg, paddle.Tensor):
is_cuda = arg.is_cuda
arg = from_dlpack(paddle.utils.dlpack.to_dlpack(arg),
required_alignment=__dlpack_auto_import_required_alignment__)
out[i].type_index = kTVMFFITensor
out[i].v_ptr = (<Tensor>arg).chandle
temp_dltensor = TVMFFITensorGetDLTensorPtr((<Tensor>arg).chandle)
# record the stream and device for paddle context
if is_cuda and ctx_dev_type != NULL and ctx_dev_type[0] == -1:
ctx_dev_type[0] = temp_dltensor.device.device_type
ctx_dev_id[0] = temp_dltensor.device.device_id
# PaddlePaddle provides a torch compatible API to get the current stream
temp_ptr = paddle._C._cuda_getCurrentRawStream(temp_dltensor.device.device_id)
ctx_stream[0] = <TVMFFIStreamHandle>temp_ptr
temp_args.append(arg)
elif hasattr(arg, "__dlpack__"):
arg = from_dlpack(arg, required_alignment=__dlpack_auto_import_required_alignment__)
out[i].type_index = kTVMFFITensor
Expand Down