From ba2008ba03fe0d8b89075871ec7f120b00059333 Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 8 Sep 2025 07:55:11 -0400 Subject: [PATCH] [FFI] Relax default alignment and continguous requirement This PR relax default alignment and continguous requirement in dlpack import. This allows the ffi to be useful in most settings. We also provide utility for users to check these requirements themselves. --- ffi/include/tvm/ffi/container/tensor.h | 22 +++++++++--- ffi/python/tvm_ffi/_convert.py | 4 +-- ffi/python/tvm_ffi/cython/function.pxi | 5 ++- ffi/python/tvm_ffi/cython/tensor.pxi | 43 ++++++++++++------------ python/tvm/runtime/_tensor.py | 10 +++--- src/tir/ir/stmt.cc | 4 +-- src/tir/transforms/arg_binder.cc | 4 +-- src/tir/transforms/lower_match_buffer.cc | 4 +-- tests/python/relax/test_op_inspect.py | 2 +- 9 files changed, 54 insertions(+), 44 deletions(-) diff --git a/ffi/include/tvm/ffi/container/tensor.h b/ffi/include/tvm/ffi/container/tensor.h index b5be116b491c..99fb29d10830 100644 --- a/ffi/include/tvm/ffi/container/tensor.h +++ b/ffi/include/tvm/ffi/container/tensor.h @@ -35,6 +35,16 @@ namespace tvm { namespace ffi { +/*! + * \brief Check if the device uses direct address, where address of data indicate alignment. + * \param device The input device. + * \return True if the device uses direct address, false otherwise. + */ +inline bool IsDirectAddressDevice(const DLDevice& device) { + return device.device_type <= kDLCUDAHost || device.device_type == kDLCUDAManaged || + device.device_type == kDLROCM || device.device_type == kDLROCMHost; +} + /*! * \brief check if a DLTensor is contiguous. * \param arr The input DLTensor. @@ -67,11 +77,7 @@ inline bool IsContiguous(const DLTensor& arr) { * \return True if the data is aligned to the given alignment, false otherwise. */ inline bool IsAligned(const DLTensor& arr, size_t alignment) { - // whether the device uses direct address mapping instead of indirect buffer - bool direct_address = arr.device.device_type <= kDLCUDAHost || - arr.device.device_type == kDLCUDAManaged || - arr.device.device_type == kDLROCM || arr.device.device_type == kDLROCMHost; - if (direct_address) { + if (IsDirectAddressDevice(arr.device)) { return (reinterpret_cast(static_cast(arr.data) + arr.byte_offset) % alignment == 0); } else { @@ -278,6 +284,12 @@ class Tensor : public ObjectRef { * \return True if the Tensor is contiguous, false otherwise. */ bool IsContiguous() const { return tvm::ffi::IsContiguous(*get()); } + /*! + * \brief Check if the Tensor data is aligned to the given alignment. + * \param alignment The alignment to check. + * \return True if the Tensor data is aligned to the given alignment, false otherwise. + */ + bool IsAligned(size_t alignment) const { return tvm::ffi::IsAligned(*get(), alignment); } /*! * \brief Create a Tensor from a NDAllocator. * \param alloc The NDAllocator. diff --git a/ffi/python/tvm_ffi/_convert.py b/ffi/python/tvm_ffi/_convert.py index 168dd15b531b..b1b972633d86 100644 --- a/ffi/python/tvm_ffi/_convert.py +++ b/ffi/python/tvm_ffi/_convert.py @@ -61,9 +61,7 @@ def convert(value: Any) -> Any: elif value is None: return None elif hasattr(value, "__dlpack__"): - return core.from_dlpack( - value, required_alignment=core.__dlpack_auto_import_required_alignment__ - ) + return core.from_dlpack(value) elif isinstance(value, Exception): return core._convert_to_ffi_error(value) else: diff --git a/ffi/python/tvm_ffi/cython/function.pxi b/ffi/python/tvm_ffi/cython/function.pxi index 0161ec4292ab..28d4ba5a0094 100644 --- a/ffi/python/tvm_ffi/cython/function.pxi +++ b/ffi/python/tvm_ffi/cython/function.pxi @@ -109,8 +109,7 @@ cdef inline int make_args(tuple py_args, TVMFFIAny* out, list temp_args, out[i].v_ptr = (arg).chandle elif torch is not None and isinstance(arg, torch.Tensor): is_cuda = arg.is_cuda - arg = from_dlpack(torch.utils.dlpack.to_dlpack(arg), - required_alignment=__dlpack_auto_import_required_alignment__) + arg = from_dlpack(torch.utils.dlpack.to_dlpack(arg)) out[i].type_index = kTVMFFITensor out[i].v_ptr = (arg).chandle temp_dltensor = TVMFFITensorGetDLTensorPtr((arg).chandle) @@ -123,7 +122,7 @@ cdef inline int make_args(tuple py_args, TVMFFIAny* out, list temp_args, ctx_stream[0] = temp_ptr temp_args.append(arg) elif hasattr(arg, "__dlpack__"): - arg = from_dlpack(arg, required_alignment=__dlpack_auto_import_required_alignment__) + arg = from_dlpack(arg) out[i].type_index = kTVMFFITensor out[i].v_ptr = (arg).chandle temp_args.append(arg) diff --git a/ffi/python/tvm_ffi/cython/tensor.pxi b/ffi/python/tvm_ffi/cython/tensor.pxi index b09ac42eb99c..4658422ca524 100644 --- a/ffi/python/tvm_ffi/cython/tensor.pxi +++ b/ffi/python/tvm_ffi/cython/tensor.pxi @@ -16,7 +16,6 @@ # under the License. __dlpack_version__ = (1, 1) -__dlpack_auto_import_required_alignment__ = 8 _CLASS_TENSOR = None @@ -45,13 +44,13 @@ cdef void _c_dlpack_versioned_deleter(object pycaps): cdef inline int _from_dlpack( - object dltensor, int required_alignment, - int required_contiguous, TVMFFIObjectHandle* out + object dltensor, int require_alignment, + int require_contiguous, TVMFFIObjectHandle* out ) except -1: cdef DLManagedTensor* ptr cdef int c_api_ret_code - cdef int c_req_alignment = required_alignment - cdef int c_req_contiguous = required_contiguous + cdef int c_req_alignment = require_alignment + cdef int c_req_contiguous = require_contiguous if pycapsule.PyCapsule_IsValid(dltensor, _c_str_dltensor): ptr = pycapsule.PyCapsule_GetPointer(dltensor, _c_str_dltensor) with nogil: @@ -66,13 +65,13 @@ cdef inline int _from_dlpack( cdef inline int _from_dlpack_versioned( - object dltensor, int required_alignment, - int required_contiguous, TVMFFIObjectHandle* out + object dltensor, int require_alignment, + int require_contiguous, TVMFFIObjectHandle* out ) except -1: cdef DLManagedTensorVersioned* ptr cdef int c_api_ret_code - cdef int c_req_alignment = required_alignment - cdef int c_req_contiguous = required_contiguous + cdef int c_req_alignment = require_alignment + cdef int c_req_contiguous = require_contiguous if pycapsule.PyCapsule_IsValid(dltensor, _c_str_dltensor_versioned): ptr = pycapsule.PyCapsule_GetPointer( dltensor, _c_str_dltensor_versioned) @@ -87,7 +86,7 @@ cdef inline int _from_dlpack_versioned( raise ValueError("Expect a dltensor_versioned field, PyCapsule can only be consumed once") -def from_dlpack(ext_tensor, *, required_alignment=8, required_contiguous=True): +def from_dlpack(ext_tensor, *, require_alignment=0, require_contiguous=False): """ Convert an external tensor to an Tensor. @@ -96,10 +95,10 @@ def from_dlpack(ext_tensor, *, required_alignment=8, required_contiguous=True): ext_tensor : object The external tensor to convert. - required_alignment : int + require_alignment : int The minimum required alignment to check for the tensor. - required_contiguous : bool + require_contiguous : bool Whether to check for contiguous memory. Returns @@ -116,38 +115,38 @@ def from_dlpack(ext_tensor, *, required_alignment=8, required_contiguous=True): if favor_legacy_dlpack: _from_dlpack( ext_tensor.__dlpack__(), - required_alignment, - required_contiguous, + require_alignment, + require_contiguous, &chandle ) else: try: _from_dlpack_versioned( ext_tensor.__dlpack__(max_version=__dlpack_version__), - required_alignment, - required_contiguous, + require_alignment, + require_contiguous, &chandle ) except TypeError: _from_dlpack( ext_tensor.__dlpack__(), - required_alignment, - required_contiguous, + require_alignment, + require_contiguous, &chandle ) else: if pycapsule.PyCapsule_IsValid(ext_tensor, _c_str_dltensor_versioned): _from_dlpack_versioned( ext_tensor, - required_alignment, - required_contiguous, + require_alignment, + require_contiguous, &chandle ) elif pycapsule.PyCapsule_IsValid(ext_tensor, _c_str_dltensor): _from_dlpack( ext_tensor, - required_alignment, - required_contiguous, + require_alignment, + require_contiguous, &chandle ) else: diff --git a/python/tvm/runtime/_tensor.py b/python/tvm/runtime/_tensor.py index fc176bf60097..3affbf55d563 100644 --- a/python/tvm/runtime/_tensor.py +++ b/python/tvm/runtime/_tensor.py @@ -44,16 +44,18 @@ def from_dlpack(ext_tensor): ext_tensor : object The external tensor to convert. - required_alignment : int + require_alignment : int The minimum required alignment to check for the tensor. - required_contiguous : bool + require_contiguous : bool Whether to check for contiguous memory. """ + # TODO(tvm-team): change to require_alignment=0 and require_contiguous=False + # once we update the compiler generated code to guard against misaligned access. return tvm_ffi.from_dlpack( ext_tensor, - required_alignment=64, - required_contiguous=True, + require_alignment=64, + require_contiguous=True, ) diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 305dd5ec9af6..6674de5260f5 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -603,8 +603,8 @@ MatchBufferRegion::MatchBufferRegion(Buffer buffer, BufferRegion source) { // Check data_alignment CHECK(source_buffer->data_alignment % buffer->data_alignment == 0) << "Trying to match buffer to another one with lower alignment requirement " - << " required_alignment=" << buffer->data_alignment - << ", provided_alignment=" << source_buffer->data_alignment; + << " required alignment=" << buffer->data_alignment + << ", provided alignment=" << source_buffer->data_alignment; // Check BufferType. AutoBroadcast is not allowed for now. CHECK(buffer->buffer_type == BufferType::kDefault && diff --git a/src/tir/transforms/arg_binder.cc b/src/tir/transforms/arg_binder.cc index 5b9e005b7ea3..72dab7826c62 100644 --- a/src/tir/transforms/arg_binder.cc +++ b/src/tir/transforms/arg_binder.cc @@ -93,8 +93,8 @@ void ArgBinder::BindBuffer(const Buffer& arg, const Buffer& value, const std::st << "Argument " << arg_name << " Buffer bind data type mismatch"; if (value->data_alignment % arg->data_alignment != 0) { LOG(WARNING) << "Trying to bind buffer to another one with lower alignment requirement " - << " required_alignment=" << arg->data_alignment - << ", provided_alignment=" << value->data_alignment; + << " required alignment=" << arg->data_alignment + << ", provided alignment=" << value->data_alignment; } if (value->elem_offset.defined()) { diff --git a/src/tir/transforms/lower_match_buffer.cc b/src/tir/transforms/lower_match_buffer.cc index d301e910f922..afbf57a0cd8a 100644 --- a/src/tir/transforms/lower_match_buffer.cc +++ b/src/tir/transforms/lower_match_buffer.cc @@ -152,8 +152,8 @@ class MatchBufferLower : public StmtExprMutator { // Step.1.2. Check data alignment if (source_buffer->data_alignment % buffer->data_alignment != 0) { LOG(WARNING) << "Trying to bind buffer to another one with lower alignment requirement " - << " required_alignment=" << buffer->data_alignment - << ", provided_alignment=" << source_buffer->data_alignment; + << " required alignment=" << buffer->data_alignment + << ", provided alignment=" << source_buffer->data_alignment; } if (is_zero(buffer->elem_offset)) { ICHECK(is_zero(source_buffer->elem_offset)) diff --git a/tests/python/relax/test_op_inspect.py b/tests/python/relax/test_op_inspect.py index cb9b2ded972e..2e6d81c613d5 100644 --- a/tests/python/relax/test_op_inspect.py +++ b/tests/python/relax/test_op_inspect.py @@ -171,7 +171,7 @@ def main(A: R.Tensor, axis: R.Prim("int64")): expected_strides = [1, 4] # use transpose to make strides non-compact x = np.zeros([4, 4], "int32").T - y = tvm_ffi.from_dlpack(x, required_alignment=4, required_contiguous=False) + y = tvm_ffi.from_dlpack(x, require_alignment=4, require_contiguous=False) res = [vm["main"](y, i) for i, _ in enumerate(view_shape)] tvm.ir.assert_structural_equal(res, expected_strides)