diff --git a/ffi/include/tvm/ffi/container/tensor.h b/ffi/include/tvm/ffi/container/tensor.h index b5be116b491c..99fb29d10830 100644 --- a/ffi/include/tvm/ffi/container/tensor.h +++ b/ffi/include/tvm/ffi/container/tensor.h @@ -35,6 +35,16 @@ namespace tvm { namespace ffi { +/*! + * \brief Check if the device uses direct address, where address of data indicate alignment. + * \param device The input device. + * \return True if the device uses direct address, false otherwise. + */ +inline bool IsDirectAddressDevice(const DLDevice& device) { + return device.device_type <= kDLCUDAHost || device.device_type == kDLCUDAManaged || + device.device_type == kDLROCM || device.device_type == kDLROCMHost; +} + /*! * \brief check if a DLTensor is contiguous. * \param arr The input DLTensor. @@ -67,11 +77,7 @@ inline bool IsContiguous(const DLTensor& arr) { * \return True if the data is aligned to the given alignment, false otherwise. */ inline bool IsAligned(const DLTensor& arr, size_t alignment) { - // whether the device uses direct address mapping instead of indirect buffer - bool direct_address = arr.device.device_type <= kDLCUDAHost || - arr.device.device_type == kDLCUDAManaged || - arr.device.device_type == kDLROCM || arr.device.device_type == kDLROCMHost; - if (direct_address) { + if (IsDirectAddressDevice(arr.device)) { return (reinterpret_cast(static_cast(arr.data) + arr.byte_offset) % alignment == 0); } else { @@ -278,6 +284,12 @@ class Tensor : public ObjectRef { * \return True if the Tensor is contiguous, false otherwise. */ bool IsContiguous() const { return tvm::ffi::IsContiguous(*get()); } + /*! + * \brief Check if the Tensor data is aligned to the given alignment. + * \param alignment The alignment to check. + * \return True if the Tensor data is aligned to the given alignment, false otherwise. + */ + bool IsAligned(size_t alignment) const { return tvm::ffi::IsAligned(*get(), alignment); } /*! * \brief Create a Tensor from a NDAllocator. * \param alloc The NDAllocator. diff --git a/ffi/python/tvm_ffi/_convert.py b/ffi/python/tvm_ffi/_convert.py index 168dd15b531b..b1b972633d86 100644 --- a/ffi/python/tvm_ffi/_convert.py +++ b/ffi/python/tvm_ffi/_convert.py @@ -61,9 +61,7 @@ def convert(value: Any) -> Any: elif value is None: return None elif hasattr(value, "__dlpack__"): - return core.from_dlpack( - value, required_alignment=core.__dlpack_auto_import_required_alignment__ - ) + return core.from_dlpack(value) elif isinstance(value, Exception): return core._convert_to_ffi_error(value) else: diff --git a/ffi/python/tvm_ffi/cython/function.pxi b/ffi/python/tvm_ffi/cython/function.pxi index 0161ec4292ab..28d4ba5a0094 100644 --- a/ffi/python/tvm_ffi/cython/function.pxi +++ b/ffi/python/tvm_ffi/cython/function.pxi @@ -109,8 +109,7 @@ cdef inline int make_args(tuple py_args, TVMFFIAny* out, list temp_args, out[i].v_ptr = (arg).chandle elif torch is not None and isinstance(arg, torch.Tensor): is_cuda = arg.is_cuda - arg = from_dlpack(torch.utils.dlpack.to_dlpack(arg), - required_alignment=__dlpack_auto_import_required_alignment__) + arg = from_dlpack(torch.utils.dlpack.to_dlpack(arg)) out[i].type_index = kTVMFFITensor out[i].v_ptr = (arg).chandle temp_dltensor = TVMFFITensorGetDLTensorPtr((arg).chandle) @@ -123,7 +122,7 @@ cdef inline int make_args(tuple py_args, TVMFFIAny* out, list temp_args, ctx_stream[0] = temp_ptr temp_args.append(arg) elif hasattr(arg, "__dlpack__"): - arg = from_dlpack(arg, required_alignment=__dlpack_auto_import_required_alignment__) + arg = from_dlpack(arg) out[i].type_index = kTVMFFITensor out[i].v_ptr = (arg).chandle temp_args.append(arg) diff --git a/ffi/python/tvm_ffi/cython/tensor.pxi b/ffi/python/tvm_ffi/cython/tensor.pxi index b09ac42eb99c..4658422ca524 100644 --- a/ffi/python/tvm_ffi/cython/tensor.pxi +++ b/ffi/python/tvm_ffi/cython/tensor.pxi @@ -16,7 +16,6 @@ # under the License. __dlpack_version__ = (1, 1) -__dlpack_auto_import_required_alignment__ = 8 _CLASS_TENSOR = None @@ -45,13 +44,13 @@ cdef void _c_dlpack_versioned_deleter(object pycaps): cdef inline int _from_dlpack( - object dltensor, int required_alignment, - int required_contiguous, TVMFFIObjectHandle* out + object dltensor, int require_alignment, + int require_contiguous, TVMFFIObjectHandle* out ) except -1: cdef DLManagedTensor* ptr cdef int c_api_ret_code - cdef int c_req_alignment = required_alignment - cdef int c_req_contiguous = required_contiguous + cdef int c_req_alignment = require_alignment + cdef int c_req_contiguous = require_contiguous if pycapsule.PyCapsule_IsValid(dltensor, _c_str_dltensor): ptr = pycapsule.PyCapsule_GetPointer(dltensor, _c_str_dltensor) with nogil: @@ -66,13 +65,13 @@ cdef inline int _from_dlpack( cdef inline int _from_dlpack_versioned( - object dltensor, int required_alignment, - int required_contiguous, TVMFFIObjectHandle* out + object dltensor, int require_alignment, + int require_contiguous, TVMFFIObjectHandle* out ) except -1: cdef DLManagedTensorVersioned* ptr cdef int c_api_ret_code - cdef int c_req_alignment = required_alignment - cdef int c_req_contiguous = required_contiguous + cdef int c_req_alignment = require_alignment + cdef int c_req_contiguous = require_contiguous if pycapsule.PyCapsule_IsValid(dltensor, _c_str_dltensor_versioned): ptr = pycapsule.PyCapsule_GetPointer( dltensor, _c_str_dltensor_versioned) @@ -87,7 +86,7 @@ cdef inline int _from_dlpack_versioned( raise ValueError("Expect a dltensor_versioned field, PyCapsule can only be consumed once") -def from_dlpack(ext_tensor, *, required_alignment=8, required_contiguous=True): +def from_dlpack(ext_tensor, *, require_alignment=0, require_contiguous=False): """ Convert an external tensor to an Tensor. @@ -96,10 +95,10 @@ def from_dlpack(ext_tensor, *, required_alignment=8, required_contiguous=True): ext_tensor : object The external tensor to convert. - required_alignment : int + require_alignment : int The minimum required alignment to check for the tensor. - required_contiguous : bool + require_contiguous : bool Whether to check for contiguous memory. Returns @@ -116,38 +115,38 @@ def from_dlpack(ext_tensor, *, required_alignment=8, required_contiguous=True): if favor_legacy_dlpack: _from_dlpack( ext_tensor.__dlpack__(), - required_alignment, - required_contiguous, + require_alignment, + require_contiguous, &chandle ) else: try: _from_dlpack_versioned( ext_tensor.__dlpack__(max_version=__dlpack_version__), - required_alignment, - required_contiguous, + require_alignment, + require_contiguous, &chandle ) except TypeError: _from_dlpack( ext_tensor.__dlpack__(), - required_alignment, - required_contiguous, + require_alignment, + require_contiguous, &chandle ) else: if pycapsule.PyCapsule_IsValid(ext_tensor, _c_str_dltensor_versioned): _from_dlpack_versioned( ext_tensor, - required_alignment, - required_contiguous, + require_alignment, + require_contiguous, &chandle ) elif pycapsule.PyCapsule_IsValid(ext_tensor, _c_str_dltensor): _from_dlpack( ext_tensor, - required_alignment, - required_contiguous, + require_alignment, + require_contiguous, &chandle ) else: diff --git a/python/tvm/runtime/_tensor.py b/python/tvm/runtime/_tensor.py index fc176bf60097..3affbf55d563 100644 --- a/python/tvm/runtime/_tensor.py +++ b/python/tvm/runtime/_tensor.py @@ -44,16 +44,18 @@ def from_dlpack(ext_tensor): ext_tensor : object The external tensor to convert. - required_alignment : int + require_alignment : int The minimum required alignment to check for the tensor. - required_contiguous : bool + require_contiguous : bool Whether to check for contiguous memory. """ + # TODO(tvm-team): change to require_alignment=0 and require_contiguous=False + # once we update the compiler generated code to guard against misaligned access. return tvm_ffi.from_dlpack( ext_tensor, - required_alignment=64, - required_contiguous=True, + require_alignment=64, + require_contiguous=True, ) diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 305dd5ec9af6..6674de5260f5 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -603,8 +603,8 @@ MatchBufferRegion::MatchBufferRegion(Buffer buffer, BufferRegion source) { // Check data_alignment CHECK(source_buffer->data_alignment % buffer->data_alignment == 0) << "Trying to match buffer to another one with lower alignment requirement " - << " required_alignment=" << buffer->data_alignment - << ", provided_alignment=" << source_buffer->data_alignment; + << " required alignment=" << buffer->data_alignment + << ", provided alignment=" << source_buffer->data_alignment; // Check BufferType. AutoBroadcast is not allowed for now. CHECK(buffer->buffer_type == BufferType::kDefault && diff --git a/src/tir/transforms/arg_binder.cc b/src/tir/transforms/arg_binder.cc index 5b9e005b7ea3..72dab7826c62 100644 --- a/src/tir/transforms/arg_binder.cc +++ b/src/tir/transforms/arg_binder.cc @@ -93,8 +93,8 @@ void ArgBinder::BindBuffer(const Buffer& arg, const Buffer& value, const std::st << "Argument " << arg_name << " Buffer bind data type mismatch"; if (value->data_alignment % arg->data_alignment != 0) { LOG(WARNING) << "Trying to bind buffer to another one with lower alignment requirement " - << " required_alignment=" << arg->data_alignment - << ", provided_alignment=" << value->data_alignment; + << " required alignment=" << arg->data_alignment + << ", provided alignment=" << value->data_alignment; } if (value->elem_offset.defined()) { diff --git a/src/tir/transforms/lower_match_buffer.cc b/src/tir/transforms/lower_match_buffer.cc index d301e910f922..afbf57a0cd8a 100644 --- a/src/tir/transforms/lower_match_buffer.cc +++ b/src/tir/transforms/lower_match_buffer.cc @@ -152,8 +152,8 @@ class MatchBufferLower : public StmtExprMutator { // Step.1.2. Check data alignment if (source_buffer->data_alignment % buffer->data_alignment != 0) { LOG(WARNING) << "Trying to bind buffer to another one with lower alignment requirement " - << " required_alignment=" << buffer->data_alignment - << ", provided_alignment=" << source_buffer->data_alignment; + << " required alignment=" << buffer->data_alignment + << ", provided alignment=" << source_buffer->data_alignment; } if (is_zero(buffer->elem_offset)) { ICHECK(is_zero(source_buffer->elem_offset)) diff --git a/tests/python/relax/test_op_inspect.py b/tests/python/relax/test_op_inspect.py index cb9b2ded972e..2e6d81c613d5 100644 --- a/tests/python/relax/test_op_inspect.py +++ b/tests/python/relax/test_op_inspect.py @@ -171,7 +171,7 @@ def main(A: R.Tensor, axis: R.Prim("int64")): expected_strides = [1, 4] # use transpose to make strides non-compact x = np.zeros([4, 4], "int32").T - y = tvm_ffi.from_dlpack(x, required_alignment=4, required_contiguous=False) + y = tvm_ffi.from_dlpack(x, require_alignment=4, require_contiguous=False) res = [vm["main"](y, i) for i, _ in enumerate(view_shape)] tvm.ir.assert_structural_equal(res, expected_strides)