Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions arraycontext/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
"""

import sys
from .context import ArrayContext, Array, Scalar
from .context import ArrayContext, Array, Scalar, tag_axes

from .transform_metadata import (CommonSubexpressionTag,
ElementwiseMapKernelTag)
Expand Down Expand Up @@ -78,7 +78,7 @@


__all__ = (
"ArrayContext", "Scalar", "Array",
"ArrayContext", "Scalar", "Array", "tag_axes",

"CommonSubexpressionTag",
"ElementwiseMapKernelTag",
Expand Down
48 changes: 43 additions & 5 deletions arraycontext/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,20 @@
.. autoclass:: Array
.. autoclass:: Scalar
.. autoclass:: ArrayContext
.. autofunction:: tag_axes

Internal typing helpers (do not import)
---------------------------------------

.. currentmodule:: arraycontext.context

This is only here because the documentation tool wants it.

.. class:: SelfType

.. class:: ArrayT

A type variable, with a lower bound of :class:`Array`.
"""


Expand Down Expand Up @@ -109,8 +123,8 @@

from abc import ABC, abstractmethod
from typing import (
Any, Callable, Dict, Optional, Tuple, Union,
TYPE_CHECKING)
Any, Callable, Dict, Optional, Tuple, Union, Mapping,
TYPE_CHECKING, TypeVar)

import numpy as np
from pytools import memoize_method
Expand All @@ -129,6 +143,8 @@
except ImportError:
from typing_extensions import Protocol # type: ignore[misc]

SelfType = TypeVar("SelfType")


class Array(Protocol):
"""A :class:`~typing.Protocol` for the array type supported by
Expand All @@ -150,6 +166,9 @@ def dtype(self) -> "np.dtype[Any]":
...


ArrayT = TypeVar("ArrayT", bound=Array)


class Scalar(Protocol):
"""A :class:`~typing.Protocol` for the scalar type supported by
:class:`ArrayContext`.
Expand Down Expand Up @@ -322,7 +341,7 @@ def thaw(self, array: Array) -> Array:
@abstractmethod
def tag(self,
tags: ToTagSetConvertible,
array: Array) -> Array:
array: ArrayT) -> ArrayT:
"""If the array type used by the array context is capable of capturing
metadata, return a version of *array* with the *tags* applied. *array*
itself is not modified.
Expand All @@ -335,7 +354,7 @@ def tag(self,
@abstractmethod
def tag_axis(self,
iaxis: int, tags: ToTagSetConvertible,
array: Array) -> Array:
array: ArrayT) -> ArrayT:
"""If the array type used by the array context is capable of capturing
metadata, return a version of *array* in which axis number *iaxis* has
the *tags* applied. *array* itself is not modified.
Expand Down Expand Up @@ -406,7 +425,7 @@ def einsum(self,
return self.tag(tagged, out_ary)

@abstractmethod
def clone(self) -> "ArrayContext":
def clone(self: SelfType) -> SelfType:
"""If possible, return a version of *self* that is semantically
equivalent (i.e. implements all array operations in the same way)
but is a separate object. May return *self* if that is not possible.
Expand Down Expand Up @@ -470,4 +489,23 @@ def permits_advanced_indexing(self) -> bool:

# }}}


# {{{ tagging helpers

def tag_axes(
actx: ArrayContext,
dim_to_tags: Mapping[int, ToTagSetConvertible],
ary: ArrayT) -> ArrayT:
Comment on lines +495 to +498
Copy link
Owner Author

@inducer inducer Jun 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that I've changed the argument order here, compared to the original in inducer/meshmode#284. The rationale for this is that the array might be a complicated expression, and so I find it easiest if it comes last (because you don't have to remember that you still "owe" otherwise simpler arguments of tag_axes. Also for consistency with tag_axis.

"""
Return a copy of *ary* with the axes in *dim_to_tags* tagged with their
corresponding tags. Equivalent to repeated application of
:meth:`ArrayContext.tag_axis`.
"""
for iaxis, tags in dim_to_tags.items():
ary = actx.tag_axis(iaxis, tags, ary)

return ary

# }}}

# vim: foldmethod=marker
26 changes: 24 additions & 2 deletions test/test_arraycontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@
FirstAxisIsElementsTag,
PyOpenCLArrayContext,
PytatoPyOpenCLArrayContext,
EagerJAXArrayContext,
ArrayContainer,
to_numpy)
to_numpy, tag_axes)
from arraycontext import ( # noqa: F401
pytest_generate_tests_for_array_contexts,
)
Expand Down Expand Up @@ -1442,7 +1443,7 @@ def _twice(x):
# }}}


# {{{
# {{{ test_taggable_cl_array_tags

def test_taggable_cl_array_tags(actx_factory):
actx = actx_factory()
Expand Down Expand Up @@ -1497,6 +1498,27 @@ def test_to_numpy_on_frozen_arrays(actx_factory):
np.testing.assert_allclose(to_numpy(u, actx), 1)


def test_tagging(actx_factory):
actx = actx_factory()

if isinstance(actx, EagerJAXArrayContext):
pytest.skip("Eager JAX has no tagging support")

from pytools.tag import Tag

class ExampleTag(Tag):
pass

ary = tag_axes(actx, {0: ExampleTag()},
actx.tag(
ExampleTag(),
actx.zeros((20, 20), dtype=np.float64)))

assert ary.tags_of_type(ExampleTag)
assert ary.axes[0].tags_of_type(ExampleTag)
assert not ary.axes[1].tags_of_type(ExampleTag)


if __name__ == "__main__":
import sys
if len(sys.argv) > 1:
Expand Down