diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 47235501..3d6a9ed0 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -1698,15 +1698,14 @@ def write_metadata( check_valid_name(element_name) if element_name not in self: raise ValueError(f"Element with name {element_name} not found in SpatialData object.") + if write_attrs: + self.write_attrs(sdata_format=sdata_format) self.write_transformations(element_name) self.write_channel_names(element_name) # TODO: write .uns['spatialdata_attrs'] metadata for AnnData. # TODO: write .attrs['spatialdata_attrs'] metadata for DaskDataFrame. - if write_attrs: - self.write_attrs(sdata_format=sdata_format) - if self.has_consolidated_metadata(): consolidate_metadata = True if consolidate_metadata: diff --git a/src/spatialdata/_io/io_points.py b/src/spatialdata/_io/io_points.py index a251b042..bc52c94b 100644 --- a/src/spatialdata/_io/io_points.py +++ b/src/spatialdata/_io/io_points.py @@ -46,6 +46,11 @@ def _read_points( return points +class PointsReader: + def __call__(self, store: str | Path | MutableMapping[str, object] | zarr.Group) -> DaskDataFrame: + return _read_points(store) + + def write_points( points: DaskDataFrame, group: zarr.Group, diff --git a/src/spatialdata/_io/io_raster.py b/src/spatialdata/_io/io_raster.py index af6578cd..fe7f6b36 100644 --- a/src/spatialdata/_io/io_raster.py +++ b/src/spatialdata/_io/io_raster.py @@ -35,33 +35,6 @@ ) -def _get_multiscale_nodes(image_nodes: list[Node], nodes: list[Node]) -> list[Node]: - """Get nodes with Multiscales spec from a list of nodes. - - The nodes with the Multiscales spec are the nodes used for reading in image and label data. We only have to check - the multiscales now, while before we also had to check the label spec. In the new ome-zarr-py though labels can have - the Label spec, these do not contain the multiscales anymore used to read the data. They can contain label specific - metadata though. - - Parameters - ---------- - image_nodes - List of nodes returned from the ome-zarr-py Reader. - nodes - List to append the nodes with the multiscales spec to. - - Returns - ------- - List of nodes with the multiscales spec. - """ - if len(image_nodes): - for node in image_nodes: - # Labels are now also Multiscales in newer version of ome-zarr-py - if np.any([isinstance(spec, Multiscales) for spec in node.specs]): - nodes.append(node) - return nodes - - def _read_multiscale( store: str | Path, raster_type: Literal["image", "labels"], reader_format: Format ) -> DataArray | DataTree: @@ -134,6 +107,7 @@ def _read_multiscale( msi = DataTree.from_dict(multiscale_image) _set_transformations(msi, transformations) return compute_coordinates(msi) + data = node.load(Multiscales).array(resolution=datasets[0]) si = DataArray( data, @@ -145,6 +119,40 @@ def _read_multiscale( return compute_coordinates(si) +def _get_multiscale_nodes(image_nodes: list[Node], nodes: list[Node]) -> list[Node]: + """Get nodes with Multiscales spec from a list of nodes. + + The nodes with the Multiscales spec are the nodes used for reading in image and label data. We only have to check + the multiscales now, while before we also had to check the label spec. In the new ome-zarr-py though labels can have + the Label spec, these do not contain the multiscales anymore used to read the data. They can contain label specific + metadata though. + + Parameters + ---------- + image_nodes + List of nodes returned from the ome-zarr-py Reader. + nodes + List to append the nodes with the multiscales spec to. + + Returns + ------- + List of nodes with the multiscales spec. + """ + if len(image_nodes): + for node in image_nodes: + # Labels are now also Multiscales in newer version of ome-zarr-py + if np.any([isinstance(spec, Multiscales) for spec in node.specs]): + nodes.append(node) + return nodes + + +class MultiscaleReader: + def __call__( + self, path: str | Path, raster_type: Literal["image", "labels"], reader_format: Format + ) -> DataArray | DataTree: + return _read_multiscale(path, raster_type, reader_format) + + def _write_raster( raster_type: Literal["image", "labels"], raster_data: DataArray | DataTree, diff --git a/src/spatialdata/_io/io_shapes.py b/src/spatialdata/_io/io_shapes.py index 6e14baa4..15efdc47 100644 --- a/src/spatialdata/_io/io_shapes.py +++ b/src/spatialdata/_io/io_shapes.py @@ -30,7 +30,7 @@ def _read_shapes( - store: str | Path | MutableMapping | zarr.Group, # type: ignore[type-arg] + store: str | Path | MutableMapping[str, object] | zarr.Group, ) -> GeoDataFrame: """Read shapes from a zarr store.""" assert isinstance(store, str | Path) @@ -67,6 +67,11 @@ def _read_shapes( return geo_df +class ShapesReader: + def __call__(self, store: str | Path | MutableMapping[str, object] | zarr.Group) -> GeoDataFrame: + return _read_shapes(store) + + def write_shapes( shapes: GeoDataFrame, group: zarr.Group, diff --git a/src/spatialdata/_io/io_table.py b/src/spatialdata/_io/io_table.py index 6a9d191f..9e71c4a3 100644 --- a/src/spatialdata/_io/io_table.py +++ b/src/spatialdata/_io/io_table.py @@ -21,7 +21,7 @@ def _read_table( group: zarr.Group, tables: dict[str, AnnData], on_bad_files: Literal[BadFileHandleMethod.ERROR, BadFileHandleMethod.WARN] = BadFileHandleMethod.ERROR, -) -> dict[str, AnnData]: +) -> None: """ Read in tables in the tables Zarr.group of a SpatialData Zarr store. @@ -85,7 +85,17 @@ def _read_table( count += 1 logger.debug(f"Found {count} elements in {group}") - return tables + + +class TablesReader: + def __call__( + self, + path: str, + group: zarr.Group, + container: dict[str, AnnData], + on_bad_files: Literal[BadFileHandleMethod.ERROR, BadFileHandleMethod.WARN], + ) -> None: + return _read_table(path, group, container, on_bad_files) def write_table( diff --git a/src/spatialdata/_io/io_zarr.py b/src/spatialdata/_io/io_zarr.py index 73b0ef6e..5367dcf0 100644 --- a/src/spatialdata/_io/io_zarr.py +++ b/src/spatialdata/_io/io_zarr.py @@ -2,7 +2,7 @@ import warnings from json import JSONDecodeError from pathlib import Path -from typing import Literal +from typing import Literal, cast import zarr.storage from anndata import AnnData @@ -16,11 +16,59 @@ _resolve_zarr_store, handle_read_errors, ) -from spatialdata._io.io_points import _read_points -from spatialdata._io.io_raster import _read_multiscale -from spatialdata._io.io_shapes import _read_shapes -from spatialdata._io.io_table import _read_table +from spatialdata._io.io_points import PointsReader +from spatialdata._io.io_raster import MultiscaleReader +from spatialdata._io.io_shapes import ShapesReader +from spatialdata._io.io_table import TablesReader from spatialdata._logging import logger +from spatialdata.models import SpatialElement + +ReadClasses = MultiscaleReader | PointsReader | ShapesReader | TablesReader + + +def _read_zarr_group_spatialdata_element( + root_group: zarr.Group, + root_store_path: str, + sdata_version: Literal["0.1", "0.2"], + selector: set[str], + read_func: ReadClasses, + group_name: Literal["images", "labels", "shapes", "points", "tables"], + element_type: Literal["image", "labels", "shapes", "points", "tables"], + element_container: dict[str, SpatialElement | AnnData], + on_bad_files: Literal[BadFileHandleMethod.ERROR, BadFileHandleMethod.WARN], +) -> None: + with handle_read_errors( + on_bad_files, + location=group_name, + exc_types=JSONDecodeError, + ): + if group_name in selector and group_name in root_group: + group = root_group[group_name] + if isinstance(read_func, TablesReader): + read_func(root_store_path, group, element_container, on_bad_files=on_bad_files) + else: + count = 0 + for subgroup_name in group: + if Path(subgroup_name).name.startswith("."): + # skip hidden files like .zgroup or .zmetadata + continue + elem_group = group[subgroup_name] + elem_group_path = os.path.join(root_store_path, elem_group.path) + with handle_read_errors( + on_bad_files, + location=f"{group.path}/{subgroup_name}", + exc_types=(KeyError, ArrayNotFoundError, OSError, ArrowInvalid, JSONDecodeError), + ): + if isinstance(read_func, MultiscaleReader): + reader_format = get_raster_format_for_read(elem_group, sdata_version) + element = read_func( + elem_group_path, cast(Literal["image", "labels"], element_type), reader_format + ) + if isinstance(read_func, PointsReader | ShapesReader): + element = read_func(elem_group_path) + element_container[subgroup_name] = element + count += 1 + logger.debug(f"Found {count} elements in {group}") def get_raster_format_for_read(group: zarr.Group, sdata_version: Literal["0.1", "0.2"]) -> Format: @@ -85,136 +133,48 @@ def read_zarr( resolved_store = _resolve_zarr_store(store) root_group = zarr.open_group(resolved_store, mode="r") sdata_version = root_group.metadata.attributes["spatialdata_attrs"]["version"] + if sdata_version == "0.1": + warnings.warn( + "SpatialData is not stored in the most current format. If you want to use Zarr v3" + ", please write the store to a new location.", + UserWarning, + stacklevel=2, + ) root_store_path = root_group.store.root - images = {} - labels = {} - points = {} + images: dict[str, SpatialElement] = {} + labels: dict[str, SpatialElement] = {} + points: dict[str, SpatialElement] = {} tables: dict[str, AnnData] = {} - shapes = {} + shapes: dict[str, SpatialElement] = {} selector = {"images", "labels", "points", "shapes", "tables"} if not selection else set(selection or []) logger.debug(f"Reading selection {selector}") - # We raise OS errors instead for some read errors now as in zarr v3 with some corruptions nothing will be read. - # related to images / labels. - with handle_read_errors( - on_bad_files, - location="images", - exc_types=JSONDecodeError, - ): - if "images" in selector and "images" in root_group: - group = root_group["images"] - count = 0 - for subgroup_name in group: - if Path(subgroup_name).name.startswith("."): - # skip hidden files like .zgroup or .zmetadata - continue - elem_group = group[subgroup_name] - reader_format = get_raster_format_for_read(elem_group, sdata_version) - elem_group_path = os.path.join(root_store_path, elem_group.path) - with handle_read_errors( - on_bad_files, - location=f"{group.path}/{subgroup_name}", - exc_types=( - KeyError, - ArrayNotFoundError, - OSError, - ), - ): - element = _read_multiscale(elem_group_path, raster_type="image", reader_format=reader_format) - images[subgroup_name] = element - count += 1 - logger.debug(f"Found {count} elements in {group}") - - # read multiscale labels - with handle_read_errors( - on_bad_files, - location="labels", - exc_types=JSONDecodeError, - ): - if "labels" in selector and "labels" in root_group: - group = root_group["labels"] - count = 0 - for subgroup_name in group: - if Path(subgroup_name).name.startswith("."): - # skip hidden files like .zgroup or .zmetadata - continue - elem_group = group[subgroup_name] - reader_format = get_raster_format_for_read(elem_group, sdata_version) - elem_group_path = root_store_path / elem_group.path - with handle_read_errors( - on_bad_files, - location=f"{group.path}/{subgroup_name}", - exc_types=( - KeyError, - ArrayNotFoundError, - OSError, - ), - ): - labels[subgroup_name] = _read_multiscale( - elem_group_path, raster_type="labels", reader_format=reader_format - ) - count += 1 - logger.debug(f"Found {count} elements in {group}") - # now read rest of the data - with handle_read_errors( - on_bad_files, - location="points", - exc_types=JSONDecodeError, - ): - if "points" in selector and "points" in root_group: - group = root_group["points"] - count = 0 - for subgroup_name in group: - elem_group = group[subgroup_name] - if Path(subgroup_name).name.startswith("."): - # skip hidden files like .zgroup or .zmetadata - continue - elem_group_path = os.path.join(root_store_path, elem_group.path) - with handle_read_errors( - on_bad_files, - location=f"{group.path}/{subgroup_name}", - exc_types=(KeyError, ArrowInvalid, JSONDecodeError), - ): - points[subgroup_name] = _read_points(elem_group_path) - count += 1 - logger.debug(f"Found {count} elements in {group}") - - with handle_read_errors( - on_bad_files, - location="shapes", - exc_types=JSONDecodeError, - ): - if "shapes" in selector and "shapes" in root_group: - group = root_group["shapes"] - count = 0 - for subgroup_name in group: - if Path(subgroup_name).name.startswith("."): - # skip hidden files like .zgroup or .zmetadata - continue - elem_group = group[subgroup_name] - elem_group_path = os.path.join(root_store_path, elem_group.path) - with handle_read_errors( - on_bad_files, - location=f"{group.path}/{subgroup_name}", - exc_types=( - JSONDecodeError, - KeyError, - ArrayNotFoundError, - ), - ): - shapes[subgroup_name] = _read_shapes(elem_group_path) - count += 1 - logger.debug(f"Found {count} elements in {group}") - if "tables" in selector and "tables" in root_group: - with handle_read_errors( + group_readers: dict[ + Literal["images", "labels", "shapes", "points", "tables"], + tuple[ + ReadClasses, Literal["image", "labels", "shapes", "points", "tables"], dict[str, SpatialElement | AnnData] + ], + ] = { + "images": (MultiscaleReader(), "image", images), + "labels": (MultiscaleReader(), "labels", labels), + "points": (PointsReader(), "points", points), + "shapes": (ShapesReader(), "shapes", shapes), + "tables": (TablesReader(), "tables", tables), + } + for group_name, (reader, raster_type, container) in group_readers.items(): + _read_zarr_group_spatialdata_element( + root_group, + root_store_path, + sdata_version, + selector, + reader, + group_name, + raster_type, + container, on_bad_files, - location="tables", - exc_types=JSONDecodeError, - ): - group = root_group["tables"] - tables = _read_table(root_store_path, group, tables, on_bad_files=on_bad_files) + ) # read attrs metadata attrs = root_group.attrs.asdict()