From 9bbc16bc810861e161b2b9b7a59bbd1b9f1462a1 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Wed, 11 Mar 2026 20:08:38 -0700 Subject: [PATCH 1/7] [Refactor] Make primitive dtypes Python classes wrapping DataTypeCxx Convert primitive dtypes (f32, i32, etc.) from bare DataTypeCxx module-level variables into Python classes with a PrimitiveMeta metaclass. Each class has a .cxx attribute holding the underlying DataTypeCxx, and the metaclass delegates __eq__, __hash__, __getattr__ for backward compatibility. Update cook_dtype, to_quadrants_type, MAP_TYPE_IDS, and type utility functions to handle the new class-based types. Add PrimitiveBase checks in expr_init and quant.py. --- python/quadrants/lang/impl.py | 3 + python/quadrants/lang/util.py | 107 +++++++--- python/quadrants/types/primitive_types.py | 248 +++++++++++++--------- python/quadrants/types/quant.py | 22 +- python/quadrants/types/utils.py | 32 ++- tests/python/test_binding.py | 4 +- 6 files changed, 272 insertions(+), 144 deletions(-) diff --git a/python/quadrants/lang/impl.py b/python/quadrants/lang/impl.py index 0036a3bb78..701ec7c376 100644 --- a/python/quadrants/lang/impl.py +++ b/python/quadrants/lang/impl.py @@ -62,6 +62,7 @@ from quadrants.types.enums import SNodeGradType from quadrants.types.ndarray_type import NdarrayType from quadrants.types.primitive_types import ( + PrimitiveBase, all_types, f16, f32, @@ -110,6 +111,8 @@ def expr_init(rhs): return dict((key, expr_init(val)) for key, val in rhs.items()) if isinstance(rhs, _qd_core.DataTypeCxx): return rhs + if isinstance(rhs, type) and issubclass(rhs, PrimitiveBase): + return rhs.cxx if isinstance(rhs, _qd_core.Arch): return rhs if isinstance(rhs, _Ndrange): diff --git a/python/quadrants/lang/util.py b/python/quadrants/lang/util.py index 93536ce9da..17b4333a8e 100644 --- a/python/quadrants/lang/util.py +++ b/python/quadrants/lang/util.py @@ -12,22 +12,59 @@ from quadrants.lang import impl from quadrants.types import Template from quadrants.types.primitive_types import ( + PrimitiveBase, all_types, f16, + f16_cxx, f32, + f32_cxx, f64, + f64_cxx, i8, + i8_cxx, i16, + i16_cxx, i32, + i32_cxx, i64, + i64_cxx, u1, + u1_cxx, u8, + u8_cxx, u16, + u16_cxx, u32, + u32_cxx, u64, + u64_cxx, ) -MAP_TYPE_IDS = {id(dtype): dtype for dtype in all_types} +MAP_TYPE_IDS: dict[int, Any] = {id(dtype): dtype for dtype in all_types} +_all_cxx_objs = ( + f16_cxx, + f32_cxx, + f64_cxx, + i8_cxx, + i16_cxx, + i32_cxx, + i64_cxx, + u1_cxx, + u8_cxx, + u16_cxx, + u32_cxx, + u64_cxx, +) +for _cxx in _all_cxx_objs: + MAP_TYPE_IDS[id(_cxx)] = _cxx + +# Pre-computed id-based cache for cook_dtype hot path. +# Maps id(Python class) and id(DataTypeCxx) to the DataTypeCxx result. +_cook_cache: dict[int, _qd_core.DataTypeCxx] = {} +for _cls in (f16, f32, f64, i8, i16, i32, i64, u1, u8, u16, u32, u64): + _cook_cache[id(_cls)] = _cls.cxx +for _cxx in _all_cxx_objs: + _cook_cache[id(_cxx)] = _cxx def has_pytorch(): @@ -177,71 +214,74 @@ def to_quadrants_type(dt): dt (DataType): The desired data type to convert. Returns: - DataType: The counterpart data type in quadrants. + DataTypeCxx: The counterpart data type in quadrants (always returns DataTypeCxx). """ _type = type(dt) if _type is int: - return MAP_TYPE_IDS[dt] + return cook_dtype(MAP_TYPE_IDS[dt]) + + if isinstance(dt, type) and issubclass(dt, PrimitiveBase): + return dt.cxx if issubclass(_type, _qd_core.DataTypeCxx): return dt if dt == np.float32: - return f32 + return f32.cxx if dt == np.float64: - return f64 + return f64.cxx if dt == np.int32: - return i32 + return i32.cxx if dt == np.int64: - return i64 + return i64.cxx if dt == np.int8: - return i8 + return i8.cxx if dt == np.int16: - return i16 + return i16.cxx if dt == np.bool_: - return u1 + return u1.cxx if dt == np.uint8: - return u8 + return u8.cxx if dt == np.uint16: - return u16 + return u16.cxx if dt == np.uint32: - return u32 + return u32.cxx if dt == np.uint64: - return u64 + return u64.cxx if dt == np.half: - return f16 + return f16.cxx if has_pytorch(): import torch # pylint: disable=C0415 # pylint: disable=E1101 if dt == torch.float32: - return f32 + return f32.cxx if dt == torch.float64: - return f64 + return f64.cxx if dt == torch.int32: - return i32 + return i32.cxx if dt == torch.int64: - return i64 + return i64.cxx if dt == torch.int8: - return i8 + return i8.cxx if dt == torch.int16: - return i16 + return i16.cxx if dt == torch.bool: - return u1 + return u1.cxx if dt == torch.uint8: - return u8 + return u8.cxx if dt == torch.float16: - return f16 + return f16.cxx if hasattr(torch, "uint16"): if dt == torch.uint16: - return u16 + return u16.cxx if dt == torch.uint32: - return u32 + return u32.cxx if dt == torch.uint64: - return u64 + return u64.cxx raise RuntimeError(f"PyTorch doesn't support {dt.to_string()} data type before version 2.3.0.") @@ -264,8 +304,17 @@ def __hash__(self): def cook_dtype(dtype: Any) -> _qd_core.DataTypeCxx: - # Convert Python dtype to CPP dtype + """Convert Python dtype to C++ DataTypeCxx. + + Handles PrimitiveBase classes, raw DataTypeCxx instances, Type instances, + and Python builtins (float, int, bool). Uses id-based cache for hot paths. + """ + cached = _cook_cache.get(id(dtype)) + if cached is not None: + return cached _type = type(dtype) + if isinstance(dtype, type) and issubclass(dtype, PrimitiveBase): + return dtype.cxx if issubclass(_type, _qd_core.DataTypeCxx): return dtype if issubclass(_type, _qd_core.Type): @@ -275,7 +324,7 @@ def cook_dtype(dtype: Any) -> _qd_core.DataTypeCxx: if dtype is int: return impl.get_runtime().default_ip if dtype is bool: - return u1 + return u1.cxx raise ValueError(f"Invalid data type {dtype}") diff --git a/python/quadrants/types/primitive_types.py b/python/quadrants/types/primitive_types.py index 04b8fc4cb5..c6fcdc856f 100644 --- a/python/quadrants/types/primitive_types.py +++ b/python/quadrants/types/primitive_types.py @@ -1,159 +1,193 @@ -from typing import Union +from typing import ClassVar, Union from quadrants._lib import core as qd_python_core +from quadrants._lib.core.quadrants_python import DataTypeCxx # ======================================== -# real types +# Raw C++ DataType instances (internal use) +# ======================================== + +f16_cxx = qd_python_core.DataType_f16 +f32_cxx = qd_python_core.DataType_f32 +f64_cxx = qd_python_core.DataType_f64 + +i8_cxx = qd_python_core.DataType_i8 +i16_cxx = qd_python_core.DataType_i16 +i32_cxx = qd_python_core.DataType_i32 +i64_cxx = qd_python_core.DataType_i64 + +u1_cxx = qd_python_core.DataType_u1 +u8_cxx = qd_python_core.DataType_u8 +u16_cxx = qd_python_core.DataType_u16 +u32_cxx = qd_python_core.DataType_u32 +u64_cxx = qd_python_core.DataType_u64 + + +# ======================================== +# Metaclass and base class for Python dtype wrappers +# ======================================== -# ---------------------------------------- -float16 = qd_python_core.DataType_f16 -"""16-bit precision floating point data type. -""" +class PrimitiveMeta(type): + """Metaclass that makes dtype classes behave like DataTypeCxx objects. -# ---------------------------------------- + Delegates attribute access and comparisons to the underlying .cxx object, + allowing existing code that does e.g. dtype.to_string() to keep working. + """ -f16 = float16 -"""Alias for :const:`~quadrants.types.primitive_types.float16` -""" + def __eq__(cls, other): + if isinstance(other, PrimitiveMeta): + return cls is other + if isinstance(other, DataTypeCxx): + return cls.cxx == other + return NotImplemented -# ---------------------------------------- + def __ne__(cls, other): + if isinstance(other, PrimitiveMeta): + return cls is not other + if isinstance(other, DataTypeCxx): + return cls.cxx != other + return NotImplemented -float32 = qd_python_core.DataType_f32 -"""32-bit single precision floating point data type. -""" + def __hash__(cls): + return hash(cls.cxx) -# ---------------------------------------- + def __repr__(cls): + return cls.cxx.to_string() -f32 = float32 -"""Alias for :const:`~quadrants.types.primitive_types.float32` -""" + def __getattr__(cls, name): + try: + return getattr(cls.cxx, name) + except AttributeError: + raise AttributeError(f"type object '{cls.__name__}' has no attribute '{name}'") from None -# ---------------------------------------- -float64 = qd_python_core.DataType_f64 -"""64-bit double precision floating point data type. -""" +class PrimitiveBase(metaclass=PrimitiveMeta): + """Base class for all primitive dtype classes. -# ---------------------------------------- + Each subclass has a `cxx` class variable holding the corresponding DataTypeCxx instance. + Subclasses auto-register themselves in the _registry for reverse lookup (DataTypeCxx -> Python class). + """ + + cxx: ClassVar[DataTypeCxx] + _registry: ClassVar[dict[DataTypeCxx, "type[PrimitiveBase]"]] = {} + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + if hasattr(cls, "cxx"): + PrimitiveBase._registry[cls.cxx] = cls + + +def cxx_to_py(dtype_cxx: DataTypeCxx) -> "type[PrimitiveBase]": + """Convert a DataTypeCxx to its corresponding Python dtype class.""" + return PrimitiveBase._registry[dtype_cxx] -f64 = float64 -"""Alias for :const:`~quadrants.types.primitive_types.float64` -""" -# ---------------------------------------- # ======================================== -# Integer types +# Floating point types +# ======================================== + + +class f16(PrimitiveBase): + """16-bit precision floating point data type.""" + + cxx = f16_cxx -# ---------------------------------------- -int8 = qd_python_core.DataType_i8 -"""8-bit signed integer data type. -""" +class f32(PrimitiveBase): + """32-bit single precision floating point data type.""" -# ---------------------------------------- + cxx = f32_cxx -i8 = int8 -"""Alias for :const:`~quadrants.types.primitive_types.int8` -""" -# ---------------------------------------- +class f64(PrimitiveBase): + """64-bit double precision floating point data type.""" -int16 = qd_python_core.DataType_i16 -"""16-bit signed integer data type. -""" + cxx = f64_cxx -# ---------------------------------------- -i16 = int16 -"""Alias for :const:`~quadrants.types.primitive_types.int16` -""" +float16 = f16 +float32 = f32 +float64 = f64 -# ---------------------------------------- +# ======================================== +# Signed integer types +# ======================================== + + +class i8(PrimitiveBase): + """8-bit signed integer data type.""" -int32 = qd_python_core.DataType_i32 -"""32-bit signed integer data type. -""" + cxx = i8_cxx -# ---------------------------------------- -i32 = int32 -"""Alias for :const:`~quadrants.types.primitive_types.int32` -""" +class i16(PrimitiveBase): + """16-bit signed integer data type.""" -# ---------------------------------------- + cxx = i16_cxx -int64 = qd_python_core.DataType_i64 -"""64-bit signed integer data type. -""" -# ---------------------------------------- +class i32(PrimitiveBase): + """32-bit signed integer data type.""" -i64 = int64 -"""Alias for :const:`~quadrants.types.primitive_types.int64` -""" + cxx = i32_cxx -# ---------------------------------------- -uint8 = qd_python_core.DataType_u8 -"""8-bit unsigned integer data type. -""" +class i64(PrimitiveBase): + """64-bit signed integer data type.""" -# ---------------------------------------- + cxx = i64_cxx + + +int8 = i8 +int16 = i16 +int32 = i32 +int64 = i64 + +# ======================================== +# Unsigned integer types +# ======================================== -uint1 = qd_python_core.DataType_u1 -"""1-bit unsigned integer data type. Same as booleans. -""" -# ---------------------------------------- +class u1(PrimitiveBase): + """1-bit unsigned integer data type. Same as booleans.""" -u1 = uint1 -"""Alias for :const:`~quadrants.types.primitive_types.uint1` -""" + cxx = u1_cxx -# ---------------------------------------- -u8 = uint8 -"""Alias for :const:`~quadrants.types.primitive_types.uint8` -""" +class u8(PrimitiveBase): + """8-bit unsigned integer data type.""" -# ---------------------------------------- + cxx = u8_cxx -uint16 = qd_python_core.DataType_u16 -"""16-bit unsigned integer data type. -""" -# ---------------------------------------- +class u16(PrimitiveBase): + """16-bit unsigned integer data type.""" -u16 = uint16 -"""Alias for :const:`~quadrants.types.primitive_types.uint16` -""" + cxx = u16_cxx -# ---------------------------------------- -uint32 = qd_python_core.DataType_u32 -"""32-bit unsigned integer data type. -""" +class u32(PrimitiveBase): + """32-bit unsigned integer data type.""" -# ---------------------------------------- + cxx = u32_cxx -u32 = uint32 -"""Alias for :const:`~quadrants.types.primitive_types.uint32` -""" -# ---------------------------------------- +class u64(PrimitiveBase): + """64-bit unsigned integer data type.""" -uint64 = qd_python_core.DataType_u64 -"""64-bit unsigned integer data type. -""" + cxx = u64_cxx -# ---------------------------------------- -u64 = uint64 -"""Alias for :const:`~quadrants.types.primitive_types.uint64` -""" +uint1 = u1 +uint8 = u8 +uint16 = u16 +uint32 = u32 +uint64 = u64 -# ---------------------------------------- +# ======================================== +# Ref type (unchanged) +# ======================================== class RefType: @@ -165,6 +199,10 @@ def ref(tp): return RefType(tp) +# ======================================== +# Type sets for fast lookup +# ======================================== + real_types = {f16, f32, f64, float} real_type_ids = {id(t) for t in real_types} @@ -172,7 +210,13 @@ def ref(tp): integer_type_ids = {id(t) for t in integer_types} all_types = real_types | integer_types -type_ids = {id(t) for t in all_types} +_py_type_ids = {id(t) for t in all_types} + +_all_cxx = {f16_cxx, f32_cxx, f64_cxx, i8_cxx, i16_cxx, i32_cxx, i64_cxx, u1_cxx, u8_cxx, u16_cxx, u32_cxx, u64_cxx} +cxx_type_ids = {id(t) for t in _all_cxx} + +# Combined set: matches both Python classes and DataTypeCxx instances +type_ids = _py_type_ids | cxx_type_ids _python_primitive_types = Union[int, float, bool, str, None] diff --git a/python/quadrants/types/quant.py b/python/quadrants/types/quant.py index b67e79ea12..780c2f2088 100644 --- a/python/quadrants/types/quant.py +++ b/python/quadrants/types/quant.py @@ -3,12 +3,23 @@ For more details, read https://yuanming.quadrants.graphics/publication/2021-quanquadrants/quanquadrants.pdf. """ +from typing import Any + from quadrants._lib.utils import qd_python_core as _qd_python_core -from quadrants.types.primitive_types import i32 +from quadrants.types.primitive_types import PrimitiveBase, i32 _type_factory = _qd_python_core.get_type_factory_instance() +def _to_ptr(compute: Any) -> Any: + """Convert a dtype (Python class or DataTypeCxx) to a Type pointer for C++ APIs.""" + if isinstance(compute, type) and issubclass(compute, PrimitiveBase): + compute = compute.cxx + if isinstance(compute, _qd_python_core.DataTypeCxx): + return compute.get_ptr() + return compute + + def int(bits, signed=True, compute=None): # pylint: disable=W0622 """Generates a quantized type for integers. @@ -24,8 +35,7 @@ def int(bits, signed=True, compute=None): # pylint: disable=W0622 from quadrants.lang import impl # pylint: disable=C0415 compute = impl.get_runtime().default_ip if signed else impl.get_runtime().default_up - if isinstance(compute, _qd_python_core.DataTypeCxx): - compute = compute.get_ptr() + compute = _to_ptr(compute) return _type_factory.get_quant_int_type(bits, signed, compute) @@ -46,8 +56,7 @@ def fixed(bits, signed=True, max_value=1.0, compute=None, scale=None): from quadrants.lang import impl # pylint: disable=C0415 compute = impl.get_runtime().default_fp - if isinstance(compute, _qd_python_core.DataTypeCxx): - compute = compute.get_ptr() + compute = _to_ptr(compute) # TODO: handle cases with bits > 32 underlying_type = int(bits=bits, signed=signed, compute=i32) if scale is None: @@ -74,8 +83,7 @@ def float(exp, frac, signed=True, compute=None): # pylint: disable=W0622 from quadrants.lang import impl # pylint: disable=C0415 compute = impl.get_runtime().default_fp - if isinstance(compute, _qd_python_core.DataTypeCxx): - compute = compute.get_ptr() + compute = _to_ptr(compute) # Exponent is always unsigned exp_type = int(bits=exp, signed=False, compute=i32) # TODO: handle cases with frac > 32 diff --git a/python/quadrants/types/utils.py b/python/quadrants/types/utils.py index 0803085e2b..e268279d8d 100644 --- a/python/quadrants/types/utils.py +++ b/python/quadrants/types/utils.py @@ -1,11 +1,35 @@ +from typing import Any + from quadrants._lib import core as qd_python_core +from quadrants._lib.core.quadrants_python import DataTypeCxx +from quadrants.types.primitive_types import PrimitiveBase + +_is_signed_cxx = qd_python_core.is_signed +_is_integral_cxx = qd_python_core.is_integral +_is_real_cxx = qd_python_core.is_real +_is_tensor_cxx = qd_python_core.is_tensor + + +def _cook_if_needed(dt: Any) -> DataTypeCxx: + if isinstance(dt, type) and issubclass(dt, PrimitiveBase): + return dt.cxx + return dt # type: ignore[return-value] + + +def is_signed(dt: Any) -> bool: + return _is_signed_cxx(_cook_if_needed(dt)) # type: ignore[arg-type] + + +def is_integral(dt: Any) -> bool: + return _is_integral_cxx(_cook_if_needed(dt)) # type: ignore[arg-type] + -is_signed = qd_python_core.is_signed +def is_real(dt: Any) -> bool: + return _is_real_cxx(_cook_if_needed(dt)) # type: ignore[arg-type] -is_integral = qd_python_core.is_integral -is_real = qd_python_core.is_real +def is_tensor(dt: Any) -> bool: + return _is_tensor_cxx(_cook_if_needed(dt)) # type: ignore[arg-type] -is_tensor = qd_python_core.is_tensor __all__ = ["is_signed", "is_integral", "is_real", "is_tensor"] diff --git a/tests/python/test_binding.py b/tests/python/test_binding.py index 4bf7fdb754..0e6338787f 100644 --- a/tests/python/test_binding.py +++ b/tests/python/test_binding.py @@ -5,7 +5,7 @@ def test_binding(): qd.init() quadrants_lang = qd._lib.core print(quadrants_lang.BinaryOpType.mul) - one = quadrants_lang.make_const_expr_int(qd.i32, 1) - two = quadrants_lang.make_const_expr_int(qd.i32, 2) + one = quadrants_lang.make_const_expr_int(qd.i32.cxx, 1) + two = quadrants_lang.make_const_expr_int(qd.i32.cxx, 2) expr = quadrants_lang.make_binary_op_expr(quadrants_lang.BinaryOpType.add, one, two) print(quadrants_lang.make_global_store_stmt(None, None)) From 7921408673856cda3ec3044eca8740027663293a Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Sun, 19 Apr 2026 15:54:00 -0700 Subject: [PATCH 2/7] [Lint] Move imports to module top-level to fix pylint C0415 Move 'import inspect' in exception.py and the 'from quadrants.lang.exception import get_func_signature' in _kernel_impl_dataclass.py from inside functions to the module top-level, fixing the pylint import-outside-toplevel (C0415) errors introduced by the get_func_signature refactor. Made-with: Cursor --- python/quadrants/lang/_kernel_impl_dataclass.py | 3 +-- python/quadrants/lang/exception.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/python/quadrants/lang/_kernel_impl_dataclass.py b/python/quadrants/lang/_kernel_impl_dataclass.py index ebd298dff5..baec2c68ab 100644 --- a/python/quadrants/lang/_kernel_impl_dataclass.py +++ b/python/quadrants/lang/_kernel_impl_dataclass.py @@ -7,6 +7,7 @@ from quadrants.lang.ast import ( ASTTransformerFuncContext, ) +from quadrants.lang.exception import get_func_signature from quadrants.lang.kernel_arguments import ArgMetadata @@ -72,8 +73,6 @@ def extract_struct_locals_from_context(ctx: ASTTransformerFuncContext) -> set[st """ struct_locals = set() assert ctx.func is not None - from quadrants.lang.exception import get_func_signature - sig = get_func_signature(ctx.func.func) parameters = sig.parameters for param_name, parameter in parameters.items(): diff --git a/python/quadrants/lang/exception.py b/python/quadrants/lang/exception.py index beaf2eeb0e..8db8c5bffc 100644 --- a/python/quadrants/lang/exception.py +++ b/python/quadrants/lang/exception.py @@ -1,5 +1,7 @@ # type: ignore +import inspect + from quadrants._lib import core @@ -59,8 +61,6 @@ def get_ret(needed, provided): def get_func_signature(func): """Call inspect.signature with eval_str=True, converting annotation errors to QuadrantsSyntaxError.""" - import inspect - try: return inspect.signature(func, eval_str=True) except (NameError, AttributeError) as e: From 1ec365866be096444f422b261945b420048d094e Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Sun, 19 Apr 2026 16:39:35 -0700 Subject: [PATCH 3/7] [Test] Require data64 for dual_field_dtype debug-mode test The Vulkan/Metal backends on macOS lack f64 support and crash when running this test, which uses qd.f64 fields. Add require=qd.extension.data64 to skip on backends without double-precision support. Made-with: Cursor --- tests/python/test_ad_basics_fwd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/test_ad_basics_fwd.py b/tests/python/test_ad_basics_fwd.py index 760db9a7c4..9dcbd1f060 100644 --- a/tests/python/test_ad_basics_fwd.py +++ b/tests/python/test_ad_basics_fwd.py @@ -126,7 +126,7 @@ def clear_dual_test(): assert y.dual[None] == 4.0 -@test_utils.test(debug=True) +@test_utils.test(require=qd.extension.data64, debug=True) def test_dual_field_dtype_preserved_in_debug_mode(): """Regression: debug-mode checkbit must not shadow the outer dtype.""" x = qd.field(qd.f64, shape=(), needs_dual=True) From 5b681957f88b2162f09dfe60f2731a5264ec24ce Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Mon, 20 Apr 2026 00:02:41 -0700 Subject: [PATCH 4/7] [CI] Sync test_gpu.yml from main to fix CUDA torch version The branch is out of date with main, missing the workflow fix from #428 that pins torch to CUDA 12.8 for CUDA tests. Without this, `pip install torch` pulls torch 2.11.0 built against CUDA 13, which is incompatible with the CUDA 12.8 toolkit on the runner, causing all torch+CUDA dlpack tests to fail. Made-with: Cursor --- .github/workflows/test_gpu.yml | 55 ++++++++++++++++++---------------- 1 file changed, 29 insertions(+), 26 deletions(-) diff --git a/.github/workflows/test_gpu.yml b/.github/workflows/test_gpu.yml index 9eaeb5f679..c16c3cf748 100644 --- a/.github/workflows/test_gpu.yml +++ b/.github/workflows/test_gpu.yml @@ -71,11 +71,9 @@ jobs: python-version: '3.10' - name: install cuda stuff run: | - sudo apt-get install -y libcusolver-dev-12-8 - sudo apt-get install -y libcusparse-dev-12-8 - ls -lhd /usr/local/cuda* - ls -l /usr/local/cuda/lib64/libcusolver* - ls -l /usr/local/cuda/lib64/libcusparse* + sudo apt-get install -y libcusolver-dev-12-8 libcusolver-12-8 libcusparse-dev-12-8 libcusparse-12-8 libnvjitlink-12-8 libcublas-12-8 + echo "/usr/local/cuda/targets/x86_64-linux/lib" | sudo tee /etc/ld.so.conf.d/cuda-targets.conf + sudo ldconfig - name: install quadrants run: | set -x @@ -84,9 +82,7 @@ jobs: pip install dist/*.whl - name: run tests run: | - ls -lh chmod +x quadrants_cpp_tests - ls -lh export QD_LIB_DIR="$(python -c 'import quadrants as ti; print(ti.__path__[0])' | tail -n 1)/_lib/runtime" ./quadrants_cpp_tests --gtest_filter=-AMDGPU.* test_linux_cuda: @@ -105,26 +101,28 @@ jobs: python-version: '3.10' - name: install cuda stuff run: | - sudo apt-get install -y libcusolver-dev-12-8 - sudo apt-get install -y libcusparse-dev-12-8 - ls -lhd /usr/local/cuda* - ls -l /usr/local/cuda/lib64/libcusolver* - ls -l /usr/local/cuda/lib64/libcusparse* + sudo apt-get install -y libcusolver-dev-12-8 libcusolver-12-8 libcusparse-dev-12-8 libcusparse-12-8 libnvjitlink-12-8 libcublas-12-8 + echo "/usr/local/cuda/targets/x86_64-linux/lib" | sudo tee /etc/ld.so.conf.d/cuda-targets.conf + sudo ldconfig - name: install quadrants run: | set -x mkdir -p dist mv *.whl dist/ pip install dist/*.whl - - name: install torch and requirements_test.txt + - name: install test requirements run: | - pip install torch pip install --group test pip install -r requirements_test_xdist.txt - - name: run tests + - name: run tests (without torch) + run: | + python tests/run_tests.py -r 1 -v --arch cuda -m "not needs_torch" + - name: install torch and run torch tests run: | - ls -lh - python tests/run_tests.py -r 1 -v --arch cuda + # pin to torch 12.8 for now, until we update the driver on the + # github runner gpu nodes + pip install torch --index-url https://download.pytorch.org/whl/cu128 + python tests/run_tests.py -r 1 -v --arch cuda -m needs_torch test_linux_vulkan: name: Test Linux Vulkan runs-on: gpu-t4-4-core @@ -148,14 +146,17 @@ jobs: mkdir -p dist mv *.whl dist/ pip install dist/*.whl - - name: install torch and requirements_test.txt + - name: install test requirements run: | - pip install torch pip install --group test pip install -r requirements_test_xdist.txt - - name: run tests + - name: run tests (without torch) run: | - python tests/run_tests.py -r 3 -v --arch vulkan + python tests/run_tests.py -r 3 -v --arch vulkan -m "not needs_torch" + - name: install torch and run torch tests + run: | + pip install torch + python tests/run_tests.py -r 3 -v --arch vulkan -m needs_torch test_linux_amdgpu: name: Test Linux AMD GPU runs-on: amdgpu @@ -176,14 +177,16 @@ jobs: mkdir -p dist mv *.whl dist/ pip install dist/*.whl - - name: install torch - run: | - pip install --upgrade torch --index-url https://download.pytorch.org/whl/rocm6.4 - name: install test requirements run: | pip install --group test pip install -r requirements_test_xdist.txt - - name: run tests + - name: run tests (without torch) + run: | + export QD_AMDGPU_V520=1 + python tests/run_tests.py -t 4 -r 1 -v --arch amdgpu -m "not needs_torch" + - name: install torch and run torch tests run: | + pip install --upgrade torch --index-url https://download.pytorch.org/whl/rocm6.4 export QD_AMDGPU_V520=1 # uses .cpu() before running asserts on dlpack torch tensors - python tests/run_tests.py -t 4 -r 1 -v --arch amdgpu + python tests/run_tests.py -t 4 -r 1 -v --arch amdgpu -m needs_torch From 8dd34074265ec1c11e6a8b74d812202c28a09627 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Mon, 20 Apr 2026 00:41:41 -0700 Subject: [PATCH 5/7] [Type] Make dtype classes callable via metaclass __call__ After the dtype refactor, calling primitive dtype classes (e.g. qd.u32(value), qd.i32(value)) inside @qd.func bodies broke because the new Python classes inherited type.__call__ which tried to instantiate them with no __init__ args. This surfaced as pyright errors in main's _tile16.py once merged with this branch. Add a metaclass __call__ that delegates to cls.cxx(value), preserving the DataTypeCxx-style cast/IR-emit semantics expected by call sites. Made-with: Cursor --- python/quadrants/types/primitive_types.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/quadrants/types/primitive_types.py b/python/quadrants/types/primitive_types.py index c6fcdc856f..4c8b4d8547 100644 --- a/python/quadrants/types/primitive_types.py +++ b/python/quadrants/types/primitive_types.py @@ -61,6 +61,9 @@ def __getattr__(cls, name): except AttributeError: raise AttributeError(f"type object '{cls.__name__}' has no attribute '{name}'") from None + def __call__(cls, value): + return cls.cxx(value) + class PrimitiveBase(metaclass=PrimitiveMeta): """Base class for all primitive dtype classes. From 5cd9bf3472e172d2f05195fe7f541f1ac2b4ff3b Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Mon, 20 Apr 2026 01:22:55 -0700 Subject: [PATCH 6/7] [Type] Fix pyright errors when calling dtype classes (qd.u32(j), etc.) Pyright uses __init__/__new__ signatures (not metaclass __call__) when type-checking class instantiation. With the new PrimitiveBase classes, calls like qd.u32(j) and qd.i32(...) reported "Expected 0 positional arguments" because PrimitiveBase had no __init__. Add a typing-only __init__ stub (never invoked at runtime since PrimitiveMeta.__call__ short-circuits instantiation). Also reference PrimitiveMeta directly in _install_python_backend_dtype_call instead of going through the misleadingly-named ``type(f32)`` indirection, so pyright sees the correct ``(cls, value)`` signature when calling the saved ``_original`` callable. Made-with: Cursor --- python/quadrants/lang/misc.py | 7 +++---- python/quadrants/types/primitive_types.py | 9 ++++++++- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/python/quadrants/lang/misc.py b/python/quadrants/lang/misc.py index fa08aad77c..fbd3dc55ac 100644 --- a/python/quadrants/lang/misc.py +++ b/python/quadrants/lang/misc.py @@ -13,7 +13,7 @@ from quadrants.lang.expr import Expr from quadrants.lang.impl import axes, get_runtime from quadrants.profiler.kernel_profiler import get_default_kernel_profiler -from quadrants.types.primitive_types import f32, f64, i32, i64 +from quadrants.types.primitive_types import PrimitiveMeta, f32, f64, i32, i64 warnings.filterwarnings("once", category=DeprecationWarning, module="quadrants") @@ -314,15 +314,14 @@ def _install_python_backend_dtype_call(): return _dtype_call_installed = True - DataTypeCxx = type(f32) - _original = DataTypeCxx.__call__ + _original = PrimitiveMeta.__call__ def _dtype_call(self, value): if impl.is_python_backend(): return float(value) if self in _FLOAT_DTYPES else int(value) return _original(self, value) - DataTypeCxx.__call__ = _dtype_call # type: ignore[assignment] + PrimitiveMeta.__call__ = _dtype_call # type: ignore[assignment] def init( diff --git a/python/quadrants/types/primitive_types.py b/python/quadrants/types/primitive_types.py index 4c8b4d8547..be0b1c0cc2 100644 --- a/python/quadrants/types/primitive_types.py +++ b/python/quadrants/types/primitive_types.py @@ -1,4 +1,4 @@ -from typing import ClassVar, Union +from typing import Any, ClassVar, Union from quadrants._lib import core as qd_python_core from quadrants._lib.core.quadrants_python import DataTypeCxx @@ -75,6 +75,13 @@ class PrimitiveBase(metaclass=PrimitiveMeta): cxx: ClassVar[DataTypeCxx] _registry: ClassVar[dict[DataTypeCxx, "type[PrimitiveBase]"]] = {} + # NOTE: __init__ is never executed at runtime because PrimitiveMeta.__call__ + # short-circuits class instantiation and returns cls.cxx(value) directly. + # This stub exists purely so pyright recognises ``f32(value)`` etc. as a + # legal call site (pyright uses __init__/__new__ signatures, not metaclass + # __call__, when type-checking class instantiation). + def __init__(self, value: Any = None) -> None: ... + def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) if hasattr(cls, "cxx"): From f683c0cd8961f31f4c498f5072246a649cebe34e Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Mon, 20 Apr 2026 02:37:56 -0700 Subject: [PATCH 7/7] [Fix] Accept PrimitiveBase classes in qd._lib.core.data_type_size The C++ binding for data_type_size only accepts DataTypeCxx instances. With the new dtype Python wrapper classes (qd.i8, qd.f32, etc.), callers that pass qd.i8 directly to qd._lib.core.data_type_size hit a TypeError. Monkey-patch the binding to transparently unwrap PrimitiveBase subclasses to their underlying .cxx attribute before calling the C++ implementation. This unblocks the merged-with-main test_shared_array_not_accumulated_across_offloads test which passes qd.i8/qd.f32/qd.u32 directly to the C++ binding. Made-with: Cursor --- python/quadrants/types/primitive_types.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/python/quadrants/types/primitive_types.py b/python/quadrants/types/primitive_types.py index be0b1c0cc2..d9f62d5ff6 100644 --- a/python/quadrants/types/primitive_types.py +++ b/python/quadrants/types/primitive_types.py @@ -93,6 +93,21 @@ def cxx_to_py(dtype_cxx: DataTypeCxx) -> "type[PrimitiveBase]": return PrimitiveBase._registry[dtype_cxx] +# Wrap C++ helpers that take a DataTypeCxx so they also accept PrimitiveBase +# Python wrapper classes (e.g. qd.i8). The C++ binding only accepts DataTypeCxx, +# so we transparently unwrap PrimitiveBase subclasses to their .cxx attribute. +_orig_data_type_size = qd_python_core.data_type_size + + +def _data_type_size(dtype: Any) -> int: + if isinstance(dtype, type) and issubclass(dtype, PrimitiveBase): + dtype = dtype.cxx + return _orig_data_type_size(dtype) + + +qd_python_core.data_type_size = _data_type_size + + # ======================================== # Floating point types # ========================================