Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,17 @@ on:
- cron: '17 3 * * 0'

jobs:
flake8:
name: Flake8
ruff:
name: Ruff
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
-
uses: actions/setup-python@v5
with:
# matches compat target in setup.py
python-version: '3.8'
- name: "Main Script"
run: |
curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/main/prepare-and-run-flake8.sh
. ./prepare-and-run-flake8.sh "$(basename $GITHUB_REPOSITORY)" test examples
pip install ruff
ruff check

pylint:
name: Pylint
Expand Down
8 changes: 4 additions & 4 deletions .gitlab-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,12 @@ Documentation:
tags:
- python3

Flake8:
Ruff:
script:
- curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/main/prepare-and-run-flake8.sh
- . ./prepare-and-run-flake8.sh "$CI_PROJECT_NAME" test examples
- pipx install ruff
- ruff check
tags:
- python3
- docker-runner
except:
- tags

Expand Down
208 changes: 118 additions & 90 deletions arraycontext/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,86 +28,126 @@
THE SOFTWARE.
"""

import sys

from .container import (
ArrayContainer, ArrayContainerT, NotAnArrayContainerError, deserialize_container,
get_container_context_opt, get_container_context_recursively,
get_container_context_recursively_opt, is_array_container,
is_array_container_type, register_multivector_as_array_container,
serialize_container)
ArrayContainer,
ArrayContainerT,
NotAnArrayContainerError,
deserialize_container,
get_container_context_opt,
get_container_context_recursively,
get_container_context_recursively_opt,
is_array_container,
is_array_container_type,
register_multivector_as_array_container,
serialize_container,
)
from .container.arithmetic import with_container_arithmetic
from .container.dataclass import dataclass_array_container
from .container.traversal import (
flat_size_and_dtype, flatten, freeze, from_numpy, map_array_container,
map_reduce_array_container, mapped_over_array_containers,
multimap_array_container, multimap_reduce_array_container,
multimapped_over_array_containers, outer, rec_map_array_container,
rec_map_reduce_array_container, rec_multimap_array_container,
rec_multimap_reduce_array_container, stringify_array_container_tree, thaw,
to_numpy, unflatten, with_array_context)
flat_size_and_dtype,
flatten,
freeze,
from_numpy,
map_array_container,
map_reduce_array_container,
mapped_over_array_containers,
multimap_array_container,
multimap_reduce_array_container,
multimapped_over_array_containers,
outer,
rec_map_array_container,
rec_map_reduce_array_container,
rec_multimap_array_container,
rec_multimap_reduce_array_container,
stringify_array_container_tree,
thaw,
to_numpy,
unflatten,
with_array_context,
)
from .context import (
Array, ArrayContext, ArrayOrContainer, ArrayOrContainerOrScalar,
ArrayOrContainerOrScalarT, ArrayOrContainerT, ArrayT, Scalar, ScalarLike,
tag_axes)
Array,
ArrayContext,
ArrayOrContainer,
ArrayOrContainerOrScalar,
ArrayOrContainerOrScalarT,
ArrayOrContainerT,
ArrayT,
Scalar,
ScalarLike,
tag_axes,
)
from .impl.jax import EagerJAXArrayContext
from .impl.pyopencl import PyOpenCLArrayContext
from .impl.pytato import PytatoJAXArrayContext, PytatoPyOpenCLArrayContext
from .loopy import make_loopy_program
# deprecated, remove in 2022.
from .metadata import _FirstAxisIsElementsTag
from .pytest import (
PytestArrayContextFactory, PytestPyOpenCLArrayContextFactory,
PytestArrayContextFactory,
PytestPyOpenCLArrayContextFactory,
pytest_generate_tests_for_array_contexts,
pytest_generate_tests_for_pyopencl_array_context)
pytest_generate_tests_for_pyopencl_array_context,
)
from .transform_metadata import CommonSubexpressionTag, ElementwiseMapKernelTag


__all__ = (
"ArrayContext", "Scalar", "Array",
"Scalar", "ScalarLike",
"Array", "ArrayT",
"ArrayOrContainer", "ArrayOrContainerT",
"ArrayOrContainerOrScalar", "ArrayOrContainerOrScalarT",
"tag_axes",

"CommonSubexpressionTag",
"ElementwiseMapKernelTag",

"ArrayContainer", "ArrayContainerT",
"NotAnArrayContainerError",
"is_array_container", "is_array_container_type",
"get_container_context_opt",
"get_container_context_recursively_opt",
"get_container_context_recursively",
"serialize_container", "deserialize_container",
"register_multivector_as_array_container",
"with_container_arithmetic",
"dataclass_array_container",

"stringify_array_container_tree",
"map_array_container", "multimap_array_container",
"rec_map_array_container", "rec_multimap_array_container",
"mapped_over_array_containers",
"multimapped_over_array_containers",
"map_reduce_array_container", "multimap_reduce_array_container",
"rec_map_reduce_array_container", "rec_multimap_reduce_array_container",
"thaw", "freeze",
"flatten", "unflatten", "flat_size_and_dtype",
"from_numpy", "to_numpy", "with_array_context",
"outer",

"PyOpenCLArrayContext", "PytatoPyOpenCLArrayContext",
"PytatoJAXArrayContext",
"EagerJAXArrayContext",

"make_loopy_program",

"PytestArrayContextFactory",
"PytestPyOpenCLArrayContextFactory",
"pytest_generate_tests_for_array_contexts",
"pytest_generate_tests_for_pyopencl_array_context"
)
"Array",
"Array",
"ArrayContainer",
"ArrayContainerT",
"ArrayContext",
"ArrayOrContainer",
"ArrayOrContainerOrScalar",
"ArrayOrContainerOrScalarT",
"ArrayOrContainerT",
"ArrayT",
"CommonSubexpressionTag",
"EagerJAXArrayContext",
"ElementwiseMapKernelTag",
"NotAnArrayContainerError",
"PyOpenCLArrayContext",
"PytatoJAXArrayContext",
"PytatoPyOpenCLArrayContext",
"PytestArrayContextFactory",
"PytestPyOpenCLArrayContextFactory",
"Scalar",
"Scalar",
"ScalarLike",
"dataclass_array_container",
"deserialize_container",
"flat_size_and_dtype",
"flatten",
"freeze",
"from_numpy",
"get_container_context_opt",
"get_container_context_recursively",
"get_container_context_recursively_opt",
"is_array_container",
"is_array_container_type",
"make_loopy_program",
"map_array_container",
"map_reduce_array_container",
"mapped_over_array_containers",
"multimap_array_container",
"multimap_reduce_array_container",
"multimapped_over_array_containers",
"outer",
"pytest_generate_tests_for_array_contexts",
"pytest_generate_tests_for_pyopencl_array_context",
"rec_map_array_container",
"rec_map_reduce_array_container",
"rec_multimap_array_container",
"rec_multimap_reduce_array_container",
"register_multivector_as_array_container",
"serialize_container",
"stringify_array_container_tree",
"tag_axes",
"thaw",
"to_numpy",
"unflatten",
"with_array_context",
"with_container_arithmetic"
)


# {{{ deprecation handling
Expand All @@ -127,33 +167,21 @@ def _deprecated_acf():
"get_container_context": (
"get_container_context_opt",
get_container_context_opt, 2022),
"FirstAxisIsElementsTag": (
"meshmode.transform_metadata.FirstAxisIsElementsTag",
_FirstAxisIsElementsTag, 2022),
"_acf": ("<no replacement yet>", _deprecated_acf, 2022),
"DeviceArray": ("Array", Array, 2023),
"DeviceScalar": ("Scalar", Scalar, 2023),
}

if sys.version_info >= (3, 7):
def __getattr__(name):
replacement_and_obj = _depr_name_to_replacement_and_obj.get(name, None)
if replacement_and_obj is not None:
replacement, obj, year = replacement_and_obj
from warnings import warn
warn(f"'arraycontext.{name}' is deprecated. "
f"Use '{replacement}' instead. "
f"'arraycontext.{name}' will continue to work until {year}.",
DeprecationWarning, stacklevel=2)
return obj
else:
raise AttributeError(name)
else:
FirstAxisIsElementsTag = _FirstAxisIsElementsTag
_acf = _deprecated_acf
get_container_context = get_container_context_opt
DeviceArray = Array
DeviceScalar = Scalar

def __getattr__(name):
replacement_and_obj = _depr_name_to_replacement_and_obj.get(name, None)
if replacement_and_obj is not None:
replacement, obj, year = replacement_and_obj
from warnings import warn
warn(f"'arraycontext.{name}' is deprecated. "
f"Use '{replacement}' instead. "
f"'arraycontext.{name}' will continue to work until {year}.",
DeprecationWarning, stacklevel=2)
return obj
else:
raise AttributeError(name)

# }}}

Expand Down
8 changes: 4 additions & 4 deletions arraycontext/container/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,17 +339,17 @@ def get_container_context_recursively(ary: ArrayContainer) -> Optional[ArrayCont
# FYI: This doesn't, and never should, make arraycontext directly depend on pymbolic.
# (Though clearly there exists a dependency via loopy.)

def _serialize_multivec_as_container(mv: "MultiVector") -> Iterable[Tuple[Any, Any]]:
def _serialize_multivec_as_container(mv: MultiVector) -> Iterable[Tuple[Any, Any]]:
return list(mv.data.items())


def _deserialize_multivec_as_container(template: "MultiVector",
iterable: Iterable[Tuple[Any, Any]]) -> "MultiVector":
def _deserialize_multivec_as_container(template: MultiVector,
iterable: Iterable[Tuple[Any, Any]]) -> MultiVector:
from pymbolic.geometric_algebra import MultiVector
return MultiVector(dict(iterable), space=template.space)


def _get_container_context_opt_from_multivec(mv: "MultiVector") -> None:
def _get_container_context_opt_from_multivec(mv: MultiVector) -> None:
return None


Expand Down
2 changes: 1 addition & 1 deletion arraycontext/container/arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def tup_str(t: Tuple[str, ...]) -> str:
if not t:
return "()"
else:
return "(%s,)" % ", ".join(t)
return "({},)".format(", ".join(t))

gen(f"cls._outer_bcast_types = {tup_str(outer_bcast_type_names)}")
gen(f"cls._bcast_numpy_array = {bcast_numpy_array}")
Expand Down
4 changes: 3 additions & 1 deletion arraycontext/container/dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,9 @@ def is_array_field(f: Field) -> bool:
# * `_BaseGenericAlias` catches `List`, `Tuple`, etc.
# * `_SpecialForm` catches `Any`, `Literal`, etc.
from typing import ( # type: ignore[attr-defined]
_BaseGenericAlias, _SpecialForm)
_BaseGenericAlias,
_SpecialForm,
)
if isinstance(f.type, (_BaseGenericAlias, _SpecialForm)):
# NOTE: anything except a Union is not allowed
raise TypeError(
Expand Down
Loading