diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index fcd9adba94..962e1f3769 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -274,41 +274,47 @@ Intensity :special-members: __call__ `RandHistogramShift` -""""""""""""""""""""" +"""""""""""""""""""" .. autoclass:: RandHistogramShift :members: :special-members: __call__ `DetectEnvelope` -""""""""""""""""""""" +"""""""""""""""" .. autoclass:: DetectEnvelope :members: :special-members: __call__ `GibbsNoise` -"""""""""""""" +"""""""""""" .. autoclass:: GibbsNoise :members: :special-members: __call__ `RandGibbsNoise` -""""""""""""""""" +"""""""""""""""" .. autoclass:: RandGibbsNoise :members: :special-members: __call__ `KSpaceSpikeNoise` -"""""""""""""""""""" +"""""""""""""""""" .. autoclass:: KSpaceSpikeNoise :members: :special-members: __call__ `RandKSpaceSpikeNoise` -"""""""""""""""""""""""" +"""""""""""""""""""""" .. autoclass:: RandKSpaceSpikeNoise :members: :special-members: __call__ +`RandCoarseDropout` +""""""""""""""""""" + .. autoclass:: RandCoarseDropout + :members: + :special-members: __call__ + IO ^^ @@ -889,6 +895,12 @@ Intensity (Dict) :members: :special-members: __call__ +`RandCoarseDropoutd` +"""""""""""""""""""" +.. autoclass:: RandCoarseDropoutd + :members: + :special-members: __call__ + IO (Dict) ^^^^^^^^^ diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 21cfce2b82..45eecd266c 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -88,6 +88,7 @@ NormalizeIntensity, RandAdjustContrast, RandBiasField, + RandCoarseDropout, RandGaussianNoise, RandGaussianSharpen, RandGaussianSmooth, @@ -134,6 +135,9 @@ RandBiasFieldd, RandBiasFieldD, RandBiasFieldDict, + RandCoarseDropoutd, + RandCoarseDropoutD, + RandCoarseDropoutDict, RandGaussianNoised, RandGaussianNoiseD, RandGaussianNoiseDict, diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 65c114abcd..dfbac7465c 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -21,6 +21,7 @@ import torch from monai.config import DtypeLike +from monai.data.utils import get_random_patch, get_valid_patch_size from monai.networks.layers import GaussianFilter, HilbertTransform, SavitzkyGolayFilter from monai.transforms.transform import RandomizableTransform, Transform from monai.transforms.utils import rescale_array @@ -31,6 +32,7 @@ ensure_tuple, ensure_tuple_rep, ensure_tuple_size, + fall_back_tuple, ) __all__ = [ @@ -61,6 +63,7 @@ "RandGibbsNoise", "KSpaceSpikeNoise", "RandKSpaceSpikeNoise", + "RandCoarseDropout", ] @@ -1603,3 +1606,68 @@ def _to_numpy(self, img: Union[np.ndarray, torch.Tensor]) -> Tuple[np.ndarray, t return img.cpu().detach().numpy(), img.device else: return img, torch.device("cpu") + + +class RandCoarseDropout(RandomizableTransform): + """ + Randomly coarse dropout regions in the image, then fill in the rectangular regions with specified value. + Refer to: https://arxiv.org/abs/1708.04552 and: + https://albumentations.ai/docs/api_reference/augmentations/transforms/ + #albumentations.augmentations.transforms.CoarseDropout. + + Args: + holes: number of regions to dropout, if `max_holes` is not None, use this arg as the minimum number to + randomly select the expected number of regions. + spatial_size: spatial size of the regions to dropout, if `max_spatial_size` is not None, use this arg + as the minimum spatial size to randomly select size for every region. + if some components of the `spatial_size` are non-positive values, the transform will use the + corresponding components of input img size. For example, `spatial_size=(32, -1)` will be adapted + to `(32, 64)` if the second spatial dimension size of img is `64`. + fill_value: target value to fill the dropout regions. + max_holes: if not None, define the maximum number to randomly select the expected number of regions. + max_spatial_size: if not None, define the maximum spatial size to randomly select size for every region. + if some components of the `max_spatial_size` are non-positive values, the transform will use the + corresponding components of input img size. For example, `max_spatial_size=(32, -1)` will be adapted + to `(32, 64)` if the second spatial dimension size of img is `64`. + prob: probability of applying the transform. + + """ + + def __init__( + self, + holes: int, + spatial_size: Union[Sequence[int], int], + fill_value: Union[float, int] = 0, + max_holes: Optional[int] = None, + max_spatial_size: Optional[Union[Sequence[int], int]] = None, + prob: float = 0.1, + ) -> None: + RandomizableTransform.__init__(self, prob) + if holes < 1: + raise ValueError("number of holes must be greater than 0.") + self.holes = holes + self.spatial_size = spatial_size + self.fill_value = fill_value + self.max_holes = max_holes + self.max_spatial_size = max_spatial_size + self.hole_coords: List = [] + + def randomize(self, img_size: Sequence[int]) -> None: + super().randomize(None) + size = fall_back_tuple(self.spatial_size, img_size) + self.hole_coords = [] # clear previously computed coords + num_holes = self.holes if self.max_holes is None else self.R.randint(self.holes, self.max_holes + 1) + for _ in range(num_holes): + if self.max_spatial_size is not None: + max_size = fall_back_tuple(self.max_spatial_size, img_size) + size = tuple(self.R.randint(low=size[i], high=max_size[i] + 1) for i in range(len(img_size))) + valid_size = get_valid_patch_size(img_size, size) + self.hole_coords.append((slice(None),) + get_random_patch(img_size, valid_size, self.R)) + + def __call__(self, img: np.ndarray): + self.randomize(img.shape[1:]) + if self._do_transform: + for h in self.hole_coords: + img[h] = self.fill_value + + return img diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index ae0b83e0ea..49f20ea419 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -22,6 +22,7 @@ import torch from monai.config import DtypeLike, KeysCollection +from monai.data.utils import get_random_patch, get_valid_patch_size from monai.transforms.intensity.array import ( AdjustContrast, GaussianSharpen, @@ -41,7 +42,7 @@ ThresholdIntensity, ) from monai.transforms.transform import MapTransform, RandomizableTransform -from monai.utils import dtype_torch_to_numpy, ensure_tuple_rep, ensure_tuple_size +from monai.utils import dtype_torch_to_numpy, ensure_tuple_rep, ensure_tuple_size, fall_back_tuple __all__ = [ "RandGaussianNoised", @@ -69,6 +70,7 @@ "KSpaceSpikeNoised", "RandKSpaceSpikeNoised", "RandHistogramShiftd", + "RandCoarseDropoutd", "RandGaussianNoiseD", "RandGaussianNoiseDict", "ShiftIntensityD", @@ -117,13 +119,16 @@ "RandHistogramShiftDict", "RandRicianNoiseD", "RandRicianNoiseDict", + "RandCoarseDropoutD", + "RandCoarseDropoutDict", ] class RandGaussianNoised(RandomizableTransform, MapTransform): """ Dictionary-based version :py:class:`monai.transforms.RandGaussianNoise`. - Add Gaussian noise to image. This transform assumes all the expected fields have same shape. + Add Gaussian noise to image. This transform assumes all the expected fields have same shape, if want to add + different noise for every field, please use this transform separately. Args: keys: keys of the corresponding items to be transformed. @@ -172,7 +177,8 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda class RandRicianNoised(RandomizableTransform, MapTransform): """ Dictionary-based version :py:class:`monai.transforms.RandRicianNoise`. - Add Rician noise to image. This transform assumes all the expected fields have same shape. + Add Rician noise to image. This transform assumes all the expected fields have same shape, if want to add + different noise for every field, please use this transform separately. Args: keys: Keys of the corresponding items to be transformed. @@ -1324,6 +1330,78 @@ def _to_numpy(self, d: Union[torch.Tensor, np.ndarray]) -> np.ndarray: return d_numpy +class RandCoarseDropoutd(RandomizableTransform, MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.RandCoarseDropout`. + Expect all the data specified by `keys` have same spatial shape and will randomly dropout the same regions + for every key, if want to dropout differently for every key, please use this transform separately. + + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + holes: number of regions to dropout, if `max_holes` is not None, use this arg as the minimum number to + randomly select the expected number of regions. + spatial_size: spatial size of the regions to dropout, if `max_spatial_size` is not None, use this arg + as the minimum spatial size to randomly select size for every region. + if some components of the `spatial_size` are non-positive values, the transform will use the + corresponding components of input img size. For example, `spatial_size=(32, -1)` will be adapted + to `(32, 64)` if the second spatial dimension size of img is `64`. + fill_value: target value to fill the dropout regions. + max_holes: if not None, define the maximum number to randomly select the expected number of regions. + max_spatial_size: if not None, define the maximum spatial size to randomly select size for every region. + if some components of the `max_spatial_size` are non-positive values, the transform will use the + corresponding components of input img size. For example, `max_spatial_size=(32, -1)` will be adapted + to `(32, 64)` if the second spatial dimension size of img is `64`. + prob: probability of applying the transform. + allow_missing_keys: don't raise exception if key is missing. + + """ + + def __init__( + self, + keys: KeysCollection, + holes: int, + spatial_size: Union[Sequence[int], int], + fill_value: Union[float, int] = 0, + max_holes: Optional[int] = None, + max_spatial_size: Optional[Union[Sequence[int], int]] = None, + prob: float = 0.1, + allow_missing_keys: bool = False, + ): + MapTransform.__init__(self, keys, allow_missing_keys) + RandomizableTransform.__init__(self, prob) + if holes < 1: + raise ValueError("number of holes must be greater than 0.") + self.holes = holes + self.spatial_size = spatial_size + self.fill_value = fill_value + self.max_holes = max_holes + self.max_spatial_size = max_spatial_size + self.hole_coords: List = [] + + def randomize(self, img_size: Sequence[int]) -> None: + super().randomize(None) + size = fall_back_tuple(self.spatial_size, img_size) + self.hole_coords = [] # clear previously computed coords + num_holes = self.holes if self.max_holes is None else self.R.randint(self.holes, self.max_holes + 1) + for _ in range(num_holes): + if self.max_spatial_size is not None: + max_size = fall_back_tuple(self.max_spatial_size, img_size) + size = tuple(self.R.randint(low=size[i], high=max_size[i] + 1) for i in range(len(img_size))) + valid_size = get_valid_patch_size(img_size, size) + self.hole_coords.append((slice(None),) + get_random_patch(img_size, valid_size, self.R)) + + def __call__(self, data): + d = dict(data) + # expect all the specified keys have same spatial shape + self.randomize(d[self.keys[0]].shape[1:]) + if self._do_transform: + for key in self.key_iterator(d): + for h in self.hole_coords: + d[key][h] = self.fill_value + return d + + RandGaussianNoiseD = RandGaussianNoiseDict = RandGaussianNoised RandRicianNoiseD = RandRicianNoiseDict = RandRicianNoised ShiftIntensityD = ShiftIntensityDict = ShiftIntensityd @@ -1349,3 +1427,4 @@ def _to_numpy(self, d: Union[torch.Tensor, np.ndarray]) -> np.ndarray: GibbsNoiseD = GibbsNoiseDict = GibbsNoised KSpaceSpikeNoiseD = KSpaceSpikeNoiseDict = KSpaceSpikeNoised RandKSpaceSpikeNoiseD = RandKSpaceSpikeNoiseDict = RandKSpaceSpikeNoised +RandCoarseDropoutD = RandCoarseDropoutDict = RandCoarseDropoutd diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 37dd9b47c6..06b98cdd2e 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -337,7 +337,7 @@ class Resize(Transform): Args: spatial_size: expected shape of spatial dimensions after resize operation. - if the components of the `spatial_size` are non-positive values, the transform will use the + if some components of the `spatial_size` are non-positive values, the transform will use the corresponding components of img size. For example, `spatial_size=(32, -1)` will be adapted to `(32, 64)` if the second spatial dimension size of img is `64`. mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} @@ -1297,7 +1297,7 @@ def __init__( spatial_size: output image spatial size. if `spatial_size` and `self.spatial_size` are not defined, or smaller than 1, the transform will use the spatial size of `img`. - if the components of the `spatial_size` are non-positive values, the transform will use the + if some components of the `spatial_size` are non-positive values, the transform will use the corresponding components of img size. For example, `spatial_size=(32, -1)` will be adapted to `(32, 64)` if the second spatial dimension size of img is `64`. mode: {``"bilinear"``, ``"nearest"``} @@ -1390,7 +1390,7 @@ def __init__( spatial_size: output image spatial size. if `spatial_size` and `self.spatial_size` are not defined, or smaller than 1, the transform will use the spatial size of `img`. - if the components of the `spatial_size` are non-positive values, the transform will use the + if some components of the `spatial_size` are non-positive values, the transform will use the corresponding components of img size. For example, `spatial_size=(32, -1)` will be adapted to `(32, 64)` if the second spatial dimension size of img is `64`. mode: {``"bilinear"``, ``"nearest"``} @@ -1553,7 +1553,7 @@ def __init__( spatial_size: specifying output image spatial size [h, w]. if `spatial_size` and `self.spatial_size` are not defined, or smaller than 1, the transform will use the spatial size of `img`. - if the components of the `spatial_size` are non-positive values, the transform will use the + if some components of the `spatial_size` are non-positive values, the transform will use the corresponding components of img size. For example, `spatial_size=(32, -1)` will be adapted to `(32, 64)` if the second spatial dimension size of img is `64`. mode: {``"bilinear"``, ``"nearest"``} @@ -1681,7 +1681,7 @@ def __init__( spatial_size: specifying output image spatial size [h, w, d]. if `spatial_size` and `self.spatial_size` are not defined, or smaller than 1, the transform will use the spatial size of `img`. - if the components of the `spatial_size` are non-positive values, the transform will use the + if some components of the `spatial_size` are non-positive values, the transform will use the corresponding components of img size. For example, `spatial_size=(32, 32, -1)` will be adapted to `(32, 32, 64)` if the third spatial dimension size of img is `64`. mode: {``"bilinear"``, ``"nearest"``} diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index b961ef7c92..2c9cac8438 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -500,7 +500,7 @@ class Resized(MapTransform, InvertibleTransform): keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` spatial_size: expected shape of spatial dimensions after resize operation. - if the components of the `spatial_size` are non-positive values, the transform will use the + if some components of the `spatial_size` are non-positive values, the transform will use the corresponding components of img size. For example, `spatial_size=(32, -1)` will be adapted to `(32, 64)` if the second spatial dimension size of img is `64`. mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} @@ -589,7 +589,7 @@ def __init__( spatial_size: output image spatial size. if `spatial_size` and `self.spatial_size` are not defined, or smaller than 1, the transform will use the spatial size of `img`. - if the components of the `spatial_size` are non-positive values, the transform will use the + if some components of the `spatial_size` are non-positive values, the transform will use the corresponding components of img size. For example, `spatial_size=(32, -1)` will be adapted to `(32, 64)` if the second spatial dimension size of img is `64`. mode: {``"bilinear"``, ``"nearest"``} @@ -695,7 +695,7 @@ def __init__( spatial_size: output image spatial size. if `spatial_size` and `self.spatial_size` are not defined, or smaller than 1, the transform will use the spatial size of `img`. - if the components of the `spatial_size` are non-positive values, the transform will use the + if some components of the `spatial_size` are non-positive values, the transform will use the corresponding components of img size. For example, `spatial_size=(32, -1)` will be adapted to `(32, 64)` if the second spatial dimension size of img is `64`. prob: probability of returning a randomized affine grid. @@ -860,7 +860,7 @@ def __init__( spatial_size: specifying output image spatial size [h, w]. if `spatial_size` and `self.spatial_size` are not defined, or smaller than 1, the transform will use the spatial size of `img`. - if the components of the `spatial_size` are non-positive values, the transform will use the + if some components of the `spatial_size` are non-positive values, the transform will use the corresponding components of img size. For example, `spatial_size=(32, -1)` will be adapted to `(32, 64)` if the second spatial dimension size of img is `64`. prob: probability of returning a randomized affine grid. @@ -980,7 +980,7 @@ def __init__( spatial_size: specifying output image spatial size [h, w, d]. if `spatial_size` and `self.spatial_size` are not defined, or smaller than 1, the transform will use the spatial size of `img`. - if the components of the `spatial_size` are non-positive values, the transform will use the + if some components of the `spatial_size` are non-positive values, the transform will use the corresponding components of img size. For example, `spatial_size=(32, 32, -1)` will be adapted to `(32, 32, 64)` if the third spatial dimension size of img is `64`. prob: probability of returning a randomized affine grid. diff --git a/tests/test_rand_coarse_dropout.py b/tests/test_rand_coarse_dropout.py new file mode 100644 index 0000000000..235a391567 --- /dev/null +++ b/tests/test_rand_coarse_dropout.py @@ -0,0 +1,73 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import RandCoarseDropout +from monai.utils import fall_back_tuple + +TEST_CASE_0 = [ + {"holes": 2, "spatial_size": [2, 2, 2], "fill_value": 5, "prob": 1.0}, + np.random.randint(0, 2, size=[3, 3, 3, 4]), + (3, 3, 3, 4), +] + +TEST_CASE_1 = [ + {"holes": 1, "spatial_size": [1, 2, 3], "fill_value": 5, "max_holes": 5, "prob": 1.0}, + np.random.randint(0, 2, size=[3, 3, 3, 4]), + (3, 3, 3, 4), +] + +TEST_CASE_2 = [ + {"holes": 2, "spatial_size": [2, 2, 2], "fill_value": 5, "max_spatial_size": [4, 4, 3], "prob": 1.0}, + np.random.randint(0, 2, size=[3, 3, 3, 4]), + (3, 3, 3, 4), +] + +TEST_CASE_3 = [ + {"holes": 2, "spatial_size": [2, -1, 2], "fill_value": 5, "max_spatial_size": [4, 4, -1], "prob": 1.0}, + np.random.randint(0, 2, size=[3, 3, 3, 4]), + (3, 3, 3, 4), +] + + +class TestRandCoarseDropout(unittest.TestCase): + @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + def test_value(self, input_param, input_data, expected_shape): + dropout = RandCoarseDropout(**input_param) + result = dropout(input_data) + holes = input_param.get("holes") + max_holes = input_param.get("max_holes") + spatial_size = fall_back_tuple(input_param.get("spatial_size"), input_data.shape[1:]) + max_spatial_size = fall_back_tuple(input_param.get("max_spatial_size"), input_data.shape[1:]) + + if max_holes is None: + self.assertEqual(len(dropout.hole_coords), holes) + else: + self.assertGreaterEqual(len(dropout.hole_coords), holes) + self.assertLessEqual(len(dropout.hole_coords), max_holes) + + for h in dropout.hole_coords: + data = result[h] + np.testing.assert_allclose(data, input_param.get("fill_value", 0)) + if max_spatial_size is None: + self.assertTupleEqual(data.shape[1:], tuple(spatial_size)) + else: + for d, s, m in zip(data.shape[1:], spatial_size, max_spatial_size): + self.assertGreaterEqual(d, s) + self.assertLessEqual(d, m) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_rand_coarse_dropoutd.py b/tests/test_rand_coarse_dropoutd.py new file mode 100644 index 0000000000..d189a80f56 --- /dev/null +++ b/tests/test_rand_coarse_dropoutd.py @@ -0,0 +1,87 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import RandCoarseDropoutd +from monai.utils import fall_back_tuple + +TEST_CASE_0 = [ + {"keys": "img", "holes": 2, "spatial_size": [2, 2, 2], "fill_value": 5, "prob": 1.0}, + {"img": np.random.randint(0, 2, size=[3, 3, 3, 4])}, + (3, 3, 3, 4), +] + +TEST_CASE_1 = [ + {"keys": "img", "holes": 1, "spatial_size": [1, 2, 3], "fill_value": 5, "max_holes": 5, "prob": 1.0}, + {"img": np.random.randint(0, 2, size=[3, 3, 3, 4])}, + (3, 3, 3, 4), +] + +TEST_CASE_2 = [ + { + "keys": "img", + "holes": 2, + "spatial_size": [2, 2, 2], + "fill_value": 5, + "max_spatial_size": [4, 4, 3], + "prob": 1.0, + }, + {"img": np.random.randint(0, 2, size=[3, 3, 3, 4])}, + (3, 3, 3, 4), +] + +TEST_CASE_3 = [ + { + "keys": "img", + "holes": 2, + "spatial_size": [2, -1, 2], + "fill_value": 5, + "max_spatial_size": [4, 4, -1], + "prob": 1.0, + }, + {"img": np.random.randint(0, 2, size=[3, 3, 3, 4])}, + (3, 3, 3, 4), +] + + +class TestRandCoarseDropoutd(unittest.TestCase): + @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + def test_value(self, input_param, input_data, expected_shape): + dropout = RandCoarseDropoutd(**input_param) + result = dropout(input_data)["img"] + holes = input_param.get("holes") + max_holes = input_param.get("max_holes") + spatial_size = fall_back_tuple(input_param.get("spatial_size"), input_data["img"].shape[1:]) + max_spatial_size = fall_back_tuple(input_param.get("max_spatial_size"), input_data["img"].shape[1:]) + + if max_holes is None: + self.assertEqual(len(dropout.hole_coords), holes) + else: + self.assertGreaterEqual(len(dropout.hole_coords), holes) + self.assertLessEqual(len(dropout.hole_coords), max_holes) + + for h in dropout.hole_coords: + data = result[h] + np.testing.assert_allclose(data, input_param.get("fill_value", 0)) + if max_spatial_size is None: + self.assertTupleEqual(data.shape[1:], tuple(spatial_size)) + else: + for d, s, m in zip(data.shape[1:], spatial_size, max_spatial_size): + self.assertGreaterEqual(d, s) + self.assertLessEqual(d, m) + + +if __name__ == "__main__": + unittest.main()