From 8b704724eb99fd3fe507c531c9faecee39e9bccf Mon Sep 17 00:00:00 2001 From: David Stansby Date: Sat, 18 May 2024 10:33:08 +0100 Subject: [PATCH] Finish typing zarr.metadata --- pyproject.toml | 1 - src/zarr/array.py | 2 ++ src/zarr/chunk_grids.py | 8 +++---- src/zarr/chunk_key_encodings.py | 2 +- src/zarr/common.py | 5 ++-- src/zarr/metadata.py | 41 ++++++++++++++++++--------------- 6 files changed, 32 insertions(+), 27 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 05022261fa..4894637645 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -241,7 +241,6 @@ module = [ "zarr.array", "zarr.common", "zarr.group", - "zarr.metadata" ] disallow_untyped_defs = false diff --git a/src/zarr/array.py b/src/zarr/array.py index 039f39e98e..86ff262940 100644 --- a/src/zarr/array.py +++ b/src/zarr/array.py @@ -199,6 +199,8 @@ async def _create_v3( if chunk_key_encoding is None: chunk_key_encoding = ("default", "/") + assert chunk_key_encoding is not None + if isinstance(chunk_key_encoding, tuple): chunk_key_encoding = ( V2ChunkKeyEncoding(separator=chunk_key_encoding[1]) diff --git a/src/zarr/chunk_grids.py b/src/zarr/chunk_grids.py index 45f77cc99c..f6366b8038 100644 --- a/src/zarr/chunk_grids.py +++ b/src/zarr/chunk_grids.py @@ -3,7 +3,7 @@ import itertools from collections.abc import Iterator from dataclasses import dataclass -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING from zarr.abc.metadata import Metadata from zarr.common import ( @@ -22,13 +22,13 @@ @dataclass(frozen=True) class ChunkGrid(Metadata): @classmethod - def from_dict(cls, data: dict[str, JSON]) -> ChunkGrid: + def from_dict(cls, data: dict[str, JSON] | ChunkGrid) -> ChunkGrid: if isinstance(data, ChunkGrid): return data name_parsed, _ = parse_named_configuration(data) if name_parsed == "regular": - return RegularChunkGrid.from_dict(data) + return RegularChunkGrid._from_dict(data) raise ValueError(f"Unknown chunk grid. Got {name_parsed}.") def all_chunk_coords(self, array_shape: ChunkCoords) -> Iterator[ChunkCoords]: @@ -45,7 +45,7 @@ def __init__(self, *, chunk_shape: ChunkCoordsLike) -> None: object.__setattr__(self, "chunk_shape", chunk_shape_parsed) @classmethod - def from_dict(cls, data: dict[str, Any]) -> Self: + def _from_dict(cls, data: dict[str, JSON]) -> Self: _, configuration_parsed = parse_named_configuration(data, "regular") return cls(**configuration_parsed) # type: ignore[arg-type] diff --git a/src/zarr/chunk_key_encodings.py b/src/zarr/chunk_key_encodings.py index 5ecb98ef61..ed6c181764 100644 --- a/src/zarr/chunk_key_encodings.py +++ b/src/zarr/chunk_key_encodings.py @@ -34,7 +34,7 @@ def __init__(self, *, separator: SeparatorLiteral) -> None: object.__setattr__(self, "separator", separator_parsed) @classmethod - def from_dict(cls, data: dict[str, JSON]) -> ChunkKeyEncoding: + def from_dict(cls, data: dict[str, JSON] | ChunkKeyEncoding) -> ChunkKeyEncoding: if isinstance(data, ChunkKeyEncoding): return data diff --git a/src/zarr/common.py b/src/zarr/common.py index 32a6c2fd0c..9d8315abc8 100644 --- a/src/zarr/common.py +++ b/src/zarr/common.py @@ -13,6 +13,7 @@ from collections.abc import Awaitable, Callable, Iterator import numpy as np +import numpy.typing as npt ZARR_JSON = "zarr.json" ZARRAY_JSON = ".zarray" @@ -150,7 +151,7 @@ def parse_named_configuration( return name_parsed, configuration_parsed -def parse_shapelike(data: Any) -> tuple[int, ...]: +def parse_shapelike(data: Iterable[int]) -> tuple[int, ...]: if not isinstance(data, Iterable): raise TypeError(f"Expected an iterable. Got {data} instead.") data_tuple = tuple(data) @@ -164,7 +165,7 @@ def parse_shapelike(data: Any) -> tuple[int, ...]: return data_tuple -def parse_dtype(data: Any) -> np.dtype[Any]: +def parse_dtype(data: npt.DTypeLike) -> np.dtype[Any]: # todo: real validation return np.dtype(data) diff --git a/src/zarr/metadata.py b/src/zarr/metadata.py index 8db8c8033e..e2b7b987c0 100644 --- a/src/zarr/metadata.py +++ b/src/zarr/metadata.py @@ -22,6 +22,7 @@ from typing_extensions import Self +import numcodecs.abc from zarr.common import ( JSON, @@ -168,15 +169,15 @@ class ArrayV3Metadata(ArrayMetadata): def __init__( self, *, - shape, - data_type, - chunk_grid, - chunk_key_encoding, - fill_value, - codecs, - attributes, - dimension_names, - ): + shape: Iterable[int], + data_type: npt.DTypeLike, + chunk_grid: dict[str, JSON] | ChunkGrid, + chunk_key_encoding: dict[str, JSON] | ChunkKeyEncoding, + fill_value: Any, + codecs: Iterable[Codec | JSON], + attributes: None | dict[str, JSON], + dimension_names: None | Iterable[str], + ) -> None: """ Because the class is a frozen dataclass, we set attributes using object.__setattr__ """ @@ -249,14 +250,14 @@ def encode_chunk_key(self, chunk_coords: ChunkCoords) -> str: return self.chunk_key_encoding.encode_chunk_key(chunk_coords) def to_buffer_dict(self) -> dict[str, Buffer]: - def _json_convert(o): + def _json_convert(o: np.dtype[Any] | Enum | Codec) -> str | dict[str, Any]: if isinstance(o, np.dtype): return str(o) if isinstance(o, Enum): return o.name # this serializes numcodecs compressors # todo: implement to_dict for codecs - elif hasattr(o, "get_config"): + elif isinstance(o, numcodecs.abc.Codec): return o.get_config() raise TypeError @@ -271,9 +272,10 @@ def from_dict(cls, data: dict[str, JSON]) -> ArrayV3Metadata: # check that the node_type attribute is correct _ = parse_node_type_array(data.pop("node_type")) - dimension_names = data.pop("dimension_names", None) + data["dimension_names"] = data.pop("dimension_names", None) - return cls(**data, dimension_names=dimension_names) + # TODO: Remove the ignores and use a TypedDict to type `data` + return cls(**data) # type: ignore[arg-type] def to_dict(self) -> dict[str, Any]: out_dict = super().to_dict() @@ -367,7 +369,9 @@ def codec_pipeline(self) -> CodecPipeline: ) def to_buffer_dict(self) -> dict[str, Buffer]: - def _json_convert(o): + def _json_convert( + o: np.dtype[Any], + ) -> str | list[tuple[str, str] | tuple[str, str, tuple[int, ...]]]: if isinstance(o, np.dtype): if o.fields is None: return o.str @@ -399,7 +403,7 @@ def to_dict(self) -> JSON: zarray_dict["chunks"] = self.chunk_grid.chunk_shape _ = zarray_dict.pop("data_type") - zarray_dict["dtype"] = self.data_type + zarray_dict["dtype"] = self.data_type.str return zarray_dict @@ -422,7 +426,7 @@ def update_attributes(self, attributes: dict[str, JSON]) -> Self: return replace(self, attributes=attributes) -def parse_dimension_names(data: Any) -> tuple[str, ...] | None: +def parse_dimension_names(data: None | Iterable[str]) -> tuple[str, ...] | None: if data is None: return data if isinstance(data, Iterable) and all([isinstance(x, str) for x in data]): @@ -432,12 +436,11 @@ def parse_dimension_names(data: Any) -> tuple[str, ...] | None: # todo: real validation -def parse_attributes(data: Any) -> dict[str, JSON]: +def parse_attributes(data: None | dict[str, JSON]) -> dict[str, JSON]: if data is None: return {} - data_json = cast(dict[str, JSON], data) - return data_json + return data # todo: move to its own module and drop _v3 suffix