diff --git a/src/zarr/abc/store.py b/src/zarr/abc/store.py index 14566dfed2..449816209b 100644 --- a/src/zarr/abc/store.py +++ b/src/zarr/abc/store.py @@ -1,31 +1,74 @@ from abc import ABC, abstractmethod from collections.abc import AsyncGenerator -from typing import Protocol, runtime_checkable +from typing import Any, NamedTuple, Protocol, runtime_checkable + +from typing_extensions import Self from zarr.buffer import Buffer, BufferPrototype -from zarr.common import BytesLike, OpenMode +from zarr.common import AccessModeLiteral, BytesLike + + +class AccessMode(NamedTuple): + readonly: bool + overwrite: bool + create: bool + update: bool + + @classmethod + def from_literal(cls, mode: AccessModeLiteral) -> Self: + if mode in ("r", "r+", "a", "w", "w-"): + return cls( + readonly=mode == "r", + overwrite=mode == "w", + create=mode in ("a", "w", "w-"), + update=mode in ("r+", "a"), + ) + raise ValueError("mode must be one of 'r', 'r+', 'w', 'w-', 'a'") class Store(ABC): - _mode: OpenMode + _mode: AccessMode + _is_open: bool + + def __init__(self, mode: AccessModeLiteral = "r", *args: Any, **kwargs: Any): + self._is_open = False + self._mode = AccessMode.from_literal(mode) + + @classmethod + async def open(cls, *args: Any, **kwargs: Any) -> Self: + store = cls(*args, **kwargs) + await store._open() + return store + + async def _open(self) -> None: + if self._is_open: + raise ValueError("store is already open") + if not await self.empty(): + if self.mode.update or self.mode.readonly: + pass + elif self.mode.overwrite: + await self.clear() + else: + raise FileExistsError("Store already exists") + self._is_open = True + + async def _ensure_open(self) -> None: + if not self._is_open: + await self._open() - def __init__(self, mode: OpenMode = "r"): - if mode not in ("r", "r+", "w", "w-", "a"): - raise ValueError("mode must be one of 'r', 'r+', 'w', 'w-', 'a'") - self._mode = mode + @abstractmethod + async def empty(self) -> bool: ... + + @abstractmethod + async def clear(self) -> None: ... @property - def mode(self) -> OpenMode: + def mode(self) -> AccessMode: """Access mode of the store.""" return self._mode - @property - def writeable(self) -> bool: - """Is the store writeable?""" - return self.mode in ("a", "w", "w-") - def _check_writable(self) -> None: - if not self.writeable: + if self.mode.readonly: raise ValueError("store mode does not support writing") @abstractmethod @@ -173,8 +216,9 @@ def list_dir(self, prefix: str) -> AsyncGenerator[str, None]: """ ... - def close(self) -> None: # noqa: B027 + def close(self) -> None: """Close the store.""" + self._is_open = False pass diff --git a/src/zarr/api/asynchronous.py b/src/zarr/api/asynchronous.py index fa63ab46a8..5d2e54baa3 100644 --- a/src/zarr/api/asynchronous.py +++ b/src/zarr/api/asynchronous.py @@ -12,7 +12,7 @@ from zarr.array import Array, AsyncArray from zarr.buffer import NDArrayLike from zarr.chunk_key_encodings import ChunkKeyEncoding -from zarr.common import JSON, ChunkCoords, MemoryOrder, OpenMode, ZarrFormat +from zarr.common import JSON, AccessModeLiteral, ChunkCoords, MemoryOrder, ZarrFormat from zarr.group import AsyncGroup from zarr.metadata import ArrayV2Metadata, ArrayV3Metadata from zarr.store import ( @@ -158,7 +158,7 @@ async def load( async def open( *, store: StoreLike | None = None, - mode: OpenMode | None = None, # type and value changed + mode: AccessModeLiteral | None = None, # type and value changed zarr_version: ZarrFormat | None = None, # deprecated zarr_format: ZarrFormat | None = None, path: str | None = None, @@ -189,15 +189,15 @@ async def open( Return type depends on what exists in the given store. """ zarr_format = _handle_zarr_version_or_format(zarr_version=zarr_version, zarr_format=zarr_format) - store_path = make_store_path(store, mode=mode) + store_path = await make_store_path(store, mode=mode) if path is not None: store_path = store_path / path try: - return await open_array(store=store_path, zarr_format=zarr_format, **kwargs) + return await open_array(store=store_path, zarr_format=zarr_format, mode=mode, **kwargs) except KeyError: - return await open_group(store=store_path, zarr_format=zarr_format, **kwargs) + return await open_group(store=store_path, zarr_format=zarr_format, mode=mode, **kwargs) async def open_consolidated(*args: Any, **kwargs: Any) -> AsyncGroup: @@ -267,7 +267,7 @@ async def save_array( or _default_zarr_version() ) - store_path = make_store_path(store, mode="w") + store_path = await make_store_path(store, mode="w") if path is not None: store_path = store_path / path new = await AsyncArray.create( @@ -421,7 +421,7 @@ async def group( or _default_zarr_version() ) - store_path = make_store_path(store) + store_path = await make_store_path(store) if path is not None: store_path = store_path / path @@ -451,7 +451,7 @@ async def group( async def open_group( *, # Note: this is a change from v2 store: StoreLike | None = None, - mode: OpenMode | None = None, # not used + mode: AccessModeLiteral | None = None, # not used cache_attrs: bool | None = None, # not used, default changed synchronizer: Any = None, # not used path: str | None = None, @@ -512,7 +512,7 @@ async def open_group( if storage_options is not None: warnings.warn("storage_options is not yet implemented", RuntimeWarning, stacklevel=2) - store_path = make_store_path(store, mode=mode) + store_path = await make_store_path(store, mode=mode) if path is not None: store_path = store_path / path @@ -682,8 +682,8 @@ async def create( if meta_array is not None: warnings.warn("meta_array is not yet implemented", RuntimeWarning, stacklevel=2) - mode = cast(OpenMode, "r" if read_only else "w") - store_path = make_store_path(store, mode=mode) + mode = kwargs.pop("mode", cast(AccessModeLiteral, "r" if read_only else "w")) + store_path = await make_store_path(store, mode=mode) if path is not None: store_path = store_path / path @@ -854,7 +854,7 @@ async def open_array( The opened array. """ - store_path = make_store_path(store) + store_path = await make_store_path(store) if path is not None: store_path = store_path / path @@ -862,14 +862,16 @@ async def open_array( try: return await AsyncArray.open(store_path, zarr_format=zarr_format) - except KeyError as e: - if store_path.store.writeable: - pass - else: - raise e - - # if array was not found, create it - return await create(store=store, path=path, zarr_format=zarr_format, **kwargs) + except FileNotFoundError as e: + if store_path.store.mode.create: + return await create( + store=store_path, + path=path, + zarr_format=zarr_format, + overwrite=store_path.store.mode.overwrite, + **kwargs, + ) + raise e async def open_like(a: ArrayLike, path: str, **kwargs: Any) -> AsyncArray: diff --git a/src/zarr/api/synchronous.py b/src/zarr/api/synchronous.py index 57b9d5630f..eef87aab7e 100644 --- a/src/zarr/api/synchronous.py +++ b/src/zarr/api/synchronous.py @@ -5,7 +5,7 @@ import zarr.api.asynchronous as async_api from zarr.array import Array, AsyncArray from zarr.buffer import NDArrayLike -from zarr.common import JSON, ChunkCoords, OpenMode, ZarrFormat +from zarr.common import JSON, AccessModeLiteral, ChunkCoords, ZarrFormat from zarr.group import Group from zarr.store import StoreLike from zarr.sync import sync @@ -36,7 +36,7 @@ def load( def open( *, store: StoreLike | None = None, - mode: OpenMode | None = None, # type and value changed + mode: AccessModeLiteral | None = None, # type and value changed zarr_version: ZarrFormat | None = None, # deprecated zarr_format: ZarrFormat | None = None, path: str | None = None, @@ -161,7 +161,7 @@ def group( def open_group( *, # Note: this is a change from v2 store: StoreLike | None = None, - mode: OpenMode | None = None, # not used in async api + mode: AccessModeLiteral | None = None, # not used in async api cache_attrs: bool | None = None, # default changed, not used in async api synchronizer: Any = None, # not used in async api path: str | None = None, diff --git a/src/zarr/array.py b/src/zarr/array.py index e366321b15..6572127450 100644 --- a/src/zarr/array.py +++ b/src/zarr/array.py @@ -142,7 +142,7 @@ async def create( exists_ok: bool = False, data: npt.ArrayLike | None = None, ) -> AsyncArray: - store_path = make_store_path(store) + store_path = await make_store_path(store) if chunk_shape is None: if chunks is None: @@ -334,18 +334,18 @@ async def open( store: StoreLike, zarr_format: ZarrFormat | None = 3, ) -> AsyncArray: - store_path = make_store_path(store) + store_path = await make_store_path(store) if zarr_format == 2: zarray_bytes, zattrs_bytes = await gather( (store_path / ZARRAY_JSON).get(), (store_path / ZATTRS_JSON).get() ) if zarray_bytes is None: - raise KeyError(store_path) # filenotfounderror? + raise FileNotFoundError(store_path) elif zarr_format == 3: zarr_json_bytes = await (store_path / ZARR_JSON).get() if zarr_json_bytes is None: - raise KeyError(store_path) # filenotfounderror? + raise FileNotFoundError(store_path) elif zarr_format is None: zarr_json_bytes, zarray_bytes, zattrs_bytes = await gather( (store_path / ZARR_JSON).get(), @@ -357,7 +357,7 @@ async def open( # alternatively, we could warn and favor v3 raise ValueError("Both zarr.json and .zarray objects exist") if zarr_json_bytes is None and zarray_bytes is None: - raise KeyError(store_path) # filenotfounderror? + raise FileNotFoundError(store_path) # set zarr_format based on which keys were found if zarr_json_bytes is not None: zarr_format = 3 @@ -412,7 +412,7 @@ def attrs(self) -> dict[str, JSON]: @property def read_only(self) -> bool: - return bool(not self.store_path.store.writeable) + return self.store_path.store.mode.readonly @property def path(self) -> str: diff --git a/src/zarr/common.py b/src/zarr/common.py index 342db1412d..aaa30cfcb8 100644 --- a/src/zarr/common.py +++ b/src/zarr/common.py @@ -33,7 +33,7 @@ ZarrFormat = Literal[2, 3] JSON = None | str | int | float | Enum | dict[str, "JSON"] | list["JSON"] | tuple["JSON", ...] MemoryOrder = Literal["C", "F"] -OpenMode = Literal["r", "r+", "a", "w", "w-"] +AccessModeLiteral = Literal["r", "r+", "a", "w", "w-"] def product(tup: ChunkCoords) -> int: diff --git a/src/zarr/group.py b/src/zarr/group.py index 5361eb1345..9d2ad68422 100644 --- a/src/zarr/group.py +++ b/src/zarr/group.py @@ -129,7 +129,7 @@ async def create( exists_ok: bool = False, zarr_format: ZarrFormat = 3, ) -> AsyncGroup: - store_path = make_store_path(store) + store_path = await make_store_path(store) if not exists_ok: await ensure_no_existing_node(store_path, zarr_format=zarr_format) attributes = attributes or {} @@ -146,7 +146,7 @@ async def open( store: StoreLike, zarr_format: Literal[2, 3, None] = 3, ) -> AsyncGroup: - store_path = make_store_path(store) + store_path = await make_store_path(store) if zarr_format == 2: zgroup_bytes, zattrs_bytes = await asyncio.gather( @@ -169,7 +169,7 @@ async def open( # alternatively, we could warn and favor v3 raise ValueError("Both zarr.json and .zgroup objects exist") if zarr_json_bytes is None and zgroup_bytes is None: - raise KeyError(store_path) # filenotfounderror? + raise FileNotFoundError(store_path) # set zarr_format based on which keys were found if zarr_json_bytes is not None: zarr_format = 3 diff --git a/src/zarr/store/core.py b/src/zarr/store/core.py index 85f85aabde..e483c8f3b6 100644 --- a/src/zarr/store/core.py +++ b/src/zarr/store/core.py @@ -4,9 +4,9 @@ from pathlib import Path from typing import Any, Literal -from zarr.abc.store import Store +from zarr.abc.store import AccessMode, Store from zarr.buffer import Buffer, BufferPrototype, default_buffer_prototype -from zarr.common import ZARR_JSON, ZARRAY_JSON, ZGROUP_JSON, OpenMode, ZarrFormat +from zarr.common import ZARR_JSON, ZARRAY_JSON, ZGROUP_JSON, AccessModeLiteral, ZarrFormat from zarr.errors import ContainsArrayAndGroupError, ContainsArrayError, ContainsGroupError from zarr.store.local import LocalStore from zarr.store.memory import MemoryStore @@ -68,23 +68,26 @@ def __eq__(self, other: Any) -> bool: StoreLike = Store | StorePath | Path | str -def make_store_path(store_like: StoreLike | None, *, mode: OpenMode | None = None) -> StorePath: +async def make_store_path( + store_like: StoreLike | None, *, mode: AccessModeLiteral | None = None +) -> StorePath: if isinstance(store_like, StorePath): if mode is not None: - assert mode == store_like.store.mode + assert AccessMode.from_literal(mode) == store_like.store.mode return store_like elif isinstance(store_like, Store): if mode is not None: - assert mode == store_like.mode + assert AccessMode.from_literal(mode) == store_like.mode + await store_like._ensure_open() return StorePath(store_like) elif store_like is None: if mode is None: mode = "w" # exception to the default mode = 'r' - return StorePath(MemoryStore(mode=mode)) + return StorePath(await MemoryStore.open(mode=mode)) elif isinstance(store_like, Path): - return StorePath(LocalStore(store_like, mode=mode or "r")) + return StorePath(await LocalStore.open(root=store_like, mode=mode or "r")) elif isinstance(store_like, str): - return StorePath(LocalStore(Path(store_like), mode=mode or "r")) + return StorePath(await LocalStore.open(root=Path(store_like), mode=mode or "r")) raise TypeError diff --git a/src/zarr/store/local.py b/src/zarr/store/local.py index 5915559900..25fd9fc13a 100644 --- a/src/zarr/store/local.py +++ b/src/zarr/store/local.py @@ -1,13 +1,14 @@ from __future__ import annotations import io +import os import shutil from collections.abc import AsyncGenerator from pathlib import Path from zarr.abc.store import Store from zarr.buffer import Buffer, BufferPrototype -from zarr.common import OpenMode, concurrent_map, to_thread +from zarr.common import AccessModeLiteral, concurrent_map, to_thread def _get( @@ -71,14 +72,25 @@ class LocalStore(Store): root: Path - def __init__(self, root: Path | str, *, mode: OpenMode = "r"): + def __init__(self, root: Path | str, *, mode: AccessModeLiteral = "r"): super().__init__(mode=mode) if isinstance(root, str): root = Path(root) assert isinstance(root, Path) - self.root = root + async def clear(self) -> None: + self._check_writable() + shutil.rmtree(self.root) + self.root.mkdir() + + async def empty(self) -> bool: + try: + subpaths = os.listdir(self.root) + return not subpaths + except FileNotFoundError: + return True + def __str__(self) -> str: return f"file://{self.root}" @@ -94,6 +106,8 @@ async def get( prototype: BufferPrototype, byte_range: tuple[int | None, int | None] | None = None, ) -> Buffer | None: + if not self._is_open: + await self._open() assert isinstance(key, str) path = self.root / key @@ -126,6 +140,8 @@ async def get_partial_values( return await concurrent_map(args, to_thread, limit=None) # TODO: fix limit async def set(self, key: str, value: Buffer) -> None: + if not self._is_open: + await self._open() self._check_writable() assert isinstance(key, str) if not isinstance(value, Buffer): diff --git a/src/zarr/store/memory.py b/src/zarr/store/memory.py index 7f3c575719..dd3e52e703 100644 --- a/src/zarr/store/memory.py +++ b/src/zarr/store/memory.py @@ -4,7 +4,7 @@ from zarr.abc.store import Store from zarr.buffer import Buffer, BufferPrototype -from zarr.common import OpenMode, concurrent_map +from zarr.common import AccessModeLiteral, concurrent_map from zarr.store.utils import _normalize_interval_index @@ -18,11 +18,20 @@ class MemoryStore(Store): _store_dict: MutableMapping[str, Buffer] def __init__( - self, store_dict: MutableMapping[str, Buffer] | None = None, *, mode: OpenMode = "r" + self, + store_dict: MutableMapping[str, Buffer] | None = None, + *, + mode: AccessModeLiteral = "r", ): super().__init__(mode=mode) self._store_dict = store_dict or {} + async def empty(self) -> bool: + return not self._store_dict + + async def clear(self) -> None: + self._store_dict.clear() + def __str__(self) -> str: return f"memory://{id(self._store_dict)}" @@ -35,6 +44,8 @@ async def get( prototype: BufferPrototype, byte_range: tuple[int | None, int | None] | None = None, ) -> Buffer | None: + if not self._is_open: + await self._open() assert isinstance(key, str) try: value = self._store_dict[key] @@ -59,6 +70,8 @@ async def exists(self, key: str) -> bool: return key in self._store_dict async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None = None) -> None: + if not self._is_open: + await self._open() self._check_writable() assert isinstance(key, str) if not isinstance(value, Buffer): diff --git a/src/zarr/store/remote.py b/src/zarr/store/remote.py index 18ad3fa0bf..c742d9e567 100644 --- a/src/zarr/store/remote.py +++ b/src/zarr/store/remote.py @@ -6,7 +6,8 @@ import fsspec from zarr.abc.store import Store -from zarr.common import OpenMode +from zarr.buffer import BufferPrototype +from zarr.common import AccessModeLiteral from zarr.store.core import _dereference_path if TYPE_CHECKING: @@ -31,7 +32,7 @@ class RemoteStore(Store): def __init__( self, url: UPath | str, - mode: OpenMode = "r", + mode: AccessModeLiteral = "r", allowed_exceptions: tuple[type[Exception], ...] = ( FileNotFoundError, IsADirectoryError, @@ -49,7 +50,6 @@ def __init__( storage_options: passed on to fsspec to make the filesystem instance. If url is a UPath, this must not be used. """ - super().__init__(mode=mode) if isinstance(url, str): self._url = url.rstrip("/") @@ -74,6 +74,17 @@ def __init__( if not self._fs.async_impl: raise TypeError("FileSystem needs to support async operations") + async def clear(self) -> None: + try: + for subpath in await self._fs._find(self.path, withdirs=True): + if subpath != self.path: + await self._fs._rm(subpath, recursive=True) + except FileNotFoundError: + pass + + async def empty(self) -> bool: + return not await self._fs._find(self.path, withdirs=True) + def __str__(self) -> str: return f"{self._url}" @@ -86,6 +97,8 @@ async def get( prototype: BufferPrototype, byte_range: tuple[int | None, int | None] | None = None, ) -> Buffer | None: + if not self._is_open: + await self._open() path = _dereference_path(self.path, key) try: @@ -121,6 +134,8 @@ async def set( value: Buffer, byte_range: tuple[int, int] | None = None, ) -> None: + if not self._is_open: + await self._open() self._check_writable() path = _dereference_path(self.path, key) # write data diff --git a/src/zarr/testing/store.py b/src/zarr/testing/store.py index a4e154bbc9..72dadac1df 100644 --- a/src/zarr/testing/store.py +++ b/src/zarr/testing/store.py @@ -2,7 +2,7 @@ import pytest -from zarr.abc.store import Store +from zarr.abc.store import AccessMode, Store from zarr.buffer import Buffer, default_buffer_prototype from zarr.store.utils import _normalize_interval_index from zarr.testing.utils import assert_bytes_equal @@ -32,33 +32,28 @@ def get(self, store: S, key: str) -> Buffer: @pytest.fixture(scope="function") def store_kwargs(self) -> dict[str, Any]: - return {"mode": "w"} + return {"mode": "r+"} @pytest.fixture(scope="function") - def store(self, store_kwargs: dict[str, Any]) -> Store: - return self.store_cls(**store_kwargs) + async def store(self, store_kwargs: dict[str, Any]) -> Store: + return await self.store_cls.open(**store_kwargs) def test_store_type(self, store: S) -> None: assert isinstance(store, Store) assert isinstance(store, self.store_cls) def test_store_mode(self, store: S, store_kwargs: dict[str, Any]) -> None: - assert store.mode == "w", store.mode - assert store.writeable + assert store.mode == AccessMode.from_literal("r+") + assert not store.mode.readonly with pytest.raises(AttributeError): - store.mode = "w" # type: ignore[misc] - - # read-only - kwargs = {**store_kwargs, "mode": "r"} - read_store = self.store_cls(**kwargs) - assert read_store.mode == "r", read_store.mode - assert not read_store.writeable + store.mode = AccessMode.from_literal("w") # type: ignore[misc] async def test_not_writable_store_raises(self, store_kwargs: dict[str, Any]) -> None: kwargs = {**store_kwargs, "mode": "r"} - store = self.store_cls(**kwargs) - assert not store.writeable + store = await self.store_cls.open(**kwargs) + assert store.mode == AccessMode.from_literal("r") + assert store.mode.readonly # set with pytest.raises(ValueError): @@ -102,7 +97,7 @@ async def test_set(self, store: S, key: str, data: bytes) -> None: """ Ensure that data can be written to the store using the store.set method. """ - assert store.writeable + assert not store.mode.readonly data_buf = Buffer.from_bytes(data) await store.set(key, data_buf) observed = self.get(store, key) @@ -157,6 +152,16 @@ async def test_delete(self, store: S) -> None: await store.delete("foo/zarr.json") assert not await store.exists("foo/zarr.json") + async def test_empty(self, store: S) -> None: + assert await store.empty() + self.set(store, "key", Buffer.from_bytes(bytes("something", encoding="utf-8"))) + assert not await store.empty() + + async def test_clear(self, store: S) -> None: + self.set(store, "key", Buffer.from_bytes(bytes("something", encoding="utf-8"))) + await store.clear() + assert await store.empty() + async def test_list(self, store: S) -> None: assert [k async for k in store.list()] == [] await store.set("foo/zarr.json", Buffer.from_bytes(b"bar")) diff --git a/tests/v3/conftest.py b/tests/v3/conftest.py index 8b75d9f2f8..e8080e6034 100644 --- a/tests/v3/conftest.py +++ b/tests/v3/conftest.py @@ -22,15 +22,15 @@ from zarr.store.remote import RemoteStore -def parse_store( +async def parse_store( store: Literal["local", "memory", "remote"], path: str ) -> LocalStore | MemoryStore | RemoteStore: if store == "local": - return LocalStore(path, mode="w") + return await LocalStore.open(path, mode="w") if store == "memory": - return MemoryStore(mode="w") + return await MemoryStore.open(mode="w") if store == "remote": - return RemoteStore(url=path, mode="w") + return await RemoteStore.open(url=path, mode="w") raise AssertionError @@ -41,31 +41,31 @@ def path_type(request: pytest.FixtureRequest) -> Any: # todo: harmonize this with local_store fixture @pytest.fixture -def store_path(tmpdir: LEGACY_PATH) -> StorePath: - store = LocalStore(str(tmpdir), mode="w") +async def store_path(tmpdir: LEGACY_PATH) -> StorePath: + store = await LocalStore.open(str(tmpdir), mode="w") p = StorePath(store) return p @pytest.fixture(scope="function") -def local_store(tmpdir: LEGACY_PATH) -> LocalStore: - return LocalStore(str(tmpdir), mode="w") +async def local_store(tmpdir: LEGACY_PATH) -> LocalStore: + return await LocalStore.open(str(tmpdir), mode="w") @pytest.fixture(scope="function") -def remote_store(url: str) -> RemoteStore: - return RemoteStore(url, mode="w") +async def remote_store(url: str) -> RemoteStore: + return await RemoteStore.open(url, mode="w") @pytest.fixture(scope="function") -def memory_store() -> MemoryStore: - return MemoryStore(mode="w") +async def memory_store() -> MemoryStore: + return await MemoryStore.open(mode="w") @pytest.fixture(scope="function") -def store(request: pytest.FixtureRequest, tmpdir: LEGACY_PATH) -> Store: +async def store(request: pytest.FixtureRequest, tmpdir: LEGACY_PATH) -> Store: param = request.param - return parse_store(param, str(tmpdir)) + return await parse_store(param, str(tmpdir)) @dataclass @@ -79,7 +79,7 @@ class AsyncGroupRequest: async def async_group(request: pytest.FixtureRequest, tmpdir: LEGACY_PATH) -> AsyncGroup: param: AsyncGroupRequest = request.param - store = parse_store(param.store, str(tmpdir)) + store = await parse_store(param.store, str(tmpdir)) agroup = await AsyncGroup.create( store, attributes=param.attributes, diff --git a/tests/v3/test_api.py b/tests/v3/test_api.py index 67e2904a83..002678be88 100644 --- a/tests/v3/test_api.py +++ b/tests/v3/test_api.py @@ -1,6 +1,7 @@ import numpy as np import pytest from numpy.testing import assert_array_equal +from pytest_asyncio import fixture import zarr from zarr import Array, Group @@ -29,7 +30,7 @@ def test_create_array(memory_store: Store) -> None: assert z.chunks == (40,) -def test_open_array(memory_store: Store) -> None: +async def test_open_array(memory_store: Store) -> None: store = memory_store # open array, create if doesn't exist @@ -44,18 +45,19 @@ def test_open_array(memory_store: Store) -> None: assert z.shape == (200,) # open array, read-only - ro_store = type(store)(store_dict=store._store_dict, mode="r") + store_cls = type(store) + ro_store = await store_cls.open(store_dict=store._store_dict, mode="r") z = open(store=ro_store) assert isinstance(z, Array) assert z.shape == (200,) assert z.read_only # path not found - with pytest.raises(ValueError): + with pytest.raises(FileNotFoundError): open(store="doesnotexist", mode="r") -def test_open_group(memory_store: Store) -> None: +async def test_open_group(memory_store: Store) -> None: store = memory_store # open group, create if doesn't exist @@ -70,7 +72,8 @@ def test_open_group(memory_store: Store) -> None: # assert "foo" not in g # open group, read-only - ro_store = type(store)(store_dict=store._store_dict, mode="r") + store_cls = type(store) + ro_store = await store_cls.open(store_dict=store._store_dict, mode="r") g = open_group(store=ro_store) assert isinstance(g, Group) # assert g.read_only @@ -88,6 +91,55 @@ def test_save_errors() -> None: save("data/group.zarr") +@fixture +def tmppath(tmpdir): + return str(tmpdir / "example.zarr") + + +def test_open_with_mode_r(tmppath) -> None: + # 'r' means read only (must exist) + with pytest.raises(FileNotFoundError): + zarr.open(store=tmppath, mode="r") + zarr.ones(store=tmppath, shape=(3, 3)) + z2 = zarr.open(store=tmppath, mode="r") + assert (z2[:] == 1).all() + with pytest.raises(ValueError): + z2[:] = 3 + + +def test_open_with_mode_r_plus(tmppath) -> None: + # 'r+' means read/write (must exist) + with pytest.raises(FileNotFoundError): + zarr.open(store=tmppath, mode="r+") + zarr.ones(store=tmppath, shape=(3, 3)) + z2 = zarr.open(store=tmppath, mode="r+") + assert (z2[:] == 1).all() + z2[:] = 3 + + +def test_open_with_mode_a(tmppath) -> None: + # 'a' means read/write (create if doesn't exist) + zarr.open(store=tmppath, mode="a", shape=(3, 3))[...] = 1 + z2 = zarr.open(store=tmppath, mode="a") + assert (z2[:] == 1).all() + z2[:] = 3 + + +def test_open_with_mode_w(tmppath) -> None: + # 'w' means create (overwrite if exists); + zarr.open(store=tmppath, mode="w", shape=(3, 3))[...] = 3 + z2 = zarr.open(store=tmppath, mode="w", shape=(3, 3)) + assert not (z2[:] == 3).all() + z2[:] = 3 + + +def test_open_with_mode_w_minus(tmppath) -> None: + # 'w-' means create (fail if exists) + zarr.open(store=tmppath, mode="w-", shape=(3, 3))[...] = 1 + with pytest.raises(FileExistsError): + zarr.open(store=tmppath, mode="w-") + + # def test_lazy_loader(): # foo = np.arange(100) # bar = np.arange(100, 0, -1) diff --git a/tests/v3/test_group.py b/tests/v3/test_group.py index f942eb6033..daa5979b27 100644 --- a/tests/v3/test_group.py +++ b/tests/v3/test_group.py @@ -19,10 +19,10 @@ @pytest.fixture(params=["local", "memory"]) -def store(request: pytest.FixtureRequest, tmpdir: LEGACY_PATH) -> LocalStore | MemoryStore: - result = parse_store(request.param, str(tmpdir)) +async def store(request: pytest.FixtureRequest, tmpdir: LEGACY_PATH) -> LocalStore | MemoryStore: + result = await parse_store(request.param, str(tmpdir)) if not isinstance(result, MemoryStore | LocalStore): - raise TypeError("Wrong store class returned by test fixture!") + raise TypeError("Wrong store class returned by test fixture! got " + result + " instead") return result @@ -430,7 +430,7 @@ async def test_asyncgroup_create( ) assert agroup.metadata == GroupMetadata(zarr_format=zarr_format, attributes=attributes) - assert agroup.store_path == make_store_path(store) + assert agroup.store_path == await make_store_path(store) if not exists_ok: with pytest.raises(ContainsGroupError): diff --git a/tests/v3/test_indexing.py b/tests/v3/test_indexing.py index c84c091089..925c258b16 100644 --- a/tests/v3/test_indexing.py +++ b/tests/v3/test_indexing.py @@ -26,8 +26,8 @@ @pytest.fixture -def store() -> Iterator[Store]: - yield StorePath(MemoryStore(mode="w")) +async def store() -> Iterator[Store]: + yield StorePath(await MemoryStore.open(mode="w")) def zarr_array_from_numpy_array( @@ -47,9 +47,11 @@ def zarr_array_from_numpy_array( class CountingDict(MemoryStore): - def __init__(self): - super().__init__(mode="w") - self.counter = Counter() + @classmethod + async def open(cls): + store = await super().open(mode="w") + store.counter = Counter() + return store async def get(self, key, prototype: BufferPrototype, byte_range=None): key_suffix = "/".join(key.split("/")[1:]) @@ -1679,7 +1681,7 @@ def test_numpy_int_indexing(store: StorePath): ), ], ) -def test_accessed_chunks(shape, chunks, ops): +async def test_accessed_chunks(shape, chunks, ops): # Test that only the required chunks are accessed during basic selection operations # shape: array shape # chunks: chunk size @@ -1688,7 +1690,7 @@ def test_accessed_chunks(shape, chunks, ops): import itertools # Use a counting dict as the backing store so we can track the items access - store = CountingDict() + store = await CountingDict.open() z = zarr_array_from_numpy_array(StorePath(store), np.zeros(shape), chunk_shape=chunks) for ii, (optype, slices) in enumerate(ops): diff --git a/tests/v3/test_store/test_core.py b/tests/v3/test_store/test_core.py index b573b0fef5..1d277cf502 100644 --- a/tests/v3/test_store/test_core.py +++ b/tests/v3/test_store/test_core.py @@ -7,30 +7,30 @@ from zarr.store.memory import MemoryStore -def test_make_store_path(tmpdir) -> None: +async def test_make_store_path(tmpdir) -> None: # None - store_path = make_store_path(None) + store_path = await make_store_path(None) assert isinstance(store_path.store, MemoryStore) # str - store_path = make_store_path(str(tmpdir)) + store_path = await make_store_path(str(tmpdir)) assert isinstance(store_path.store, LocalStore) assert Path(store_path.store.root) == Path(tmpdir) # Path - store_path = make_store_path(Path(tmpdir)) + store_path = await make_store_path(Path(tmpdir)) assert isinstance(store_path.store, LocalStore) assert Path(store_path.store.root) == Path(tmpdir) # Store - store_path = make_store_path(store_path.store) + store_path = await make_store_path(store_path.store) assert isinstance(store_path.store, LocalStore) assert Path(store_path.store.root) == Path(tmpdir) # StorePath - store_path = make_store_path(store_path) + store_path = await make_store_path(store_path) assert isinstance(store_path.store, LocalStore) assert Path(store_path.store.root) == Path(tmpdir) with pytest.raises(TypeError): - make_store_path(1) + await make_store_path(1) diff --git a/tests/v3/test_store/test_local.py b/tests/v3/test_store/test_local.py index 191a137d46..6b7f91b87d 100644 --- a/tests/v3/test_store/test_local.py +++ b/tests/v3/test_store/test_local.py @@ -21,7 +21,7 @@ def set(self, store: LocalStore, key: str, value: Buffer) -> None: @pytest.fixture def store_kwargs(self, tmpdir) -> dict[str, str]: - return {"root": str(tmpdir), "mode": "w"} + return {"root": str(tmpdir), "mode": "r+"} def test_store_repr(self, store: LocalStore) -> None: assert str(store) == f"file://{store.root!s}" diff --git a/tests/v3/test_store/test_memory.py b/tests/v3/test_store/test_memory.py index dd3cad7d7e..5b8f1ef875 100644 --- a/tests/v3/test_store/test_memory.py +++ b/tests/v3/test_store/test_memory.py @@ -20,7 +20,7 @@ def get(self, store: MemoryStore, key: str) -> Buffer: def store_kwargs( self, request: pytest.FixtureRequest ) -> dict[str, str | None | dict[str, Buffer]]: - return {"store_dict": request.param, "mode": "w"} + return {"store_dict": request.param, "mode": "r+"} @pytest.fixture(scope="function") def store(self, store_kwargs: str | None | dict[str, Buffer]) -> MemoryStore: diff --git a/tests/v3/test_store/test_remote.py b/tests/v3/test_store/test_remote.py index 0dc399be42..3a43ab43e9 100644 --- a/tests/v3/test_store/test_remote.py +++ b/tests/v3/test_store/test_remote.py @@ -82,7 +82,9 @@ async def alist(it): async def test_basic(): - store = RemoteStore(f"s3://{test_bucket_name}", mode="w", endpoint_url=endpoint_url, anon=False) + store = await RemoteStore.open( + f"s3://{test_bucket_name}", mode="w", endpoint_url=endpoint_url, anon=False + ) assert not await alist(store.list()) assert not await store.exists("foo") data = b"hello" @@ -102,7 +104,7 @@ class TestRemoteStoreS3(StoreTests[RemoteStore]): def store_kwargs(self, request) -> dict[str, str | bool]: url = f"s3://{test_bucket_name}" anon = False - mode = "w" + mode = "r+" if request.param == "use_upath": return {"mode": mode, "url": UPath(url, endpoint_url=endpoint_url, anon=anon)} elif request.param == "use_str": diff --git a/tests/v3/test_v2.py b/tests/v3/test_v2.py index 41555bbd26..7a7d728067 100644 --- a/tests/v3/test_v2.py +++ b/tests/v3/test_v2.py @@ -9,8 +9,8 @@ @pytest.fixture -def store() -> Iterator[Store]: - yield StorePath(MemoryStore(mode="w")) +async def store() -> Iterator[Store]: + yield StorePath(await MemoryStore.open(mode="w")) def test_simple(store: Store):