From 8407c640b2764f3ca88b989bb7c344d88b98174c Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Fri, 8 Nov 2024 14:45:16 +0100 Subject: [PATCH 01/25] feat: add wrapperstore --- src/zarr/storage/__init__.py | 2 + src/zarr/storage/wrapper.py | 144 +++++++++++++++++++++++++++++++ tests/test_store/test_wrapper.py | 46 ++++++++++ 3 files changed, 192 insertions(+) create mode 100644 src/zarr/storage/wrapper.py create mode 100644 tests/test_store/test_wrapper.py diff --git a/src/zarr/storage/__init__.py b/src/zarr/storage/__init__.py index 6703aa2723..17b11f54a6 100644 --- a/src/zarr/storage/__init__.py +++ b/src/zarr/storage/__init__.py @@ -3,6 +3,7 @@ from zarr.storage.logging import LoggingStore from zarr.storage.memory import MemoryStore from zarr.storage.remote import RemoteStore +from zarr.storage.wrapper import WrapperStore from zarr.storage.zip import ZipStore __all__ = [ @@ -12,6 +13,7 @@ "RemoteStore", "StoreLike", "StorePath", + "WrapperStore", "ZipStore", "make_store_path", ] diff --git a/src/zarr/storage/wrapper.py b/src/zarr/storage/wrapper.py new file mode 100644 index 0000000000..344ba41065 --- /dev/null +++ b/src/zarr/storage/wrapper.py @@ -0,0 +1,144 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Generic, TypeVar + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator, Iterable + from types import TracebackType + from typing import Any, Self + + from zarr.abc.store import ByteRangeRequest + from zarr.core.buffer import Buffer, BufferPrototype + from zarr.core.common import AccessModeLiteral, BytesLike + +from zarr.abc.store import AccessMode, Store + +T_Wrapped = TypeVar("T_Wrapped", bound=Store) + + +class WrapperStore(Store, Generic[T_Wrapped]): + """ + A store class that wraps an existing ``Store`` instance. + By default all of the store methods are delegated to the wrapped store instance, which is + accessible via the ``._wrapped`` attribute of this class. + + Use this class to modify or extend the behavior of the other store classes. + """ + + _wrapped: T_Wrapped + + def __init__(self, wrapped: T_Wrapped) -> None: + self._wrapped = wrapped + + @classmethod + async def open( + cls: type[Self], wrapped_class: type[T_Wrapped], *args: Any, **kwargs: Any + ) -> Self: + wrapped = wrapped_class(*args, **kwargs) + await wrapped._open() + return cls(wrapped=wrapped) + + def __enter__(self) -> Self: + return type(self)(self._wrapped.__enter__()) + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + return self._wrapped.__exit__(exc_type, exc_value, traceback) + + async def _open(self) -> None: + await self._wrapped._open() + + async def _ensure_open(self) -> None: + await self._wrapped._ensure_open() + + async def empty(self) -> bool: + return await self._wrapped.empty() + + async def clear(self) -> None: + return await self._wrapped.clear() + + def with_mode(self, mode: AccessModeLiteral) -> Self: + return type(self)(wrapped=self._wrapped.with_mode(mode=mode)) + + @property + def mode(self) -> AccessMode: + return self._wrapped._mode + + def _check_writable(self) -> None: + return self._wrapped._check_writable() + + def __eq__(self, value: object) -> bool: + return type(self) is type(value) and self._wrapped.__eq__(value) + + async def get( + self, key: str, prototype: BufferPrototype, byte_range: ByteRangeRequest | None = None + ) -> Buffer | None: + return await self._wrapped.get(key, prototype, byte_range) + + async def get_partial_values( + self, + prototype: BufferPrototype, + key_ranges: Iterable[tuple[str, ByteRangeRequest]], + ) -> list[Buffer | None]: + return await self._wrapped.get_partial_values(prototype, key_ranges) + + async def exists(self, key: str) -> bool: + return await self._wrapped.exists(key) + + async def set(self, key: str, value: Buffer) -> None: + await self._wrapped.set(key, value) + + async def set_if_not_exists(self, key: str, value: Buffer) -> None: + return await self._wrapped.set_if_not_exists(key, value) + + async def _set_many(self, values: Iterable[tuple[str, Buffer]]) -> None: + await self._wrapped._set_many(values) + + @property + def supports_writes(self) -> bool: + return self._wrapped.supports_writes + + @property + def supports_deletes(self) -> bool: + return self._wrapped.supports_deletes + + async def delete(self, key: str) -> None: + await self._wrapped.delete(key) + + @property + def supports_partial_writes(self) -> bool: + return self._wrapped.supports_partial_writes + + async def set_partial_values( + self, key_start_values: Iterable[tuple[str, int, BytesLike]] + ) -> None: + return await self._wrapped.set_partial_values(key_start_values) + + @property + def supports_listing(self) -> bool: + return self._wrapped.supports_listing + + def list(self) -> AsyncGenerator[str]: + return self._wrapped.list() + + def list_prefix(self, prefix: str) -> AsyncGenerator[str]: + return self._wrapped.list_prefix(prefix) + + def list_dir(self, prefix: str) -> AsyncGenerator[str]: + return self._wrapped.list_dir(prefix) + + async def delete_dir(self, prefix: str) -> None: + return await self._wrapped.delete_dir(prefix) + + def close(self) -> None: + self._wrapped.close() + + async def _get_many( + self, requests: Iterable[tuple[str, BufferPrototype, ByteRangeRequest | None]] + ) -> AsyncGenerator[tuple[str, Buffer | None], None]: + async for req in self._wrapped._get_many(requests): + yield req diff --git a/tests/test_store/test_wrapper.py b/tests/test_store/test_wrapper.py new file mode 100644 index 0000000000..3289db4c5e --- /dev/null +++ b/tests/test_store/test_wrapper.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +from zarr.core.buffer.cpu import Buffer, buffer_prototype +from zarr.storage.wrapper import WrapperStore + +if TYPE_CHECKING: + from zarr.abc.store import Store + from zarr.core.buffer.core import BufferPrototype + + +@pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=True) +async def test_wrapped_set(store: Store, capsys: pytest.CaptureFixture[str]) -> None: + # define a class that prints when it sets + class NoisySetter(WrapperStore): + async def set(self, key: str, value: Buffer) -> None: + print(f"setting {key}") + await super().set(key, value) + + key = "foo" + value = Buffer.from_bytes(b"bar") + store_wrapped = NoisySetter(store) + await store_wrapped.set(key, value) + captured = capsys.readouterr() + assert f"setting {key}" in captured.out + assert await store_wrapped.get(key, buffer_prototype) == value + + +@pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=True) +async def test_wrapped_get(store: Store, capsys: pytest.CaptureFixture[str]) -> None: + # define a class that prints when it sets + class NoisySetter(WrapperStore): + def get(self, key: str, prototype: BufferPrototype) -> None: + print(f"getting {key}") + return super().get(key, prototype=prototype) + + key = "foo" + value = Buffer.from_bytes(b"bar") + store_wrapped = NoisySetter(store) + await store_wrapped.set(key, value) + assert await store_wrapped.get(key, buffer_prototype) == value + captured = capsys.readouterr() + assert f"getting {key}" in captured.out From 5e9ffb8152a607773cea0b509b5b71a6f5237696 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Fri, 8 Nov 2024 14:45:53 +0100 Subject: [PATCH 02/25] feat: add latencystore --- src/zarr/testing/store.py | 56 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 54 insertions(+), 2 deletions(-) diff --git a/src/zarr/testing/store.py b/src/zarr/testing/store.py index 3aece0f4a9..40d4aa4095 100644 --- a/src/zarr/testing/store.py +++ b/src/zarr/testing/store.py @@ -1,9 +1,20 @@ +from __future__ import annotations + +import asyncio import pickle -from typing import Any, Generic, TypeVar, cast +from typing import TYPE_CHECKING, Generic, TypeVar, cast + +from zarr.storage.wrapper import WrapperStore + +if TYPE_CHECKING: + from typing import Any + + from zarr.abc.store import ByteRangeRequest + from zarr.core.buffer.core import BufferPrototype import pytest -from zarr.abc.store import AccessMode, Store +from zarr.abc.store import AccessMode, ByteRangeRequest, Store from zarr.core.buffer import Buffer, default_buffer_prototype from zarr.core.common import AccessModeLiteral from zarr.core.sync import _collect_aiterator @@ -352,3 +363,44 @@ async def test_set_if_not_exists(self, store: S) -> None: result = await store.get("k2", default_buffer_prototype()) assert result == new + + +class LatencyStore(WrapperStore[Store]): + """ + A wrapper class that takes any store class in its constructor and + adds latency to the `set` and `get` methods. This can be used for + performance testing. + """ + + get_latency: float + set_latency: float + + def __init__(self, cls: Store, *, get_latency: float = 0, set_latency: float = 0) -> None: + self.get_latency = float(get_latency) + self.set_latency = float(set_latency) + self._wrapped = cls + + async def set(self, key: str, value: Buffer) -> None: + await asyncio.sleep(self.set_latency) + await self._wrapped.set(key, value) + + async def get( + self, key: str, prototype: BufferPrototype, byte_range: ByteRangeRequest | None = None + ) -> Buffer | None: + """ + Add latency to the get method. + + Adds a sleep of `self.get_latency` seconds before calling the wrapped method. + + Parameters + ---------- + key : str + prototype : BufferPrototype + byte_range : ByteRangeRequest, optional + + Returns + ------- + buffer : Buffer or None + """ + await asyncio.sleep(self.get_latency) + return await self._wrapped.get(key, prototype=prototype, byte_range=byte_range) From 5d7abf47e9d1acaa53c4f339a4ac65a9cce24ad3 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Fri, 8 Nov 2024 15:02:13 +0100 Subject: [PATCH 03/25] rename noisysetter -> noisygetter --- tests/test_store/test_wrapper.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_store/test_wrapper.py b/tests/test_store/test_wrapper.py index 3289db4c5e..1caf9c9ae4 100644 --- a/tests/test_store/test_wrapper.py +++ b/tests/test_store/test_wrapper.py @@ -32,14 +32,14 @@ async def set(self, key: str, value: Buffer) -> None: @pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=True) async def test_wrapped_get(store: Store, capsys: pytest.CaptureFixture[str]) -> None: # define a class that prints when it sets - class NoisySetter(WrapperStore): + class NoisyGetter(WrapperStore): def get(self, key: str, prototype: BufferPrototype) -> None: print(f"getting {key}") return super().get(key, prototype=prototype) key = "foo" value = Buffer.from_bytes(b"bar") - store_wrapped = NoisySetter(store) + store_wrapped = NoisyGetter(store) await store_wrapped.set(key, value) assert await store_wrapped.get(key, buffer_prototype) == value captured = capsys.readouterr() From c4863512fe0a9a5a9c3a7842d1c07309951649e5 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Fri, 8 Nov 2024 15:28:14 +0100 Subject: [PATCH 04/25] rename _wrapped to _store --- src/zarr/storage/wrapper.py | 76 ++++++++++++++++++------------------- src/zarr/testing/store.py | 6 +-- 2 files changed, 40 insertions(+), 42 deletions(-) diff --git a/src/zarr/storage/wrapper.py b/src/zarr/storage/wrapper.py index 344ba41065..a571568e03 100644 --- a/src/zarr/storage/wrapper.py +++ b/src/zarr/storage/wrapper.py @@ -13,10 +13,10 @@ from zarr.abc.store import AccessMode, Store -T_Wrapped = TypeVar("T_Wrapped", bound=Store) +T_Store = TypeVar("T_Store", bound=Store) -class WrapperStore(Store, Generic[T_Wrapped]): +class WrapperStore(Store, Generic[T_Store]): """ A store class that wraps an existing ``Store`` instance. By default all of the store methods are delegated to the wrapped store instance, which is @@ -25,21 +25,19 @@ class WrapperStore(Store, Generic[T_Wrapped]): Use this class to modify or extend the behavior of the other store classes. """ - _wrapped: T_Wrapped + _store: T_Store - def __init__(self, wrapped: T_Wrapped) -> None: - self._wrapped = wrapped + def __init__(self, store: T_Store) -> None: + self._store = store @classmethod - async def open( - cls: type[Self], wrapped_class: type[T_Wrapped], *args: Any, **kwargs: Any - ) -> Self: - wrapped = wrapped_class(*args, **kwargs) - await wrapped._open() - return cls(wrapped=wrapped) + async def open(cls: type[Self], store_cls: type[T_Store], *args: Any, **kwargs: Any) -> Self: + store = store_cls(*args, **kwargs) + await store._open() + return cls(store=store) def __enter__(self) -> Self: - return type(self)(self._wrapped.__enter__()) + return type(self)(self._store.__enter__()) def __exit__( self, @@ -47,98 +45,98 @@ def __exit__( exc_value: BaseException | None, traceback: TracebackType | None, ) -> None: - return self._wrapped.__exit__(exc_type, exc_value, traceback) + return self._store.__exit__(exc_type, exc_value, traceback) async def _open(self) -> None: - await self._wrapped._open() + await self._store._open() async def _ensure_open(self) -> None: - await self._wrapped._ensure_open() + await self._store._ensure_open() async def empty(self) -> bool: - return await self._wrapped.empty() + return await self._store.empty() async def clear(self) -> None: - return await self._wrapped.clear() + return await self._store.clear() def with_mode(self, mode: AccessModeLiteral) -> Self: - return type(self)(wrapped=self._wrapped.with_mode(mode=mode)) + return type(self)(store=self._store.with_mode(mode=mode)) @property def mode(self) -> AccessMode: - return self._wrapped._mode + return self._store._mode def _check_writable(self) -> None: - return self._wrapped._check_writable() + return self._store._check_writable() def __eq__(self, value: object) -> bool: - return type(self) is type(value) and self._wrapped.__eq__(value) + return type(self) is type(value) and self._store.__eq__(value) async def get( self, key: str, prototype: BufferPrototype, byte_range: ByteRangeRequest | None = None ) -> Buffer | None: - return await self._wrapped.get(key, prototype, byte_range) + return await self._store.get(key, prototype, byte_range) async def get_partial_values( self, prototype: BufferPrototype, key_ranges: Iterable[tuple[str, ByteRangeRequest]], ) -> list[Buffer | None]: - return await self._wrapped.get_partial_values(prototype, key_ranges) + return await self._store.get_partial_values(prototype, key_ranges) async def exists(self, key: str) -> bool: - return await self._wrapped.exists(key) + return await self._store.exists(key) async def set(self, key: str, value: Buffer) -> None: - await self._wrapped.set(key, value) + await self._store.set(key, value) async def set_if_not_exists(self, key: str, value: Buffer) -> None: - return await self._wrapped.set_if_not_exists(key, value) + return await self._store.set_if_not_exists(key, value) async def _set_many(self, values: Iterable[tuple[str, Buffer]]) -> None: - await self._wrapped._set_many(values) + await self._store._set_many(values) @property def supports_writes(self) -> bool: - return self._wrapped.supports_writes + return self._store.supports_writes @property def supports_deletes(self) -> bool: - return self._wrapped.supports_deletes + return self._store.supports_deletes async def delete(self, key: str) -> None: - await self._wrapped.delete(key) + await self._store.delete(key) @property def supports_partial_writes(self) -> bool: - return self._wrapped.supports_partial_writes + return self._store.supports_partial_writes async def set_partial_values( self, key_start_values: Iterable[tuple[str, int, BytesLike]] ) -> None: - return await self._wrapped.set_partial_values(key_start_values) + return await self._store.set_partial_values(key_start_values) @property def supports_listing(self) -> bool: - return self._wrapped.supports_listing + return self._store.supports_listing def list(self) -> AsyncGenerator[str]: - return self._wrapped.list() + return self._store.list() def list_prefix(self, prefix: str) -> AsyncGenerator[str]: - return self._wrapped.list_prefix(prefix) + return self._store.list_prefix(prefix) def list_dir(self, prefix: str) -> AsyncGenerator[str]: - return self._wrapped.list_dir(prefix) + return self._store.list_dir(prefix) async def delete_dir(self, prefix: str) -> None: - return await self._wrapped.delete_dir(prefix) + return await self._store.delete_dir(prefix) def close(self) -> None: - self._wrapped.close() + self._store.close() async def _get_many( self, requests: Iterable[tuple[str, BufferPrototype, ByteRangeRequest | None]] ) -> AsyncGenerator[tuple[str, Buffer | None], None]: - async for req in self._wrapped._get_many(requests): + async for req in self._store._get_many(requests): yield req diff --git a/src/zarr/testing/store.py b/src/zarr/testing/store.py index 40d4aa4095..aedb537d75 100644 --- a/src/zarr/testing/store.py +++ b/src/zarr/testing/store.py @@ -378,11 +378,11 @@ class LatencyStore(WrapperStore[Store]): def __init__(self, cls: Store, *, get_latency: float = 0, set_latency: float = 0) -> None: self.get_latency = float(get_latency) self.set_latency = float(set_latency) - self._wrapped = cls + self._store = cls async def set(self, key: str, value: Buffer) -> None: await asyncio.sleep(self.set_latency) - await self._wrapped.set(key, value) + await self._store.set(key, value) async def get( self, key: str, prototype: BufferPrototype, byte_range: ByteRangeRequest | None = None @@ -403,4 +403,4 @@ async def get( buffer : Buffer or None """ await asyncio.sleep(self.get_latency) - return await self._wrapped.get(key, prototype=prototype, byte_range=byte_range) + return await self._store.get(key, prototype=prototype, byte_range=byte_range) From f97b27c7a01efc1dd298510817f17a8b96d62092 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Fri, 8 Nov 2024 15:28:47 +0100 Subject: [PATCH 05/25] loggingstore inherits from wrapperstore --- src/zarr/storage/logging.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/zarr/storage/logging.py b/src/zarr/storage/logging.py index d3e55c0687..925c835bba 100644 --- a/src/zarr/storage/logging.py +++ b/src/zarr/storage/logging.py @@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any, Self from zarr.abc.store import AccessMode, ByteRangeRequest, Store +from zarr.storage.wrapper import WrapperStore if TYPE_CHECKING: from collections.abc import AsyncGenerator, Generator, Iterable @@ -15,8 +16,10 @@ from zarr.core.buffer import Buffer, BufferPrototype from zarr.core.common import AccessModeLiteral + counter: defaultdict[str, int] + -class LoggingStore(Store): +class LoggingStore(WrapperStore[Store]): """ Store wrapper that logs all calls to the wrapped store. @@ -35,7 +38,6 @@ class LoggingStore(Store): Counter of number of times each method has been called """ - _store: Store counter: defaultdict[str, int] def __init__( @@ -44,11 +46,10 @@ def __init__( log_level: str = "DEBUG", log_handler: logging.Handler | None = None, ) -> None: - self._store = store + super().__init__(store) self.counter = defaultdict(int) self.log_level = log_level self.log_handler = log_handler - self._configure_logger(log_level, log_handler) def _configure_logger( From ffca71030d7e92cef550e48a8d74d3907d340eb9 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Fri, 8 Nov 2024 20:54:02 +0100 Subject: [PATCH 06/25] initial commit --- scratch.py | 28 ++++++++++++++++++++++++++++ src/zarr/core/group.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+) create mode 100644 scratch.py diff --git a/scratch.py b/scratch.py new file mode 100644 index 0000000000..a2e0da854c --- /dev/null +++ b/scratch.py @@ -0,0 +1,28 @@ +import asyncio +from time import time + +import pytest + +import zarr +from zarr.storage import MemoryStore +from zarr.testing.store import LatencyStore + + +@pytest.mark.parametrize("num_members", [4, 8, 16]) +def test_collect_members(num_members: int) -> None: + local_store = MemoryStore(mode="a") + local_latency_store = LatencyStore(local_store, get_latency=0.1, set_latency=0.0) + + root_group_raw = zarr.open_group(store=local_store) + root_group_latency = zarr.open_group(store=local_latency_store) + for i in range(num_members): + root_group_raw.create_group(f"group_{i}") + + async def get_members_async() -> None: + members2 = await root_group_latency._async_group._members2(max_depth=0, current_depth=0) + print(members2) + + start = time() + asyncio.run(get_members_async()) + elapsed = time() - start + print(elapsed) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 9a54b346b0..aa8f4e9fa5 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -1846,3 +1846,35 @@ def array( ) ) ) + + +async def members_v3( + node: AsyncGroup, max_depth: int | None, current_depth: int, prototype: BufferPrototype +) -> Any: + node_path = node.store_path.path + metadata_keys = ("zarr.json",) + + members_flat: dict[ + str, AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | AsyncGroup + ] = {"": node} + + group_queue: asyncio.Queue[AsyncGroup] = asyncio.Queue() + key_queue: asyncio.Queue[str] = asyncio.Queue() + + keys = [key async for key in node.store_path.store.list_dir(node_path)] + keys_filtered = tuple(filter(lambda v: v not in metadata_keys, keys)) + doc_keys = [] + + for key in keys_filtered: + for metadata_key in metadata_keys: + doc_keys.append("/".join([key, metadata_key])) + + # optimistically fetch extant metadata documents + blobs = asyncio.gather( + *(node.store.get(key, prototype=prototype) for key in doc_keys), + ) + # insert resolved metadata_documents into members_flat + + # repeat for groups + + return objects From d33cb7d5d2f9d16ffbd7d47821df67f6cfc24993 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Fri, 8 Nov 2024 23:21:36 +0100 Subject: [PATCH 07/25] working members traversal --- scratch.py | 31 +++++++++++++++++-------- src/zarr/core/group.py | 52 ++++++++++++++++++++++++++++++------------ 2 files changed, 60 insertions(+), 23 deletions(-) diff --git a/scratch.py b/scratch.py index a2e0da854c..e0a3053c86 100644 --- a/scratch.py +++ b/scratch.py @@ -1,28 +1,41 @@ import asyncio from time import time +from typing import Literal import pytest import zarr +from zarr.core.group import members_v3 from zarr.storage import MemoryStore from zarr.testing.store import LatencyStore @pytest.mark.parametrize("num_members", [4, 8, 16]) -def test_collect_members(num_members: int) -> None: +@pytest.mark.parametrize("method", ["default", "fast_members"]) +def test_collect_members(num_members: int, method: Literal["fast_members", "default"]) -> None: local_store = MemoryStore(mode="a") local_latency_store = LatencyStore(local_store, get_latency=0.1, set_latency=0.0) root_group_raw = zarr.open_group(store=local_store) root_group_latency = zarr.open_group(store=local_latency_store) + for i in range(num_members): - root_group_raw.create_group(f"group_{i}") + subgroup = root_group_raw.create_group(f"group_outer_{i}") + for j in range(num_members): + subgroup.create_group(f"group_inner_{j}") + + if method == "fast_members": - async def get_members_async() -> None: - members2 = await root_group_latency._async_group._members2(max_depth=0, current_depth=0) - print(members2) + async def amain() -> None: + res = await members_v3(local_latency_store, path="") + print(res) - start = time() - asyncio.run(get_members_async()) - elapsed = time() - start - print(elapsed) + start = time() + asyncio.run(amain()) + elapsed = time() - start + print(elapsed) + else: + start = time() + root_group_latency.members(max_depth=None) + elapsed = time() - start + print(elapsed) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index aa8f4e9fa5..c2d448ed36 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -1849,32 +1849,56 @@ def array( async def members_v3( - node: AsyncGroup, max_depth: int | None, current_depth: int, prototype: BufferPrototype + store: Store, + path: str, ) -> Any: - node_path = node.store_path.path metadata_keys = ("zarr.json",) - members_flat: dict[ - str, AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | AsyncGroup - ] = {"": node} + members_flat: tuple[tuple[str, ArrayV3Metadata | GroupMetadata], ...] = () - group_queue: asyncio.Queue[AsyncGroup] = asyncio.Queue() - key_queue: asyncio.Queue[str] = asyncio.Queue() - - keys = [key async for key in node.store_path.store.list_dir(node_path)] + keys = [key async for key in store.list_dir(path)] keys_filtered = tuple(filter(lambda v: v not in metadata_keys, keys)) doc_keys = [] for key in keys_filtered: for metadata_key in metadata_keys: - doc_keys.append("/".join([key, metadata_key])) + doc_keys.append("/".join([path, key, metadata_key]).lstrip("/")) # optimistically fetch extant metadata documents - blobs = asyncio.gather( - *(node.store.get(key, prototype=prototype) for key in doc_keys), + blobs = await asyncio.gather( + *(store.get(key, prototype=default_buffer_prototype()) for key in doc_keys) ) + + to_recurse = [] + # insert resolved metadata_documents into members_flat + for key, blob in zip(doc_keys, blobs, strict=False): + key_body = "/".join(key.split("/")[:-1]) + + if blob is not None: + resolved_metadata = resolve_metadata_v3(blob.to_bytes()) + members_flat += ((key_body, resolved_metadata),) + if isinstance(resolved_metadata, GroupMetadata): + to_recurse.append(members_v3(store, key_body)) + + # for r in to_recurse: + # members_flat += await r - # repeat for groups + subgroups = await asyncio.gather(*to_recurse) + members_flat += tuple(subgroup for subgroup in subgroups) - return objects + # recurse for groups + + return members_flat + + +def resolve_metadata_v3(blob: str | bytes | bytearray) -> ArrayV3Metadata | GroupMetadata: + zarr_json = json.loads(blob) + if "node_type" not in zarr_json: + raise ValueError("missing node_type in metadata document") + if zarr_json["node_type"] == "array": + return ArrayV3Metadata.from_dict(zarr_json) + elif zarr_json["node_type"] == "group": + return GroupMetadata.from_dict(zarr_json) + else: + raise ValueError("invalid node_type in metadata document") From 8f8797733c15d25ec6afc957b0ff729ccbb83464 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Mon, 18 Nov 2024 11:22:38 +0100 Subject: [PATCH 08/25] bolt concurrent members implementation onto async group --- src/zarr/core/array.py | 24 +++---- src/zarr/core/group.py | 139 ++++++++++++++++++++++++++++++++++++----- tests/test_group.py | 62 ++++++++++++++++++ 3 files changed, 195 insertions(+), 30 deletions(-) diff --git a/src/zarr/core/array.py b/src/zarr/core/array.py index 1646959cb5..e559de9dc0 100644 --- a/src/zarr/core/array.py +++ b/src/zarr/core/array.py @@ -786,7 +786,7 @@ def path(self) -> str: return self.store_path.path @property - def name(self) -> str | None: + def name(self) -> str: """Array name following h5py convention. Returns @@ -794,16 +794,14 @@ def name(self) -> str | None: str The name of the array. """ - if self.path: - # follow h5py convention: add leading slash - name = self.path - if name[0] != "/": - name = "/" + name - return name - return None + # follow h5py convention: add leading slash + name = self.path + if not name.startswith('/'): + name = "/" + name + return name @property - def basename(self) -> str | None: + def basename(self) -> str: """Final component of name. Returns @@ -811,9 +809,7 @@ def basename(self) -> str | None: str The basename or final component of the array name. """ - if self.name is not None: - return self.name.split("/")[-1] - return None + return self.name.split("/")[-1] @property def cdata_shape(self) -> ChunkCoords: @@ -1436,12 +1432,12 @@ def path(self) -> str: return self._async_array.path @property - def name(self) -> str | None: + def name(self) -> str: """Array name following h5py convention.""" return self._async_array.name @property - def basename(self) -> str | None: + def basename(self) -> str: """Final component of name.""" return self._async_array.basename diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index c2d448ed36..ed9f2e7f0d 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -19,7 +19,7 @@ from zarr.abc.store import Store, set_or_delete from zarr.core.array import Array, AsyncArray, _build_parents from zarr.core.attributes import Attributes -from zarr.core.buffer import default_buffer_prototype +from zarr.core.buffer import default_buffer_prototype, Buffer from zarr.core.common import ( JSON, ZARR_JSON, @@ -1151,10 +1151,10 @@ async def members( """ if max_depth is not None and max_depth < 0: raise ValueError(f"max_depth must be None or >= 0. Got '{max_depth}' instead") - async for item in self._members(max_depth=max_depth, current_depth=0): + async for item in self._members(max_depth=max_depth): yield item - async def _members( + async def _members_old( self, max_depth: int | None, current_depth: int ) -> AsyncGenerator[ tuple[str, AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | AsyncGroup], @@ -1162,7 +1162,7 @@ async def _members( ]: if self.metadata.consolidated_metadata is not None: # we should be able to do members without any additional I/O - members = self._members_consolidated(max_depth, current_depth) + members = self._members_consolidated(max_depth) for member in members: yield member return @@ -1202,8 +1202,7 @@ async def _members( # implies an AsyncGroup, not an AsyncArray assert isinstance(obj, AsyncGroup) async for child_key, val in obj._members( - max_depth=max_depth, current_depth=current_depth + 1 - ): + max_depth=max_depth): yield f"{key}/{child_key}", val except KeyError: # keyerror is raised when `key` names an object (in the object storage sense), @@ -1216,12 +1215,14 @@ async def _members( ) def _members_consolidated( - self, max_depth: int | None, current_depth: int, prefix: str = "" + self, max_depth: int | None, prefix: str = "" ) -> Generator[ tuple[str, AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | AsyncGroup], None, ]: consolidated_metadata = self.metadata.consolidated_metadata + + do_recursion = max_depth is None or max_depth > 0 # we kind of just want the top-level keys. if consolidated_metadata is not None: @@ -1232,10 +1233,43 @@ def _members_consolidated( key = f"{prefix}/{key}".lstrip("/") yield key, obj - if ((max_depth is None) or (current_depth < max_depth)) and isinstance( + if do_recursion and isinstance( obj, AsyncGroup ): - yield from obj._members_consolidated(max_depth, current_depth + 1, prefix=key) + if max_depth is None: + new_depth = None + else: + new_depth = max_depth - 1 + yield from obj._members_consolidated(new_depth, prefix=key) + + async def _members( + self, + max_depth: int | None) -> AsyncGenerator[tuple[str, AsyncArray[ArrayV3Metadata] | AsyncArray[ArrayV2Metadata] | AsyncGroup], None]: + skip_keys: tuple[str, ...] + if self.metadata.zarr_format == 2: + skip_keys = ('.zattrs', '.zgroup','.zarray', '.zmetadata') + elif self.metadata.zarr_format == 3: + skip_keys = ('zarr.json',) + else: + raise ValueError(f"Unknown Zarr format: {self.metadata.zarr_format}") + + if self.metadata.consolidated_metadata is not None: + members = self._members_consolidated(max_depth=max_depth) + for member in members: + yield member + return + + if not self.store_path.store.supports_listing: + msg = ( + f"The store associated with this group ({type(self.store_path.store)}) " + "does not support listing, " + "specifically via the `list_dir` method. " + "This function requires a store that supports listing." + ) + + raise ValueError(msg) + async for member in iter_members_deep(self, max_depth=max_depth, prefix=self.basename, skip_keys=skip_keys): + yield member async def keys(self) -> AsyncGenerator[str, None]: async for key, _ in self.members(): @@ -1848,10 +1882,13 @@ def array( ) -async def members_v3( +async def members_recursive( store: Store, path: str, ) -> Any: + """ + Recursively fetch all members of a group. + """ metadata_keys = ("zarr.json",) members_flat: tuple[tuple[str, ArrayV3Metadata | GroupMetadata], ...] = () @@ -1879,18 +1916,88 @@ async def members_v3( resolved_metadata = resolve_metadata_v3(blob.to_bytes()) members_flat += ((key_body, resolved_metadata),) if isinstance(resolved_metadata, GroupMetadata): - to_recurse.append(members_v3(store, key_body)) - - # for r in to_recurse: - # members_flat += await r + to_recurse.append( + members_recursive(store, key_body)) subgroups = await asyncio.gather(*to_recurse) members_flat += tuple(subgroup for subgroup in subgroups) - # recurse for groups - return members_flat +async def iter_members( + node: AsyncGroup, + skip_keys: tuple[str, ...] +) -> AsyncGenerator[tuple[str, AsyncArray[ArrayV3Metadata] | AsyncArray[ArrayV2Metadata] | AsyncGroup], None]: + """ + Iterate over the arrays and groups contained in a group. + """ + + # retrieve keys from storage + keys = [key async for key in node.store.list_dir(node.path)] + keys_filtered = tuple(filter(lambda v: v not in skip_keys, keys)) + + node_tasks = tuple(asyncio.create_task( + node.getitem(key), name=key) for key in keys_filtered) + + for fetched_node_coro in asyncio.as_completed(node_tasks): + try: + fetched_node = await fetched_node_coro + except KeyError as e: + # keyerror is raised when `key` names an object (in the object storage sense), + # as opposed to a prefix, in the store under the prefix associated with this group + # in which case `key` cannot be the name of a sub-array or sub-group. + warnings.warn( + f"Object at {e.args[0]} is not recognized as a component of a Zarr hierarchy.", + UserWarning, + stacklevel=1, + ) + continue + match fetched_node: + case AsyncArray() | AsyncGroup(): + yield fetched_node.basename, fetched_node + case _: + raise ValueError(f"Unexpected type: {type(fetched_node)}") + +async def iter_members_deep( + group: AsyncGroup, + *, + prefix: str, + max_depth: int | None, + skip_keys: tuple[str, ...] +) -> AsyncGenerator[tuple[str, AsyncArray[ArrayV3Metadata] | AsyncArray[ArrayV2Metadata] | AsyncGroup], None]: + """ + Iterate over the arrays and groups contained in a group, and optionally the + arrays and groups contained in those groups. + """ + + to_recurse = [] + do_recursion = max_depth is None or max_depth > 0 + if max_depth is None: + new_depth = None + else: + new_depth = max_depth - 1 + + async for name, node in iter_members(group, skip_keys=skip_keys): + yield f'{prefix}/{name}'.lstrip('/'), node + if isinstance(node, AsyncGroup) and do_recursion: + to_recurse.append(iter_members_deep( + node, + max_depth=new_depth, + prefix=f'{prefix}/{name}', + skip_keys=skip_keys)) + + for subgroup in to_recurse: + async for name, node in subgroup: + yield name, node + + +def resolve_metadata_v2(blobs: tuple[str | bytes | bytearray, str | bytes | bytearray]) -> ArrayV2Metadata | GroupMetadata: + zarr_metadata = json.loads(blobs[0]) + attrs = json.loads(blobs[1]) + if 'shape' in zarr_metadata: + return ArrayV2Metadata.from_dict(zarr_metadata | {'attrs': attrs}) + else: + return GroupMetadata.from_dict(zarr_metadata | {'attrs': attrs}) def resolve_metadata_v3(blob: str | bytes | bytearray) -> ArrayV3Metadata | GroupMetadata: zarr_json = json.loads(blob) diff --git a/tests/test_group.py b/tests/test_group.py index 6bacca4889..04fe9497c0 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -207,6 +207,68 @@ def test_group_members(store: Store, zarr_format: ZarrFormat, consolidated_metad with pytest.raises(ValueError, match="max_depth"): members_observed = group.members(max_depth=-1) +def test_group_members_2(store: Store, zarr_format: ZarrFormat) -> None: + """ + Test that `Group.members` returns correct values, i.e. the arrays and groups + (explicit and implicit) contained in that group. + """ + # group/ + # subgroup/ + # subsubgroup/ + # subsubsubgroup + # subarray + + path = "group" + group = Group.from_store( + store=store, + zarr_format=zarr_format, + ) + members_expected: dict[str, Array | Group] = {} + + members_expected["subgroup"] = group.create_group("subgroup") + # make a sub-sub-subgroup, to ensure that the children calculation doesn't go + # too deep in the hierarchy + subsubgroup = members_expected["subgroup"].create_group("subsubgroup") + subsubsubgroup = subsubgroup.create_group("subsubsubgroup") + + members_expected["subarray"] = group.create_array( + "subarray", shape=(100,), dtype="uint8", chunk_shape=(10,), exists_ok=True + ) + + # add an extra object to the domain of the group. + # the list of children should ignore this object. + sync( + store.set( + f"{path}/extra_object-1", + default_buffer_prototype().buffer.from_bytes(b"000000"), + ) + ) + # add an extra object under a directory-like prefix in the domain of the group. + # this creates a directory with a random key in it + # this should not show up as a member + sync( + store.set( + f"{path}/extra_directory/extra_object-2", + default_buffer_prototype().buffer.from_bytes(b"000000"), + ) + ) + + # this warning shows up when extra objects show up in the hierarchy + warn_context = pytest.warns( + UserWarning, match=r"Object at .* is not recognized as a component of a Zarr hierarchy." + ) + + with warn_context: + members_observed = group.members() + # members are not guaranteed to be ordered, so sort before comparing + assert sorted(dict(members_observed)) == sorted(members_expected) + + # partial + with warn_context: + members_observed = group.members(max_depth=1) + members_expected["subgroup/subsubgroup"] = subsubgroup + # members are not guaranteed to be ordered, so sort before comparing + assert sorted(dict(members_observed)) == sorted(members_expected) def test_group(store: Store, zarr_format: ZarrFormat) -> None: """ From 87e0b83ad6c05373bc1e3174e0b93d614d45efc2 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Mon, 18 Nov 2024 12:57:08 +0100 Subject: [PATCH 09/25] update scratch file --- scratch.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/scratch.py b/scratch.py index e0a3053c86..8edb1c3fca 100644 --- a/scratch.py +++ b/scratch.py @@ -5,14 +5,14 @@ import pytest import zarr -from zarr.core.group import members_v3 +from zarr.core.group import iter_members, members_recursive, iter_members_deep from zarr.storage import MemoryStore from zarr.testing.store import LatencyStore -@pytest.mark.parametrize("num_members", [4, 8, 16]) +@pytest.mark.parametrize("num_members", [10, 100, 1000]) @pytest.mark.parametrize("method", ["default", "fast_members"]) -def test_collect_members(num_members: int, method: Literal["fast_members", "default"]) -> None: +def test_collect_members(num_members: int, method: Literal["fast_members", "default", "fast_members_2"]) -> None: local_store = MemoryStore(mode="a") local_latency_store = LatencyStore(local_store, get_latency=0.1, set_latency=0.0) @@ -21,14 +21,10 @@ def test_collect_members(num_members: int, method: Literal["fast_members", "defa for i in range(num_members): subgroup = root_group_raw.create_group(f"group_outer_{i}") - for j in range(num_members): - subgroup.create_group(f"group_inner_{j}") if method == "fast_members": - async def amain() -> None: - res = await members_v3(local_latency_store, path="") - print(res) + res = [x async for x in iter_members(root_group_latency._async_group)] start = time() asyncio.run(amain()) From 502ad5e343ae1b4f9af1717146c57d450164facd Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Wed, 20 Nov 2024 22:30:16 +0100 Subject: [PATCH 10/25] use metadata / node builders for v3 node creation --- src/zarr/core/group.py | 198 +++++++++++++++++------------------------ 1 file changed, 84 insertions(+), 114 deletions(-) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index ed9f2e7f0d..47c0aa7a2a 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -19,7 +19,7 @@ from zarr.abc.store import Store, set_or_delete from zarr.core.array import Array, AsyncArray, _build_parents from zarr.core.attributes import Attributes -from zarr.core.buffer import default_buffer_prototype, Buffer +from zarr.core.buffer import Buffer, default_buffer_prototype from zarr.core.common import ( JSON, ZARR_JSON, @@ -645,12 +645,10 @@ async def getitem( raise KeyError(key) else: zarr_json = json.loads(zarr_json_bytes.to_bytes()) - if zarr_json["node_type"] == "group": - return type(self).from_dict(store_path, zarr_json) - elif zarr_json["node_type"] == "array": - return AsyncArray.from_dict(store_path, zarr_json) - else: - raise ValueError(f"unexpected node_type: {zarr_json['node_type']}") + metadata = build_metadata_v3(zarr_json) + node = build_node_v3(metadata, store_path) + return node + elif self.metadata.zarr_format == 2: # Q: how do we like optimistically fetching .zgroup, .zarray, and .zattrs? # This guarantees that we will always make at least one extra request to the store @@ -1154,66 +1152,6 @@ async def members( async for item in self._members(max_depth=max_depth): yield item - async def _members_old( - self, max_depth: int | None, current_depth: int - ) -> AsyncGenerator[ - tuple[str, AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | AsyncGroup], - None, - ]: - if self.metadata.consolidated_metadata is not None: - # we should be able to do members without any additional I/O - members = self._members_consolidated(max_depth) - for member in members: - yield member - return - - if not self.store_path.store.supports_listing: - msg = ( - f"The store associated with this group ({type(self.store_path.store)}) " - "does not support listing, " - "specifically via the `list_dir` method. " - "This function requires a store that supports listing." - ) - - raise ValueError(msg) - # would be nice to make these special keys accessible programmatically, - # and scoped to specific zarr versions - # especially true for `.zmetadata` which is configurable - _skip_keys = ("zarr.json", ".zgroup", ".zattrs", ".zmetadata") - - # hmm lots of I/O and logic interleaved here. - # We *could* have an async gen over self.metadata.consolidated_metadata.metadata.keys() - # and plug in here. `getitem` will skip I/O. - # Kinda a shame to have all the asyncio task overhead though, when it isn't needed. - - async for key in self.store_path.store.list_dir(self.store_path.path): - if key in _skip_keys: - continue - try: - obj = await self.getitem(key) - yield (key, obj) - - if ( - ((max_depth is None) or (current_depth < max_depth)) - and hasattr(obj.metadata, "node_type") - and obj.metadata.node_type == "group" - ): - # the assert is just for mypy to know that `obj.metadata.node_type` - # implies an AsyncGroup, not an AsyncArray - assert isinstance(obj, AsyncGroup) - async for child_key, val in obj._members( - max_depth=max_depth): - yield f"{key}/{child_key}", val - except KeyError: - # keyerror is raised when `key` names an object (in the object storage sense), - # as opposed to a prefix, in the store under the prefix associated with this group - # in which case `key` cannot be the name of a sub-array or sub-group. - warnings.warn( - f"Object at {key} is not recognized as a component of a Zarr hierarchy.", - UserWarning, - stacklevel=1, - ) - def _members_consolidated( self, max_depth: int | None, prefix: str = "" ) -> Generator[ @@ -1221,7 +1159,7 @@ def _members_consolidated( None, ]: consolidated_metadata = self.metadata.consolidated_metadata - + do_recursion = max_depth is None or max_depth > 0 # we kind of just want the top-level keys. @@ -1233,23 +1171,23 @@ def _members_consolidated( key = f"{prefix}/{key}".lstrip("/") yield key, obj - if do_recursion and isinstance( - obj, AsyncGroup - ): + if do_recursion and isinstance(obj, AsyncGroup): if max_depth is None: - new_depth = None + new_depth = None else: new_depth = max_depth - 1 yield from obj._members_consolidated(new_depth, prefix=key) - + async def _members( - self, - max_depth: int | None) -> AsyncGenerator[tuple[str, AsyncArray[ArrayV3Metadata] | AsyncArray[ArrayV2Metadata] | AsyncGroup], None]: + self, max_depth: int | None + ) -> AsyncGenerator[ + tuple[str, AsyncArray[ArrayV3Metadata] | AsyncArray[ArrayV2Metadata] | AsyncGroup], None + ]: skip_keys: tuple[str, ...] if self.metadata.zarr_format == 2: - skip_keys = ('.zattrs', '.zgroup','.zarray', '.zmetadata') + skip_keys = (".zattrs", ".zgroup", ".zarray", ".zmetadata") elif self.metadata.zarr_format == 3: - skip_keys = ('zarr.json',) + skip_keys = ("zarr.json",) else: raise ValueError(f"Unknown Zarr format: {self.metadata.zarr_format}") @@ -1268,7 +1206,9 @@ async def _members( ) raise ValueError(msg) - async for member in iter_members_deep(self, max_depth=max_depth, prefix=self.basename, skip_keys=skip_keys): + async for member in iter_members_deep( + self, max_depth=max_depth, prefix=self.basename, skip_keys=skip_keys + ): yield member async def keys(self) -> AsyncGenerator[str, None]: @@ -1913,31 +1853,31 @@ async def members_recursive( key_body = "/".join(key.split("/")[:-1]) if blob is not None: - resolved_metadata = resolve_metadata_v3(blob.to_bytes()) + resolved_metadata = build_metadata_v3(blob.to_bytes()) members_flat += ((key_body, resolved_metadata),) if isinstance(resolved_metadata, GroupMetadata): - to_recurse.append( - members_recursive(store, key_body)) + to_recurse.append(members_recursive(store, key_body)) subgroups = await asyncio.gather(*to_recurse) members_flat += tuple(subgroup for subgroup in subgroups) return members_flat + async def iter_members( - node: AsyncGroup, - skip_keys: tuple[str, ...] -) -> AsyncGenerator[tuple[str, AsyncArray[ArrayV3Metadata] | AsyncArray[ArrayV2Metadata] | AsyncGroup], None]: + node: AsyncGroup, skip_keys: tuple[str, ...] +) -> AsyncGenerator[ + tuple[str, AsyncArray[ArrayV3Metadata] | AsyncArray[ArrayV2Metadata] | AsyncGroup], None +]: """ Iterate over the arrays and groups contained in a group. """ - + # retrieve keys from storage keys = [key async for key in node.store.list_dir(node.path)] keys_filtered = tuple(filter(lambda v: v not in skip_keys, keys)) - node_tasks = tuple(asyncio.create_task( - node.getitem(key), name=key) for key in keys_filtered) + node_tasks = tuple(asyncio.create_task(node.getitem(key), name=key) for key in keys_filtered) for fetched_node_coro in asyncio.as_completed(node_tasks): try: @@ -1958,15 +1898,14 @@ async def iter_members( case _: raise ValueError(f"Unexpected type: {type(fetched_node)}") + async def iter_members_deep( - group: AsyncGroup, - *, - prefix: str, - max_depth: int | None, - skip_keys: tuple[str, ...] -) -> AsyncGenerator[tuple[str, AsyncArray[ArrayV3Metadata] | AsyncArray[ArrayV2Metadata] | AsyncGroup], None]: + group: AsyncGroup, *, prefix: str, max_depth: int | None, skip_keys: tuple[str, ...] +) -> AsyncGenerator[ + tuple[str, AsyncArray[ArrayV3Metadata] | AsyncArray[ArrayV2Metadata] | AsyncGroup], None +]: """ - Iterate over the arrays and groups contained in a group, and optionally the + Iterate over the arrays and groups contained in a group, and optionally the arrays and groups contained in those groups. """ @@ -1978,34 +1917,65 @@ async def iter_members_deep( new_depth = max_depth - 1 async for name, node in iter_members(group, skip_keys=skip_keys): - yield f'{prefix}/{name}'.lstrip('/'), node + yield f"{prefix}/{name}".lstrip("/"), node if isinstance(node, AsyncGroup) and do_recursion: - to_recurse.append(iter_members_deep( - node, - max_depth=new_depth, - prefix=f'{prefix}/{name}', - skip_keys=skip_keys)) + to_recurse.append( + iter_members_deep( + node, max_depth=new_depth, prefix=f"{prefix}/{name}", skip_keys=skip_keys + ) + ) for subgroup in to_recurse: async for name, node in subgroup: yield name, node - -def resolve_metadata_v2(blobs: tuple[str | bytes | bytearray, str | bytes | bytearray]) -> ArrayV2Metadata | GroupMetadata: + +def resolve_metadata_v2( + blobs: tuple[str | bytes | bytearray, str | bytes | bytearray], +) -> ArrayV2Metadata | GroupMetadata: zarr_metadata = json.loads(blobs[0]) attrs = json.loads(blobs[1]) - if 'shape' in zarr_metadata: - return ArrayV2Metadata.from_dict(zarr_metadata | {'attrs': attrs}) + if "shape" in zarr_metadata: + return ArrayV2Metadata.from_dict(zarr_metadata | {"attrs": attrs}) else: - return GroupMetadata.from_dict(zarr_metadata | {'attrs': attrs}) + return GroupMetadata.from_dict(zarr_metadata | {"attrs": attrs}) + -def resolve_metadata_v3(blob: str | bytes | bytearray) -> ArrayV3Metadata | GroupMetadata: - zarr_json = json.loads(blob) +def build_metadata_v3(zarr_json: dict[str, Any]) -> ArrayV3Metadata | GroupMetadata: + """ + Take a dict and convert it into the correct metadata type. + """ if "node_type" not in zarr_json: - raise ValueError("missing node_type in metadata document") - if zarr_json["node_type"] == "array": - return ArrayV3Metadata.from_dict(zarr_json) - elif zarr_json["node_type"] == "group": - return GroupMetadata.from_dict(zarr_json) - else: - raise ValueError("invalid node_type in metadata document") + raise KeyError("missing `node_type` key in metadata document.") + match zarr_json: + case {"node_type": "array"}: + return ArrayV3Metadata.from_dict(zarr_json) + case {"node_type": "group"}: + return GroupMetadata.from_dict(zarr_json) + case _: + raise ValueError("invalid value for `node_type` key in metadata document") + + +def build_metadata_v2( + zarr_json: dict[str, Any], attrs_json: dict[str, Any] +) -> ArrayV2Metadata | GroupMetadata: + match zarr_json: + case {"shape": _}: + return ArrayV2Metadata.from_dict(zarr_json | {"attributes": attrs_json}) + case _: + return GroupMetadata.from_dict(zarr_json | {"attributes": attrs_json}) + + +def build_node_v3( + metadata: ArrayV3Metadata | GroupMetadata, store_path: StorePath +) -> AsyncArray[ArrayV3Metadata] | AsyncGroup: + """ + Take a metadata object and return a node (AsyncArray or AsyncGroup). + """ + match metadata: + case ArrayV3Metadata(): + return AsyncArray(metadata, store_path=store_path) + case GroupMetadata(): + return AsyncGroup(metadata, store_path=store_path) + case _: + raise ValueError(f"Unexpected metadata type: {type(metadata)}") From d10d8057ed5fb69c9d94bb9b8a4296d6cc79fe6b Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Tue, 26 Nov 2024 12:41:57 +0100 Subject: [PATCH 11/25] fix key/name handling in recursion --- src/zarr/api/asynchronous.py | 2 - src/zarr/core/group.py | 76 ++++++++++++++++++++---------------- 2 files changed, 43 insertions(+), 35 deletions(-) diff --git a/src/zarr/api/asynchronous.py b/src/zarr/api/asynchronous.py index 3f36614cc2..5e4855edb2 100644 --- a/src/zarr/api/asynchronous.py +++ b/src/zarr/api/asynchronous.py @@ -186,14 +186,12 @@ async def consolidate_metadata( group.store_path.store._check_writable() members_metadata = {k: v.metadata async for k, v in group.members(max_depth=None)} - # While consolidating, we want to be explicit about when child groups # are empty by inserting an empty dict for consolidated_metadata.metadata for k, v in members_metadata.items(): if isinstance(v, GroupMetadata) and v.consolidated_metadata is None: v = dataclasses.replace(v, consolidated_metadata=ConsolidatedMetadata(metadata={})) members_metadata[k] = v - ConsolidatedMetadata._flat_to_nested(members_metadata) consolidated_metadata = ConsolidatedMetadata(metadata=members_metadata) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 476ee2d6b9..ba602b77db 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -650,6 +650,7 @@ async def getitem( """ store_path = self.store_path / key logger.debug("key=%s, store_path=%s", key, store_path) + metadata: ArrayV2Metadata | ArrayV3Metadata | GroupMetadata # Consolidated metadata lets us avoid some I/O operations so try that first. if self.metadata.consolidated_metadata is not None: @@ -666,8 +667,8 @@ async def getitem( raise KeyError(key) else: zarr_json = json.loads(zarr_json_bytes.to_bytes()) - metadata = build_metadata_v3(zarr_json) - return build_node_v3(metadata, store_path) + metadata = _build_metadata_v3(zarr_json) + return _build_node_v3(metadata, store_path) elif self.metadata.zarr_format == 2: # Q: how do we like optimistically fetching .zgroup, .zarray, and .zattrs? @@ -683,21 +684,16 @@ async def getitem( # unpack the zarray, if this is None then we must be opening a group zarray = json.loads(zarray_bytes.to_bytes()) if zarray_bytes else None + zgroup = json.loads(zgroup_bytes.to_bytes()) if zgroup_bytes else None # unpack the zattrs, this can be None if no attrs were written zattrs = json.loads(zattrs_bytes.to_bytes()) if zattrs_bytes is not None else {} if zarray is not None: - # TODO: update this once the V2 array support is part of the primary array class - zarr_json = {**zarray, "attributes": zattrs} - return AsyncArray.from_dict(store_path, zarr_json) + metadata = _build_metadata_v2(zarray, zattrs) + return _build_node_v2(metadata=metadata, store_path=store_path) else: - zgroup = ( - json.loads(zgroup_bytes.to_bytes()) - if zgroup_bytes is not None - else {"zarr_format": self.metadata.zarr_format} - ) - zarr_json = {**zgroup, "attributes": zattrs} - return type(self).from_dict(store_path, zarr_json) + metadata = _build_metadata_v2(zgroup, zattrs) + return _build_node_v2(metadata=metadata, store_path=store_path) else: raise ValueError(f"unexpected zarr_format: {self.metadata.zarr_format}") @@ -1332,9 +1328,7 @@ async def _members( ) raise ValueError(msg) - async for member in iter_members_deep( - self, max_depth=max_depth, prefix=self.basename, skip_keys=skip_keys - ): + async for member in _iter_members_deep(self, max_depth=max_depth, skip_keys=skip_keys): yield member async def keys(self) -> AsyncGenerator[str, None]: @@ -2633,7 +2627,7 @@ async def members_recursive( key_body = "/".join(key.split("/")[:-1]) if blob is not None: - resolved_metadata = build_metadata_v3(json.loads(blob.to_bytes())) + resolved_metadata = _build_metadata_v3(json.loads(blob.to_bytes())) members_flat += ((key_body, resolved_metadata),) if isinstance(resolved_metadata, GroupMetadata): to_recurse.append(members_recursive(store, key_body)) @@ -2679,8 +2673,8 @@ async def iter_members( raise ValueError(f"Unexpected type: {type(fetched_node)}") -async def iter_members_deep( - group: AsyncGroup, *, prefix: str, max_depth: int | None, skip_keys: tuple[str, ...] +async def _iter_members_deep( + group: AsyncGroup, *, max_depth: int | None, skip_keys: tuple[str, ...] ) -> AsyncGenerator[ tuple[str, AsyncArray[ArrayV3Metadata] | AsyncArray[ArrayV2Metadata] | AsyncGroup], None ]: @@ -2689,28 +2683,25 @@ async def iter_members_deep( arrays and groups contained in those groups. """ - to_recurse = [] + to_recurse = {} do_recursion = max_depth is None or max_depth > 0 + if max_depth is None: new_depth = None else: new_depth = max_depth - 1 - async for name, node in iter_members(group, skip_keys=skip_keys): - yield f"{prefix}/{name}".lstrip("/"), node + yield name, node if isinstance(node, AsyncGroup) and do_recursion: - to_recurse.append( - iter_members_deep( - node, max_depth=new_depth, prefix=f"{prefix}/{name}", skip_keys=skip_keys - ) - ) + to_recurse[name] = _iter_members_deep(node, max_depth=new_depth, skip_keys=skip_keys) - for subgroup in to_recurse: - async for name, node in subgroup: - yield name, node + for prefix, subgroup_iter in to_recurse.items(): + async for name, node in subgroup_iter: + key = f"{prefix}/{name}".lstrip("/") + yield key, node -def resolve_metadata_v2( +def _resolve_metadata_v2( blobs: tuple[str | bytes | bytearray, str | bytes | bytearray], ) -> ArrayV2Metadata | GroupMetadata: zarr_metadata = json.loads(blobs[0]) @@ -2721,7 +2712,7 @@ def resolve_metadata_v2( return GroupMetadata.from_dict(zarr_metadata | {"attrs": attrs}) -def build_metadata_v3(zarr_json: dict[str, Any]) -> ArrayV3Metadata | GroupMetadata: +def _build_metadata_v3(zarr_json: dict[str, Any]) -> ArrayV3Metadata | GroupMetadata: """ Take a dict and convert it into the correct metadata type. """ @@ -2736,9 +2727,12 @@ def build_metadata_v3(zarr_json: dict[str, Any]) -> ArrayV3Metadata | GroupMetad raise ValueError("invalid value for `node_type` key in metadata document") -def build_metadata_v2( +def _build_metadata_v2( zarr_json: dict[str, Any], attrs_json: dict[str, Any] ) -> ArrayV2Metadata | GroupMetadata: + """ + Take a dict and convert it into the correct metadata type. + """ match zarr_json: case {"shape": _}: return ArrayV2Metadata.from_dict(zarr_json | {"attributes": attrs_json}) @@ -2746,7 +2740,7 @@ def build_metadata_v2( return GroupMetadata.from_dict(zarr_json | {"attributes": attrs_json}) -def build_node_v3( +def _build_node_v3( metadata: ArrayV3Metadata | GroupMetadata, store_path: StorePath ) -> AsyncArray[ArrayV3Metadata] | AsyncGroup: """ @@ -2759,3 +2753,19 @@ def build_node_v3( return AsyncGroup(metadata, store_path=store_path) case _: raise ValueError(f"Unexpected metadata type: {type(metadata)}") + + +def _build_node_v2( + metadata: ArrayV2Metadata | GroupMetadata, store_path: StorePath +) -> AsyncArray[ArrayV2Metadata] | AsyncGroup: + """ + Take a metadata object and return a node (AsyncArray or AsyncGroup). + """ + + match metadata: + case ArrayV2Metadata(): + return AsyncArray(metadata, store_path=store_path) + case GroupMetadata(): + return AsyncGroup(metadata, store_path=store_path) + case _: + raise ValueError(f"Unexpected metadata type: {type(metadata)}") From 4c624e1770bd4d258fcae951f3b75e0edca5ee5b Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Tue, 26 Nov 2024 12:42:08 +0100 Subject: [PATCH 12/25] add latency-based test --- tests/test_group.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/test_group.py b/tests/test_group.py index 3135f5b736..23d637b1d0 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -3,6 +3,7 @@ import contextlib import operator import pickle +import time import warnings from typing import TYPE_CHECKING, Any, Literal @@ -22,6 +23,7 @@ from zarr.errors import ContainsArrayError, ContainsGroupError from zarr.storage import LocalStore, MemoryStore, StorePath, ZipStore from zarr.storage.common import make_store_path +from zarr.testing.store import LatencyStore from .conftest import parse_store @@ -1484,3 +1486,23 @@ def test_delitem_removes_children(store: Store, zarr_format: ZarrFormat) -> None del g1["0"] with pytest.raises(KeyError): g1["0/0"] + + +@pytest.mark.parametrize('store', ['memory'], indirect=True) +def test_group_members_performance(store: MemoryStore) -> None: + """ + Test that the performance of Group.members is robust to asynchronous latency + """ + get_latency = 0.1 + latency_store = LatencyStore(store, get_latency=get_latency) + + group = zarr.group(store=latency_store) + num_groups = 100 + # Create some groups + for i in range(num_groups): + group.create_group(f"group{i}") + + start= time.time() + members = group.members() + elapsed = start = time.time() + assert elapsed < 2 * get_latency \ No newline at end of file From f23ee851728f19300a81e5020ff7067adef37aa4 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Tue, 26 Nov 2024 13:48:38 +0100 Subject: [PATCH 13/25] add latency-based concurrency tests for group.members --- tests/test_group.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/tests/test_group.py b/tests/test_group.py index 23d637b1d0..0892a535d6 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -1488,21 +1488,27 @@ def test_delitem_removes_children(store: Store, zarr_format: ZarrFormat) -> None g1["0/0"] -@pytest.mark.parametrize('store', ['memory'], indirect=True) +@pytest.mark.parametrize("store", ["memory"], indirect=True) def test_group_members_performance(store: MemoryStore) -> None: """ Test that the performance of Group.members is robust to asynchronous latency """ get_latency = 0.1 - latency_store = LatencyStore(store, get_latency=get_latency) - group = zarr.group(store=latency_store) - num_groups = 100 + # use the input store to create some groups + group_create = zarr.group(store=store) + num_groups = 10 + # Create some groups for i in range(num_groups): - group.create_group(f"group{i}") + group_create.create_group(f"group{i}") - start= time.time() - members = group.members() - elapsed = start = time.time() - assert elapsed < 2 * get_latency \ No newline at end of file + latency_store = LatencyStore(store, get_latency=get_latency) + # create a group with some latency on get operations + group_read = zarr.group(store=latency_store) + + # check how long it takes to iterate over the groups + start = time.time() + _ = group_read.members() + elapsed = time.time() - start + assert elapsed < (1.1 * get_latency) + 0.001 From cba42f3f1ba75d5a0e4a1d4e730145efccb52ce5 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Tue, 26 Nov 2024 13:50:40 +0100 Subject: [PATCH 14/25] improve comments for test --- tests/test_group.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/test_group.py b/tests/test_group.py index 0892a535d6..a035dabaa5 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -1508,7 +1508,11 @@ def test_group_members_performance(store: MemoryStore) -> None: group_read = zarr.group(store=latency_store) # check how long it takes to iterate over the groups + # if .members is sensitive to IO latency, + # this should take (num_groups * get_latency) seconds + # otherwise, it should take only marginally more than get_latency seconds start = time.time() _ = group_read.members() elapsed = time.time() - start + assert elapsed < (1.1 * get_latency) + 0.001 From 9691102451059d8821ab4099874c5db59259f521 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Tue, 26 Nov 2024 15:15:31 +0100 Subject: [PATCH 15/25] add concurrency limit --- src/zarr/core/group.py | 39 +++++++++++++++++++++++++++++++++------ 1 file changed, 33 insertions(+), 6 deletions(-) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index ba602b77db..844e4626c4 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -692,6 +692,8 @@ async def getitem( metadata = _build_metadata_v2(zarray, zattrs) return _build_node_v2(metadata=metadata, store_path=store_path) else: + # this is just for mypy + assert zgroup is not None metadata = _build_metadata_v2(zgroup, zattrs) return _build_node_v2(metadata=metadata, store_path=store_path) else: @@ -1328,7 +1330,11 @@ async def _members( ) raise ValueError(msg) - async for member in _iter_members_deep(self, max_depth=max_depth, skip_keys=skip_keys): + # enforce a concurrency limit by passing a semaphore to all the recursive functions + semaphore = asyncio.Semaphore(config.get("async.concurrency")) + async for member in _iter_members_deep( + self, max_depth=max_depth, skip_keys=skip_keys, semaphore=semaphore + ): yield member async def keys(self) -> AsyncGenerator[str, None]: @@ -2638,8 +2644,20 @@ async def members_recursive( return members_flat +async def _getitem_semaphore( + node: AsyncGroup, key: str, semaphore: asyncio.Semaphore | None +) -> AsyncArray[ArrayV3Metadata] | AsyncArray[ArrayV2Metadata] | AsyncGroup: + if semaphore is not None: + async with semaphore: + return await node.getitem(key) + else: + return await node.getitem(key) + + async def iter_members( - node: AsyncGroup, skip_keys: tuple[str, ...] + node: AsyncGroup, + skip_keys: tuple[str, ...], + semaphore: asyncio.Semaphore | None, ) -> AsyncGenerator[ tuple[str, AsyncArray[ArrayV3Metadata] | AsyncArray[ArrayV2Metadata] | AsyncGroup], None ]: @@ -2651,7 +2669,10 @@ async def iter_members( keys = [key async for key in node.store.list_dir(node.path)] keys_filtered = tuple(filter(lambda v: v not in skip_keys, keys)) - node_tasks = tuple(asyncio.create_task(node.getitem(key), name=key) for key in keys_filtered) + node_tasks = tuple( + asyncio.create_task(_getitem_semaphore(node, key, semaphore), name=key) + for key in keys_filtered + ) for fetched_node_coro in asyncio.as_completed(node_tasks): try: @@ -2674,7 +2695,11 @@ async def iter_members( async def _iter_members_deep( - group: AsyncGroup, *, max_depth: int | None, skip_keys: tuple[str, ...] + group: AsyncGroup, + *, + max_depth: int | None, + skip_keys: tuple[str, ...], + semaphore: asyncio.Semaphore | None = None, ) -> AsyncGenerator[ tuple[str, AsyncArray[ArrayV3Metadata] | AsyncArray[ArrayV2Metadata] | AsyncGroup], None ]: @@ -2690,10 +2715,12 @@ async def _iter_members_deep( new_depth = None else: new_depth = max_depth - 1 - async for name, node in iter_members(group, skip_keys=skip_keys): + async for name, node in iter_members(group, skip_keys=skip_keys, semaphore=semaphore): yield name, node if isinstance(node, AsyncGroup) and do_recursion: - to_recurse[name] = _iter_members_deep(node, max_depth=new_depth, skip_keys=skip_keys) + to_recurse[name] = _iter_members_deep( + node, max_depth=new_depth, skip_keys=skip_keys, semaphore=semaphore + ) for prefix, subgroup_iter in to_recurse.items(): async for name, node in subgroup_iter: From d79037983da2d8de089bdbc0f4cac8bfcf392817 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Tue, 26 Nov 2024 15:24:33 +0100 Subject: [PATCH 16/25] add test for concurrency limiting --- scratch.py | 37 ------------------------------------- tests/test_group.py | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 37 deletions(-) delete mode 100644 scratch.py diff --git a/scratch.py b/scratch.py deleted file mode 100644 index 8edb1c3fca..0000000000 --- a/scratch.py +++ /dev/null @@ -1,37 +0,0 @@ -import asyncio -from time import time -from typing import Literal - -import pytest - -import zarr -from zarr.core.group import iter_members, members_recursive, iter_members_deep -from zarr.storage import MemoryStore -from zarr.testing.store import LatencyStore - - -@pytest.mark.parametrize("num_members", [10, 100, 1000]) -@pytest.mark.parametrize("method", ["default", "fast_members"]) -def test_collect_members(num_members: int, method: Literal["fast_members", "default", "fast_members_2"]) -> None: - local_store = MemoryStore(mode="a") - local_latency_store = LatencyStore(local_store, get_latency=0.1, set_latency=0.0) - - root_group_raw = zarr.open_group(store=local_store) - root_group_latency = zarr.open_group(store=local_latency_store) - - for i in range(num_members): - subgroup = root_group_raw.create_group(f"group_outer_{i}") - - if method == "fast_members": - async def amain() -> None: - res = [x async for x in iter_members(root_group_latency._async_group)] - - start = time() - asyncio.run(amain()) - elapsed = time() - start - print(elapsed) - else: - start = time() - root_group_latency.members(max_depth=None) - elapsed = time() - start - print(elapsed) diff --git a/tests/test_group.py b/tests/test_group.py index a035dabaa5..c66c2c00ce 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -1516,3 +1516,36 @@ def test_group_members_performance(store: MemoryStore) -> None: elapsed = time.time() - start assert elapsed < (1.1 * get_latency) + 0.001 + + +@pytest.mark.parametrize("store", ["memory"], indirect=True) +def test_group_members_concurrency_limit(store: MemoryStore) -> None: + """ + Test that the performance of Group.members is robust to asynchronous latency + """ + get_latency = 0.02 + + # use the input store to create some groups + group_create = zarr.group(store=store) + num_groups = 10 + + # Create some groups + for i in range(num_groups): + group_create.create_group(f"group{i}") + + latency_store = LatencyStore(store, get_latency=get_latency) + # create a group with some latency on get operations + group_read = zarr.group(store=latency_store) + + # check how long it takes to iterate over the groups + # if .members is sensitive to IO latency, + # this should take (num_groups * get_latency) seconds + # otherwise, it should take only marginally more than get_latency seconds + from zarr.core.config import config + + with config.set({"async.concurrency": 1}): + start = time.time() + _ = group_read.members() + elapsed = time.time() - start + + assert elapsed > num_groups * get_latency From aadbece8973e791844547139b5f03e5a4ec4f095 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Tue, 26 Nov 2024 15:32:21 +0100 Subject: [PATCH 17/25] docstrings --- tests/test_group.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_group.py b/tests/test_group.py index c66c2c00ce..1e0e1ab092 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -1491,7 +1491,7 @@ def test_delitem_removes_children(store: Store, zarr_format: ZarrFormat) -> None @pytest.mark.parametrize("store", ["memory"], indirect=True) def test_group_members_performance(store: MemoryStore) -> None: """ - Test that the performance of Group.members is robust to asynchronous latency + Test that the execution time of Group.members does not scale with asynchronous latency """ get_latency = 0.1 @@ -1521,7 +1521,8 @@ def test_group_members_performance(store: MemoryStore) -> None: @pytest.mark.parametrize("store", ["memory"], indirect=True) def test_group_members_concurrency_limit(store: MemoryStore) -> None: """ - Test that the performance of Group.members is robust to asynchronous latency + Test that the execution time of Group.members can be constrained by the async concurrency + configuration setting. """ get_latency = 0.02 From e19238da756bb1158ac68fb563566bc5c47d9776 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Tue, 26 Nov 2024 15:59:15 +0100 Subject: [PATCH 18/25] remove function that was only calling itself --- src/zarr/core/group.py | 43 ------------------------------------------ 1 file changed, 43 deletions(-) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 844e4626c4..805dabcde4 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -2601,49 +2601,6 @@ def array( ) -async def members_recursive( - store: Store, - path: str, -) -> Any: - """ - Recursively fetch all members of a group. - """ - metadata_keys = ("zarr.json",) - - members_flat: tuple[tuple[str, ArrayV3Metadata | GroupMetadata], ...] = () - - keys = [key async for key in store.list_dir(path)] - keys_filtered = tuple(filter(lambda v: v not in metadata_keys, keys)) - doc_keys = [] - - for key in keys_filtered: - doc_keys.extend( - [f"{path.lstrip('/')}/{key}/{metadata_key}" for metadata_key in metadata_keys] - ) - - # optimistically fetch extant metadata documents - blobs = await asyncio.gather( - *(store.get(key, prototype=default_buffer_prototype()) for key in doc_keys) - ) - - to_recurse = [] - - # insert resolved metadata_documents into members_flat - for key, blob in zip(doc_keys, blobs, strict=False): - key_body = "/".join(key.split("/")[:-1]) - - if blob is not None: - resolved_metadata = _build_metadata_v3(json.loads(blob.to_bytes())) - members_flat += ((key_body, resolved_metadata),) - if isinstance(resolved_metadata, GroupMetadata): - to_recurse.append(members_recursive(store, key_body)) - - subgroups = await asyncio.gather(*to_recurse) - members_flat += tuple(subgroup for subgroup in subgroups) - - return members_flat - - async def _getitem_semaphore( node: AsyncGroup, key: str, semaphore: asyncio.Semaphore | None ) -> AsyncArray[ArrayV3Metadata] | AsyncArray[ArrayV2Metadata] | AsyncGroup: From db74205f72fda179ba10d9465fff268a3f2ca8d3 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Tue, 26 Nov 2024 16:06:51 +0100 Subject: [PATCH 19/25] docstrings --- src/zarr/core/group.py | 38 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 36 insertions(+), 2 deletions(-) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 805dabcde4..ca3b9b55d1 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -2604,6 +2604,12 @@ def array( async def _getitem_semaphore( node: AsyncGroup, key: str, semaphore: asyncio.Semaphore | None ) -> AsyncArray[ArrayV3Metadata] | AsyncArray[ArrayV2Metadata] | AsyncGroup: + """ + Combine node.getitem with an optional semaphore. If the semaphore parameter is an + asyncio.Semaphore instance, then the getitem operation is performed inside an async context + manager provided by that semaphore. If the semaphore parameter is None, then getitem is invoked + without a context manager. + """ if semaphore is not None: async with semaphore: return await node.getitem(key) @@ -2611,7 +2617,7 @@ async def _getitem_semaphore( return await node.getitem(key) -async def iter_members( +async def _iter_members( node: AsyncGroup, skip_keys: tuple[str, ...], semaphore: asyncio.Semaphore | None, @@ -2620,6 +2626,19 @@ async def iter_members( ]: """ Iterate over the arrays and groups contained in a group. + + Parameters + ---------- + node : AsyncGroup + The group to traverse. + skip_keys : tuple[str, ...] + A tuple of keys to skip when iterating over the possible members of the group. + semaphore : asyncio.Semaphore | None + An optional semaphore to use for concurrency control. + + Yields + ------ + tuple[str, AsyncArray[ArrayV3Metadata] | AsyncArray[ArrayV2Metadata] | AsyncGroup] """ # retrieve keys from storage @@ -2663,6 +2682,21 @@ async def _iter_members_deep( """ Iterate over the arrays and groups contained in a group, and optionally the arrays and groups contained in those groups. + + Parameters + ---------- + group : AsyncGroup + The group to traverse. + max_depth : int | None + The maximum depth of recursion. + skip_keys : tuple[str, ...] + A tuple of keys to skip when iterating over the possible members of the group. + semaphore : asyncio.Semaphore | None + An optional semaphore to use for concurrency control. + + Yields + ------ + tuple[str, AsyncArray[ArrayV3Metadata] | AsyncArray[ArrayV2Metadata] | AsyncGroup] """ to_recurse = {} @@ -2672,7 +2706,7 @@ async def _iter_members_deep( new_depth = None else: new_depth = max_depth - 1 - async for name, node in iter_members(group, skip_keys=skip_keys, semaphore=semaphore): + async for name, node in _iter_members(group, skip_keys=skip_keys, semaphore=semaphore): yield name, node if isinstance(node, AsyncGroup) and do_recursion: to_recurse[name] = _iter_members_deep( From 46ba0cb1fbba64c20d298d99badace8bfa52f96c Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Wed, 27 Nov 2024 20:23:39 +0100 Subject: [PATCH 20/25] relax timing requirement for concurrency test --- tests/test_group.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_group.py b/tests/test_group.py index 1e0e1ab092..894c5673aa 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -1515,7 +1515,7 @@ def test_group_members_performance(store: MemoryStore) -> None: _ = group_read.members() elapsed = time.time() - start - assert elapsed < (1.1 * get_latency) + 0.001 + assert elapsed < (1.1 * get_latency) + 0.01 @pytest.mark.parametrize("store", ["memory"], indirect=True) From a48efa8bc9eb705d4b6b538fbd2f741dddaed6ad Mon Sep 17 00:00:00 2001 From: Davis Bennett Date: Thu, 12 Dec 2024 11:04:11 +0100 Subject: [PATCH 21/25] Update src/zarr/core/group.py Co-authored-by: Deepak Cherian --- src/zarr/core/group.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index e897c9d439..c8707fb68d 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -693,7 +693,8 @@ async def getitem( return _build_node_v2(metadata=metadata, store_path=store_path) else: # this is just for mypy - assert zgroup is not None + if TYPE_CHECKING: + assert zgroup is not None metadata = _build_metadata_v2(zgroup, zattrs) return _build_node_v2(metadata=metadata, store_path=store_path) else: From ebfc200609c6eb16067c42eafa97cbc7713832b4 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Wed, 18 Dec 2024 15:44:05 +0100 Subject: [PATCH 22/25] exists_ok -> overwrite --- tests/test_group.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_group.py b/tests/test_group.py index f7019a4494..e07fb1be7a 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -237,7 +237,7 @@ def test_group_members_2(store: Store, zarr_format: ZarrFormat) -> None: _ = subsubgroup.create_group("subsubsubgroup") members_expected["subarray"] = group.create_array( - "subarray", shape=(100,), dtype="uint8", chunk_shape=(10,), exists_ok=True + "subarray", shape=(100,), dtype="uint8", chunk_shape=(10,), overwrite=True ) # add an extra object to the domain of the group. From 39ec6b51174b624b2b0db10999318ab15b4da09c Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Wed, 18 Dec 2024 16:08:09 +0100 Subject: [PATCH 23/25] simplify group_members_perf test, just require that the duration is less than the number of groups * latency --- tests/test_group.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_group.py b/tests/test_group.py index e07fb1be7a..c9fe25fbca 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -1515,7 +1515,7 @@ def test_group_members_performance(store: MemoryStore) -> None: _ = group_read.members() elapsed = time.time() - start - assert elapsed < (1.1 * get_latency) + 0.01 + assert elapsed < (num_groups * get_latency) @pytest.mark.parametrize("store", ["memory"], indirect=True) From f25bda302d2b1b1968fc5559b75b467111efef5f Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Wed, 18 Dec 2024 16:39:53 +0100 Subject: [PATCH 24/25] update test docstring --- tests/test_group.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_group.py b/tests/test_group.py index c9fe25fbca..55b900cfa8 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -1491,7 +1491,8 @@ def test_delitem_removes_children(store: Store, zarr_format: ZarrFormat) -> None @pytest.mark.parametrize("store", ["memory"], indirect=True) def test_group_members_performance(store: MemoryStore) -> None: """ - Test that the execution time of Group.members does not scale with asynchronous latency + Test that the execution time of Group.members is less than the number of members times the + latency for accessing each member. """ get_latency = 0.1 From 08440e5f7a87db000fc5daeeb6052279757ea052 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Tue, 7 Jan 2025 14:45:30 +0100 Subject: [PATCH 25/25] remove vestigial test --- tests/test_group.py | 64 --------------------------------------------- 1 file changed, 64 deletions(-) diff --git a/tests/test_group.py b/tests/test_group.py index 198e260f6d..c2a5f751f3 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -211,70 +211,6 @@ def test_group_members(store: Store, zarr_format: ZarrFormat, consolidated_metad members_observed = group.members(max_depth=-1) -def test_group_members_2(store: Store, zarr_format: ZarrFormat) -> None: - """ - Test that `Group.members` returns correct values, i.e. the arrays and groups - (explicit and implicit) contained in that group. - """ - # group/ - # subgroup/ - # subsubgroup/ - # subsubsubgroup - # subarray - - path = "group" - group = Group.from_store( - store=store, - zarr_format=zarr_format, - ) - members_expected: dict[str, Array | Group] = {} - - members_expected["subgroup"] = group.create_group("subgroup") - # make a sub-sub-subgroup, to ensure that the children calculation doesn't go - # too deep in the hierarchy - subsubgroup = members_expected["subgroup"].create_group("subsubgroup") - _ = subsubgroup.create_group("subsubsubgroup") - - members_expected["subarray"] = group.create_array( - "subarray", shape=(100,), dtype="uint8", chunk_shape=(10,), overwrite=True - ) - - # add an extra object to the domain of the group. - # the list of children should ignore this object. - sync( - store.set( - f"{path}/extra_object-1", - default_buffer_prototype().buffer.from_bytes(b"000000"), - ) - ) - # add an extra object under a directory-like prefix in the domain of the group. - # this creates a directory with a random key in it - # this should not show up as a member - sync( - store.set( - f"{path}/extra_directory/extra_object-2", - default_buffer_prototype().buffer.from_bytes(b"000000"), - ) - ) - - # this warning shows up when extra objects show up in the hierarchy - warn_context = pytest.warns( - UserWarning, match=r"Object at .* is not recognized as a component of a Zarr hierarchy." - ) - - with warn_context: - members_observed = group.members() - # members are not guaranteed to be ordered, so sort before comparing - assert sorted(dict(members_observed)) == sorted(members_expected) - - # partial - with warn_context: - members_observed = group.members(max_depth=1) - members_expected["subgroup/subsubgroup"] = subsubgroup - # members are not guaranteed to be ordered, so sort before comparing - assert sorted(dict(members_observed)) == sorted(members_expected) - - def test_group(store: Store, zarr_format: ZarrFormat) -> None: """ Test basic Group routines.