From 688768a217ac2f7b740bec7df561f6e3e50f4334 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 29 Apr 2025 11:30:30 +0200 Subject: [PATCH] make array_type a session-scoped fixture --- src/testing/fast_array_utils/_array_type.py | 2 +- src/testing/fast_array_utils/pytest.py | 49 +++++++++++++++++---- tests/test_test_utils.py | 13 +++++- typings/h5py.pyi | 4 ++ 4 files changed, 57 insertions(+), 11 deletions(-) diff --git a/src/testing/fast_array_utils/_array_type.py b/src/testing/fast_array_utils/_array_type.py index 5dbbfd2..552198d 100644 --- a/src/testing/fast_array_utils/_array_type.py +++ b/src/testing/fast_array_utils/_array_type.py @@ -77,7 +77,7 @@ class Flags(enum.Flag): class ConversionContext: """Conversion context required for h5py.""" - hdf5_file: h5py.File + hdf5_file: h5py.File # TODO(flying-sheep): ReadOnly @dataclass(frozen=True) diff --git a/src/testing/fast_array_utils/pytest.py b/src/testing/fast_array_utils/pytest.py index 908036e..1799a2d 100644 --- a/src/testing/fast_array_utils/pytest.py +++ b/src/testing/fast_array_utils/pytest.py @@ -7,6 +7,8 @@ from __future__ import annotations import dataclasses +import os +import re from importlib.util import find_spec from typing import TYPE_CHECKING, cast @@ -93,8 +95,8 @@ def _skip_if_unimportable(array_type: ArrayType) -> pytest.MarkDecorator: ] -@pytest.fixture(params=SUPPORTED_TYPE_PARAMS) -def array_type(request: pytest.FixtureRequest, tmp_path: Path) -> Generator[ArrayType, None, None]: +@pytest.fixture(scope="session", params=SUPPORTED_TYPE_PARAMS) +def array_type(request: pytest.FixtureRequest) -> ArrayType: """Fixture for a supported :class:`~testing.fast_array_utils.ArrayType`. Use :class:`testing.fast_array_utils.Flags` to select or skip array types: @@ -131,13 +133,42 @@ def test_something(array_type: ArrayType) -> None: ... """ at = cast("ArrayType", request.param) - f: h5py.File | None = None if at.cls is types.H5Dataset or (at.inner and at.inner.cls is types.H5Dataset): + at = dataclasses.replace(at, conversion_context=CC(request)) + return at + + +try: # get the exception type + pytest.fail("x") +except BaseException as e: # noqa: BLE001 + Failed = type(e) +else: + raise AssertionError + + +class CC(ConversionContext): + def __init__(self, request: pytest.FixtureRequest) -> None: + self._request = request + + @property # This is intentionally not cached and creates a new file on each access + def hdf5_file(self) -> h5py.File: # type: ignore[override] import h5py - f = h5py.File(tmp_path / f"{request.fixturename}.h5", "w") - ctx = ConversionContext(hdf5_file=f) - at = dataclasses.replace(at, conversion_context=ctx) - yield at - if f: - f.close() + try: # If we’re being called in a test or function-scoped fixture, use the test `tmp_path` + return cast("h5py.File", self._request.getfixturevalue("tmp_hdf5_file")) + except Failed: # We’re being called from a session-scoped fixture or so + factory = cast( + "pytest.TempPathFactory", self._request.getfixturevalue("tmp_path_factory") + ) + name = re.sub(r"[^\w_. -()\[\]{}]", "_", os.environ["PYTEST_CURRENT_TEST"]) + f = h5py.File(factory.mktemp(name) / "test.h5", "w") + self._request.addfinalizer(f.close) + return f + + +@pytest.fixture +def tmp_hdf5_file(tmp_path: Path) -> Generator[h5py.File, None, None]: + import h5py + + with h5py.File(tmp_path / "test.h5", "w") as f: + yield f diff --git a/tests/test_test_utils.py b/tests/test_test_utils.py index 3415a69..d227f6d 100644 --- a/tests/test_test_utils.py +++ b/tests/test_test_utils.py @@ -19,7 +19,7 @@ from numpy.typing import DTypeLike, NDArray from scipy.sparse import coo_array, coo_matrix - from testing.fast_array_utils import ArrayType + from testing.fast_array_utils import Array, ArrayType other_array_type = array_type @@ -78,3 +78,14 @@ def test_array_types(array_type: ArrayType) -> None: assert any( getattr(t, "mod", None) in {"zarr", "h5py"} for t in (array_type, array_type.inner) ) == bool(array_type.flags & Flags.Disk) + + +@pytest.fixture(scope="session") +def session_scoped_array(array_type: ArrayType) -> Array: + return array_type(np.arange(12).reshape(3, 4), dtype=np.float32) + + +def test_session_scoped_array(session_scoped_array: Array) -> None: + """Tests that creating a session-scoped array works.""" + assert session_scoped_array.shape == (3, 4) + assert session_scoped_array.dtype == np.float32 diff --git a/typings/h5py.pyi b/typings/h5py.pyi index c6a211d..7e4809f 100644 --- a/typings/h5py.pyi +++ b/typings/h5py.pyi @@ -17,6 +17,10 @@ class Dataset(HLObject): class Group(HLObject): ... class File(Group, closing[File]): # not actually a subclass of closing + filename: str + mode: Literal["r", "r+"] + libver: Literal["earliest", "latest", "v108", "v110"] + def __init__( self, name: AnyStr | os.PathLike[AnyStr] | IO[bytes],