Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8,156 changes: 2,638 additions & 5,518 deletions .basedpyright/baseline.json

Large diffs are not rendered by default.

14 changes: 0 additions & 14 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,6 @@ jobs:
pip install ruff
ruff check

mypy:
name: Mypy
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: "Main Script"
run: |
curl -L -O https://tiker.net/ci-support-v0
. ./ci-support-v0

build_py_project_in_conda_env
python -m pip install mypy pytest
./run-mypy.sh

basedpyright:
runs-on: ubuntu-latest
steps:
Expand Down
14 changes: 0 additions & 14 deletions .gitlab-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -96,20 +96,6 @@ Ruff:
except:
- tags

Mypy:
script: |
EXTRA_INSTALL="mypy pytest"

curl -L -O https://tiker.net/ci-support-v0
. ./ci-support-v0

build_py_project_in_venv
./run-mypy.sh
tags:
- python3
except:
- tags

Downstream:
parallel:
matrix:
Expand Down
2 changes: 2 additions & 0 deletions arraycontext/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,8 @@ def __sub__(self, other: Self | ScalarLike) -> Self: ...
def __rsub__(self, other: Self | ScalarLike) -> Self: ...
def __mul__(self, other: Self | ScalarLike) -> Self: ...
def __rmul__(self, other: Self | ScalarLike) -> Self: ...
def __pow__(self, other: Self | ScalarLike) -> Self: ...
def __rpow__(self, other: Self | ScalarLike) -> Self: ...
def __truediv__(self, other: Self | ScalarLike) -> Self: ...
def __rtruediv__(self, other: Self | ScalarLike) -> Self: ...

Expand Down
17 changes: 13 additions & 4 deletions arraycontext/impl/pyopencl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
"""

from collections.abc import Callable
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Literal
from warnings import warn

import numpy as np
Expand All @@ -51,7 +51,8 @@

if TYPE_CHECKING:
import loopy as lp
import pyopencl
import pyopencl as cl
import pyopencl.array as cl_array


# {{{ PyOpenCLArrayContext
Expand Down Expand Up @@ -81,9 +82,17 @@ class PyOpenCLArrayContext(ArrayContext):
.. automethod:: transform_loopy_program
"""

context: cl.Context
queue: cl.CommandQueue
allocator: cl_array.Allocator | None

_force_device_scalars: Literal[True]
_passed_force_device_scalars: bool
_wait_event_queue_length: int

def __init__(self,
queue: pyopencl.CommandQueue,
allocator: pyopencl.tools.AllocatorBase | None = None,
queue: cl.CommandQueue,
allocator: cl_array.Allocator | None = None,
wait_event_queue_length: int | None = None,
force_device_scalars: bool | None = None) -> None:
r"""
Expand Down
4 changes: 2 additions & 2 deletions arraycontext/impl/pyopencl/fake_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
rec_multimap_array_container,
rec_multimap_reduce_array_container,
)
from arraycontext.context import Array, ArrayOrContainer
from arraycontext.context import Array as actx_Array, ArrayOrContainer
from arraycontext.fake_numpy import BaseFakeNumpyLinalgNamespace
from arraycontext.impl.pyopencl.taggable_cl_array import TaggableCLArray
from arraycontext.loopy import LoopyBasedFakeNumpyNamespace
Expand Down Expand Up @@ -206,7 +206,7 @@ def _any(ary):
_any,
a)

def array_equal(self, a: ArrayOrContainer, b: ArrayOrContainer) -> Array:
def array_equal(self, a: ArrayOrContainer, b: ArrayOrContainer) -> actx_Array:
actx = self._array_context
queue = actx.queue

Expand Down
46 changes: 31 additions & 15 deletions arraycontext/impl/pyopencl/taggable_cl_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Any
from typing import Any, Literal

import numpy as np
from numpy.typing import DTypeLike

import pyopencl as cl
import pyopencl.array as cla
from pytools import memoize
from pytools.tag import Tag, Taggable, ToTagSetConvertible
Expand Down Expand Up @@ -165,13 +167,20 @@ def to_tagged_cl_array(ary: cla.Array,
# }}}


_EMPTY_TAG_SET: frozenset[Tag] = frozenset()


# {{{ creation

def empty(queue, shape, dtype=float, *,
axes: tuple[Axis, ...] | None = None,
tags: frozenset[Tag] = frozenset(),
order: str = "C",
allocator=None) -> TaggableCLArray:
def empty(
queue: cl.CommandQueue,
shape: tuple[int, ...] | int,
dtype: DTypeLike = float,
*, axes: tuple[Axis, ...] | None = None,
tags: frozenset[Tag] = _EMPTY_TAG_SET,
order: Literal["C"] | Literal["F"] = "C",
allocator: cla.Allocator | None = None,
) -> TaggableCLArray:
if dtype is not None:
dtype = np.dtype(dtype)

Expand All @@ -181,11 +190,15 @@ def empty(queue, shape, dtype=float, *,
order=order, allocator=allocator)


def zeros(queue, shape, dtype=float, *,
axes: tuple[Axis, ...] | None = None,
tags: frozenset[Tag] = frozenset(),
order: str = "C",
allocator=None) -> TaggableCLArray:
def zeros(
queue: cl.CommandQueue,
shape: tuple[int, ...] | int,
dtype: DTypeLike = float,
*, axes: tuple[Axis, ...] | None = None,
tags: frozenset[Tag] = _EMPTY_TAG_SET,
order: Literal["C"] | Literal["F"] = "C",
allocator: cla.Allocator | None = None,
) -> TaggableCLArray:
result = empty(
queue, shape, dtype=dtype, axes=axes, tags=tags,
order=order, allocator=allocator)
Expand All @@ -194,10 +207,13 @@ def zeros(queue, shape, dtype=float, *,
return result


def to_device(queue, ary, *,
axes: tuple[Axis, ...] | None = None,
tags: frozenset[Tag] = frozenset(),
allocator=None):
def to_device(
queue: cl.CommandQueue,
ary: np.ndarray[Any],
*, axes: tuple[Axis, ...] | None = None,
tags: frozenset[Tag] = _EMPTY_TAG_SET,
allocator: cla.Allocator | None = None,
) -> TaggableCLArray:
return to_tagged_cl_array(
cla.to_device(queue, ary, allocator=allocator),
axes=axes, tags=tags)
Expand Down
18 changes: 15 additions & 3 deletions arraycontext/impl/pytato/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@

if TYPE_CHECKING:
import loopy as lp
import pyopencl as cl
import pyopencl.array as cl_array
import pytato

if getattr(sys, "_BUILDING_SPHINX_DOCS", False):
Expand Down Expand Up @@ -240,8 +242,8 @@ def get_target(self):
class ProfileEvent:
"""Holds a profile event that has not been collected by the profiler yet."""

start_cl_event: cl._cl.Event
stop_cl_event: cl._cl.Event
start_cl_event: cl.Event
stop_cl_event: cl.Event
t_unit_name: str


Expand All @@ -265,8 +267,18 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext):

.. automethod:: compile
"""
context: cl.Context
queue: cl.CommandQueue
allocator: cl_array.Allocator
using_svm: bool | None
profile_kernels: bool

_force_svm_arg_limit: int | None

def __init__(
self, queue: cl.CommandQueue, allocator=None, *,
self, queue: cl.CommandQueue,
allocator: cl_array.Allocator | None = None,
*,
use_memory_pool: bool | None = None,
compile_trace_callback: Callable[[Any, str, Any], None] | None = None,
profile_kernels: bool = False,
Expand Down
9 changes: 9 additions & 0 deletions doc/other.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,12 @@ Program creation for :mod:`loopy`
---------------------------------

.. automodule:: arraycontext.loopy

References
----------

.. currentmodule:: cl_array

.. class:: Allocator

See :class:`pyopencl.array.Allocator`.
29 changes: 10 additions & 19 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,20 @@ pytato = [
"pytato>=2021.1",
]
test = [
"mypy",
"basedpyright",
"pytest",
"ruff",
]

[tool.hatch.build.targets.sdist]
exclude = [
"/.git*",
"/doc/_build",
"/.editorconfig",
"/run-*.sh",
"/.basedpyright",
]

[project.urls]
Documentation = "https://documen.tician.de/arraycontext"
Homepage = "https://github.com/inducer/arraycontext"
Expand Down Expand Up @@ -110,24 +119,6 @@ required-imports = ["from __future__ import annotations"]
# from @dataclass_array_container.
"test/test_utils.py" = ["I002"]

[tool.mypy]
python_version = "3.10"
warn_unused_ignores = true
# TODO: enable this
# check_untyped_defs = true

[[tool.mypy.overrides]]
module = [
"islpy.*",
"loopy.*",
"meshmode.*",
"pymbolic",
"pymbolic.*",
"pyopencl.*",
"jax.*",
]
ignore_missing_imports = true

[tool.typos.default]
extend-ignore-re = [
"(?Rm)^.*(#|//)\\s*spellchecker:\\s*disable-line$"
Expand Down
3 changes: 0 additions & 3 deletions run-mypy.sh

This file was deleted.

Loading