diff --git a/monai/inferers/merger.py b/monai/inferers/merger.py index a1ab8e8a56..31e9b5d632 100644 --- a/monai/inferers/merger.py +++ b/monai/inferers/merger.py @@ -21,7 +21,14 @@ import numpy as np import torch -from monai.utils import ensure_tuple_size, get_package_version, optional_import, require_pkg, version_geq +from monai.utils import ( + deprecated_arg, + ensure_tuple_size, + get_package_version, + optional_import, + require_pkg, + version_geq, +) if TYPE_CHECKING: import zarr @@ -218,15 +225,41 @@ class ZarrAvgMerger(Merger): store: the zarr store to save the final results. Default is "merged.zarr". value_store: the zarr store to save the value aggregating tensor. Default is a temporary store. count_store: the zarr store to save the sample counting tensor. Default is a temporary store. - compressor: the compressor for final merged zarr array. Default is "default". + compressor: the compressor for final merged zarr array. Default is None. + Deprecated since 1.5.0 and will be removed in 1.7.0. Use codecs instead. value_compressor: the compressor for value aggregating zarr array. Default is None. + Deprecated since 1.5.0 and will be removed in 1.7.0. Use value_codecs instead. count_compressor: the compressor for sample counting zarr array. Default is None. + Deprecated since 1.5.0 and will be removed in 1.7.0. Use count_codecs instead. + codecs: the codecs for final merged zarr array. Default is None. + For zarr v3, this is a list of codec configurations. See zarr documentation for details. + value_codecs: the codecs for value aggregating zarr array. Default is None. + For zarr v3, this is a list of codec configurations. See zarr documentation for details. + count_codecs: the codecs for sample counting zarr array. Default is None. + For zarr v3, this is a list of codec configurations. See zarr documentation for details. chunks : int or tuple of ints that defines the chunk shape, or boolean. Default is True. If True, chunk shape will be guessed from `shape` and `dtype`. If False, it will be set to `shape`, i.e., single chunk for the whole array. If an int, the chunk size in each dimension will be given by the value of `chunks`. """ + @deprecated_arg( + name="compressor", since="1.5.0", removed="1.7.0", new_name="codecs", msg_suffix="Please use 'codecs' instead." + ) + @deprecated_arg( + name="value_compressor", + since="1.5.0", + removed="1.7.0", + new_name="value_codecs", + msg_suffix="Please use 'value_codecs' instead.", + ) + @deprecated_arg( + name="count_compressor", + since="1.5.0", + removed="1.7.0", + new_name="count_codecs", + msg_suffix="Please use 'count_codecs' instead.", + ) def __init__( self, merged_shape: Sequence[int], @@ -240,6 +273,9 @@ def __init__( compressor: str | None = None, value_compressor: str | None = None, count_compressor: str | None = None, + codecs: list | None = None, + value_codecs: list | None = None, + count_codecs: list | None = None, chunks: Sequence[int] | bool = True, thread_locking: bool = True, ) -> None: @@ -251,7 +287,11 @@ def __init__( self.count_dtype = count_dtype self.store = store self.tmpdir: TemporaryDirectory | None - if version_geq(get_package_version("zarr"), "3.0.0"): + + # Handle zarr v3 vs older versions + is_zarr_v3 = version_geq(get_package_version("zarr"), "3.0.0") + + if is_zarr_v3: if value_store is None: self.tmpdir = TemporaryDirectory() self.value_store = zarr.storage.LocalStore(self.tmpdir.name) # type: ignore @@ -266,34 +306,119 @@ def __init__( self.tmpdir = None self.value_store = zarr.storage.TempStore() if value_store is None else value_store # type: ignore self.count_store = zarr.storage.TempStore() if count_store is None else count_store # type: ignore + self.chunks = chunks - self.compressor = compressor - self.value_compressor = value_compressor - self.count_compressor = count_compressor - self.output = zarr.empty( - shape=self.merged_shape, - chunks=self.chunks, - dtype=self.output_dtype, - compressor=self.compressor, - store=self.store, - overwrite=True, - ) - self.values = zarr.zeros( - shape=self.merged_shape, - chunks=self.chunks, - dtype=self.value_dtype, - compressor=self.value_compressor, - store=self.value_store, - overwrite=True, - ) - self.counts = zarr.zeros( - shape=self.merged_shape, - chunks=self.chunks, - dtype=self.count_dtype, - compressor=self.count_compressor, - store=self.count_store, - overwrite=True, - ) + + # Handle compressor/codecs based on zarr version + is_zarr_v3 = version_geq(get_package_version("zarr"), "3.0.0") + + # Initialize codecs/compressor attributes with proper types + self.codecs: list | None = None + self.value_codecs: list | None = None + self.count_codecs: list | None = None + + if is_zarr_v3: + # For zarr v3, use codecs or convert compressor to codecs + if codecs is not None: + self.codecs = codecs + elif compressor is not None: + # Convert compressor to codec format + if isinstance(compressor, (list, tuple)): + self.codecs = compressor + else: + self.codecs = [compressor] + else: + self.codecs = None + + if value_codecs is not None: + self.value_codecs = value_codecs + elif value_compressor is not None: + if isinstance(value_compressor, (list, tuple)): + self.value_codecs = value_compressor + else: + self.value_codecs = [value_compressor] + else: + self.value_codecs = None + + if count_codecs is not None: + self.count_codecs = count_codecs + elif count_compressor is not None: + if isinstance(count_compressor, (list, tuple)): + self.count_codecs = count_compressor + else: + self.count_codecs = [count_compressor] + else: + self.count_codecs = None + else: + # For zarr v2, use compressors + if codecs is not None: + # If codecs are specified in v2, use the first codec as compressor + self.codecs = codecs[0] if isinstance(codecs, (list, tuple)) else codecs + else: + self.codecs = compressor # type: ignore[assignment] + + if value_codecs is not None: + self.value_codecs = value_codecs[0] if isinstance(value_codecs, (list, tuple)) else value_codecs + else: + self.value_codecs = value_compressor # type: ignore[assignment] + + if count_codecs is not None: + self.count_codecs = count_codecs[0] if isinstance(count_codecs, (list, tuple)) else count_codecs + else: + self.count_codecs = count_compressor # type: ignore[assignment] + + # Create zarr arrays with appropriate parameters based on version + if is_zarr_v3: + self.output = zarr.empty( + shape=self.merged_shape, + chunks=self.chunks, + dtype=self.output_dtype, + codecs=self.codecs, + store=self.store, + overwrite=True, + ) + self.values = zarr.zeros( + shape=self.merged_shape, + chunks=self.chunks, + dtype=self.value_dtype, + codecs=self.value_codecs, + store=self.value_store, + overwrite=True, + ) + self.counts = zarr.zeros( + shape=self.merged_shape, + chunks=self.chunks, + dtype=self.count_dtype, + codecs=self.count_codecs, + store=self.count_store, + overwrite=True, + ) + else: + self.output = zarr.empty( + shape=self.merged_shape, + chunks=self.chunks, + dtype=self.output_dtype, + compressor=self.codecs, + store=self.store, + overwrite=True, + ) + self.values = zarr.zeros( + shape=self.merged_shape, + chunks=self.chunks, + dtype=self.value_dtype, + compressor=self.value_codecs, + store=self.value_store, + overwrite=True, + ) + self.counts = zarr.zeros( + shape=self.merged_shape, + chunks=self.chunks, + dtype=self.count_dtype, + compressor=self.count_codecs, + store=self.count_store, + overwrite=True, + ) + self.lock: threading.Lock | nullcontext if thread_locking: # use lock to protect the in-place addition during aggregation diff --git a/tests/inferers/test_zarr_avg_merger.py b/tests/inferers/test_zarr_avg_merger.py index b5ba1d9902..8dfd4cd96b 100644 --- a/tests/inferers/test_zarr_avg_merger.py +++ b/tests/inferers/test_zarr_avg_merger.py @@ -24,6 +24,7 @@ np.seterr(divide="ignore", invalid="ignore") zarr, has_zarr = optional_import("zarr") +print(version_geq(get_package_version("zarr"), "3.0.0")) if has_zarr: if version_geq(get_package_version("zarr"), "3.0.0"): directory_store = zarr.storage.LocalStore("test.zarr") @@ -200,9 +201,20 @@ TENSOR_4x4, ] -# test for LZ4 compressor +# Define zarr v3 codec configurations with proper bytes codec +ZARR_V3_LZ4_CODECS = [{"name": "bytes", "configuration": {}}, {"name": "blosc", "configuration": {"cname": "lz4"}}] + +ZARR_V3_PICKLE_CODECS = [{"name": "bytes", "configuration": {}}, {"name": "blosc", "configuration": {"cname": "zstd"}}] + +ZARR_V3_LZMA_CODECS = [{"name": "bytes", "configuration": {}}, {"name": "blosc", "configuration": {"cname": "zlib"}}] + +# test for LZ4 compressor (zarr v2) or codecs (zarr v3) TEST_CASE_13_COMPRESSOR_LZ4 = [ - dict(merged_shape=TENSOR_4x4.shape, compressor="LZ4"), + ( + dict(merged_shape=TENSOR_4x4.shape, compressor="LZ4") + if not version_geq(get_package_version("zarr"), "3.0.0") + else dict(merged_shape=TENSOR_4x4.shape, codecs=ZARR_V3_LZ4_CODECS) + ), [ (TENSOR_4x4[..., :2, :2], (0, 0)), (TENSOR_4x4[..., :2, 2:], (0, 2)), @@ -212,9 +224,13 @@ TENSOR_4x4, ] -# test for pickle compressor +# test for pickle compressor (zarr v2) or codecs (zarr v3) TEST_CASE_14_COMPRESSOR_PICKLE = [ - dict(merged_shape=TENSOR_4x4.shape, compressor="Pickle"), + ( + dict(merged_shape=TENSOR_4x4.shape, compressor="Pickle") + if not version_geq(get_package_version("zarr"), "3.0.0") + else dict(merged_shape=TENSOR_4x4.shape, codecs=ZARR_V3_PICKLE_CODECS) + ), [ (TENSOR_4x4[..., :2, :2], (0, 0)), (TENSOR_4x4[..., :2, 2:], (0, 2)), @@ -224,9 +240,13 @@ TENSOR_4x4, ] -# test for LZMA compressor +# test for LZMA compressor (zarr v2) or codecs (zarr v3) TEST_CASE_15_COMPRESSOR_LZMA = [ - dict(merged_shape=TENSOR_4x4.shape, compressor="LZMA"), + ( + dict(merged_shape=TENSOR_4x4.shape, compressor="LZMA") + if not version_geq(get_package_version("zarr"), "3.0.0") + else dict(merged_shape=TENSOR_4x4.shape, codecs=ZARR_V3_LZMA_CODECS) + ), [ (TENSOR_4x4[..., :2, :2], (0, 0)), (TENSOR_4x4[..., :2, 2:], (0, 2)), @@ -260,6 +280,48 @@ TENSOR_4x4, ] +# test with codecs for zarr v3 +TEST_CASE_18_CODECS = [ + dict(merged_shape=TENSOR_4x4.shape, codecs=ZARR_V3_LZ4_CODECS), + [ + (TENSOR_4x4[..., :2, :2], (0, 0)), + (TENSOR_4x4[..., :2, 2:], (0, 2)), + (TENSOR_4x4[..., 2:, :2], (2, 0)), + (TENSOR_4x4[..., 2:, 2:], (2, 2)), + ], + TENSOR_4x4, +] + +# test with value_codecs for zarr v3 +TEST_CASE_19_VALUE_CODECS = [ + dict( + merged_shape=TENSOR_4x4.shape, + value_codecs=[{"name": "bytes", "configuration": {}}, {"name": "blosc", "configuration": {"cname": "zstd"}}], + ), + [ + (TENSOR_4x4[..., :2, :2], (0, 0)), + (TENSOR_4x4[..., :2, 2:], (0, 2)), + (TENSOR_4x4[..., 2:, :2], (2, 0)), + (TENSOR_4x4[..., 2:, 2:], (2, 2)), + ], + TENSOR_4x4, +] + +# test with count_codecs for zarr v3 +TEST_CASE_20_COUNT_CODECS = [ + dict( + merged_shape=TENSOR_4x4.shape, + count_codecs=[{"name": "bytes", "configuration": {}}, {"name": "blosc", "configuration": {"cname": "zlib"}}], + ), + [ + (TENSOR_4x4[..., :2, :2], (0, 0)), + (TENSOR_4x4[..., :2, 2:], (0, 2)), + (TENSOR_4x4[..., 2:, :2], (2, 0)), + (TENSOR_4x4[..., 2:, 2:], (2, 2)), + ], + TENSOR_4x4, +] + ALL_TESTS = [ TEST_CASE_0_DEFAULT_DTYPE, TEST_CASE_1_DEFAULT_DTYPE, @@ -276,11 +338,15 @@ TEST_CASE_12_CHUNKS, TEST_CASE_16_WITH_LOCK, TEST_CASE_17_WITHOUT_LOCK, + # Add compression/codec tests regardless of zarr version - they're now version-aware + TEST_CASE_13_COMPRESSOR_LZ4, + TEST_CASE_14_COMPRESSOR_PICKLE, + TEST_CASE_15_COMPRESSOR_LZMA, ] -# add compression tests only when using Zarr version before 3.0 -if not version_geq(get_package_version("zarr"), "3.0.0"): - ALL_TESTS += [TEST_CASE_13_COMPRESSOR_LZ4, TEST_CASE_14_COMPRESSOR_PICKLE, TEST_CASE_15_COMPRESSOR_LZMA] +# Add zarr v3 specific codec tests only when using Zarr version 3.0 or later +if version_geq(get_package_version("zarr"), "3.0.0"): + ALL_TESTS += [TEST_CASE_18_CODECS, TEST_CASE_19_VALUE_CODECS, TEST_CASE_20_COUNT_CODECS] @unittest.skipUnless(has_zarr and has_numcodecs, "Requires zarr (and numcodecs) packages.)") @@ -288,16 +354,57 @@ class ZarrAvgMergerTests(unittest.TestCase): @parameterized.expand(ALL_TESTS) def test_zarr_avg_merger_patches(self, arguments, patch_locations, expected): + is_zarr_v3 = version_geq(get_package_version("zarr"), "3.0.0") codec_reg = numcodecs.registry.codec_registry - if "compressor" in arguments: - if arguments["compressor"] != "default": + + # Handle compressor/codecs based on zarr version + if "compressor" in arguments and is_zarr_v3: + # For zarr v3, convert compressor to codecs + if arguments["compressor"] != "default" and arguments["compressor"] is not None: + compressor_name = arguments["compressor"].lower() + if compressor_name == "lz4": + arguments["codecs"] = ZARR_V3_LZ4_CODECS + elif compressor_name == "pickle": + arguments["codecs"] = ZARR_V3_PICKLE_CODECS + elif compressor_name == "lzma": + arguments["codecs"] = ZARR_V3_LZMA_CODECS + # Remove compressor as it's not supported in zarr v3 + del arguments["compressor"] + elif "compressor" in arguments and not is_zarr_v3: + # For zarr v2, use the compressor registry + if arguments["compressor"] != "default" and arguments["compressor"] is not None: arguments["compressor"] = codec_reg[arguments["compressor"].lower()]() - if "value_compressor" in arguments: - if arguments["value_compressor"] != "default": + + # Same for value_compressor + if "value_compressor" in arguments and is_zarr_v3: + if arguments["value_compressor"] != "default" and arguments["value_compressor"] is not None: + compressor_name = arguments["value_compressor"].lower() + if compressor_name == "lz4": + arguments["value_codecs"] = ZARR_V3_LZ4_CODECS + elif compressor_name == "pickle": + arguments["value_codecs"] = ZARR_V3_PICKLE_CODECS + elif compressor_name == "lzma": + arguments["value_codecs"] = ZARR_V3_LZMA_CODECS + del arguments["value_compressor"] + elif "value_compressor" in arguments and not is_zarr_v3: + if arguments["value_compressor"] != "default" and arguments["value_compressor"] is not None: arguments["value_compressor"] = codec_reg[arguments["value_compressor"].lower()]() - if "count_compressor" in arguments: - if arguments["count_compressor"] != "default": + + # Same for count_compressor + if "count_compressor" in arguments and is_zarr_v3: + if arguments["count_compressor"] != "default" and arguments["count_compressor"] is not None: + compressor_name = arguments["count_compressor"].lower() + if compressor_name == "lz4": + arguments["count_codecs"] = ZARR_V3_LZ4_CODECS + elif compressor_name == "pickle": + arguments["count_codecs"] = ZARR_V3_PICKLE_CODECS + elif compressor_name == "lzma": + arguments["count_codecs"] = ZARR_V3_LZMA_CODECS + del arguments["count_compressor"] + elif "count_compressor" in arguments and not is_zarr_v3: + if arguments["count_compressor"] != "default" and arguments["count_compressor"] is not None: arguments["count_compressor"] = codec_reg[arguments["count_compressor"].lower()]() + merger = ZarrAvgMerger(**arguments) for pl in patch_locations: merger.aggregate(pl[0], pl[1]) @@ -320,7 +427,3 @@ def test_zarr_avg_merger_finalized_error(self): def test_zarr_avg_merge_none_merged_shape_error(self): with self.assertRaises(ValueError): ZarrAvgMerger(merged_shape=None) - - -if __name__ == "__main__": - unittest.main()