diff --git a/.basedpyright/baseline.json b/.basedpyright/baseline.json index aba089a6..192e5270 100644 --- a/.basedpyright/baseline.json +++ b/.basedpyright/baseline.json @@ -9,14 +9,6 @@ "lineCount": 1 } }, - { - "code": "reportUnusedFunction", - "range": { - "startColumn": 4, - "endColumn": 19, - "lineCount": 1 - } - }, { "code": "reportUnknownParameterType", "range": { @@ -91,14 +83,6 @@ "lineCount": 1 } }, - { - "code": "reportUnusedFunction", - "range": { - "startColumn": 4, - "endColumn": 32, - "lineCount": 1 - } - }, { "code": "reportAny", "range": { @@ -123,14 +107,6 @@ "lineCount": 1 } }, - { - "code": "reportUnusedFunction", - "range": { - "startColumn": 4, - "endColumn": 34, - "lineCount": 1 - } - }, { "code": "reportCallIssue", "range": { @@ -10643,30 +10619,6 @@ "lineCount": 1 } }, - { - "code": "reportUnannotatedClassAttribute", - "range": { - "startColumn": 13, - "endColumn": 17, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 8, - "endColumn": 17, - "lineCount": 1 - } - }, - { - "code": "reportUnannotatedClassAttribute", - "range": { - "startColumn": 13, - "endColumn": 17, - "lineCount": 1 - } - }, { "code": "reportImplicitOverride", "range": { @@ -10675,14 +10627,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 42, - "endColumn": 51, - "lineCount": 1 - } - }, { "code": "reportImplicitOverride", "range": { @@ -10707,14 +10651,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 53, - "endColumn": 62, - "lineCount": 1 - } - }, { "code": "reportAny", "range": { @@ -10827,14 +10763,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 48, - "endColumn": 57, - "lineCount": 1 - } - }, { "code": "reportAny", "range": { @@ -10939,46 +10867,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 8, - "endColumn": 16, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 21, - "endColumn": 30, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 20, - "endColumn": 29, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 20, - "endColumn": 43, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 21, - "endColumn": 30, - "lineCount": 1 - } - }, { "code": "reportAny", "range": { @@ -11589,30 +11477,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 24, - "endColumn": 27, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 24, - "endColumn": 27, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 42, - "endColumn": 45, - "lineCount": 1 - } - }, { "code": "reportAny", "range": { @@ -12245,70 +12109,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 18, - "endColumn": 21, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 18, - "endColumn": 21, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 40, - "endColumn": 54, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 40, - "endColumn": 66, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 70, - "endColumn": 78, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 70, - "endColumn": 78, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 45, - "endColumn": 53, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 45, - "endColumn": 53, - "lineCount": 1 - } - }, { "code": "reportAny", "range": { @@ -12365,22 +12165,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 16, - "endColumn": 19, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 16, - "endColumn": 19, - "lineCount": 1 - } - }, { "code": "reportAny", "range": { @@ -13093,22 +12877,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 24, - "endColumn": 27, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 24, - "endColumn": 27, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -13125,14 +12893,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 55, - "endColumn": 58, - "lineCount": 1 - } - }, { "code": "reportAny", "range": { @@ -13325,22 +13085,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 45, - "endColumn": 48, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 45, - "endColumn": 48, - "lineCount": 1 - } - }, { "code": "reportUnusedParameter", "range": { @@ -13478,31 +13222,7 @@ } }, { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 18, - "endColumn": 21, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 18, - "endColumn": 21, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 40, - "endColumn": 43, - "lineCount": 1 - } - }, - { - "code": "reportAny", + "code": "reportAny", "range": { "startColumn": 15, "endColumn": 22, @@ -14071,38 +13791,6 @@ "lineCount": 1 } }, - { - "code": "reportAny", - "range": { - "startColumn": 17, - "endColumn": 20, - "lineCount": 1 - } - }, - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 56, - "endColumn": 57, - "lineCount": 1 - } - }, - { - "code": "reportAny", - "range": { - "startColumn": 47, - "endColumn": 50, - "lineCount": 1 - } - }, - { - "code": "reportAny", - "range": { - "startColumn": 38, - "endColumn": 41, - "lineCount": 1 - } - }, { "code": "reportAny", "range": { @@ -14239,14 +13927,6 @@ "lineCount": 1 } }, - { - "code": "reportInvalidTypeVarUse", - "range": { - "startColumn": 32, - "endColumn": 38, - "lineCount": 1 - } - }, { "code": "reportUnknownParameterType", "range": { @@ -14343,14 +14023,6 @@ "lineCount": 1 } }, - { - "code": "reportArgumentType", - "range": { - "startColumn": 37, - "endColumn": 40, - "lineCount": 1 - } - }, { "code": "reportUnknownArgumentType", "range": { @@ -14367,22 +14039,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 40, - "endColumn": 48, - "lineCount": 1 - } - }, - { - "code": "reportArgumentType", - "range": { - "startColumn": 40, - "endColumn": 48, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -14415,22 +14071,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 38, - "endColumn": 41, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 38, - "endColumn": 41, - "lineCount": 1 - } - }, { "code": "reportUnknownVariableType", "range": { @@ -14447,14 +14087,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 41, - "endColumn": 44, - "lineCount": 1 - } - }, { "code": "reportUnknownArgumentType", "range": { @@ -14471,22 +14103,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 44, - "endColumn": 52, - "lineCount": 1 - } - }, - { - "code": "reportArgumentType", - "range": { - "startColumn": 44, - "endColumn": 52, - "lineCount": 1 - } - }, { "code": "reportArgumentType", "range": { @@ -17393,14 +17009,6 @@ "lineCount": 1 } }, - { - "code": "reportInvalidCast", - "range": { - "startColumn": 15, - "endColumn": 43, - "lineCount": 1 - } - }, { "code": "reportUnknownParameterType", "range": { @@ -17707,14 +17315,6 @@ "lineCount": 1 } }, - { - "code": "reportUnusedFunction", - "range": { - "startColumn": 4, - "endColumn": 22, - "lineCount": 1 - } - }, { "code": "reportUnannotatedClassAttribute", "range": { @@ -19753,14 +19353,6 @@ "lineCount": 1 } }, - { - "code": "reportUnusedFunction", - "range": { - "startColumn": 4, - "endColumn": 8, - "lineCount": 1 - } - }, { "code": "reportMissingParameterType", "range": { @@ -21427,30 +21019,6 @@ "lineCount": 1 } }, - { - "code": "reportAttributeAccessIssue", - "range": { - "startColumn": 25, - "endColumn": 29, - "lineCount": 1 - } - }, - { - "code": "reportAttributeAccessIssue", - "range": { - "startColumn": 25, - "endColumn": 29, - "lineCount": 1 - } - }, - { - "code": "reportAttributeAccessIssue", - "range": { - "startColumn": 14, - "endColumn": 31, - "lineCount": 1 - } - }, { "code": "reportAttributeAccessIssue", "range": { @@ -21459,70 +21027,6 @@ "lineCount": 1 } }, - { - "code": "reportAttributeAccessIssue", - "range": { - "startColumn": 43, - "endColumn": 47, - "lineCount": 1 - } - }, - { - "code": "reportAttributeAccessIssue", - "range": { - "startColumn": 43, - "endColumn": 47, - "lineCount": 1 - } - }, - { - "code": "reportArgumentType", - "range": { - "startColumn": 22, - "endColumn": 36, - "lineCount": 1 - } - }, - { - "code": "reportAttributeAccessIssue", - "range": { - "startColumn": 37, - "endColumn": 41, - "lineCount": 1 - } - }, - { - "code": "reportAttributeAccessIssue", - "range": { - "startColumn": 42, - "endColumn": 50, - "lineCount": 1 - } - }, - { - "code": "reportArgumentType", - "range": { - "startColumn": 22, - "endColumn": 37, - "lineCount": 1 - } - }, - { - "code": "reportAttributeAccessIssue", - "range": { - "startColumn": 38, - "endColumn": 42, - "lineCount": 1 - } - }, - { - "code": "reportAttributeAccessIssue", - "range": { - "startColumn": 43, - "endColumn": 51, - "lineCount": 1 - } - }, { "code": "reportMissingParameterType", "range": { @@ -21735,22 +21239,6 @@ "lineCount": 1 } }, - { - "code": "reportUnusedFunction", - "range": { - "startColumn": 4, - "endColumn": 28, - "lineCount": 1 - } - }, - { - "code": "reportUnusedFunction", - "range": { - "startColumn": 4, - "endColumn": 30, - "lineCount": 1 - } - }, { "code": "reportMissingParameterType", "range": { @@ -21783,14 +21271,6 @@ "lineCount": 1 } }, - { - "code": "reportUnusedFunction", - "range": { - "startColumn": 4, - "endColumn": 23, - "lineCount": 1 - } - }, { "code": "reportGeneralTypeIssues", "range": { @@ -21855,14 +21335,6 @@ "lineCount": 1 } }, - { - "code": "reportUnusedFunction", - "range": { - "startColumn": 4, - "endColumn": 26, - "lineCount": 1 - } - }, { "code": "reportMissingParameterType", "range": { diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index ef7141f5..69d050f7 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -77,6 +77,8 @@ import numpy as np +from pymbolic.typing import Integer + from arraycontext.container import ( ArrayContainer, NotAnArrayContainerError, @@ -91,7 +93,6 @@ ArrayOrContainer, ArrayOrContainerOrScalar, ArrayOrContainerT, - ArrayT, ScalarLike, ) @@ -400,7 +401,7 @@ def keyed_map_array_container( def rec_keyed_map_array_container( - f: Callable[[tuple[SerializationKey, ...], ArrayT], ArrayT], + f: Callable[[tuple[SerializationKey, ...], Array], Array], ary: ArrayOrContainer) -> ArrayOrContainer: """ Works similarly to :func:`rec_map_array_container`, except that *f* also @@ -408,13 +409,12 @@ def rec_keyed_map_array_container( passed in as a tuple of identifiers of the arrays traversed before reaching the current array. """ - def rec(keys: tuple[SerializationKey, ...], - ary_: ArrayOrContainerT) -> ArrayOrContainerT: + ary_: ArrayOrContainer) -> ArrayOrContainer: try: iterable = serialize_container(ary_) except NotAnArrayContainerError: - return cast(ArrayOrContainerT, f(keys, cast(ArrayT, ary_))) + return cast(ArrayOrContainer, f(keys, cast(Array, ary_))) else: return deserialize_container(ary_, [ (key, rec((*keys, key), subary)) for key, subary in iterable @@ -777,7 +777,7 @@ def unflatten( checks are skipped. """ # NOTE: https://github.com/python/mypy/issues/7057 - offset = 0 + offset: int = 0 common_dtype = None def _unflatten(template_subary: ArrayOrContainer) -> ArrayOrContainer: @@ -790,7 +790,11 @@ def _unflatten(template_subary: ArrayOrContainer) -> ArrayOrContainer: # {{{ validate subary - if (offset + template_subary_c.size) > ary.size: + if ( + isinstance(offset, Integer) + and isinstance(template_subary_c.size, Integer) + and isinstance(ary.size, Integer) + and (offset + template_subary_c.size) > ary.size): raise ValueError("'template' and 'ary' sizes do not match: " "'template' is too large") from None @@ -813,6 +817,12 @@ def _unflatten(template_subary: ArrayOrContainer) -> ArrayOrContainer: # {{{ reshape + if not isinstance(template_subary_c.size, Integer): + raise NotImplementedError( + "unflatten is not implemented for arrays with array-valued " + "size.") from None + + # FIXME: Not sure how to make the slicing part work for Array-valued sizes flat_subary = ary[offset:offset + template_subary_c.size] try: subary = actx.np.reshape(flat_subary, @@ -871,7 +881,7 @@ def _unflatten(template_subary: ArrayOrContainer) -> ArrayOrContainer: def flat_size_and_dtype( - ary: ArrayOrContainer) -> tuple[int, np.dtype[Any] | None]: + ary: ArrayOrContainer) -> tuple[Array | Integer, np.dtype[Any] | None]: """ :returns: a tuple ``(size, dtype)`` that would be the length and :class:`numpy.dtype` of the one-dimensional array returned by @@ -879,7 +889,7 @@ def flat_size_and_dtype( """ common_dtype = None - def _flat_size(subary: ArrayOrContainer) -> int: + def _flat_size(subary: ArrayOrContainer) -> Array | Integer: nonlocal common_dtype try: diff --git a/arraycontext/context.py b/arraycontext/context.py index 7e4659ae..f751413c 100644 --- a/arraycontext/context.py +++ b/arraycontext/context.py @@ -174,6 +174,7 @@ import numpy as np from typing_extensions import Self +from pymbolic.typing import Integer, Scalar as _Scalar from pytools import memoize_method from pytools.tag import ToTagSetConvertible @@ -202,11 +203,11 @@ class Array(Protocol): """ @property - def shape(self) -> tuple[int, ...]: + def shape(self) -> tuple[Array | Integer, ...]: ... @property - def size(self) -> int: + def size(self) -> Array | Integer: ... @property @@ -220,24 +221,27 @@ def dtype(self) -> np.dtype[Any]: def __getitem__(self, index: Any) -> Array: ... - # some basic arithmetic that's supposed to work - def __neg__(self) -> Self: ... - def __abs__(self) -> Self: ... - def __add__(self, other: Self | ScalarLike) -> Self: ... - def __radd__(self, other: Self | ScalarLike) -> Self: ... - def __sub__(self, other: Self | ScalarLike) -> Self: ... - def __rsub__(self, other: Self | ScalarLike) -> Self: ... - def __mul__(self, other: Self | ScalarLike) -> Self: ... - def __rmul__(self, other: Self | ScalarLike) -> Self: ... - def __pow__(self, other: Self | ScalarLike) -> Self: ... - def __rpow__(self, other: Self | ScalarLike) -> Self: ... - def __truediv__(self, other: Self | ScalarLike) -> Self: ... - def __rtruediv__(self, other: Self | ScalarLike) -> Self: ... + # Some basic arithmetic that's supposed to work + # Need to return Array instead of Self because for some array types, arithmetic + # operations on one subtype may result in a different subtype. + # For example, pytato arrays: + 1 -> + def __neg__(self) -> Array: ... + def __abs__(self) -> Array: ... + def __add__(self, other: Self | ScalarLike) -> Array: ... + def __radd__(self, other: Self | ScalarLike) -> Array: ... + def __sub__(self, other: Self | ScalarLike) -> Array: ... + def __rsub__(self, other: Self | ScalarLike) -> Array: ... + def __mul__(self, other: Self | ScalarLike) -> Array: ... + def __rmul__(self, other: Self | ScalarLike) -> Array: ... + def __pow__(self, other: Self | ScalarLike) -> Array: ... + def __rpow__(self, other: Self | ScalarLike) -> Array: ... + def __truediv__(self, other: Self | ScalarLike) -> Array: ... + def __rtruediv__(self, other: Self | ScalarLike) -> Array: ... # deprecated, use ScalarLike instead -ScalarLike: TypeAlias = int | float | complex | np.generic -Scalar = ScalarLike +Scalar = _Scalar +ScalarLike = Scalar ScalarLikeT = TypeVar("ScalarLikeT", bound=ScalarLike) # NOTE: I'm kind of not sure about the *Tc versions of these type variables. diff --git a/arraycontext/impl/pyopencl/taggable_cl_array.py b/arraycontext/impl/pyopencl/taggable_cl_array.py index 6d27eb74..79a108e2 100644 --- a/arraycontext/impl/pyopencl/taggable_cl_array.py +++ b/arraycontext/impl/pyopencl/taggable_cl_array.py @@ -76,6 +76,9 @@ class TaggableCLArray(cla.Array, Taggable): record application-specific metadata to drive the optimizations in :meth:`arraycontext.PyOpenCLArrayContext.transform_loopy_program`. """ + tags: frozenset[Tag] + axes: tuple[Axis, ...] + def __init__(self, cq, shape, dtype, order="C", allocator=None, data=None, offset=0, strides=None, events=None, _flags=None, _fast=False, _size=None, _context=None, _queue=None, diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index fff683dd..d55e86da 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -437,7 +437,7 @@ def from_numpy(self, array): import arraycontext.impl.pyopencl.taggable_cl_array as tga - def _from_numpy(ary): + def _from_numpy(ary: np.ndarray[Any, Any]) -> pt.Array: return pt.make_data_wrapper( tga.to_device(self.queue, ary, allocator=self.allocator) ) @@ -654,10 +654,11 @@ def thaw(self, array): import arraycontext.impl.pyopencl.taggable_cl_array as tga from .utils import get_pt_axes_from_cl_axes - def _thaw(ary): - return pt.make_data_wrapper(ary.with_queue(self.queue), - axes=get_pt_axes_from_cl_axes(ary.axes), - tags=ary.tags) + def _thaw(ary: tga.TaggableCLArray) -> pt.Array: + return pt.make_data_wrapper( + ary.with_queue(self.queue), + axes=get_pt_axes_from_cl_axes(ary.axes), + tags=ary.tags) return with_array_context( self._rec_map_container(_thaw, array, (tga.TaggableCLArray,)), @@ -668,7 +669,7 @@ def freeze_thaw(self, array): import arraycontext.impl.pyopencl.taggable_cl_array as tga - def _ft(ary): + def _ft(ary: tga.TaggableCLArray | pt.Array) -> tga.TaggableCLArray | pt.Array: if isinstance(ary, (pt.DataWrapper, tga.TaggableCLArray)): return ary else: @@ -848,7 +849,7 @@ def from_numpy(self, array): import jax import pytato as pt - def _from_numpy(ary): + def _from_numpy(ary: np.ndarray[Any, Any]) -> pt.Array: return pt.make_data_wrapper(jax.device_put(ary)) return with_array_context( @@ -904,7 +905,7 @@ def _record_leaf_ary_in_dict(key: tuple[Any, ...], # }}} - def _to_frozen(key: tuple[Any, ...], ary) -> jnp.ndarray: + def _to_frozen(key: tuple[Any, ...], ary: pt.Array) -> jnp.ndarray: key_str = "_ary" + _ary_container_key_stringifier(key) return key_to_frozen_subary[key_str] @@ -931,9 +932,10 @@ def _to_frozen(key: tuple[Any, ...], ary) -> jnp.ndarray: actx=None) def thaw(self, array): + import jax.numpy as jnp import pytato as pt - def _thaw(ary): + def _thaw(ary: jnp.ndarray) -> pt.Array: return pt.make_data_wrapper(ary) return with_array_context( diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index 5a4340e3..e78c4e62 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -37,18 +37,21 @@ import logging from collections.abc import Callable, Hashable, Mapping from dataclasses import dataclass, field -from typing import Any +from typing import Any, overload import numpy as np from immutabledict import immutabledict +import pyopencl.array as cla import pytato as pt from pytools import ProcessLogger, to_identifier from pytools.tag import Tag from arraycontext.container import ArrayContainer, is_array_container_type from arraycontext.container.traversal import rec_keyed_map_array_container -from arraycontext.context import ArrayT +from arraycontext.impl.pyopencl.taggable_cl_array import ( + TaggableCLArray, +) from arraycontext.impl.pytato import ( PytatoJAXArrayContext, PytatoPyOpenCLArrayContext, @@ -110,7 +113,7 @@ class LeafArrayDescriptor(AbstractInputDescriptor): # {{{ utilities -def _ary_container_key_stringifier(keys: tuple[Any, ...]) -> str: +def _ary_container_key_stringifier(keys: tuple[object, ...]) -> str: """ Helper for :meth:`BaseLazilyCompilingFunctionCaller.__call__`. Stringifies an array-container's component's key. Goals of this routine: @@ -119,12 +122,12 @@ def _ary_container_key_stringifier(keys: tuple[Any, ...]) -> str: * Stringified key must a valid identifier according to :meth:`str.isidentifier` * (informal) Shorter identifiers are preferred """ - def _rec_str(key: Any) -> str: + def _rec_str(key: object) -> str: if isinstance(key, str | int): return str(key) elif isinstance(key, tuple): # t in '_actx_t': stands for tuple - return "_actx_t" + "_".join(_rec_str(k) for k in key) + "_actx_endt" + return "_actx_t" + "_".join(_rec_str(k) for k in key) + "_actx_endt" # pyright: ignore[reportUnknownArgumentType, reportUnknownVariableType] else: raise NotImplementedError("Key-stringication unimplemented for " f"'{type(key).__name__}'.") @@ -175,7 +178,28 @@ def id_collector(keys, ary): return immutabledict(arg_id_to_arg), immutabledict(arg_id_to_descr) -def _to_input_for_compiled(ary: ArrayT, actx: PytatoPyOpenCLArrayContext): +@overload +def _to_input_for_compiled( + ary: pt.Array, actx: PytatoPyOpenCLArrayContext) -> pt.Array: + ... + + +@overload +def _to_input_for_compiled( + ary: TaggableCLArray, actx: PytatoPyOpenCLArrayContext) -> TaggableCLArray: + ... + + +@overload +def _to_input_for_compiled( + ary: cla.Array, actx: PytatoPyOpenCLArrayContext + ) -> cla.Array: + ... + + +def _to_input_for_compiled( + ary: pt.Array | TaggableCLArray | cla.Array, + actx: PytatoPyOpenCLArrayContext) -> pt.Array | TaggableCLArray | cla.Array: """ Preprocess *ary* before turning it into a :class:`pytato.array.Placeholder` in :meth:`LazilyCompilingFunctionCaller.__call__`. @@ -185,19 +209,14 @@ def _to_input_for_compiled(ary: ArrayT, actx: PytatoPyOpenCLArrayContext): - Metadata Inference that is supplied via *actx*\'s :meth:`PytatoPyOpenCLArrayContext.transform_dag`. """ - import pyopencl.array as cla - - from arraycontext.impl.pyopencl.taggable_cl_array import ( - TaggableCLArray, - to_tagged_cl_array, - ) + from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array if isinstance(ary, pt.Array): dag = pt.make_dict_of_named_arrays({"_actx_out": ary}) # Transform the DAG to give metadata inference a chance to do its job return actx.transform_dag(dag)["_actx_out"].expr elif isinstance(ary, TaggableCLArray): return ary - elif isinstance(ary, cla.Array): + else: from warnings import warn warn("Passing pyopencl.array.Array to a compiled callable" " is deprecated and will stop working in 2023." @@ -207,8 +226,6 @@ def _to_input_for_compiled(ary: ArrayT, actx: PytatoPyOpenCLArrayContext): return to_tagged_cl_array(ary, axes=None, tags=frozenset()) - else: - raise NotImplementedError(type(ary)) def _get_f_placeholder_args(arg, kw, arg_id_to_name, actx): @@ -230,7 +247,7 @@ def _get_f_placeholder_args(arg, kw, arg_id_to_name, actx): axes=arg.axes, tags=arg.tags) elif is_array_container_type(arg.__class__): - def _rec_to_placeholder(keys, ary): + def _rec_to_placeholder(keys, ary: pt.Array): index = (kw, *keys) name = arg_id_to_name[index] # Transform the DAG to give metadata inference a chance to do its job diff --git a/pyproject.toml b/pyproject.toml index fa75e4de..510f6709 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -129,6 +129,8 @@ reportImplicitStringConcatenation = "none" reportUnnecessaryIsInstance = "none" reportUnusedCallResult = "none" reportExplicitAny = "none" +reportPrivateUsage = "hint" +reportUnusedFunction = "none" # This reports even cycles that are qualified by 'if TYPE_CHECKING'. Not what # we care about at this moment. @@ -148,3 +150,14 @@ reportPrivateUsage = "none" reportMissingTypeStubs = "hint" reportAny = "hint" +[[tool.basedpyright.executionEnvironments]] +root = "examples" +reportUnknownArgumentType = "hint" +reportUnknownMemberType = "hint" +reportUnknownVariableType = "hint" +reportUnknownParameterType = "hint" +reportMissingTypeArgument = "hint" +reportPrivateUsage = "none" +reportMissingTypeStubs = "hint" +reportAny = "hint" + diff --git a/test/test_pytato_arraycontext.py b/test/test_pytato_arraycontext.py index 188aa9de..053117b7 100644 --- a/test/test_pytato_arraycontext.py +++ b/test/test_pytato_arraycontext.py @@ -25,10 +25,12 @@ """ import logging +from typing import Any import numpy as np import pytest +import pytato as pt from pytools.tag import Tag from arraycontext import ( @@ -41,6 +43,25 @@ logger = logging.getLogger(__name__) +# {{{ type checking + +def verify_is_dag(dag: Any) -> pt.DictOfNamedArrays: + assert isinstance(dag, pt.DictOfNamedArrays) + return dag + + +def verify_is_idx_lambda(ary: Any) -> pt.IndexLambda: + assert isinstance(ary, pt.IndexLambda) + return ary + + +def verify_is_data_wrapper(ary: Any) -> pt.DataWrapper: + assert isinstance(ary, pt.DataWrapper) + return ary + +# }}} + + # {{{ pytato-array context fixture class _PytatoPyOpenCLArrayContextForTests(PytatoPyOpenCLArrayContext): @@ -198,7 +219,6 @@ def twice(x): def test_transfer(actx_factory): import numpy as np - import pytato as pt actx = actx_factory() # {{{ simple tests @@ -219,7 +239,7 @@ def test_transfer(actx_factory): with pytest.raises(ValueError): _ahh = transfer_to_numpy(ah, actx) - ad = transfer_from_numpy(ah, actx) + ad = verify_is_data_wrapper(transfer_from_numpy(ah, actx)) assert isinstance(ad.data, TaggableCLArray) assert ad != ah assert ad != a # copied DataWrappers compare unequal @@ -238,12 +258,18 @@ def test_transfer(actx_factory): "a_expr": a + 2 }) - dagh = transfer_to_numpy(dag, actx) + dagh = verify_is_dag(transfer_to_numpy(dag, actx)) assert dagh != dag - assert isinstance(dagh["a_expr"].expr.bindings["_in0"].data, np.ndarray) + bndh = verify_is_data_wrapper( + verify_is_idx_lambda( + dagh["a_expr"].expr).bindings["_in0"]) + assert isinstance(bndh.data, np.ndarray) - daghd = transfer_from_numpy(dagh, actx) - assert isinstance(daghd["a_expr"].expr.bindings["_in0"].data, TaggableCLArray) + daghd = verify_is_dag(transfer_from_numpy(dagh, actx)) + bndhd = verify_is_data_wrapper( + verify_is_idx_lambda( + daghd["a_expr"].expr).bindings["_in0"]) + assert isinstance(bndhd.data, TaggableCLArray) # }}} @@ -254,7 +280,6 @@ def test_pass_args_compiled_func(actx_factory): import loopy as lp import pyopencl as cl import pyopencl.array - import pytato as pt def twice(x, y, a): return 2 * x * y * a