Skip to content

Commit dee0ca4

Browse files
add device<->host transfer mappers (#282)
* implement TransferTo{Host,Device}Mapper * do not support transfer to same type * use {to,from}_numpy * lint fixes * add comment * ruff * disable spurious pylint warning * add to docs * more doc fixes * datawrapper doc * more doc fixes * Improve mapper names * Bring back numpy actx constructor docs --------- Co-authored-by: Andreas Kloeckner <inform@tiker.net>
1 parent a59a9c8 commit dee0ca4

File tree

4 files changed

+156
-7
lines changed

4 files changed

+156
-7
lines changed

arraycontext/impl/numpy/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from __future__ import annotations
22

33

4-
"""
4+
__doc__ = """
55
.. currentmodule:: arraycontext
66
7-
A mod :`numpy`-based array context.
7+
A :mod:`numpy`-based array context.
88
99
.. autoclass:: NumpyArrayContext
1010
"""

arraycontext/impl/pytato/__init__.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
__doc__ = """
55
.. currentmodule:: arraycontext
66
7-
A :mod:`pytato`-based array context defers the evaluation of an array until its
7+
A :mod:`pytato`-based array context defers the evaluation of an array until it is
88
frozen. The execution contexts for the evaluations are specific to an
9-
:class:`~arraycontext.ArrayContext` type. For ex.
9+
:class:`~arraycontext.ArrayContext` type. For example,
1010
:class:`~arraycontext.PytatoPyOpenCLArrayContext` uses :mod:`pyopencl` to
1111
JIT-compile and execute the array expressions.
1212
13-
Following :mod:`pytato`-based array context are provided:
13+
The following :mod:`pytato`-based array contexts are provided:
1414
1515
.. autoclass:: PytatoPyOpenCLArrayContext
1616
.. autoclass:: PytatoJAXArrayContext
@@ -20,6 +20,12 @@
2020
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2121
2222
.. automodule:: arraycontext.impl.pytato.compile
23+
24+
25+
Utils
26+
^^^^^
27+
28+
.. automodule:: arraycontext.impl.pytato.utils
2329
"""
2430
__copyright__ = """
2531
Copyright (C) 2020-1 University of Illinois Board of Trustees
@@ -227,7 +233,7 @@ def get_target(self):
227233

228234
class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext):
229235
"""
230-
A :class:`ArrayContext` that uses :mod:`pytato` data types to represent
236+
An :class:`ArrayContext` that uses :mod:`pytato` data types to represent
231237
the arrays targeting OpenCL for offloading operations.
232238
233239
.. attribute:: queue

arraycontext/impl/pytato/utils.py

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
__doc__ = """
2+
.. autofunction:: transfer_from_numpy
3+
.. autofunction:: transfer_to_numpy
4+
"""
5+
6+
17
__copyright__ = """
28
Copyright (C) 2021 University of Illinois Board of Trustees
39
"""
@@ -22,6 +28,7 @@
2228
THE SOFTWARE.
2329
"""
2430

31+
2532
from collections.abc import Mapping
2633
from typing import TYPE_CHECKING, Any, cast
2734

@@ -36,9 +43,10 @@
3643
make_placeholder,
3744
)
3845
from pytato.target.loopy import LoopyPyOpenCLTarget
39-
from pytato.transform import CopyMapper
46+
from pytato.transform import ArrayOrNames, CopyMapper
4047
from pytools import UniqueNameGenerator, memoize_method
4148

49+
from arraycontext import ArrayContext
4250
from arraycontext.impl.pyopencl.taggable_cl_array import Axis as ClAxis
4351

4452

@@ -125,4 +133,86 @@ def get_loopy_target(self) -> "lp.PyOpenCLTarget":
125133

126134
# }}}
127135

136+
137+
# {{{ Transfer mappers
138+
139+
class TransferFromNumpyMapper(CopyMapper):
140+
"""A mapper to transfer arrays contained in :class:`~pytato.array.DataWrapper`
141+
instances to be device arrays, using
142+
:meth:`~arraycontext.ArrayContext.from_numpy`.
143+
"""
144+
def __init__(self, actx: ArrayContext) -> None:
145+
super().__init__()
146+
self.actx = actx
147+
148+
def map_data_wrapper(self, expr: DataWrapper) -> Array:
149+
import numpy as np
150+
151+
if not isinstance(expr.data, np.ndarray):
152+
raise ValueError("TransferFromNumpyMapper: tried to transfer data that "
153+
"is already on the device")
154+
155+
# Ideally, this code should just do
156+
# return self.actx.from_numpy(expr.data).tagged(expr.tags),
157+
# but there seems to be no way to transfer the non_equality_tags in that case.
158+
new_dw = self.actx.from_numpy(expr.data)
159+
assert isinstance(new_dw, DataWrapper)
160+
161+
# https://github.com/pylint-dev/pylint/issues/3893
162+
# pylint: disable=unexpected-keyword-arg
163+
return DataWrapper(
164+
data=new_dw.data,
165+
shape=expr.shape,
166+
axes=expr.axes,
167+
tags=expr.tags,
168+
non_equality_tags=expr.non_equality_tags)
169+
170+
171+
class TransferToNumpyMapper(CopyMapper):
172+
"""A mapper to transfer arrays contained in :class:`~pytato.array.DataWrapper`
173+
instances to be :class:`numpy.ndarray` instances, using
174+
:meth:`~arraycontext.ArrayContext.to_numpy`.
175+
"""
176+
def __init__(self, actx: ArrayContext) -> None:
177+
super().__init__()
178+
self.actx = actx
179+
180+
def map_data_wrapper(self, expr: DataWrapper) -> Array:
181+
import numpy as np
182+
183+
import arraycontext.impl.pyopencl.taggable_cl_array as tga
184+
if not isinstance(expr.data, tga.TaggableCLArray):
185+
raise ValueError("TransferToNumpyMapper: tried to transfer data that "
186+
"is already on the host")
187+
188+
np_data = self.actx.to_numpy(expr.data)
189+
assert isinstance(np_data, np.ndarray)
190+
191+
# https://github.com/pylint-dev/pylint/issues/3893
192+
# pylint: disable=unexpected-keyword-arg
193+
return DataWrapper(
194+
data=np_data,
195+
shape=expr.shape,
196+
axes=expr.axes,
197+
tags=expr.tags,
198+
non_equality_tags=expr.non_equality_tags)
199+
200+
201+
def transfer_from_numpy(expr: ArrayOrNames, actx: ArrayContext) -> ArrayOrNames:
202+
"""Transfer arrays contained in :class:`~pytato.array.DataWrapper`
203+
instances to be device arrays, using
204+
:meth:`~arraycontext.ArrayContext.from_numpy`.
205+
"""
206+
return TransferFromNumpyMapper(actx)(expr)
207+
208+
209+
def transfer_to_numpy(expr: ArrayOrNames, actx: ArrayContext) -> ArrayOrNames:
210+
"""Transfer arrays contained in :class:`~pytato.array.DataWrapper`
211+
instances to be :class:`numpy.ndarray` instances, using
212+
:meth:`~arraycontext.ArrayContext.to_numpy`.
213+
"""
214+
return TransferToNumpyMapper(actx)(expr)
215+
216+
# }}}
217+
128218
# vim: foldmethod=marker

test/test_pytato_arraycontext.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,59 @@ def twice(x):
192192
assert res == 198
193193

194194

195+
def test_transfer(actx_factory):
196+
import numpy as np
197+
198+
import pytato as pt
199+
actx = actx_factory()
200+
201+
# {{{ simple tests
202+
203+
a = actx.from_numpy(np.array([0, 1, 2, 3])).tagged(FooTag())
204+
205+
from arraycontext.impl.pyopencl.taggable_cl_array import TaggableCLArray
206+
assert isinstance(a.data, TaggableCLArray)
207+
208+
from arraycontext.impl.pytato.utils import transfer_from_numpy, transfer_to_numpy
209+
210+
ah = transfer_to_numpy(a, actx)
211+
assert ah != a
212+
assert a.tags == ah.tags
213+
assert a.non_equality_tags == ah.non_equality_tags
214+
assert isinstance(ah.data, np.ndarray)
215+
216+
with pytest.raises(ValueError):
217+
_ahh = transfer_to_numpy(ah, actx)
218+
219+
ad = transfer_from_numpy(ah, actx)
220+
assert isinstance(ad.data, TaggableCLArray)
221+
assert ad != ah
222+
assert ad != a # copied DataWrappers compare unequal
223+
assert ad.tags == ah.tags
224+
assert ad.non_equality_tags == ah.non_equality_tags
225+
assert np.array_equal(a.data.get(), ad.data.get())
226+
227+
with pytest.raises(ValueError):
228+
_add = transfer_from_numpy(ad, actx)
229+
230+
# }}}
231+
232+
# {{{ test with DictOfNamedArrays
233+
234+
dag = pt.make_dict_of_named_arrays({
235+
"a_expr": a + 2
236+
})
237+
238+
dagh = transfer_to_numpy(dag, actx)
239+
assert dagh != dag
240+
assert isinstance(dagh["a_expr"].expr.bindings["_in0"].data, np.ndarray)
241+
242+
daghd = transfer_from_numpy(dagh, actx)
243+
assert isinstance(daghd["a_expr"].expr.bindings["_in0"].data, TaggableCLArray)
244+
245+
# }}}
246+
247+
195248
if __name__ == "__main__":
196249
import sys
197250
if len(sys.argv) > 1:

0 commit comments

Comments
 (0)