Skip to content

Commit 2882aa0

Browse files
committed
Type improvements in with_container_arithmetic
1 parent 7127131 commit 2882aa0

File tree

3 files changed

+48
-23
lines changed

3 files changed

+48
-23
lines changed

arraycontext/container/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ def serialize_container(
208208
@singledispatch
209209
def deserialize_container(
210210
template: ArrayContainerT,
211-
serialized: SerializedContainer) -> ArrayContainerT:
211+
serialized: SerializedContainer) -> ArrayContainerT: # pyright: ignore[reportUnusedParameter]
212212
"""Deserialize a sequence into an array container following a *template*.
213213
214214
:param template: an instance of an existing object that

arraycontext/container/arithmetic.py

Lines changed: 44 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
from dataclasses import dataclass, field
4747
from functools import partialmethod
4848
from numbers import Number
49-
from typing import TYPE_CHECKING, Any, TypeVar
49+
from typing import TYPE_CHECKING, Protocol, TypeVar, cast
5050
from warnings import warn
5151

5252
import numpy as np
@@ -66,7 +66,7 @@
6666

6767

6868
if TYPE_CHECKING:
69-
from collections.abc import Callable
69+
from collections.abc import Callable, Mapping
7070

7171
from arraycontext.context import ArrayContext
7272
from arraycontext.typing import (
@@ -82,6 +82,19 @@
8282
TypeT = TypeVar("TypeT", bound=type)
8383

8484

85+
class _HasInitArraysSerialization(Protocol):
86+
@classmethod
87+
def _serialize_init_arrays_code(cls, instance_name: str) -> Mapping[str, str]:
88+
...
89+
90+
@classmethod
91+
def _deserialize_init_arrays_code(cls,
92+
tmpl_instance_name: str,
93+
args: Mapping[str, str]
94+
) -> str:
95+
...
96+
97+
8598
@enum.unique
8699
class _OpClass(enum.Enum):
87100
ARITHMETIC = enum.auto()
@@ -254,11 +267,15 @@ class methods ``_deserialize_init_arrays_code`` and
254267
structure type, the implementation might look like this::
255268
256269
@classmethod
257-
def _serialize_init_arrays_code(cls, instance_name):
270+
def _serialize_init_arrays_code(cls,
271+
instance_name: str) -> Mapping[str, str]:
258272
return {"u": f"{instance_name}.u", "v": f"{instance_name}.v"}
259273
260274
@classmethod
261-
def _deserialize_init_arrays_code(cls, tmpl_instance_name, args):
275+
def _deserialize_init_arrays_code(cls,
276+
tmpl_instance_name: str,
277+
args: Mapping[str, str]
278+
) -> str:
262279
return f"u={args['u']}, v={args['v']}"
263280
264281
:func:`dataclass_array_container` automatically generates an appropriate
@@ -366,7 +383,7 @@ def numpy_pred(name: str) -> str:
366383
def numpy_pred(name: str) -> str:
367384
return f"isinstance({name}, np.ndarray) and {name}.dtype.char == 'O'"
368385
else:
369-
def numpy_pred(name: str) -> str:
386+
def numpy_pred(name: str) -> str: # pyright: ignore[reportUnusedParameter]
370387
return "False" # optimized away
371388

372389
if np.ndarray in container_types_bcast_across and bcasts_across_obj_array:
@@ -383,7 +400,7 @@ def numpy_pred(name: str) -> str:
383400
else [old_ct])
384401
)
385402

386-
desired_op_classes = set()
403+
desired_op_classes: set[_OpClass] = set()
387404
if arithmetic:
388405
desired_op_classes.add(_OpClass.ARITHMETIC)
389406
if matmul:
@@ -399,7 +416,7 @@ def numpy_pred(name: str) -> str:
399416

400417
# }}}
401418

402-
def wrap(cls: Any) -> Any:
419+
def wrap(cls: TypeT) -> TypeT:
403420
if not hasattr(cls, "__array_ufunc__"):
404421
warn(f"{cls} does not have __array_ufunc__ set. "
405422
"This will cause numpy to attempt broadcasting, in a way that "
@@ -533,15 +550,16 @@ def tup_str(t: tuple[str, ...]) -> str:
533550

534551
# {{{ unary operators
535552

553+
cls_init_arg_ser = cast("type[_HasInitArraysSerialization]", cls)
536554
for dunder_name, op_str, op_cls in _UNARY_OP_AND_DUNDER:
537555
if op_cls not in desired_op_classes:
538556
continue
539557

540558
fname = f"_{cls.__name__.lower()}_{dunder_name}"
541-
init_args = cls._deserialize_init_arrays_code("arg1", {
559+
init_args = cls_init_arg_ser._deserialize_init_arrays_code("arg1", {
542560
key_arg1: _format_unary_op_str(op_str, expr_arg1)
543561
for key_arg1, expr_arg1 in
544-
cls._serialize_init_arrays_code("arg1").items()
562+
cls_init_arg_ser._serialize_init_arrays_code("arg1").items()
545563
})
546564

547565
gen(f"""
@@ -572,24 +590,28 @@ def {fname}(arg1):
572590

573591
continue
574592

575-
zip_init_args = cls._deserialize_init_arrays_code("arg1", {
593+
zip_init_args = cls_init_arg_ser._deserialize_init_arrays_code("arg1", {
576594
same_key(key_arg1, key_arg2):
577595
_format_binary_op_str(op_str, expr_arg1, expr_arg2)
578596
for (key_arg1, expr_arg1), (key_arg2, expr_arg2) in zip(
579-
cls._serialize_init_arrays_code("arg1").items(),
580-
cls._serialize_init_arrays_code("arg2").items(),
597+
cls_init_arg_ser._serialize_init_arrays_code("arg1").items(),
598+
cls_init_arg_ser._serialize_init_arrays_code("arg2").items(),
581599
strict=True)
582600
})
583-
bcast_init_args_arg1_is_outer = cls._deserialize_init_arrays_code("arg1", {
584-
key_arg1: _format_binary_op_str(op_str, expr_arg1, "arg2")
585-
for key_arg1, expr_arg1 in
586-
cls._serialize_init_arrays_code("arg1").items()
587-
})
588-
bcast_init_args_arg2_is_outer = cls._deserialize_init_arrays_code("arg2", {
589-
key_arg2: _format_binary_op_str(op_str, "arg1", expr_arg2)
590-
for key_arg2, expr_arg2 in
591-
cls._serialize_init_arrays_code("arg2").items()
592-
})
601+
bcast_init_args_arg1_is_outer = \
602+
cls_init_arg_ser._deserialize_init_arrays_code(
603+
"arg1", {
604+
key_arg1: _format_binary_op_str(op_str, expr_arg1, "arg2")
605+
for key_arg1, expr_arg1 in
606+
cls_init_arg_ser._serialize_init_arrays_code("arg1").items()
607+
})
608+
bcast_init_args_arg2_is_outer = \
609+
cls_init_arg_ser._deserialize_init_arrays_code(
610+
"arg2", {
611+
key_arg2: _format_binary_op_str(op_str, "arg1", expr_arg2)
612+
for key_arg2, expr_arg2 in
613+
cls_init_arg_ser._serialize_init_arrays_code("arg2").items()
614+
})
593615

594616
# {{{ "forward" binary operators
595617

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,9 @@ extend-ignore-re = [
130130
[tool.typos.default.extend-words]
131131
"nd" = "nd"
132132

133+
# short for 'serialization'
134+
"ser" = "ser"
135+
133136
[tool.basedpyright]
134137
reportImplicitStringConcatenation = "none"
135138
reportUnnecessaryIsInstance = "none"

0 commit comments

Comments
 (0)