diff --git a/.basedpyright/baseline.json b/.basedpyright/baseline.json index 192e5270..b5a86728 100644 --- a/.basedpyright/baseline.json +++ b/.basedpyright/baseline.json @@ -970,8 +970,8 @@ { "code": "reportDeprecated", "range": { - "startColumn": 31, - "endColumn": 36, + "startColumn": 46, + "endColumn": 51, "lineCount": 1 } }, @@ -6265,7 +6265,7 @@ "code": "reportInvalidCast", "range": { "startColumn": 28, - "endColumn": 72, + "endColumn": 74, "lineCount": 1 } }, @@ -17355,6 +17355,38 @@ "lineCount": 1 } }, + { + "code": "reportUnknownVariableType", + "range": { + "startColumn": 11, + "endColumn": 46, + "lineCount": 1 + } + }, + { + "code": "reportArgumentType", + "range": { + "startColumn": 41, + "endColumn": 45, + "lineCount": 1 + } + }, + { + "code": "reportUnknownVariableType", + "range": { + "startColumn": 11, + "endColumn": 44, + "lineCount": 1 + } + }, + { + "code": "reportArgumentType", + "range": { + "startColumn": 39, + "endColumn": 43, + "lineCount": 1 + } + }, { "code": "reportPrivateUsage", "range": { @@ -17461,6 +17493,38 @@ "lineCount": 1 } }, + { + "code": "reportUnknownArgumentType", + "range": { + "startColumn": 12, + "endColumn": 19, + "lineCount": 1 + } + }, + { + "code": "reportUnknownArgumentType", + "range": { + "startColumn": 12, + "endColumn": 22, + "lineCount": 1 + } + }, + { + "code": "reportUnknownArgumentType", + "range": { + "startColumn": 24, + "endColumn": 35, + "lineCount": 1 + } + }, + { + "code": "reportUnknownArgumentType", + "range": { + "startColumn": 17, + "endColumn": 21, + "lineCount": 1 + } + }, { "code": "reportUnknownParameterType", "range": { @@ -19452,32 +19516,24 @@ { "code": "reportMissingParameterType", "range": { - "startColumn": 36, - "endColumn": 48, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 50, - "endColumn": 58, + "startColumn": 12, + "endColumn": 20, "lineCount": 1 } }, { "code": "reportMissingParameterType", "range": { - "startColumn": 60, - "endColumn": 66, + "startColumn": 12, + "endColumn": 18, "lineCount": 1 } }, { "code": "reportMissingParameterType", "range": { - "startColumn": 68, - "endColumn": 73, + "startColumn": 12, + "endColumn": 17, "lineCount": 1 } }, @@ -19524,32 +19580,24 @@ { "code": "reportMissingParameterType", "range": { - "startColumn": 31, - "endColumn": 43, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 45, - "endColumn": 53, + "startColumn": 12, + "endColumn": 20, "lineCount": 1 } }, { "code": "reportMissingParameterType", "range": { - "startColumn": 55, - "endColumn": 61, + "startColumn": 12, + "endColumn": 18, "lineCount": 1 } }, { "code": "reportMissingParameterType", "range": { - "startColumn": 63, - "endColumn": 68, + "startColumn": 12, + "endColumn": 17, "lineCount": 1 } }, @@ -19577,14 +19625,6 @@ "lineCount": 1 } }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 20, - "endColumn": 32, - "lineCount": 1 - } - }, { "code": "reportUnusedVariable", "range": { @@ -19617,14 +19657,6 @@ "lineCount": 1 } }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 26, - "endColumn": 38, - "lineCount": 1 - } - }, { "code": "reportUnusedVariable", "range": { @@ -19657,14 +19689,6 @@ "lineCount": 1 } }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 22, - "endColumn": 34, - "lineCount": 1 - } - }, { "code": "reportUnknownLambdaType", "range": { @@ -19689,14 +19713,6 @@ "lineCount": 1 } }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 20, - "endColumn": 32, - "lineCount": 1 - } - }, { "code": "reportUnknownLambdaType", "range": { @@ -19721,14 +19737,6 @@ "lineCount": 1 } }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 44, - "endColumn": 56, - "lineCount": 1 - } - }, { "code": "reportMissingParameterType", "range": { @@ -19764,32 +19772,16 @@ { "code": "reportMissingParameterType", "range": { - "startColumn": 34, - "endColumn": 46, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 48, - "endColumn": 50, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 31, - "endColumn": 43, + "startColumn": 69, + "endColumn": 71, "lineCount": 1 } }, { "code": "reportMissingParameterType", "range": { - "startColumn": 45, - "endColumn": 53, + "startColumn": 66, + "endColumn": 74, "lineCount": 1 } }, @@ -19841,14 +19833,6 @@ "lineCount": 1 } }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 21, - "endColumn": 33, - "lineCount": 1 - } - }, { "code": "reportUnknownLambdaType", "range": { @@ -19900,64 +19884,24 @@ { "code": "reportMissingParameterType", "range": { - "startColumn": 49, - "endColumn": 61, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 63, - "endColumn": 67, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 48, - "endColumn": 60, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 62, - "endColumn": 66, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 47, - "endColumn": 59, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 61, - "endColumn": 65, + "startColumn": 12, + "endColumn": 16, "lineCount": 1 } }, { "code": "reportMissingParameterType", "range": { - "startColumn": 40, - "endColumn": 52, + "startColumn": 12, + "endColumn": 16, "lineCount": 1 } }, { "code": "reportMissingParameterType", "range": { - "startColumn": 23, - "endColumn": 35, + "startColumn": 82, + "endColumn": 86, "lineCount": 1 } }, @@ -20041,14 +19985,6 @@ "lineCount": 1 } }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 28, - "endColumn": 40, - "lineCount": 1 - } - }, { "code": "reportMissingParameterType", "range": { @@ -20281,14 +20217,6 @@ "lineCount": 1 } }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 30, - "endColumn": 42, - "lineCount": 1 - } - }, { "code": "reportMissingParameterType", "range": { @@ -20345,14 +20273,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownLambdaType", - "range": { - "startColumn": 51, - "endColumn": 65, - "lineCount": 1 - } - }, { "code": "reportUnusedExpression", "range": { @@ -20393,14 +20313,6 @@ "lineCount": 1 } }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 31, - "endColumn": 43, - "lineCount": 1 - } - }, { "code": "reportUnusedExpression", "range": { @@ -20420,32 +20332,16 @@ { "code": "reportMissingParameterType", "range": { - "startColumn": 24, - "endColumn": 36, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 38, - "endColumn": 41, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 33, - "endColumn": 45, + "startColumn": 59, + "endColumn": 62, "lineCount": 1 } }, { "code": "reportMissingParameterType", "range": { - "startColumn": 47, - "endColumn": 53, + "startColumn": 68, + "endColumn": 74, "lineCount": 1 } }, @@ -20506,58 +20402,50 @@ } }, { - "code": "reportMissingParameterType", - "range": { - "startColumn": 41, - "endColumn": 53, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", + "code": "reportArgumentType", "range": { - "startColumn": 33, - "endColumn": 45, + "startColumn": 27, + "endColumn": 50, "lineCount": 1 } }, { - "code": "reportMissingParameterType", + "code": "reportAttributeAccessIssue", "range": { - "startColumn": 26, - "endColumn": 38, + "startColumn": 45, + "endColumn": 49, "lineCount": 1 } }, { - "code": "reportMissingParameterType", + "code": "reportAttributeAccessIssue", "range": { - "startColumn": 22, - "endColumn": 34, + "startColumn": 52, + "endColumn": 60, "lineCount": 1 } }, { - "code": "reportMissingParameterType", + "code": "reportAttributeAccessIssue", "range": { - "startColumn": 36, - "endColumn": 44, + "startColumn": 50, + "endColumn": 54, "lineCount": 1 } }, { "code": "reportMissingParameterType", "range": { - "startColumn": 23, - "endColumn": 35, + "startColumn": 57, + "endColumn": 65, "lineCount": 1 } }, { "code": "reportMissingParameterType", "range": { - "startColumn": 37, - "endColumn": 41, + "startColumn": 58, + "endColumn": 62, "lineCount": 1 } }, @@ -20618,66 +20506,66 @@ } }, { - "code": "reportMissingParameterType", + "code": "reportAttributeAccessIssue", "range": { - "startColumn": 22, - "endColumn": 34, + "startColumn": 38, + "endColumn": 39, "lineCount": 1 } }, { - "code": "reportMissingParameterType", + "code": "reportAttributeAccessIssue", "range": { - "startColumn": 36, - "endColumn": 48, + "startColumn": 38, + "endColumn": 39, "lineCount": 1 } }, { - "code": "reportMissingParameterType", + "code": "reportAttributeAccessIssue", "range": { - "startColumn": 29, - "endColumn": 41, + "startColumn": 38, + "endColumn": 39, "lineCount": 1 } }, { - "code": "reportMissingParameterType", + "code": "reportAttributeAccessIssue", "range": { - "startColumn": 45, - "endColumn": 57, + "startColumn": 38, + "endColumn": 39, "lineCount": 1 } }, { - "code": "reportMissingParameterType", + "code": "reportAttributeAccessIssue", "range": { - "startColumn": 15, - "endColumn": 20, + "startColumn": 38, + "endColumn": 39, "lineCount": 1 } }, { - "code": "reportMissingParameterType", + "code": "reportAttributeAccessIssue", "range": { - "startColumn": 22, - "endColumn": 25, + "startColumn": 38, + "endColumn": 39, "lineCount": 1 } }, { "code": "reportMissingParameterType", "range": { - "startColumn": 28, - "endColumn": 40, + "startColumn": 15, + "endColumn": 20, "lineCount": 1 } }, { "code": "reportMissingParameterType", "range": { - "startColumn": 41, - "endColumn": 53, + "startColumn": 22, + "endColumn": 25, "lineCount": 1 } }, @@ -20713,14 +20601,6 @@ "lineCount": 1 } }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 15, - "endColumn": 27, - "lineCount": 1 - } - }, { "code": "reportMissingParameterType", "range": { @@ -20745,22 +20625,6 @@ "lineCount": 1 } }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 36, - "endColumn": 48, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 43, - "endColumn": 55, - "lineCount": 1 - } - }, { "code": "reportMissingParameterType", "range": { @@ -20769,38 +20633,6 @@ "lineCount": 1 } }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 32, - "endColumn": 44, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 35, - "endColumn": 47, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 17, - "endColumn": 29, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 36, - "endColumn": 48, - "lineCount": 1 - } - }, { "code": "reportUnknownLambdaType", "range": { @@ -20836,24 +20668,16 @@ { "code": "reportMissingParameterType", "range": { - "startColumn": 18, - "endColumn": 30, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 32, - "endColumn": 36, + "startColumn": 53, + "endColumn": 57, "lineCount": 1 } }, { "code": "reportMissingParameterType", "range": { - "startColumn": 38, - "endColumn": 44, + "startColumn": 59, + "endColumn": 65, "lineCount": 1 } } @@ -20899,14 +20723,6 @@ "lineCount": 1 } }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 37, - "endColumn": 49, - "lineCount": 1 - } - }, { "code": "reportAttributeAccessIssue", "range": { @@ -20915,14 +20731,6 @@ "lineCount": 1 } }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 24, - "endColumn": 36, - "lineCount": 1 - } - }, { "code": "reportMissingParameterType", "range": { @@ -20966,16 +20774,8 @@ { "code": "reportMissingParameterType", "range": { - "startColumn": 31, - "endColumn": 43, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 45, - "endColumn": 59, + "startColumn": 66, + "endColumn": 80, "lineCount": 1 } }, @@ -20988,10 +20788,10 @@ } }, { - "code": "reportMissingParameterType", + "code": "reportAttributeAccessIssue", "range": { - "startColumn": 18, - "endColumn": 30, + "startColumn": 48, + "endColumn": 54, "lineCount": 1 } }, @@ -21027,14 +20827,6 @@ "lineCount": 1 } }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 33, - "endColumn": 45, - "lineCount": 1 - } - }, { "code": "reportUnusedImport", "range": { @@ -21137,7 +20929,7 @@ "code": "reportInvalidCast", "range": { "startColumn": 25, - "endColumn": 52, + "endColumn": 54, "lineCount": 1 } }, @@ -21145,7 +20937,7 @@ "code": "reportInvalidCast", "range": { "startColumn": 21, - "endColumn": 53, + "endColumn": 55, "lineCount": 1 } }, @@ -21153,7 +20945,7 @@ "code": "reportInvalidCast", "range": { "startColumn": 21, - "endColumn": 53, + "endColumn": 55, "lineCount": 1 } } diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index 4abe4a54..ed983f92 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -75,6 +75,7 @@ from .context import ( Array, ArrayContext, + ArrayContextFactory, ArrayOrArithContainer, ArrayOrArithContainerOrScalar, ArrayOrArithContainerOrScalarT, @@ -107,6 +108,7 @@ "ArrayContainer", "ArrayContainerT", "ArrayContext", + "ArrayContextFactory", "ArrayOrArithContainer", "ArrayOrArithContainerOrScalar", "ArrayOrArithContainerOrScalarT", diff --git a/arraycontext/container/__init__.py b/arraycontext/container/__init__.py index 1ebc2cb2..78d64844 100644 --- a/arraycontext/container/__init__.py +++ b/arraycontext/container/__init__.py @@ -40,10 +40,6 @@ :canonical: arraycontext.ArrayContainerT -.. class:: ArrayOrContainerT - - :canonical: arraycontext.ArrayOrContainerT - .. class:: SerializationKey :canonical: arraycontext.SerializationKey @@ -90,13 +86,12 @@ import numpy as np from typing_extensions import Self -from arraycontext.context import ArrayContext, ArrayOrScalar - if TYPE_CHECKING: from pymbolic.geometric_algebra import MultiVector from arraycontext import ArrayOrContainer + from arraycontext.context import ArrayContext, ArrayOrScalar # {{{ ArrayContainer diff --git a/arraycontext/container/arithmetic.py b/arraycontext/container/arithmetic.py index 06ffe738..87313fd9 100644 --- a/arraycontext/container/arithmetic.py +++ b/arraycontext/container/arithmetic.py @@ -37,11 +37,10 @@ import enum import operator -from collections.abc import Callable from dataclasses import dataclass, field from functools import partialmethod from numbers import Number -from typing import Any, TypeVar +from typing import TYPE_CHECKING, Any, TypeVar from warnings import warn import numpy as np @@ -51,7 +50,12 @@ deserialize_container, serialize_container, ) -from arraycontext.context import ArrayContext, ArrayOrContainer + + +if TYPE_CHECKING: + from collections.abc import Callable + + from arraycontext.context import ArrayContext, ArrayOrContainer # {{{ with_container_arithmetic diff --git a/arraycontext/container/dataclass.py b/arraycontext/container/dataclass.py index 21ed61f3..82d307f6 100644 --- a/arraycontext/container/dataclass.py +++ b/arraycontext/container/dataclass.py @@ -31,13 +31,16 @@ THE SOFTWARE. """ -from collections.abc import Mapping, Sequence from dataclasses import fields, is_dataclass -from typing import NamedTuple, Union, get_args, get_origin +from typing import TYPE_CHECKING, NamedTuple, Union, get_args, get_origin from arraycontext.container import is_array_container_type +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + + # {{{ dataclass containers class _Field(NamedTuple): diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index 69d050f7..1ddc7a24 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -70,9 +70,8 @@ THE SOFTWARE. """ -from collections.abc import Callable, Iterable from functools import partial, singledispatch, update_wrapper -from typing import Any, cast +from typing import TYPE_CHECKING, Any, cast from warnings import warn import numpy as np @@ -87,14 +86,19 @@ get_container_context_recursively_opt, serialize_container, ) -from arraycontext.context import ( - Array, - ArrayContext, - ArrayOrContainer, - ArrayOrContainerOrScalar, - ArrayOrContainerT, - ScalarLike, -) + + +if TYPE_CHECKING: + from collections.abc import Callable, Iterable + + from arraycontext.context import ( + Array, + ArrayContext, + ArrayOrContainer, + ArrayOrContainerOrScalar, + ArrayOrContainerT, + ScalarLike, + ) # {{{ array container traversal helpers @@ -414,7 +418,7 @@ def rec(keys: tuple[SerializationKey, ...], try: iterable = serialize_container(ary_) except NotAnArrayContainerError: - return cast(ArrayOrContainer, f(keys, cast(Array, ary_))) + return cast("ArrayOrContainer", f(keys, cast("Array", ary_))) else: return deserialize_container(ary_, [ (key, rec((*keys, key), subary)) for key, subary in iterable @@ -699,7 +703,7 @@ def _flatten(subary: ArrayOrContainer) -> list[Array]: try: iterable = serialize_container(subary) except NotAnArrayContainerError: - subary_c = cast(Array, subary) + subary_c = cast("Array", subary) if common_dtype is None: common_dtype = subary_c.dtype @@ -786,7 +790,7 @@ def _unflatten(template_subary: ArrayOrContainer) -> ArrayOrContainer: try: iterable = serialize_container(template_subary) except NotAnArrayContainerError: - template_subary_c = cast(Array, template_subary) + template_subary_c = cast("Array", template_subary) # {{{ validate subary @@ -877,7 +881,7 @@ def _unflatten(template_subary: ArrayOrContainer) -> ArrayOrContainer: raise ValueError("'template' and 'ary' sizes do not match: " "'ary' is too large") - return cast(ArrayOrContainerT, result) + return cast("ArrayOrContainerT", result) def flat_size_and_dtype( @@ -895,7 +899,7 @@ def _flat_size(subary: ArrayOrContainer) -> Array | Integer: try: iterable = serialize_container(subary) except NotAnArrayContainerError: - subary_c = cast(Array, subary) + subary_c = cast("Array", subary) if common_dtype is None: common_dtype = subary_c.dtype diff --git a/arraycontext/context.py b/arraycontext/context.py index f751413c..d064392d 100644 --- a/arraycontext/context.py +++ b/arraycontext/context.py @@ -85,6 +85,11 @@ Types and Type Variables for Arrays and Containers -------------------------------------------------- +.. autodata:: ScalarLike + :noindex: + + A type alias of :data:`pymbolic.Scalar`. + .. autoclass:: Array .. autodata:: ArrayT @@ -176,11 +181,11 @@ from pymbolic.typing import Integer, Scalar as _Scalar from pytools import memoize_method -from pytools.tag import ToTagSetConvertible if TYPE_CHECKING: import loopy + from pytools.tag import ToTagSetConvertible from arraycontext.container import ArithArrayContainer, ArrayContainer @@ -254,7 +259,7 @@ def __rtruediv__(self, other: Self | ScalarLike) -> Array: ... # # For now, they're purposefully not in the main arraycontext.* name space. ArrayT = TypeVar("ArrayT", bound=Array) -ArrayOrScalar: TypeAlias = "Array | ScalarLike" +ArrayOrScalar: TypeAlias = Array | ScalarLike ArrayOrContainer: TypeAlias = "Array | ArrayContainer" ArrayOrArithContainer: TypeAlias = "Array | ArithArrayContainer" ArrayOrContainerT = TypeVar("ArrayOrContainerT", bound=ArrayOrContainer) @@ -610,6 +615,9 @@ def permits_advanced_indexing(self) -> bool: # }}} +ArrayContextFactory: TypeAlias = Callable[[], ArrayContext] + + # {{{ tagging helpers def tag_axes( diff --git a/arraycontext/impl/jax/__init__.py b/arraycontext/impl/jax/__init__.py index d862b01d..248d6630 100644 --- a/arraycontext/impl/jax/__init__.py +++ b/arraycontext/impl/jax/__init__.py @@ -29,16 +29,21 @@ THE SOFTWARE. """ -from collections.abc import Callable -import numpy as np +from typing import TYPE_CHECKING -from pytools.tag import ToTagSetConvertible +import numpy as np from arraycontext.container.traversal import rec_map_array_container, with_array_context from arraycontext.context import Array, ArrayContext, ArrayOrContainer, ScalarLike +if TYPE_CHECKING: + from collections.abc import Callable + + from pytools.tag import ToTagSetConvertible + + class EagerJAXArrayContext(ArrayContext): """ A :class:`ArrayContext` that uses diff --git a/arraycontext/impl/jax/fake_numpy.py b/arraycontext/impl/jax/fake_numpy.py index 7acf4fab..b60ae038 100644 --- a/arraycontext/impl/jax/fake_numpy.py +++ b/arraycontext/impl/jax/fake_numpy.py @@ -25,6 +25,7 @@ THE SOFTWARE. """ from functools import partial, reduce +from typing import TYPE_CHECKING import numpy as np @@ -39,10 +40,13 @@ rec_map_reduce_array_container, rec_multimap_array_container, ) -from arraycontext.context import Array, ArrayOrContainer from arraycontext.fake_numpy import BaseFakeNumpyLinalgNamespace, BaseFakeNumpyNamespace +if TYPE_CHECKING: + from arraycontext.context import Array, ArrayOrContainer + + class EagerJAXFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace): # Everything is implemented in the base class for now. pass diff --git a/arraycontext/impl/numpy/__init__.py b/arraycontext/impl/numpy/__init__.py index f9d6c541..a887f2dc 100644 --- a/arraycontext/impl/numpy/__init__.py +++ b/arraycontext/impl/numpy/__init__.py @@ -33,12 +33,11 @@ THE SOFTWARE. """ -from typing import Any, overload +from typing import TYPE_CHECKING, Any, overload import numpy as np import loopy as lp -from pytools.tag import ToTagSetConvertible from arraycontext.container.traversal import rec_map_array_container, with_array_context from arraycontext.context import ( @@ -52,6 +51,10 @@ ) +if TYPE_CHECKING: + from pytools.tag import ToTagSetConvertible + + class NumpyNonObjectArrayMetaclass(type): def __instancecheck__(cls, instance: Any) -> bool: return isinstance(instance, np.ndarray) and instance.dtype != object diff --git a/arraycontext/impl/numpy/fake_numpy.py b/arraycontext/impl/numpy/fake_numpy.py index 582ccda9..b4209b17 100644 --- a/arraycontext/impl/numpy/fake_numpy.py +++ b/arraycontext/impl/numpy/fake_numpy.py @@ -26,7 +26,7 @@ """ from functools import partial, reduce -from typing import cast +from typing import TYPE_CHECKING, cast import numpy as np @@ -37,13 +37,16 @@ rec_multimap_array_container, rec_multimap_reduce_array_container, ) -from arraycontext.context import Array, ArrayOrContainer from arraycontext.fake_numpy import ( BaseFakeNumpyLinalgNamespace, BaseFakeNumpyNamespace, ) +if TYPE_CHECKING: + from arraycontext.context import Array, ArrayOrContainer + + class NumpyFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace): # Everything is implemented in the base class for now. pass @@ -150,7 +153,7 @@ def array_equal(self, a: ArrayOrContainer, b: ArrayOrContainer) -> Array: return false_ary return np.logical_and.reduce( [(true_ary if kx_i == ky_i else false_ary) - and cast(np.ndarray, self.array_equal(x_i, y_i)) + and cast("np.ndarray", self.array_equal(x_i, y_i)) for (kx_i, x_i), (ky_i, y_i) in zip(serialized_x, serialized_y, strict=True)], initial=true_ary) diff --git a/arraycontext/impl/pyopencl/__init__.py b/arraycontext/impl/pyopencl/__init__.py index 19c9faea..9e792d01 100644 --- a/arraycontext/impl/pyopencl/__init__.py +++ b/arraycontext/impl/pyopencl/__init__.py @@ -31,14 +31,11 @@ THE SOFTWARE. """ -from collections.abc import Callable from typing import TYPE_CHECKING, Literal from warnings import warn import numpy as np -from pytools.tag import ToTagSetConvertible - from arraycontext.container.traversal import rec_map_array_container, with_array_context from arraycontext.context import ( Array, @@ -50,9 +47,12 @@ if TYPE_CHECKING: + from collections.abc import Callable + import loopy as lp import pyopencl as cl import pyopencl.array as cl_array + from pytools.tag import ToTagSetConvertible # {{{ PyOpenCLArrayContext diff --git a/arraycontext/impl/pyopencl/fake_numpy.py b/arraycontext/impl/pyopencl/fake_numpy.py index 17b55470..1661a75c 100644 --- a/arraycontext/impl/pyopencl/fake_numpy.py +++ b/arraycontext/impl/pyopencl/fake_numpy.py @@ -31,6 +31,7 @@ import operator from functools import partial, reduce +from typing import TYPE_CHECKING import numpy as np @@ -41,12 +42,15 @@ rec_multimap_array_container, rec_multimap_reduce_array_container, ) -from arraycontext.context import Array as actx_Array, ArrayOrContainer from arraycontext.fake_numpy import BaseFakeNumpyLinalgNamespace -from arraycontext.impl.pyopencl.taggable_cl_array import TaggableCLArray from arraycontext.loopy import LoopyBasedFakeNumpyNamespace +if TYPE_CHECKING: + from arraycontext.context import Array as actx_Array, ArrayOrContainer + from arraycontext.impl.pyopencl.taggable_cl_array import TaggableCLArray + + # {{{ fake numpy class PyOpenCLFakeNumpyNamespace(LoopyBasedFakeNumpyNamespace): diff --git a/arraycontext/impl/pyopencl/taggable_cl_array.py b/arraycontext/impl/pyopencl/taggable_cl_array.py index 79a108e2..579d3291 100644 --- a/arraycontext/impl/pyopencl/taggable_cl_array.py +++ b/arraycontext/impl/pyopencl/taggable_cl_array.py @@ -7,10 +7,9 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any, Literal +from typing import TYPE_CHECKING, Any, Literal import numpy as np -from numpy.typing import DTypeLike import pyopencl as cl import pyopencl.array as cla @@ -18,6 +17,10 @@ from pytools.tag import Tag, Taggable, ToTagSetConvertible +if TYPE_CHECKING: + from numpy.typing import DTypeLike + + # {{{ utils @dataclass(frozen=True, eq=True) diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index d55e86da..270230bd 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -53,7 +53,6 @@ import abc import sys -from collections.abc import Callable from dataclasses import dataclass from typing import TYPE_CHECKING, Any @@ -74,13 +73,17 @@ if TYPE_CHECKING: + from collections.abc import Callable + + import jax.numpy as jnp import loopy as lp import pyopencl as cl import pyopencl.array as cl_array import pytato + import pytato as pt if getattr(sys, "_BUILDING_SPHINX_DOCS", False): - import pyopencl as cl + pass import logging @@ -150,7 +153,6 @@ def __init__( """ super().__init__() - import pytato as pt self._freeze_prg_cache: dict[pt.DictOfNamedArrays, lp.TranslationUnit] = {} self._dag_transform_cache: dict[ pt.DictOfNamedArrays, @@ -932,7 +934,6 @@ def _to_frozen(key: tuple[Any, ...], ary: pt.Array) -> jnp.ndarray: actx=None) def thaw(self, array): - import jax.numpy as jnp import pytato as pt def _thaw(ary: jnp.ndarray) -> pt.Array: diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index e78c4e62..1ea47d6d 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -35,14 +35,12 @@ import abc import itertools import logging -from collections.abc import Callable, Hashable, Mapping from dataclasses import dataclass, field -from typing import Any, overload +from typing import TYPE_CHECKING, Any, overload import numpy as np from immutabledict import immutabledict -import pyopencl.array as cla import pytato as pt from pytools import ProcessLogger, to_identifier from pytools.tag import Tag @@ -59,6 +57,12 @@ ) +if TYPE_CHECKING: + from collections.abc import Callable, Hashable, Mapping + + import pyopencl.array as cla + + logger = logging.getLogger(__name__) diff --git a/arraycontext/impl/pytato/fake_numpy.py b/arraycontext/impl/pytato/fake_numpy.py index 5b864e6c..98d3e438 100644 --- a/arraycontext/impl/pytato/fake_numpy.py +++ b/arraycontext/impl/pytato/fake_numpy.py @@ -25,7 +25,7 @@ THE SOFTWARE. """ from functools import partial, reduce -from typing import Any, cast +from typing import TYPE_CHECKING, Any, cast import numpy as np @@ -37,11 +37,14 @@ rec_map_reduce_array_container, rec_multimap_array_container, ) -from arraycontext.context import Array, ArrayOrContainer from arraycontext.fake_numpy import BaseFakeNumpyLinalgNamespace from arraycontext.loopy import LoopyBasedFakeNumpyNamespace +if TYPE_CHECKING: + from arraycontext.context import Array, ArrayOrContainer + + class PytatoFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace): # Everything is implemented in the base class for now. pass @@ -196,7 +199,7 @@ def rec_equal(x: ArrayOrContainer, y: ArrayOrContainer) -> pt.Array: if x.shape != y.shape: return false_ary else: - return pt.all(cast(pt.Array, pt.equal(x, y))) + return pt.all(cast("pt.Array", pt.equal(x, y))) else: if len(serialized_x) != len(serialized_y): return false_ary @@ -209,7 +212,7 @@ def rec_equal(x: ArrayOrContainer, y: ArrayOrContainer) -> pt.Array: in zip(serialized_x, serialized_y, strict=True)], true_ary) - return cast(Array, rec_equal(a, b)) + return cast("Array", rec_equal(a, b)) # }}} diff --git a/arraycontext/impl/pytato/utils.py b/arraycontext/impl/pytato/utils.py index 8c6e7f5c..005e5987 100644 --- a/arraycontext/impl/pytato/utils.py +++ b/arraycontext/impl/pytato/utils.py @@ -38,7 +38,6 @@ """ -from collections.abc import Mapping from typing import TYPE_CHECKING, Any, cast import pytools @@ -56,14 +55,17 @@ from pytato.transform import ArrayOrNames, CopyMapper from pytools import UniqueNameGenerator, memoize_method -from arraycontext import ArrayContext from arraycontext.impl.pyopencl.taggable_cl_array import Axis as ClAxis -from arraycontext.impl.pytato import PytatoPyOpenCLArrayContext if TYPE_CHECKING: + from collections.abc import Mapping + import loopy as lp + from arraycontext import ArrayContext + from arraycontext.impl.pytato import PytatoPyOpenCLArrayContext + class _DatawrapperToBoundPlaceholderMapper(CopyMapper): """ @@ -90,8 +92,10 @@ def map_data_wrapper(self, expr: DataWrapper) -> Array: self.bound_arguments[name] = expr.data return make_placeholder( name=name, - shape=tuple(cast(Array, self.rec(s)) if isinstance(s, Array) else s - for s in expr.shape), + shape=tuple( + cast("Array", self.rec(s)) + if isinstance(s, Array) else s + for s in expr.shape), dtype=expr.dtype, axes=expr.axes, tags=expr.tags) diff --git a/arraycontext/loopy.py b/arraycontext/loopy.py index af579324..a9fc8d7b 100644 --- a/arraycontext/loopy.py +++ b/arraycontext/loopy.py @@ -29,8 +29,7 @@ THE SOFTWARE. """ -from collections.abc import Mapping -from typing import ClassVar +from typing import TYPE_CHECKING, ClassVar import numpy as np @@ -42,6 +41,10 @@ from arraycontext.fake_numpy import BaseFakeNumpyNamespace +if TYPE_CHECKING: + from collections.abc import Mapping + + # {{{ loopy _DEFAULT_LOOPY_OPTIONS = lp.Options( diff --git a/arraycontext/pytest.py b/arraycontext/pytest.py index 760fc103..1e5b5374 100644 --- a/arraycontext/pytest.py +++ b/arraycontext/pytest.py @@ -33,11 +33,15 @@ THE SOFTWARE. """ -from collections.abc import Callable, Sequence -from typing import Any +from typing import TYPE_CHECKING, Any from arraycontext import NumpyArrayContext -from arraycontext.context import ArrayContext + + +if TYPE_CHECKING: + from collections.abc import Callable, Sequence + + from arraycontext.context import ArrayContext # {{{ array context factories diff --git a/doc/other.rst b/doc/other.rst index e483f1b3..fc13932f 100644 --- a/doc/other.rst +++ b/doc/other.rst @@ -26,3 +26,33 @@ References .. class:: Allocator See :class:`pyopencl.array.Allocator`. + +.. currentmodule:: np + +.. class:: ndarray + + See :class:`numpy.ndarray`. + +.. currentmodule:: dummy_refs + +.. class:: ToTagSetConvertible + + See :mod:`pytools.tag`. + +.. class:: ArrayOrNames + + A type alias in :mod:`pytato` allowing + :class:`pytato.Array` and + :class:`pytato.AbstractResultWithNamedArrays`. + +.. class:: Integer + + A type alias allowing integers. + +.. class:: ScalarLike + + See :class:`arraycontext.ScalarLike`. + +.. class:: ArrayOrContainerOrScalar + + See :class:`arraycontext.ArrayOrContainerOrScalar`. diff --git a/pyproject.toml b/pyproject.toml index 510f6709..5be437d6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -83,6 +83,7 @@ extend-select = [ "UP", # pyupgrade "W", # pycodestyle "SIM", + "TC", ] extend-ignore = [ "C90", # McCabe complexity diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 31fa9e79..a1bdd625 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -34,6 +34,7 @@ from pytools.tag import Tag from arraycontext import ( + ArrayContextFactory, BcastUntilActxArray, EagerJAXArrayContext, NumpyArrayContext, @@ -248,7 +249,11 @@ def assert_close_to_numpy_in_containers(actx, op, args): ("sum", 1, np.complex64), ("isnan", 1, np.float64), ]) -def test_array_context_np_workalike(actx_factory, sym_name, n_args, dtype): +def test_array_context_np_workalike( + actx_factory: ArrayContextFactory, + sym_name, + n_args, + dtype): actx = actx_factory() if not hasattr(actx.np, sym_name): pytest.skip(f"'{sym_name}' not implemented on '{type(actx).__name__}'") @@ -281,7 +286,11 @@ def evaluate(np_, *args_): ("ones_like", 1, np.float64), ("ones_like", 1, np.complex128), ]) -def test_array_context_np_like(actx_factory, sym_name, n_args, dtype): +def test_array_context_np_like( + actx_factory: ArrayContextFactory, + sym_name, + n_args, + dtype): actx = actx_factory() ndofs = 512 @@ -311,7 +320,7 @@ def test_array_context_np_like(actx_factory, sym_name, n_args, dtype): # {{{ array manipulations -def test_actx_stack(actx_factory): +def test_actx_stack(actx_factory: ArrayContextFactory): rng = np.random.default_rng() actx = actx_factory() @@ -323,7 +332,7 @@ def test_actx_stack(actx_factory): actx, lambda _np, *_args: _np.stack(_args), args) -def test_actx_concatenate(actx_factory): +def test_actx_concatenate(actx_factory: ArrayContextFactory): rng = np.random.default_rng() actx = actx_factory() @@ -334,7 +343,7 @@ def test_actx_concatenate(actx_factory): actx, lambda _np, *_args: _np.concatenate(_args), args) -def test_actx_reshape(actx_factory): +def test_actx_reshape(actx_factory: ArrayContextFactory): rng = np.random.default_rng() actx = actx_factory() @@ -344,7 +353,7 @@ def test_actx_reshape(actx_factory): (rng.normal(size=(2, 3)), new_shape)) -def test_actx_ravel(actx_factory): +def test_actx_ravel(actx_factory: ArrayContextFactory): from numpy.random import default_rng actx = actx_factory() rng = default_rng() @@ -359,7 +368,7 @@ def test_actx_ravel(actx_factory): # {{{ arithmetic same as numpy -def test_dof_array_arithmetic_same_as_numpy(actx_factory): +def test_dof_array_arithmetic_same_as_numpy(actx_factory: ArrayContextFactory): rng = np.random.default_rng() actx = actx_factory() @@ -513,7 +522,7 @@ def get_imag(ary): # {{{ reductions same as numpy @pytest.mark.parametrize("op", ["sum", "min", "max"]) -def test_reductions_same_as_numpy(actx_factory, op): +def test_reductions_same_as_numpy(actx_factory: ArrayContextFactory, op): rng = np.random.default_rng() actx = actx_factory() @@ -528,7 +537,7 @@ def test_reductions_same_as_numpy(actx_factory, op): @pytest.mark.parametrize("sym_name", ["any", "all"]) -def test_any_all_same_as_numpy(actx_factory, sym_name): +def test_any_all_same_as_numpy(actx_factory: ArrayContextFactory, sym_name): actx = actx_factory() if not hasattr(actx.np, sym_name): pytest.skip(f"'{sym_name}' not implemented on '{type(actx).__name__}'") @@ -545,7 +554,7 @@ def test_any_all_same_as_numpy(actx_factory, sym_name): lambda _np, *_args: getattr(_np, sym_name)(*_args), [1 - ary_all]) -def test_array_equal(actx_factory): +def test_array_equal(actx_factory: ArrayContextFactory): actx = actx_factory() sym_name = "array_equal" @@ -590,7 +599,9 @@ def test_array_equal(actx_factory): "ij->ji", "ii->i", ]) -def test_array_context_einsum_array_manipulation(actx_factory, spec): +def test_array_context_einsum_array_manipulation( + actx_factory: ArrayContextFactory, + spec): actx = actx_factory() rng = np.random.default_rng() @@ -605,7 +616,9 @@ def test_array_context_einsum_array_manipulation(actx_factory, spec): "ij,ji->ij", "ij,kj->ik", ]) -def test_array_context_einsum_array_matmatprods(actx_factory, spec): +def test_array_context_einsum_array_matmatprods( + actx_factory: ArrayContextFactory, + spec): actx = actx_factory() rng = np.random.default_rng() @@ -619,7 +632,7 @@ def test_array_context_einsum_array_matmatprods(actx_factory, spec): @pytest.mark.parametrize("spec", [ "im,mj,k->ijk" ]) -def test_array_context_einsum_array_tripleprod(actx_factory, spec): +def test_array_context_einsum_array_tripleprod(actx_factory: ArrayContextFactory, spec): actx = actx_factory() rng = np.random.default_rng() @@ -639,7 +652,7 @@ def test_array_context_einsum_array_tripleprod(actx_factory, spec): # {{{ array container classes for test -def test_container_map_on_device_scalar(actx_factory): +def test_container_map_on_device_scalar(actx_factory: ArrayContextFactory): actx = actx_factory() expected_sizes = [1, 2, 4, 4, 4] @@ -665,7 +678,7 @@ def test_container_map_on_device_scalar(actx_factory): assert result == size -def test_container_map(actx_factory): +def test_container_map(actx_factory: ArrayContextFactory): actx = actx_factory() ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs, _bcast_dc_of_dofs = \ _get_test_containers(actx) @@ -719,7 +732,7 @@ def check_leaf(x): # }}} -def test_container_multimap(actx_factory): +def test_container_multimap(actx_factory: ArrayContextFactory): actx = actx_factory() ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs, _bcast_dc_of_dofs = \ _get_test_containers(actx) @@ -786,7 +799,7 @@ def check_leaf(a, subary1, b, subary2): # }}} -def test_container_arithmetic(actx_factory): +def test_container_arithmetic(actx_factory: ArrayContextFactory): actx = actx_factory() ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs, bcast_dc_of_dofs = \ _get_test_containers(actx) @@ -840,7 +853,7 @@ def _check_allclose(f, arg1, arg2, atol=5.0e-14): # }}} -def test_container_freeze_thaw(actx_factory): +def test_container_freeze_thaw(actx_factory: ArrayContextFactory): actx = actx_factory() ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs, _bcast_dc_of_dofs = \ _get_test_containers(actx) @@ -885,7 +898,7 @@ def test_container_freeze_thaw(actx_factory): @pytest.mark.parametrize("ord", [2, np.inf]) -def test_container_norm(actx_factory, ord): +def test_container_norm(actx_factory: ArrayContextFactory, ord): actx = actx_factory() from pytools.obj_array import make_obj_array @@ -907,7 +920,7 @@ def test_container_norm(actx_factory, ord): [(127, 67), (18, 0)], # tests 0-sized arrays [(64, 7), (154, 12)] ]) -def test_flatten_array_container(actx_factory, shapes): +def test_flatten_array_container(actx_factory: ArrayContextFactory, shapes): actx = actx_factory() from arraycontext import flatten, unflatten @@ -962,7 +975,7 @@ def _checked_flatten(ary, actx, leaf_class=None): return result -def test_flatten_array_container_failure(actx_factory): +def test_flatten_array_container_failure(actx_factory: ArrayContextFactory): actx = actx_factory() from arraycontext import unflatten @@ -983,7 +996,7 @@ def test_flatten_array_container_failure(actx_factory): unflatten(ary, flat_ary[:-1], actx) -def test_flatten_with_leaf_class(actx_factory): +def test_flatten_with_leaf_class(actx_factory: ArrayContextFactory): actx = actx_factory() arys = _get_test_containers(actx, shapes=512) @@ -1010,7 +1023,7 @@ def test_flatten_with_leaf_class(actx_factory): # {{{ test from_numpy and to_numpy -def test_numpy_conversion(actx_factory): +def test_numpy_conversion(actx_factory: ArrayContextFactory): actx = actx_factory() rng = np.random.default_rng() @@ -1046,7 +1059,7 @@ def test_numpy_conversion(actx_factory): # {{{ test actx.np.linalg.norm @pytest.mark.parametrize("norm_ord", [2, np.inf]) -def test_norm_complex(actx_factory, norm_ord): +def test_norm_complex(actx_factory: ArrayContextFactory, norm_ord): actx = actx_factory() a = randn(2000, np.complex128) @@ -1059,7 +1072,7 @@ def test_norm_complex(actx_factory, norm_ord): @pytest.mark.parametrize("ndim", [1, 2, 3, 4, 5]) -def test_norm_ord_none(actx_factory, ndim): +def test_norm_ord_none(actx_factory: ArrayContextFactory, ndim): actx = actx_factory() from numpy.random import default_rng @@ -1087,7 +1100,7 @@ def scale_and_orthogonalize(alpha, vel): return Velocity2D(-scaled_vel.v, scaled_vel.u, actx) -def test_actx_compile(actx_factory): +def test_actx_compile(actx_factory: ArrayContextFactory): actx = actx_factory() rng = np.random.default_rng() @@ -1105,7 +1118,7 @@ def test_actx_compile(actx_factory): np.testing.assert_allclose(result.v, 3.14*v_x) -def test_actx_compile_python_scalar(actx_factory): +def test_actx_compile_python_scalar(actx_factory: ArrayContextFactory): actx = actx_factory() rng = np.random.default_rng() @@ -1123,7 +1136,7 @@ def test_actx_compile_python_scalar(actx_factory): np.testing.assert_allclose(result.v, 3.14*v_x) -def test_actx_compile_kwargs(actx_factory): +def test_actx_compile_kwargs(actx_factory: ArrayContextFactory): actx = actx_factory() rng = np.random.default_rng() @@ -1141,7 +1154,7 @@ def test_actx_compile_kwargs(actx_factory): np.testing.assert_allclose(result.v, 3.14*v_x) -def test_actx_compile_with_tuple_output_keys(actx_factory): +def test_actx_compile_with_tuple_output_keys(actx_factory: ArrayContextFactory): # arraycontext.git<=3c9aee68 would fail due to a bug in output # key stringification logic. actx = actx_factory() @@ -1170,7 +1183,7 @@ def my_rhs(scale, vel): # {{{ test_container_equality -def test_container_equality(actx_factory): +def test_container_equality(actx_factory: ArrayContextFactory): actx = actx_factory() ary_dof, _, _, _dc_of_dofs, bcast_dc_of_dofs = \ @@ -1192,7 +1205,7 @@ def test_container_equality(actx_factory): # {{{ test_leaf_array_type_broadcasting -def test_no_leaf_array_type_broadcasting(actx_factory): +def test_no_leaf_array_type_broadcasting(actx_factory: ArrayContextFactory): from testlib import Foo # test lack of support for https://github.com/inducer/arraycontext/issues/49 actx = actx_factory() @@ -1283,7 +1296,7 @@ def _actx_allows_scalar_broadcast(actx): # {{{ test outer product -def test_outer(actx_factory): +def test_outer(actx_factory: ArrayContextFactory): actx = actx_factory() a_dof, a_ary_of_dofs, _, _, a_bcast_dc_of_dofs = _get_test_containers(actx) @@ -1372,7 +1385,7 @@ class ArrayContainerWithNumpy: __array_ufunc__ = None -def test_array_container_with_numpy(actx_factory): +def test_array_container_with_numpy(actx_factory: ArrayContextFactory): actx = actx_factory() mystate = ArrayContainerWithNumpy( @@ -1389,7 +1402,7 @@ def test_array_container_with_numpy(actx_factory): # {{{ test_actx_compile_on_pure_array_return -def test_actx_compile_on_pure_array_return(actx_factory): +def test_actx_compile_on_pure_array_return(actx_factory: ArrayContextFactory): def _twice(x): return 2 * x @@ -1410,7 +1423,7 @@ class MySampleTag(Tag): pass -def test_taggable_cl_array_tags(actx_factory): +def test_taggable_cl_array_tags(actx_factory: ArrayContextFactory): actx = actx_factory() if not isinstance(actx, PyOpenCLArrayContext): pytest.skip(f"not relevant for '{type(actx).__name__}'") @@ -1455,7 +1468,7 @@ def test_taggable_cl_array_tags(actx_factory): # }}} -def test_to_numpy_on_frozen_arrays(actx_factory): +def test_to_numpy_on_frozen_arrays(actx_factory: ArrayContextFactory): # See https://github.com/inducer/arraycontext/issues/159 actx = actx_factory() u = actx.freeze(actx.np.zeros(10, dtype="float64")+1) @@ -1463,7 +1476,7 @@ def test_to_numpy_on_frozen_arrays(actx_factory): np.testing.assert_allclose(actx.to_numpy(u), 1) -def test_tagging(actx_factory): +def test_tagging(actx_factory: ArrayContextFactory): actx = actx_factory() if isinstance(actx, NumpyArrayContext | EagerJAXArrayContext): @@ -1484,7 +1497,7 @@ class ExampleTag(Tag): assert not ary.axes[1].tags_of_type(ExampleTag) -def test_compile_anonymous_function(actx_factory): +def test_compile_anonymous_function(actx_factory: ArrayContextFactory): from functools import partial # See https://github.com/inducer/grudge/issues/287 @@ -1511,7 +1524,7 @@ def test_compile_anonymous_function(actx_factory): ((1, 5, 20), {"dtype": np.complex128}), ((1, 5, 20), {"dtype": np.int32}), ]) -def test_linspace(actx_factory, args, kwargs): +def test_linspace(actx_factory: ArrayContextFactory, args, kwargs): if "Jax" in actx_factory.__class__.__name__: pytest.xfail("jax actx does not have arange") diff --git a/test/test_pytato_arraycontext.py b/test/test_pytato_arraycontext.py index 053117b7..a7f48fae 100644 --- a/test/test_pytato_arraycontext.py +++ b/test/test_pytato_arraycontext.py @@ -23,9 +23,8 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ - import logging -from typing import Any +from typing import TYPE_CHECKING, Any, cast import numpy as np import pytest @@ -34,12 +33,17 @@ from pytools.tag import Tag from arraycontext import ( + ArrayContextFactory, PytatoPyOpenCLArrayContext, pytest_generate_tests_for_array_contexts, ) from arraycontext.pytest import _PytestPytatoPyOpenCLArrayContextFactory +if TYPE_CHECKING: + from collections.abc import Callable + + logger = logging.getLogger(__name__) @@ -108,7 +112,7 @@ class BazTag(Tag): # }}} -def test_tags_preserved_after_freeze(actx_factory): +def test_tags_preserved_after_freeze(actx_factory: ArrayContextFactory): actx = actx_factory() from arraycontext.impl.pytato import _BasePytatoArrayContext @@ -130,7 +134,7 @@ def test_tags_preserved_after_freeze(actx_factory): assert foo.axes[1].tags_of_type(BazTag) -def test_arg_size_limit(actx_factory): +def test_arg_size_limit(actx_factory: Callable[[], PytatoPyOpenCLArrayContext]): ran_callback = False def my_ctc(what, stage, ir): @@ -154,8 +158,8 @@ def twice(x): @pytest.mark.parametrize("pass_allocator", ["auto_none", "auto_true", "auto_false", "pass_buffer", "pass_svm", "pass_buffer_pool", "pass_svm_pool"]) -def test_pytato_actx_allocator(actx_factory, pass_allocator): - base_actx = actx_factory() +def test_pytato_actx_allocator(actx_factory: ArrayContextFactory, pass_allocator): + base_actx = cast("PytatoPyOpenCLArrayContext", actx_factory()) alloc = None use_memory_pool = None @@ -216,7 +220,7 @@ def twice(x): assert res == 198 -def test_transfer(actx_factory): +def test_transfer(actx_factory: ArrayContextFactory): import numpy as np actx = actx_factory() @@ -274,7 +278,8 @@ def test_transfer(actx_factory): # }}} -def test_pass_args_compiled_func(actx_factory): +def test_pass_args_compiled_func( + actx_factory: Callable[[], PytatoPyOpenCLArrayContext]): import numpy as np import loopy as lp diff --git a/test/test_utils.py b/test/test_utils.py index eeef7723..422d11fe 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -209,13 +209,13 @@ class SomeOtherContainer: extent: float rng = np.random.default_rng(seed=42) - a = ArrayWrapper(ary=cast(Array, rng.random(10))) + a = ArrayWrapper(ary=cast("Array", rng.random(10))) d = SomeContainer( - points=cast(Array, rng.random((2, 10))), + points=cast("Array", rng.random((2, 10))), radius=rng.random(), centers=a) c = SomeContainer( - points=cast(Array, rng.random((2, 10))), + points=cast("Array", rng.random((2, 10))), radius=rng.random(), centers=a) ary = SomeOtherContainer(