Skip to content

Commit 3c9ec8c

Browse files
committed
Implement actx.np.zeros
1 parent 0e3f321 commit 3c9ec8c

File tree

8 files changed

+40
-19
lines changed

8 files changed

+40
-19
lines changed

arraycontext/context.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@
171171
TypeVar,
172172
Union,
173173
)
174+
from warnings import warn
174175

175176
import numpy as np
176177

@@ -249,10 +250,6 @@ class ArrayContext(ABC):
249250
250251
.. versionadded:: 2020.2
251252
252-
.. automethod:: empty
253-
.. automethod:: zeros
254-
.. automethod:: empty_like
255-
.. automethod:: zeros_like
256253
.. automethod:: from_numpy
257254
.. automethod:: to_numpy
258255
.. automethod:: call_loopy
@@ -293,9 +290,9 @@ class ArrayContext(ABC):
293290
def __init__(self) -> None:
294291
self.np = self._get_fake_numpy_namespace()
295292

293+
@abstractmethod
296294
def _get_fake_numpy_namespace(self) -> Any:
297-
from .fake_numpy import BaseFakeNumpyNamespace
298-
return BaseFakeNumpyNamespace(self)
295+
...
299296

300297
def __hash__(self) -> int:
301298
raise TypeError(f"unhashable type: '{type(self).__name__}'")
@@ -306,22 +303,23 @@ def empty(self,
306303
dtype: "np.dtype[Any]") -> Array:
307304
pass
308305

309-
@abstractmethod
310306
def zeros(self,
311307
shape: Union[int, Tuple[int, ...]],
312308
dtype: "np.dtype[Any]") -> Array:
313-
pass
309+
warn(f"{type(self).__name__}.zeros is deprecated and will stop "
310+
"working in 2025. Use actx.np.zeros instead.",
311+
DeprecationWarning, stacklevel=2)
312+
313+
return self.np.zeros(shape, dtype)
314314

315315
def empty_like(self, ary: Array) -> Array:
316-
from warnings import warn
317316
warn(f"{type(self).__name__}.empty_like is deprecated and will stop "
318317
"working in 2023. Prefer actx.np.zeros_like instead.",
319318
DeprecationWarning, stacklevel=2)
320319

321320
return self.empty(shape=ary.shape, dtype=ary.dtype)
322321

323322
def zeros_like(self, ary: Array) -> Array:
324-
from warnings import warn
325323
warn(f"{type(self).__name__}.zeros_like is deprecated and will stop "
326324
"working in 2023. Use actx.np.zeros_like instead.",
327325
DeprecationWarning, stacklevel=2)

arraycontext/fake_numpy.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525

2626
import operator
27+
from abc import ABC, abstractmethod
2728
from typing import Any
2829

2930
import numpy as np
@@ -34,7 +35,7 @@
3435

3536
# {{{ BaseFakeNumpyNamespace
3637

37-
class BaseFakeNumpyNamespace:
38+
class BaseFakeNumpyNamespace(ABC):
3839
def __init__(self, array_context):
3940
self._array_context = array_context
4041
self.linalg = self._get_fake_numpy_linalg_namespace()
@@ -95,6 +96,14 @@ def _get_fake_numpy_linalg_namespace(self):
9596
# "interp",
9697
})
9798

99+
@abstractmethod
100+
def zeros(self, shape, dtype):
101+
...
102+
103+
@abstractmethod
104+
def zeros_like(self, ary):
105+
...
106+
98107
def conjugate(self, x):
99108
# NOTE: conjugate distributes over object arrays, but it looks for a
100109
# `conjugate` ufunc, while some implementations only have the shorter

arraycontext/impl/jax/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def _wrapper(ary):
9090
def empty(self, shape, dtype):
9191
from warnings import warn
9292
warn(f"{type(self).__name__}.empty is deprecated and will stop "
93-
"working in 2023. Prefer actx.zeros instead.",
93+
"working in 2023. Prefer actx.np.zeros instead.",
9494
DeprecationWarning, stacklevel=2)
9595

9696
import jax.numpy as jnp

arraycontext/impl/jax/fake_numpy.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ def __getattr__(self, name):
5656

5757
# {{{ array creation routines
5858

59+
def zeros(self, shape, dtype):
60+
return jnp.zeros(shape=shape, dtype=dtype)
61+
5962
def empty_like(self, ary):
6063
from warnings import warn
6164
warn(f"{type(self._array_context).__name__}.np.empty_like is "

arraycontext/impl/pyopencl/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def _wrapper(ary):
201201
def empty(self, shape, dtype):
202202
from warnings import warn
203203
warn(f"{type(self).__name__}.empty is deprecated and will stop "
204-
"working in 2023. Prefer actx.zeros instead.",
204+
"working in 2023. Prefer actx.np.zeros instead.",
205205
DeprecationWarning, stacklevel=2)
206206

207207
import arraycontext.impl.pyopencl.taggable_cl_array as tga

arraycontext/impl/pyopencl/fake_numpy.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,15 @@
3939
rec_multimap_reduce_array_container,
4040
)
4141
from arraycontext.fake_numpy import BaseFakeNumpyLinalgNamespace
42+
from arraycontext.impl.pyopencl.taggable_cl_array import TaggableCLArray
4243
from arraycontext.loopy import LoopyBasedFakeNumpyNamespace
4344

4445

4546
try:
4647
import pyopencl as cl # noqa: F401
4748
import pyopencl.array as cl_array
49+
50+
from arraycontext.impl.pyopencl.taggable_cl_array import TaggableCLArray
4851
except ImportError:
4952
pass
5053

@@ -60,6 +63,11 @@ def _get_fake_numpy_linalg_namespace(self):
6063

6164
# {{{ array creation routines
6265

66+
def zeros(self, shape, dtype) -> TaggableCLArray:
67+
import arraycontext.impl.pyopencl.taggable_cl_array as tga
68+
return tga.zeros(self._array_context.queue, shape, dtype,
69+
allocator=self._array_context.allocator)
70+
6371
def empty_like(self, ary):
6472
from warnings import warn
6573
warn(f"{type(self._array_context).__name__}.np.empty_like is "

arraycontext/impl/pytato/fake_numpy.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,9 @@ def __getattr__(self, name):
8484

8585
# {{{ array creation routines
8686

87+
def zeros(self, shape, dtype):
88+
return pt.zeros(shape, dtype)
89+
8790
def zeros_like(self, ary):
8891
def _zeros_like(array):
8992
return self._array_context.zeros(

test/test_arraycontext.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1367,7 +1367,7 @@ def test_leaf_array_type_broadcasting(actx_factory):
13671367
# test support for https://github.com/inducer/arraycontext/issues/49
13681368
actx = actx_factory()
13691369

1370-
foo = Foo(DOFArray(actx, (actx.zeros(3, dtype=np.float64) + 41, )))
1370+
foo = Foo(DOFArray(actx, (actx.np.zeros(3, dtype=np.float64) + 41, )))
13711371
bar = foo + 4
13721372
baz = foo + actx.from_numpy(4*np.ones((3, )))
13731373
qux = actx.from_numpy(4*np.ones((3, ))) + foo
@@ -1510,7 +1510,7 @@ def _twice(x):
15101510

15111511
actx = actx_factory()
15121512
ones = actx.thaw(actx.freeze(
1513-
actx.zeros(shape=(10, 4), dtype=np.float64) + 1
1513+
actx.np.zeros(shape=(10, 4), dtype=np.float64) + 1
15141514
))
15151515
np.testing.assert_allclose(actx.to_numpy(_twice(ones)),
15161516
actx.to_numpy(actx.compile(_twice)(ones)))
@@ -1573,7 +1573,7 @@ def test_taggable_cl_array_tags(actx_factory):
15731573
def test_to_numpy_on_frozen_arrays(actx_factory):
15741574
# See https://github.com/inducer/arraycontext/issues/159
15751575
actx = actx_factory()
1576-
u = actx.freeze(actx.zeros(10, dtype="float64")+1)
1576+
u = actx.freeze(actx.np.zeros(10, dtype="float64")+1)
15771577
np.testing.assert_allclose(actx.to_numpy(u), 1)
15781578
np.testing.assert_allclose(actx.to_numpy(u), 1)
15791579

@@ -1592,7 +1592,7 @@ class ExampleTag(Tag):
15921592
ary = tag_axes(actx, {0: ExampleTag()},
15931593
actx.tag(
15941594
ExampleTag(),
1595-
actx.zeros((20, 20), dtype=np.float64)))
1595+
actx.np.zeros((20, 20), dtype=np.float64)))
15961596

15971597
assert ary.tags_of_type(ExampleTag)
15981598
assert ary.axes[0].tags_of_type(ExampleTag)
@@ -1606,11 +1606,11 @@ def test_compile_anonymous_function(actx_factory):
16061606
actx = actx_factory()
16071607
f = actx.compile(lambda x: 2*x+40)
16081608
np.testing.assert_allclose(
1609-
actx.to_numpy(f(1+actx.zeros((10, 4), "float64"))),
1609+
actx.to_numpy(f(1+actx.np.zeros((10, 4), "float64"))),
16101610
42)
16111611
f = actx.compile(partial(lambda x: 2*x+40))
16121612
np.testing.assert_allclose(
1613-
actx.to_numpy(f(1+actx.zeros((10, 4), "float64"))),
1613+
actx.to_numpy(f(1+actx.np.zeros((10, 4), "float64"))),
16141614
42)
16151615

16161616

0 commit comments

Comments
 (0)