Skip to content

Commit ce9736c

Browse files
committed
Numpy actx: cache execuctor
1 parent a63c263 commit ce9736c

File tree

2 files changed

+18
-10
lines changed

2 files changed

+18
-10
lines changed

arraycontext/context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ def to_numpy(self,
339339

340340
@abstractmethod
341341
def call_loopy(self,
342-
program: "loopy.TranslationUnit",
342+
t_unit: "loopy.TranslationUnit",
343343
**kwargs: Any) -> Dict[str, Array]:
344344
"""Execute the :mod:`loopy` program *program* on the arguments
345345
*kwargs*.

arraycontext/impl/numpy/__init__.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
from __future__ import annotations
2+
3+
14
"""
25
.. currentmodule:: arraycontext
36
@@ -30,6 +33,7 @@
3033
THE SOFTWARE.
3134
"""
3235

36+
from collections.abc import Mapping
3337
from typing import Any, Dict
3438

3539
import numpy as np
@@ -39,6 +43,7 @@
3943

4044
from arraycontext.container.traversal import rec_map_array_container, with_array_context
4145
from arraycontext.context import (
46+
Array,
4247
ArrayContext,
4348
ArrayOrContainerOrScalar,
4449
ArrayOrContainerOrScalarT,
@@ -62,10 +67,12 @@ class NumpyArrayContext(ArrayContext):
6267
6368
.. automethod:: __init__
6469
"""
70+
71+
_loopy_transform_cache: dict[lp.TranslationUnit, lp.ExecutorBase]
72+
6573
def __init__(self) -> None:
6674
super().__init__()
67-
self._loopy_transform_cache: \
68-
Dict[lp.TranslationUnit, lp.TranslationUnit] = {}
75+
self._loopy_transform_cache = {}
6976

7077
array_types = (NumpyNonObjectArray,)
7178

@@ -88,17 +95,18 @@ def to_numpy(self,
8895
) -> NumpyOrContainerOrScalar:
8996
return array
9097

91-
def call_loopy(self, t_unit, **kwargs):
98+
def call_loopy(
99+
self,
100+
t_unit: lp.TranslationUnit, **kwargs: Any
101+
) -> dict[str, Array]:
92102
t_unit = t_unit.copy(target=lp.ExecutableCTarget())
93103
try:
94-
t_unit = self._loopy_transform_cache[t_unit]
104+
executor = self._loopy_transform_cache[t_unit]
95105
except KeyError:
96-
orig_t_unit = t_unit
97-
t_unit = self.transform_loopy_program(t_unit)
98-
self._loopy_transform_cache[orig_t_unit] = t_unit
99-
del orig_t_unit
106+
executor = self.transform_loopy_program(t_unit).executor()
107+
self._loopy_transform_cache[t_unit] = executor
100108

101-
_, result = t_unit(**kwargs)
109+
_, result = executor(**kwargs)
102110

103111
return result
104112

0 commit comments

Comments
 (0)