4545from arraycontext .context import (
4646 Array ,
4747 ArrayOrContainer ,
48- ArrayOrContainerTc ,
4948 ArrayT ,
5049)
5150from arraycontext .impl .pytato import _BasePytatoArrayContext
5251
5352
5453def _get_arg_id_to_arg (args : tuple [object , ...],
5554 kwargs : Mapping [str , object ]
56- ) -> immutabledict [tuple [object , ...], object ]:
55+ ) -> immutabledict [tuple [object , ...], pt . Array ]:
5756 """
5857 Helper for :meth:`OulinedCall.__call__`. Extracts mappings from argument id
5958 to argument values. See
@@ -104,7 +103,7 @@ def _get_output_arg_id_str(arg_id: tuple[object, ...]) -> str:
104103
105104
106105def _get_arg_id_to_placeholder (
107- arg_id_to_arg : Mapping [tuple [object , ...], object ],
106+ arg_id_to_arg : Mapping [tuple [object , ...], pt . Array ],
108107 prefix : str | None = None ) -> immutabledict [tuple [object , ...], pt .Placeholder ]:
109108 """
110109 Helper for :meth:`OulinedCall.__call__`. Constructs a :class:`pytato.Placeholder`
@@ -122,25 +121,25 @@ def _get_arg_id_to_placeholder(
122121
123122def _call_with_placeholders (
124123 f : Callable [..., object ],
125- args : tuple [object ],
124+ args : tuple [object , ... ],
126125 kwargs : Mapping [str , object ],
127126 arg_id_to_placeholder : Mapping [tuple [object , ...], pt .Placeholder ]) -> object :
128127 """
129128 Construct placeholders analogous to *args* and *kwargs* and call *f*.
130129 """
131130 def get_placeholder_replacement (
132- arg : ArrayOrContainerTc | Scalar | None , key : tuple [object , ...]
133- ) -> ArrayOrContainerTc | Scalar | None :
131+ arg : ArrayOrContainer | Scalar | None , key : tuple [object , ...]
132+ ) -> ArrayOrContainer | Scalar | None :
134133 if arg is None :
135134 return None
136135 elif np .isscalar (arg ):
137136 return cast (Scalar , arg )
138137 elif isinstance (arg , pt .Array ):
139- return cast ( ArrayOrContainerTc , arg_id_to_placeholder [key ])
138+ return arg_id_to_placeholder [key ]
140139 elif is_array_container_type (arg .__class__ ):
141- def _rec_to_placeholder (keys : tuple [ object , ...], ary : ArrayT ) -> ArrayT :
142- result = get_placeholder_replacement ( ary , key + keys )
143- return cast (ArrayT , result )
140+ def _rec_to_placeholder (
141+ keys : tuple [ object , ...], ary : Array ) -> Array :
142+ return cast ("Array" , get_placeholder_replacement ( ary , key + keys ) )
144143
145144 return rec_keyed_map_array_container (_rec_to_placeholder , arg )
146145 else :
@@ -176,7 +175,7 @@ def _unpack_container(key: tuple[object, ...], ary: ArrayT) -> ArrayT:
176175
177176def _pack_output (
178177 output_template : ArrayOrContainer ,
179- unpacked_output : Array | immutabledict [str , Array ]
178+ unpacked_output : pt . Array | immutabledict [str , pt . Array ]
180179 ) -> ArrayOrContainer :
181180 """
182181 Pack *unpacked_output* into array containers according to *output_template*.
@@ -187,9 +186,9 @@ def _pack_output(
187186 elif is_array_container_type (output_template .__class__ ):
188187 assert isinstance (unpacked_output , immutabledict )
189188
190- def _pack_into_container (key : tuple [object , ...], ary : Array ) -> Array :
189+ def _pack_into_container (key : tuple [object , ...], ary : Array ) -> Array : # pyright: ignore[reportUnusedParameter]
191190 key_str = _get_output_arg_id_str (key )
192- return unpacked_output [key_str ]
191+ return unpacked_output [key_str ] # type: ignore[index]
193192
194193 return rec_keyed_map_array_container (_pack_into_container , output_template )
195194 else :
@@ -262,9 +261,6 @@ def __call__(self, *args: object, **kwargs: object) -> ArrayOrContainer:
262261 call_site_output = func_def (** call_bindings )
263262
264263 assert isinstance (call_site_output , pt .Array | immutabledict )
265- # FIXME: pt.Array is not an actx Array
266- return _pack_output (cast ("Array | immutabledict[str, Array]" , output ),
267- cast ("Array | immutabledict[str, Array]" , call_site_output ))
268-
264+ return _pack_output (output , call_site_output )
269265
270266# vim: foldmethod=marker
0 commit comments