Skip to content

Commit 19bb4fe

Browse files
majosminducer
andauthored
Array type checking fixes (#317)
Fix some type-checking issues and make generic Array compatible with pytato arrays --------- Co-authored-by: Andreas Kloeckner <inform@tiker.net>
1 parent 561ac08 commit 19bb4fe

File tree

8 files changed

+133
-587
lines changed

8 files changed

+133
-587
lines changed

.basedpyright/baseline.json

Lines changed: 1 addition & 529 deletions
Large diffs are not rendered by default.

arraycontext/container/traversal.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@
7777

7878
import numpy as np
7979

80+
from pymbolic.typing import Integer
81+
8082
from arraycontext.container import (
8183
ArrayContainer,
8284
NotAnArrayContainerError,
@@ -91,7 +93,6 @@
9193
ArrayOrContainer,
9294
ArrayOrContainerOrScalar,
9395
ArrayOrContainerT,
94-
ArrayT,
9596
ScalarLike,
9697
)
9798

@@ -400,21 +401,20 @@ def keyed_map_array_container(
400401

401402

402403
def rec_keyed_map_array_container(
403-
f: Callable[[tuple[SerializationKey, ...], ArrayT], ArrayT],
404+
f: Callable[[tuple[SerializationKey, ...], Array], Array],
404405
ary: ArrayOrContainer) -> ArrayOrContainer:
405406
"""
406407
Works similarly to :func:`rec_map_array_container`, except that *f* also
407408
takes in a traversal path to the leaf array. The traversal path argument is
408409
passed in as a tuple of identifiers of the arrays traversed before reaching
409410
the current array.
410411
"""
411-
412412
def rec(keys: tuple[SerializationKey, ...],
413-
ary_: ArrayOrContainerT) -> ArrayOrContainerT:
413+
ary_: ArrayOrContainer) -> ArrayOrContainer:
414414
try:
415415
iterable = serialize_container(ary_)
416416
except NotAnArrayContainerError:
417-
return cast(ArrayOrContainerT, f(keys, cast(ArrayT, ary_)))
417+
return cast(ArrayOrContainer, f(keys, cast(Array, ary_)))
418418
else:
419419
return deserialize_container(ary_, [
420420
(key, rec((*keys, key), subary)) for key, subary in iterable
@@ -777,7 +777,7 @@ def unflatten(
777777
checks are skipped.
778778
"""
779779
# NOTE: https://github.com/python/mypy/issues/7057
780-
offset = 0
780+
offset: int = 0
781781
common_dtype = None
782782

783783
def _unflatten(template_subary: ArrayOrContainer) -> ArrayOrContainer:
@@ -790,7 +790,11 @@ def _unflatten(template_subary: ArrayOrContainer) -> ArrayOrContainer:
790790

791791
# {{{ validate subary
792792

793-
if (offset + template_subary_c.size) > ary.size:
793+
if (
794+
isinstance(offset, Integer)
795+
and isinstance(template_subary_c.size, Integer)
796+
and isinstance(ary.size, Integer)
797+
and (offset + template_subary_c.size) > ary.size):
794798
raise ValueError("'template' and 'ary' sizes do not match: "
795799
"'template' is too large") from None
796800

@@ -813,6 +817,12 @@ def _unflatten(template_subary: ArrayOrContainer) -> ArrayOrContainer:
813817

814818
# {{{ reshape
815819

820+
if not isinstance(template_subary_c.size, Integer):
821+
raise NotImplementedError(
822+
"unflatten is not implemented for arrays with array-valued "
823+
"size.") from None
824+
825+
# FIXME: Not sure how to make the slicing part work for Array-valued sizes
816826
flat_subary = ary[offset:offset + template_subary_c.size]
817827
try:
818828
subary = actx.np.reshape(flat_subary,
@@ -871,15 +881,15 @@ def _unflatten(template_subary: ArrayOrContainer) -> ArrayOrContainer:
871881

872882

873883
def flat_size_and_dtype(
874-
ary: ArrayOrContainer) -> tuple[int, np.dtype[Any] | None]:
884+
ary: ArrayOrContainer) -> tuple[Array | Integer, np.dtype[Any] | None]:
875885
"""
876886
:returns: a tuple ``(size, dtype)`` that would be the length and
877887
:class:`numpy.dtype` of the one-dimensional array returned by
878888
:func:`flatten`.
879889
"""
880890
common_dtype = None
881891

882-
def _flat_size(subary: ArrayOrContainer) -> int:
892+
def _flat_size(subary: ArrayOrContainer) -> Array | Integer:
883893
nonlocal common_dtype
884894

885895
try:

arraycontext/context.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@
174174
import numpy as np
175175
from typing_extensions import Self
176176

177+
from pymbolic.typing import Integer, Scalar as _Scalar
177178
from pytools import memoize_method
178179
from pytools.tag import ToTagSetConvertible
179180

@@ -202,11 +203,11 @@ class Array(Protocol):
202203
"""
203204

204205
@property
205-
def shape(self) -> tuple[int, ...]:
206+
def shape(self) -> tuple[Array | Integer, ...]:
206207
...
207208

208209
@property
209-
def size(self) -> int:
210+
def size(self) -> Array | Integer:
210211
...
211212

212213
@property
@@ -220,24 +221,27 @@ def dtype(self) -> np.dtype[Any]:
220221
def __getitem__(self, index: Any) -> Array:
221222
...
222223

223-
# some basic arithmetic that's supposed to work
224-
def __neg__(self) -> Self: ...
225-
def __abs__(self) -> Self: ...
226-
def __add__(self, other: Self | ScalarLike) -> Self: ...
227-
def __radd__(self, other: Self | ScalarLike) -> Self: ...
228-
def __sub__(self, other: Self | ScalarLike) -> Self: ...
229-
def __rsub__(self, other: Self | ScalarLike) -> Self: ...
230-
def __mul__(self, other: Self | ScalarLike) -> Self: ...
231-
def __rmul__(self, other: Self | ScalarLike) -> Self: ...
232-
def __pow__(self, other: Self | ScalarLike) -> Self: ...
233-
def __rpow__(self, other: Self | ScalarLike) -> Self: ...
234-
def __truediv__(self, other: Self | ScalarLike) -> Self: ...
235-
def __rtruediv__(self, other: Self | ScalarLike) -> Self: ...
224+
# Some basic arithmetic that's supposed to work
225+
# Need to return Array instead of Self because for some array types, arithmetic
226+
# operations on one subtype may result in a different subtype.
227+
# For example, pytato arrays: <Placeholder> + 1 -> <IndexLambda>
228+
def __neg__(self) -> Array: ...
229+
def __abs__(self) -> Array: ...
230+
def __add__(self, other: Self | ScalarLike) -> Array: ...
231+
def __radd__(self, other: Self | ScalarLike) -> Array: ...
232+
def __sub__(self, other: Self | ScalarLike) -> Array: ...
233+
def __rsub__(self, other: Self | ScalarLike) -> Array: ...
234+
def __mul__(self, other: Self | ScalarLike) -> Array: ...
235+
def __rmul__(self, other: Self | ScalarLike) -> Array: ...
236+
def __pow__(self, other: Self | ScalarLike) -> Array: ...
237+
def __rpow__(self, other: Self | ScalarLike) -> Array: ...
238+
def __truediv__(self, other: Self | ScalarLike) -> Array: ...
239+
def __rtruediv__(self, other: Self | ScalarLike) -> Array: ...
236240

237241

238242
# deprecated, use ScalarLike instead
239-
ScalarLike: TypeAlias = int | float | complex | np.generic
240-
Scalar = ScalarLike
243+
Scalar = _Scalar
244+
ScalarLike = Scalar
241245
ScalarLikeT = TypeVar("ScalarLikeT", bound=ScalarLike)
242246

243247
# NOTE: I'm kind of not sure about the *Tc versions of these type variables.

arraycontext/impl/pyopencl/taggable_cl_array.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,9 @@ class TaggableCLArray(cla.Array, Taggable):
7676
record application-specific metadata to drive the optimizations in
7777
:meth:`arraycontext.PyOpenCLArrayContext.transform_loopy_program`.
7878
"""
79+
tags: frozenset[Tag]
80+
axes: tuple[Axis, ...]
81+
7982
def __init__(self, cq, shape, dtype, order="C", allocator=None,
8083
data=None, offset=0, strides=None, events=None, _flags=None,
8184
_fast=False, _size=None, _context=None, _queue=None,

arraycontext/impl/pytato/__init__.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,7 @@ def from_numpy(self, array):
437437

438438
import arraycontext.impl.pyopencl.taggable_cl_array as tga
439439

440-
def _from_numpy(ary):
440+
def _from_numpy(ary: np.ndarray[Any, Any]) -> pt.Array:
441441
return pt.make_data_wrapper(
442442
tga.to_device(self.queue, ary, allocator=self.allocator)
443443
)
@@ -654,10 +654,11 @@ def thaw(self, array):
654654
import arraycontext.impl.pyopencl.taggable_cl_array as tga
655655
from .utils import get_pt_axes_from_cl_axes
656656

657-
def _thaw(ary):
658-
return pt.make_data_wrapper(ary.with_queue(self.queue),
659-
axes=get_pt_axes_from_cl_axes(ary.axes),
660-
tags=ary.tags)
657+
def _thaw(ary: tga.TaggableCLArray) -> pt.Array:
658+
return pt.make_data_wrapper(
659+
ary.with_queue(self.queue),
660+
axes=get_pt_axes_from_cl_axes(ary.axes),
661+
tags=ary.tags)
661662

662663
return with_array_context(
663664
self._rec_map_container(_thaw, array, (tga.TaggableCLArray,)),
@@ -668,7 +669,7 @@ def freeze_thaw(self, array):
668669

669670
import arraycontext.impl.pyopencl.taggable_cl_array as tga
670671

671-
def _ft(ary):
672+
def _ft(ary: tga.TaggableCLArray | pt.Array) -> tga.TaggableCLArray | pt.Array:
672673
if isinstance(ary, (pt.DataWrapper, tga.TaggableCLArray)):
673674
return ary
674675
else:
@@ -848,7 +849,7 @@ def from_numpy(self, array):
848849
import jax
849850
import pytato as pt
850851

851-
def _from_numpy(ary):
852+
def _from_numpy(ary: np.ndarray[Any, Any]) -> pt.Array:
852853
return pt.make_data_wrapper(jax.device_put(ary))
853854

854855
return with_array_context(
@@ -904,7 +905,7 @@ def _record_leaf_ary_in_dict(key: tuple[Any, ...],
904905

905906
# }}}
906907

907-
def _to_frozen(key: tuple[Any, ...], ary) -> jnp.ndarray:
908+
def _to_frozen(key: tuple[Any, ...], ary: pt.Array) -> jnp.ndarray:
908909
key_str = "_ary" + _ary_container_key_stringifier(key)
909910
return key_to_frozen_subary[key_str]
910911

@@ -931,9 +932,10 @@ def _to_frozen(key: tuple[Any, ...], ary) -> jnp.ndarray:
931932
actx=None)
932933

933934
def thaw(self, array):
935+
import jax.numpy as jnp
934936
import pytato as pt
935937

936-
def _thaw(ary):
938+
def _thaw(ary: jnp.ndarray) -> pt.Array:
937939
return pt.make_data_wrapper(ary)
938940

939941
return with_array_context(

arraycontext/impl/pytato/compile.py

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -37,18 +37,21 @@
3737
import logging
3838
from collections.abc import Callable, Hashable, Mapping
3939
from dataclasses import dataclass, field
40-
from typing import Any
40+
from typing import Any, overload
4141

4242
import numpy as np
4343
from immutabledict import immutabledict
4444

45+
import pyopencl.array as cla
4546
import pytato as pt
4647
from pytools import ProcessLogger, to_identifier
4748
from pytools.tag import Tag
4849

4950
from arraycontext.container import ArrayContainer, is_array_container_type
5051
from arraycontext.container.traversal import rec_keyed_map_array_container
51-
from arraycontext.context import ArrayT
52+
from arraycontext.impl.pyopencl.taggable_cl_array import (
53+
TaggableCLArray,
54+
)
5255
from arraycontext.impl.pytato import (
5356
PytatoJAXArrayContext,
5457
PytatoPyOpenCLArrayContext,
@@ -110,7 +113,7 @@ class LeafArrayDescriptor(AbstractInputDescriptor):
110113

111114
# {{{ utilities
112115

113-
def _ary_container_key_stringifier(keys: tuple[Any, ...]) -> str:
116+
def _ary_container_key_stringifier(keys: tuple[object, ...]) -> str:
114117
"""
115118
Helper for :meth:`BaseLazilyCompilingFunctionCaller.__call__`. Stringifies an
116119
array-container's component's key. Goals of this routine:
@@ -119,12 +122,12 @@ def _ary_container_key_stringifier(keys: tuple[Any, ...]) -> str:
119122
* Stringified key must a valid identifier according to :meth:`str.isidentifier`
120123
* (informal) Shorter identifiers are preferred
121124
"""
122-
def _rec_str(key: Any) -> str:
125+
def _rec_str(key: object) -> str:
123126
if isinstance(key, str | int):
124127
return str(key)
125128
elif isinstance(key, tuple):
126129
# t in '_actx_t': stands for tuple
127-
return "_actx_t" + "_".join(_rec_str(k) for k in key) + "_actx_endt"
130+
return "_actx_t" + "_".join(_rec_str(k) for k in key) + "_actx_endt" # pyright: ignore[reportUnknownArgumentType, reportUnknownVariableType]
128131
else:
129132
raise NotImplementedError("Key-stringication unimplemented for "
130133
f"'{type(key).__name__}'.")
@@ -175,7 +178,28 @@ def id_collector(keys, ary):
175178
return immutabledict(arg_id_to_arg), immutabledict(arg_id_to_descr)
176179

177180

178-
def _to_input_for_compiled(ary: ArrayT, actx: PytatoPyOpenCLArrayContext):
181+
@overload
182+
def _to_input_for_compiled(
183+
ary: pt.Array, actx: PytatoPyOpenCLArrayContext) -> pt.Array:
184+
...
185+
186+
187+
@overload
188+
def _to_input_for_compiled(
189+
ary: TaggableCLArray, actx: PytatoPyOpenCLArrayContext) -> TaggableCLArray:
190+
...
191+
192+
193+
@overload
194+
def _to_input_for_compiled(
195+
ary: cla.Array, actx: PytatoPyOpenCLArrayContext
196+
) -> cla.Array:
197+
...
198+
199+
200+
def _to_input_for_compiled(
201+
ary: pt.Array | TaggableCLArray | cla.Array,
202+
actx: PytatoPyOpenCLArrayContext) -> pt.Array | TaggableCLArray | cla.Array:
179203
"""
180204
Preprocess *ary* before turning it into a :class:`pytato.array.Placeholder`
181205
in :meth:`LazilyCompilingFunctionCaller.__call__`.
@@ -185,19 +209,14 @@ def _to_input_for_compiled(ary: ArrayT, actx: PytatoPyOpenCLArrayContext):
185209
- Metadata Inference that is supplied via *actx*\'s
186210
:meth:`PytatoPyOpenCLArrayContext.transform_dag`.
187211
"""
188-
import pyopencl.array as cla
189-
190-
from arraycontext.impl.pyopencl.taggable_cl_array import (
191-
TaggableCLArray,
192-
to_tagged_cl_array,
193-
)
212+
from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array
194213
if isinstance(ary, pt.Array):
195214
dag = pt.make_dict_of_named_arrays({"_actx_out": ary})
196215
# Transform the DAG to give metadata inference a chance to do its job
197216
return actx.transform_dag(dag)["_actx_out"].expr
198217
elif isinstance(ary, TaggableCLArray):
199218
return ary
200-
elif isinstance(ary, cla.Array):
219+
else:
201220
from warnings import warn
202221
warn("Passing pyopencl.array.Array to a compiled callable"
203222
" is deprecated and will stop working in 2023."
@@ -207,8 +226,6 @@ def _to_input_for_compiled(ary: ArrayT, actx: PytatoPyOpenCLArrayContext):
207226
return to_tagged_cl_array(ary,
208227
axes=None,
209228
tags=frozenset())
210-
else:
211-
raise NotImplementedError(type(ary))
212229

213230

214231
def _get_f_placeholder_args(arg, kw, arg_id_to_name, actx):
@@ -230,7 +247,7 @@ def _get_f_placeholder_args(arg, kw, arg_id_to_name, actx):
230247
axes=arg.axes,
231248
tags=arg.tags)
232249
elif is_array_container_type(arg.__class__):
233-
def _rec_to_placeholder(keys, ary):
250+
def _rec_to_placeholder(keys, ary: pt.Array):
234251
index = (kw, *keys)
235252
name = arg_id_to_name[index]
236253
# Transform the DAG to give metadata inference a chance to do its job

pyproject.toml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,8 @@ reportImplicitStringConcatenation = "none"
129129
reportUnnecessaryIsInstance = "none"
130130
reportUnusedCallResult = "none"
131131
reportExplicitAny = "none"
132+
reportPrivateUsage = "hint"
133+
reportUnusedFunction = "none"
132134

133135
# This reports even cycles that are qualified by 'if TYPE_CHECKING'. Not what
134136
# we care about at this moment.
@@ -148,3 +150,14 @@ reportPrivateUsage = "none"
148150
reportMissingTypeStubs = "hint"
149151
reportAny = "hint"
150152

153+
[[tool.basedpyright.executionEnvironments]]
154+
root = "examples"
155+
reportUnknownArgumentType = "hint"
156+
reportUnknownMemberType = "hint"
157+
reportUnknownVariableType = "hint"
158+
reportUnknownParameterType = "hint"
159+
reportMissingTypeArgument = "hint"
160+
reportPrivateUsage = "none"
161+
reportMissingTypeStubs = "hint"
162+
reportAny = "hint"
163+

0 commit comments

Comments
 (0)