From 1c51bc3db8be5a8270aa0c955fe5b0c9d63288fa Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 11 Sep 2025 11:59:18 -0400 Subject: [PATCH] [FFI][ABI][REFACTOR] Enhance DLPack Exchange Speed and Behavior This PR enhances DLPack exchange by introducing DLPackPyObjectExporter, DLPackPyObjectImporter and DLPackTensorAllocator. These three function pointers will help us to speedup import/export with DLPack and also streamline the rare(but still useful sometimes) allocation inside the FFI. They can help to significantly speedup autodlpack import. They will also enable us to be able to query the allocator from env and return ffi::Tensor back to the caller environment(experimental), when a function takes torch.Tensor as argument, returned Tensor values will be converted to torch.Tensor. Also renames SetCurrentStream => SetStream to align with styles in CUDA API. Finally, we add option to select whether we release GIL, we release gil by default like ctypes, however, for short running functions it may be helpful to set func.release_gil = False --- ffi/CMakeLists.txt | 3 +- ffi/docs/get_started/quick_start.md | 4 +- ffi/examples/inline_module/main.py | 2 +- ffi/examples/quick_start/run_example.py | 2 +- ffi/examples/quick_start/src/add_one_cuda.cu | 4 +- ffi/include/tvm/ffi/c_api.h | 15 + ffi/include/tvm/ffi/container/tensor.h | 56 ++- ffi/include/tvm/ffi/extra/c_env_api.h | 31 +- ffi/licenses/LICENSE.pytorch.txt | 84 ++++ ffi/licenses/NOTICE.pytorch.txt | 456 ++++++++++++++++++ ffi/pyproject.toml | 2 +- ffi/python/tvm_ffi/__init__.py | 2 + .../tvm_ffi/_optional_torch_c_dlpack.py | 403 ++++++++++++++++ ffi/python/tvm_ffi/cython/base.pxi | 32 +- ffi/python/tvm_ffi/cython/function.pxi | 91 +++- ffi/python/tvm_ffi/cython/tensor.pxi | 70 +-- .../tvm_ffi/cython/tvm_ffi_python_helpers.h | 95 +++- ffi/python/tvm_ffi/libinfo.py | 23 + ffi/scripts/benchmark_dlpack.py | 5 +- ffi/src/ffi/extra/env_context.cc | 120 +++++ ffi/src/ffi/extra/stream_context.cc | 81 ---- ffi/tests/cpp/test_tensor.cc | 45 ++ ffi/tests/python/test_load_inline.py | 64 ++- ffi/tests/python/test_tensor.py | 22 +- .../contrib/cutlass/attention_operation.py | 8 +- .../tvm/contrib/cutlass/conv2d_operation.py | 2 +- python/tvm/contrib/cutlass/gemm_operation.py | 4 +- .../contrib/cutlass/layer_norm_operation.py | 2 +- .../tvm/contrib/cutlass/rms_norm_operation.py | 2 +- src/contrib/msc/plugin/tvm_codegen.cc | 2 +- src/runtime/contrib/cublas/cublas.cc | 2 +- .../contrib/cublas/cublas_json_runtime.cc | 2 +- src/runtime/contrib/cublas/cublas_utils.cc | 4 +- .../contrib/cudnn/cudnn_json_runtime.cc | 3 +- src/runtime/contrib/cudnn/cudnn_utils.cc | 4 +- .../contrib/cutlass/fp16_group_gemm.cuh | 2 +- src/runtime/contrib/cutlass/fp8_gemm.cu | 3 +- .../contrib/cutlass/fp8_group_gemm_sm90.cu | 3 +- .../cutlass/fp8_groupwise_scaled_gemm.cuh | 4 +- .../fp8_groupwise_scaled_group_gemm_sm100.cu | 3 +- .../contrib/hipblas/hipblas_json_runtime.cc | 2 +- src/runtime/contrib/hipblas/hipblas_utils.cc | 3 +- src/runtime/contrib/miopen/miopen_utils.cc | 3 +- src/runtime/contrib/msc/tensorrt_runtime.cc | 2 +- src/runtime/contrib/thrust/thrust.cu | 2 +- src/runtime/cuda/cuda_device_api.cc | 6 +- src/runtime/cuda/cuda_module.cc | 2 +- src/runtime/cuda/l2_cache_flush.cc | 2 +- src/runtime/device_api.cc | 5 +- src/runtime/rocm/rocm_device_api.cc | 4 +- src/runtime/rocm/rocm_module.cc | 2 +- src/runtime/vm/cuda/cuda_graph_builtin.cc | 11 +- 52 files changed, 1556 insertions(+), 250 deletions(-) create mode 100644 ffi/licenses/LICENSE.pytorch.txt create mode 100644 ffi/licenses/NOTICE.pytorch.txt create mode 100644 ffi/python/tvm_ffi/_optional_torch_c_dlpack.py create mode 100644 ffi/src/ffi/extra/env_context.cc delete mode 100644 ffi/src/ffi/extra/stream_context.cc diff --git a/ffi/CMakeLists.txt b/ffi/CMakeLists.txt index f927403cbde9..2767669bce24 100644 --- a/ffi/CMakeLists.txt +++ b/ffi/CMakeLists.txt @@ -73,7 +73,7 @@ set(tvm_ffi_extra_objs_sources "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/library_module.cc" "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/library_module_system_lib.cc" "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/library_module_dynamic_lib.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/stream_context.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/env_context.cc" "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/env_c_api.cc" "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/testing.cc" ) @@ -249,6 +249,7 @@ endif() install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/include/tvm/ffi/ DESTINATION include/tvm/ffi/) install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/3rdparty/dlpack/include/ DESTINATION include/) +install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/python/tvm_ffi/cython/tvm_ffi_python_helpers.h DESTINATION include/) install(TARGETS tvm_ffi_shared DESTINATION lib) # ship additional dSYM files for debugging symbols on if available if (APPLE) diff --git a/ffi/docs/get_started/quick_start.md b/ffi/docs/get_started/quick_start.md index c7cb007c7815..4861aa87b253 100644 --- a/ffi/docs/get_started/quick_start.md +++ b/ffi/docs/get_started/quick_start.md @@ -125,7 +125,7 @@ void AddOneCUDA(DLTensor* x, DLTensor* y) { // Get current CUDA stream from environment cudaStream_t stream = static_cast( - TVMFFIEnvGetCurrentStream(x->device.device_type, x->device.device_id)); + TVMFFIEnvGetStream(x->device.device_type, x->device.device_id)); // Launch kernel AddOneKernel<<>>( @@ -136,7 +136,7 @@ TVM_FFI_DLL_EXPORT_TYPED_FUNC(add_one_cuda, tvm_ffi_example::AddOneCUDA); ``` **Key Points:** -- We use `TVMFFIEnvGetCurrentStream` to obtain the current stream from the environement +- We use `TVMFFIEnvGetStream` to obtain the current stream from the environement - When invoking ffi Function from python end with PyTorch tensor as argument, the stream will be populated with torch's current stream. diff --git a/ffi/examples/inline_module/main.py b/ffi/examples/inline_module/main.py index b55574ae7bab..5cfcd41bec12 100644 --- a/ffi/examples/inline_module/main.py +++ b/ffi/examples/inline_module/main.py @@ -63,7 +63,7 @@ def main(): // it will be set to torch.cuda.current_stream() when calling the function // with torch.Tensors cudaStream_t stream = static_cast( - TVMFFIEnvGetCurrentStream(x->device.device_type, x->device.device_id)); + TVMFFIEnvGetStream(x->device.device_type, x->device.device_id)); // launch the kernel AddOneKernel<<>>(static_cast(x->data), static_cast(y->data), n); diff --git a/ffi/examples/quick_start/run_example.py b/ffi/examples/quick_start/run_example.py index 456e58ce91b9..a8f4fc00a600 100644 --- a/ffi/examples/quick_start/run_example.py +++ b/ffi/examples/quick_start/run_example.py @@ -64,7 +64,7 @@ def run_add_one_cuda(): with torch.cuda.stream(stream): # tvm-ffi automatically handles DLPack compatible tensors # it also handles interactions with torch runtime - # torch.cuda.current_stream() will be set and available via TVMFFIEnvGetCurrentStream + # torch.cuda.current_stream() will be set and available via TVMFFIEnvGetStream # when calling the function mod.add_one_cuda(x, y) stream.synchronize() diff --git a/ffi/examples/quick_start/src/add_one_cuda.cu b/ffi/examples/quick_start/src/add_one_cuda.cu index ead2ec89a95c..52f1e7482505 100644 --- a/ffi/examples/quick_start/src/add_one_cuda.cu +++ b/ffi/examples/quick_start/src/add_one_cuda.cu @@ -46,8 +46,8 @@ void AddOneCUDA(tvm::ffi::Tensor x, tvm::ffi::Tensor y) { // Obtain the current stream from the environment // it will be set to torch.cuda.current_stream() when calling the function // with torch.Tensors - cudaStream_t stream = static_cast( - TVMFFIEnvGetCurrentStream(x->device.device_type, x->device.device_id)); + cudaStream_t stream = + static_cast(TVMFFIEnvGetStream(x->device.device_type, x->device.device_id)); // launch the kernel AddOneKernel<<>>(static_cast(x->data), static_cast(y->data), n); diff --git a/ffi/include/tvm/ffi/c_api.h b/ffi/include/tvm/ffi/c_api.h index 5d67fcd22128..a53dac4d00af 100644 --- a/ffi/include/tvm/ffi/c_api.h +++ b/ffi/include/tvm/ffi/c_api.h @@ -27,6 +27,21 @@ #include #include +/* + * \brief C-style Allocator that allocates memory for a DLPack tensor. + * \param prototype The prototype DLTensor to offer details about device and shape. + * \param out The output DLManagedTensorVersioned. + * \param error_ctx The context to set the error. + * \param SetError The function to set the error. + * \return 0 on success, -1 on failure. + * call SetError(error_ctx, kind, message) to set the error kind and message. + * \note Error propagation via SetError. + */ +typedef int (*DLPackTensorAllocator)( // + DLTensor* prototype, DLManagedTensorVersioned** out, void* error_ctx, // + void (*SetError)(void* error_ctx, const char* kind, const char* message) // +); + // Macros to do weak linking #ifdef _MSC_VER #define TVM_FFI_WEAK __declspec(selectany) diff --git a/ffi/include/tvm/ffi/container/tensor.h b/ffi/include/tvm/ffi/container/tensor.h index 5e20b7b51df2..59dc7739ea63 100644 --- a/ffi/include/tvm/ffi/container/tensor.h +++ b/ffi/include/tvm/ffi/container/tensor.h @@ -32,6 +32,7 @@ #include #include +#include #include namespace tvm { @@ -341,7 +342,60 @@ class Tensor : public ObjectRef { return Tensor(make_object>( alloc, shape, dtype, device, std::forward(extra_args)...)); } - + /*! + * \brief Create a Tensor from a DLPackTensorAllocator + * + * This function can be used together with TVMFFIEnvSetTensorAllocator + * in the extra/c_env_api.h to create Tensor from the thread-local + * environment allocator. + * + * \code + * + * ffi::Tensor tensor = ffi::Tensor::FromDLPackAlloc( + * TVMFFIEnvGetTensorAllocator(), shape, dtype, device + * ); + * \endcode + * + * \param allocator The DLPack allocator. + * \param shape The shape of the Tensor. + * \param dtype The data type of the Tensor. + * \param device The device of the Tensor. + * \return The created Tensor. + */ + static Tensor FromDLPackAlloc(DLPackTensorAllocator allocator, ffi::Shape shape, DLDataType dtype, + DLDevice device) { + if (allocator == nullptr) { + TVM_FFI_THROW(RuntimeError) + << "FromDLPackAlloc: allocator is nullptr, " + << "likely because TVMFFIEnvSetTensorAllocator has not been called."; + } + DLTensor prototype; + prototype.device = device; + prototype.dtype = dtype; + prototype.shape = const_cast(shape.data()); + prototype.ndim = static_cast(shape.size()); + prototype.strides = nullptr; + prototype.byte_offset = 0; + prototype.data = nullptr; + DLManagedTensorVersioned* tensor = nullptr; + // error context to be used to propagate error + struct ErrorContext { + std::string kind; + std::string message; + static void SetError(void* error_ctx, const char* kind, const char* message) { + ErrorContext* error_context = static_cast(error_ctx); + error_context->kind = kind; + error_context->message = message; + } + }; + ErrorContext error_context; + int ret = (*allocator)(&prototype, &tensor, &error_context, ErrorContext::SetError); + if (ret != 0) { + throw ffi::Error(error_context.kind, error_context.message, + TVMFFITraceback(__FILE__, __LINE__, __func__, 0)); + } + return Tensor(make_object>(tensor)); + } /*! * \brief Create a Tensor from a DLPack managed tensor, pre v1.0 API. * \param tensor The input DLPack managed tensor. diff --git a/ffi/include/tvm/ffi/extra/c_env_api.h b/ffi/include/tvm/ffi/extra/c_env_api.h index bd0d188155fe..3c49d79d3071 100644 --- a/ffi/include/tvm/ffi/extra/c_env_api.h +++ b/ffi/include/tvm/ffi/extra/c_env_api.h @@ -46,12 +46,11 @@ typedef void* TVMFFIStreamHandle; * \param device_id The id of the device. * \param stream The stream to set. * \param opt_out_original_stream Output original stream if the address is not nullptr. - * \note The stream is a weak reference that is cached/owned by the module. * \return 0 when success, nonzero when failure happens */ -TVM_FFI_DLL int TVMFFIEnvSetCurrentStream(int32_t device_type, int32_t device_id, - TVMFFIStreamHandle stream, - TVMFFIStreamHandle* opt_out_original_stream); +TVM_FFI_DLL int TVMFFIEnvSetStream(int32_t device_type, int32_t device_id, + TVMFFIStreamHandle stream, + TVMFFIStreamHandle* opt_out_original_stream); /*! * \brief FFI function to get the current stream for a device @@ -60,7 +59,29 @@ TVM_FFI_DLL int TVMFFIEnvSetCurrentStream(int32_t device_type, int32_t device_id * \param device_id The id of the device. * \return The current stream of the device. */ -TVM_FFI_DLL TVMFFIStreamHandle TVMFFIEnvGetCurrentStream(int32_t device_type, int32_t device_id); +TVM_FFI_DLL TVMFFIStreamHandle TVMFFIEnvGetStream(int32_t device_type, int32_t device_id); + +/*! + * \brief FFI function to set the current DLPack allocator in thread-local(TLS) context + * + * \param allocator The allocator to set. + * \param write_to_global_context Whether to also set the allocator to the global context. + * \param opt_out_original_allocator Output original TLS allocator if the address is not nullptr. + * \return 0 when success, nonzero when failure happens + */ +TVM_FFI_DLL int TVMFFIEnvSetTensorAllocator(DLPackTensorAllocator allocator, + int write_to_global_context, + DLPackTensorAllocator* opt_out_original_allocator); + +/*! + * \brief FFI function get the current DLPack allocator stored in context. + * + * This function first queries the global context, and if not found, + * queries the thread-local context. + * + * \return The current DLPack allocator. + */ +TVM_FFI_DLL DLPackTensorAllocator TVMFFIEnvGetTensorAllocator(); /*! * \brief Check if there are any signals raised in the surrounding env. diff --git a/ffi/licenses/LICENSE.pytorch.txt b/ffi/licenses/LICENSE.pytorch.txt new file mode 100644 index 000000000000..966a609b61e5 --- /dev/null +++ b/ffi/licenses/LICENSE.pytorch.txt @@ -0,0 +1,84 @@ +From PyTorch: + +Copyright (c) 2016- Facebook, Inc (Adam Paszke) +Copyright (c) 2014- Facebook, Inc (Soumith Chintala) +Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) +Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) +Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) +Copyright (c) 2011-2013 NYU (Clement Farabet) +Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) +Copyright (c) 2006 Idiap Research Institute (Samy Bengio) +Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) + +From Caffe2: + +Copyright (c) 2016-present, Facebook Inc. All rights reserved. + +All contributions by Facebook: +Copyright (c) 2016 Facebook Inc. + +All contributions by Google: +Copyright (c) 2015 Google Inc. +All rights reserved. + +All contributions by Yangqing Jia: +Copyright (c) 2015 Yangqing Jia +All rights reserved. + +All contributions by Kakao Brain: +Copyright 2019-2020 Kakao Brain + +All contributions by Cruise LLC: +Copyright (c) 2022 Cruise LLC. +All rights reserved. + +All contributions by Tri Dao: +Copyright (c) 2024 Tri Dao. +All rights reserved. + +All contributions by Arm: +Copyright (c) 2021, 2023-2024 Arm Limited and/or its affiliates + +All contributions from Caffe: +Copyright(c) 2013, 2014, 2015, the respective contributors +All rights reserved. + +All other contributions: +Copyright(c) 2015, 2016 the respective contributors +All rights reserved. + +Caffe2 uses a copyright model similar to Caffe: each contributor holds +copyright over their contributions to Caffe2. The project versioning records +all such contribution and copyright details. If a contributor wants to further +mark their specific copyright on a particular contribution, they should +indicate their copyright solely in the commit message of the change when it is +committed. + +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America + and IDIAP Research Institute nor the names of its contributors may be + used to endorse or promote products derived from this software without + specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. diff --git a/ffi/licenses/NOTICE.pytorch.txt b/ffi/licenses/NOTICE.pytorch.txt new file mode 100644 index 000000000000..6effb8b5d707 --- /dev/null +++ b/ffi/licenses/NOTICE.pytorch.txt @@ -0,0 +1,456 @@ +======================================================================= +Software under third_party +======================================================================= +Software libraries under third_party are provided as github submodule +links, and their content is not part of the Caffe2 codebase. Their +licences can be found under the respective software repositories. + +======================================================================= +Earlier BSD License +======================================================================= +Early development of Caffe2 in 2015 and early 2016 is licensed under the +BSD license. The license is attached below: + +All contributions by Facebook: +Copyright (c) 2016 Facebook Inc. + +All contributions by Google: +Copyright (c) 2015 Google Inc. +All rights reserved. + +All contributions by Yangqing Jia: +Copyright (c) 2015 Yangqing Jia +All rights reserved. + +All contributions by Kakao Brain: +Copyright 2019-2020 Kakao Brain + +All other contributions: +Copyright(c) 2015, 2016 the respective contributors +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +======================================================================= +Caffe's BSD License +======================================================================= +Some parts of the caffe2 code is derived from the original Caffe code, which is +created by Yangqing Jia and is now a BSD-licensed open-source project. The Caffe +license is as follows: + +COPYRIGHT + +All contributions by the University of California: +Copyright (c) 2014, The Regents of the University of California (Regents) +All rights reserved. + +All other contributions: +Copyright (c) 2014, the respective contributors +All rights reserved. + +Caffe uses a shared copyright model: each contributor holds copyright over +their contributions to Caffe. The project versioning records all such +contribution and copyright details. If a contributor wants to further mark +their specific copyright on a particular contribution, they should indicate +their copyright solely in the commit message of the change when it is +committed. + +LICENSE + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +CONTRIBUTION AGREEMENT + +By contributing to the BVLC/caffe repository through pull-request, comment, +or otherwise, the contributor releases their content to the +license and copyright terms herein. + +======================================================================= +Caffe2's Apache License +======================================================================= + +This repo contains Caffe2 code, which was previously licensed under +Apache License Version 2.0: + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + +======================================================================= +Cephes's 3-Clause BSD License +======================================================================= + +Code derived from implementations in the Cephes Math Library should mention +its derivation and reference the following license: + + 3-Clause BSD License for the Cephes Math Library + Copyright (c) 2018, Steven Moshier + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + * Neither the name of the nor the + names of its contributors may be used to endorse or promote products + derived from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL Steven Moshier BE LIABLE FOR ANY + DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +======================================================================= +SciPy's 3-Clause BSD License +======================================================================= + +Code derived from implementations in SciPy should mention its derivation +and reference the following license: + + Copyright (c) 2001-2002 Enthought, Inc. 2003-2019, SciPy Developers. + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + 1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + disclaimer in the documentation and/or other materials provided + with the distribution. + + 3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +======================================================================= +Boost's 1.0 Software License +======================================================================= + +Code derived from implementations in Boost 1.0 should mention its +derivation and reference the following license: + + Boost Software License - Version 1.0 - August 17th, 2003 + + Permission is hereby granted, free of charge, to any person or organization + obtaining a copy of the software and accompanying documentation covered by + this license (the "Software") to use, reproduce, display, distribute, + execute, and transmit the Software, and to prepare derivative works of the + Software, and to permit third-parties to whom the Software is furnished to + do so, all subject to the following: + + The copyright notices in the Software and this entire statement, including + the above license grant, this restriction and the following disclaimer, + must be included in all copies of the Software, in whole or in part, and + all derivative works of the Software, unless such copies or derivative + works are solely in the form of machine-executable object code generated by + a source language processor. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT + SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE + FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, + ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + DEALINGS IN THE SOFTWARE. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +======================================================================= +PILLOW-SIMD Software License +======================================================================= + +Code derived from implementations in PILLOW-SIMD should mention its derivation +and reference the following license: + + The Python Imaging Library (PIL) is + + Copyright © 1997-2011 by Secret Labs AB + Copyright © 1995-2011 by Fredrik Lundh + + Pillow is the friendly PIL fork. It is + + Copyright © 2010-2022 by Alex Clark and contributors + + Like PIL, Pillow is licensed under the open source HPND License: + + By obtaining, using, and/or copying this software and/or its associated + documentation, you agree that you have read, understood, and will comply + with the following terms and conditions: + + Permission to use, copy, modify, and distribute this software and its + associated documentation for any purpose and without fee is hereby granted, + provided that the above copyright notice appears in all copies, and that + both that copyright notice and this permission notice appear in supporting + documentation, and that the name of Secret Labs AB or the author not be + used in advertising or publicity pertaining to distribution of the software + without specific, written prior permission. + + SECRET LABS AB AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS + SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. + IN NO EVENT SHALL SECRET LABS AB OR THE AUTHOR BE LIABLE FOR ANY SPECIAL, + INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM + LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE + OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR + PERFORMANCE OF THIS SOFTWARE. diff --git a/ffi/pyproject.toml b/ffi/pyproject.toml index 0988a78d6308..11e65a9065d2 100644 --- a/ffi/pyproject.toml +++ b/ffi/pyproject.toml @@ -17,7 +17,7 @@ [project] name = "apache-tvm-ffi" -version = "0.1.0a9" +version = "0.1.0a11" description = "tvm ffi" authors = [{ name = "TVM FFI team" }] diff --git a/ffi/python/tvm_ffi/__init__.py b/ffi/python/tvm_ffi/__init__.py index b0ff88c6c8e1..c23e8b59fee7 100644 --- a/ffi/python/tvm_ffi/__init__.py +++ b/ffi/python/tvm_ffi/__init__.py @@ -39,6 +39,8 @@ from . import access_path from . import testing +# optional module to speedup dlpack conversion +from . import _optional_torch_c_dlpack __all__ = [ "dtype", diff --git a/ffi/python/tvm_ffi/_optional_torch_c_dlpack.py b/ffi/python/tvm_ffi/_optional_torch_c_dlpack.py new file mode 100644 index 000000000000..f4af39302521 --- /dev/null +++ b/ffi/python/tvm_ffi/_optional_torch_c_dlpack.py @@ -0,0 +1,403 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Optional module to support faster DLPack conversion. + +This is an optional module to support faster DLPack conversion for torch. +Some of the changes are merged but not yet released, so it is used +as a stop gap to support faster DLPack conversion. + +This file contains source code from PyTorch: +License: licenses/LICENSE.pytorch.txt + +This module only serves as temp measure and will +likely be phased away and deleted after changes landed and released in pytorch. + +This module will load slowly at first time due to JITing, +subsequent calls will be much faster. +""" +import warnings +from . import libinfo + + +def load_torch_c_dlpack_extension(): + """Load the torch c dlpack extension.""" + cpp_source = """ +#include +#include +#include +#include + +using namespace std; +namespace at { +namespace { + +DLDataType getDLDataTypeForDLPackv1(const Tensor& t) { + DLDataType dtype; + dtype.lanes = 1; + dtype.bits = t.element_size() * 8; + switch (t.scalar_type()) { + case ScalarType::UInt1: + case ScalarType::UInt2: + case ScalarType::UInt3: + case ScalarType::UInt4: + case ScalarType::UInt5: + case ScalarType::UInt6: + case ScalarType::UInt7: + case ScalarType::Byte: + case ScalarType::UInt16: + case ScalarType::UInt32: + case ScalarType::UInt64: + dtype.code = DLDataTypeCode::kDLUInt; + break; + case ScalarType::Int1: + case ScalarType::Int2: + case ScalarType::Int3: + case ScalarType::Int4: + case ScalarType::Int5: + case ScalarType::Int6: + case ScalarType::Int7: + case ScalarType::Char: + dtype.code = DLDataTypeCode::kDLInt; + break; + case ScalarType::Double: + dtype.code = DLDataTypeCode::kDLFloat; + break; + case ScalarType::Float: + dtype.code = DLDataTypeCode::kDLFloat; + break; + case ScalarType::Int: + dtype.code = DLDataTypeCode::kDLInt; + break; + case ScalarType::Long: + dtype.code = DLDataTypeCode::kDLInt; + break; + case ScalarType::Short: + dtype.code = DLDataTypeCode::kDLInt; + break; + case ScalarType::Half: + dtype.code = DLDataTypeCode::kDLFloat; + break; + case ScalarType::Bool: + dtype.code = DLDataTypeCode::kDLBool; + break; + case ScalarType::ComplexHalf: + case ScalarType::ComplexFloat: + case ScalarType::ComplexDouble: + dtype.code = DLDataTypeCode::kDLComplex; + break; + case ScalarType::BFloat16: + dtype.code = DLDataTypeCode::kDLBfloat; + break; + case ScalarType::Float8_e5m2: + dtype.code = DLDataTypeCode::kDLFloat8_e5m2; + break; + case ScalarType::Float8_e5m2fnuz: + dtype.code = DLDataTypeCode::kDLFloat8_e5m2fnuz; + break; + case ScalarType::Float8_e4m3fn: + dtype.code = DLDataTypeCode::kDLFloat8_e4m3fn; + break; + case ScalarType::Float8_e4m3fnuz: + dtype.code = DLDataTypeCode::kDLFloat8_e4m3fnuz; + break; + case ScalarType::Float8_e8m0fnu: + dtype.code = DLDataTypeCode::kDLFloat8_e8m0fnu; + break; + case ScalarType::Float4_e2m1fn_x2: + dtype.code = DLDataTypeCode::kDLFloat4_e2m1fn; + break; + default: + TORCH_CHECK(false, "Unsupported scalar type: "); + } + return dtype; +} + +DLDevice torchDeviceToDLDeviceForDLPackv1(at::Device device) { + DLDevice ctx; + + ctx.device_id = (device.is_cuda() || device.is_privateuseone()) + ? static_cast(static_cast(device.index())) + : 0; + + switch (device.type()) { + case DeviceType::CPU: + ctx.device_type = DLDeviceType::kDLCPU; + break; + case DeviceType::CUDA: +#ifdef USE_ROCM + ctx.device_type = DLDeviceType::kDLROCM; +#else + ctx.device_type = DLDeviceType::kDLCUDA; +#endif + break; + case DeviceType::OPENCL: + ctx.device_type = DLDeviceType::kDLOpenCL; + break; + case DeviceType::HIP: + ctx.device_type = DLDeviceType::kDLROCM; + break; + case DeviceType::XPU: + ctx.device_type = DLDeviceType::kDLOneAPI; + ctx.device_id = at::detail::getXPUHooks().getGlobalIdxFromDevice(device); + break; + case DeviceType::MAIA: + ctx.device_type = DLDeviceType::kDLMAIA; + break; + case DeviceType::PrivateUse1: + ctx.device_type = DLDeviceType::kDLExtDev; + break; + case DeviceType::MPS: + ctx.device_type = DLDeviceType::kDLMetal; + break; + default: + TORCH_CHECK(false, "Cannot pack tensors on " + device.str()); + } + + return ctx; +} + +template +struct ATenDLMTensor { + Tensor handle; + T tensor{}; +}; + +template +void deleter(T* arg) { + delete static_cast*>(arg->manager_ctx); +} + +// Adds version information for DLManagedTensorVersioned. +// This is a no-op for the other types. +template +void fillVersion(T* tensor) {} + +template <> +void fillVersion( + DLManagedTensorVersioned* tensor) { + tensor->flags = 0; + tensor->version.major = DLPACK_MAJOR_VERSION; + tensor->version.minor = DLPACK_MINOR_VERSION; +} + +// This function returns a shared_ptr to memory managed DLpack tensor +// constructed out of ATen tensor +template +T* toDLPackImpl(const Tensor& src) { + auto view = src; + + bool need_normalize_strides = false; + int64_t expected_stride = 1; + for (int i = src.dim() - 1; i >= 0; i--) { + // detect if we do not meet continuous pattern + // and the size is 1, so there is opportunity to normalize + if (src.stride(i) != expected_stride && src.size(i) == 1) { + need_normalize_strides = true; + break; + } + expected_stride *= src.size(i); + } + + // less common case, try normalizing the strides + if (need_normalize_strides) { + // create a new tensor with possibly normalized strides + // gh-83069 + auto shape = src.sizes(); + auto strides = src.strides().vec(); + for (int i = 0; i < src.dim(); i++) { + if (shape[i] < 2) { + strides[i] = 1; + } + } + view = src.as_strided(shape, strides, src.storage_offset()); + } + + ATenDLMTensor* atDLMTensor(new ATenDLMTensor); + atDLMTensor->handle = view; + atDLMTensor->tensor.manager_ctx = atDLMTensor; + atDLMTensor->tensor.deleter = &deleter; + atDLMTensor->tensor.dl_tensor.data = view.data_ptr(); + atDLMTensor->tensor.dl_tensor.device = torchDeviceToDLDeviceForDLPackv1(src.device()); + atDLMTensor->tensor.dl_tensor.ndim = static_cast(src.dim()); + atDLMTensor->tensor.dl_tensor.dtype = getDLDataTypeForDLPackv1(src); + atDLMTensor->tensor.dl_tensor.shape = const_cast(view.sizes().data()); + atDLMTensor->tensor.dl_tensor.strides = const_cast(view.strides().data()); + atDLMTensor->tensor.dl_tensor.byte_offset = 0; + fillVersion(&atDLMTensor->tensor); + return &(atDLMTensor->tensor); +} + +static Device getATenDeviceForDLPackv1(DLDeviceType type, c10::DeviceIndex index, void* data = nullptr) { + switch (type) { + case DLDeviceType::kDLCPU: + return at::Device(DeviceType::CPU); +#ifndef USE_ROCM + // if we are compiled under HIP, we cannot do cuda + case DLDeviceType::kDLCUDA: + return at::Device(DeviceType::CUDA, index); +#endif + case DLDeviceType::kDLOpenCL: + return at::Device(DeviceType::OPENCL, index); + case DLDeviceType::kDLROCM: +#ifdef USE_ROCM + // this looks funny, we need to return CUDA here to masquerade + return at::Device(DeviceType::CUDA, index); +#else + return at::Device(DeviceType::HIP, index); +#endif + case DLDeviceType::kDLOneAPI: + TORCH_CHECK(data != nullptr, "Can't get ATen device for XPU without XPU data."); + return at::detail::getXPUHooks().getDeviceFromPtr(data); + case DLDeviceType::kDLMAIA: + return at::Device(DeviceType::MAIA, index); + case DLDeviceType::kDLExtDev: + return at::Device(DeviceType::PrivateUse1, index); + case DLDeviceType::kDLMetal: + return at::Device(DeviceType::MPS, index); + default: + TORCH_CHECK( + false, "Unsupported device_type: ", std::to_string(type)); + } +} + + +// This function constructs a Tensor from a memory managed DLPack which +// may be represented as either: DLManagedTensor and DLManagedTensorVersioned. +template +at::Tensor fromDLPackImpl(T* src, std::function deleter) { + if (!deleter) { + deleter = [src](void* self [[maybe_unused]]) { + if (src->deleter) { + src->deleter(src); + } + }; + } + + DLTensor& dl_tensor = src->dl_tensor; + Device device = getATenDeviceForDLPackv1(dl_tensor.device.device_type, dl_tensor.device.device_id, dl_tensor.data); + ScalarType stype = toScalarType(dl_tensor.dtype); + + if (!dl_tensor.strides) { + return at::from_blob( + dl_tensor.data, + IntArrayRef(dl_tensor.shape, dl_tensor.ndim), + std::move(deleter), + at::device(device).dtype(stype), + {device}); + } + return at::from_blob( + dl_tensor.data, + IntArrayRef(dl_tensor.shape, dl_tensor.ndim), + IntArrayRef(dl_tensor.strides, dl_tensor.ndim), + deleter, + at::device(device).dtype(stype), + {device}); +} + +} // namespace +} // namespace at + +int TorchDLPackPyObjectExporter(void* py_obj, DLManagedTensorVersioned** out, void** env_stream) { + try { + py::handle handle(static_cast(py_obj)); + at::Tensor tensor = handle.cast(); + if (env_stream != nullptr && tensor.is_cuda()) { + *env_stream = at::cuda::getCurrentCUDAStream(tensor.device().index()).stream(); + } + *out = at::toDLPackImpl(tensor); + return 0; + } catch (const std::exception& e) { + PyErr_SetString(PyExc_RuntimeError, e.what()); + return -1; + } +} + +int TorchDLPackPyObjectImporter(DLManagedTensorVersioned* src, void** py_obj_out) { + try { + at::Tensor tensor = at::fromDLPackImpl(src, nullptr); + *py_obj_out = THPVariable_Wrap(tensor); + return 0; + } catch (const std::exception& e) { + PyErr_SetString(PyExc_RuntimeError, e.what()); + return -1; + } +} + +int TorchDLPackTensorAllocator( + DLTensor* prototype, DLManagedTensorVersioned** out, void* error_ctx, + void (*SetError)(void* error_ctx, const char* kind, const char* message) +) { + try { + at::IntArrayRef shape(prototype->shape, prototype->shape + prototype->ndim); + at::TensorOptions options = at::TensorOptions() + .dtype(at::toScalarType(prototype->dtype)) + .device(at::getATenDeviceForDLPackv1(prototype->device.device_type, prototype->device.device_id)); + at::Tensor tensor = at::empty(shape, options); + *out = at::toDLPackImpl(tensor); + return 0; + } catch (const std::exception& e) { + SetError(error_ctx, "TorchDLPackTensorAllocator", e.what()); + return -1; + } +} + +int64_t TorchDLPackPyObjectExporterPtr() { + return reinterpret_cast(TorchDLPackPyObjectExporter); +} + +int64_t TorchDLPackPyObjectImporterPtr() { + return reinterpret_cast(TorchDLPackPyObjectImporter); +} + +int64_t TorchDLPackTensorAllocatorPtr() { + return reinterpret_cast(TorchDLPackTensorAllocator); +} + """ + try: + # optionally import torch + import torch + from torch.utils import cpp_extension + + mod = cpp_extension.load_inline( + name="to_dlpack", + cpp_sources=cpp_source, + functions=[ + "TorchDLPackPyObjectExporterPtr", + "TorchDLPackPyObjectImporterPtr", + "TorchDLPackTensorAllocatorPtr", + ], + extra_cflags=["-O3"], + extra_include_paths=libinfo.include_paths() + cpp_extension.include_paths("cuda"), + verbose=True, + ) + # set the dlpack related flags + torch.Tensor.__c_dlpack_exporter__ = mod.TorchDLPackPyObjectExporterPtr() + torch.Tensor.__c_dlpack_importer__ = mod.TorchDLPackPyObjectImporterPtr() + torch.Tensor.__c_dlpack_tensor_allocator__ = mod.TorchDLPackTensorAllocatorPtr() + return mod + except ImportError: + pass + except Exception as e: + warnings.warn( + f"Failed to load torch c dlpack extension: {e}," + "EnvTensorAllocator will not be enabled." + ) + return None + + +# keep alive +_mod = load_torch_c_dlpack_extension() diff --git a/ffi/python/tvm_ffi/cython/base.pxi b/ffi/python/tvm_ffi/cython/base.pxi index 08b01d424f1f..a1de1de1cd89 100644 --- a/ffi/python/tvm_ffi/cython/base.pxi +++ b/ffi/python/tvm_ffi/cython/base.pxi @@ -238,27 +238,39 @@ cdef extern from "tvm/ffi/extra/c_env_api.h": ctypedef void* TVMFFIStreamHandle int TVMFFIEnvRegisterCAPI(const char* name, void* ptr) nogil - void* TVMFFIEnvGetCurrentStream(int32_t device_type, int32_t device_id) nogil - int TVMFFIEnvSetCurrentStream(int32_t device_type, int32_t device_id, + void* TVMFFIEnvGetStream(int32_t device_type, int32_t device_id) nogil + int TVMFFIEnvSetStream(int32_t device_type, int32_t device_id, TVMFFIStreamHandle stream, TVMFFIStreamHandle* opt_out_original_stream) nogil cdef extern from "tvm_ffi_python_helpers.h": # no need to expose fields of the call context + # setter data structure + ctypedef int (*DLPackPyObjectExporter)( + void* py_obj, DLManagedTensorVersioned** out, TVMFFIStreamHandle* env_stream + ) except -1 + + ctypedef int (*DLPackPyObjectImporter)( + DLManagedTensorVersioned* tensor, void** py_obj_out + ) except -1 + ctypedef int (*DLPackTensorAllocator)( + DLTensor* prototype, DLManagedTensorVersioned** out, void* error_ctx, + void (*SetError)(void* error_ctx, const char* kind, const char* message) + ) except -1 + ctypedef struct TVMFFIPyCallContext: int device_type int device_id TVMFFIStreamHandle stream - - # setter data structure - ctypedef int (*DLPackPyObjectCExporter)( - void* py_obj, DLManagedTensorVersioned** out, TVMFFIStreamHandle* env_stream - ) except -1 + DLPackPyObjectImporter c_dlpack_importer + DLPackTensorAllocator c_dlpack_tensor_allocator ctypedef struct TVMFFIPyArgSetter: int (*func)(TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx, PyObject* py_arg, TVMFFIAny* out) except -1 - DLPackPyObjectCExporter dlpack_c_exporter + DLPackPyObjectExporter c_dlpack_exporter + DLPackPyObjectImporter c_dlpack_importer + DLPackTensorAllocator c_dlpack_tensor_allocator ctypedef int (*TVMFFIPyArgSetterFactory)(PyObject* value, TVMFFIPyArgSetter* out) except -1 # The main call function @@ -267,7 +279,9 @@ cdef extern from "tvm_ffi_python_helpers.h": void* chandle, PyObject* py_arg_tuple, TVMFFIAny* result, - int* c_api_ret_code + int* c_api_ret_code, + int release_gil, + DLPackPyObjectImporter* out_dlpack_importer ) except -1 int TVMFFIPyCallFieldSetter( diff --git a/ffi/python/tvm_ffi/cython/function.pxi b/ffi/python/tvm_ffi/cython/function.pxi index b77b19a2eabb..bd486c5f77f5 100644 --- a/ffi/python/tvm_ffi/cython/function.pxi +++ b/ffi/python/tvm_ffi/cython/function.pxi @@ -29,8 +29,9 @@ else: torch = None -_torch_dlpack_c_exporter_ptr = None - +cdef int _RELEASE_GIL_BY_DEFAULT = int( + os.environ.get("TVM_FFI_RELEASE_GIL_BY_DEFAULT", "1") +) cdef inline object make_ret_small_str(TVMFFIAny result): """convert small string to return value.""" @@ -46,13 +47,13 @@ cdef inline object make_ret_small_bytes(TVMFFIAny result): return PyBytes_FromStringAndSize(bytes.data, bytes.size) -cdef inline object make_ret(TVMFFIAny result): +cdef inline object make_ret(TVMFFIAny result, DLPackPyObjectImporter c_dlpack_importer = NULL): """convert result to return value.""" cdef int32_t type_index type_index = result.type_index if type_index == kTVMFFITensor: # specially handle Tensor as it needs a special dltensor field - return make_tensor_from_any(result) + return make_tensor_from_any(result, c_dlpack_importer) elif type_index == kTVMFFIOpaquePyObject: return make_ret_opaque_object(result) elif type_index >= kTVMFFIStaticObjectBegin: @@ -120,13 +121,18 @@ cdef int TVMFFIPyArgSetterDLPackCExporter_( cdef TVMFFIObjectHandle temp_chandle cdef TVMFFIStreamHandle env_stream = NULL + if this.c_dlpack_importer != NULL: + ctx.c_dlpack_importer = this.c_dlpack_importer + if this.c_dlpack_tensor_allocator != NULL: + ctx.c_dlpack_tensor_allocator = this.c_dlpack_tensor_allocator + if ctx.device_id != -1: # already queried device, do not do it again, pass NULL to stream - if (this.dlpack_c_exporter)(arg, &temp_managed_tensor, NULL) != 0: + if (this.c_dlpack_exporter)(arg, &temp_managed_tensor, NULL) != 0: return -1 else: # query string on the envrionment stream - if (this.dlpack_c_exporter)(arg, &temp_managed_tensor, &env_stream) != 0: + if (this.c_dlpack_exporter)(arg, &temp_managed_tensor, &env_stream) != 0: return -1 # If device is not CPU, we should set the device type and id if temp_managed_tensor.dl_tensor.device.device_type != kDLCPU: @@ -142,17 +148,32 @@ cdef int TVMFFIPyArgSetterDLPackCExporter_( return 0 -cdef int TVMFFIPyArgSetterTorch_( +cdef int TorchDLPackPyObjectImporterFallback_( + DLManagedTensorVersioned* dltensor, void** py_obj_out +) except -1: + # a bit convoluted but ok as a fallback + cdef TVMFFIObjectHandle temp_chandle + TVMFFITensorFromDLPackVersioned(dltensor, 0, 0, &temp_chandle) + tensor = make_tensor_from_chandle(temp_chandle) + torch_tensor = torch.from_dlpack(tensor) + Py_INCREF(torch_tensor) + py_obj_out[0] = (torch_tensor) + return 0 + + +cdef int TVMFFIPyArgSetterTorchFallback_( TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx, PyObject* py_arg, TVMFFIAny* out ) except -1: """Current setter for torch.Tensor, go through python and not as fast as c exporter""" + # TODO(tqchen): remove this once torch always support fast DLPack importer cdef object arg = py_arg is_cuda = arg.is_cuda arg = from_dlpack(torch.utils.dlpack.to_dlpack(arg)) out.type_index = kTVMFFITensor out.v_ptr = (arg).chandle temp_dltensor = TVMFFITensorGetDLTensorPtr((arg).chandle) + ctx.c_dlpack_importer = TorchDLPackPyObjectImporterFallback_ # record the stream and device for torch context if is_cuda and ctx.device_type != -1: ctx.device_type = temp_dltensor.device.device_type @@ -180,10 +201,10 @@ cdef int TVMFFIPyArgSetterDLPack_( if (temp_dltensor.device.device_type != kDLCPU and ctx.device_type != -1): # __tvm_ffi_env_stream__ returns the expected stream that should be set - # through TVMFFIEnvSetCurrentStream when calling a TVM FFI function + # through TVMFFIEnvSetStream when calling a TVM FFI function if hasattr(arg, "__tvm_ffi_env_stream__"): # Ideally projects should directly setup their stream context API - # write through by also calling TVMFFIEnvSetCurrentStream + # write through by also calling TVMFFIEnvSetStream # so we do not need this protocol to do exchange ctx.device_type = temp_dltensor.device.device_type ctx.device_id = temp_dltensor.device.device_id @@ -349,19 +370,21 @@ cdef int TVMFFIPyArgSetterFactory_(PyObject* value, TVMFFIPyArgSetter* out) exce if isinstance(arg, ObjectRValueRef): out.func = TVMFFIPyArgSetterObjectRValueRef_ return 0 - # external tensors - if hasattr(arg, "__dlpack_c_exporter__"): - out.func = TVMFFIPyArgSetterDLPackCExporter_ - temp_ptr = arg.__dlpack_c_exporter__ - out.dlpack_c_exporter = temp_ptr - return 0 - if torch is not None and isinstance(arg, torch.Tensor): - if _torch_dlpack_c_exporter_ptr is not None: - temp_ptr = _torch_dlpack_c_exporter_ptr + if os.environ.get("TVM_FFI_SKIP_C_DLPACK_EXPORTER", "0") != "1": + # external tensors + if hasattr(arg, "__c_dlpack_exporter__"): out.func = TVMFFIPyArgSetterDLPackCExporter_ - out.dlpack_c_exporter = temp_ptr - else: - out.func = TVMFFIPyArgSetterTorch_ + temp_ptr = arg.__c_dlpack_exporter__ + out.c_dlpack_exporter = temp_ptr + if hasattr(arg, "__c_dlpack_importer__"): + temp_ptr = arg.__c_dlpack_importer__ + out.c_dlpack_importer = temp_ptr + if hasattr(arg, "__c_dlpack_tensor_allocator__"): + temp_ptr = arg.__c_dlpack_tensor_allocator__ + out.c_dlpack_tensor_allocator = temp_ptr + return 0 + if torch is not None and isinstance(arg, torch.Tensor): + out.func = TVMFFIPyArgSetterTorchFallback_ return 0 if hasattr(arg, "__dlpack__"): out.func = TVMFFIPyArgSetterDLPack_ @@ -415,13 +438,16 @@ cdef inline int ConstructorCall(void* constructor_handle, # IMPORTANT: caller need to initialize result->type_index to kTVMFFINone result.type_index = kTVMFFINone result.v_int64 = 0 - TVMFFIPyFuncCall(TVMFFIPyArgSetterFactory_, constructor_handle, args, &result, &c_api_ret_code) + TVMFFIPyFuncCall( + TVMFFIPyArgSetterFactory_, constructor_handle, args, &result, &c_api_ret_code, + False, NULL + ) CHECK_CALL(c_api_ret_code) handle[0] = result.v_ptr return 0 -class Function(Object): +cdef class Function(Object): """Python class that wraps a function with tvm-ffi ABI. See Also @@ -429,9 +455,22 @@ class Function(Object): tvm_ffi.register_global_func: How to register global function. tvm_ffi.get_global_func: How to get global function. """ + cdef int c_release_gil + cdef dict __dict__ + + def __cinit__(self): + self.c_release_gil = _RELEASE_GIL_BY_DEFAULT + + property release_gil: + def __get__(self): + return self.c_release_gil != 0 + def __set__(self, value): + self.c_release_gil = value + def __call__(self, *args): cdef TVMFFIAny result cdef int c_api_ret_code + cdef DLPackPyObjectImporter c_dlpack_importer = NULL # IMPORTANT: caller need to initialize result->type_index to kTVMFFINone result.type_index = kTVMFFINone result.v_int64 = 0 @@ -439,12 +478,14 @@ class Function(Object): TVMFFIPyArgSetterFactory_, (self).chandle, args, &result, - &c_api_ret_code + &c_api_ret_code, + self.release_gil, + &c_dlpack_importer ) # NOTE: logic is same as check_call # directly inline here to simplify traceback if c_api_ret_code == 0: - return make_ret(result) + return make_ret(result, c_dlpack_importer) elif c_api_ret_code == -2: raise_existing_error() raise move_from_last_error().py_error() diff --git a/ffi/python/tvm_ffi/cython/tensor.pxi b/ffi/python/tvm_ffi/cython/tensor.pxi index fca6cc0bbc08..2fd80bc1a6c8 100644 --- a/ffi/python/tvm_ffi/cython/tensor.pxi +++ b/ffi/python/tvm_ffi/cython/tensor.pxi @@ -51,9 +51,8 @@ cdef inline object _from_dlpack_intptr( cdef int c_api_ret_code cdef int c_req_alignment = 0 cdef int c_req_contiguous = 0 - with nogil: - c_api_ret_code = TVMFFITensorFromDLPack( - ptr, c_req_alignment, c_req_contiguous, &chandle) + c_api_ret_code = TVMFFITensorFromDLPack( + ptr, c_req_alignment, c_req_contiguous, &chandle) CHECK_CALL(c_api_ret_code) return make_tensor_from_chandle(chandle) @@ -68,9 +67,8 @@ cdef inline int _from_dlpack( cdef int c_req_contiguous = require_contiguous if pycapsule.PyCapsule_IsValid(dltensor, _c_str_dltensor): ptr = pycapsule.PyCapsule_GetPointer(dltensor, _c_str_dltensor) - with nogil: - c_api_ret_code = TVMFFITensorFromDLPack( - ptr, c_req_alignment, c_req_contiguous, out) + c_api_ret_code = TVMFFITensorFromDLPack( + ptr, c_req_alignment, c_req_contiguous, out) CHECK_CALL(c_api_ret_code) # set name and destructor to be empty pycapsule.PyCapsule_SetDestructor(dltensor, NULL) @@ -90,9 +88,8 @@ cdef inline int _from_dlpack_versioned( if pycapsule.PyCapsule_IsValid(dltensor, _c_str_dltensor_versioned): ptr = pycapsule.PyCapsule_GetPointer( dltensor, _c_str_dltensor_versioned) - with nogil: - c_api_ret_code = TVMFFITensorFromDLPackVersioned( - ptr, c_req_alignment, c_req_contiguous, out) + c_api_ret_code = TVMFFITensorFromDLPackVersioned( + ptr, c_req_alignment, c_req_contiguous, out) CHECK_CALL(c_api_ret_code) # set name and destructor to be empty pycapsule.PyCapsule_SetDestructor(dltensor, NULL) @@ -209,18 +206,14 @@ cdef class Tensor(Object): def _to_dlpack(self): cdef DLManagedTensor* dltensor cdef int c_api_ret_code - - with nogil: - c_api_ret_code = TVMFFITensorToDLPack(self.chandle, &dltensor) + c_api_ret_code = TVMFFITensorToDLPack(self.chandle, &dltensor) CHECK_CALL(c_api_ret_code) return pycapsule.PyCapsule_New(dltensor, _c_str_dltensor, _c_dlpack_deleter) def _to_dlpack_versioned(self): cdef DLManagedTensorVersioned* dltensor cdef int c_api_ret_code - - with nogil: - c_api_ret_code = TVMFFITensorToDLPackVersioned(self.chandle, &dltensor) + c_api_ret_code = TVMFFITensorToDLPackVersioned(self.chandle, &dltensor) CHECK_CALL(c_api_ret_code) return pycapsule.PyCapsule_New( dltensor, _c_str_dltensor_versioned, _c_dlpack_versioned_deleter) @@ -282,24 +275,24 @@ _set_class_tensor(Tensor) _register_object_by_index(kTVMFFITensor, Tensor) - -cdef int _dltensor_test_wrapper_dlpack_c_exporter( +cdef int _dltensor_test_wrapper_c_dlpack_exporter( void* obj, DLManagedTensorVersioned** out, TVMFFIStreamHandle* env_stream ) except -1: - cdef object ref_obj = (obj) - cdef DLTensorTestWrapper wrapper = ref_obj + cdef PyObject* py_obj = obj + cdef DLTensorTestWrapper wrapper = py_obj cdef TVMFFIStreamHandle current_stream - + cdef DLManagedTensorVersioned* temp_managed_tensor if env_stream != NULL: - env_stream[0] = TVMFFIEnvGetCurrentStream( + env_stream[0] = TVMFFIEnvGetStream( wrapper.tensor.cdltensor.device.device_type, wrapper.tensor.cdltensor.device.device_id ) + return TVMFFITensorToDLPackVersioned(wrapper.tensor.chandle, out) -def _dltensor_test_wrapper_dlpack_c_exporter_as_intptr(): - cdef DLPackPyObjectCExporter converter_func = _dltensor_test_wrapper_dlpack_c_exporter +def _dltensor_test_wrapper_c_dlpack_exporter_as_intptr(): + cdef DLPackPyObjectExporter converter_func = _dltensor_test_wrapper_c_dlpack_exporter cdef void* temp_ptr = converter_func cdef long long temp_int_ptr = temp_ptr return temp_int_ptr @@ -308,8 +301,10 @@ def _dltensor_test_wrapper_dlpack_c_exporter_as_intptr(): cdef class DLTensorTestWrapper: """Wrapper of a Tensor that exposes DLPack protocol, only for testing purpose. """ - __dlpack_c_exporter__ = _dltensor_test_wrapper_dlpack_c_exporter_as_intptr() + __c_dlpack_exporter__ = _dltensor_test_wrapper_c_dlpack_exporter_as_intptr() + cdef Tensor tensor + cdef dict __dict__ def __init__(self, tensor): self.tensor = tensor @@ -317,9 +312,8 @@ cdef class DLTensorTestWrapper: cdef TVMFFIStreamHandle stream cdef long long stream_as_int cdef int c_api_ret_code - with nogil: - stream = TVMFFIEnvGetCurrentStream( - self.tensor.cdltensor.device.device_type, self.tensor.cdltensor.device.device_id) + stream = TVMFFIEnvGetStream( + self.tensor.cdltensor.device.device_type, self.tensor.cdltensor.device.device_id) stream_as_int = stream return stream_as_int @@ -339,14 +333,30 @@ cdef inline object make_ret_dltensor(TVMFFIAny result): return tensor -cdef inline object make_tensor_from_chandle(TVMFFIObjectHandle chandle): +cdef inline object make_tensor_from_chandle(TVMFFIObjectHandle chandle, DLPackPyObjectImporter c_dlpack_importer = NULL): # TODO: Implement cdef Tensor tensor + cdef void* py_obj + cdef DLManagedTensorVersioned* dlpack + + if c_dlpack_importer != NULL: + # try convert and import into the environment array if possible + if TVMFFITensorToDLPackVersioned(chandle, &dlpack) == 0: + try: + # note that py_obj already holds an extra reference to the tensor + # so we need to decref it after the conversion + c_dlpack_importer(dlpack, &py_obj) + tensor = (py_obj) + Py_DECREF(tensor) + return tensor + except Exception: + pass + # default return the tensor tensor = _CLASS_TENSOR.__new__(_CLASS_TENSOR) (tensor).chandle = chandle (tensor).cdltensor = TVMFFITensorGetDLTensorPtr(chandle) return tensor -cdef inline object make_tensor_from_any(TVMFFIAny any): - return make_tensor_from_chandle(any.v_ptr) +cdef inline object make_tensor_from_any(TVMFFIAny any, DLPackPyObjectImporter c_dlpack_importer): + return make_tensor_from_chandle(any.v_ptr, c_dlpack_importer) diff --git a/ffi/python/tvm_ffi/cython/tvm_ffi_python_helpers.h b/ffi/python/tvm_ffi/cython/tvm_ffi_python_helpers.h index 32ded385bae8..c7d847b85780 100644 --- a/ffi/python/tvm_ffi/cython/tvm_ffi_python_helpers.h +++ b/ffi/python/tvm_ffi/cython/tvm_ffi_python_helpers.h @@ -27,13 +27,40 @@ #include #include +#include #include +#include #include +//---------------------------------------------------------- +// Extra support for DLPack +//---------------------------------------------------------- +/*! + * \brief C-style function pointer to speed convert a PyObject Tensor to a DLManagedTensorVersioned. + * \param py_obj The Python object to convert, this should be PyObject* + * \param out The output DLManagedTensorVersioned. + * \param env_stream Outputs the current context stream of the device provided by the tensor. + * \return 0 on success, -1 on failure. PyError should be set if -1 is returned. + * \note We use void* to avoid dependency on Python.h so this specific type is + * not dependent on Python.h and can be copied to dlpack.h + */ +typedef int (*DLPackPyObjectExporter)(void* py_obj, DLManagedTensorVersioned** out, + void** env_stream); +/*! + * \brief C-style function pointer to speed convert a DLManagedTensorVersioned to a PyObject Tensor. + * \param tensor The DLManagedTensorVersioned to convert. + * \param py_obj_out The output Python object. + * \return 0 on success, -1 on failure. PyError should be set if -1 is returned. + * \note We use void* to avoid dependency on Python.h so this specific type is + * not dependent on Python.h and can be copied to dlpack.h + */ +typedef int (*DLPackPyObjectImporter)(DLManagedTensorVersioned* tensor, void** py_obj_out); + ///-------------------------------------------------------------------------------- /// We deliberately designed the data structure and function to be C-style // prefixed with TVMFFIPy so they can be easily invoked through Cython. ///-------------------------------------------------------------------------------- + /*! * \brief Context for each ffi call to track the stream, device and temporary arguments. */ @@ -54,20 +81,12 @@ struct TVMFFIPyCallContext { void** temp_py_objects = nullptr; /*! \brief the number of temporary arguments */ int num_temp_py_objects = 0; + /*! \brief the DLPack exporter, if any */ + DLPackPyObjectImporter c_dlpack_importer{nullptr}; + /*! \brief the DLPack allocator, if any */ + DLPackTensorAllocator c_dlpack_tensor_allocator{nullptr}; }; -/*! - * \brief C-style function pointer to speed convert a Tensor to a DLManagedTensorVersioned. - * \param py_obj The Python object to convert, this should be PyObject* - * \param out The output DLManagedTensorVersioned. - * \param env_stream Outputs the current context stream of the device provided by the tensor. - * \return 0 on success, -1 on failure. PyError should be set if -1 is returned. - * \note We use void* to avoid dependency on Python.h so this specific type is - * not dependent on Python.h and can be copied to dlpack.h - */ -typedef int (*DLPackPyObjectCExporter)(void* py_obj, DLManagedTensorVersioned** out, - void** env_stream); - /*! \brief Argument setter for a given python argument. */ struct TVMFFIPyArgSetter { /*! @@ -83,7 +102,15 @@ struct TVMFFIPyArgSetter { /*! * \brief Optional DLPack exporter for for setters that leverages DLPack protocol. */ - DLPackPyObjectCExporter dlpack_c_exporter{nullptr}; + DLPackPyObjectExporter c_dlpack_exporter{nullptr}; + /*! + * \brief Optional DLPack importer for for setters that leverages DLPack protocol. + */ + DLPackPyObjectImporter c_dlpack_importer{nullptr}; + /*! + * \brief Optional DLPack allocator for for setters that leverages DLPack protocol. + */ + DLPackTensorAllocator c_dlpack_tensor_allocator{nullptr}; /*! * \brief Invoke the setter. * \param call_ctx The call context. @@ -239,11 +266,14 @@ class TVMFFIPyCallManager { * \param py_arg_tuple The arguments to the function * \param result The result of the function * \param c_api_ret_code The return code of the C-call + * \param release_gil Whether to release the GIL + * \param optional_out_dlpack_importer The DLPack importer to be used for the result * \return 0 on when there is no python error, -1 on python error * \note When an error happens on FFI side, we should return 0 and set c_api_ret_code */ int Call(TVMFFIPyArgSetterFactory setter_factory, void* func_handle, PyObject* py_arg_tuple, - TVMFFIAny* result, int* c_api_ret_code) { + TVMFFIAny* result, int* c_api_ret_code, bool release_gil, + DLPackPyObjectImporter* optional_out_dlpack_importer) { int64_t num_args = PyTuple_Size(py_arg_tuple); if (num_args == -1) return -1; try { @@ -256,27 +286,44 @@ class TVMFFIPyCallManager { if (SetArgument(setter_factory, &ctx, py_arg, c_arg) != 0) return -1; } TVMFFIStreamHandle prev_stream = nullptr; + DLPackTensorAllocator prev_tensor_allocator = nullptr; // setup stream context if needed if (ctx.device_type != -1) { c_api_ret_code[0] = - TVMFFIEnvSetCurrentStream(ctx.device_type, ctx.device_id, ctx.stream, &prev_stream); + TVMFFIEnvSetStream(ctx.device_type, ctx.device_id, ctx.stream, &prev_stream); // setting failed, directly return if (c_api_ret_code[0] != 0) return 0; } + if (ctx.c_dlpack_tensor_allocator != nullptr) { + c_api_ret_code[0] = + TVMFFIEnvSetTensorAllocator(ctx.c_dlpack_tensor_allocator, 0, &prev_tensor_allocator); + if (c_api_ret_code[0] != 0) return 0; + } // call the function - // release the GIL - Py_BEGIN_ALLOW_THREADS; - c_api_ret_code[0] = TVMFFIFunctionCall(func_handle, ctx.packed_args, num_args, result); - Py_END_ALLOW_THREADS; + if (release_gil) { + // release the GIL + Py_BEGIN_ALLOW_THREADS; + c_api_ret_code[0] = TVMFFIFunctionCall(func_handle, ctx.packed_args, num_args, result); + Py_END_ALLOW_THREADS; + } else { + c_api_ret_code[0] = TVMFFIFunctionCall(func_handle, ctx.packed_args, num_args, result); + } // restore the original stream if (ctx.device_type != -1 && prev_stream != ctx.stream) { // always try recover first, even if error happens - if (TVMFFIEnvSetCurrentStream(ctx.device_type, ctx.device_id, prev_stream, nullptr) != 0) { + if (TVMFFIEnvSetStream(ctx.device_type, ctx.device_id, prev_stream, nullptr) != 0) { // recover failed, set python error PyErr_SetString(PyExc_RuntimeError, "Failed to recover stream"); return -1; } } + if (prev_tensor_allocator != ctx.c_dlpack_tensor_allocator) { + c_api_ret_code[0] = TVMFFIEnvSetTensorAllocator(prev_tensor_allocator, 0, nullptr); + if (c_api_ret_code[0] != 0) return 0; + } + if (optional_out_dlpack_importer != nullptr && ctx.c_dlpack_importer != nullptr) { + *optional_out_dlpack_importer = ctx.c_dlpack_importer; + } return 0; } catch (const std::exception& ex) { // very rare, catch c++ exception and set python error @@ -376,12 +423,16 @@ class TVMFFIPyCallManager { * \param py_arg_tuple The arguments to the function * \param result The result of the function * \param c_api_ret_code The return code of the function + * \param release_gil Whether to release the GIL + * \param out_dlpack_exporter The DLPack exporter to be used for the result * \return 0 on success, nonzero on failure */ inline int TVMFFIPyFuncCall(TVMFFIPyArgSetterFactory setter_factory, void* func_handle, - PyObject* py_arg_tuple, TVMFFIAny* result, int* c_api_ret_code) { + PyObject* py_arg_tuple, TVMFFIAny* result, int* c_api_ret_code, + bool release_gil = true, + DLPackPyObjectImporter* out_dlpack_importer = nullptr) { return TVMFFIPyCallManager::ThreadLocal()->Call(setter_factory, func_handle, py_arg_tuple, result, - c_api_ret_code); + c_api_ret_code, release_gil, out_dlpack_importer); } /*! diff --git a/ffi/python/tvm_ffi/libinfo.py b/ffi/python/tvm_ffi/libinfo.py index b449bc1abcf5..b02897f27917 100644 --- a/ffi/python/tvm_ffi/libinfo.py +++ b/ffi/python/tvm_ffi/libinfo.py @@ -116,6 +116,18 @@ def find_include_path(): raise RuntimeError("Cannot find include path.") +def find_python_helper_include_path(): + """Find header files for C compilation.""" + candidates = [ + os.path.join(os.path.dirname(os.path.realpath(__file__)), "include"), + os.path.join(os.path.dirname(os.path.realpath(__file__)), "cython"), + ] + for candidate in candidates: + if os.path.isfile(os.path.join(candidate, "tvm_ffi_python_helpers.h")): + return candidate + raise RuntimeError("Cannot find python helper include path.") + + def find_dlpack_include_path(): """Find dlpack header files for C compilation.""" install_include_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "include") @@ -142,3 +154,14 @@ def find_cython_lib(): for path in glob.glob(os.path.join(candidate, f"core*.{suffixes}")): return os.path.abspath(path) raise RuntimeError("Cannot find tvm cython path.") + + +def include_paths(): + """Find all include paths needed for FFI related compilation.""" + include_path = find_include_path() + python_helper_include_path = find_python_helper_include_path() + dlpack_include_path = find_dlpack_include_path() + result = [include_path, dlpack_include_path] + if python_helper_include_path != include_path: + result.append(python_helper_include_path) + return result diff --git a/ffi/scripts/benchmark_dlpack.py b/ffi/scripts/benchmark_dlpack.py index 364afa1b5fdf..2ab85bf03559 100644 --- a/ffi/scripts/benchmark_dlpack.py +++ b/ffi/scripts/benchmark_dlpack.py @@ -436,9 +436,12 @@ def main(): ) bench_torch_get_current_stream(repeat, "python", torch_get_cuda_stream_native) print("---------------------------------------------------") - print("Benchmark tvm_ffi.print_helper_info") + print("Debug information") print("---------------------------------------------------") tvm_ffi.core._print_debug_info() + release_gil = tvm_ffi.get_global_func("testing.nop").release_gil + print(f"TVM_FFI_RELEASE_GIL_BY_DEFAULT={int(release_gil)}") + print("---------------------------------------------------") if __name__ == "__main__": diff --git a/ffi/src/ffi/extra/env_context.cc b/ffi/src/ffi/extra/env_context.cc new file mode 100644 index 000000000000..30f9270dabc7 --- /dev/null +++ b/ffi/src/ffi/extra/env_context.cc @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/* + * \file src/ffi/extra/env_context.cc + * + * \brief A minimalistic env context based on ffi values. + */ + +#include +#include + +#include + +namespace tvm { +namespace ffi { + +class EnvContext { + public: + void SetStream(int32_t device_type, int32_t device_id, TVMFFIStreamHandle stream, + TVMFFIStreamHandle* out_original_stream) { + if (static_cast(device_type) >= stream_table_.size()) { + stream_table_.resize(device_type + 1); + } + if (static_cast(device_id) >= stream_table_[device_type].size()) { + stream_table_[device_type].resize(device_id + 1, nullptr); + } + if (out_original_stream != nullptr) { + *out_original_stream = stream_table_[device_type][device_id]; + } + stream_table_[device_type][device_id] = stream; + } + + TVMFFIStreamHandle GetStream(int32_t device_type, int32_t device_id) { + if (static_cast(device_type) < stream_table_.size() && + static_cast(device_id) < stream_table_[device_type].size()) { + return stream_table_[device_type][device_id]; + } + return nullptr; + } + + DLPackTensorAllocator GetDLPackTensorAllocator() { + if (dlpack_allocator_ != nullptr) { + return dlpack_allocator_; + } + return GlobalTensorAllocator(); + } + + void SetDLPackTensorAllocator(DLPackTensorAllocator allocator, int write_to_global_context, + DLPackTensorAllocator* opt_out_original_allocator) { + dlpack_allocator_ = allocator; + if (write_to_global_context != 0) { + GlobalTensorAllocator() = allocator; + } + if (opt_out_original_allocator != nullptr) { + *opt_out_original_allocator = dlpack_allocator_; + } + dlpack_allocator_ = allocator; + } + + static EnvContext* ThreadLocal() { + static thread_local EnvContext inst; + return &inst; + } + + private: + // use static function to avoid static initialization order issue + static DLPackTensorAllocator& GlobalTensorAllocator() { // NOLINT(*) + static DLPackTensorAllocator allocator = nullptr; + return allocator; + } + std::vector> stream_table_; + DLPackTensorAllocator dlpack_allocator_ = nullptr; +}; + +} // namespace ffi +} // namespace tvm + +int TVMFFIEnvSetStream(int32_t device_type, int32_t device_id, TVMFFIStreamHandle stream, + TVMFFIStreamHandle* out_original_stream) { + TVM_FFI_SAFE_CALL_BEGIN(); + tvm::ffi::EnvContext::ThreadLocal()->SetStream(device_type, device_id, stream, + out_original_stream); + TVM_FFI_SAFE_CALL_END(); +} + +TVMFFIStreamHandle TVMFFIEnvGetStream(int32_t device_type, int32_t device_id) { + TVM_FFI_LOG_EXCEPTION_CALL_BEGIN(); + return tvm::ffi::EnvContext::ThreadLocal()->GetStream(device_type, device_id); + TVM_FFI_LOG_EXCEPTION_CALL_END(TVMFFIEnvGetStream); +} + +int TVMFFIEnvSetTensorAllocator(DLPackTensorAllocator allocator, int write_to_global_context, + DLPackTensorAllocator* opt_out_original_allocator) { + TVM_FFI_SAFE_CALL_BEGIN(); + tvm::ffi::EnvContext::ThreadLocal()->SetDLPackTensorAllocator(allocator, write_to_global_context, + opt_out_original_allocator); + TVM_FFI_SAFE_CALL_END(); +} + +DLPackTensorAllocator TVMFFIEnvGetTensorAllocator() { + TVM_FFI_LOG_EXCEPTION_CALL_BEGIN(); + return tvm::ffi::EnvContext::ThreadLocal()->GetDLPackTensorAllocator(); + TVM_FFI_LOG_EXCEPTION_CALL_END(TVMFFIEnvGetTensorAllocator); +} diff --git a/ffi/src/ffi/extra/stream_context.cc b/ffi/src/ffi/extra/stream_context.cc deleted file mode 100644 index 5a6afad4c1d8..000000000000 --- a/ffi/src/ffi/extra/stream_context.cc +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/* - * \file src/ffi/extra/stream_context.cc - * - * \brief A minimalistic stream context based on ffi values. - */ - -#include -#include - -#include - -namespace tvm { -namespace ffi { - -class StreamContext { - public: - void SetStream(int32_t device_type, int32_t device_id, TVMFFIStreamHandle stream, - TVMFFIStreamHandle* out_original_stream) { - if (static_cast(device_type) >= stream_table_.size()) { - stream_table_.resize(device_type + 1); - } - if (static_cast(device_id) >= stream_table_[device_type].size()) { - stream_table_[device_type].resize(device_id + 1, nullptr); - } - if (out_original_stream != nullptr) { - *out_original_stream = stream_table_[device_type][device_id]; - } - stream_table_[device_type][device_id] = stream; - } - - TVMFFIStreamHandle GetStream(int32_t device_type, int32_t device_id) { - if (static_cast(device_type) < stream_table_.size() && - static_cast(device_id) < stream_table_[device_type].size()) { - return stream_table_[device_type][device_id]; - } - return nullptr; - } - - static StreamContext* ThreadLocal() { - static thread_local StreamContext inst; - return &inst; - } - - private: - std::vector> stream_table_; -}; - -} // namespace ffi -} // namespace tvm - -int TVMFFIEnvSetCurrentStream(int32_t device_type, int32_t device_id, TVMFFIStreamHandle stream, - TVMFFIStreamHandle* out_original_stream) { - TVM_FFI_SAFE_CALL_BEGIN(); - tvm::ffi::StreamContext::ThreadLocal()->SetStream(device_type, device_id, stream, - out_original_stream); - TVM_FFI_SAFE_CALL_END(); -} - -TVMFFIStreamHandle TVMFFIEnvGetCurrentStream(int32_t device_type, int32_t device_id) { - TVM_FFI_LOG_EXCEPTION_CALL_BEGIN(); - return tvm::ffi::StreamContext::ThreadLocal()->GetStream(device_type, device_id); - TVM_FFI_LOG_EXCEPTION_CALL_END(TVMFFIEnvGetCurrentStream); -} diff --git a/ffi/tests/cpp/test_tensor.cc b/ffi/tests/cpp/test_tensor.cc index 3ad182d844f0..7c696a3429c1 100644 --- a/ffi/tests/cpp/test_tensor.cc +++ b/ffi/tests/cpp/test_tensor.cc @@ -32,6 +32,23 @@ inline Tensor Empty(Shape shape, DLDataType dtype, DLDevice device) { return Tensor::FromNDAlloc(CPUNDAlloc(), shape, dtype, device); } +int TestDLPackTensorAllocator(DLTensor* prototype, DLManagedTensorVersioned** out, void* error_ctx, + void (*SetError)(void* error_ctx, const char* kind, + const char* message)) { + Shape shape(prototype->shape, prototype->shape + prototype->ndim); + Tensor nd = Empty(shape, prototype->dtype, prototype->device); + *out = nd.ToDLPackVersioned(); + return 0; +} + +int TestDLPackTensorAllocatorError(DLTensor* prototype, DLManagedTensorVersioned** out, + void* error_ctx, + void (*SetError)(void* error_ctx, const char* kind, + const char* message)) { + SetError(error_ctx, "RuntimeError", "TestDLPackTensorAllocatorError"); + return -1; +} + TEST(Tensor, Basic) { Tensor nd = Empty(Shape({1, 2, 3}), DLDataType({kDLFloat, 32, 1}), DLDevice({kDLCPU, 0})); Shape shape = nd.shape(); @@ -116,4 +133,32 @@ TEST(Tensor, DLPackVersioned) { } EXPECT_EQ(tensor.use_count(), 1); } + +TEST(Tensor, DLPackAlloc) { + // Test successful allocation + Tensor tensor = Tensor::FromDLPackAlloc(TestDLPackTensorAllocator, {1, 2, 3}, + DLDataType({kDLFloat, 32, 1}), DLDevice({kDLCPU, 0})); + EXPECT_EQ(tensor.use_count(), 1); + EXPECT_EQ(tensor.shape().size(), 3); + EXPECT_EQ(tensor.shape()[0], 1); + EXPECT_EQ(tensor.shape()[1], 2); + EXPECT_EQ(tensor.shape()[2], 3); + EXPECT_EQ(tensor.dtype().code, kDLFloat); + EXPECT_EQ(tensor.dtype().bits, 32); + EXPECT_EQ(tensor.dtype().lanes, 1); + EXPECT_EQ(tensor->device.device_type, kDLCPU); + EXPECT_EQ(tensor->device.device_id, 0); + EXPECT_NE(tensor->data, nullptr); +} + +TEST(Tensor, DLPackAllocError) { + // Test error handling in DLPackAlloc + EXPECT_THROW( + { + Tensor::FromDLPackAlloc(TestDLPackTensorAllocatorError, {1, 2, 3}, + DLDataType({kDLFloat, 32, 1}), DLDevice({kDLCPU, 0})); + }, + tvm::ffi::Error); +} + } // namespace diff --git a/ffi/tests/python/test_load_inline.py b/ffi/tests/python/test_load_inline.py index 9a10476d8eff..89f00b1f36fd 100644 --- a/ffi/tests/python/test_load_inline.py +++ b/ffi/tests/python/test_load_inline.py @@ -186,7 +186,7 @@ def test_load_inline_cuda(): // it will be set to torch.cuda.current_stream() when calling the function // with torch.Tensors cudaStream_t stream = static_cast( - TVMFFIEnvGetCurrentStream(x->device.device_type, x->device.device_id)); + TVMFFIEnvGetStream(x->device.device_type, x->device.device_id)); // launch the kernel AddOneKernel<<>>(static_cast(x->data), static_cast(y->data), n); @@ -202,6 +202,66 @@ def test_load_inline_cuda(): torch.testing.assert_close(x_cuda + 1, y_cuda) +@pytest.mark.skipif( + torch is None or not torch.cuda.is_available(), reason="Requires torch and CUDA" +) +def test_load_inline_cuda_with_env_tensor_allocator(): + if not hasattr(torch.Tensor, "__c_dlpack_tensor_allocator__"): + pytest.skip("Torch does not support __c_dlpack_tensor_allocator__") + mod: Module = tvm_ffi.cpp.load_inline( + name="hello", + cpp_sources=r""" + #include + + tvm::ffi::Tensor return_add_one(DLTensor* x); + """, + cuda_sources=r""" + #include + + __global__ void AddOneKernel(float* x, float* y, int n) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + y[idx] = x[idx] + 1; + } + } + namespace ffi = tvm::ffi; + + ffi::Tensor return_add_one(DLTensor* x) { + // implementation of a library function + TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor"; + DLDataType f32_dtype{kDLFloat, 32, 1}; + TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float tensor"; + // allocate a new tensor with the env tensor allocator + // it will be redirected to torch.empty when calling the function + ffi::Tensor y = ffi::Tensor::FromDLPackAlloc( + TVMFFIEnvGetTensorAllocator(), ffi::Shape({x->shape[0]}), f32_dtype, x->device); + int64_t n = x->shape[0]; + int64_t nthread_per_block = 256; + int64_t nblock = (n + nthread_per_block - 1) / nthread_per_block; + // Obtain the current stream from the environment + // it will be set to torch.cuda.current_stream() when calling the function + // with torch.Tensors + cudaStream_t stream = static_cast( + TVMFFIEnvGetStream(x->device.device_type, x->device.device_id)); + // launch the kernel + AddOneKernel<<>>(static_cast(x->data), + static_cast(y->data), n); + return y; + } + """, + functions=["return_add_one"], + ) + + if torch is not None: + x_cuda = torch.asarray([1, 2, 3, 4, 5], dtype=torch.float32, device="cuda") + y_cuda = mod.return_add_one(x_cuda) + assert isinstance(y_cuda, torch.Tensor) + assert y_cuda.shape == (5,) + assert y_cuda.dtype == torch.float32 + torch.testing.assert_close(x_cuda + 1, y_cuda) + assert y_cuda.is_cuda + + @pytest.mark.skipif( torch is None or not torch.cuda.is_available(), reason="Requires torch and CUDA" ) @@ -248,7 +308,7 @@ def test_load_inline_both(): // it will be set to torch.cuda.current_stream() when calling the function // with torch.Tensors cudaStream_t stream = static_cast( - TVMFFIEnvGetCurrentStream(x->device.device_type, x->device.device_id)); + TVMFFIEnvGetStream(x->device.device_type, x->device.device_id)); // launch the kernel AddOneKernel<<>>(static_cast(x->data), static_cast(y->data), n); diff --git a/ffi/tests/python/test_tensor.py b/ffi/tests/python/test_tensor.py index aa2482f88852..5c7051279815 100644 --- a/ffi/tests/python/test_tensor.py +++ b/ffi/tests/python/test_tensor.py @@ -55,22 +55,14 @@ def test_shape_object(): assert isinstance(shape3, tvm_ffi.Shape) -@pytest.mark.skipif(torch is None, reason="Torch is not installed") +@pytest.mark.skipif(torch is None, reason="Fast torch dlpack importer is not enabled") def test_tensor_auto_dlpack(): - def check(x, y): - assert isinstance(y, tvm_ffi.Tensor) - assert y.shape == (128,) - assert y.dtype == tvm_ffi.dtype("int64") - assert y.device.dlpack_device_type() == tvm_ffi.DLDeviceType.kDLCPU - assert y.device.index == 0 - x2 = torch.from_dlpack(y) - np.testing.assert_equal(x2.numpy(), x.numpy()) - x = torch.arange(128) fecho = tvm_ffi.get_global_func("testing.echo") y = fecho(x) - check(x, y) - - # pass in list of tensors - y = fecho([x]) - check(x, y[0]) + assert isinstance(y, torch.Tensor) + assert y.data_ptr() == x.data_ptr() + assert y.dtype == x.dtype + assert y.shape == x.shape + assert y.device == x.device + np.testing.assert_equal(y.numpy(), x.numpy()) diff --git a/python/tvm/contrib/cutlass/attention_operation.py b/python/tvm/contrib/cutlass/attention_operation.py index fe29cd59459b..ff804e83460c 100644 --- a/python/tvm/contrib/cutlass/attention_operation.py +++ b/python/tvm/contrib/cutlass/attention_operation.py @@ -147,7 +147,7 @@ def instantiate_attention_template(attrs): } CHECK(Attention::check_supported(p)); - cudaStream_t stream = static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, ${query}->device.device_id)); + cudaStream_t stream = static_cast(TVMFFIEnvGetStream(kDLCUDA, ${query}->device.device_id)); kernel_fn<<>>(p); @@ -185,7 +185,7 @@ def instantiate_flash_attention_template(attrs): int v_batch_stride = v_row_stride * ${num_keys}; int o_batch_stride = o_row_stride * ${num_queries}; - cudaStream_t stream = static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, ${query}->device.device_id)); + cudaStream_t stream = static_cast(TVMFFIEnvGetStream(kDLCUDA, ${query}->device.device_id)); flash_attn::flash_attention_forward( static_cast(${query}->data), @@ -235,7 +235,7 @@ def instantiate_flash_attention_template(attrs): int v_batch_stride = v_row_stride * ${num_keys}; int o_batch_stride = o_row_stride * ${num_queries}; - cudaStream_t stream = static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, ${query}->device.device_id)); + cudaStream_t stream = static_cast(TVMFFIEnvGetStream(kDLCUDA, ${query}->device.device_id)); flash_attn::flash_attention_forward( static_cast(${qkv}->data), @@ -291,7 +291,7 @@ def instantiate_flash_attention_var_len_template(attrs): int v_row_stride = v_head_stride * ${num_kv_heads}; int o_row_stride = o_head_stride * ${num_q_heads}; - cudaStream_t stream = static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, ${query}->device.device_id)); + cudaStream_t stream = static_cast(TVMFFIEnvGetStream(kDLCUDA, ${query}->device.device_id)); flash_attn::flash_attention_var_len_forward( static_cast(${query}->data), diff --git a/python/tvm/contrib/cutlass/conv2d_operation.py b/python/tvm/contrib/cutlass/conv2d_operation.py index b0afdcdd6e84..e323e2a14937 100644 --- a/python/tvm/contrib/cutlass/conv2d_operation.py +++ b/python/tvm/contrib/cutlass/conv2d_operation.py @@ -424,7 +424,7 @@ def instantiate_conv2d_template(attrs): TVM_FFI_ICHECK(status == cutlass::Status::kSuccess); ${split_k_update} - cudaStream_t stream = static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, ${data_arg}->device.device_id)); + cudaStream_t stream = static_cast(TVMFFIEnvGetStream(kDLCUDA, ${data_arg}->device.device_id)); status = conv2d_op(stream); TVM_FFI_ICHECK(status == cutlass::Status::kSuccess); diff --git a/python/tvm/contrib/cutlass/gemm_operation.py b/python/tvm/contrib/cutlass/gemm_operation.py index 453839cc8130..d8940230e0e3 100644 --- a/python/tvm/contrib/cutlass/gemm_operation.py +++ b/python/tvm/contrib/cutlass/gemm_operation.py @@ -345,7 +345,7 @@ def instantiate_gemm_template(attrs): status = gemm_op.initialize(arguments, workspace.get()); TVM_FFI_ICHECK(status == cutlass::Status::kSuccess); - cudaStream_t stream = static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, ${A_arg}->device.device_id)); + cudaStream_t stream = static_cast(TVMFFIEnvGetStream(kDLCUDA, ${A_arg}->device.device_id)); status = gemm_op(stream); TVM_FFI_ICHECK(status == cutlass::Status::kSuccess); @@ -428,7 +428,7 @@ def emit_fp16A_intB_matmul(attrs): int k = ${B_arg}->shape[0]; cudaStream_t stream = static_cast( - TVMFFIEnvGetCurrentStream(kDLCUDA, ${A_arg}->device.device_id)); + TVMFFIEnvGetStream(kDLCUDA, ${A_arg}->device.device_id)); """, attrs, ) diff --git a/python/tvm/contrib/cutlass/layer_norm_operation.py b/python/tvm/contrib/cutlass/layer_norm_operation.py index d2a031024475..b0f7dc7c14f7 100644 --- a/python/tvm/contrib/cutlass/layer_norm_operation.py +++ b/python/tvm/contrib/cutlass/layer_norm_operation.py @@ -39,7 +39,7 @@ def instantiate_layer_norm_template(attrs): cutlass::TensorRef _beta((data_type*)${beta}->data, layout_channels); cutlass::TensorRef _output((data_type*)out0->data, layout_2D); - cudaStream_t stream = static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, ${input}->device.device_id)); + cudaStream_t stream = static_cast(TVMFFIEnvGetStream(kDLCUDA, ${input}->device.device_id)); cutlass::layernorm(size, _output, _input, _gamma, _beta, stream); """ diff --git a/python/tvm/contrib/cutlass/rms_norm_operation.py b/python/tvm/contrib/cutlass/rms_norm_operation.py index 51c18d4ae47b..3d038ab21011 100644 --- a/python/tvm/contrib/cutlass/rms_norm_operation.py +++ b/python/tvm/contrib/cutlass/rms_norm_operation.py @@ -38,7 +38,7 @@ def instantiate_rms_norm_template(attrs): cutlass::TensorRef _weight((data_type*)${weight}->data, layout_channels); cutlass::TensorRef _output((data_type*)out0->data, layout_2D); - cudaStream_t stream = static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, ${input}->device.device_id)); + cudaStream_t stream = static_cast(TVMFFIEnvGetStream(kDLCUDA, ${input}->device.device_id)); cutlass::rmsnorm(size, _output, _input, _weight, stream, ${rms_eps}); """ diff --git a/src/contrib/msc/plugin/tvm_codegen.cc b/src/contrib/msc/plugin/tvm_codegen.cc index 373e9aaac294..ae107c06773f 100644 --- a/src/contrib/msc/plugin/tvm_codegen.cc +++ b/src/contrib/msc/plugin/tvm_codegen.cc @@ -385,7 +385,7 @@ void TVMPluginCodeGen::CodeGenCompute(const Plugin& plugin, const ffi::String& d compute_args.push_back("meta_attr"); if (device == "cuda") { // TODO(tvm-team): update to support get stream from device id - stack_.assign("stream", "TVMFFIEnvGetCurrentStream(kDLCUDA, 0)", "auto"); + stack_.assign("stream", "TVMFFIEnvGetStream(kDLCUDA, 0)", "auto"); compute_args.push_back("stream"); } CodeGenSafeCall(plugin->externs[device + "_compute"], compute_args); diff --git a/src/runtime/contrib/cublas/cublas.cc b/src/runtime/contrib/cublas/cublas.cc index 13f958744e61..88a0dc128df2 100644 --- a/src/runtime/contrib/cublas/cublas.cc +++ b/src/runtime/contrib/cublas/cublas.cc @@ -558,7 +558,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ cublasLtHandle_t ltHandle; CHECK_CUBLAS_ERROR(cublasLtCreate(<Handle)); cudaStream_t stream = - static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, A->device.device_id)); + static_cast(TVMFFIEnvGetStream(kDLCUDA, A->device.device_id)); CallLtIgemm(args, ret, ltHandle, stream); CHECK_CUBLAS_ERROR(cublasLtDestroy(ltHandle)); }); diff --git a/src/runtime/contrib/cublas/cublas_json_runtime.cc b/src/runtime/contrib/cublas/cublas_json_runtime.cc index 98b05ba31995..33bdaaf0f7c0 100644 --- a/src/runtime/contrib/cublas/cublas_json_runtime.cc +++ b/src/runtime/contrib/cublas/cublas_json_runtime.cc @@ -91,7 +91,7 @@ class CublasJSONRuntime : public JSONRuntimeBase { CUDA_CALL(cudaGetDevice(&device_id)); } auto* entry_ptr = tvm::contrib::CuBlasLtThreadEntry::ThreadLocal(DLDevice{kDLCUDA, device_id}); - cudaStream_t stream = static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, device_id)); + cudaStream_t stream = static_cast(TVMFFIEnvGetStream(kDLCUDA, device_id)); auto get_input = [this, &dl_tensors](const JSONGraphNode& node, int idx) { ICHECK_LT(idx, node.GetInputs().size()); diff --git a/src/runtime/contrib/cublas/cublas_utils.cc b/src/runtime/contrib/cublas/cublas_utils.cc index 0ba654c9ebc8..f5248fde7e00 100644 --- a/src/runtime/contrib/cublas/cublas_utils.cc +++ b/src/runtime/contrib/cublas/cublas_utils.cc @@ -44,8 +44,8 @@ typedef dmlc::ThreadLocalStore CuBlasThreadStore; CuBlasThreadEntry* CuBlasThreadEntry::ThreadLocal(DLDevice curr_device) { CuBlasThreadEntry* retval = CuBlasThreadStore::Get(); - cudaStream_t stream = static_cast( - TVMFFIEnvGetCurrentStream(curr_device.device_type, curr_device.device_id)); + cudaStream_t stream = + static_cast(TVMFFIEnvGetStream(curr_device.device_type, curr_device.device_id)); CHECK_CUBLAS_ERROR(cublasSetStream(retval->handle, stream)); return retval; } diff --git a/src/runtime/contrib/cudnn/cudnn_json_runtime.cc b/src/runtime/contrib/cudnn/cudnn_json_runtime.cc index fa046980e39a..48560f4306a6 100644 --- a/src/runtime/contrib/cudnn/cudnn_json_runtime.cc +++ b/src/runtime/contrib/cudnn/cudnn_json_runtime.cc @@ -164,8 +164,7 @@ class cuDNNJSONRuntime : public JSONRuntimeBase { std::function op_exec = [=]() { int device_id; CUDA_CALL(cudaGetDevice(&device_id)); - cudaStream_t stream = - static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, device_id)); + cudaStream_t stream = static_cast(TVMFFIEnvGetStream(kDLCUDA, device_id)); CUDNN_CALL(cudnnSetStream(entry_ptr->handle, stream)); auto get_inputs = [this](const JSONGraphNode& node, bool has_bias) { diff --git a/src/runtime/contrib/cudnn/cudnn_utils.cc b/src/runtime/contrib/cudnn/cudnn_utils.cc index acedf7a9e2dd..f36a50a80a35 100644 --- a/src/runtime/contrib/cudnn/cudnn_utils.cc +++ b/src/runtime/contrib/cudnn/cudnn_utils.cc @@ -129,8 +129,8 @@ CuDNNThreadEntry* CuDNNThreadEntry::ThreadLocal(Device curr_device, bool check_e ICHECK(res->exists()) << "CUDNN_STATUS_NOT_INITIALIZED"; } - cudaStream_t stream = static_cast( - TVMFFIEnvGetCurrentStream(curr_device.device_type, curr_device.device_id)); + cudaStream_t stream = + static_cast(TVMFFIEnvGetStream(curr_device.device_type, curr_device.device_id)); CUDNN_CALL(cudnnSetStream(res->handle, stream)); return res; } diff --git a/src/runtime/contrib/cutlass/fp16_group_gemm.cuh b/src/runtime/contrib/cutlass/fp16_group_gemm.cuh index ffc05893cad6..0527829c528d 100644 --- a/src/runtime/contrib/cutlass/fp16_group_gemm.cuh +++ b/src/runtime/contrib/cutlass/fp16_group_gemm.cuh @@ -38,7 +38,7 @@ void tvm_cutlass_group_gemm_impl(Tensor x, Tensor weight, Tensor indptr, Tensor // Workspace is used for storing device-side group gemm arguments and cutlass internal workspace. // Recommened size is 4MB. cudaStream_t stream = - static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, x->device.device_id)); + static_cast(TVMFFIEnvGetStream(kDLCUDA, x->device.device_id)); CHECK_EQ(x->ndim, 2); CHECK_EQ(weight->ndim, 3); CHECK_EQ(indptr->ndim, 1); diff --git a/src/runtime/contrib/cutlass/fp8_gemm.cu b/src/runtime/contrib/cutlass/fp8_gemm.cu index 2be8c09da2dc..5c73c0cb74bd 100644 --- a/src/runtime/contrib/cutlass/fp8_gemm.cu +++ b/src/runtime/contrib/cutlass/fp8_gemm.cu @@ -42,8 +42,7 @@ template void tvm_cutlass_fp8_gemm(Tensor x, Tensor weight, Tensor workspace, Tensor alpha, Tensor out) { // Workspace is used for storing device-side gemm arguments and cutlass internal workspace. // Recommened size is 4MB. - cudaStream_t stream = - static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, x->device.device_id)); + cudaStream_t stream = static_cast(TVMFFIEnvGetStream(kDLCUDA, x->device.device_id)); CHECK_GE(x->ndim, 2); CHECK_EQ(weight->ndim, 2); diff --git a/src/runtime/contrib/cutlass/fp8_group_gemm_sm90.cu b/src/runtime/contrib/cutlass/fp8_group_gemm_sm90.cu index 48e68cb804f6..97f3e80e5bf0 100644 --- a/src/runtime/contrib/cutlass/fp8_group_gemm_sm90.cu +++ b/src/runtime/contrib/cutlass/fp8_group_gemm_sm90.cu @@ -46,8 +46,7 @@ void tvm_cutlass_fp8_group_gemm(Tensor x, Tensor weight, Tensor indptr, Tensor w Tensor alpha, Tensor out) { // Workspace is used for storing device-side group gemm arguments and cutlass internal workspace. // Recommened size is 4MB. - cudaStream_t stream = - static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, x->device.device_id)); + cudaStream_t stream = static_cast(TVMFFIEnvGetStream(kDLCUDA, x->device.device_id)); CHECK_EQ(x->ndim, 2); CHECK_EQ(weight->ndim, 3); CHECK_EQ(indptr->ndim, 1); diff --git a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm.cuh b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm.cuh index e03366a03860..35f08efbc57c 100644 --- a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm.cuh +++ b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm.cuh @@ -40,7 +40,7 @@ void tvm_cutlass_fp8_groupwise_scaled_gemm_impl(Tensor a, Tensor b, Tensor scale Tensor out) { // Workspace is used for storing device-side gemm arguments and cutlass internal workspace. // Recommened size is 4MB. - cudaStream_t stream = static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, a->device.device_id)); + cudaStream_t stream = static_cast(TVMFFIEnvGetStream(kDLCUDA, a->device.device_id)); CHECK_GE(a->ndim, 2); CHECK_EQ(scales_a->ndim, a->ndim); @@ -106,7 +106,7 @@ void tvm_cutlass_fp8_groupwise_scaled_bmm_impl(Tensor a, Tensor b, Tensor scales Tensor out) { // Workspace is used for storing device-side gemm arguments and cutlass internal workspace. // Recommened size is 4MB. - cudaStream_t stream = static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, a->device.device_id)); + cudaStream_t stream = static_cast(TVMFFIEnvGetStream(kDLCUDA, a->device.device_id)); CHECK_EQ(a->ndim, 3); CHECK_EQ(scales_a->ndim, 3); diff --git a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_sm100.cu b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_sm100.cu index 420f93d4f2f3..8ac0e0452d57 100644 --- a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_sm100.cu +++ b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_sm100.cu @@ -38,8 +38,7 @@ void tvm_fp8_groupwise_scaled_group_gemm_sm100(Tensor a, Tensor b, Tensor scales Tensor out) { // Workspace is used for storing device-side group gemm arguments and cutlass internal workspace. // Recommended size is 4MB. - cudaStream_t stream = - static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, a->device.device_id)); + cudaStream_t stream = static_cast(TVMFFIEnvGetStream(kDLCUDA, a->device.device_id)); CHECK_EQ(a->ndim, 2); CHECK_EQ(b->ndim, 3); CHECK_EQ(indptr->ndim, 1); diff --git a/src/runtime/contrib/hipblas/hipblas_json_runtime.cc b/src/runtime/contrib/hipblas/hipblas_json_runtime.cc index 6e760b7f0625..f53f8f7c6a51 100644 --- a/src/runtime/contrib/hipblas/hipblas_json_runtime.cc +++ b/src/runtime/contrib/hipblas/hipblas_json_runtime.cc @@ -89,7 +89,7 @@ class HipblasJSONRuntime : public JSONRuntimeBase { ROCM_CALL(hipGetDevice(&device_id)); } auto* entry_ptr = tvm::contrib::HipBlasLtThreadEntry::ThreadLocal(DLDevice{kDLROCM, device_id}); - hipStream_t stream = static_cast(TVMFFIEnvGetCurrentStream(kDLROCM, device_id)); + hipStream_t stream = static_cast(TVMFFIEnvGetStream(kDLROCM, device_id)); auto get_input = [this, &dl_tensors](const JSONGraphNode& node, int idx) { ICHECK_LT(idx, node.GetInputs().size()); diff --git a/src/runtime/contrib/hipblas/hipblas_utils.cc b/src/runtime/contrib/hipblas/hipblas_utils.cc index 1b61cbd38219..17ed9a0d936d 100644 --- a/src/runtime/contrib/hipblas/hipblas_utils.cc +++ b/src/runtime/contrib/hipblas/hipblas_utils.cc @@ -44,8 +44,7 @@ typedef dmlc::ThreadLocalStore HipBlasThreadStore; HipBlasThreadEntry* HipBlasThreadEntry::ThreadLocal(DLDevice curr_device) { HipBlasThreadEntry* retval = HipBlasThreadStore::Get(); - TVMFFIStreamHandle stream = - TVMFFIEnvGetCurrentStream(curr_device.device_type, curr_device.device_id); + TVMFFIStreamHandle stream = TVMFFIEnvGetStream(curr_device.device_type, curr_device.device_id); CHECK_HIPBLAS_ERROR(hipblasSetStream(retval->handle, static_cast(stream))); return retval; } diff --git a/src/runtime/contrib/miopen/miopen_utils.cc b/src/runtime/contrib/miopen/miopen_utils.cc index e860ba8ea7f2..617ea5aaf027 100644 --- a/src/runtime/contrib/miopen/miopen_utils.cc +++ b/src/runtime/contrib/miopen/miopen_utils.cc @@ -56,8 +56,7 @@ typedef dmlc::ThreadLocalStore MIOpenThreadStore; MIOpenThreadEntry* MIOpenThreadEntry::ThreadLocal(Device curr_device) { // Need to update stream per fetch to avoid stream switching MIOpenThreadEntry* res = MIOpenThreadStore::Get(); - TVMFFIStreamHandle stream = - TVMFFIEnvGetCurrentStream(curr_device.device_type, curr_device.device_id); + TVMFFIStreamHandle stream = TVMFFIEnvGetStream(curr_device.device_type, curr_device.device_id); MIOPEN_CALL(miopenSetStream(res->handle, stream)); return res; } diff --git a/src/runtime/contrib/msc/tensorrt_runtime.cc b/src/runtime/contrib/msc/tensorrt_runtime.cc index 8a837370fa34..07b190a2c0be 100644 --- a/src/runtime/contrib/msc/tensorrt_runtime.cc +++ b/src/runtime/contrib/msc/tensorrt_runtime.cc @@ -133,7 +133,7 @@ class MSCTensorRTRuntime : public JSONRuntimeBase { context.Set("datas", input_datas); (*pf)(context, "before_forward", graph_name_, tool_tag_); } - auto tvm_stream = TVMFFIEnvGetCurrentStream(kDLCUDA, device_id); + auto tvm_stream = TVMFFIEnvGetStream(kDLCUDA, device_id); #if TRT_VERSION_GE(6, 0, 1) ICHECK(context_->enqueueV2(bindings_.data(), tvm_stream, nullptr)) << "Running TensorRT failed."; diff --git a/src/runtime/contrib/thrust/thrust.cu b/src/runtime/contrib/thrust/thrust.cu index 1adf95f69320..7eede1b65485 100644 --- a/src/runtime/contrib/thrust/thrust.cu +++ b/src/runtime/contrib/thrust/thrust.cu @@ -94,7 +94,7 @@ class WorkspaceMemoryResource : public thrust::mr::memory_resource { auto get_thrust_exec_policy(WorkspaceMemoryResource* memory_resouce) { int device_id; CUDA_CALL(cudaGetDevice(&device_id)); - cudaStream_t stream = static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, device_id)); + cudaStream_t stream = static_cast(TVMFFIEnvGetStream(kDLCUDA, device_id)); return thrust::cuda::par_nosync(memory_resouce).on(stream); } diff --git a/src/runtime/cuda/cuda_device_api.cc b/src/runtime/cuda/cuda_device_api.cc index 623968fedeab..f8ec539cc0dc 100644 --- a/src/runtime/cuda/cuda_device_api.cc +++ b/src/runtime/cuda/cuda_device_api.cc @@ -301,7 +301,7 @@ class CUDATimerNode : public TimerNode { // cudaEventRecord do some stream synchronization? int device_id; CUDA_CALL(cudaGetDevice(&device_id)); - stream_ = TVMFFIEnvGetCurrentStream(kDLCUDA, device_id); + stream_ = TVMFFIEnvGetStream(kDLCUDA, device_id); CUDA_CALL(cudaEventRecord(start_, static_cast(stream_))); } virtual void Stop() { @@ -352,10 +352,10 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def("runtime.GetCudaFreeMemory", GetCudaFreeMemory) .def("runtime.get_cuda_stream", []() { // TODO(tvm-team): remove once confirms all dep such as flashinfer - // migrated to TVMFFIEnvGetCurrentStream + // migrated to TVMFFIEnvGetStream int device_id; CUDA_CALL(cudaGetDevice(&device_id)); - return static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, device_id)); + return static_cast(TVMFFIEnvGetStream(kDLCUDA, device_id)); }); }); diff --git a/src/runtime/cuda/cuda_module.cc b/src/runtime/cuda/cuda_module.cc index 9086903d0141..9673dfa169fd 100644 --- a/src/runtime/cuda/cuda_module.cc +++ b/src/runtime/cuda/cuda_module.cc @@ -199,7 +199,7 @@ class CUDAWrappedFunc { } } } - CUstream strm = static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, device_id)); + CUstream strm = static_cast(TVMFFIEnvGetStream(kDLCUDA, device_id)); CUresult result = cuLaunchKernel(fcache_[device_id], wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2), wl.block_dim(0), wl.block_dim(1), wl.block_dim(2), wl.dyn_shmem_size, strm, void_args, nullptr); diff --git a/src/runtime/cuda/l2_cache_flush.cc b/src/runtime/cuda/l2_cache_flush.cc index 0c7f939181a2..d02f4efdb900 100644 --- a/src/runtime/cuda/l2_cache_flush.cc +++ b/src/runtime/cuda/l2_cache_flush.cc @@ -40,7 +40,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ ICHECK(L2Flush::ThreadLocal() != nullptr) << "L2Flush::ThreadLocal do not exist."; int device_id; CUDA_CALL(cudaGetDevice(&device_id)); - cudaStream_t stream = static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, device_id)); + cudaStream_t stream = static_cast(TVMFFIEnvGetStream(kDLCUDA, device_id)); L2Flush::ThreadLocal()->Flush(stream); }); }); diff --git a/src/runtime/device_api.cc b/src/runtime/device_api.cc index e574ce14b004..96d370dfe2e5 100644 --- a/src/runtime/device_api.cc +++ b/src/runtime/device_api.cc @@ -165,12 +165,11 @@ TVMStreamHandle DeviceAPI::CreateStream(Device dev) { return nullptr; } void DeviceAPI::FreeStream(Device dev, TVMStreamHandle stream) {} void DeviceAPI::SetStream(Device dev, TVMStreamHandle stream) { - TVM_FFI_CHECK_SAFE_CALL( - TVMFFIEnvSetCurrentStream(dev.device_type, dev.device_id, stream, nullptr)); + TVM_FFI_CHECK_SAFE_CALL(TVMFFIEnvSetStream(dev.device_type, dev.device_id, stream, nullptr)); } TVMStreamHandle DeviceAPI::GetCurrentStream(Device dev) { - return TVMFFIEnvGetCurrentStream(dev.device_type, dev.device_id); + return TVMFFIEnvGetStream(dev.device_type, dev.device_id); } void DeviceAPI::SyncStreamFromTo(Device dev, TVMStreamHandle event_src, TVMStreamHandle event_dst) { diff --git a/src/runtime/rocm/rocm_device_api.cc b/src/runtime/rocm/rocm_device_api.cc index 4b042d8d491d..2ea9727b8b53 100644 --- a/src/runtime/rocm/rocm_device_api.cc +++ b/src/runtime/rocm/rocm_device_api.cc @@ -264,7 +264,7 @@ class ROCMTimerNode : public TimerNode { virtual void Start() { int device_id; ROCM_CALL(hipGetDevice(&device_id)); - stream_ = TVMFFIEnvGetCurrentStream(kDLROCM, device_id); + stream_ = TVMFFIEnvGetStream(kDLROCM, device_id); ROCM_CALL(hipEventRecord(start_, static_cast(stream_))); } virtual void Stop() { @@ -302,7 +302,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def("runtime.get_rocm_stream", []() { int device_id; ROCM_CALL(hipGetDevice(&device_id)); - return static_cast(TVMFFIEnvGetCurrentStream(kDLROCM, device_id)); + return static_cast(TVMFFIEnvGetStream(kDLROCM, device_id)); }); }); diff --git a/src/runtime/rocm/rocm_module.cc b/src/runtime/rocm/rocm_module.cc index 3ef9bf47a9b1..f8f7ed673f07 100644 --- a/src/runtime/rocm/rocm_module.cc +++ b/src/runtime/rocm/rocm_module.cc @@ -172,7 +172,7 @@ class ROCMWrappedFunc { fcache_[device_id] = m_->GetFunc(device_id, func_name_); } - hipStream_t strm = static_cast(TVMFFIEnvGetCurrentStream(kDLROCM, device_id)); + hipStream_t strm = static_cast(TVMFFIEnvGetStream(kDLROCM, device_id)); ThreadWorkLoad wl = launch_param_config_.Extract(args); void* config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, packed_args, HIP_LAUNCH_PARAM_BUFFER_SIZE, diff --git a/src/runtime/vm/cuda/cuda_graph_builtin.cc b/src/runtime/vm/cuda/cuda_graph_builtin.cc index 252841528152..0e8cc2090784 100644 --- a/src/runtime/vm/cuda/cuda_graph_builtin.cc +++ b/src/runtime/vm/cuda/cuda_graph_builtin.cc @@ -118,14 +118,13 @@ class CUDACaptureStream { explicit CUDACaptureStream(cudaGraph_t* graph) : output_graph_(graph) { CUDA_CALL(cudaGetDevice(&device_id_)); TVM_FFI_CHECK_SAFE_CALL( - TVMFFIEnvSetCurrentStream(kDLCUDA, device_id_, capture_stream_, - reinterpret_cast(&prev_default_stream_))); + TVMFFIEnvSetStream(kDLCUDA, device_id_, capture_stream_, + reinterpret_cast(&prev_default_stream_))); CUDA_CALL(cudaStreamBeginCapture(capture_stream_, cudaStreamCaptureModeGlobal)); } ~CUDACaptureStream() noexcept(false) { cudaStreamEndCapture(capture_stream_, output_graph_); - TVM_FFI_CHECK_SAFE_CALL( - TVMFFIEnvSetCurrentStream(kDLCUDA, device_id_, prev_default_stream_, nullptr)); + TVM_FFI_CHECK_SAFE_CALL(TVMFFIEnvSetStream(kDLCUDA, device_id_, prev_default_stream_, nullptr)); } private: @@ -159,8 +158,8 @@ class CUDAGraphExtensionNode : public VMExtensionNode { const auto& [states, exec] = it->second; int device_id; CUDA_CALL(cudaGetDevice(&device_id)); - CUDA_CALL(cudaGraphLaunch( - exec, static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, device_id)))); + CUDA_CALL( + cudaGraphLaunch(exec, static_cast(TVMFFIEnvGetStream(kDLCUDA, device_id)))); return states; }