diff --git a/src/zarr/buffer.py b/src/zarr/buffer.py index 1298711d4e..86f9b53477 100644 --- a/src/zarr/buffer.py +++ b/src/zarr/buffer.py @@ -146,7 +146,7 @@ def create_zero_length(cls) -> Self: @classmethod def from_array_like(cls, array_like: ArrayLike) -> Self: - """Create a new buffer of a array-like object + """Create a new buffer of an array-like object Parameters ---------- @@ -159,6 +159,29 @@ def from_array_like(cls, array_like: ArrayLike) -> Self: """ return cls(array_like) + @classmethod + def from_buffer(cls, buffer: Buffer) -> Self: + """Create a new buffer of an existing Buffer + + This is useful if you want to ensure that an existing buffer is + of the correct subclass of Buffer. E.g., MemoryStore uses this + to return a buffer instance of the subclass specified by its + BufferPrototype argument. + + Typically, this only copies data if the data has to be moved between + memory types, such as from host to device memory. + + Parameters + ---------- + buffer + buffer object. + + Returns + ------- + A new buffer representing the content of the input buffer + """ + return cls.from_array_like(buffer.as_array_like()) + @classmethod def from_bytes(cls, bytes_like: BytesLike) -> Self: """Create a new buffer of a bytes-like object (host memory) diff --git a/src/zarr/store/memory.py b/src/zarr/store/memory.py index 43d65ce836..7b73330b6c 100644 --- a/src/zarr/store/memory.py +++ b/src/zarr/store/memory.py @@ -39,7 +39,7 @@ async def get( try: value = self._store_dict[key] start, length = _normalize_interval_index(value, byte_range) - return value[start : start + length] + return prototype.buffer.from_buffer(value[start : start + length]) except KeyError: return None diff --git a/src/zarr/store/remote.py b/src/zarr/store/remote.py index 15051334e9..50a02dcbcd 100644 --- a/src/zarr/store/remote.py +++ b/src/zarr/store/remote.py @@ -6,7 +6,7 @@ import fsspec from zarr.abc.store import Store -from zarr.buffer import Buffer, BufferPrototype, default_buffer_prototype +from zarr.buffer import Buffer, BufferPrototype from zarr.common import OpenMode from zarr.store.core import _dereference_path @@ -84,7 +84,7 @@ def __repr__(self) -> str: async def get( self, key: str, - prototype: BufferPrototype = default_buffer_prototype, + prototype: BufferPrototype, byte_range: tuple[int | None, int | None] | None = None, ) -> Buffer | None: path = _dereference_path(self.path, key) @@ -99,7 +99,7 @@ async def get( end = length else: end = None - value: Buffer = prototype.buffer.from_bytes( + value = prototype.buffer.from_bytes( await ( self._fs._cat_file(path, start=byte_range[0], end=end) if byte_range diff --git a/tests/v3/test_buffer.py b/tests/v3/test_buffer.py index e814afef15..77e1b6b688 100644 --- a/tests/v3/test_buffer.py +++ b/tests/v3/test_buffer.py @@ -68,7 +68,10 @@ async def get( ) -> Buffer | None: if "json" not in key: assert prototype.buffer is MyBuffer - return await super().get(key, byte_range) + ret = await super().get(key=key, prototype=prototype, byte_range=byte_range) + if ret is not None: + assert isinstance(ret, prototype.buffer) + return ret def test_nd_array_like(xp): diff --git a/tests/v3/test_store/test_remote.py b/tests/v3/test_store/test_remote.py index 98206d427f..0dc399be42 100644 --- a/tests/v3/test_store/test_remote.py +++ b/tests/v3/test_store/test_remote.py @@ -88,7 +88,7 @@ async def test_basic(): data = b"hello" await store.set("foo", Buffer.from_bytes(data)) assert await store.exists("foo") - assert (await store.get("foo")).to_bytes() == data + assert (await store.get("foo", prototype=default_buffer_prototype)).to_bytes() == data out = await store.get_partial_values( prototype=default_buffer_prototype, key_ranges=[("foo", (1, None))] )