4646from dataclasses import dataclass , field
4747from functools import partialmethod
4848from numbers import Number
49- from typing import TYPE_CHECKING , Any , TypeVar
49+ from typing import TYPE_CHECKING , Protocol , TypeVar , cast
5050from warnings import warn
5151
5252import numpy as np
6666
6767
6868if TYPE_CHECKING :
69- from collections .abc import Callable
69+ from collections .abc import Callable , Mapping
7070
7171 from arraycontext .context import ArrayContext
7272 from arraycontext .typing import (
8282TypeT = TypeVar ("TypeT" , bound = type )
8383
8484
85+ class _HasInitArraysSerialization (Protocol ):
86+ @classmethod
87+ def _serialize_init_arrays_code (cls , instance_name : str ) -> Mapping [str , str ]:
88+ ...
89+
90+ @classmethod
91+ def _deserialize_init_arrays_code (cls ,
92+ tmpl_instance_name : str ,
93+ args : Mapping [str , str ]
94+ ) -> str :
95+ ...
96+
97+
8598@enum .unique
8699class _OpClass (enum .Enum ):
87100 ARITHMETIC = enum .auto ()
@@ -254,11 +267,15 @@ class methods ``_deserialize_init_arrays_code`` and
254267 structure type, the implementation might look like this::
255268
256269 @classmethod
257- def _serialize_init_arrays_code(cls, instance_name):
270+ def _serialize_init_arrays_code(cls,
271+ instance_name: str) -> Mapping[str, str]:
258272 return {"u": f"{instance_name}.u", "v": f"{instance_name}.v"}
259273
260274 @classmethod
261- def _deserialize_init_arrays_code(cls, tmpl_instance_name, args):
275+ def _deserialize_init_arrays_code(cls,
276+ tmpl_instance_name: str,
277+ args: Mapping[str, str]
278+ ) -> str:
262279 return f"u={args['u']}, v={args['v']}"
263280
264281 :func:`dataclass_array_container` automatically generates an appropriate
@@ -366,7 +383,7 @@ def numpy_pred(name: str) -> str:
366383 def numpy_pred (name : str ) -> str :
367384 return f"isinstance({ name } , np.ndarray) and { name } .dtype.char == 'O'"
368385 else :
369- def numpy_pred (name : str ) -> str :
386+ def numpy_pred (name : str ) -> str : # pyright: ignore[reportUnusedParameter]
370387 return "False" # optimized away
371388
372389 if np .ndarray in container_types_bcast_across and bcasts_across_obj_array :
@@ -383,7 +400,7 @@ def numpy_pred(name: str) -> str:
383400 else [old_ct ])
384401 )
385402
386- desired_op_classes = set ()
403+ desired_op_classes : set [ _OpClass ] = set ()
387404 if arithmetic :
388405 desired_op_classes .add (_OpClass .ARITHMETIC )
389406 if matmul :
@@ -399,7 +416,7 @@ def numpy_pred(name: str) -> str:
399416
400417 # }}}
401418
402- def wrap (cls : Any ) -> Any :
419+ def wrap (cls : TypeT ) -> TypeT :
403420 if not hasattr (cls , "__array_ufunc__" ):
404421 warn (f"{ cls } does not have __array_ufunc__ set. "
405422 "This will cause numpy to attempt broadcasting, in a way that "
@@ -533,15 +550,16 @@ def tup_str(t: tuple[str, ...]) -> str:
533550
534551 # {{{ unary operators
535552
553+ cls_init_arg_ser = cast ("type[_HasInitArraysSerialization]" , cls )
536554 for dunder_name , op_str , op_cls in _UNARY_OP_AND_DUNDER :
537555 if op_cls not in desired_op_classes :
538556 continue
539557
540558 fname = f"_{ cls .__name__ .lower ()} _{ dunder_name } "
541- init_args = cls ._deserialize_init_arrays_code ("arg1" , {
559+ init_args = cls_init_arg_ser ._deserialize_init_arrays_code ("arg1" , {
542560 key_arg1 : _format_unary_op_str (op_str , expr_arg1 )
543561 for key_arg1 , expr_arg1 in
544- cls ._serialize_init_arrays_code ("arg1" ).items ()
562+ cls_init_arg_ser ._serialize_init_arrays_code ("arg1" ).items ()
545563 })
546564
547565 gen (f"""
@@ -572,24 +590,28 @@ def {fname}(arg1):
572590
573591 continue
574592
575- zip_init_args = cls ._deserialize_init_arrays_code ("arg1" , {
593+ zip_init_args = cls_init_arg_ser ._deserialize_init_arrays_code ("arg1" , {
576594 same_key (key_arg1 , key_arg2 ):
577595 _format_binary_op_str (op_str , expr_arg1 , expr_arg2 )
578596 for (key_arg1 , expr_arg1 ), (key_arg2 , expr_arg2 ) in zip (
579- cls ._serialize_init_arrays_code ("arg1" ).items (),
580- cls ._serialize_init_arrays_code ("arg2" ).items (),
597+ cls_init_arg_ser ._serialize_init_arrays_code ("arg1" ).items (),
598+ cls_init_arg_ser ._serialize_init_arrays_code ("arg2" ).items (),
581599 strict = True )
582600 })
583- bcast_init_args_arg1_is_outer = cls ._deserialize_init_arrays_code ("arg1" , {
584- key_arg1 : _format_binary_op_str (op_str , expr_arg1 , "arg2" )
585- for key_arg1 , expr_arg1 in
586- cls ._serialize_init_arrays_code ("arg1" ).items ()
587- })
588- bcast_init_args_arg2_is_outer = cls ._deserialize_init_arrays_code ("arg2" , {
589- key_arg2 : _format_binary_op_str (op_str , "arg1" , expr_arg2 )
590- for key_arg2 , expr_arg2 in
591- cls ._serialize_init_arrays_code ("arg2" ).items ()
592- })
601+ bcast_init_args_arg1_is_outer = \
602+ cls_init_arg_ser ._deserialize_init_arrays_code (
603+ "arg1" , {
604+ key_arg1 : _format_binary_op_str (op_str , expr_arg1 , "arg2" )
605+ for key_arg1 , expr_arg1 in
606+ cls_init_arg_ser ._serialize_init_arrays_code ("arg1" ).items ()
607+ })
608+ bcast_init_args_arg2_is_outer = \
609+ cls_init_arg_ser ._deserialize_init_arrays_code (
610+ "arg2" , {
611+ key_arg2 : _format_binary_op_str (op_str , "arg1" , expr_arg2 )
612+ for key_arg2 , expr_arg2 in
613+ cls_init_arg_ser ._serialize_init_arrays_code ("arg2" ).items ()
614+ })
593615
594616 # {{{ "forward" binary operators
595617
0 commit comments