diff --git a/ffi/scripts/benchmark_dlpack.py b/ffi/scripts/benchmark_dlpack.py index b19f566364e4..1453aa95a67c 100644 --- a/ffi/scripts/benchmark_dlpack.py +++ b/ffi/scripts/benchmark_dlpack.py @@ -36,6 +36,7 @@ - """ +import os import torch import numpy as np from tvm import ffi as tvm_ffi @@ -244,7 +245,7 @@ def bench_tvm_ffi_nop_autodlpack(name, x, y, z, repeat): print_speed(name, speed) -def tvm_ffi_nop_autodlpack_from_torch(repeat, device="cpu"): +def tvm_ffi_nop_autodlpack_from_torch(repeat, device="cpu", stream=False): """ Measures overhead of running dlpack via auto convert by directly take torch.Tensor as inputs. @@ -253,7 +254,13 @@ def tvm_ffi_nop_autodlpack_from_torch(repeat, device="cpu"): x = torch.arange(1, device=device) y = torch.arange(1, device=device) z = torch.arange(1, device=device) - bench_tvm_ffi_nop_autodlpack(f"tvm.ffi.nop.autodlpack(torch[{device}])", x, y, z, repeat) + if stream: + with torch.cuda.stream(torch.cuda.Stream()): + bench_tvm_ffi_nop_autodlpack( + f"tvm.ffi.nop.autodlpack(torch[{device}][stream])", x, y, z, repeat + ) + else: + bench_tvm_ffi_nop_autodlpack(f"tvm.ffi.nop.autodlpack(torch[{device}])", x, y, z, repeat) def tvm_ffi_nop_autodlpack_from_numpy(repeat): @@ -308,6 +315,50 @@ def bench_torch_utils_to_dlpack(repeat): print_speed("torch.utils.dlpack.to_dlpack", speed) +def torch_get_cuda_stream_native(device_id): + return torch.cuda.current_stream(device_id).cuda_stream + + +def load_torch_get_current_cuda_stream(): + """Create a faster get_current_cuda_stream for torch through cpp extension.""" + from torch.utils import cpp_extension + + source = """ + #include + + int64_t get_current_cuda_stream(int device_id) { + at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(device_id); + // fast invariant, default stream is always 0 + if (stream.id() == 0) return 0; + // convert to cudaStream_t + return reinterpret_cast(static_cast(stream)); + } + """ + result = cpp_extension.load_inline( + name="get_current_cuda_stream", + cpp_sources=[source], + cuda_sources=[], + extra_cflags=["-O3"], + extra_include_paths=cpp_extension.include_paths("cuda"), + functions=["get_current_cuda_stream"], + ) + return result.get_current_cuda_stream + + +def bench_torch_get_current_stream(repeat, name, func): + """ + Measures overhead of running torch.cuda.current_stream + """ + x = torch.arange(1, device="cuda") + func(0) + start = time.time() + for i in range(repeat): + func(0) + end = time.time() + speed = (end - start) / repeat + print_speed(f"torch.cuda.current_stream[{name}]", speed) + + def main(): repeat = 10000 print("-----------------------------") @@ -323,6 +374,8 @@ def main(): tvm_ffi_nop_from_torch_utils_to_dlpack(repeat) tvm_ffi_nop_autodlpack_from_torch(repeat, "cpu") tvm_ffi_nop_autodlpack_from_torch(repeat, "cuda") + tvm_ffi_nop_autodlpack_from_torch(repeat, "cuda", stream=True) + tvm_ffi_nop_autodlpack_from_numpy(repeat) print("-------------------------------") print("Benchmark x.__dlpack__ overhead") @@ -339,6 +392,19 @@ def main(): bench_to_dlpack_versioned( tvm_ffi.from_dlpack(torch.arange(1)), "tvm.__dlpack__(max_version=(1,1))", repeat ) + print("---------------------------------------------------") + print("Benchmark torch.get_cuda_stream[default stream]") + print("---------------------------------------------------") + bench_torch_get_current_stream(repeat, "cpp-extension", load_torch_get_current_cuda_stream()) + bench_torch_get_current_stream(repeat, "python", torch_get_cuda_stream_native) + print("---------------------------------------------------") + print("Benchmark torch.get_cuda_stream[non-default stream]") + print("---------------------------------------------------") + with torch.cuda.stream(torch.cuda.Stream()): + bench_torch_get_current_stream( + repeat, "cpp-extension", load_torch_get_current_cuda_stream() + ) + bench_torch_get_current_stream(repeat, "python", torch_get_cuda_stream_native) if __name__ == "__main__": diff --git a/python/tvm/ffi/cython/base.pxi b/python/tvm/ffi/cython/base.pxi index 00b76e68f74d..24c729095989 100644 --- a/python/tvm/ffi/cython/base.pxi +++ b/python/tvm/ffi/cython/base.pxi @@ -205,6 +205,14 @@ cdef extern from "tvm/ffi/c_api.h": DLTensor* TVMFFINDArrayGetDLTensorPtr(TVMFFIObjectHandle obj) nogil DLDevice TVMFFIDLDeviceFromIntPair(int32_t device_type, int32_t device_id) nogil +cdef extern from "tvm/ffi/extra/c_env_api.h": + ctypedef void* TVMFFIStreamHandle + + void* TVMFFIEnvGetCurrentStream(int32_t device_type, int32_t device_id) nogil + int TVMFFIEnvSetStream(int32_t device_type, int32_t device_id, + TVMFFIStreamHandle stream, + TVMFFIStreamHandle* opt_out_original_stream) nogil + cdef class ByteArrayArg: cdef TVMFFIByteArray cdata diff --git a/python/tvm/ffi/cython/function.pxi b/python/tvm/ffi/cython/function.pxi index 999c2e1338b5..3ab232e95997 100644 --- a/python/tvm/ffi/cython/function.pxi +++ b/python/tvm/ffi/cython/function.pxi @@ -18,11 +18,51 @@ import ctypes from numbers import Real, Integral try: + # optionally import torch and setup torch related utils import torch except ImportError: torch = None +def load_torch_get_current_cuda_stream(): + """Create a faster get_current_cuda_stream for torch through cpp extension. + """ + from torch.utils import cpp_extension + + source = """ + #include + + int64_t get_current_cuda_stream(int device_id) { + at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(device_id); + // fast invariant, default stream is always 0 + if (stream.id() == 0) return 0; + // convert to cudaStream_t + return reinterpret_cast(static_cast(stream)); + } + """ + def fallback_get_current_cuda_stream(device_id): + """Fallback with python api""" + return torch.cuda.current_stream(device_id).cuda_stream + return fallback_get_current_cuda_stream + try: + result = cpp_extension.load_inline( + name="get_current_cuda_stream", + cpp_sources=[source], + cuda_sources=[], + extra_cflags=["-O3"], + extra_include_paths=cpp_extension.include_paths("cuda"), + functions=["get_current_cuda_stream"], + ) + return result.get_current_cuda_stream + except Exception: + return fallback_get_current_cuda_stream + +if torch is not None: + # when torch is available, jit compile the get_current_cuda_stream function + # the torch caches the extension so second loading is faster + torch_get_current_cuda_stream = load_torch_get_current_cuda_stream() + + cdef inline object make_ret_small_str(TVMFFIAny result): """convert small string to return value.""" cdef TVMFFIByteArray bytes @@ -76,9 +116,13 @@ cdef inline object make_ret(TVMFFIAny result): raise ValueError("Unhandled type index %d" % type_index) -cdef inline int make_args(tuple py_args, TVMFFIAny* out, list temp_args) except -1: +cdef inline int make_args(tuple py_args, TVMFFIAny* out, list temp_args, + int* ctx_dev_type, int* ctx_dev_id, TVMFFIStreamHandle* ctx_stream) except -1: """Pack arguments into c args tvm call accept""" - cdef unsigned long long ptr + cdef unsigned long long temp_ptr + cdef DLTensor* temp_dltensor + cdef int is_cuda = 0 + for i, arg in enumerate(py_args): # clear the value to ensure zero padding on 32bit platforms if sizeof(void*) != 8: @@ -96,10 +140,18 @@ cdef inline int make_args(tuple py_args, TVMFFIAny* out, list temp_args) except out[i].type_index = TVMFFIObjectGetTypeIndex((arg).chandle) out[i].v_ptr = (arg).chandle elif torch is not None and isinstance(arg, torch.Tensor): + is_cuda = arg.is_cuda arg = from_dlpack(torch.utils.dlpack.to_dlpack(arg), required_alignment=__dlpack_auto_import_required_alignment__) out[i].type_index = kTVMFFINDArray out[i].v_ptr = (arg).chandle + temp_dltensor = TVMFFINDArrayGetDLTensorPtr((arg).chandle) + # record the stream and device for torch context + if is_cuda and ctx_dev_type != NULL and ctx_dev_type[0] == -1: + ctx_dev_type[0] = temp_dltensor.device.device_type + ctx_dev_id[0] = temp_dltensor.device.device_id + temp_ptr = torch_get_current_cuda_stream(temp_dltensor.device.device_id) + ctx_stream[0] = temp_ptr temp_args.append(arg) elif hasattr(arg, "__dlpack__"): arg = from_dlpack(arg, required_alignment=__dlpack_auto_import_required_alignment__) @@ -177,12 +229,27 @@ cdef inline int FuncCall3(void* chandle, # fast path with stack alloca for less than 3 args cdef TVMFFIAny[3] packed_args cdef int nargs = len(args) + cdef int ctx_dev_type = -1 + cdef int ctx_dev_id = 0 + cdef TVMFFIStreamHandle ctx_stream = NULL + cdef TVMFFIStreamHandle prev_stream = NULL temp_args = [] - make_args(args, &packed_args[0], temp_args) + make_args(args, &packed_args[0], temp_args, &ctx_dev_type, &ctx_dev_id, &ctx_stream) with nogil: + if ctx_dev_type != -1: + # set the stream based on ctx stream + c_api_ret_code[0] = TVMFFIEnvSetStream(ctx_dev_type, ctx_dev_id, ctx_stream, &prev_stream) + if c_api_ret_code[0] != 0: + return 0 c_api_ret_code[0] = TVMFFIFunctionCall( chandle, &packed_args[0], nargs, result ) + # restore the original stream if it is not the same as the context stream + if ctx_dev_type != -1 and prev_stream != ctx_stream: + # restore the original stream + c_api_ret_code[0] = TVMFFIEnvSetStream(ctx_dev_type, ctx_dev_id, prev_stream, NULL) + if c_api_ret_code[0] != 0: + return 0 return 0 @@ -191,6 +258,10 @@ cdef inline int FuncCall(void* chandle, TVMFFIAny* result, int* c_api_ret_code) except -1: cdef int nargs = len(args) + cdef int ctx_dev_type = -1 + cdef int ctx_dev_id = 0 + cdef TVMFFIStreamHandle ctx_stream = NULL + cdef TVMFFIStreamHandle prev_stream = NULL if nargs <= 3: FuncCall3(chandle, args, result, c_api_ret_code) @@ -200,10 +271,19 @@ cdef inline int FuncCall(void* chandle, packed_args.resize(nargs) temp_args = [] - make_args(args, &packed_args[0], temp_args) + make_args(args, &packed_args[0], temp_args, &ctx_dev_type, &ctx_dev_id, &ctx_stream) with nogil: + if ctx_dev_type != -1: + c_api_ret_code[0] = TVMFFIEnvSetStream(ctx_dev_type, ctx_dev_id, ctx_stream, &prev_stream) + if c_api_ret_code[0] != 0: + return 0 c_api_ret_code[0] = TVMFFIFunctionCall(chandle, &packed_args[0], nargs, result) + # restore the original stream if it is not the same as the context stream + if ctx_dev_type != -1 and prev_stream != ctx_stream: + c_api_ret_code[0] = TVMFFIEnvSetStream(ctx_dev_type, ctx_dev_id, prev_stream, NULL) + if c_api_ret_code[0] != 0: + return 0 return 0 @@ -274,7 +354,7 @@ cdef class FieldSetter: cdef void* field_ptr = ((obj).chandle) + self.offset cdef int nargs = 1 temp_args = [] - make_args((value,), &packed_args[0], temp_args) + make_args((value,), &packed_args[0], temp_args, NULL, NULL, NULL) c_api_ret_code = self.setter(field_ptr, &packed_args[0]) # NOTE: logic is same as check_call # directly inline here to simplify traceback @@ -412,7 +492,7 @@ cdef int tvm_ffi_callback(void* context, return -1 temp_args = [] - make_args((rv,), &temp_result, temp_args) + make_args((rv,), &temp_result, temp_args, NULL, NULL, NULL) CHECK_CALL(TVMFFIAnyViewToOwnedAny(&temp_result, result)) return 0