3131import itertools
3232from collections .abc import Callable , Mapping
3333from dataclasses import dataclass
34- from typing import Any
34+ from typing import cast
3535
3636import numpy as np
3737from immutabledict import immutabledict
3838
3939import pytato as pt
40+ from pymbolic import Scalar
4041from pytools .tag import Tag
4142
4243from arraycontext .container import is_array_container_type
4344from arraycontext .container .traversal import rec_keyed_map_array_container
44- from arraycontext .context import ArrayOrContainer , ArrayT
45+ from arraycontext .context import (
46+ Array ,
47+ ArrayOrContainer ,
48+ ArrayOrContainerTc ,
49+ ArrayT ,
50+ )
4551from arraycontext .impl .pytato import _BasePytatoArrayContext
4652
4753
48- def _get_arg_id_to_arg (args : tuple [Any , ...],
49- kwargs : Mapping [str , Any ]
50- ) -> immutabledict [tuple [Any , ...], Any ]:
54+ def _get_arg_id_to_arg (args : tuple [object , ...],
55+ kwargs : Mapping [str , object ]
56+ ) -> immutabledict [tuple [object , ...], object ]:
5157 """
5258 Helper for :meth:`OulinedCall.__call__`. Extracts mappings from argument id
5359 to argument values. See
5460 :attr:`CompiledFunction.input_id_to_name_in_function` for argument-id's
5561 representation.
5662 """
57- arg_id_to_arg : dict [tuple [Any , ...], Any ] = {}
63+ arg_id_to_arg : dict [tuple [object , ...], object ] = {}
5864
5965 for kw , arg in itertools .chain (enumerate (args ),
6066 kwargs .items ()):
@@ -64,7 +70,7 @@ def _get_arg_id_to_arg(args: tuple[Any, ...],
6470 # do not make scalars as placeholders since we inline them.
6571 pass
6672 elif is_array_container_type (arg .__class__ ):
67- def id_collector (keys : tuple [Any , ...], ary : ArrayT ) -> ArrayT :
73+ def id_collector (keys : tuple [object , ...], ary : ArrayT ) -> ArrayT :
6874 if np .isscalar (ary ):
6975 pass
7076 else :
@@ -85,21 +91,21 @@ def id_collector(keys: tuple[Any, ...], ary: ArrayT) -> ArrayT:
8591
8692
8793def _get_input_arg_id_str (
88- arg_id : tuple [Any , ...], prefix : str | None = None ) -> str :
94+ arg_id : tuple [object , ...], prefix : str | None = None ) -> str :
8995 if prefix is None :
9096 prefix = ""
9197 from arraycontext .impl .pytato .utils import _ary_container_key_stringifier
9298 return f"_actx_{ prefix } _in_{ _ary_container_key_stringifier (arg_id )} "
9399
94100
95- def _get_output_arg_id_str (arg_id : tuple [Any , ...]) -> str :
101+ def _get_output_arg_id_str (arg_id : tuple [object , ...]) -> str :
96102 from arraycontext .impl .pytato .utils import _ary_container_key_stringifier
97103 return f"_actx_out_{ _ary_container_key_stringifier (arg_id )} "
98104
99105
100106def _get_arg_id_to_placeholder (
101- arg_id_to_arg : Mapping [tuple [Any , ...], Any ],
102- prefix : str | None = None ) -> immutabledict [tuple [Any , ...], pt .Placeholder ]:
107+ arg_id_to_arg : Mapping [tuple [object , ...], object ],
108+ prefix : str | None = None ) -> immutabledict [tuple [object , ...], pt .Placeholder ]:
103109 """
104110 Helper for :meth:`OulinedCall.__call__`. Constructs a :class:`pytato.Placeholder`
105111 for each argument in *arg_id_to_arg*. See
@@ -115,27 +121,26 @@ def _get_arg_id_to_placeholder(
115121
116122
117123def _call_with_placeholders (
118- f : Callable [..., Any ],
119- args : tuple [Any ],
120- kwargs : Mapping [str , Any ],
121- arg_id_to_placeholder : Mapping [tuple [Any , ...], pt .Placeholder ]) -> Any :
124+ f : Callable [..., object ],
125+ args : tuple [object ],
126+ kwargs : Mapping [str , object ],
127+ arg_id_to_placeholder : Mapping [tuple [object , ...], pt .Placeholder ]) -> object :
122128 """
123129 Construct placeholders analogous to *args* and *kwargs* and call *f*.
124130 """
125131 def get_placeholder_replacement (
126- arg : ArrayOrContainer | None , key : tuple [Any , ...]
127- ) -> ArrayOrContainer | None :
132+ arg : ArrayOrContainerTc | Scalar | None , key : tuple [object , ...]
133+ ) -> ArrayOrContainerTc | Scalar | None :
128134 if arg is None :
129135 return None
130136 elif np .isscalar (arg ):
131- return arg
137+ return cast ( Scalar , arg )
132138 elif isinstance (arg , pt .Array ):
133- return arg_id_to_placeholder [key ]
139+ return cast ( ArrayOrContainerTc , arg_id_to_placeholder [key ])
134140 elif is_array_container_type (arg .__class__ ):
135- def _rec_to_placeholder (keys : tuple [Any , ...], ary : pt . Array ) -> pt . Array :
141+ def _rec_to_placeholder (keys : tuple [object , ...], ary : ArrayT ) -> ArrayT :
136142 result = get_placeholder_replacement (ary , key + keys )
137- assert isinstance (result , pt .Array )
138- return result
143+ return cast (ArrayT , result )
139144
140145 return rec_keyed_map_array_container (_rec_to_placeholder , arg )
141146 else :
@@ -157,7 +162,7 @@ def _unpack_output(
157162 elif is_array_container_type (output .__class__ ):
158163 unpacked_output = {}
159164
160- def _unpack_container (key : tuple [Any , ...], ary : ArrayT ) -> ArrayT :
165+ def _unpack_container (key : tuple [object , ...], ary : ArrayT ) -> ArrayT :
161166 key_str = _get_output_arg_id_str (key )
162167 unpacked_output [key_str ] = ary
163168 return ary
@@ -171,7 +176,7 @@ def _unpack_container(key: tuple[Any, ...], ary: ArrayT) -> ArrayT:
171176
172177def _pack_output (
173178 output_template : ArrayOrContainer ,
174- unpacked_output : pt . Array | immutabledict [str , pt . Array ]
179+ unpacked_output : Array | immutabledict [str , Array ]
175180 ) -> ArrayOrContainer :
176181 """
177182 Pack *unpacked_output* into array containers according to *output_template*.
@@ -182,7 +187,7 @@ def _pack_output(
182187 elif is_array_container_type (output_template .__class__ ):
183188 assert isinstance (unpacked_output , immutabledict )
184189
185- def _pack_into_container (key : tuple [Any , ...], ary : pt . Array ) -> pt . Array :
190+ def _pack_into_container (key : tuple [object , ...], ary : Array ) -> Array :
186191 key_str = _get_output_arg_id_str (key )
187192 return unpacked_output [key_str ]
188193
@@ -194,10 +199,10 @@ def _pack_into_container(key: tuple[Any, ...], ary: pt.Array) -> pt.Array:
194199@dataclass (frozen = True )
195200class OutlinedCall :
196201 actx : _BasePytatoArrayContext
197- f : Callable [..., Any ]
202+ f : Callable [..., object ]
198203 tags : frozenset [Tag ]
199204
200- def __call__ (self , * args : Any , ** kwargs : Any ) -> ArrayOrContainer :
205+ def __call__ (self , * args : object , ** kwargs : object ) -> ArrayOrContainer :
201206 arg_id_to_arg = _get_arg_id_to_arg (args , kwargs )
202207
203208 if __debug__ :
@@ -257,7 +262,9 @@ def __call__(self, *args: Any, **kwargs: Any) -> ArrayOrContainer:
257262 call_site_output = func_def (** call_bindings )
258263
259264 assert isinstance (call_site_output , pt .Array | immutabledict )
260- return _pack_output (output , call_site_output )
265+ # FIXME: pt.Array is not an actx Array
266+ return _pack_output (cast ("Array | immutabledict[str, Array]" , output ),
267+ cast ("Array | immutabledict[str, Array]" , call_site_output ))
261268
262269
263270# vim: foldmethod=marker
0 commit comments