diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index 1613d3e645..21b7d9df79 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -27,6 +27,7 @@ from monai.data.utils import ( affine_to_spacing, correct_nifti_header_if_necessary, + is_no_channel, is_supported_format, orientation_ras_lps, ) @@ -162,7 +163,7 @@ def _copy_compatible_dict(from_dict: dict, to_dict: dict): def _stack_images(image_list: list, meta_dict: dict): if len(image_list) <= 1: return image_list[0] - if meta_dict.get(MetaKeys.ORIGINAL_CHANNEL_DIM, None) not in ("no_channel", None): + if not is_no_channel(meta_dict.get(MetaKeys.ORIGINAL_CHANNEL_DIM, None)): channel_dim = int(meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM]) return np.concatenate(image_list, axis=channel_dim) # stack at a new first dim as the channel dim, if `'original_channel_dim'` is unspecified @@ -213,7 +214,7 @@ def __init__( ): super().__init__() self.kwargs = kwargs - self.channel_dim = channel_dim + self.channel_dim = float("nan") if channel_dim == "no_channel" else channel_dim self.series_name = series_name self.reverse_indexing = reverse_indexing self.series_meta = series_meta @@ -305,7 +306,7 @@ def get_data(self, img) -> tuple[np.ndarray, dict]: header[MetaKeys.SPATIAL_SHAPE] = self._get_spatial_shape(i) if self.channel_dim is None: # default to "no_channel" or -1 header[MetaKeys.ORIGINAL_CHANNEL_DIM] = ( - "no_channel" if len(data.shape) == len(header[MetaKeys.SPATIAL_SHAPE]) else -1 + float("nan") if len(data.shape) == len(header[MetaKeys.SPATIAL_SHAPE]) else -1 ) else: header[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim @@ -435,7 +436,7 @@ class PydicomReader(ImageReader): def __init__( self, - channel_dim: int | None = None, + channel_dim: str | int | None = None, affine_lps_to_ras: bool = True, swap_ij: bool = True, prune_metadata: bool = True, @@ -444,7 +445,7 @@ def __init__( ): super().__init__() self.kwargs = kwargs - self.channel_dim = channel_dim + self.channel_dim = float("nan") if channel_dim == "no_channel" else channel_dim self.affine_lps_to_ras = affine_lps_to_ras self.swap_ij = swap_ij self.prune_metadata = prune_metadata @@ -629,7 +630,7 @@ def get_data(self, data) -> tuple[np.ndarray, dict]: metadata[MetaKeys.AFFINE] = affine.copy() if self.channel_dim is None: # default to "no_channel" or -1 metadata[MetaKeys.ORIGINAL_CHANNEL_DIM] = ( - "no_channel" if len(data_array.shape) == len(metadata[MetaKeys.SPATIAL_SHAPE]) else -1 + float("nan") if len(data_array.shape) == len(metadata[MetaKeys.SPATIAL_SHAPE]) else -1 ) else: metadata[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim @@ -883,14 +884,14 @@ class NibabelReader(ImageReader): @deprecated_arg("dtype", since="1.0", msg_suffix="please modify dtype of the returned by ``get_data`` instead.") def __init__( self, - channel_dim: int | None = None, + channel_dim: str | int | None = None, as_closest_canonical: bool = False, squeeze_non_spatial_dims: bool = False, dtype: DtypeLike = np.float32, **kwargs, ): super().__init__() - self.channel_dim = channel_dim + self.channel_dim = float("nan") if channel_dim == "no_channel" else channel_dim self.as_closest_canonical = as_closest_canonical self.squeeze_non_spatial_dims = squeeze_non_spatial_dims self.dtype = dtype # deprecated @@ -965,7 +966,7 @@ def get_data(self, img) -> tuple[np.ndarray, dict]: img_array.append(data) if self.channel_dim is None: # default to "no_channel" or -1 header[MetaKeys.ORIGINAL_CHANNEL_DIM] = ( - "no_channel" if len(data.shape) == len(header[MetaKeys.SPATIAL_SHAPE]) else -1 + float("nan") if len(data.shape) == len(header[MetaKeys.SPATIAL_SHAPE]) else -1 ) else: header[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim @@ -1018,8 +1019,8 @@ def _get_spatial_shape(self, img): dim = np.insert(dim, 0, 3) ndim = dim[0] size = list(dim[1:]) - if self.channel_dim is not None: - size.pop(self.channel_dim) + if not is_no_channel(self.channel_dim): + size.pop(int(self.channel_dim)) # type: ignore spatial_rank = max(min(ndim, 3), 1) return np.asarray(size[:spatial_rank]) @@ -1049,12 +1050,12 @@ class NumpyReader(ImageReader): """ - def __init__(self, npz_keys: KeysCollection | None = None, channel_dim: int | None = None, **kwargs): + def __init__(self, npz_keys: KeysCollection | None = None, channel_dim: str | int | None = None, **kwargs): super().__init__() if npz_keys is not None: npz_keys = ensure_tuple(npz_keys) self.npz_keys = npz_keys - self.channel_dim = channel_dim + self.channel_dim = float("nan") if channel_dim == "no_channel" else channel_dim self.kwargs = kwargs def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: @@ -1126,7 +1127,7 @@ def get_data(self, img) -> tuple[np.ndarray, dict]: header[MetaKeys.SPACE] = SpaceKeys.RAS img_array.append(i) header[MetaKeys.ORIGINAL_CHANNEL_DIM] = ( - self.channel_dim if isinstance(self.channel_dim, int) else "no_channel" + self.channel_dim if isinstance(self.channel_dim, int) else float("nan") ) _copy_compatible_dict(header, compatible_meta) @@ -1214,7 +1215,7 @@ def get_data(self, img) -> tuple[np.ndarray, dict]: data = np.moveaxis(np.asarray(i), 0, 1) if self.reverse_indexing else np.asarray(i) img_array.append(data) header[MetaKeys.ORIGINAL_CHANNEL_DIM] = ( - "no_channel" if len(data.shape) == len(header[MetaKeys.SPATIAL_SHAPE]) else -1 + float("nan") if len(data.shape) == len(header[MetaKeys.SPATIAL_SHAPE]) else -1 ) _copy_compatible_dict(header, compatible_meta) @@ -1532,13 +1533,13 @@ class NrrdReader(ImageReader): def __init__( self, - channel_dim: int | None = None, + channel_dim: str | int | None = None, dtype: np.dtype | type | str | None = np.float32, index_order: str = "F", affine_lps_to_ras: bool = True, **kwargs, ): - self.channel_dim = channel_dim + self.channel_dim = float("nan") if channel_dim == "no_channel" else channel_dim self.dtype = dtype self.index_order = index_order self.affine_lps_to_ras = affine_lps_to_ras @@ -1605,7 +1606,7 @@ def get_data(self, img: NrrdImage | list[NrrdImage]) -> tuple[np.ndarray, dict]: if self.channel_dim is None: # default to "no_channel" or -1 header[MetaKeys.ORIGINAL_CHANNEL_DIM] = ( - "no_channel" if len(data.shape) == len(header[MetaKeys.SPATIAL_SHAPE]) else 0 + float("nan") if len(data.shape) == len(header[MetaKeys.SPATIAL_SHAPE]) else 0 ) else: header[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim diff --git a/monai/data/utils.py b/monai/data/utils.py index 91358b2c63..6501122e2a 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -94,6 +94,7 @@ "remove_extra_metadata", "get_extra_metadata_keys", "PICKLE_KEY_SUFFIX", + "is_no_channel", ] # module to be used by `torch.save` @@ -1529,3 +1530,14 @@ def get_extra_metadata_keys() -> list[str]: # ] return keys + + +def is_no_channel(val) -> bool: + """Returns whether `val` indicates "no_channel", for MetaKeys.ORIGINAL_CHANNEL_DIM.""" + if isinstance(val, torch.Tensor): + return bool(torch.isnan(val)) + if isinstance(val, str): + return val == "no_channel" + if np.isscalar(val): + return bool(np.isnan(val)) + return val is None diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index dac8ea8ad5..3cbaab1430 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -32,7 +32,7 @@ from monai.config.type_definitions import NdarrayOrTensor from monai.data.meta_obj import get_track_meta from monai.data.meta_tensor import MetaTensor -from monai.data.utils import no_collation +from monai.data.utils import is_no_channel, no_collation from monai.networks.layers.simplelayers import ( ApplyFilter, EllipticalFilter, @@ -54,6 +54,7 @@ ) from monai.transforms.utils_pytorch_numpy_unification import concatenate, in1d, moveaxis, unravel_indices from monai.utils import ( + MetaKeys, TraceKeys, convert_data_type, convert_to_cupy, @@ -267,9 +268,9 @@ def __call__(self, img: torch.Tensor, meta_dict: Mapping | None = None) -> torch if isinstance(img, MetaTensor): meta_dict = img.meta - channel_dim = meta_dict.get("original_channel_dim", None) if isinstance(meta_dict, Mapping) else None + channel_dim = meta_dict.get(MetaKeys.ORIGINAL_CHANNEL_DIM, None) if isinstance(meta_dict, Mapping) else None if self.input_channel_dim is not None: - channel_dim = self.input_channel_dim + channel_dim = float("nan") if self.input_channel_dim == "no_channel" else self.input_channel_dim if channel_dim is None: msg = "Unknown original_channel_dim in the MetaTensor meta dict or `meta_dict` or `channel_dim`." @@ -280,12 +281,12 @@ def __call__(self, img: torch.Tensor, meta_dict: Mapping | None = None) -> torch # track the original channel dim if isinstance(meta_dict, dict): - meta_dict["original_channel_dim"] = channel_dim + meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM] = channel_dim - if channel_dim == "no_channel": + if is_no_channel(channel_dim): result = img[None] else: - result = moveaxis(img, channel_dim, 0) # type: ignore + result = moveaxis(img, int(channel_dim), 0) # type: ignore return convert_to_tensor(result, track_meta=get_track_meta()) # type: ignore @@ -371,8 +372,6 @@ def __call__(self, img: torch.Tensor) -> list[torch.Tensor]: Apply the transform to `img`. """ n_out = img.shape[self.dim] - if n_out <= 1: - raise RuntimeError(f"Input image is singleton along dimension to be split, got shape {img.shape}.") if isinstance(img, torch.Tensor): outputs = list(torch.split(img, 1, self.dim)) else: diff --git a/monai/utils/enums.py b/monai/utils/enums.py index 3fb2a5bc6b..77529acdef 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -518,7 +518,7 @@ class MetaKeys(StrEnum): ORIGINAL_AFFINE = "original_affine" # the affine after image loading before any data processing SPATIAL_SHAPE = "spatial_shape" # optional key for the length in each spatial dimension SPACE = "space" # possible values of space type are defined in `SpaceKeys` - ORIGINAL_CHANNEL_DIM = "original_channel_dim" # an integer or "no_channel" + ORIGINAL_CHANNEL_DIM = "original_channel_dim" # an integer or float("nan") class ColorOrder(StrEnum): diff --git a/tests/test_splitdim.py b/tests/test_splitdim.py index 38775dba5c..6c678a6bc2 100644 --- a/tests/test_splitdim.py +++ b/tests/test_splitdim.py @@ -40,13 +40,12 @@ def test_correct_shape(self, shape, keepdim, im_type): arr[0, 0, 0, 0] *= 2 self.assertEqual(arr.flatten()[0], out[0].flatten()[0]) - def test_error(self): - """Should fail because splitting along singleton dimension""" + def test_singleton(self): shape = (2, 1, 8, 7) for p in TEST_NDARRAYS: arr = p(np.random.rand(*shape)) - with self.assertRaises(RuntimeError): - _ = SplitDim(dim=1)(arr) + out = SplitDim(dim=1)(arr) + self.assertEqual(out[0].shape, shape) if __name__ == "__main__": diff --git a/tests/test_splitdimd.py b/tests/test_splitdimd.py index e512de6b03..b01913269d 100644 --- a/tests/test_splitdimd.py +++ b/tests/test_splitdimd.py @@ -40,7 +40,7 @@ def setUpClass(cls) -> None: affine = make_rand_affine() data = {"i": make_nifti_image(arr, affine)} - loader = LoadImaged("i") + loader = LoadImaged("i", image_only=True) cls.data = loader(data) @parameterized.expand(TESTS) @@ -84,13 +84,12 @@ def test_correct(self, keepdim, im_type, update_meta, list_output): arr[0, 0, 0, 0] *= 2 self.assertEqual(arr.flatten()[0], out.flatten()[0]) - def test_error(self): - """Should fail because splitting along singleton dimension""" + def test_singleton(self): shape = (2, 1, 8, 7) for p in TEST_NDARRAYS: arr = p(np.random.rand(*shape)) - with self.assertRaises(RuntimeError): - _ = SplitDimd("i", dim=1)({"i": arr}) + out = SplitDimd("i", dim=1)({"i": arr}) + self.assertEqual(out["i"].shape, shape) if __name__ == "__main__":