Skip to content

Commit 33c5d9b

Browse files
inducermajosm
authored andcommitted
Towards typing of outlining
1 parent ed09753 commit 33c5d9b

File tree

1 file changed

+35
-28
lines changed

1 file changed

+35
-28
lines changed

arraycontext/impl/pytato/outline.py

Lines changed: 35 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -31,30 +31,36 @@
3131
import itertools
3232
from collections.abc import Callable, Mapping
3333
from dataclasses import dataclass
34-
from typing import Any
34+
from typing import cast
3535

3636
import numpy as np
3737
from immutabledict import immutabledict
3838

3939
import pytato as pt
40+
from pymbolic import Scalar
4041
from pytools.tag import Tag
4142

4243
from arraycontext.container import is_array_container_type
4344
from arraycontext.container.traversal import rec_keyed_map_array_container
44-
from arraycontext.context import ArrayOrContainer, ArrayT
45+
from arraycontext.context import (
46+
Array,
47+
ArrayOrContainer,
48+
ArrayOrContainerTc,
49+
ArrayT,
50+
)
4551
from arraycontext.impl.pytato import _BasePytatoArrayContext
4652

4753

48-
def _get_arg_id_to_arg(args: tuple[Any, ...],
49-
kwargs: Mapping[str, Any]
50-
) -> immutabledict[tuple[Any, ...], Any]:
54+
def _get_arg_id_to_arg(args: tuple[object, ...],
55+
kwargs: Mapping[str, object]
56+
) -> immutabledict[tuple[object, ...], object]:
5157
"""
5258
Helper for :meth:`OulinedCall.__call__`. Extracts mappings from argument id
5359
to argument values. See
5460
:attr:`CompiledFunction.input_id_to_name_in_function` for argument-id's
5561
representation.
5662
"""
57-
arg_id_to_arg: dict[tuple[Any, ...], Any] = {}
63+
arg_id_to_arg: dict[tuple[object, ...], object] = {}
5864

5965
for kw, arg in itertools.chain(enumerate(args),
6066
kwargs.items()):
@@ -64,7 +70,7 @@ def _get_arg_id_to_arg(args: tuple[Any, ...],
6470
# do not make scalars as placeholders since we inline them.
6571
pass
6672
elif is_array_container_type(arg.__class__):
67-
def id_collector(keys: tuple[Any, ...], ary: ArrayT) -> ArrayT:
73+
def id_collector(keys: tuple[object, ...], ary: ArrayT) -> ArrayT:
6874
if np.isscalar(ary):
6975
pass
7076
else:
@@ -85,21 +91,21 @@ def id_collector(keys: tuple[Any, ...], ary: ArrayT) -> ArrayT:
8591

8692

8793
def _get_input_arg_id_str(
88-
arg_id: tuple[Any, ...], prefix: str | None = None) -> str:
94+
arg_id: tuple[object, ...], prefix: str | None = None) -> str:
8995
if prefix is None:
9096
prefix = ""
9197
from arraycontext.impl.pytato.utils import _ary_container_key_stringifier
9298
return f"_actx_{prefix}_in_{_ary_container_key_stringifier(arg_id)}"
9399

94100

95-
def _get_output_arg_id_str(arg_id: tuple[Any, ...]) -> str:
101+
def _get_output_arg_id_str(arg_id: tuple[object, ...]) -> str:
96102
from arraycontext.impl.pytato.utils import _ary_container_key_stringifier
97103
return f"_actx_out_{_ary_container_key_stringifier(arg_id)}"
98104

99105

100106
def _get_arg_id_to_placeholder(
101-
arg_id_to_arg: Mapping[tuple[Any, ...], Any],
102-
prefix: str | None = None) -> immutabledict[tuple[Any, ...], pt.Placeholder]:
107+
arg_id_to_arg: Mapping[tuple[object, ...], object],
108+
prefix: str | None = None) -> immutabledict[tuple[object, ...], pt.Placeholder]:
103109
"""
104110
Helper for :meth:`OulinedCall.__call__`. Constructs a :class:`pytato.Placeholder`
105111
for each argument in *arg_id_to_arg*. See
@@ -115,27 +121,26 @@ def _get_arg_id_to_placeholder(
115121

116122

117123
def _call_with_placeholders(
118-
f: Callable[..., Any],
119-
args: tuple[Any],
120-
kwargs: Mapping[str, Any],
121-
arg_id_to_placeholder: Mapping[tuple[Any, ...], pt.Placeholder]) -> Any:
124+
f: Callable[..., object],
125+
args: tuple[object],
126+
kwargs: Mapping[str, object],
127+
arg_id_to_placeholder: Mapping[tuple[object, ...], pt.Placeholder]) -> object:
122128
"""
123129
Construct placeholders analogous to *args* and *kwargs* and call *f*.
124130
"""
125131
def get_placeholder_replacement(
126-
arg: ArrayOrContainer | None, key: tuple[Any, ...]
127-
) -> ArrayOrContainer | None:
132+
arg: ArrayOrContainerTc | Scalar | None, key: tuple[object, ...]
133+
) -> ArrayOrContainerTc | Scalar | None:
128134
if arg is None:
129135
return None
130136
elif np.isscalar(arg):
131-
return arg
137+
return cast(Scalar, arg)
132138
elif isinstance(arg, pt.Array):
133-
return arg_id_to_placeholder[key]
139+
return cast(ArrayOrContainerTc, arg_id_to_placeholder[key])
134140
elif is_array_container_type(arg.__class__):
135-
def _rec_to_placeholder(keys: tuple[Any, ...], ary: pt.Array) -> pt.Array:
141+
def _rec_to_placeholder(keys: tuple[object, ...], ary: ArrayT) -> ArrayT:
136142
result = get_placeholder_replacement(ary, key + keys)
137-
assert isinstance(result, pt.Array)
138-
return result
143+
return cast(ArrayT, result)
139144

140145
return rec_keyed_map_array_container(_rec_to_placeholder, arg)
141146
else:
@@ -157,7 +162,7 @@ def _unpack_output(
157162
elif is_array_container_type(output.__class__):
158163
unpacked_output = {}
159164

160-
def _unpack_container(key: tuple[Any, ...], ary: ArrayT) -> ArrayT:
165+
def _unpack_container(key: tuple[object, ...], ary: ArrayT) -> ArrayT:
161166
key_str = _get_output_arg_id_str(key)
162167
unpacked_output[key_str] = ary
163168
return ary
@@ -171,7 +176,7 @@ def _unpack_container(key: tuple[Any, ...], ary: ArrayT) -> ArrayT:
171176

172177
def _pack_output(
173178
output_template: ArrayOrContainer,
174-
unpacked_output: pt.Array | immutabledict[str, pt.Array]
179+
unpacked_output: Array | immutabledict[str, Array]
175180
) -> ArrayOrContainer:
176181
"""
177182
Pack *unpacked_output* into array containers according to *output_template*.
@@ -182,7 +187,7 @@ def _pack_output(
182187
elif is_array_container_type(output_template.__class__):
183188
assert isinstance(unpacked_output, immutabledict)
184189

185-
def _pack_into_container(key: tuple[Any, ...], ary: pt.Array) -> pt.Array:
190+
def _pack_into_container(key: tuple[object, ...], ary: Array) -> Array:
186191
key_str = _get_output_arg_id_str(key)
187192
return unpacked_output[key_str]
188193

@@ -194,10 +199,10 @@ def _pack_into_container(key: tuple[Any, ...], ary: pt.Array) -> pt.Array:
194199
@dataclass(frozen=True)
195200
class OutlinedCall:
196201
actx: _BasePytatoArrayContext
197-
f: Callable[..., Any]
202+
f: Callable[..., object]
198203
tags: frozenset[Tag]
199204

200-
def __call__(self, *args: Any, **kwargs: Any) -> ArrayOrContainer:
205+
def __call__(self, *args: object, **kwargs: object) -> ArrayOrContainer:
201206
arg_id_to_arg = _get_arg_id_to_arg(args, kwargs)
202207

203208
if __debug__:
@@ -257,7 +262,9 @@ def __call__(self, *args: Any, **kwargs: Any) -> ArrayOrContainer:
257262
call_site_output = func_def(**call_bindings)
258263

259264
assert isinstance(call_site_output, pt.Array | immutabledict)
260-
return _pack_output(output, call_site_output)
265+
# FIXME: pt.Array is not an actx Array
266+
return _pack_output(cast("Array | immutabledict[str, Array]", output),
267+
cast("Array | immutabledict[str, Array]", call_site_output))
261268

262269

263270
# vim: foldmethod=marker

0 commit comments

Comments
 (0)