From 7c9e3bd1eb1fc1579b1a55eaca76f8133cd80276 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Fri, 10 Jun 2022 16:23:29 -0500 Subject: [PATCH 1/2] Make typing in ArrayContext more precise using TypeVars --- arraycontext/context.py | 26 ++++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/arraycontext/context.py b/arraycontext/context.py index cc157a4f..72bf8c7d 100644 --- a/arraycontext/context.py +++ b/arraycontext/context.py @@ -80,6 +80,19 @@ .. autoclass:: Array .. autoclass:: Scalar .. autoclass:: ArrayContext + +Internal typing helpers (do not import) +--------------------------------------- + +.. currentmodule:: arraycontext.context + +This is only here because the documentation tool wants it. + +.. class:: SelfType + +.. class:: ArrayT + + A type variable, with a lower bound of :class:`Array`. """ @@ -110,7 +123,7 @@ from abc import ABC, abstractmethod from typing import ( Any, Callable, Dict, Optional, Tuple, Union, - TYPE_CHECKING) + TYPE_CHECKING, TypeVar) import numpy as np from pytools import memoize_method @@ -129,6 +142,8 @@ except ImportError: from typing_extensions import Protocol # type: ignore[misc] +SelfType = TypeVar("SelfType") + class Array(Protocol): """A :class:`~typing.Protocol` for the array type supported by @@ -150,6 +165,9 @@ def dtype(self) -> "np.dtype[Any]": ... +ArrayT = TypeVar("ArrayT", bound=Array) + + class Scalar(Protocol): """A :class:`~typing.Protocol` for the scalar type supported by :class:`ArrayContext`. @@ -322,7 +340,7 @@ def thaw(self, array: Array) -> Array: @abstractmethod def tag(self, tags: ToTagSetConvertible, - array: Array) -> Array: + array: ArrayT) -> ArrayT: """If the array type used by the array context is capable of capturing metadata, return a version of *array* with the *tags* applied. *array* itself is not modified. @@ -335,7 +353,7 @@ def tag(self, @abstractmethod def tag_axis(self, iaxis: int, tags: ToTagSetConvertible, - array: Array) -> Array: + array: ArrayT) -> ArrayT: """If the array type used by the array context is capable of capturing metadata, return a version of *array* in which axis number *iaxis* has the *tags* applied. *array* itself is not modified. @@ -406,7 +424,7 @@ def einsum(self, return self.tag(tagged, out_ary) @abstractmethod - def clone(self) -> "ArrayContext": + def clone(self: SelfType) -> SelfType: """If possible, return a version of *self* that is semantically equivalent (i.e. implements all array operations in the same way) but is a separate object. May return *self* if that is not possible. From 9f84329bd1a492fe19a2a3b86a831471d5fe4cf6 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Fri, 10 Jun 2022 16:24:23 -0500 Subject: [PATCH 2/2] Add tag_axes Co-authored-by: Kaushik Kulkarni --- arraycontext/__init__.py | 4 ++-- arraycontext/context.py | 22 +++++++++++++++++++++- test/test_arraycontext.py | 26 ++++++++++++++++++++++++-- 3 files changed, 47 insertions(+), 5 deletions(-) diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index 3665cf6d..bfbc14f7 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -29,7 +29,7 @@ """ import sys -from .context import ArrayContext, Array, Scalar +from .context import ArrayContext, Array, Scalar, tag_axes from .transform_metadata import (CommonSubexpressionTag, ElementwiseMapKernelTag) @@ -78,7 +78,7 @@ __all__ = ( - "ArrayContext", "Scalar", "Array", + "ArrayContext", "Scalar", "Array", "tag_axes", "CommonSubexpressionTag", "ElementwiseMapKernelTag", diff --git a/arraycontext/context.py b/arraycontext/context.py index 72bf8c7d..b6278e20 100644 --- a/arraycontext/context.py +++ b/arraycontext/context.py @@ -80,6 +80,7 @@ .. autoclass:: Array .. autoclass:: Scalar .. autoclass:: ArrayContext +.. autofunction:: tag_axes Internal typing helpers (do not import) --------------------------------------- @@ -122,7 +123,7 @@ from abc import ABC, abstractmethod from typing import ( - Any, Callable, Dict, Optional, Tuple, Union, + Any, Callable, Dict, Optional, Tuple, Union, Mapping, TYPE_CHECKING, TypeVar) import numpy as np @@ -488,4 +489,23 @@ def permits_advanced_indexing(self) -> bool: # }}} + +# {{{ tagging helpers + +def tag_axes( + actx: ArrayContext, + dim_to_tags: Mapping[int, ToTagSetConvertible], + ary: ArrayT) -> ArrayT: + """ + Return a copy of *ary* with the axes in *dim_to_tags* tagged with their + corresponding tags. Equivalent to repeated application of + :meth:`ArrayContext.tag_axis`. + """ + for iaxis, tags in dim_to_tags.items(): + ary = actx.tag_axis(iaxis, tags, ary) + + return ary + +# }}} + # vim: foldmethod=marker diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 08489872..154af2f9 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -35,8 +35,9 @@ FirstAxisIsElementsTag, PyOpenCLArrayContext, PytatoPyOpenCLArrayContext, + EagerJAXArrayContext, ArrayContainer, - to_numpy) + to_numpy, tag_axes) from arraycontext import ( # noqa: F401 pytest_generate_tests_for_array_contexts, ) @@ -1442,7 +1443,7 @@ def _twice(x): # }}} -# {{{ +# {{{ test_taggable_cl_array_tags def test_taggable_cl_array_tags(actx_factory): actx = actx_factory() @@ -1497,6 +1498,27 @@ def test_to_numpy_on_frozen_arrays(actx_factory): np.testing.assert_allclose(to_numpy(u, actx), 1) +def test_tagging(actx_factory): + actx = actx_factory() + + if isinstance(actx, EagerJAXArrayContext): + pytest.skip("Eager JAX has no tagging support") + + from pytools.tag import Tag + + class ExampleTag(Tag): + pass + + ary = tag_axes(actx, {0: ExampleTag()}, + actx.tag( + ExampleTag(), + actx.zeros((20, 20), dtype=np.float64))) + + assert ary.tags_of_type(ExampleTag) + assert ary.axes[0].tags_of_type(ExampleTag) + assert not ary.axes[1].tags_of_type(ExampleTag) + + if __name__ == "__main__": import sys if len(sys.argv) > 1: