Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 104 additions & 48 deletions .basedpyright/baseline.json
Original file line number Diff line number Diff line change
Expand Up @@ -1851,14 +1851,6 @@
}
],
"./arraycontext/context.py": [
{
"code": "reportDeprecated",
"range": {
"startColumn": 69,
"endColumn": 74,
"lineCount": 1
}
},
{
"code": "reportAny",
"range": {
Expand All @@ -1867,14 +1859,6 @@
"lineCount": 1
}
},
{
"code": "reportDeprecated",
"range": {
"startColumn": 27,
"endColumn": 32,
"lineCount": 1
}
},
{
"code": "reportAny",
"range": {
Expand Down Expand Up @@ -17009,6 +16993,14 @@
"lineCount": 1
}
},
{
"code": "reportInvalidCast",
"range": {
"startColumn": 15,
"endColumn": 45,
"lineCount": 1
}
},
{
"code": "reportUnknownParameterType",
"range": {
Expand Down Expand Up @@ -19761,6 +19753,46 @@
"lineCount": 1
}
},
{
"code": "reportArgumentType",
"range": {
"startColumn": 54,
"endColumn": 63,
"lineCount": 1
}
},
{
"code": "reportArgumentType",
"range": {
"startColumn": 54,
"endColumn": 63,
"lineCount": 1
}
},
{
"code": "reportArgumentType",
"range": {
"startColumn": 54,
"endColumn": 63,
"lineCount": 1
}
},
{
"code": "reportArgumentType",
"range": {
"startColumn": 54,
"endColumn": 63,
"lineCount": 1
}
},
{
"code": "reportArgumentType",
"range": {
"startColumn": 54,
"endColumn": 63,
"lineCount": 1
}
},
{
"code": "reportIndexIssue",
"range": {
Expand Down Expand Up @@ -20265,6 +20297,14 @@
"lineCount": 1
}
},
{
"code": "reportOperatorIssue",
"range": {
"startColumn": 21,
"endColumn": 28,
"lineCount": 1
}
},
{
"code": "reportUnknownLambdaType",
"range": {
Expand Down Expand Up @@ -20313,6 +20353,22 @@
"lineCount": 1
}
},
{
"code": "reportOperatorIssue",
"range": {
"startColumn": 8,
"endColumn": 53,
"lineCount": 1
}
},
{
"code": "reportOperatorIssue",
"range": {
"startColumn": 8,
"endColumn": 32,
"lineCount": 1
}
},
{
"code": "reportUnusedExpression",
"range": {
Expand All @@ -20321,6 +20377,14 @@
"lineCount": 1
}
},
{
"code": "reportOperatorIssue",
"range": {
"startColumn": 8,
"endColumn": 27,
"lineCount": 1
}
},
{
"code": "reportUnusedExpression",
"range": {
Expand Down Expand Up @@ -20377,6 +20441,22 @@
"lineCount": 1
}
},
{
"code": "reportOperatorIssue",
"range": {
"startColumn": 36,
"endColumn": 55,
"lineCount": 1
}
},
{
"code": "reportOperatorIssue",
"range": {
"startColumn": 32,
"endColumn": 51,
"lineCount": 1
}
},
{
"code": "reportMissingParameterType",
"range": {
Expand Down Expand Up @@ -20601,6 +20681,14 @@
"lineCount": 1
}
},
{
"code": "reportOperatorIssue",
"range": {
"startColumn": 12,
"endColumn": 21,
"lineCount": 1
}
},
{
"code": "reportMissingParameterType",
"range": {
Expand Down Expand Up @@ -21063,22 +21151,6 @@
"lineCount": 1
}
},
{
"code": "reportGeneralTypeIssues",
"range": {
"startColumn": 10,
"endColumn": 31,
"lineCount": 1
}
},
{
"code": "reportGeneralTypeIssues",
"range": {
"startColumn": 14,
"endColumn": 35,
"lineCount": 1
}
},
{
"code": "reportUnannotatedClassAttribute",
"range": {
Expand All @@ -21087,22 +21159,6 @@
"lineCount": 1
}
},
{
"code": "reportGeneralTypeIssues",
"range": {
"startColumn": 10,
"endColumn": 31,
"lineCount": 1
}
},
{
"code": "reportGeneralTypeIssues",
"range": {
"startColumn": 14,
"endColumn": 35,
"lineCount": 1
}
},
{
"code": "reportUnannotatedClassAttribute",
"range": {
Expand Down
11 changes: 10 additions & 1 deletion arraycontext/container/arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@
.. autofunction:: with_container_arithmetic

.. autoclass:: BcastUntilActxArray

References
----------

.. class:: TypeT

A type variable with an upper bound of :class:`type`.
"""


Expand Down Expand Up @@ -62,6 +69,8 @@

T = TypeVar("T")

TypeT = TypeVar("TypeT", bound=type)


@enum.unique
class _OpClass(enum.Enum):
Expand Down Expand Up @@ -190,7 +199,7 @@ def with_container_arithmetic(
bcast_numpy_array: bool = False,
_bcast_actx_array_type: bool | None = None,
bcast_container_types: tuple[type, ...] | None = None,
) -> Callable[[type], type]:
) -> Callable[[TypeT], TypeT]:
"""A class decorator that implements built-in operators for array containers
by propagating the operations to the elements of the container.

Expand Down
36 changes: 31 additions & 5 deletions arraycontext/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,19 +171,30 @@
THE SOFTWARE.
"""


from abc import ABC, abstractmethod
from collections.abc import Callable, Mapping
from typing import TYPE_CHECKING, Any, Protocol, TypeAlias, TypeVar, Union, overload
from typing import (
TYPE_CHECKING,
Any,
Literal,
Protocol,
TypeAlias,
TypeVar,
overload,
)
from warnings import warn

import numpy as np
from typing_extensions import Self

from pymbolic.typing import Integer, Scalar as _Scalar
from pytools import memoize_method


if TYPE_CHECKING:
import numpy as np
from numpy.typing import DTypeLike

import loopy
from pytools.tag import ToTagSetConvertible

Expand Down Expand Up @@ -243,6 +254,21 @@ def __rpow__(self, other: Self | ScalarLike) -> Array: ...
def __truediv__(self, other: Self | ScalarLike) -> Array: ...
def __rtruediv__(self, other: Self | ScalarLike) -> Array: ...

def copy(self) -> Self: ...

@property
def real(self) -> Array: ...
@property
def imag(self) -> Array: ...
def conj(self) -> Array: ...

def astype(self, dtype: DTypeLike) -> Array: ...

def reshape(self,
*shape: int,
order: Literal["C"] | Literal["F"]
) -> Array: ...


# deprecated, use ScalarLike instead
Scalar = _Scalar
Expand Down Expand Up @@ -287,7 +313,7 @@ def __rtruediv__(self, other: Self | ScalarLike) -> Array: ...
ContainerOrScalarT = TypeVar("ContainerOrScalarT", bound="ArrayContainer | ScalarLike")


NumpyOrContainerOrScalar = Union[np.ndarray, "ArrayContainer", ScalarLike]
NumpyOrContainerOrScalar: TypeAlias = "np.ndarray | ArrayContainer | ScalarLike"

# }}}

Expand Down Expand Up @@ -476,7 +502,7 @@ def tag(self,
@abstractmethod
def tag_axis(self,
iaxis: int, tags: ToTagSetConvertible,
array: ArrayOrContainerOrScalarT) -> ArrayOrContainerOrScalarT:
array: ArrayOrContainerT) -> ArrayOrContainerT:
"""If the array type used by the array context is capable of capturing
metadata, return a version of *array* in which axis number *iaxis* has
the *tags* applied. *array* itself is not modified. When working with
Expand Down Expand Up @@ -623,7 +649,7 @@ def permits_advanced_indexing(self) -> bool:
def tag_axes(
actx: ArrayContext,
dim_to_tags: Mapping[int, ToTagSetConvertible],
ary: ArrayT) -> ArrayT:
ary: ArrayOrContainerT) -> ArrayOrContainerT:
"""
Return a copy of *ary* with the axes in *dim_to_tags* tagged with their
corresponding tags. Equivalent to repeated application of
Expand Down
Loading