diff --git a/.spdx-ignore b/.spdx-ignore index 7263b5414f..866b2274e0 100644 --- a/.spdx-ignore +++ b/.spdx-ignore @@ -10,5 +10,7 @@ cuda_bindings/examples/* # Vendored cuda_core/cuda/core/_include/dlpack.h +cuda_core/cuda/core/_include/aoti_shim.h +cuda_core/cuda/core/_include/aoti_shim.def qa/ctk-next.drawio.svg diff --git a/cuda_core/build_hooks.py b/cuda_core/build_hooks.py index 16d393344b..444da18eb1 100644 --- a/cuda_core/build_hooks.py +++ b/cuda_core/build_hooks.py @@ -11,6 +11,7 @@ import glob import os import re +import subprocess import sys import tempfile import zipfile @@ -182,6 +183,25 @@ def get_sources(mod_name): # related to free-threading builds. extra_compile_args += ["-DCYTHON_TRACE_NOGIL=1", "-DCYTHON_USE_SYS_MONITORING=0"] + # On Windows, _tensor_bridge.pyx needs a stub import library so the MSVC + # linker can resolve the AOTI symbols (they live in torch_cpu.dll at + # runtime). We generate the .lib from a .def file at build time. + _aoti_extra_link_args = [] + if sys.platform == "win32": + _def_file = os.path.join("cuda", "core", "_include", "aoti_shim.def") + _lib_file = os.path.join("build", "aoti_shim.lib") + os.makedirs("build", exist_ok=True) + subprocess.check_call( # noqa: S603 + ["lib", f"/DEF:{_def_file}", f"/OUT:{_lib_file}", "/MACHINE:X64"], # noqa: S607 + stdout=subprocess.DEVNULL, + ) + _aoti_extra_link_args = [_lib_file] + + def get_extra_link_args(mod_name): + if mod_name == "_tensor_bridge" and _aoti_extra_link_args: + return extra_link_args + _aoti_extra_link_args + return extra_link_args + ext_modules = tuple( Extension( f"cuda.core.{mod.replace(os.path.sep, '.')}", @@ -193,7 +213,7 @@ def get_sources(mod_name): + all_include_dirs, language="c++", extra_compile_args=extra_compile_args, - extra_link_args=extra_link_args, + extra_link_args=get_extra_link_args(mod), ) for mod in module_names() ) diff --git a/cuda_core/cuda/core/_include/aoti_shim.def b/cuda_core/cuda/core/_include/aoti_shim.def new file mode 100644 index 0000000000..7fecf15276 --- /dev/null +++ b/cuda_core/cuda/core/_include/aoti_shim.def @@ -0,0 +1,29 @@ +; Stub import library definition for PyTorch's AOTI stable C ABI symbols. +; Used on Windows only: 'lib /DEF:aoti_shim.def /OUT:aoti_shim.lib /MACHINE:X64' +; generates a minimal import library that satisfies the MSVC linker. +; At runtime the symbols resolve from torch_cpu.dll (loaded by 'import torch'). +LIBRARY torch_cpu.dll +EXPORTS + aoti_torch_get_data_ptr + aoti_torch_get_dim + aoti_torch_get_sizes + aoti_torch_get_strides + aoti_torch_get_dtype + aoti_torch_dtype_float16 + aoti_torch_dtype_float32 + aoti_torch_dtype_float64 + aoti_torch_dtype_bfloat16 + aoti_torch_dtype_uint8 + aoti_torch_dtype_int8 + aoti_torch_dtype_int16 + aoti_torch_dtype_int32 + aoti_torch_dtype_int64 + aoti_torch_dtype_bool + aoti_torch_dtype_complex32 + aoti_torch_dtype_complex64 + aoti_torch_dtype_complex128 + aoti_torch_get_device_type + aoti_torch_get_device_index + aoti_torch_device_type_cpu + aoti_torch_device_type_cuda + aoti_torch_get_current_cuda_stream diff --git a/cuda_core/cuda/core/_include/aoti_shim.h b/cuda_core/cuda/core/_include/aoti_shim.h new file mode 100644 index 0000000000..2b3aa2b53d --- /dev/null +++ b/cuda_core/cuda/core/_include/aoti_shim.h @@ -0,0 +1,106 @@ +/* + * Vendored subset of PyTorch's AOT Inductor (AOTI) stable C ABI. + * Original: torch/csrc/inductor/aoti_torch/c/shim.h + * + * These are declarations only -- no definitions are provided. The actual + * symbols are exported by libtorch (loaded via torch._C with RTLD_GLOBAL) + * and resolved at runtime by the dynamic linker. This means PyTorch is + * NOT required at compile time. + * + * From PyTorch: + * + * Copyright (c) 2016- Facebook, Inc (Adam Paszke) + * Copyright (c) 2014- Facebook, Inc (Soumith Chintala) + * Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) + * Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) + * Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) + * Copyright (c) 2011-2013 NYU (Clement Farabet) + * Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) + * Copyright (c) 2006 Idiap Research Institute (Samy Bengio) + * Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) + * + * SPDX-License-Identifier: BSD-3-Clause + * See https://github.com/pytorch/pytorch/blob/main/LICENSE + */ + +#ifndef CUDA_CORE_AOTI_SHIM_H +#define CUDA_CORE_AOTI_SHIM_H + +#include + +/* + * On Windows the AOTI symbols live in torch_cpu.dll. We consume them + * via __declspec(dllimport) and a stub import library generated from + * aoti_shim.def at build time. On Linux/macOS the symbols are made + * visible at runtime through ctypes.CDLL(torch._C, RTLD_GLOBAL). + */ +#ifdef _WIN32 +# define AOTI_SHIM_API __declspec(dllimport) +#else +# define AOTI_SHIM_API +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +typedef int32_t AOTITorchError; + +/* Opaque tensor handle -- corresponds to at::Tensor on the C++ side. */ +struct AtenTensorOpaque; +typedef struct AtenTensorOpaque* AtenTensorHandle; + +/* ---- tensor metadata --------------------------------------------------- */ + +AOTI_SHIM_API AOTITorchError aoti_torch_get_data_ptr( + AtenTensorHandle tensor, void** ret_data_ptr); + +AOTI_SHIM_API AOTITorchError aoti_torch_get_dim( + AtenTensorHandle tensor, int64_t* ret_dim); + +AOTI_SHIM_API AOTITorchError aoti_torch_get_sizes( + AtenTensorHandle tensor, int64_t** ret_sizes); + +AOTI_SHIM_API AOTITorchError aoti_torch_get_strides( + AtenTensorHandle tensor, int64_t** ret_strides); + +/* ---- dtype ------------------------------------------------------------- */ + +AOTI_SHIM_API AOTITorchError aoti_torch_get_dtype( + AtenTensorHandle tensor, int32_t* ret_dtype); + +AOTI_SHIM_API int32_t aoti_torch_dtype_float16(void); +AOTI_SHIM_API int32_t aoti_torch_dtype_float32(void); +AOTI_SHIM_API int32_t aoti_torch_dtype_float64(void); +AOTI_SHIM_API int32_t aoti_torch_dtype_bfloat16(void); +AOTI_SHIM_API int32_t aoti_torch_dtype_uint8(void); +AOTI_SHIM_API int32_t aoti_torch_dtype_int8(void); +AOTI_SHIM_API int32_t aoti_torch_dtype_int16(void); +AOTI_SHIM_API int32_t aoti_torch_dtype_int32(void); +AOTI_SHIM_API int32_t aoti_torch_dtype_int64(void); +AOTI_SHIM_API int32_t aoti_torch_dtype_bool(void); +AOTI_SHIM_API int32_t aoti_torch_dtype_complex32(void); +AOTI_SHIM_API int32_t aoti_torch_dtype_complex64(void); +AOTI_SHIM_API int32_t aoti_torch_dtype_complex128(void); + +/* ---- device ------------------------------------------------------------ */ + +AOTI_SHIM_API AOTITorchError aoti_torch_get_device_type( + AtenTensorHandle tensor, int32_t* ret_device_type); + +AOTI_SHIM_API AOTITorchError aoti_torch_get_device_index( + AtenTensorHandle tensor, int32_t* ret_device_index); + +AOTI_SHIM_API int32_t aoti_torch_device_type_cpu(void); +AOTI_SHIM_API int32_t aoti_torch_device_type_cuda(void); + +/* ---- stream -------------------------------------------------------------- */ + +AOTI_SHIM_API AOTITorchError aoti_torch_get_current_cuda_stream( + int32_t device_index, void** ret_stream); + +#ifdef __cplusplus +} /* extern "C" */ +#endif + +#endif /* CUDA_CORE_AOTI_SHIM_H */ diff --git a/cuda_core/cuda/core/_memoryview.pyx b/cuda_core/cuda/core/_memoryview.pyx index e0439ef23c..3ebde8dcff 100644 --- a/cuda_core/cuda/core/_memoryview.pyx +++ b/cuda_core/cuda/core/_memoryview.pyx @@ -10,7 +10,9 @@ from libc.stdint cimport intptr_t from cuda.core._layout cimport _StridedLayout, get_strides_ptr from cuda.core._stream import Stream +import ctypes import functools +import sys import warnings import numpy @@ -29,6 +31,73 @@ from cuda.core._utils.cuda_utils cimport HANDLE_RETURN from cuda.core._memory import Buffer +# --------------------------------------------------------------------------- +# Lazy tensor bridge (avoids loading _tensor_bridge.so until torch is used) +# --------------------------------------------------------------------------- + +cdef object _tensor_bridge = None +# Cache: type(obj) -> True/False for the torch tensor check. +# Once a type is seen, we never re-check. +cdef dict _torch_type_cache = {} +# Tri-state: None = not checked, True/False = result of version check +cdef object _torch_version_ok = None + +cdef inline bint _torch_version_check(): + """Return True if 2.3 <= torch <= 2.11 (known AOTI ABI range). Memoized. + + Lower bound: AOTI functions we use were introduced in PyTorch 2.3. + Upper bound: the ``pyobj_to_aten_handle`` trick relies on the + THPVariable struct layout (PyObject_HEAD followed by at::Tensor cdata) + and the identity ``AtenTensorHandle == at::Tensor*``. Both are + undocumented internals that could change in a future PyTorch version. + We cap at the latest version we have tested against; unknown versions + fall back to the standard DLPack/CAI paths. Bump the upper bound + after verifying a new PyTorch release. + """ + global _torch_version_ok + if _torch_version_ok is not None: + return _torch_version_ok + torch = sys.modules.get("torch") + if torch is None: + _torch_version_ok = False + return False + try: + major, minor = int(torch.__version__.split(".")[0]), \ + int(torch.__version__.split(".")[1]) + _torch_version_ok = (2, 3) <= (major, minor) <= (2, 11) + except (ValueError, IndexError): + _torch_version_ok = False + return _torch_version_ok + + +cdef inline bint _is_torch_tensor(object obj): + cdef type tp = type(obj) + cdef object cached = _torch_type_cache.get(tp) + if cached is not None: + return cached + cdef str mod = tp.__module__ or "" + cdef bint result = mod.startswith("torch") and hasattr(obj, "data_ptr") \ + and _torch_version_check() + _torch_type_cache[tp] = result + return result + + +cdef object _get_tensor_bridge(): + """Bootstrap AOTI symbols, then import _tensor_bridge on first use.""" + global _tensor_bridge + if _tensor_bridge is not None: + return _tensor_bridge + torch_C = sys.modules.get("torch._C") + if torch_C is None: + raise RuntimeError( + "torch._C is not loaded; cannot initialise the tensor bridge. " + "Make sure PyTorch is imported before passing a torch.Tensor.") + ctypes.CDLL(torch_C.__file__, mode=ctypes.RTLD_GLOBAL) + from cuda.core import _tensor_bridge as tb + _tensor_bridge = tb + return _tensor_bridge + + try: from ml_dtypes import bfloat16 except ImportError: @@ -150,6 +219,9 @@ cdef class StridedMemoryView: Stream pointer for synchronization. If ``None``, no synchronization is performed. """ cdef StridedMemoryView buf = StridedMemoryView.__new__(cls) + if _is_torch_tensor(obj): + _get_tensor_bridge().view_as_torch_tensor(obj, stream_ptr, buf) + return buf view_as_dlpack(obj, stream_ptr, buf) return buf @@ -165,6 +237,9 @@ cdef class StridedMemoryView: Stream pointer for synchronization. If ``None``, no synchronization is performed. """ cdef StridedMemoryView buf = StridedMemoryView.__new__(cls) + if _is_torch_tensor(obj): + _get_tensor_bridge().view_as_torch_tensor(obj, stream_ptr, buf) + return buf view_as_cai(obj, stream_ptr, buf) return buf @@ -178,6 +253,9 @@ cdef class StridedMemoryView: An object implementing the `__array_interface__ `_ protocol (e.g., a numpy array). """ cdef StridedMemoryView buf = StridedMemoryView.__new__(cls) + if _is_torch_tensor(obj): + _get_tensor_bridge().view_as_torch_tensor(obj, None, buf) + return buf view_as_array_interface(obj, buf) return buf @@ -187,6 +265,8 @@ cdef class StridedMemoryView: Tries `DLPack `_ first, then falls back to `__cuda_array_interface__ `_. + ``torch.Tensor`` objects are transparently handled via a fast AOTI path + regardless of which protocol is selected. Parameters ---------- @@ -480,6 +560,10 @@ cdef class StridedMemoryView: if self._dtype is None: if self.dl_tensor != NULL: self._dtype = dtype_dlpack_to_numpy(&self.dl_tensor.dtype) + elif isinstance(self.metadata, int): + # AOTI dtype code stored by the torch tensor bridge + self._dtype = _get_tensor_bridge().resolve_aoti_dtype( + self.metadata) elif self.metadata is not None: self._dtype = _typestr2dtype(self.metadata["typestr"]) return self._dtype @@ -1122,6 +1206,16 @@ cpdef StridedMemoryView view_as_cai(obj, stream_ptr, view=None): as_cu(h_event), producer_s)) HANDLE_RETURN(cydriver.cuStreamWaitEvent( consumer_s, as_cu(h_event), 0)) + elif _is_torch_tensor(obj): + # PyTorch's __cuda_array_interface__ reports version 2 and + # omits the "stream" field, so the standard CAI sync path + # above is a no-op for torch tensors. This is unsafe: the + # consumer has no guarantee that the producer's work is + # visible. We fix this by querying PyTorch's current CUDA + # stream via the AOTI stable C ABI and performing the same + # event-based stream ordering. + _get_tensor_bridge().sync_torch_stream( + buf.device_id, (stream_ptr)) return buf diff --git a/cuda_core/cuda/core/_tensor_bridge.pyx b/cuda_core/cuda/core/_tensor_bridge.pyx new file mode 100644 index 0000000000..3030c45004 --- /dev/null +++ b/cuda_core/cuda/core/_tensor_bridge.pyx @@ -0,0 +1,385 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Tensor bridge: extract PyTorch tensor metadata via the AOTI stable C ABI. + +PyTorch is NOT required at build time. At runtime the AOTI symbols are +resolved from ``torch._C`` (which is loaded with ``RTLD_GLOBAL``). + +The ``pyobj_to_aten_handle`` trick exploits the internal layout of +``THPVariable`` (PyTorch's Python tensor wrapper). + +In PyTorch 2.10+ ``cdata`` is ``at::Tensor`` directly:: + + struct THPVariable { + PyObject_HEAD + at::Tensor cdata; // <-- &cdata is usable as AtenTensorHandle + ... + }; + +In PyTorch 2.3–2.9 ``cdata`` was ``c10::MaybeOwned``, +whose first member is ``bool isBorrowed_`` (padded to 8 bytes), +followed by the ``at::Tensor`` union member:: + + struct THPVariable { + PyObject_HEAD + c10::MaybeOwned cdata; + // MaybeOwned layout: { bool isBorrowed_ (8 bytes); at::Tensor own_; } + ... + }; + +In both cases the address of the ``at::Tensor`` inside ``cdata`` is +accepted by the AOTI stable C ABI functions as an ``AtenTensorHandle``. +The extra 8-byte skip for the ``isBorrowed_`` member is determined +at runtime from the PyTorch version (see ``_get_cdata_extra_offset``). + +Offsetting past ``PyObject_HEAD`` gives us the handle +without any Python attribute access or method calls (~14 ns for all +7 metadata queries). + +Credit: Emilio Castillo (ecastillo@nvidia.com) – original tensor-bridge POC. + +.. note:: + + This module must NOT be imported at ``cuda.core`` load time. It is + loaded lazily (by ``_memoryview.pyx``) only when the user actually + passes a ``torch.Tensor``. The caller must ensure that + ``torch._C`` has been re-opened with ``RTLD_GLOBAL`` *before* + importing this module so that the AOTI symbols are visible. +""" + +from libc.stdint cimport intptr_t, int8_t, int16_t, int32_t, int64_t, uint8_t + +from cuda.core._memoryview cimport StridedMemoryView +from cuda.core._layout cimport _StridedLayout +from cuda.bindings cimport cydriver +from cuda.core._resource_handles cimport ( + EventHandle, + create_event_handle_noctx, + as_cu, +) +from cuda.core._utils.cuda_utils cimport HANDLE_RETURN + +cdef extern from "Python.h": + ctypedef struct PyObject: + pass + +cdef extern from "_include/aoti_shim.h": + ctypedef int32_t AOTITorchError + + ctypedef struct AtenTensorOpaque: + pass + ctypedef AtenTensorOpaque* AtenTensorHandle + + # tensor metadata + AOTITorchError aoti_torch_get_data_ptr(AtenTensorHandle, void**) + AOTITorchError aoti_torch_get_dim(AtenTensorHandle, int64_t*) + AOTITorchError aoti_torch_get_sizes(AtenTensorHandle, int64_t**) + AOTITorchError aoti_torch_get_strides(AtenTensorHandle, int64_t**) + + # dtype + AOTITorchError aoti_torch_get_dtype(AtenTensorHandle, int32_t*) + int32_t aoti_torch_dtype_float16() + int32_t aoti_torch_dtype_float32() + int32_t aoti_torch_dtype_float64() + int32_t aoti_torch_dtype_bfloat16() + int32_t aoti_torch_dtype_uint8() + int32_t aoti_torch_dtype_int8() + int32_t aoti_torch_dtype_int16() + int32_t aoti_torch_dtype_int32() + int32_t aoti_torch_dtype_int64() + int32_t aoti_torch_dtype_bool() + int32_t aoti_torch_dtype_complex32() + int32_t aoti_torch_dtype_complex64() + int32_t aoti_torch_dtype_complex128() + + # device + AOTITorchError aoti_torch_get_device_type(AtenTensorHandle, int32_t*) + AOTITorchError aoti_torch_get_device_index(AtenTensorHandle, int32_t*) + int32_t aoti_torch_device_type_cpu() + int32_t aoti_torch_device_type_cuda() + + # stream + AOTITorchError aoti_torch_get_current_cuda_stream(int32_t, void**) + +import numpy +import sys + + +# --------------------------------------------------------------------------- +# Module-level state (initialised at import time — AOTI symbols are +# guaranteed visible because _memoryview bootstraps RTLD_GLOBAL before +# importing us) +# --------------------------------------------------------------------------- + +cdef int32_t _DEVICE_TYPE_CPU = aoti_torch_device_type_cpu() +cdef int32_t _DEVICE_TYPE_CUDA = aoti_torch_device_type_cuda() +cdef dict _aoti_dtype_map = None +cdef dict _aoti_itemsize_map = None + +# Extra byte offset to skip before reaching the at::Tensor inside +# THPVariable::cdata. See _get_cdata_extra_offset() for details. +# Tri-state: -1 = not yet probed, 0 or 8 = cached result. +cdef Py_ssize_t _cdata_extra_offset = -1 + + +cdef Py_ssize_t _get_cdata_extra_offset(): + """Return the extra byte offset caused by ``MaybeOwned``'s bool member. + + In PyTorch 2.3–2.9 ``THPVariable::cdata`` is + ``c10::MaybeOwned``, whose first member is + ``bool isBorrowed_`` (padded to pointer alignment = 8 bytes). + The actual ``at::Tensor`` sits *after* that bool, so we must + skip 8 extra bytes. + + From PyTorch 2.10 onward ``cdata`` is ``at::Tensor`` directly, + so no extra offset is needed. + """ + global _cdata_extra_offset + if _cdata_extra_offset >= 0: + return _cdata_extra_offset + torch = sys.modules.get("torch") + if torch is None: + raise RuntimeError("torch must be imported before _tensor_bridge") + try: + major = int(torch.__version__.split(".")[0]) + minor = int(torch.__version__.split(".")[1]) + except (ValueError, IndexError): + raise RuntimeError( + f"Cannot parse torch version: {torch.__version__!r}") + if (major, minor) < (2, 10): + _cdata_extra_offset = 8 # skip MaybeOwned::isBorrowed_ padding + else: + _cdata_extra_offset = 0 # at::Tensor directly at base + return _cdata_extra_offset + + +# --------------------------------------------------------------------------- +# pointer extraction +# --------------------------------------------------------------------------- + +cdef inline AtenTensorHandle pyobj_to_aten_handle(object obj): + """Extract AtenTensorHandle by offsetting past PyObject_HEAD. + + In PyTorch 2.3–2.9 the first field after PyObject_HEAD is + ``c10::MaybeOwned cdata``, whose ``isBorrowed_`` + bool member (padded to 8 bytes) precedes the actual + ``at::Tensor``. From 2.10 onward ``cdata`` is ``at::Tensor`` + directly. The extra offset is determined once by + :func:`_get_cdata_extra_offset` and cached. + """ + return ( + obj + sizeof(PyObject) + _get_cdata_extra_offset()) + + +cdef inline int check_aoti(AOTITorchError err, const char* name) except? -1: + """Raise RuntimeError if an AOTI call returned a non-zero error code.""" + if err != 0: + raise RuntimeError(f"{name.decode()} failed") + return 0 + + +# --------------------------------------------------------------------------- +# dtype mapping (AOTI int32 -> numpy dtype) +# --------------------------------------------------------------------------- + +cdef dict _build_dtype_map(): + try: + from ml_dtypes import bfloat16 as _bf16 # noqa: F811 + has_bfloat16 = True + except ImportError: + has_bfloat16 = False + + cdef dict m = { + aoti_torch_dtype_float16(): numpy.dtype(numpy.float16), + aoti_torch_dtype_float32(): numpy.dtype(numpy.float32), + aoti_torch_dtype_float64(): numpy.dtype(numpy.float64), + aoti_torch_dtype_uint8(): numpy.dtype(numpy.uint8), + aoti_torch_dtype_int8(): numpy.dtype(numpy.int8), + aoti_torch_dtype_int16(): numpy.dtype(numpy.int16), + aoti_torch_dtype_int32(): numpy.dtype(numpy.int32), + aoti_torch_dtype_int64(): numpy.dtype(numpy.int64), + aoti_torch_dtype_bool(): numpy.dtype(numpy.bool_), + aoti_torch_dtype_complex64(): numpy.dtype(numpy.complex64), + aoti_torch_dtype_complex128(): numpy.dtype(numpy.complex128), + } + if has_bfloat16: + m[aoti_torch_dtype_bfloat16()] = numpy.dtype(_bf16) + return m + + +cdef object _get_aoti_dtype(int32_t dtype_code): + global _aoti_dtype_map + if _aoti_dtype_map is None: + _aoti_dtype_map = _build_dtype_map() + result = _aoti_dtype_map.get(dtype_code) + if result is None: + raise TypeError(f"Unsupported AOTI dtype code: {dtype_code}") + return result + + +def resolve_aoti_dtype(int32_t dtype_code): + """Python-callable wrapper around _get_aoti_dtype (for lazy resolution).""" + return _get_aoti_dtype(dtype_code) + + +cdef dict _build_itemsize_map(): + return { + aoti_torch_dtype_bool(): sizeof(uint8_t), + aoti_torch_dtype_uint8(): sizeof(uint8_t), + aoti_torch_dtype_int8(): sizeof(int8_t), + aoti_torch_dtype_float16(): sizeof(int16_t), # no C float16 + aoti_torch_dtype_bfloat16(): sizeof(int16_t), # no C bfloat16 + aoti_torch_dtype_int16(): sizeof(int16_t), + aoti_torch_dtype_complex32(): 2 * sizeof(int16_t), # no C complex32 + aoti_torch_dtype_float32(): sizeof(float), + aoti_torch_dtype_int32(): sizeof(int32_t), + aoti_torch_dtype_complex64(): 2 * sizeof(float), + aoti_torch_dtype_float64(): sizeof(double), + aoti_torch_dtype_int64(): sizeof(int64_t), + aoti_torch_dtype_complex128(): 2 * sizeof(double), + } + + +cdef int _get_aoti_itemsize(int32_t dtype_code) except -1: + global _aoti_itemsize_map + if _aoti_itemsize_map is None: + _aoti_itemsize_map = _build_itemsize_map() + result = _aoti_itemsize_map.get(dtype_code) + if result is None: + raise TypeError(f"Unsupported AOTI dtype code: {dtype_code}") + return result + + +# --------------------------------------------------------------------------- +# Stream ordering helper +# --------------------------------------------------------------------------- + +cpdef int sync_torch_stream(int32_t device_index, + intptr_t consumer_s) except? -1: + """Establish stream ordering between PyTorch's current CUDA stream + and the given consumer stream. + + Records an event on PyTorch's current stream (the producer) and makes + the consumer stream wait on it. This is a no-op if both streams are + the same. + """ + cdef void* producer_s + cdef EventHandle h_event + + check_aoti(aoti_torch_get_current_cuda_stream(device_index, &producer_s), + b"aoti_torch_get_current_cuda_stream") + if producer_s != consumer_s: + with nogil: + h_event = create_event_handle_noctx( + cydriver.CUevent_flags.CU_EVENT_DISABLE_TIMING) + HANDLE_RETURN(cydriver.cuEventRecord( + as_cu(h_event), producer_s)) + HANDLE_RETURN(cydriver.cuStreamWaitEvent( + consumer_s, as_cu(h_event), 0)) + return 0 + + +# --------------------------------------------------------------------------- +# Public API: construct StridedMemoryView from a torch.Tensor +# --------------------------------------------------------------------------- + +def view_as_torch_tensor(object obj, object stream_ptr, view=None): + """Create/populate a :class:`StridedMemoryView` from a ``torch.Tensor``. + + This is a fast path that avoids DLPack/CAI protocol overhead by + reading tensor metadata directly through the AOTI stable C ABI. + + Parameters + ---------- + obj : torch.Tensor + The source tensor. + stream_ptr : int or None + Consumer stream pointer. When not ``-1``, stream ordering is + established between PyTorch's current CUDA stream (the producer) + and the consumer stream, matching the DLPack contract. + view : StridedMemoryView, optional + If provided, populate this existing view in-place. Otherwise a + new instance is created. + """ + cdef AtenTensorHandle handle = pyobj_to_aten_handle(obj) + cdef void* data_ptr + cdef int64_t ndim + cdef int64_t* sizes_ptr + cdef int64_t* strides_ptr + cdef int32_t dtype_code + cdef int32_t device_type, device_index + cdef StridedMemoryView buf + cdef int itemsize + cdef intptr_t _stream_ptr_int + cdef _StridedLayout layout + + check_aoti(aoti_torch_get_data_ptr(handle, &data_ptr), + b"aoti_torch_get_data_ptr") + check_aoti(aoti_torch_get_dim(handle, &ndim), + b"aoti_torch_get_dim") + check_aoti(aoti_torch_get_sizes(handle, &sizes_ptr), + b"aoti_torch_get_sizes") + check_aoti(aoti_torch_get_strides(handle, &strides_ptr), + b"aoti_torch_get_strides") + check_aoti(aoti_torch_get_dtype(handle, &dtype_code), + b"aoti_torch_get_dtype") + check_aoti(aoti_torch_get_device_type(handle, &device_type), + b"aoti_torch_get_device_type") + check_aoti(aoti_torch_get_device_index(handle, &device_index), + b"aoti_torch_get_device_index") + + # -- populate StridedMemoryView -- + if view is not None: + buf = view + else: + buf = StridedMemoryView.__new__(StridedMemoryView) + + buf.ptr = data_ptr + # PyTorch always reports tensors as writable via both DLPack + # (flags=0, no DLPACK_FLAG_BITMASK_READ_ONLY) and CAI + # (__cuda_array_interface__["data"] = (ptr, False)). Tensors that + # cannot be safely exported (requires_grad, conjugate, non-strided) + # are rejected with BufferError rather than marked read-only. + # The AOTI C ABI has no readonly query either, so False is correct. + buf.readonly = False + buf.exporting_obj = obj + buf.dl_tensor = NULL + buf.metadata = None + buf._buffer = None + + if device_type == _DEVICE_TYPE_CPU: + buf.device_id = -1 + buf.is_device_accessible = False + elif device_type == _DEVICE_TYPE_CUDA: + buf.device_id = device_index + buf.is_device_accessible = True + + # -- stream ordering (matches the DLPack contract) -- + if stream_ptr is not None: + _stream_ptr_int = int(stream_ptr) + if _stream_ptr_int != -1: + sync_torch_stream(device_index, _stream_ptr_int) + else: + raise BufferError( + f"Unsupported device type from torch tensor " + f"(AOTI device type id: {device_type})") + + # Defer full numpy dtype resolution until first .dtype access. + # Store the raw AOTI dtype code in metadata for lazy lookup. + buf.metadata = dtype_code + + # Build _StridedLayout. init_from_ptr copies shape/strides so we are + # safe even though they are borrowed pointers. + itemsize = _get_aoti_itemsize(dtype_code) + layout = _StridedLayout.__new__(_StridedLayout) + layout.init_from_ptr( + ndim, + sizes_ptr, + strides_ptr, + itemsize, + ) + buf._layout = layout + + return buf diff --git a/cuda_core/docs/source/release/1.0.0-notes.rst b/cuda_core/docs/source/release/1.0.0-notes.rst new file mode 100644 index 0000000000..34eff57100 --- /dev/null +++ b/cuda_core/docs/source/release/1.0.0-notes.rst @@ -0,0 +1,35 @@ +.. SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +.. SPDX-License-Identifier: Apache-2.0 + +.. currentmodule:: cuda.core + +``cuda.core`` 1.0.0 Release Notes +================================= + + +Highlights +---------- + +- TBD + + +New features +------------ + +- TBD + + +Fixes and enhancements +----------------------- + +- :class:`~utils.StridedMemoryView` now provides a fast path for ``torch.Tensor`` + objects via PyTorch's AOT Inductor (AOTI) stable C ABI. When a ``torch.Tensor`` + is passed to any ``from_*`` classmethod (``from_dlpack``, + ``from_cuda_array_interface``, ``from_array_interface``, or + ``from_any_interface``), tensor metadata is read directly from the underlying + C struct, bypassing the DLPack and CUDA Array Interface protocol overhead. + This yields ~7-20x faster ``StridedMemoryView`` construction for PyTorch + tensors (depending on whether stream ordering is required). Proper CUDA stream ordering is established between PyTorch's current + stream and the consumer stream, matching the DLPack synchronization contract. + Requires PyTorch >= 2.3. + (`#749 `__) diff --git a/cuda_core/pyproject.toml b/cuda_core/pyproject.toml index 660c2a577f..ed276a5ad5 100644 --- a/cuda_core/pyproject.toml +++ b/cuda_core/pyproject.toml @@ -118,4 +118,4 @@ archs = "native" [tool.cibuildwheel.windows] archs = "AMD64" before-build = "pip install delvewheel" -repair-wheel-command = "delvewheel repair --namespace-pkg cuda -w {dest_dir} {wheel}" +repair-wheel-command = "delvewheel repair --namespace-pkg cuda --exclude \"torch_cpu.dll;torch_python.dll\" -w {dest_dir} {wheel}" diff --git a/cuda_core/tests/test_utils.py b/cuda_core/tests/test_utils.py index 4bdebcbde3..0f11b5629c 100644 --- a/cuda_core/tests/test_utils.py +++ b/cuda_core/tests/test_utils.py @@ -76,9 +76,43 @@ def convert_strides_to_counts(strides, itemsize): return tuple(s // itemsize for s in strides) -@pytest.mark.parametrize( - "in_arr,", - ( +def _arr_ptr(arr): + """Return the data pointer of *arr* regardless of its type.""" + if torch is not None and isinstance(arr, torch.Tensor): + return arr.data_ptr() + if isinstance(arr, np.ndarray): + return arr.ctypes.data + return gpu_array_ptr(arr) + + +def _arr_strides_in_counts(arr): + """Return strides in element counts for *arr* regardless of its type.""" + if torch is not None and isinstance(arr, torch.Tensor): + return tuple(arr.stride()) + return convert_strides_to_counts(arr.strides, arr.dtype.itemsize) + + +def _arr_size(arr): + """Return the number of elements in *arr*.""" + if torch is not None and isinstance(arr, torch.Tensor): + return arr.numel() + return arr.size + + +def _arr_is_c_contiguous(arr): + if torch is not None and isinstance(arr, torch.Tensor): + return arr.is_contiguous() + return arr.flags.c_contiguous if hasattr(arr, "flags") else arr.flags["C_CONTIGUOUS"] + + +def _arr_is_writeable(arr): + if torch is not None and isinstance(arr, torch.Tensor): + return True # torch tensors are writable by default + return arr.flags.writeable if hasattr(arr.flags, "writeable") else True + + +def _cpu_array_samples(): + samples = [ np.empty(3, dtype=np.int32), np.empty((6, 6), dtype=np.float64)[::2, ::2], np.empty((3, 4), order="F"), @@ -88,8 +122,27 @@ def convert_strides_to_counts(strides, itemsize): np.frombuffer(b""), marks=requires_module(np, "2.1"), ), - ), -) + ] + if torch is not None: + samples += [ + pytest.param(torch.arange(12, dtype=torch.float32), id="torch-1d"), + pytest.param(torch.arange(24, dtype=torch.float32).reshape(2, 3, 4), id="torch-nd"), + pytest.param(torch.tensor(42.0), id="torch-scalar"), + pytest.param(torch.empty(0, dtype=torch.float32), id="torch-empty"), + pytest.param( + torch.arange(12, dtype=torch.float32).reshape(3, 4).t(), + id="torch-non-contiguous", + ), + pytest.param(torch.arange(100, dtype=torch.int64)[10:20], id="torch-sliced"), + pytest.param( + torch.arange(60, dtype=torch.float32).reshape(6, 10)[1:4, 2:7], + id="torch-sliced-2d", + ), + ] + return samples + + +@pytest.mark.parametrize("in_arr,", _cpu_array_samples()) class TestViewCPU: def test_args_viewable_as_strided_memory_cpu(self, in_arr): @args_viewable_as_strided_memory((0,)) @@ -113,16 +166,16 @@ def test_strided_memory_view_cpu_init(self, in_arr): def _check_view(self, view, in_arr): assert isinstance(view, StridedMemoryView) - assert view.ptr == in_arr.ctypes.data - assert view.shape == in_arr.shape - assert view.size == in_arr.size - strides_in_counts = convert_strides_to_counts(in_arr.strides, in_arr.dtype.itemsize) - assert (in_arr.flags.c_contiguous and view.strides is None) or view.strides == strides_in_counts - assert view.dtype == in_arr.dtype + assert view.ptr == _arr_ptr(in_arr) + expected_shape = tuple(in_arr.shape) + assert view.shape == expected_shape + assert view.size == _arr_size(in_arr) + strides_in_counts = _arr_strides_in_counts(in_arr) + assert (_arr_is_c_contiguous(in_arr) and view.strides is None) or view.strides == strides_in_counts assert view.device_id == -1 assert view.is_device_accessible is False assert view.exporting_obj is in_arr - assert view.readonly is not in_arr.flags.writeable + assert view.readonly is not _arr_is_writeable(in_arr) def gpu_array_samples(): @@ -141,10 +194,38 @@ def gpu_array_samples(): pytest.param(numba_cuda.device_array((2,), dtype=np.int8), False, id="numba-cuda-int8"), pytest.param(numba_cuda.device_array((4, 2), dtype=np.float32), True, id="numba-cuda-float32"), ] + if torch is not None: + samples += [ + pytest.param(torch.arange(12, dtype=torch.float32, device="cuda"), True, id="torch-1d"), + pytest.param( + torch.arange(24, dtype=torch.float32, device="cuda").reshape(2, 3, 4), + True, + id="torch-nd", + ), + pytest.param(torch.tensor(42.0, dtype=torch.float32, device="cuda"), False, id="torch-scalar"), + pytest.param(torch.empty(0, dtype=torch.float32, device="cuda"), False, id="torch-empty"), + pytest.param( + torch.arange(12, dtype=torch.float32, device="cuda").reshape(3, 4).t(), + True, + id="torch-non-contiguous", + ), + pytest.param( + torch.arange(100, dtype=torch.int64, device="cuda")[10:20], + True, + id="torch-sliced", + ), + pytest.param( + torch.arange(60, dtype=torch.float32, device="cuda").reshape(6, 10)[1:4, 2:7], + True, + id="torch-sliced-2d", + ), + ] return samples def gpu_array_ptr(arr): + if torch is not None and isinstance(arr, torch.Tensor): + return arr.data_ptr() if cp is not None and isinstance(arr, cp.ndarray): return arr.data.ptr if numba_cuda is not None and isinstance(arr, numba_cuda.cudadrv.devicearray.DeviceNDArray): @@ -192,18 +273,18 @@ def test_strided_memory_view_init(self, in_arr, use_stream): def _check_view(self, view, in_arr, dev): assert isinstance(view, StridedMemoryView) assert view.ptr == gpu_array_ptr(in_arr) - assert view.shape == in_arr.shape - assert view.size == in_arr.size - strides_in_counts = convert_strides_to_counts(in_arr.strides, in_arr.dtype.itemsize) - if in_arr.flags["C_CONTIGUOUS"]: + expected_shape = tuple(in_arr.shape) + assert view.shape == expected_shape + assert view.size == _arr_size(in_arr) + strides_in_counts = _arr_strides_in_counts(in_arr) + if _arr_is_c_contiguous(in_arr): assert view.strides in (None, strides_in_counts) else: assert view.strides == strides_in_counts - assert view.dtype == in_arr.dtype assert view.device_id == dev.device_id assert view.is_device_accessible is True assert view.exporting_obj is in_arr - # can't test view.readonly with CuPy or Numba... + # can't test view.readonly with CuPy, Numba, or torch... def test_strided_memory_view_dlpack_export_numpy_roundtrip(): @@ -712,3 +793,41 @@ def test_ml_dtypes_bfloat16_dlpack_requires_ml_dtypes(init_cuda, no_ml_dtypes, a smv = api(a, stream_ptr=0) with pytest.raises(NotImplementedError, match=r"requires `ml_dtypes`"): smv.dtype # noqa: B018 + + +# =================================================================== +# Tensor bridge (torch.Tensor fast path via AOTI stable C ABI) +# =================================================================== + +_torch_skip = pytest.mark.skipif(torch is None, reason="PyTorch is not installed") + + +@_torch_skip +@pytest.mark.parametrize( + "dtype", + [ + pytest.param("float16", id="float16"), + pytest.param("float32", id="float32"), + pytest.param("float64", id="float64"), + pytest.param("int8", id="int8"), + pytest.param("int16", id="int16"), + pytest.param("int32", id="int32"), + pytest.param("int64", id="int64"), + pytest.param("uint8", id="uint8"), + pytest.param("bool", id="bool"), + pytest.param("complex64", id="complex64"), + pytest.param("complex128", id="complex128"), + pytest.param( + "bfloat16", + id="bfloat16", + marks=pytest.mark.skipif(ml_dtypes is None, reason="ml_dtypes is not installed"), + ), + ], +) +def test_torch_tensor_bridge_dtypes(init_cuda, dtype): + """Verify that dtype mapping via the tensor bridge matches torch's own dtype.""" + torch_dtype = getattr(torch, dtype) + a = torch.tensor([1, 0, 1], dtype=torch_dtype, device="cuda") + smv = StridedMemoryView.from_any_interface(a, stream_ptr=0) + assert smv.dtype.itemsize == a.element_size() + assert smv.ptr == a.data_ptr()