3737import logging
3838from collections .abc import Callable , Hashable , Mapping
3939from dataclasses import dataclass , field
40- from typing import Any
40+ from typing import Any , overload
4141
4242import numpy as np
4343from immutabledict import immutabledict
4444
45+ import pyopencl .array as cla
4546import pytato as pt
4647from pytools import ProcessLogger , to_identifier
4748from pytools .tag import Tag
4849
4950from arraycontext .container import ArrayContainer , is_array_container_type
5051from arraycontext .container .traversal import rec_keyed_map_array_container
51- from arraycontext .context import ArrayT
52+ from arraycontext .impl .pyopencl .taggable_cl_array import (
53+ TaggableCLArray ,
54+ )
5255from arraycontext .impl .pytato import (
5356 PytatoJAXArrayContext ,
5457 PytatoPyOpenCLArrayContext ,
@@ -110,7 +113,7 @@ class LeafArrayDescriptor(AbstractInputDescriptor):
110113
111114# {{{ utilities
112115
113- def _ary_container_key_stringifier (keys : tuple [Any , ...]) -> str :
116+ def _ary_container_key_stringifier (keys : tuple [object , ...]) -> str :
114117 """
115118 Helper for :meth:`BaseLazilyCompilingFunctionCaller.__call__`. Stringifies an
116119 array-container's component's key. Goals of this routine:
@@ -119,12 +122,12 @@ def _ary_container_key_stringifier(keys: tuple[Any, ...]) -> str:
119122 * Stringified key must a valid identifier according to :meth:`str.isidentifier`
120123 * (informal) Shorter identifiers are preferred
121124 """
122- def _rec_str (key : Any ) -> str :
125+ def _rec_str (key : object ) -> str :
123126 if isinstance (key , str | int ):
124127 return str (key )
125128 elif isinstance (key , tuple ):
126129 # t in '_actx_t': stands for tuple
127- return "_actx_t" + "_" .join (_rec_str (k ) for k in key ) + "_actx_endt"
130+ return "_actx_t" + "_" .join (_rec_str (k ) for k in key ) + "_actx_endt" # pyright: ignore[reportUnknownArgumentType, reportUnknownVariableType]
128131 else :
129132 raise NotImplementedError ("Key-stringication unimplemented for "
130133 f"'{ type (key ).__name__ } '." )
@@ -175,7 +178,28 @@ def id_collector(keys, ary):
175178 return immutabledict (arg_id_to_arg ), immutabledict (arg_id_to_descr )
176179
177180
178- def _to_input_for_compiled (ary : ArrayT , actx : PytatoPyOpenCLArrayContext ):
181+ @overload
182+ def _to_input_for_compiled (
183+ ary : pt .Array , actx : PytatoPyOpenCLArrayContext ) -> pt .Array :
184+ ...
185+
186+
187+ @overload
188+ def _to_input_for_compiled (
189+ ary : TaggableCLArray , actx : PytatoPyOpenCLArrayContext ) -> TaggableCLArray :
190+ ...
191+
192+
193+ @overload
194+ def _to_input_for_compiled (
195+ ary : cla .Array , actx : PytatoPyOpenCLArrayContext
196+ ) -> cla .Array :
197+ ...
198+
199+
200+ def _to_input_for_compiled (
201+ ary : pt .Array | TaggableCLArray | cla .Array ,
202+ actx : PytatoPyOpenCLArrayContext ) -> pt .Array | TaggableCLArray | cla .Array :
179203 """
180204 Preprocess *ary* before turning it into a :class:`pytato.array.Placeholder`
181205 in :meth:`LazilyCompilingFunctionCaller.__call__`.
@@ -185,19 +209,14 @@ def _to_input_for_compiled(ary: ArrayT, actx: PytatoPyOpenCLArrayContext):
185209 - Metadata Inference that is supplied via *actx*\' s
186210 :meth:`PytatoPyOpenCLArrayContext.transform_dag`.
187211 """
188- import pyopencl .array as cla
189-
190- from arraycontext .impl .pyopencl .taggable_cl_array import (
191- TaggableCLArray ,
192- to_tagged_cl_array ,
193- )
212+ from arraycontext .impl .pyopencl .taggable_cl_array import to_tagged_cl_array
194213 if isinstance (ary , pt .Array ):
195214 dag = pt .make_dict_of_named_arrays ({"_actx_out" : ary })
196215 # Transform the DAG to give metadata inference a chance to do its job
197216 return actx .transform_dag (dag )["_actx_out" ].expr
198217 elif isinstance (ary , TaggableCLArray ):
199218 return ary
200- elif isinstance ( ary , cla . Array ) :
219+ else :
201220 from warnings import warn
202221 warn ("Passing pyopencl.array.Array to a compiled callable"
203222 " is deprecated and will stop working in 2023."
@@ -207,8 +226,6 @@ def _to_input_for_compiled(ary: ArrayT, actx: PytatoPyOpenCLArrayContext):
207226 return to_tagged_cl_array (ary ,
208227 axes = None ,
209228 tags = frozenset ())
210- else :
211- raise NotImplementedError (type (ary ))
212229
213230
214231def _get_f_placeholder_args (arg , kw , arg_id_to_name , actx ):
@@ -230,7 +247,7 @@ def _get_f_placeholder_args(arg, kw, arg_id_to_name, actx):
230247 axes = arg .axes ,
231248 tags = arg .tags )
232249 elif is_array_container_type (arg .__class__ ):
233- def _rec_to_placeholder (keys , ary ):
250+ def _rec_to_placeholder (keys , ary : pt . Array ):
234251 index = (kw , * keys )
235252 name = arg_id_to_name [index ]
236253 # Transform the DAG to give metadata inference a chance to do its job
0 commit comments