From 7f9dbf5b9417cdae1f7b7e6618c793aaf06c1a7b Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Fri, 12 Sep 2025 15:40:52 +0200 Subject: [PATCH 1/3] refactor read_zarr --- src/spatialdata/_io/io_points.py | 5 + src/spatialdata/_io/io_raster.py | 60 +++++----- src/spatialdata/_io/io_shapes.py | 7 +- src/spatialdata/_io/io_table.py | 14 ++- src/spatialdata/_io/io_zarr.py | 198 ++++++++++++------------------- 5 files changed, 131 insertions(+), 153 deletions(-) diff --git a/src/spatialdata/_io/io_points.py b/src/spatialdata/_io/io_points.py index a251b042..bc52c94b 100644 --- a/src/spatialdata/_io/io_points.py +++ b/src/spatialdata/_io/io_points.py @@ -46,6 +46,11 @@ def _read_points( return points +class PointsReader: + def __call__(self, store: str | Path | MutableMapping[str, object] | zarr.Group) -> DaskDataFrame: + return _read_points(store) + + def write_points( points: DaskDataFrame, group: zarr.Group, diff --git a/src/spatialdata/_io/io_raster.py b/src/spatialdata/_io/io_raster.py index 80aa07e3..f0e8b6b4 100644 --- a/src/spatialdata/_io/io_raster.py +++ b/src/spatialdata/_io/io_raster.py @@ -33,33 +33,6 @@ ) -def _get_multiscale_nodes(image_nodes: list[Node], nodes: list[Node]) -> list[Node]: - """Get nodes with Multiscales spec from a list of nodes. - - The nodes with the Multiscales spec are the nodes used for reading in image and label data. We only have to check - the multiscales now, while before we also had to check the label spec. In the new ome-zarr-py though labels can have - the Label spec, these do not contain the multiscales anymore used to read the data. They can contain label specific - metadata though. - - Parameters - ---------- - image_nodes - List of nodes returned from the ome-zarr-py Reader. - nodes - List to append the nodes with the multiscales spec to. - - Returns - ------- - List of nodes with the multiscales spec. - """ - if len(image_nodes): - for node in image_nodes: - # Labels are now also Multiscales in newer version of ome-zarr-py - if np.any([isinstance(spec, Multiscales) for spec in node.specs]): - nodes.append(node) - return nodes - - def _read_multiscale(store: str | Path, raster_type: Literal["image", "labels"]) -> DataArray | DataTree: assert isinstance(store, str | Path) assert raster_type in ["image", "labels"] @@ -127,6 +100,7 @@ def _read_multiscale(store: str | Path, raster_type: Literal["image", "labels"]) msi = DataTree.from_dict(multiscale_image) _set_transformations(msi, transformations) return compute_coordinates(msi) + data = node.load(Multiscales).array(resolution=datasets[0]) si = DataArray( data, @@ -138,6 +112,38 @@ def _read_multiscale(store: str | Path, raster_type: Literal["image", "labels"]) return compute_coordinates(si) +def _get_multiscale_nodes(image_nodes: list[Node], nodes: list[Node]) -> list[Node]: + """Get nodes with Multiscales spec from a list of nodes. + + The nodes with the Multiscales spec are the nodes used for reading in image and label data. We only have to check + the multiscales now, while before we also had to check the label spec. In the new ome-zarr-py though labels can have + the Label spec, these do not contain the multiscales anymore used to read the data. They can contain label specific + metadata though. + + Parameters + ---------- + image_nodes + List of nodes returned from the ome-zarr-py Reader. + nodes + List to append the nodes with the multiscales spec to. + + Returns + ------- + List of nodes with the multiscales spec. + """ + if len(image_nodes): + for node in image_nodes: + # Labels are now also Multiscales in newer version of ome-zarr-py + if np.any([isinstance(spec, Multiscales) for spec in node.specs]): + nodes.append(node) + return nodes + + +class MultiscaleReader: + def __call__(self, path: str | Path, raster_type: Literal["image", "labels"]) -> DataArray | DataTree: + return _read_multiscale(path, raster_type) + + def _write_raster( raster_type: Literal["image", "labels"], raster_data: DataArray | DataTree, diff --git a/src/spatialdata/_io/io_shapes.py b/src/spatialdata/_io/io_shapes.py index 6e14baa4..15efdc47 100644 --- a/src/spatialdata/_io/io_shapes.py +++ b/src/spatialdata/_io/io_shapes.py @@ -30,7 +30,7 @@ def _read_shapes( - store: str | Path | MutableMapping | zarr.Group, # type: ignore[type-arg] + store: str | Path | MutableMapping[str, object] | zarr.Group, ) -> GeoDataFrame: """Read shapes from a zarr store.""" assert isinstance(store, str | Path) @@ -67,6 +67,11 @@ def _read_shapes( return geo_df +class ShapesReader: + def __call__(self, store: str | Path | MutableMapping[str, object] | zarr.Group) -> GeoDataFrame: + return _read_shapes(store) + + def write_shapes( shapes: GeoDataFrame, group: zarr.Group, diff --git a/src/spatialdata/_io/io_table.py b/src/spatialdata/_io/io_table.py index 6a9d191f..9e71c4a3 100644 --- a/src/spatialdata/_io/io_table.py +++ b/src/spatialdata/_io/io_table.py @@ -21,7 +21,7 @@ def _read_table( group: zarr.Group, tables: dict[str, AnnData], on_bad_files: Literal[BadFileHandleMethod.ERROR, BadFileHandleMethod.WARN] = BadFileHandleMethod.ERROR, -) -> dict[str, AnnData]: +) -> None: """ Read in tables in the tables Zarr.group of a SpatialData Zarr store. @@ -85,7 +85,17 @@ def _read_table( count += 1 logger.debug(f"Found {count} elements in {group}") - return tables + + +class TablesReader: + def __call__( + self, + path: str, + group: zarr.Group, + container: dict[str, AnnData], + on_bad_files: Literal[BadFileHandleMethod.ERROR, BadFileHandleMethod.WARN], + ) -> None: + return _read_table(path, group, container, on_bad_files) def write_table( diff --git a/src/spatialdata/_io/io_zarr.py b/src/spatialdata/_io/io_zarr.py index 79a827a5..0f60bd1a 100644 --- a/src/spatialdata/_io/io_zarr.py +++ b/src/spatialdata/_io/io_zarr.py @@ -2,7 +2,7 @@ import warnings from json import JSONDecodeError from pathlib import Path -from typing import Literal +from typing import Literal, cast import zarr.storage from anndata import AnnData @@ -15,11 +15,55 @@ _resolve_zarr_store, handle_read_errors, ) -from spatialdata._io.io_points import _read_points -from spatialdata._io.io_raster import _read_multiscale -from spatialdata._io.io_shapes import _read_shapes -from spatialdata._io.io_table import _read_table +from spatialdata._io.io_points import PointsReader +from spatialdata._io.io_raster import MultiscaleReader +from spatialdata._io.io_shapes import ShapesReader +from spatialdata._io.io_table import TablesReader from spatialdata._logging import logger +from spatialdata.models import SpatialElement + +ReadClasses = MultiscaleReader | PointsReader | ShapesReader | TablesReader + + +def _read_zarr_group_spatialdata_element( + root_group: zarr.Group, + root_store_path: str, + selector: set[str], + read_func: ReadClasses, + group_name: Literal["images", "labels", "shapes", "points", "tables"], + element_type: Literal["image", "labels", "shapes", "points", "tables"], + element_container: dict[str, SpatialElement | AnnData], + on_bad_files: Literal[BadFileHandleMethod.ERROR, BadFileHandleMethod.WARN], +) -> None: + with handle_read_errors( + on_bad_files, + location=group_name, + exc_types=JSONDecodeError, + ): + if group_name in selector and group_name in root_group: + group = root_group[group_name] + if isinstance(read_func, TablesReader) and element_type == "tables": + read_func(root_store_path, group, element_container, on_bad_files=on_bad_files) + else: + count = 0 + for subgroup_name in group: + if Path(subgroup_name).name.startswith("."): + # skip hidden files like .zgroup or .zmetadata + continue + elem_group = group[subgroup_name] + elem_group_path = os.path.join(root_store_path, elem_group.path) + with handle_read_errors( + on_bad_files, + location=f"{group.path}/{subgroup_name}", + exc_types=(KeyError, ArrayNotFoundError, OSError, ArrowInvalid, JSONDecodeError), + ): + if isinstance(read_func, MultiscaleReader) and element_type in ("image", "labels"): + element = read_func(elem_group_path, cast(Literal["image", "labels"], element_type)) + if isinstance(read_func, PointsReader | ShapesReader): + element = read_func(elem_group_path) + element_container[subgroup_name] = element + count += 1 + logger.debug(f"Found {count} elements in {group}") def read_zarr( @@ -59,130 +103,38 @@ def read_zarr( root_group = zarr.open_group(resolved_store, mode="r") root_store_path = root_group.store.root - images = {} - labels = {} - points = {} + images: dict[str, SpatialElement] = {} + labels: dict[str, SpatialElement] = {} + points: dict[str, SpatialElement] = {} tables: dict[str, AnnData] = {} - shapes = {} + shapes: dict[str, SpatialElement] = {} selector = {"images", "labels", "points", "shapes", "tables"} if not selection else set(selection or []) logger.debug(f"Reading selection {selector}") - # We raise OS errors instead for some read errors now as in zarr v3 with some corruptions nothing will be read. - # related to images / labels. - with handle_read_errors( - on_bad_files, - location="images", - exc_types=JSONDecodeError, - ): - if "images" in selector and "images" in root_group: - group = root_group["images"] - count = 0 - for subgroup_name in group: - if Path(subgroup_name).name.startswith("."): - # skip hidden files like .zgroup or .zmetadata - continue - elem_group = group[subgroup_name] - elem_group_path = os.path.join(root_store_path, elem_group.path) - with handle_read_errors( - on_bad_files, - location=f"{group.path}/{subgroup_name}", - exc_types=( - KeyError, - ArrayNotFoundError, - OSError, - ), - ): - element = _read_multiscale(elem_group_path, raster_type="image") - images[subgroup_name] = element - count += 1 - logger.debug(f"Found {count} elements in {group}") - - # read multiscale labels - with handle_read_errors( - on_bad_files, - location="labels", - exc_types=JSONDecodeError, - ): - if "labels" in selector and "labels" in root_group: - group = root_group["labels"] - count = 0 - for subgroup_name in group: - if Path(subgroup_name).name.startswith("."): - # skip hidden files like .zgroup or .zmetadata - continue - elem_group = group[subgroup_name] - elem_group_path = root_store_path / elem_group.path - with handle_read_errors( - on_bad_files, - location=f"{group.path}/{subgroup_name}", - exc_types=( - KeyError, - ArrayNotFoundError, - OSError, - ), - ): - labels[subgroup_name] = _read_multiscale(elem_group_path, raster_type="labels") - count += 1 - logger.debug(f"Found {count} elements in {group}") - # now read rest of the data - with handle_read_errors( - on_bad_files, - location="points", - exc_types=JSONDecodeError, - ): - if "points" in selector and "points" in root_group: - group = root_group["points"] - count = 0 - for subgroup_name in group: - elem_group = group[subgroup_name] - if Path(subgroup_name).name.startswith("."): - # skip hidden files like .zgroup or .zmetadata - continue - elem_group_path = os.path.join(root_store_path, elem_group.path) - with handle_read_errors( - on_bad_files, - location=f"{group.path}/{subgroup_name}", - exc_types=(KeyError, ArrowInvalid, JSONDecodeError), - ): - points[subgroup_name] = _read_points(elem_group_path) - count += 1 - logger.debug(f"Found {count} elements in {group}") - - with handle_read_errors( - on_bad_files, - location="shapes", - exc_types=JSONDecodeError, - ): - if "shapes" in selector and "shapes" in root_group: - group = root_group["shapes"] - count = 0 - for subgroup_name in group: - if Path(subgroup_name).name.startswith("."): - # skip hidden files like .zgroup or .zmetadata - continue - elem_group = group[subgroup_name] - elem_group_path = os.path.join(root_store_path, elem_group.path) - with handle_read_errors( - on_bad_files, - location=f"{group.path}/{subgroup_name}", - exc_types=( - JSONDecodeError, - KeyError, - ArrayNotFoundError, - ), - ): - shapes[subgroup_name] = _read_shapes(elem_group_path) - count += 1 - logger.debug(f"Found {count} elements in {group}") - if "tables" in selector and "tables" in root_group: - with handle_read_errors( + group_readers: dict[ + Literal["images", "labels", "shapes", "points", "tables"], + tuple[ + ReadClasses, Literal["image", "labels", "shapes", "points", "tables"], dict[str, SpatialElement | AnnData] + ], + ] = { + "images": (MultiscaleReader(), "image", images), + "labels": (MultiscaleReader(), "labels", labels), + "points": (PointsReader(), "points", points), + "shapes": (ShapesReader(), "shapes", shapes), + "tables": (TablesReader(), "tables", tables), + } + for group_name, (reader, raster_type, container) in group_readers.items(): + _read_zarr_group_spatialdata_element( + root_group, + root_store_path, + selector, + reader, + group_name, + raster_type, + container, on_bad_files, - location="tables", - exc_types=JSONDecodeError, - ): - group = root_group["tables"] - tables = _read_table(root_store_path, group, tables, on_bad_files=on_bad_files) + ) # read attrs metadata attrs = root_group.attrs.asdict() From 953f36dd583e0ee9b0d4f3671fcd107c424b1635 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Fri, 12 Sep 2025 15:51:24 +0200 Subject: [PATCH 2/3] remove unneccesary checks --- src/spatialdata/_io/io_zarr.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spatialdata/_io/io_zarr.py b/src/spatialdata/_io/io_zarr.py index 0f60bd1a..47a144f4 100644 --- a/src/spatialdata/_io/io_zarr.py +++ b/src/spatialdata/_io/io_zarr.py @@ -42,7 +42,7 @@ def _read_zarr_group_spatialdata_element( ): if group_name in selector and group_name in root_group: group = root_group[group_name] - if isinstance(read_func, TablesReader) and element_type == "tables": + if isinstance(read_func, TablesReader): read_func(root_store_path, group, element_container, on_bad_files=on_bad_files) else: count = 0 @@ -57,7 +57,7 @@ def _read_zarr_group_spatialdata_element( location=f"{group.path}/{subgroup_name}", exc_types=(KeyError, ArrayNotFoundError, OSError, ArrowInvalid, JSONDecodeError), ): - if isinstance(read_func, MultiscaleReader) and element_type in ("image", "labels"): + if isinstance(read_func, MultiscaleReader): element = read_func(elem_group_path, cast(Literal["image", "labels"], element_type)) if isinstance(read_func, PointsReader | ShapesReader): element = read_func(elem_group_path) From 8c1f45a476474dc7032134f8d39a12ea57c8be77 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Wed, 17 Sep 2025 15:04:33 +0200 Subject: [PATCH 3/3] emit warning with old spatialdata storage version detected --- src/spatialdata/_core/spatialdata.py | 5 ++--- src/spatialdata/_io/io_zarr.py | 7 +++++++ 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 47235501..3d6a9ed0 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -1698,15 +1698,14 @@ def write_metadata( check_valid_name(element_name) if element_name not in self: raise ValueError(f"Element with name {element_name} not found in SpatialData object.") + if write_attrs: + self.write_attrs(sdata_format=sdata_format) self.write_transformations(element_name) self.write_channel_names(element_name) # TODO: write .uns['spatialdata_attrs'] metadata for AnnData. # TODO: write .attrs['spatialdata_attrs'] metadata for DaskDataFrame. - if write_attrs: - self.write_attrs(sdata_format=sdata_format) - if self.has_consolidated_metadata(): consolidate_metadata = True if consolidate_metadata: diff --git a/src/spatialdata/_io/io_zarr.py b/src/spatialdata/_io/io_zarr.py index b0e1d32b..5367dcf0 100644 --- a/src/spatialdata/_io/io_zarr.py +++ b/src/spatialdata/_io/io_zarr.py @@ -133,6 +133,13 @@ def read_zarr( resolved_store = _resolve_zarr_store(store) root_group = zarr.open_group(resolved_store, mode="r") sdata_version = root_group.metadata.attributes["spatialdata_attrs"]["version"] + if sdata_version == "0.1": + warnings.warn( + "SpatialData is not stored in the most current format. If you want to use Zarr v3" + ", please write the store to a new location.", + UserWarning, + stacklevel=2, + ) root_store_path = root_group.store.root images: dict[str, SpatialElement] = {}