diff --git a/docs/requirements.txt b/docs/requirements.txt index 9369548c67..07b189dd79 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -37,3 +37,4 @@ optuna opencv-python-headless onnx>=1.13.0 onnxruntime; python_version <= '3.10' +zarr diff --git a/docs/source/inferers.rst b/docs/source/inferers.rst index 3bf6af15b0..0011a489f3 100644 --- a/docs/source/inferers.rst +++ b/docs/source/inferers.rst @@ -77,6 +77,11 @@ Mergers :members: :special-members: __call__ +`ZarrAvgMerger` +~~~~~~~~~~~~~~~ +.. autoclass:: ZarrAvgMerger + :members: + :special-members: __call__ Sliding Window Inference Function diff --git a/docs/source/installation.md b/docs/source/installation.md index c3e7297da6..eb7adb06fb 100644 --- a/docs/source/installation.md +++ b/docs/source/installation.md @@ -254,10 +254,10 @@ Since MONAI v0.2.0, the extras syntax such as `pip install 'monai[nibabel]'` is - The options are ``` -[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, clearml, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, ninja, pynrrd, pydicom, h5py, nni, optuna, onnx, onnxruntime] +[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, clearml, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, ninja, pynrrd, pydicom, h5py, nni, optuna, onnx, onnxruntime, zarr] ``` which correspond to `nibabel`, `scikit-image`, `pillow`, `tensorboard`, -`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `clearml`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `ninja`, `pynrrd`, `pydicom`, `h5py`, `nni`, `optuna`, `onnx`, `onnxruntime`, respectively. +`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `clearml`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `ninja`, `pynrrd`, `pydicom`, `h5py`, `nni`, `optuna`, `onnx`, `onnxruntime`, and `zarr` respectively. - `pip install 'monai[all]'` installs all the optional dependencies. diff --git a/monai/inferers/__init__.py b/monai/inferers/__init__.py index bbd361ca79..960380bfb8 100644 --- a/monai/inferers/__init__.py +++ b/monai/inferers/__init__.py @@ -20,6 +20,6 @@ SlidingWindowInferer, SlidingWindowInfererAdapt, ) -from .merger import AvgMerger, Merger +from .merger import AvgMerger, Merger, ZarrAvgMerger from .splitter import SlidingWindowSplitter, Splitter, WSISlidingWindowSplitter from .utils import sliding_window_inference diff --git a/monai/inferers/merger.py b/monai/inferers/merger.py index 63c39aed24..9901951928 100644 --- a/monai/inferers/merger.py +++ b/monai/inferers/merger.py @@ -11,15 +11,24 @@ from __future__ import annotations +import threading from abc import ABC, abstractmethod from collections.abc import Sequence -from typing import Any +from contextlib import nullcontext +from typing import TYPE_CHECKING, Any +import numpy as np import torch -from monai.utils import ensure_tuple_size +from monai.utils import ensure_tuple_size, optional_import, require_pkg -__all__ = ["Merger", "AvgMerger"] +if TYPE_CHECKING: + import zarr +else: + zarr, _ = optional_import("zarr") + + +__all__ = ["Merger", "AvgMerger", "ZarrAvgMerger"] class Merger(ABC): @@ -97,9 +106,9 @@ def __init__( self, merged_shape: Sequence[int], cropped_shape: Sequence[int] | None = None, - device: torch.device | str = "cpu", value_dtype: torch.dtype = torch.float32, count_dtype: torch.dtype = torch.uint8, + device: torch.device | str = "cpu", ) -> None: super().__init__(merged_shape=merged_shape, cropped_shape=cropped_shape, device=device) if not self.merged_shape: @@ -152,12 +161,21 @@ def finalize(self) -> torch.Tensor: return self.values + def get_output(self) -> torch.Tensor: + """ + Get the final merged output. + + Returns: + torch.Tensor: merged output. + """ + return self.finalize() + def get_values(self) -> torch.Tensor: """ Get the accumulated values during aggregation or final averaged values after it is finalized. Returns: - Merged (averaged) output tensor. + torch.tensor: aggregated values. Notes: - If called before calling `finalize()`, this method returns the accumulating values. @@ -170,6 +188,195 @@ def get_counts(self) -> torch.Tensor: Get the aggregator tensor for number of samples. Returns: - torch.Tensor: Number of accumulated samples at each location. + torch.Tensor: number of accumulated samples at each location. """ return self.counts + + +@require_pkg(pkg_name="zarr") +class ZarrAvgMerger(Merger): + """Merge patches by taking average of the overlapping area and store the results in zarr array. + + Zarr is a format for the storage of chunked, compressed, N-dimensional arrays. + Zarr data can be stored in any storage system that can be represented as a key-value store, + like POSIX file systems, cloud object storage, zip files, and relational and document databases. + See https://zarr.readthedocs.io/en/stable/ for more details. + It is particularly useful for storing N-dimensional arrays too large to fit into memory. + One specific use case of this class is to merge patches extracted from whole slide images (WSI), + where the merged results do not fit into memory and need to be stored on a file system. + + Args: + merged_shape: the shape of the tensor required to merge the patches. + cropped_shape: the shape of the final merged output tensor. + If not provided, it will be the same as `merged_shape`. + dtype: the dtype for the final merged result. Default is `float32`. + value_dtype: the dtype for value aggregating tensor and the final result. Default is `float32`. + count_dtype: the dtype for sample counting tensor. Default is `uint8`. + store: the zarr store to save the final results. Default is "merged.zarr". + value_store: the zarr store to save the value aggregating tensor. Default is a temporary store. + count_store: the zarr store to save the sample counting tensor. Default is a temporary store. + compressor: the compressor for final merged zarr array. Default is "default". + value_compressor: the compressor for value aggregating zarr array. Default is None. + count_compressor: the compressor for sample counting zarr array. Default is None. + chunks : int or tuple of ints that defines the chunk shape, or boolean. Default is True. + If True, chunk shape will be guessed from `shape` and `dtype`. + If False, it will be set to `shape`, i.e., single chunk for the whole array. + If an int, the chunk size in each dimension will be given by the value of `chunks`. + """ + + def __init__( + self, + merged_shape: Sequence[int], + cropped_shape: Sequence[int] | None = None, + dtype: np.dtype | str = "float32", + value_dtype: np.dtype | str = "float32", + count_dtype: np.dtype | str = "uint8", + store: zarr.storage.Store | str = "merged.zarr", + value_store: zarr.storage.Store | str | None = None, + count_store: zarr.storage.Store | str | None = None, + compressor: str = "default", + value_compressor: str | None = None, + count_compressor: str | None = None, + chunks: Sequence[int] | bool = True, + thread_locking: bool = True, + ) -> None: + super().__init__(merged_shape=merged_shape, cropped_shape=cropped_shape) + if not self.merged_shape: + raise ValueError(f"`merged_shape` must be provided for `ZarrAvgMerger`. {self.merged_shape} is give.") + self.output_dtype = dtype + self.value_dtype = value_dtype + self.count_dtype = count_dtype + self.store = store + self.value_store = zarr.storage.TempStore() if value_store is None else value_store + self.count_store = zarr.storage.TempStore() if count_store is None else count_store + self.chunks = chunks + self.compressor = compressor + self.value_compressor = value_compressor + self.count_compressor = count_compressor + self.output = zarr.empty( + shape=self.merged_shape, + chunks=self.chunks, + dtype=self.output_dtype, + compressor=self.compressor, + store=self.store, + overwrite=True, + ) + self.values = zarr.zeros( + shape=self.merged_shape, + chunks=self.chunks, + dtype=self.value_dtype, + compressor=self.value_compressor, + store=self.value_store, + overwrite=True, + ) + self.counts = zarr.zeros( + shape=self.merged_shape, + chunks=self.chunks, + dtype=self.count_dtype, + compressor=self.count_compressor, + store=self.count_store, + overwrite=True, + ) + self.lock: threading.Lock | nullcontext + if thread_locking: + # use lock to protect the in-place addition during aggregation + self.lock = threading.Lock() + else: + # use nullcontext to avoid the locking if not needed + self.lock = nullcontext() + + def aggregate(self, values: torch.Tensor, location: Sequence[int]) -> None: + """ + Aggregate values for merging. + + Args: + values: a tensor of shape BCHW[D], representing the values of inference output. + location: a tuple/list giving the top left location of the patch in the original image. + """ + if self.is_finalized: + raise ValueError("`ZarrAvgMerger` is already finalized. Please instantiate a new object to aggregate.") + patch_size = values.shape[2:] + map_slice = tuple(slice(loc, loc + size) for loc, size in zip(location, patch_size)) + map_slice = ensure_tuple_size(map_slice, values.ndim, pad_val=slice(None), pad_from_start=True) + with self.lock: + self.values[map_slice] += values.numpy() + self.counts[map_slice] += 1 + + def finalize(self) -> zarr.Array: + """ + Finalize merging by dividing values by counts and return the merged tensor. + + Notes: + To avoid creating a new tensor for the final results (to save memory space), + after this method is called, `get_values()` method will return the "final" averaged values, + and not the accumulating values. Also calling `finalize()` multiple times does not have any effect. + + Returns: + zarr.Array: a zarr array of of merged patches + """ + # guard against multiple calls to finalize + if not self.is_finalized: + # use chunks for division to fit into memory + for chunk in iterate_over_chunks(self.values.chunks, self.values.cdata_shape): + self.output[chunk] = self.values[chunk] / self.counts[chunk] + # finalize the shape + self.output.resize(self.cropped_shape) + # set finalize flag to protect performing in-place division again + self.is_finalized = True + + return self.output + + def get_output(self) -> zarr.Array: + """ + Get the final merged output. + + Returns: + zarr.Array: Merged (averaged) output tensor. + """ + return self.output + + def get_values(self) -> zarr.Array: + """ + Get the accumulated values during aggregation + + Returns: + zarr.Array: aggregated values. + + """ + return self.values + + def get_counts(self) -> zarr.Array: + """ + Get the aggregator tensor for number of samples. + + Returns: + zarr.Array: Number of accumulated samples at each location. + """ + return self.counts + + +def iterate_over_chunks(chunks, cdata_shape, slice_tuple=()): + """ + Iterate over chunks of a given shape. + + Args: + chunks: the chunk shape + cdata_shape: the shape of the data in chunks + slice_tuple: the slice tuple to be used for indexing + + Raises: + ValueError: When the length of chunks and cdata_shape are not the same. + + Yields: + slices of the data + """ + if len(chunks) != len(cdata_shape): + raise ValueError("chunks and cdata_shape must have the same length") + if len(chunks) == 1: + for i in range(cdata_shape[0]): + yield slice_tuple + (slice(i * chunks[0], (i + 1) * chunks[0]),) + else: + for i in range(cdata_shape[0]): + yield from iterate_over_chunks( + chunks[1:], cdata_shape[1:], slice_tuple + (slice(i * chunks[0], (i + 1) * chunks[0]),) + ) diff --git a/requirements-dev.txt b/requirements-dev.txt index 3f733ac723..78e3b7381a 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -52,3 +52,4 @@ onnx>=1.13.0 onnxruntime; python_version <= '3.10' typeguard<3 # https://github.com/microsoft/nni/issues/5457 filelock!=3.12.0 # https://github.com/microsoft/nni/issues/5523 +zarr diff --git a/setup.cfg b/setup.cfg index c7dcf384b8..c218b133ee 100644 --- a/setup.cfg +++ b/setup.cfg @@ -79,6 +79,7 @@ all = optuna onnx>=1.13.0 onnxruntime; python_version <= '3.10' + zarr nibabel = nibabel ninja = @@ -142,6 +143,8 @@ optuna = onnx = onnx>=1.13.0 onnxruntime; python_version <= '3.10' +zarr = + zarr # # workaround https://github.com/Project-MONAI/MONAI/issues/5882 # MetricsReloaded = # MetricsReloaded @ git+https://github.com/Project-MONAI/MetricsReloaded@monai-support#egg=MetricsReloaded diff --git a/tests/min_tests.py b/tests/min_tests.py index 2fc22452d0..f553dc4a50 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -202,6 +202,7 @@ def run_testsuit(): "test_metrics_reloaded", "test_spatial_combine_transforms", "test_bundle_workflow", + "test_zarr_avg_merger", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" diff --git a/tests/test_download_url_yandex.py b/tests/test_download_url_yandex.py index d0946f9f70..a08105a93f 100644 --- a/tests/test_download_url_yandex.py +++ b/tests/test_download_url_yandex.py @@ -29,6 +29,7 @@ class TestDownloadUrlYandex(unittest.TestCase): + @unittest.skip("data source unstable") def test_verify(self): with tempfile.TemporaryDirectory() as tempdir: download_url(url=YANDEX_MODEL_URL, filepath=os.path.join(tempdir, "model.pt")) diff --git a/tests/test_zarr_avg_merger.py b/tests/test_zarr_avg_merger.py new file mode 100644 index 0000000000..cbc713b442 --- /dev/null +++ b/tests/test_zarr_avg_merger.py @@ -0,0 +1,321 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch +from parameterized import parameterized +from torch.nn.functional import pad + +from monai.inferers import ZarrAvgMerger +from monai.utils import optional_import +from tests.utils import assert_allclose + +np.seterr(divide="ignore", invalid="ignore") +zarr, has_zarr = optional_import("zarr") +numcodecs, has_numcodecs = optional_import("numcodecs") + +TENSOR_4x4 = torch.randint(low=0, high=255, size=(2, 3, 4, 4), dtype=torch.float32) +TENSOR_4x4_WITH_NAN = TENSOR_4x4.clone() +TENSOR_4x4_WITH_NAN[..., 2:, 2:] = float("nan") + +# no-overlapping 2x2 +TEST_CASE_0_DEFAULT_DTYPE = [ + dict(merged_shape=TENSOR_4x4.shape), + [ + (TENSOR_4x4[..., :2, :2], (0, 0)), + (TENSOR_4x4[..., :2, 2:], (0, 2)), + (TENSOR_4x4[..., 2:, :2], (2, 0)), + (TENSOR_4x4[..., 2:, 2:], (2, 2)), + ], + TENSOR_4x4, +] + +# overlapping 2x2 +TEST_CASE_1_DEFAULT_DTYPE = [ + dict(merged_shape=TENSOR_4x4.shape), + [ + (TENSOR_4x4[..., 0:2, 0:2], (0, 0)), + (TENSOR_4x4[..., 0:2, 1:3], (0, 1)), + (TENSOR_4x4[..., 0:2, 2:4], (0, 2)), + (TENSOR_4x4[..., 1:3, 0:2], (1, 0)), + (TENSOR_4x4[..., 1:3, 1:3], (1, 1)), + (TENSOR_4x4[..., 1:3, 2:4], (1, 2)), + (TENSOR_4x4[..., 2:4, 0:2], (2, 0)), + (TENSOR_4x4[..., 2:4, 1:3], (2, 1)), + (TENSOR_4x4[..., 2:4, 2:4], (2, 2)), + ], + TENSOR_4x4, +] + +# overlapping 3x3 (non-divisible) +TEST_CASE_2_DEFAULT_DTYPE = [ + dict(merged_shape=TENSOR_4x4.shape), + [ + (TENSOR_4x4[..., :3, :3], (0, 0)), + (TENSOR_4x4[..., :3, 1:], (0, 1)), + (TENSOR_4x4[..., 1:, :3], (1, 0)), + (TENSOR_4x4[..., 1:, 1:], (1, 1)), + ], + TENSOR_4x4, +] + +# overlapping 2x2 with NaN values +TEST_CASE_3_DEFAULT_DTYPE = [ + dict(merged_shape=TENSOR_4x4_WITH_NAN.shape), + [ + (TENSOR_4x4_WITH_NAN[..., 0:2, 0:2], (0, 0)), + (TENSOR_4x4_WITH_NAN[..., 0:2, 1:3], (0, 1)), + (TENSOR_4x4_WITH_NAN[..., 0:2, 2:4], (0, 2)), + (TENSOR_4x4_WITH_NAN[..., 1:3, 0:2], (1, 0)), + (TENSOR_4x4_WITH_NAN[..., 1:3, 1:3], (1, 1)), + (TENSOR_4x4_WITH_NAN[..., 1:3, 2:4], (1, 2)), + (TENSOR_4x4_WITH_NAN[..., 2:4, 0:2], (2, 0)), + (TENSOR_4x4_WITH_NAN[..., 2:4, 1:3], (2, 1)), + (TENSOR_4x4_WITH_NAN[..., 2:4, 2:4], (2, 2)), + ], + TENSOR_4x4_WITH_NAN, +] + +# non-overlapping 2x2 with missing patch +TEST_CASE_4_DEFAULT_DTYPE = [ + dict(merged_shape=TENSOR_4x4.shape), + [(TENSOR_4x4[..., :2, :2], (0, 0)), (TENSOR_4x4[..., :2, 2:], (0, 2)), (TENSOR_4x4[..., 2:, :2], (2, 0))], + TENSOR_4x4_WITH_NAN, +] + +# with value_dtype set to half precision +TEST_CASE_5_VALUE_DTYPE = [ + dict(merged_shape=TENSOR_4x4.shape, value_dtype=np.float16), + [ + (TENSOR_4x4[..., :2, :2], (0, 0)), + (TENSOR_4x4[..., :2, 2:], (0, 2)), + (TENSOR_4x4[..., 2:, :2], (2, 0)), + (TENSOR_4x4[..., 2:, 2:], (2, 2)), + ], + TENSOR_4x4, +] +# with count_dtype set to int32 +TEST_CASE_6_COUNT_DTYPE = [ + dict(merged_shape=TENSOR_4x4.shape, count_dtype=np.int32), + [ + (TENSOR_4x4[..., :2, :2], (0, 0)), + (TENSOR_4x4[..., :2, 2:], (0, 2)), + (TENSOR_4x4[..., 2:, :2], (2, 0)), + (TENSOR_4x4[..., 2:, 2:], (2, 2)), + ], + TENSOR_4x4, +] +# with both value_dtype, count_dtype set to double precision +TEST_CASE_7_COUNT_VALUE_DTYPE = [ + dict(merged_shape=TENSOR_4x4.shape, value_dtype=np.float64, count_dtype=np.float64), + [ + (TENSOR_4x4[..., :2, :2], (0, 0)), + (TENSOR_4x4[..., :2, 2:], (0, 2)), + (TENSOR_4x4[..., 2:, :2], (2, 0)), + (TENSOR_4x4[..., 2:, 2:], (2, 2)), + ], + TENSOR_4x4, +] +# with both value_dtype, count_dtype set to double precision +TEST_CASE_8_DTYPE = [ + dict(merged_shape=TENSOR_4x4.shape, dtype=np.float64), + [ + (TENSOR_4x4[..., :2, :2], (0, 0)), + (TENSOR_4x4[..., :2, 2:], (0, 2)), + (TENSOR_4x4[..., 2:, :2], (2, 0)), + (TENSOR_4x4[..., 2:, 2:], (2, 2)), + ], + TENSOR_4x4, +] + + +# shape larger than what is covered by patches +TEST_CASE_9_LARGER_SHAPE = [ + dict(merged_shape=(2, 3, 4, 6)), + [ + (TENSOR_4x4[..., :2, :2], (0, 0)), + (TENSOR_4x4[..., :2, 2:], (0, 2)), + (TENSOR_4x4[..., 2:, :2], (2, 0)), + (TENSOR_4x4[..., 2:, 2:], (2, 2)), + ], + pad(TENSOR_4x4, (0, 2), value=float("nan")), +] + + +# explicit directory store +TEST_CASE_10_DIRECTORY_STORE = [ + dict(merged_shape=TENSOR_4x4.shape, store=zarr.storage.DirectoryStore("test.zarr")), + [ + (TENSOR_4x4[..., :2, :2], (0, 0)), + (TENSOR_4x4[..., :2, 2:], (0, 2)), + (TENSOR_4x4[..., 2:, :2], (2, 0)), + (TENSOR_4x4[..., 2:, 2:], (2, 2)), + ], + TENSOR_4x4, +] + +# memory store for all arrays +TEST_CASE_11_MEMORY_STORE = [ + dict( + merged_shape=TENSOR_4x4.shape, + store=zarr.storage.MemoryStore(), + value_store=zarr.storage.MemoryStore(), + count_store=zarr.storage.MemoryStore(), + ), + [ + (TENSOR_4x4[..., :2, :2], (0, 0)), + (TENSOR_4x4[..., :2, 2:], (0, 2)), + (TENSOR_4x4[..., 2:, :2], (2, 0)), + (TENSOR_4x4[..., 2:, 2:], (2, 2)), + ], + TENSOR_4x4, +] + + +# explicit chunk size +TEST_CASE_12_CHUNKS = [ + dict(merged_shape=TENSOR_4x4.shape, chunks=(1, 1, 2, 2)), + [ + (TENSOR_4x4[..., :2, :2], (0, 0)), + (TENSOR_4x4[..., :2, 2:], (0, 2)), + (TENSOR_4x4[..., 2:, :2], (2, 0)), + (TENSOR_4x4[..., 2:, 2:], (2, 2)), + ], + TENSOR_4x4, +] + + +# test for LZ4 compressor +TEST_CASE_13_COMPRESSOR_LZ4 = [ + dict(merged_shape=TENSOR_4x4.shape, compressor="LZ4"), + [ + (TENSOR_4x4[..., :2, :2], (0, 0)), + (TENSOR_4x4[..., :2, 2:], (0, 2)), + (TENSOR_4x4[..., 2:, :2], (2, 0)), + (TENSOR_4x4[..., 2:, 2:], (2, 2)), + ], + TENSOR_4x4, +] + +# test for pickle compressor +TEST_CASE_14_COMPRESSOR_PICKLE = [ + dict(merged_shape=TENSOR_4x4.shape, compressor="Pickle"), + [ + (TENSOR_4x4[..., :2, :2], (0, 0)), + (TENSOR_4x4[..., :2, 2:], (0, 2)), + (TENSOR_4x4[..., 2:, :2], (2, 0)), + (TENSOR_4x4[..., 2:, 2:], (2, 2)), + ], + TENSOR_4x4, +] + +# test for LZMA compressor +TEST_CASE_15_COMPRESSOR_LZMA = [ + dict(merged_shape=TENSOR_4x4.shape, compressor="LZMA"), + [ + (TENSOR_4x4[..., :2, :2], (0, 0)), + (TENSOR_4x4[..., :2, 2:], (0, 2)), + (TENSOR_4x4[..., 2:, :2], (2, 0)), + (TENSOR_4x4[..., 2:, 2:], (2, 2)), + ], + TENSOR_4x4, +] + + +# test with thread locking +TEST_CASE_16_WITH_LOCK = [ + dict(merged_shape=TENSOR_4x4.shape, thread_locking=True), + [ + (TENSOR_4x4[..., :2, :2], (0, 0)), + (TENSOR_4x4[..., :2, 2:], (0, 2)), + (TENSOR_4x4[..., 2:, :2], (2, 0)), + (TENSOR_4x4[..., 2:, 2:], (2, 2)), + ], + TENSOR_4x4, +] + +# test without thread locking +TEST_CASE_17_WITHOUT_LOCK = [ + dict(merged_shape=TENSOR_4x4.shape, thread_locking=False), + [ + (TENSOR_4x4[..., :2, :2], (0, 0)), + (TENSOR_4x4[..., :2, 2:], (0, 2)), + (TENSOR_4x4[..., 2:, :2], (2, 0)), + (TENSOR_4x4[..., 2:, 2:], (2, 2)), + ], + TENSOR_4x4, +] + + +@unittest.skipUnless(has_zarr and has_numcodecs, "Requires zarr (and numcodecs) packages.)") +class ZarrAvgMergerTests(unittest.TestCase): + @parameterized.expand( + [ + TEST_CASE_0_DEFAULT_DTYPE, + TEST_CASE_1_DEFAULT_DTYPE, + TEST_CASE_2_DEFAULT_DTYPE, + TEST_CASE_3_DEFAULT_DTYPE, + TEST_CASE_4_DEFAULT_DTYPE, + TEST_CASE_5_VALUE_DTYPE, + TEST_CASE_6_COUNT_DTYPE, + TEST_CASE_7_COUNT_VALUE_DTYPE, + TEST_CASE_8_DTYPE, + TEST_CASE_9_LARGER_SHAPE, + TEST_CASE_10_DIRECTORY_STORE, + TEST_CASE_11_MEMORY_STORE, + TEST_CASE_12_CHUNKS, + TEST_CASE_13_COMPRESSOR_LZ4, + TEST_CASE_14_COMPRESSOR_PICKLE, + TEST_CASE_15_COMPRESSOR_LZMA, + TEST_CASE_16_WITH_LOCK, + TEST_CASE_17_WITHOUT_LOCK, + ] + ) + def test_zarr_avg_merger_patches(self, arguments, patch_locations, expected): + if "compressor" in arguments: + if arguments["compressor"] != "default": + arguments["compressor"] = zarr.codec_registry[arguments["compressor"].lower()]() + if "value_compressor" in arguments: + if arguments["value_compressor"] != "default": + arguments["value_compressor"] = zarr.codec_registry[arguments["value_compressor"].lower()]() + if "count_compressor" in arguments: + if arguments["count_compressor"] != "default": + arguments["count_compressor"] = zarr.codec_registry[arguments["count_compressor"].lower()]() + merger = ZarrAvgMerger(**arguments) + for pl in patch_locations: + merger.aggregate(pl[0], pl[1]) + output = merger.finalize() + if "value_dtype" in arguments: + self.assertTrue(merger.get_values().dtype, arguments["value_dtype"]) + if "count_dtype" in arguments: + self.assertTrue(merger.get_counts().dtype, arguments["count_dtype"]) + # check for multiple call of finalize + self.assertIs(output, merger.finalize()) + # check if the result is matching the expectation + assert_allclose(output[:], expected.numpy()) + + def test_zarr_avg_merger_finalized_error(self): + with self.assertRaises(ValueError): + merger = ZarrAvgMerger(merged_shape=(1, 3, 2, 3)) + merger.finalize() + merger.aggregate(torch.zeros(1, 3, 2, 2), (3, 3)) + + def test_zarr_avg_merge_none_merged_shape_error(self): + with self.assertRaises(ValueError): + ZarrAvgMerger(merged_shape=None) + + +if __name__ == "__main__": + unittest.main()