Skip to content

Commit fd95813

Browse files
lint
1 parent 9c56443 commit fd95813

File tree

3 files changed

+29
-4
lines changed

3 files changed

+29
-4
lines changed

.github/workflows/ci.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ jobs:
5656
curl -L -O https://tiker.net/ci-support-v0
5757
. ./ci-support-v0
5858
59+
CONDA_ENVIRONMENT=.test-conda-env-py3.yml
60+
echo "- cupy" >> "$CONDA_ENVIRONMENT"
61+
5962
build_py_project_in_conda_env
6063
python -m pip install mypy pytest
6164
./run-mypy.sh

arraycontext/impl/cupy/__init__.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@
3333
THE SOFTWARE.
3434
"""
3535

36-
from typing import Any
36+
from typing import Any, overload
37+
38+
import numpy as np
3739

3840
import loopy as lp
3941
from pytools.tag import ToTagSetConvertible
@@ -44,6 +46,7 @@
4446
ArrayContext,
4547
ArrayOrContainerOrScalar,
4648
ArrayOrContainerOrScalarT,
49+
ContainerOrScalarT,
4750
NumpyOrContainerOrScalar,
4851
UntransformedCodeWarning,
4952
)
@@ -83,12 +86,28 @@ def _get_fake_numpy_namespace(self):
8386
def clone(self):
8487
return type(self)()
8588

89+
@overload
90+
def from_numpy(self, array: np.ndarray) -> Array:
91+
...
92+
93+
@overload
94+
def from_numpy(self, array: ContainerOrScalarT) -> ContainerOrScalarT:
95+
...
96+
8697
def from_numpy(self,
8798
array: NumpyOrContainerOrScalar
8899
) -> ArrayOrContainerOrScalar:
89100
import cupy as cp
90101
return cp.array(array)
91102

103+
@overload
104+
def to_numpy(self, array: Array) -> np.ndarray:
105+
...
106+
107+
@overload
108+
def to_numpy(self, array: ContainerOrScalarT) -> ContainerOrScalarT:
109+
...
110+
92111
def to_numpy(self,
93112
array: ArrayOrContainerOrScalar
94113
) -> NumpyOrContainerOrScalar:

arraycontext/impl/cupy/fake_numpy.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
from __future__ import annotations
2+
3+
14
__copyright__ = """
25
Copyright (C) 2024 University of Illinois Board of Trustees
36
"""
@@ -164,7 +167,7 @@ def array_equal(self, a: ArrayOrContainer, b: ArrayOrContainer) -> Array:
164167
[(true_ary if kx_i == ky_i else false_ary)
165168
and self.array_equal(x_i, y_i)
166169
for (kx_i, x_i), (ky_i, y_i)
167-
in zip(serialized_x, serialized_y)],
170+
in zip(serialized_x, serialized_y, strict=True)],
168171
true_ary)
169172

170173
def arange(self, *args, **kwargs):
@@ -176,14 +179,14 @@ def linspace(self, *args, **kwargs):
176179
return cp.linspace(*args, **kwargs)
177180

178181
def zeros_like(self, ary):
179-
if isinstance(ary, (int, float, complex)):
182+
if isinstance(ary, int | float | complex):
180183
import cupy as cp
181184
# Cupy does not support zeros_like with scalar arguments
182185
ary = cp.array(ary)
183186
return rec_map_array_container(cp.zeros_like, ary)
184187

185188
def ones_like(self, ary):
186-
if isinstance(ary, (int, float, complex)):
189+
if isinstance(ary, int | float | complex):
187190
import cupy as cp
188191
# Cupy does not support ones_like with scalar arguments
189192
ary = cp.array(ary)

0 commit comments

Comments
 (0)