diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index 3665cf6d..bfbc14f7 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -29,7 +29,7 @@ """ import sys -from .context import ArrayContext, Array, Scalar +from .context import ArrayContext, Array, Scalar, tag_axes from .transform_metadata import (CommonSubexpressionTag, ElementwiseMapKernelTag) @@ -78,7 +78,7 @@ __all__ = ( - "ArrayContext", "Scalar", "Array", + "ArrayContext", "Scalar", "Array", "tag_axes", "CommonSubexpressionTag", "ElementwiseMapKernelTag", diff --git a/arraycontext/context.py b/arraycontext/context.py index cc157a4f..b6278e20 100644 --- a/arraycontext/context.py +++ b/arraycontext/context.py @@ -80,6 +80,20 @@ .. autoclass:: Array .. autoclass:: Scalar .. autoclass:: ArrayContext +.. autofunction:: tag_axes + +Internal typing helpers (do not import) +--------------------------------------- + +.. currentmodule:: arraycontext.context + +This is only here because the documentation tool wants it. + +.. class:: SelfType + +.. class:: ArrayT + + A type variable, with a lower bound of :class:`Array`. """ @@ -109,8 +123,8 @@ from abc import ABC, abstractmethod from typing import ( - Any, Callable, Dict, Optional, Tuple, Union, - TYPE_CHECKING) + Any, Callable, Dict, Optional, Tuple, Union, Mapping, + TYPE_CHECKING, TypeVar) import numpy as np from pytools import memoize_method @@ -129,6 +143,8 @@ except ImportError: from typing_extensions import Protocol # type: ignore[misc] +SelfType = TypeVar("SelfType") + class Array(Protocol): """A :class:`~typing.Protocol` for the array type supported by @@ -150,6 +166,9 @@ def dtype(self) -> "np.dtype[Any]": ... +ArrayT = TypeVar("ArrayT", bound=Array) + + class Scalar(Protocol): """A :class:`~typing.Protocol` for the scalar type supported by :class:`ArrayContext`. @@ -322,7 +341,7 @@ def thaw(self, array: Array) -> Array: @abstractmethod def tag(self, tags: ToTagSetConvertible, - array: Array) -> Array: + array: ArrayT) -> ArrayT: """If the array type used by the array context is capable of capturing metadata, return a version of *array* with the *tags* applied. *array* itself is not modified. @@ -335,7 +354,7 @@ def tag(self, @abstractmethod def tag_axis(self, iaxis: int, tags: ToTagSetConvertible, - array: Array) -> Array: + array: ArrayT) -> ArrayT: """If the array type used by the array context is capable of capturing metadata, return a version of *array* in which axis number *iaxis* has the *tags* applied. *array* itself is not modified. @@ -406,7 +425,7 @@ def einsum(self, return self.tag(tagged, out_ary) @abstractmethod - def clone(self) -> "ArrayContext": + def clone(self: SelfType) -> SelfType: """If possible, return a version of *self* that is semantically equivalent (i.e. implements all array operations in the same way) but is a separate object. May return *self* if that is not possible. @@ -470,4 +489,23 @@ def permits_advanced_indexing(self) -> bool: # }}} + +# {{{ tagging helpers + +def tag_axes( + actx: ArrayContext, + dim_to_tags: Mapping[int, ToTagSetConvertible], + ary: ArrayT) -> ArrayT: + """ + Return a copy of *ary* with the axes in *dim_to_tags* tagged with their + corresponding tags. Equivalent to repeated application of + :meth:`ArrayContext.tag_axis`. + """ + for iaxis, tags in dim_to_tags.items(): + ary = actx.tag_axis(iaxis, tags, ary) + + return ary + +# }}} + # vim: foldmethod=marker diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 08489872..154af2f9 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -35,8 +35,9 @@ FirstAxisIsElementsTag, PyOpenCLArrayContext, PytatoPyOpenCLArrayContext, + EagerJAXArrayContext, ArrayContainer, - to_numpy) + to_numpy, tag_axes) from arraycontext import ( # noqa: F401 pytest_generate_tests_for_array_contexts, ) @@ -1442,7 +1443,7 @@ def _twice(x): # }}} -# {{{ +# {{{ test_taggable_cl_array_tags def test_taggable_cl_array_tags(actx_factory): actx = actx_factory() @@ -1497,6 +1498,27 @@ def test_to_numpy_on_frozen_arrays(actx_factory): np.testing.assert_allclose(to_numpy(u, actx), 1) +def test_tagging(actx_factory): + actx = actx_factory() + + if isinstance(actx, EagerJAXArrayContext): + pytest.skip("Eager JAX has no tagging support") + + from pytools.tag import Tag + + class ExampleTag(Tag): + pass + + ary = tag_axes(actx, {0: ExampleTag()}, + actx.tag( + ExampleTag(), + actx.zeros((20, 20), dtype=np.float64))) + + assert ary.tags_of_type(ExampleTag) + assert ary.axes[0].tags_of_type(ExampleTag) + assert not ary.axes[1].tags_of_type(ExampleTag) + + if __name__ == "__main__": import sys if len(sys.argv) > 1: