diff --git a/ffi/pyproject.toml b/ffi/pyproject.toml index 083a60fc3631..ab2a7f84dfc3 100644 --- a/ffi/pyproject.toml +++ b/ffi/pyproject.toml @@ -17,7 +17,7 @@ [project] name = "apache-tvm-ffi" -version = "0.1.0a6" +version = "0.1.0a7" description = "tvm ffi" authors = [{ name = "TVM FFI team" }] diff --git a/ffi/python/tvm_ffi/cython/function.pxi b/ffi/python/tvm_ffi/cython/function.pxi index a223da90cb7e..30ac3d3be619 100644 --- a/ffi/python/tvm_ffi/cython/function.pxi +++ b/ffi/python/tvm_ffi/cython/function.pxi @@ -146,6 +146,7 @@ cdef inline int make_args(tuple py_args, TVMFFIAny* out, list temp_args, if is_cuda and ctx_dev_type != NULL and ctx_dev_type[0] == -1: ctx_dev_type[0] = temp_dltensor.device.device_type ctx_dev_id[0] = temp_dltensor.device.device_id + global torch_get_current_cuda_stream if torch_get_current_cuda_stream is None: torch_get_current_cuda_stream = load_torch_get_current_cuda_stream() temp_ptr = torch_get_current_cuda_stream(temp_dltensor.device.device_id)