1+ from __future__ import annotations
2+
3+
14"""
25.. currentmodule:: arraycontext
36
3033THE SOFTWARE.
3134"""
3235
36+ from collections .abc import Mapping
3337from typing import Any , Dict
3438
3539import numpy as np
3943
4044from arraycontext .container .traversal import rec_map_array_container , with_array_context
4145from arraycontext .context import (
46+ Array ,
4247 ArrayContext ,
4348 ArrayOrContainerOrScalar ,
4449 ArrayOrContainerOrScalarT ,
@@ -62,10 +67,12 @@ class NumpyArrayContext(ArrayContext):
6267
6368 .. automethod:: __init__
6469 """
70+
71+ _loopy_transform_cache : dict [lp .TranslationUnit , lp .ExecutorBase ]
72+
6573 def __init__ (self ) -> None :
6674 super ().__init__ ()
67- self ._loopy_transform_cache : \
68- Dict [lp .TranslationUnit , lp .TranslationUnit ] = {}
75+ self ._loopy_transform_cache = {}
6976
7077 array_types = (NumpyNonObjectArray ,)
7178
@@ -88,17 +95,18 @@ def to_numpy(self,
8895 ) -> NumpyOrContainerOrScalar :
8996 return array
9097
91- def call_loopy (self , t_unit , ** kwargs ):
98+ def call_loopy (
99+ self ,
100+ t_unit : lp .TranslationUnit , ** kwargs : Any
101+ ) -> dict [str , Array ]:
92102 t_unit = t_unit .copy (target = lp .ExecutableCTarget ())
93103 try :
94- t_unit = self ._loopy_transform_cache [t_unit ]
104+ executor = self ._loopy_transform_cache [t_unit ]
95105 except KeyError :
96- orig_t_unit = t_unit
97- t_unit = self .transform_loopy_program (t_unit )
98- self ._loopy_transform_cache [orig_t_unit ] = t_unit
99- del orig_t_unit
106+ executor = self .transform_loopy_program (t_unit ).executor ()
107+ self ._loopy_transform_cache [t_unit ] = executor
100108
101- _ , result = t_unit (** kwargs )
109+ _ , result = executor (** kwargs )
102110
103111 return result
104112
0 commit comments