diff --git a/python/quadrants/lang/impl.py b/python/quadrants/lang/impl.py index 7a791bd2f4..a1472078ce 100644 --- a/python/quadrants/lang/impl.py +++ b/python/quadrants/lang/impl.py @@ -63,6 +63,7 @@ from quadrants.types.enums import SNodeGradType from quadrants.types.ndarray_type import NdarrayType from quadrants.types.primitive_types import ( + PrimitiveBase, all_types, f16, f32, @@ -111,6 +112,8 @@ def expr_init(rhs): return dict((key, expr_init(val)) for key, val in rhs.items()) if isinstance(rhs, _qd_core.DataTypeCxx): return rhs + if isinstance(rhs, type) and issubclass(rhs, PrimitiveBase): + return rhs.cxx if isinstance(rhs, _qd_core.Arch): return rhs if isinstance(rhs, _Ndrange): diff --git a/python/quadrants/lang/misc.py b/python/quadrants/lang/misc.py index 38b29a0408..9a645014ec 100644 --- a/python/quadrants/lang/misc.py +++ b/python/quadrants/lang/misc.py @@ -13,7 +13,7 @@ from quadrants.lang.expr import Expr from quadrants.lang.impl import axes, get_runtime from quadrants.profiler.kernel_profiler import get_default_kernel_profiler -from quadrants.types.primitive_types import f32, f64, i32, i64 +from quadrants.types.primitive_types import PrimitiveMeta, f32, f64, i32, i64 warnings.filterwarnings("once", category=DeprecationWarning, module="quadrants") @@ -314,15 +314,14 @@ def _install_python_backend_dtype_call(): return _dtype_call_installed = True - DataTypeCxx = type(f32) - _original = DataTypeCxx.__call__ + _original = PrimitiveMeta.__call__ def _dtype_call(self, value): if impl.is_python_backend(): return float(value) if self in _FLOAT_DTYPES else int(value) return _original(self, value) - DataTypeCxx.__call__ = _dtype_call # type: ignore[assignment] + PrimitiveMeta.__call__ = _dtype_call # type: ignore[assignment] def init( diff --git a/python/quadrants/lang/util.py b/python/quadrants/lang/util.py index a744954222..c358db6d95 100644 --- a/python/quadrants/lang/util.py +++ b/python/quadrants/lang/util.py @@ -13,22 +13,59 @@ from quadrants.lang import impl from quadrants.types import Template from quadrants.types.primitive_types import ( + PrimitiveBase, all_types, f16, + f16_cxx, f32, + f32_cxx, f64, + f64_cxx, i8, + i8_cxx, i16, + i16_cxx, i32, + i32_cxx, i64, + i64_cxx, u1, + u1_cxx, u8, + u8_cxx, u16, + u16_cxx, u32, + u32_cxx, u64, + u64_cxx, ) -MAP_TYPE_IDS = {id(dtype): dtype for dtype in all_types} +MAP_TYPE_IDS: dict[int, Any] = {id(dtype): dtype for dtype in all_types} +_all_cxx_objs = ( + f16_cxx, + f32_cxx, + f64_cxx, + i8_cxx, + i16_cxx, + i32_cxx, + i64_cxx, + u1_cxx, + u8_cxx, + u16_cxx, + u32_cxx, + u64_cxx, +) +for _cxx in _all_cxx_objs: + MAP_TYPE_IDS[id(_cxx)] = _cxx + +# Pre-computed id-based cache for cook_dtype hot path. +# Maps id(Python class) and id(DataTypeCxx) to the DataTypeCxx result. +_cook_cache: dict[int, _qd_core.DataTypeCxx] = {} +for _cls in (f16, f32, f64, i8, i16, i32, i64, u1, u8, u16, u32, u64): + _cook_cache[id(_cls)] = _cls.cxx +for _cxx in _all_cxx_objs: + _cook_cache[id(_cxx)] = _cxx def has_pytorch(): @@ -178,71 +215,74 @@ def to_quadrants_type(dt): dt (DataType): The desired data type to convert. Returns: - DataType: The counterpart data type in quadrants. + DataTypeCxx: The counterpart data type in quadrants (always returns DataTypeCxx). """ _type = type(dt) if _type is int: - return MAP_TYPE_IDS[dt] + return cook_dtype(MAP_TYPE_IDS[dt]) + + if isinstance(dt, type) and issubclass(dt, PrimitiveBase): + return dt.cxx if issubclass(_type, _qd_core.DataTypeCxx): return dt if dt == np.float32: - return f32 + return f32.cxx if dt == np.float64: - return f64 + return f64.cxx if dt == np.int32: - return i32 + return i32.cxx if dt == np.int64: - return i64 + return i64.cxx if dt == np.int8: - return i8 + return i8.cxx if dt == np.int16: - return i16 + return i16.cxx if dt == np.bool_: - return u1 + return u1.cxx if dt == np.uint8: - return u8 + return u8.cxx if dt == np.uint16: - return u16 + return u16.cxx if dt == np.uint32: - return u32 + return u32.cxx if dt == np.uint64: - return u64 + return u64.cxx if dt == np.half: - return f16 + return f16.cxx if has_pytorch(): import torch # pylint: disable=C0415 # pylint: disable=E1101 if dt == torch.float32: - return f32 + return f32.cxx if dt == torch.float64: - return f64 + return f64.cxx if dt == torch.int32: - return i32 + return i32.cxx if dt == torch.int64: - return i64 + return i64.cxx if dt == torch.int8: - return i8 + return i8.cxx if dt == torch.int16: - return i16 + return i16.cxx if dt == torch.bool: - return u1 + return u1.cxx if dt == torch.uint8: - return u8 + return u8.cxx if dt == torch.float16: - return f16 + return f16.cxx if hasattr(torch, "uint16"): if dt == torch.uint16: - return u16 + return u16.cxx if dt == torch.uint32: - return u32 + return u32.cxx if dt == torch.uint64: - return u64 + return u64.cxx raise RuntimeError(f"PyTorch doesn't support {dt.to_string()} data type before version 2.3.0.") @@ -265,8 +305,17 @@ def __hash__(self): def cook_dtype(dtype: Any) -> _qd_core.DataTypeCxx: - # Convert Python dtype to CPP dtype + """Convert Python dtype to C++ DataTypeCxx. + + Handles PrimitiveBase classes, raw DataTypeCxx instances, Type instances, + and Python builtins (float, int, bool). Uses id-based cache for hot paths. + """ + cached = _cook_cache.get(id(dtype)) + if cached is not None: + return cached _type = type(dtype) + if isinstance(dtype, type) and issubclass(dtype, PrimitiveBase): + return dtype.cxx if issubclass(_type, _qd_core.DataTypeCxx): return dtype if issubclass(_type, _qd_core.Type): @@ -276,7 +325,7 @@ def cook_dtype(dtype: Any) -> _qd_core.DataTypeCxx: if dtype is int: return impl.get_runtime().default_ip if dtype is bool: - return u1 + return u1.cxx raise ValueError(f"Invalid data type {dtype}") diff --git a/python/quadrants/types/primitive_types.py b/python/quadrants/types/primitive_types.py index 04b8fc4cb5..d9f62d5ff6 100644 --- a/python/quadrants/types/primitive_types.py +++ b/python/quadrants/types/primitive_types.py @@ -1,159 +1,218 @@ -from typing import Union +from typing import Any, ClassVar, Union from quadrants._lib import core as qd_python_core +from quadrants._lib.core.quadrants_python import DataTypeCxx # ======================================== -# real types +# Raw C++ DataType instances (internal use) +# ======================================== + +f16_cxx = qd_python_core.DataType_f16 +f32_cxx = qd_python_core.DataType_f32 +f64_cxx = qd_python_core.DataType_f64 + +i8_cxx = qd_python_core.DataType_i8 +i16_cxx = qd_python_core.DataType_i16 +i32_cxx = qd_python_core.DataType_i32 +i64_cxx = qd_python_core.DataType_i64 + +u1_cxx = qd_python_core.DataType_u1 +u8_cxx = qd_python_core.DataType_u8 +u16_cxx = qd_python_core.DataType_u16 +u32_cxx = qd_python_core.DataType_u32 +u64_cxx = qd_python_core.DataType_u64 + + +# ======================================== +# Metaclass and base class for Python dtype wrappers +# ======================================== + + +class PrimitiveMeta(type): + """Metaclass that makes dtype classes behave like DataTypeCxx objects. + + Delegates attribute access and comparisons to the underlying .cxx object, + allowing existing code that does e.g. dtype.to_string() to keep working. + """ + + def __eq__(cls, other): + if isinstance(other, PrimitiveMeta): + return cls is other + if isinstance(other, DataTypeCxx): + return cls.cxx == other + return NotImplemented + + def __ne__(cls, other): + if isinstance(other, PrimitiveMeta): + return cls is not other + if isinstance(other, DataTypeCxx): + return cls.cxx != other + return NotImplemented + + def __hash__(cls): + return hash(cls.cxx) + + def __repr__(cls): + return cls.cxx.to_string() + + def __getattr__(cls, name): + try: + return getattr(cls.cxx, name) + except AttributeError: + raise AttributeError(f"type object '{cls.__name__}' has no attribute '{name}'") from None + + def __call__(cls, value): + return cls.cxx(value) + + +class PrimitiveBase(metaclass=PrimitiveMeta): + """Base class for all primitive dtype classes. + + Each subclass has a `cxx` class variable holding the corresponding DataTypeCxx instance. + Subclasses auto-register themselves in the _registry for reverse lookup (DataTypeCxx -> Python class). + """ -# ---------------------------------------- + cxx: ClassVar[DataTypeCxx] + _registry: ClassVar[dict[DataTypeCxx, "type[PrimitiveBase]"]] = {} -float16 = qd_python_core.DataType_f16 -"""16-bit precision floating point data type. -""" + # NOTE: __init__ is never executed at runtime because PrimitiveMeta.__call__ + # short-circuits class instantiation and returns cls.cxx(value) directly. + # This stub exists purely so pyright recognises ``f32(value)`` etc. as a + # legal call site (pyright uses __init__/__new__ signatures, not metaclass + # __call__, when type-checking class instantiation). + def __init__(self, value: Any = None) -> None: ... -# ---------------------------------------- + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + if hasattr(cls, "cxx"): + PrimitiveBase._registry[cls.cxx] = cls -f16 = float16 -"""Alias for :const:`~quadrants.types.primitive_types.float16` -""" -# ---------------------------------------- +def cxx_to_py(dtype_cxx: DataTypeCxx) -> "type[PrimitiveBase]": + """Convert a DataTypeCxx to its corresponding Python dtype class.""" + return PrimitiveBase._registry[dtype_cxx] -float32 = qd_python_core.DataType_f32 -"""32-bit single precision floating point data type. -""" -# ---------------------------------------- +# Wrap C++ helpers that take a DataTypeCxx so they also accept PrimitiveBase +# Python wrapper classes (e.g. qd.i8). The C++ binding only accepts DataTypeCxx, +# so we transparently unwrap PrimitiveBase subclasses to their .cxx attribute. +_orig_data_type_size = qd_python_core.data_type_size -f32 = float32 -"""Alias for :const:`~quadrants.types.primitive_types.float32` -""" -# ---------------------------------------- +def _data_type_size(dtype: Any) -> int: + if isinstance(dtype, type) and issubclass(dtype, PrimitiveBase): + dtype = dtype.cxx + return _orig_data_type_size(dtype) -float64 = qd_python_core.DataType_f64 -"""64-bit double precision floating point data type. -""" -# ---------------------------------------- +qd_python_core.data_type_size = _data_type_size -f64 = float64 -"""Alias for :const:`~quadrants.types.primitive_types.float64` -""" -# ---------------------------------------- # ======================================== -# Integer types +# Floating point types +# ======================================== + + +class f16(PrimitiveBase): + """16-bit precision floating point data type.""" + + cxx = f16_cxx + -# ---------------------------------------- +class f32(PrimitiveBase): + """32-bit single precision floating point data type.""" -int8 = qd_python_core.DataType_i8 -"""8-bit signed integer data type. -""" + cxx = f32_cxx -# ---------------------------------------- -i8 = int8 -"""Alias for :const:`~quadrants.types.primitive_types.int8` -""" +class f64(PrimitiveBase): + """64-bit double precision floating point data type.""" + + cxx = f64_cxx + + +float16 = f16 +float32 = f32 +float64 = f64 + +# ======================================== +# Signed integer types +# ======================================== -# ---------------------------------------- -int16 = qd_python_core.DataType_i16 -"""16-bit signed integer data type. -""" +class i8(PrimitiveBase): + """8-bit signed integer data type.""" -# ---------------------------------------- + cxx = i8_cxx -i16 = int16 -"""Alias for :const:`~quadrants.types.primitive_types.int16` -""" -# ---------------------------------------- +class i16(PrimitiveBase): + """16-bit signed integer data type.""" -int32 = qd_python_core.DataType_i32 -"""32-bit signed integer data type. -""" + cxx = i16_cxx -# ---------------------------------------- -i32 = int32 -"""Alias for :const:`~quadrants.types.primitive_types.int32` -""" +class i32(PrimitiveBase): + """32-bit signed integer data type.""" -# ---------------------------------------- + cxx = i32_cxx -int64 = qd_python_core.DataType_i64 -"""64-bit signed integer data type. -""" -# ---------------------------------------- +class i64(PrimitiveBase): + """64-bit signed integer data type.""" -i64 = int64 -"""Alias for :const:`~quadrants.types.primitive_types.int64` -""" + cxx = i64_cxx -# ---------------------------------------- -uint8 = qd_python_core.DataType_u8 -"""8-bit unsigned integer data type. -""" +int8 = i8 +int16 = i16 +int32 = i32 +int64 = i64 -# ---------------------------------------- +# ======================================== +# Unsigned integer types +# ======================================== -uint1 = qd_python_core.DataType_u1 -"""1-bit unsigned integer data type. Same as booleans. -""" -# ---------------------------------------- +class u1(PrimitiveBase): + """1-bit unsigned integer data type. Same as booleans.""" -u1 = uint1 -"""Alias for :const:`~quadrants.types.primitive_types.uint1` -""" + cxx = u1_cxx -# ---------------------------------------- -u8 = uint8 -"""Alias for :const:`~quadrants.types.primitive_types.uint8` -""" +class u8(PrimitiveBase): + """8-bit unsigned integer data type.""" -# ---------------------------------------- + cxx = u8_cxx -uint16 = qd_python_core.DataType_u16 -"""16-bit unsigned integer data type. -""" -# ---------------------------------------- +class u16(PrimitiveBase): + """16-bit unsigned integer data type.""" -u16 = uint16 -"""Alias for :const:`~quadrants.types.primitive_types.uint16` -""" + cxx = u16_cxx -# ---------------------------------------- -uint32 = qd_python_core.DataType_u32 -"""32-bit unsigned integer data type. -""" +class u32(PrimitiveBase): + """32-bit unsigned integer data type.""" -# ---------------------------------------- + cxx = u32_cxx -u32 = uint32 -"""Alias for :const:`~quadrants.types.primitive_types.uint32` -""" -# ---------------------------------------- +class u64(PrimitiveBase): + """64-bit unsigned integer data type.""" -uint64 = qd_python_core.DataType_u64 -"""64-bit unsigned integer data type. -""" + cxx = u64_cxx -# ---------------------------------------- -u64 = uint64 -"""Alias for :const:`~quadrants.types.primitive_types.uint64` -""" +uint1 = u1 +uint8 = u8 +uint16 = u16 +uint32 = u32 +uint64 = u64 -# ---------------------------------------- +# ======================================== +# Ref type (unchanged) +# ======================================== class RefType: @@ -165,6 +224,10 @@ def ref(tp): return RefType(tp) +# ======================================== +# Type sets for fast lookup +# ======================================== + real_types = {f16, f32, f64, float} real_type_ids = {id(t) for t in real_types} @@ -172,7 +235,13 @@ def ref(tp): integer_type_ids = {id(t) for t in integer_types} all_types = real_types | integer_types -type_ids = {id(t) for t in all_types} +_py_type_ids = {id(t) for t in all_types} + +_all_cxx = {f16_cxx, f32_cxx, f64_cxx, i8_cxx, i16_cxx, i32_cxx, i64_cxx, u1_cxx, u8_cxx, u16_cxx, u32_cxx, u64_cxx} +cxx_type_ids = {id(t) for t in _all_cxx} + +# Combined set: matches both Python classes and DataTypeCxx instances +type_ids = _py_type_ids | cxx_type_ids _python_primitive_types = Union[int, float, bool, str, None] diff --git a/python/quadrants/types/quant.py b/python/quadrants/types/quant.py index b67e79ea12..780c2f2088 100644 --- a/python/quadrants/types/quant.py +++ b/python/quadrants/types/quant.py @@ -3,12 +3,23 @@ For more details, read https://yuanming.quadrants.graphics/publication/2021-quanquadrants/quanquadrants.pdf. """ +from typing import Any + from quadrants._lib.utils import qd_python_core as _qd_python_core -from quadrants.types.primitive_types import i32 +from quadrants.types.primitive_types import PrimitiveBase, i32 _type_factory = _qd_python_core.get_type_factory_instance() +def _to_ptr(compute: Any) -> Any: + """Convert a dtype (Python class or DataTypeCxx) to a Type pointer for C++ APIs.""" + if isinstance(compute, type) and issubclass(compute, PrimitiveBase): + compute = compute.cxx + if isinstance(compute, _qd_python_core.DataTypeCxx): + return compute.get_ptr() + return compute + + def int(bits, signed=True, compute=None): # pylint: disable=W0622 """Generates a quantized type for integers. @@ -24,8 +35,7 @@ def int(bits, signed=True, compute=None): # pylint: disable=W0622 from quadrants.lang import impl # pylint: disable=C0415 compute = impl.get_runtime().default_ip if signed else impl.get_runtime().default_up - if isinstance(compute, _qd_python_core.DataTypeCxx): - compute = compute.get_ptr() + compute = _to_ptr(compute) return _type_factory.get_quant_int_type(bits, signed, compute) @@ -46,8 +56,7 @@ def fixed(bits, signed=True, max_value=1.0, compute=None, scale=None): from quadrants.lang import impl # pylint: disable=C0415 compute = impl.get_runtime().default_fp - if isinstance(compute, _qd_python_core.DataTypeCxx): - compute = compute.get_ptr() + compute = _to_ptr(compute) # TODO: handle cases with bits > 32 underlying_type = int(bits=bits, signed=signed, compute=i32) if scale is None: @@ -74,8 +83,7 @@ def float(exp, frac, signed=True, compute=None): # pylint: disable=W0622 from quadrants.lang import impl # pylint: disable=C0415 compute = impl.get_runtime().default_fp - if isinstance(compute, _qd_python_core.DataTypeCxx): - compute = compute.get_ptr() + compute = _to_ptr(compute) # Exponent is always unsigned exp_type = int(bits=exp, signed=False, compute=i32) # TODO: handle cases with frac > 32 diff --git a/python/quadrants/types/utils.py b/python/quadrants/types/utils.py index 0803085e2b..e268279d8d 100644 --- a/python/quadrants/types/utils.py +++ b/python/quadrants/types/utils.py @@ -1,11 +1,35 @@ +from typing import Any + from quadrants._lib import core as qd_python_core +from quadrants._lib.core.quadrants_python import DataTypeCxx +from quadrants.types.primitive_types import PrimitiveBase + +_is_signed_cxx = qd_python_core.is_signed +_is_integral_cxx = qd_python_core.is_integral +_is_real_cxx = qd_python_core.is_real +_is_tensor_cxx = qd_python_core.is_tensor + + +def _cook_if_needed(dt: Any) -> DataTypeCxx: + if isinstance(dt, type) and issubclass(dt, PrimitiveBase): + return dt.cxx + return dt # type: ignore[return-value] + + +def is_signed(dt: Any) -> bool: + return _is_signed_cxx(_cook_if_needed(dt)) # type: ignore[arg-type] + + +def is_integral(dt: Any) -> bool: + return _is_integral_cxx(_cook_if_needed(dt)) # type: ignore[arg-type] + -is_signed = qd_python_core.is_signed +def is_real(dt: Any) -> bool: + return _is_real_cxx(_cook_if_needed(dt)) # type: ignore[arg-type] -is_integral = qd_python_core.is_integral -is_real = qd_python_core.is_real +def is_tensor(dt: Any) -> bool: + return _is_tensor_cxx(_cook_if_needed(dt)) # type: ignore[arg-type] -is_tensor = qd_python_core.is_tensor __all__ = ["is_signed", "is_integral", "is_real", "is_tensor"] diff --git a/tests/python/test_binding.py b/tests/python/test_binding.py index 4bf7fdb754..0e6338787f 100644 --- a/tests/python/test_binding.py +++ b/tests/python/test_binding.py @@ -5,7 +5,7 @@ def test_binding(): qd.init() quadrants_lang = qd._lib.core print(quadrants_lang.BinaryOpType.mul) - one = quadrants_lang.make_const_expr_int(qd.i32, 1) - two = quadrants_lang.make_const_expr_int(qd.i32, 2) + one = quadrants_lang.make_const_expr_int(qd.i32.cxx, 1) + two = quadrants_lang.make_const_expr_int(qd.i32.cxx, 2) expr = quadrants_lang.make_binary_op_expr(quadrants_lang.BinaryOpType.add, one, two) print(quadrants_lang.make_global_store_stmt(None, None))