Skip to content

Commit 26162cb

Browse files
committed
Tweak is_array_container_type logic following typing improvements
1 parent a914f86 commit 26162cb

File tree

4 files changed

+27
-75
lines changed

4 files changed

+27
-75
lines changed

.basedpyright/baseline.json

Lines changed: 0 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -4129,30 +4129,6 @@
41294129
"lineCount": 1
41304130
}
41314131
},
4132-
{
4133-
"code": "reportUnknownMemberType",
4134-
"range": {
4135-
"startColumn": 15,
4136-
"endColumn": 30,
4137-
"lineCount": 1
4138-
}
4139-
},
4140-
{
4141-
"code": "reportAny",
4142-
"range": {
4143-
"startColumn": 59,
4144-
"endColumn": 63,
4145-
"lineCount": 1
4146-
}
4147-
},
4148-
{
4149-
"code": "reportAny",
4150-
"range": {
4151-
"startColumn": 67,
4152-
"endColumn": 73,
4153-
"lineCount": 1
4154-
}
4155-
},
41564132
{
41574133
"code": "reportUnnecessaryComparison",
41584134
"range": {
@@ -6189,30 +6165,6 @@
61896165
"lineCount": 1
61906166
}
61916167
},
6192-
{
6193-
"code": "reportUnknownVariableType",
6194-
"range": {
6195-
"startColumn": 12,
6196-
"endColumn": 16,
6197-
"lineCount": 1
6198-
}
6199-
},
6200-
{
6201-
"code": "reportOperatorIssue",
6202-
"range": {
6203-
"startColumn": 19,
6204-
"endColumn": 77,
6205-
"lineCount": 1
6206-
}
6207-
},
6208-
{
6209-
"code": "reportUnknownArgumentType",
6210-
"range": {
6211-
"startColumn": 62,
6212-
"endColumn": 66,
6213-
"lineCount": 1
6214-
}
6215-
},
62166168
{
62176169
"code": "reportPrivateImportUsage",
62186170
"range": {
@@ -9979,14 +9931,6 @@
99799931
"lineCount": 3
99809932
}
99819933
},
9982-
{
9983-
"code": "reportUnknownArgumentType",
9984-
"range": {
9985-
"startColumn": 29,
9986-
"endColumn": 75,
9987-
"lineCount": 1
9988-
}
9989-
},
99909934
{
99919935
"code": "reportUnknownArgumentType",
99929936
"range": {
@@ -10043,14 +9987,6 @@
100439987
"lineCount": 1
100449988
}
100459989
},
10046-
{
10047-
"code": "reportUnknownVariableType",
10048-
"range": {
10049-
"startColumn": 8,
10050-
"endColumn": 14,
10051-
"lineCount": 1
10052-
}
10053-
},
100549990
{
100559991
"code": "reportUnknownArgumentType",
100569992
"range": {

arraycontext/container/__init__.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,17 @@
8989
.. class:: SerializedContainer
9090
9191
:canonical: arraycontext.SerializedContainer
92+
93+
References
94+
----------
95+
96+
.. class:: GenericAlias
97+
98+
See :class:`types.GenericAlias`.
99+
100+
.. class:: UnionType
101+
102+
See :class:`types.UnionType`.
92103
"""
93104

94105
from __future__ import annotations
@@ -120,7 +131,6 @@
120131

121132
from collections.abc import Hashable, Sequence
122133
from functools import singledispatch
123-
from types import GenericAlias, UnionType
124134
from typing import (
125135
TYPE_CHECKING,
126136
TypeAlias,
@@ -133,18 +143,23 @@
133143
import numpy as np
134144
from typing_extensions import TypeIs
135145

136-
from pytools.obj_array import ObjectArrayND as ObjectArrayND
146+
from pytools.obj_array import ObjectArray, ObjectArrayND as ObjectArrayND
137147

138148
from arraycontext.typing import (
149+
ArithArrayContainer,
139150
ArrayContainer,
140151
ArrayContainerT,
141152
ArrayOrArithContainer,
142153
ArrayOrArithContainerOrScalar as ArrayOrArithContainerOrScalar,
143154
ArrayOrContainerOrScalar,
155+
_UserDefinedArithArrayContainer,
156+
_UserDefinedArrayContainer,
144157
)
145158

146159

147160
if TYPE_CHECKING:
161+
from types import GenericAlias, UnionType
162+
148163
from pymbolic.geometric_algebra import CoeffT, MultiVector
149164

150165
from arraycontext.context import ArrayContext
@@ -217,17 +232,21 @@ def is_array_container_type(cls: type | GenericAlias | UnionType) -> bool:
217232
function will say that :class:`numpy.ndarray` is an array container
218233
type, only object arrays *actually are* array containers.
219234
"""
220-
if cls is ArrayContainer:
235+
if cls is ArrayContainer or cls is ArithArrayContainer:
221236
return True
222237

223-
while isinstance(cls, GenericAlias):
224-
cls = get_origin(cls)
238+
origin = get_origin(cls)
239+
if origin is not None:
240+
cls = origin # pyright: ignore[reportAny]
225241

226242
assert isinstance(cls, type), (
227243
f"must pass a {type!r}, not a '{cls!r}'")
228244

229245
return (
230-
cls is ArrayContainer # pyright: ignore[reportUnnecessaryComparison]
246+
cls is ObjectArray
247+
or cls is ArrayContainer # pyright: ignore[reportUnnecessaryComparison]
248+
or cls is _UserDefinedArrayContainer
249+
or cls is _UserDefinedArithArrayContainer
231250
or (serialize_container.dispatch(cls)
232251
is not serialize_container.__wrapped__)) # type:ignore[attr-defined]
233252

arraycontext/container/dataclass.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666

6767
if TYPE_CHECKING:
6868
from collections.abc import Mapping, Sequence
69+
from types import GenericAlias, UnionType
6970

7071

7172
T = TypeVar("T")
@@ -81,7 +82,7 @@ class _Field(NamedTuple):
8182
type: type
8283

8384

84-
def _is_array_or_container_type(tp: type, /) -> bool:
85+
def _is_array_or_container_type(tp: type | GenericAlias | UnionType, /) -> bool:
8586
if tp is np.ndarray:
8687
warn("Encountered 'numpy.ndarray' in a dataclass_array_container. "
8788
"This is deprecated and will stop working in 2026. "

arraycontext/context.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,6 @@
8484
8585
A :class:`typing.ParamSpec` representing the arguments of a function
8686
being :meth:`ArrayContext.outline`\ d.
87-
88-
References
89-
----------
90-
9187
"""
9288

9389
from __future__ import annotations

0 commit comments

Comments
 (0)