diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index b0ba1e39d9..2ea7e3aa63 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -518,4 +518,4 @@ weighted_patch_samples, zero_margins, ) -from .utils_pytorch_numpy_unification import moveaxis +from .utils_pytorch_numpy_unification import in1d, moveaxis diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index f38a94302e..918763405f 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -22,7 +22,7 @@ import numpy as np import torch -from monai.config import DtypeLike, NdarrayTensor +from monai.config import DtypeLike from monai.config.type_definitions import NdarrayOrTensor from monai.transforms.transform import Randomizable, RandomizableTransform, Transform from monai.transforms.utils import ( @@ -31,9 +31,10 @@ map_binary_to_indices, map_classes_to_indices, ) -from monai.transforms.utils_pytorch_numpy_unification import moveaxis +from monai.transforms.utils_pytorch_numpy_unification import in1d, moveaxis from monai.utils import convert_to_numpy, convert_to_tensor, ensure_tuple, look_up_option, min_version, optional_import from monai.utils.enums import TransformBackends +from monai.utils.misc import is_module_ver_at_least from monai.utils.type_conversion import convert_data_type PILImageImage, has_pil = optional_import("PIL.Image", name="Image") @@ -445,6 +446,8 @@ class SqueezeDim(Transform): Squeeze a unitary dimension. """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__(self, dim: Optional[int] = 0) -> None: """ Args: @@ -459,12 +462,17 @@ def __init__(self, dim: Optional[int] = 0) -> None: raise TypeError(f"dim must be None or a int but is {type(dim).__name__}.") self.dim = dim - def __call__(self, img: NdarrayTensor) -> NdarrayTensor: + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Args: img: numpy arrays with required dimension `dim` removed """ - return img.squeeze(self.dim) # type: ignore + if self.dim is None: + return img.squeeze() + # for pytorch/numpy unification + if img.shape[self.dim] != 1: + raise ValueError("Can only squeeze singleton dimension") + return img.squeeze(self.dim) class DataStats(Transform): @@ -475,6 +483,8 @@ class DataStats(Transform): so it can be used in pre-processing and post-processing. """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__( self, prefix: str = "Data", @@ -523,14 +533,14 @@ def __init__( def __call__( self, - img: NdarrayTensor, + img: NdarrayOrTensor, prefix: Optional[str] = None, data_type: Optional[bool] = None, data_shape: Optional[bool] = None, value_range: Optional[bool] = None, data_value: Optional[bool] = None, additional_info: Optional[Callable] = None, - ) -> NdarrayTensor: + ) -> NdarrayOrTensor: """ Apply the transform to `img`, optionally take arguments similar to the class constructor. """ @@ -570,6 +580,8 @@ class SimulateDelay(Transform): to sub-optimal design choices. """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__(self, delay_time: float = 0.0) -> None: """ Args: @@ -579,7 +591,7 @@ def __init__(self, delay_time: float = 0.0) -> None: super().__init__() self.delay_time: float = delay_time - def __call__(self, img: NdarrayTensor, delay_time: Optional[float] = None) -> NdarrayTensor: + def __call__(self, img: NdarrayOrTensor, delay_time: Optional[float] = None) -> NdarrayOrTensor: """ Args: img: data remain unchanged throughout this transform. @@ -612,12 +624,14 @@ class Lambda(Transform): """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__(self, func: Optional[Callable] = None) -> None: if func is not None and not callable(func): raise TypeError(f"func must be None or callable but is {type(func).__name__}.") self.func = func - def __call__(self, img: Union[np.ndarray, torch.Tensor], func: Optional[Callable] = None): + def __call__(self, img: NdarrayOrTensor, func: Optional[Callable] = None): """ Apply `self.func` to `img`. @@ -648,14 +662,15 @@ class RandLambda(Lambda, RandomizableTransform): prob: probability of executing the random function, default to 1.0, with 100% probability to execute. For more details, please check :py:class:`monai.transforms.Lambda`. - """ + backend = Lambda.backend + def __init__(self, func: Optional[Callable] = None, prob: float = 1.0) -> None: Lambda.__init__(self=self, func=func) RandomizableTransform.__init__(self=self, prob=prob) - def __call__(self, img: Union[np.ndarray, torch.Tensor], func: Optional[Callable] = None): + def __call__(self, img: NdarrayOrTensor, func: Optional[Callable] = None): self.randomize(img) return super().__call__(img=img, func=func) if self._do_transform else img @@ -679,6 +694,8 @@ class LabelToMask(Transform): """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__( # pytype: disable=annotation-type-mismatch self, select_labels: Union[Sequence[int], int], @@ -688,8 +705,11 @@ def __init__( # pytype: disable=annotation-type-mismatch self.merge_channels = merge_channels def __call__( - self, img: np.ndarray, select_labels: Optional[Union[Sequence[int], int]] = None, merge_channels: bool = False - ): + self, + img: NdarrayOrTensor, + select_labels: Optional[Union[Sequence[int], int]] = None, + merge_channels: bool = False, + ) -> NdarrayOrTensor: """ Args: select_labels: labels to generate mask from. for 1 channel label, the `select_labels` @@ -706,26 +726,40 @@ def __call__( if img.shape[0] > 1: data = img[[*select_labels]] else: - data = np.where(np.in1d(img, select_labels), True, False).reshape(img.shape) + where = np.where if isinstance(img, np.ndarray) else torch.where + if isinstance(img, np.ndarray) or is_module_ver_at_least(torch, (1, 8, 0)): + data = where(in1d(img, select_labels), True, False).reshape(img.shape) + # pre pytorch 1.8.0, need to use 1/0 instead of True/False + else: + data = where( + in1d(img, select_labels), torch.tensor(1, device=img.device), torch.tensor(0, device=img.device) + ).reshape(img.shape) - return np.any(data, axis=0, keepdims=True) if (merge_channels or self.merge_channels) else data + if merge_channels or self.merge_channels: + if isinstance(img, np.ndarray) or is_module_ver_at_least(torch, (1, 8, 0)): + return data.any(0)[None] + # pre pytorch 1.8.0 compatibility + return data.to(torch.uint8).any(0)[None].to(bool) # type: ignore + + return data class FgBgToIndices(Transform): - def __init__(self, image_threshold: float = 0.0, output_shape: Optional[Sequence[int]] = None) -> None: - """ - Compute foreground and background of the input label data, return the indices. - If no output_shape specified, output data will be 1 dim indices after flattening. - This transform can help pre-compute foreground and background regions for other transforms. - A typical usage is to randomly select foreground and background to crop. - The main logic is based on :py:class:`monai.transforms.utils.map_binary_to_indices`. + """ + Compute foreground and background of the input label data, return the indices. + If no output_shape specified, output data will be 1 dim indices after flattening. + This transform can help pre-compute foreground and background regions for other transforms. + A typical usage is to randomly select foreground and background to crop. + The main logic is based on :py:class:`monai.transforms.utils.map_binary_to_indices`. - Args: - image_threshold: if enabled `image` at runtime, use ``image > image_threshold`` to - determine the valid image content area and select background only in this area. - output_shape: expected shape of output indices. if not None, unravel indices to specified shape. + Args: + image_threshold: if enabled `image` at runtime, use ``image > image_threshold`` to + determine the valid image content area and select background only in this area. + output_shape: expected shape of output indices. if not None, unravel indices to specified shape. - """ + """ + + def __init__(self, image_threshold: float = 0.0, output_shape: Optional[Sequence[int]] = None) -> None: self.image_threshold = image_threshold self.output_shape = output_shape diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 1b63b308d9..e9bcce93b0 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -23,7 +23,7 @@ import numpy as np import torch -from monai.config import DtypeLike, KeysCollection, NdarrayTensor +from monai.config import DtypeLike, KeysCollection from monai.config.type_definitions import NdarrayOrTensor from monai.data.utils import no_collation from monai.transforms.inverse import InvertibleTransform @@ -59,7 +59,7 @@ ) from monai.transforms.utils import extreme_points_to_image, get_extreme_points from monai.utils import convert_to_numpy, ensure_tuple, ensure_tuple_rep -from monai.utils.enums import InverseKeys +from monai.utils.enums import InverseKeys, TransformBackends __all__ = [ "AddChannelD", @@ -650,6 +650,8 @@ class SqueezeDimd(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.SqueezeDim`. """ + backend = SqueezeDim.backend + def __init__(self, keys: KeysCollection, dim: int = 0, allow_missing_keys: bool = False) -> None: """ Args: @@ -661,7 +663,7 @@ def __init__(self, keys: KeysCollection, dim: int = 0, allow_missing_keys: bool super().__init__(keys, allow_missing_keys) self.converter = SqueezeDim(dim=dim) - def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -673,6 +675,8 @@ class DataStatsd(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.DataStats`. """ + backend = DataStats.backend + def __init__( self, keys: KeysCollection, @@ -719,7 +723,7 @@ def __init__( self.logger_handler = logger_handler self.printer = DataStats(logger_handler=logger_handler) - def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key, prefix, data_type, data_shape, value_range, data_value, additional_info in self.key_iterator( d, self.prefix, self.data_type, self.data_shape, self.value_range, self.data_value, self.additional_info @@ -741,6 +745,8 @@ class SimulateDelayd(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.SimulateDelay`. """ + backend = SimulateDelay.backend + def __init__( self, keys: KeysCollection, delay_time: Union[Sequence[float], float] = 0.0, allow_missing_keys: bool = False ) -> None: @@ -757,7 +763,7 @@ def __init__( self.delay_time = ensure_tuple_rep(delay_time, len(self.keys)) self.delayer = SimulateDelay() - def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key, delay_time in self.key_iterator(d, self.delay_time): d[key] = self.delayer(d[key], delay_time=delay_time) @@ -768,9 +774,10 @@ class CopyItemsd(MapTransform): """ Copy specified items from data dictionary and save with different key names. It can copy several items together and copy several times. - """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__( self, keys: KeysCollection, times: int, names: KeysCollection, allow_missing_keys: bool = False ) -> None: @@ -802,7 +809,7 @@ def __init__( ) self.names = names - def __call__(self, data): + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: """ Raises: KeyError: When a key in ``self.names`` already exists in ``data``. @@ -814,10 +821,11 @@ def __call__(self, data): for key, new_key in self.key_iterator(d, self.names[i * key_len : (i + 1) * key_len]): if new_key in d: raise KeyError(f"Key {new_key} already exists in data.") - if isinstance(d[key], torch.Tensor): - d[new_key] = d[key].detach().clone() + val = d[key] + if isinstance(val, torch.Tensor): + d[new_key] = val.detach().clone() else: - d[new_key] = copy.deepcopy(d[key]) + d[new_key] = copy.deepcopy(val) return d @@ -825,9 +833,10 @@ class ConcatItemsd(MapTransform): """ Concatenate specified items from data dictionary together on the first dim to construct a big array. Expect all the items are numpy array or PyTorch Tensor. - """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__(self, keys: KeysCollection, name: str, dim: int = 0, allow_missing_keys: bool = False) -> None: """ Args: @@ -841,7 +850,7 @@ def __init__(self, keys: KeysCollection, name: str, dim: int = 0, allow_missing_ self.name = name self.dim = dim - def __call__(self, data): + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: """ Raises: TypeError: When items in ``data`` differ in type. @@ -857,10 +866,10 @@ def __call__(self, data): elif not isinstance(d[key], data_type): raise TypeError("All items in data must have the same type.") output.append(d[key]) - if data_type == np.ndarray: + if data_type is np.ndarray: d[self.name] = np.concatenate(output, axis=self.dim) - elif data_type == torch.Tensor: - d[self.name] = torch.cat(output, dim=self.dim) + elif data_type is torch.Tensor: + d[self.name] = torch.cat(output, dim=self.dim) # type: ignore else: raise TypeError(f"Unsupported data type: {data_type}, available options are (numpy.ndarray, torch.Tensor).") return d @@ -896,6 +905,8 @@ class Lambdad(MapTransform, InvertibleTransform): """ + backend = Lambda.backend + def __init__( self, keys: KeysCollection, @@ -913,7 +924,7 @@ def __init__( def _transform(self, data: Any, func: Callable): return self._lambd(data, func=func) - def __call__(self, data): + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key, func, overwrite in self.key_iterator(d, self.func, self.overwrite): ret = self._transform(data=d[key], func=func) @@ -958,9 +969,10 @@ class RandLambdad(Lambdad, RandomizableTransform): Note: The inverse operation doesn't allow to define `extra_info` or access other information, such as the image's original size. If need these complicated information, please write a new InvertibleTransform directly. - """ + backend = Lambda.backend + def __init__( self, keys: KeysCollection, @@ -1007,6 +1019,8 @@ class LabelToMaskd(MapTransform): """ + backend = LabelToMask.backend + def __init__( # pytype: disable=annotation-type-mismatch self, keys: KeysCollection, @@ -1017,7 +1031,7 @@ def __init__( # pytype: disable=annotation-type-mismatch super().__init__(keys, allow_missing_keys) self.converter = LabelToMask(select_labels=select_labels, merge_channels=merge_channels) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) diff --git a/monai/transforms/utils_pytorch_numpy_unification.py b/monai/transforms/utils_pytorch_numpy_unification.py index e6dc151596..2eebe3eda3 100644 --- a/monai/transforms/utils_pytorch_numpy_unification.py +++ b/monai/transforms/utils_pytorch_numpy_unification.py @@ -16,26 +16,37 @@ __all__ = [ "moveaxis", + "in1d", ] def moveaxis(x: NdarrayOrTensor, src: int, dst: int) -> NdarrayOrTensor: + """`moveaxis` for pytorch and numpy, using `permute` for pytorch ver < 1.8""" if isinstance(x, torch.Tensor): if hasattr(torch, "moveaxis"): return torch.moveaxis(x, src, dst) - # moveaxis only available in pytorch as of 1.8.0 - else: - # get original indices - indices = list(range(x.ndim)) - # make src and dst positive - if src < 0: - src = len(indices) + src - if dst < 0: - dst = len(indices) + dst - # remove desired index and insert it in new position - indices.pop(src) - indices.insert(dst, src) - return x.permute(indices) - elif isinstance(x, np.ndarray): + return _moveaxis_with_permute(x, src, dst) # type: ignore + if isinstance(x, np.ndarray): return np.moveaxis(x, src, dst) raise RuntimeError() + + +def _moveaxis_with_permute(x, src, dst): + # get original indices + indices = list(range(x.ndim)) + # make src and dst positive + if src < 0: + src = len(indices) + src + if dst < 0: + dst = len(indices) + dst + # remove desired index and insert it in new position + indices.pop(src) + indices.insert(dst, src) + return x.permute(indices) + + +def in1d(x, y): + """`np.in1d` with equivalent implementation for torch.""" + if isinstance(x, np.ndarray): + return np.in1d(x, y) + return (x[..., None] == torch.tensor(y, device=x.device)).any(-1).view(-1) diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 0ea5afc40c..aa8f02f815 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -46,6 +46,7 @@ first, get_seed, has_option, + is_module_ver_at_least, is_scalar, is_scalar_tensor, issequenceiterable, diff --git a/monai/utils/misc.py b/monai/utils/misc.py index 66f6557032..3b287b3fe4 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -22,7 +22,7 @@ import numpy as np import torch -from monai.utils.module import get_torch_version_tuple +from monai.utils.module import get_torch_version_tuple, version_leq __all__ = [ "zip_with", @@ -42,6 +42,7 @@ "MAX_SEED", "copy_to_device", "ImageMetaKey", + "is_module_ver_at_least", ] _seed = None @@ -355,3 +356,16 @@ def has_option(obj, keywords: Union[str, Sequence[str]]) -> bool: return False sig = inspect.signature(obj) return all(key in sig.parameters for key in ensure_tuple(keywords)) + + +def is_module_ver_at_least(module, version): + """Determine if a module's version is at least equal to the given value. + + Args: + module: imported module's name, e.g., `np` or `torch`. + version: required version, given as a tuple, e.g., `(1, 8, 0)`. + Returns: + `True` if module is the given version or newer. + """ + test_ver = ".".join(map(str, version)) + return module.__version__ != test_ver and version_leq(test_ver, module.__version__) diff --git a/tests/test_data_stats.py b/tests/test_data_stats.py index 43068797a3..50536f2a5c 100644 --- a/tests/test_data_stats.py +++ b/tests/test_data_stats.py @@ -117,7 +117,7 @@ "additional_info": lambda x: torch.mean(x.float()), "logger_handler": None, }, - torch.tensor([[0, 1], [1, 2]]), + torch.tensor([[0, 1], [1, 2]]).to("cuda" if torch.cuda.is_available() else "cpu"), ( "test data statistics:\nType: \nShape: torch.Size([2, 2])\nValue range: (0, 2)\n" "Value: tensor([[0, 1],\n [1, 2]])\nAdditional info: 1.0" diff --git a/tests/test_data_statsd.py b/tests/test_data_statsd.py index be7e54bc25..aea0f1e721 100644 --- a/tests/test_data_statsd.py +++ b/tests/test_data_statsd.py @@ -124,7 +124,7 @@ "additional_info": lambda x: torch.mean(x.float()), "logger_handler": None, }, - {"img": torch.tensor([[0, 1], [1, 2]])}, + {"img": torch.tensor([[0, 1], [1, 2]]).to("cuda" if torch.cuda.is_available() else "cpu")}, ( "test data statistics:\nType: \nShape: torch.Size([2, 2])\nValue range: (0, 2)\n" "Value: tensor([[0, 1],\n [1, 2]])\nAdditional info: 1.0" diff --git a/tests/test_label_to_mask.py b/tests/test_label_to_mask.py index 2a84c7bea6..9caa7252f3 100644 --- a/tests/test_label_to_mask.py +++ b/tests/test_label_to_mask.py @@ -12,46 +12,59 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import LabelToMask +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [ - {"select_labels": [2, 3], "merge_channels": False}, - np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array([[[0, 0, 0], [1, 1, 1], [1, 1, 1], [0, 0, 0], [0, 0, 0], [0, 0, 0]]]), -] - -TEST_CASE_2 = [ - {"select_labels": 2, "merge_channels": False}, - np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array([[[0, 0, 0], [1, 1, 1], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]]]), -] - -TEST_CASE_3 = [ - {"select_labels": [1, 2], "merge_channels": False}, - np.array([[[0, 0, 1], [0, 1, 0]], [[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]]), - np.array([[[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]]), -] - -TEST_CASE_4 = [ - {"select_labels": 2, "merge_channels": False}, - np.array([[[0, 0, 1], [0, 1, 0]], [[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]]), - np.array([[[1, 0, 1], [1, 1, 0]]]), -] - -TEST_CASE_5 = [ - {"select_labels": [1, 2], "merge_channels": True}, - np.array([[[0, 0, 1], [0, 1, 0]], [[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]]), - np.array([[[1, 0, 1], [1, 1, 1]]]), -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + {"select_labels": [2, 3], "merge_channels": False}, + p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4], [5, 5, 5], [6, 6, 6]]])), + np.array([[[0, 0, 0], [1, 1, 1], [1, 1, 1], [0, 0, 0], [0, 0, 0], [0, 0, 0]]]), + ] + ) + TESTS.append( + [ + {"select_labels": 2, "merge_channels": False}, + p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4], [5, 5, 5], [6, 6, 6]]])), + np.array([[[0, 0, 0], [1, 1, 1], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]]]), + ] + ) + TESTS.append( + [ + {"select_labels": [1, 2], "merge_channels": False}, + p(np.array([[[0, 0, 1], [0, 1, 0]], [[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]])), + np.array([[[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]]), + ] + ) + TESTS.append( + [ + {"select_labels": 2, "merge_channels": False}, + p(np.array([[[0, 0, 1], [0, 1, 0]], [[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]])), + np.array([[[1, 0, 1], [1, 1, 0]]]), + ] + ) + TESTS.append( + [ + {"select_labels": [1, 2], "merge_channels": True}, + p(np.array([[[0, 0, 1], [0, 1, 0]], [[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]])), + np.array([[[1, 0, 1], [1, 1, 1]]]), + ] + ) class TestLabelToMask(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) + @parameterized.expand(TESTS) def test_value(self, argments, image, expected_data): result = LabelToMask(**argments)(image) - np.testing.assert_allclose(result, expected_data) + self.assertEqual(type(result), type(image)) + if isinstance(result, torch.Tensor): + self.assertEqual(result.device, image.device) + assert_allclose(result, expected_data) if __name__ == "__main__": diff --git a/tests/test_label_to_maskd.py b/tests/test_label_to_maskd.py index f046390c19..b8f0d3c171 100644 --- a/tests/test_label_to_maskd.py +++ b/tests/test_label_to_maskd.py @@ -12,46 +12,60 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import LabelToMaskd +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [ - {"keys": "img", "select_labels": [2, 3], "merge_channels": False}, - {"img": np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4], [5, 5, 5], [6, 6, 6]]])}, - np.array([[[0, 0, 0], [1, 1, 1], [1, 1, 1], [0, 0, 0], [0, 0, 0], [0, 0, 0]]]), -] - -TEST_CASE_2 = [ - {"keys": "img", "select_labels": 2, "merge_channels": False}, - {"img": np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4], [5, 5, 5], [6, 6, 6]]])}, - np.array([[[0, 0, 0], [1, 1, 1], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]]]), -] - -TEST_CASE_3 = [ - {"keys": "img", "select_labels": [1, 2], "merge_channels": False}, - {"img": np.array([[[0, 0, 1], [0, 1, 0]], [[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]])}, - np.array([[[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]]), -] - -TEST_CASE_4 = [ - {"keys": "img", "select_labels": 2, "merge_channels": False}, - {"img": np.array([[[0, 0, 1], [0, 1, 0]], [[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]])}, - np.array([[[1, 0, 1], [1, 1, 0]]]), -] - -TEST_CASE_5 = [ - {"keys": "img", "select_labels": [1, 2], "merge_channels": True}, - {"img": np.array([[[0, 0, 1], [0, 1, 0]], [[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]])}, - np.array([[[1, 0, 1], [1, 1, 1]]]), -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + {"keys": "img", "select_labels": [2, 3], "merge_channels": False}, + {"img": p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4], [5, 5, 5], [6, 6, 6]]]))}, + np.array([[[0, 0, 0], [1, 1, 1], [1, 1, 1], [0, 0, 0], [0, 0, 0], [0, 0, 0]]]), + ] + ) + TESTS.append( + [ + {"keys": "img", "select_labels": 2, "merge_channels": False}, + {"img": p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4], [5, 5, 5], [6, 6, 6]]]))}, + np.array([[[0, 0, 0], [1, 1, 1], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]]]), + ] + ) + TESTS.append( + [ + {"keys": "img", "select_labels": [1, 2], "merge_channels": False}, + {"img": p(np.array([[[0, 0, 1], [0, 1, 0]], [[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]]))}, + np.array([[[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]]), + ] + ) + TESTS.append( + [ + {"keys": "img", "select_labels": 2, "merge_channels": False}, + {"img": p(np.array([[[0, 0, 1], [0, 1, 0]], [[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]]))}, + np.array([[[1, 0, 1], [1, 1, 0]]]), + ] + ) + TESTS.append( + [ + {"keys": "img", "select_labels": [1, 2], "merge_channels": True}, + {"img": p(np.array([[[0, 0, 1], [0, 1, 0]], [[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]]))}, + np.array([[[1, 0, 1], [1, 1, 1]]]), + ] + ) class TestLabelToMaskd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) - def test_value(self, argments, image, expected_data): - result = LabelToMaskd(**argments)(image) - np.testing.assert_allclose(result["img"], expected_data) + @parameterized.expand(TESTS) + def test_value(self, argments, input_data, expected_data): + result = LabelToMaskd(**argments)(input_data) + r, i = result["img"], input_data["img"] + self.assertEqual(type(r), type(i)) + if isinstance(r, torch.Tensor): + self.assertEqual(r.device, i.device) + assert_allclose(r, expected_data) if __name__ == "__main__": diff --git a/tests/test_lambda.py b/tests/test_lambda.py index e71eb3e5b0..738c81130d 100644 --- a/tests/test_lambda.py +++ b/tests/test_lambda.py @@ -11,30 +11,30 @@ import unittest -import numpy as np - from monai.transforms.utility.array import Lambda -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose class TestLambda(NumpyImageTestCase2D): def test_lambda_identity(self): - img = self.imt + for p in TEST_NDARRAYS: + img = p(self.imt) - def identity_func(x): - return x + def identity_func(x): + return x - lambd = Lambda(func=identity_func) - self.assertTrue(np.allclose(identity_func(img), lambd(img))) + lambd = Lambda(func=identity_func) + assert_allclose(identity_func(img), lambd(img)) def test_lambda_slicing(self): - img = self.imt + for p in TEST_NDARRAYS: + img = p(self.imt) - def slice_func(x): - return x[:, :, :6, ::-2] + def slice_func(x): + return x[:, :, :6, ::2] - lambd = Lambda(func=slice_func) - self.assertTrue(np.allclose(slice_func(img), lambd(img))) + lambd = Lambda(func=slice_func) + assert_allclose(slice_func(img), lambd(img)) if __name__ == "__main__": diff --git a/tests/test_lambdad.py b/tests/test_lambdad.py index ca28af778b..05ba0ff6bc 100644 --- a/tests/test_lambdad.py +++ b/tests/test_lambdad.py @@ -11,37 +11,36 @@ import unittest -import numpy as np - from monai.transforms.utility.dictionary import Lambdad -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose class TestLambdad(NumpyImageTestCase2D): def test_lambdad_identity(self): - img = self.imt - data = {"img": img, "prop": 1.0} + for p in TEST_NDARRAYS: + img = p(self.imt) + data = {"img": img, "prop": 1.0} - def noise_func(x): - return x + 1.0 + def noise_func(x): + return x + 1.0 - expected = {"img": noise_func(data["img"]), "prop": 1.0} - ret = Lambdad(keys=["img", "prop"], func=noise_func, overwrite=[True, False])(data) - self.assertTrue(np.allclose(expected["img"], ret["img"])) - self.assertTrue(np.allclose(expected["prop"], ret["prop"])) + expected = {"img": noise_func(data["img"]), "prop": 1.0} + ret = Lambdad(keys=["img", "prop"], func=noise_func, overwrite=[True, False])(data) + assert_allclose(expected["img"], ret["img"]) + assert_allclose(expected["prop"], ret["prop"]) def test_lambdad_slicing(self): - img = self.imt - data = {} - data["img"] = img + for p in TEST_NDARRAYS: + img = p(self.imt) + data = {"img": img} - def slice_func(x): - return x[:, :, :6, ::-2] + def slice_func(x): + return x[:, :, :6, ::2] - lambd = Lambdad(keys=data.keys(), func=slice_func) - expected = {} - expected["img"] = slice_func(data["img"]) - self.assertTrue(np.allclose(expected["img"], lambd(data)["img"])) + lambd = Lambdad(keys=data.keys(), func=slice_func) + expected = {} + expected["img"] = slice_func(data["img"]) + assert_allclose(expected["img"], lambd(data)["img"]) if __name__ == "__main__": diff --git a/tests/test_squeezedim.py b/tests/test_squeezedim.py index 01ea489320..15ff7e94d6 100644 --- a/tests/test_squeezedim.py +++ b/tests/test_squeezedim.py @@ -12,34 +12,32 @@ import unittest import numpy as np -import torch from parameterized import parameterized from monai.transforms import SqueezeDim +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [{"dim": None}, np.random.rand(1, 2, 1, 3), (2, 3)] +TESTS, TESTS_FAIL = [], [] +for p in TEST_NDARRAYS: + TESTS.append([{"dim": None}, p(np.random.rand(1, 2, 1, 3)), (2, 3)]) + TESTS.append([{"dim": 2}, p(np.random.rand(1, 2, 1, 8, 16)), (1, 2, 8, 16)]) + TESTS.append([{"dim": -1}, p(np.random.rand(1, 1, 16, 8, 1)), (1, 1, 16, 8)]) + TESTS.append([{}, p(np.random.rand(1, 2, 1, 3)), (2, 1, 3)]) -TEST_CASE_2 = [{"dim": 2}, np.random.rand(1, 2, 1, 8, 16), (1, 2, 8, 16)] - -TEST_CASE_3 = [{"dim": -1}, np.random.rand(1, 1, 16, 8, 1), (1, 1, 16, 8)] - -TEST_CASE_4 = [{}, np.random.rand(1, 2, 1, 3), (2, 1, 3)] - -TEST_CASE_4_PT = [{}, torch.rand(1, 2, 1, 3), (2, 1, 3)] - -TEST_CASE_5 = [ValueError, {"dim": -2}, np.random.rand(1, 1, 16, 8, 1)] - -TEST_CASE_6 = [TypeError, {"dim": 0.5}, np.random.rand(1, 1, 16, 8, 1)] + TESTS_FAIL.append([ValueError, {"dim": -2}, p(np.random.rand(1, 1, 16, 8, 1))]) + TESTS_FAIL.append([TypeError, {"dim": 0.5}, p(np.random.rand(1, 1, 16, 8, 1))]) class TestSqueezeDim(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_4_PT]) + @parameterized.expand(TESTS) def test_shape(self, input_param, test_data, expected_shape): + result = SqueezeDim(**input_param)(test_data) self.assertTupleEqual(result.shape, expected_shape) - @parameterized.expand([TEST_CASE_5, TEST_CASE_6]) + @parameterized.expand(TESTS_FAIL) def test_invalid_inputs(self, exception, input_param, test_data): + with self.assertRaises(exception): SqueezeDim(**input_param)(test_data) diff --git a/tests/test_squeezedimd.py b/tests/test_squeezedimd.py index dcbd9212c7..35e7cd5d74 100644 --- a/tests/test_squeezedimd.py +++ b/tests/test_squeezedimd.py @@ -12,62 +12,78 @@ import unittest import numpy as np -import torch from parameterized import parameterized from monai.transforms import SqueezeDimd +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - {"keys": ["img", "seg"], "dim": None}, - {"img": np.random.rand(1, 2, 1, 3), "seg": np.random.randint(0, 2, size=[1, 2, 1, 3])}, - (2, 3), -] +TESTS, TESTS_FAIL = [], [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + {"keys": ["img", "seg"], "dim": None}, + {"img": p(np.random.rand(1, 2, 1, 3)), "seg": p(np.random.randint(0, 2, size=[1, 2, 1, 3]))}, + (2, 3), + ] + ) -TEST_CASE_2 = [ - {"keys": ["img", "seg"], "dim": 2}, - {"img": np.random.rand(1, 2, 1, 8, 16), "seg": np.random.randint(0, 2, size=[1, 2, 1, 8, 16])}, - (1, 2, 8, 16), -] + TESTS.append( + [ + {"keys": ["img", "seg"], "dim": 2}, + {"img": p(np.random.rand(1, 2, 1, 8, 16)), "seg": p(np.random.randint(0, 2, size=[1, 2, 1, 8, 16]))}, + (1, 2, 8, 16), + ] + ) -TEST_CASE_3 = [ - {"keys": ["img", "seg"], "dim": -1}, - {"img": np.random.rand(1, 1, 16, 8, 1), "seg": np.random.randint(0, 2, size=[1, 1, 16, 8, 1])}, - (1, 1, 16, 8), -] + TESTS.append( + [ + {"keys": ["img", "seg"], "dim": -1}, + {"img": p(np.random.rand(1, 1, 16, 8, 1)), "seg": p(np.random.randint(0, 2, size=[1, 1, 16, 8, 1]))}, + (1, 1, 16, 8), + ] + ) -TEST_CASE_4 = [ - {"keys": ["img", "seg"]}, - {"img": np.random.rand(1, 2, 1, 3), "seg": np.random.randint(0, 2, size=[1, 2, 1, 3])}, - (2, 1, 3), -] + TESTS.append( + [ + {"keys": ["img", "seg"]}, + {"img": p(np.random.rand(1, 2, 1, 3)), "seg": p(np.random.randint(0, 2, size=[1, 2, 1, 3]))}, + (2, 1, 3), + ] + ) -TEST_CASE_4_PT = [ - {"keys": ["img", "seg"], "dim": 0}, - {"img": torch.rand(1, 2, 1, 3), "seg": torch.randint(0, 2, size=[1, 2, 1, 3])}, - (2, 1, 3), -] + TESTS.append( + [ + {"keys": ["img", "seg"], "dim": 0}, + {"img": p(np.random.rand(1, 2, 1, 3)), "seg": p(np.random.randint(0, 2, size=[1, 2, 1, 3]))}, + (2, 1, 3), + ] + ) -TEST_CASE_5 = [ - ValueError, - {"keys": ["img", "seg"], "dim": -2}, - {"img": np.random.rand(1, 1, 16, 8, 1), "seg": np.random.randint(0, 2, size=[1, 1, 16, 8, 1])}, -] + TESTS_FAIL.append( + [ + ValueError, + {"keys": ["img", "seg"], "dim": -2}, + {"img": p(np.random.rand(1, 1, 16, 8, 1)), "seg": p(np.random.randint(0, 2, size=[1, 1, 16, 8, 1]))}, + ] + ) -TEST_CASE_6 = [ - TypeError, - {"keys": ["img", "seg"], "dim": 0.5}, - {"img": np.random.rand(1, 1, 16, 8, 1), "seg": np.random.randint(0, 2, size=[1, 1, 16, 8, 1])}, -] + TESTS_FAIL.append( + [ + TypeError, + {"keys": ["img", "seg"], "dim": 0.5}, + {"img": p(np.random.rand(1, 1, 16, 8, 1)), "seg": p(np.random.randint(0, 2, size=[1, 1, 16, 8, 1]))}, + ] + ) class TestSqueezeDim(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_4_PT]) + @parameterized.expand(TESTS) def test_shape(self, input_param, test_data, expected_shape): result = SqueezeDimd(**input_param)(test_data) self.assertTupleEqual(result["img"].shape, expected_shape) self.assertTupleEqual(result["seg"].shape, expected_shape) - @parameterized.expand([TEST_CASE_5, TEST_CASE_6]) + @parameterized.expand(TESTS_FAIL) def test_invalid_inputs(self, exception, input_param, test_data): with self.assertRaises(exception): SqueezeDimd(**input_param)(test_data) diff --git a/tests/test_to_cupy.py b/tests/test_to_cupy.py index a9460bc825..8b00e12539 100644 --- a/tests/test_to_cupy.py +++ b/tests/test_to_cupy.py @@ -24,7 +24,7 @@ class TestToCupy(unittest.TestCase): @skipUnless(has_cp, "CuPy is required.") - def test_cumpy_input(self): + def test_cupy_input(self): test_data = cp.array([[1, 2], [3, 4]]) test_data = cp.rot90(test_data) self.assertFalse(test_data.flags["C_CONTIGUOUS"]) diff --git a/tests/test_to_cupyd.py b/tests/test_to_cupyd.py index 2f3c42dd1f..6f40bafe1c 100644 --- a/tests/test_to_cupyd.py +++ b/tests/test_to_cupyd.py @@ -24,7 +24,7 @@ class TestToCupyd(unittest.TestCase): @skipUnless(has_cp, "CuPy is required.") - def test_cumpy_input(self): + def test_cupy_input(self): test_data = cp.array([[1, 2], [3, 4]]) test_data = cp.rot90(test_data) self.assertFalse(test_data.flags["C_CONTIGUOUS"]) diff --git a/tests/test_to_numpy.py b/tests/test_to_numpy.py index fd49a3d473..b48727c01d 100644 --- a/tests/test_to_numpy.py +++ b/tests/test_to_numpy.py @@ -24,7 +24,7 @@ class TestToNumpy(unittest.TestCase): @skipUnless(has_cp, "CuPy is required.") - def test_cumpy_input(self): + def test_cupy_input(self): test_data = cp.array([[1, 2], [3, 4]]) test_data = cp.rot90(test_data) self.assertFalse(test_data.flags["C_CONTIGUOUS"]) diff --git a/tests/test_to_numpyd.py b/tests/test_to_numpyd.py index adfab65904..5acaef39c7 100644 --- a/tests/test_to_numpyd.py +++ b/tests/test_to_numpyd.py @@ -24,7 +24,7 @@ class TestToNumpyd(unittest.TestCase): @skipUnless(has_cp, "CuPy is required.") - def test_cumpy_input(self): + def test_cupy_input(self): test_data = cp.array([[1, 2], [3, 4]]) test_data = cp.rot90(test_data) self.assertFalse(test_data.flags["C_CONTIGUOUS"]) diff --git a/tests/utils.py b/tests/utils.py index 22720849f1..1375cd2d72 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -36,6 +36,7 @@ from monai.config.type_definitions import NdarrayOrTensor from monai.data import create_test_image_2d, create_test_image_3d from monai.utils import ensure_tuple, optional_import, set_determinism +from monai.utils.misc import is_module_ver_at_least from monai.utils.module import version_leq nib, _ = optional_import("nibabel") @@ -142,8 +143,7 @@ class SkipIfBeforePyTorchVersion: def __init__(self, pytorch_version_tuple): self.min_version = pytorch_version_tuple - test_ver = ".".join(map(str, self.min_version)) - self.version_too_old = torch.__version__ != test_ver and version_leq(torch.__version__, test_ver) + self.version_too_old = not is_module_ver_at_least(torch, pytorch_version_tuple) def __call__(self, obj): return unittest.skipIf(