From 94dd5c5bd9aa682fe8976e2f4a36d682ee3e910c Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 18 Aug 2025 16:28:11 -0400 Subject: [PATCH] [FFI] AudoDLPack compatible with torch stream context This PR updates the autodlpack path to automatically update the env stream to be consistent with torch stream context. The change would help to make FFI functions to be compatible in stream based executions. We leverage torch cpp_extension load_inline to create an efficient query function, the first time loading might take more time to build the jit module and things should be fast after the torch jit module is cached. --- ffi/scripts/benchmark_dlpack.py | 70 ++++++++++++++++++++++- python/tvm/ffi/cython/base.pxi | 8 +++ python/tvm/ffi/cython/function.pxi | 92 ++++++++++++++++++++++++++++-- 3 files changed, 162 insertions(+), 8 deletions(-) diff --git a/ffi/scripts/benchmark_dlpack.py b/ffi/scripts/benchmark_dlpack.py index b19f566364e4..1453aa95a67c 100644 --- a/ffi/scripts/benchmark_dlpack.py +++ b/ffi/scripts/benchmark_dlpack.py @@ -36,6 +36,7 @@ - """ +import os import torch import numpy as np from tvm import ffi as tvm_ffi @@ -244,7 +245,7 @@ def bench_tvm_ffi_nop_autodlpack(name, x, y, z, repeat): print_speed(name, speed) -def tvm_ffi_nop_autodlpack_from_torch(repeat, device="cpu"): +def tvm_ffi_nop_autodlpack_from_torch(repeat, device="cpu", stream=False): """ Measures overhead of running dlpack via auto convert by directly take torch.Tensor as inputs. @@ -253,7 +254,13 @@ def tvm_ffi_nop_autodlpack_from_torch(repeat, device="cpu"): x = torch.arange(1, device=device) y = torch.arange(1, device=device) z = torch.arange(1, device=device) - bench_tvm_ffi_nop_autodlpack(f"tvm.ffi.nop.autodlpack(torch[{device}])", x, y, z, repeat) + if stream: + with torch.cuda.stream(torch.cuda.Stream()): + bench_tvm_ffi_nop_autodlpack( + f"tvm.ffi.nop.autodlpack(torch[{device}][stream])", x, y, z, repeat + ) + else: + bench_tvm_ffi_nop_autodlpack(f"tvm.ffi.nop.autodlpack(torch[{device}])", x, y, z, repeat) def tvm_ffi_nop_autodlpack_from_numpy(repeat): @@ -308,6 +315,50 @@ def bench_torch_utils_to_dlpack(repeat): print_speed("torch.utils.dlpack.to_dlpack", speed) +def torch_get_cuda_stream_native(device_id): + return torch.cuda.current_stream(device_id).cuda_stream + + +def load_torch_get_current_cuda_stream(): + """Create a faster get_current_cuda_stream for torch through cpp extension.""" + from torch.utils import cpp_extension + + source = """ + #include + + int64_t get_current_cuda_stream(int device_id) { + at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(device_id); + // fast invariant, default stream is always 0 + if (stream.id() == 0) return 0; + // convert to cudaStream_t + return reinterpret_cast(static_cast(stream)); + } + """ + result = cpp_extension.load_inline( + name="get_current_cuda_stream", + cpp_sources=[source], + cuda_sources=[], + extra_cflags=["-O3"], + extra_include_paths=cpp_extension.include_paths("cuda"), + functions=["get_current_cuda_stream"], + ) + return result.get_current_cuda_stream + + +def bench_torch_get_current_stream(repeat, name, func): + """ + Measures overhead of running torch.cuda.current_stream + """ + x = torch.arange(1, device="cuda") + func(0) + start = time.time() + for i in range(repeat): + func(0) + end = time.time() + speed = (end - start) / repeat + print_speed(f"torch.cuda.current_stream[{name}]", speed) + + def main(): repeat = 10000 print("-----------------------------") @@ -323,6 +374,8 @@ def main(): tvm_ffi_nop_from_torch_utils_to_dlpack(repeat) tvm_ffi_nop_autodlpack_from_torch(repeat, "cpu") tvm_ffi_nop_autodlpack_from_torch(repeat, "cuda") + tvm_ffi_nop_autodlpack_from_torch(repeat, "cuda", stream=True) + tvm_ffi_nop_autodlpack_from_numpy(repeat) print("-------------------------------") print("Benchmark x.__dlpack__ overhead") @@ -339,6 +392,19 @@ def main(): bench_to_dlpack_versioned( tvm_ffi.from_dlpack(torch.arange(1)), "tvm.__dlpack__(max_version=(1,1))", repeat ) + print("---------------------------------------------------") + print("Benchmark torch.get_cuda_stream[default stream]") + print("---------------------------------------------------") + bench_torch_get_current_stream(repeat, "cpp-extension", load_torch_get_current_cuda_stream()) + bench_torch_get_current_stream(repeat, "python", torch_get_cuda_stream_native) + print("---------------------------------------------------") + print("Benchmark torch.get_cuda_stream[non-default stream]") + print("---------------------------------------------------") + with torch.cuda.stream(torch.cuda.Stream()): + bench_torch_get_current_stream( + repeat, "cpp-extension", load_torch_get_current_cuda_stream() + ) + bench_torch_get_current_stream(repeat, "python", torch_get_cuda_stream_native) if __name__ == "__main__": diff --git a/python/tvm/ffi/cython/base.pxi b/python/tvm/ffi/cython/base.pxi index 00b76e68f74d..24c729095989 100644 --- a/python/tvm/ffi/cython/base.pxi +++ b/python/tvm/ffi/cython/base.pxi @@ -205,6 +205,14 @@ cdef extern from "tvm/ffi/c_api.h": DLTensor* TVMFFINDArrayGetDLTensorPtr(TVMFFIObjectHandle obj) nogil DLDevice TVMFFIDLDeviceFromIntPair(int32_t device_type, int32_t device_id) nogil +cdef extern from "tvm/ffi/extra/c_env_api.h": + ctypedef void* TVMFFIStreamHandle + + void* TVMFFIEnvGetCurrentStream(int32_t device_type, int32_t device_id) nogil + int TVMFFIEnvSetStream(int32_t device_type, int32_t device_id, + TVMFFIStreamHandle stream, + TVMFFIStreamHandle* opt_out_original_stream) nogil + cdef class ByteArrayArg: cdef TVMFFIByteArray cdata diff --git a/python/tvm/ffi/cython/function.pxi b/python/tvm/ffi/cython/function.pxi index 999c2e1338b5..3ab232e95997 100644 --- a/python/tvm/ffi/cython/function.pxi +++ b/python/tvm/ffi/cython/function.pxi @@ -18,11 +18,51 @@ import ctypes from numbers import Real, Integral try: + # optionally import torch and setup torch related utils import torch except ImportError: torch = None +def load_torch_get_current_cuda_stream(): + """Create a faster get_current_cuda_stream for torch through cpp extension. + """ + from torch.utils import cpp_extension + + source = """ + #include + + int64_t get_current_cuda_stream(int device_id) { + at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(device_id); + // fast invariant, default stream is always 0 + if (stream.id() == 0) return 0; + // convert to cudaStream_t + return reinterpret_cast(static_cast(stream)); + } + """ + def fallback_get_current_cuda_stream(device_id): + """Fallback with python api""" + return torch.cuda.current_stream(device_id).cuda_stream + return fallback_get_current_cuda_stream + try: + result = cpp_extension.load_inline( + name="get_current_cuda_stream", + cpp_sources=[source], + cuda_sources=[], + extra_cflags=["-O3"], + extra_include_paths=cpp_extension.include_paths("cuda"), + functions=["get_current_cuda_stream"], + ) + return result.get_current_cuda_stream + except Exception: + return fallback_get_current_cuda_stream + +if torch is not None: + # when torch is available, jit compile the get_current_cuda_stream function + # the torch caches the extension so second loading is faster + torch_get_current_cuda_stream = load_torch_get_current_cuda_stream() + + cdef inline object make_ret_small_str(TVMFFIAny result): """convert small string to return value.""" cdef TVMFFIByteArray bytes @@ -76,9 +116,13 @@ cdef inline object make_ret(TVMFFIAny result): raise ValueError("Unhandled type index %d" % type_index) -cdef inline int make_args(tuple py_args, TVMFFIAny* out, list temp_args) except -1: +cdef inline int make_args(tuple py_args, TVMFFIAny* out, list temp_args, + int* ctx_dev_type, int* ctx_dev_id, TVMFFIStreamHandle* ctx_stream) except -1: """Pack arguments into c args tvm call accept""" - cdef unsigned long long ptr + cdef unsigned long long temp_ptr + cdef DLTensor* temp_dltensor + cdef int is_cuda = 0 + for i, arg in enumerate(py_args): # clear the value to ensure zero padding on 32bit platforms if sizeof(void*) != 8: @@ -96,10 +140,18 @@ cdef inline int make_args(tuple py_args, TVMFFIAny* out, list temp_args) except out[i].type_index = TVMFFIObjectGetTypeIndex((arg).chandle) out[i].v_ptr = (arg).chandle elif torch is not None and isinstance(arg, torch.Tensor): + is_cuda = arg.is_cuda arg = from_dlpack(torch.utils.dlpack.to_dlpack(arg), required_alignment=__dlpack_auto_import_required_alignment__) out[i].type_index = kTVMFFINDArray out[i].v_ptr = (arg).chandle + temp_dltensor = TVMFFINDArrayGetDLTensorPtr((arg).chandle) + # record the stream and device for torch context + if is_cuda and ctx_dev_type != NULL and ctx_dev_type[0] == -1: + ctx_dev_type[0] = temp_dltensor.device.device_type + ctx_dev_id[0] = temp_dltensor.device.device_id + temp_ptr = torch_get_current_cuda_stream(temp_dltensor.device.device_id) + ctx_stream[0] = temp_ptr temp_args.append(arg) elif hasattr(arg, "__dlpack__"): arg = from_dlpack(arg, required_alignment=__dlpack_auto_import_required_alignment__) @@ -177,12 +229,27 @@ cdef inline int FuncCall3(void* chandle, # fast path with stack alloca for less than 3 args cdef TVMFFIAny[3] packed_args cdef int nargs = len(args) + cdef int ctx_dev_type = -1 + cdef int ctx_dev_id = 0 + cdef TVMFFIStreamHandle ctx_stream = NULL + cdef TVMFFIStreamHandle prev_stream = NULL temp_args = [] - make_args(args, &packed_args[0], temp_args) + make_args(args, &packed_args[0], temp_args, &ctx_dev_type, &ctx_dev_id, &ctx_stream) with nogil: + if ctx_dev_type != -1: + # set the stream based on ctx stream + c_api_ret_code[0] = TVMFFIEnvSetStream(ctx_dev_type, ctx_dev_id, ctx_stream, &prev_stream) + if c_api_ret_code[0] != 0: + return 0 c_api_ret_code[0] = TVMFFIFunctionCall( chandle, &packed_args[0], nargs, result ) + # restore the original stream if it is not the same as the context stream + if ctx_dev_type != -1 and prev_stream != ctx_stream: + # restore the original stream + c_api_ret_code[0] = TVMFFIEnvSetStream(ctx_dev_type, ctx_dev_id, prev_stream, NULL) + if c_api_ret_code[0] != 0: + return 0 return 0 @@ -191,6 +258,10 @@ cdef inline int FuncCall(void* chandle, TVMFFIAny* result, int* c_api_ret_code) except -1: cdef int nargs = len(args) + cdef int ctx_dev_type = -1 + cdef int ctx_dev_id = 0 + cdef TVMFFIStreamHandle ctx_stream = NULL + cdef TVMFFIStreamHandle prev_stream = NULL if nargs <= 3: FuncCall3(chandle, args, result, c_api_ret_code) @@ -200,10 +271,19 @@ cdef inline int FuncCall(void* chandle, packed_args.resize(nargs) temp_args = [] - make_args(args, &packed_args[0], temp_args) + make_args(args, &packed_args[0], temp_args, &ctx_dev_type, &ctx_dev_id, &ctx_stream) with nogil: + if ctx_dev_type != -1: + c_api_ret_code[0] = TVMFFIEnvSetStream(ctx_dev_type, ctx_dev_id, ctx_stream, &prev_stream) + if c_api_ret_code[0] != 0: + return 0 c_api_ret_code[0] = TVMFFIFunctionCall(chandle, &packed_args[0], nargs, result) + # restore the original stream if it is not the same as the context stream + if ctx_dev_type != -1 and prev_stream != ctx_stream: + c_api_ret_code[0] = TVMFFIEnvSetStream(ctx_dev_type, ctx_dev_id, prev_stream, NULL) + if c_api_ret_code[0] != 0: + return 0 return 0 @@ -274,7 +354,7 @@ cdef class FieldSetter: cdef void* field_ptr = ((obj).chandle) + self.offset cdef int nargs = 1 temp_args = [] - make_args((value,), &packed_args[0], temp_args) + make_args((value,), &packed_args[0], temp_args, NULL, NULL, NULL) c_api_ret_code = self.setter(field_ptr, &packed_args[0]) # NOTE: logic is same as check_call # directly inline here to simplify traceback @@ -412,7 +492,7 @@ cdef int tvm_ffi_callback(void* context, return -1 temp_args = [] - make_args((rv,), &temp_result, temp_args) + make_args((rv,), &temp_result, temp_args, NULL, NULL, NULL) CHECK_CALL(TVMFFIAnyViewToOwnedAny(&temp_result, result)) return 0