diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 240836ce0b..0eca034950 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -572,6 +572,7 @@ def __init__( return_coords: bool = False, k_divisible: Union[Sequence[int], int] = 1, mode: Union[NumpyPadMode, str] = NumpyPadMode.CONSTANT, + **np_kwargs, ) -> None: """ Args: @@ -586,6 +587,8 @@ def __init__( ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} one of the listed string values or a user supplied function. Defaults to ``"constant"``. see also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + np_kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension. + more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html """ self.select_fn = select_fn @@ -594,6 +597,7 @@ def __init__( self.return_coords = return_coords self.k_divisible = k_divisible self.mode: NumpyPadMode = look_up_option(mode, NumpyPadMode) + self.np_kwargs = np_kwargs def compute_bounding_box(self, img: np.ndarray): """ @@ -621,7 +625,7 @@ def crop_pad(self, img: np.ndarray, box_start: np.ndarray, box_end: np.ndarray): pad_to_start = np.maximum(-box_start, 0) pad_to_end = np.maximum(box_end - np.asarray(img.shape[1:]), 0) pad = list(chain(*zip(pad_to_start.tolist(), pad_to_end.tolist()))) - return BorderPad(spatial_border=pad, mode=self.mode)(cropped) + return BorderPad(spatial_border=pad, mode=self.mode, **self.np_kwargs)(cropped) def __call__(self, img: np.ndarray): """ diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 4a2ae32607..cdcc861c82 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -794,6 +794,7 @@ def __init__( start_coord_key: str = "foreground_start_coord", end_coord_key: str = "foreground_end_coord", allow_missing_keys: bool = False, + **np_kwargs, ) -> None: """ Args: @@ -813,6 +814,9 @@ def __init__( start_coord_key: key to record the start coordinate of spatial bounding box for foreground. end_coord_key: key to record the end coordinate of spatial bounding box for foreground. allow_missing_keys: don't raise exception if key is missing. + np_kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension. + more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + """ super().__init__(keys, allow_missing_keys) self.source_key = source_key @@ -824,6 +828,7 @@ def __init__( margin=margin, k_divisible=k_divisible, mode=mode, + **np_kwargs, ) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index dcbc7aa2f6..86f0e84249 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -546,6 +546,9 @@ class Zoom(Transform): 'linear', 'bilinear', 'bicubic' or 'trilinear'. Default: None. See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate keep_size: Should keep original size (padding/slicing if needed), default is True. + np_kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension. + more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + """ def __init__( @@ -555,12 +558,14 @@ def __init__( padding_mode: Union[NumpyPadMode, str] = NumpyPadMode.EDGE, align_corners: Optional[bool] = None, keep_size: bool = True, + **np_kwargs, ) -> None: self.zoom = zoom self.mode: InterpolateMode = InterpolateMode(mode) self.padding_mode: NumpyPadMode = NumpyPadMode(padding_mode) self.align_corners = align_corners self.keep_size = keep_size + self.np_kwargs = np_kwargs def __call__( self, @@ -607,7 +612,7 @@ def __call__( slice_vec[idx] = slice(half, half + od) padding_mode = look_up_option(self.padding_mode if padding_mode is None else padding_mode, NumpyPadMode) - zoomed = np.pad(zoomed, pad_vec, mode=padding_mode.value) # type: ignore + zoomed = np.pad(zoomed, pad_vec, mode=padding_mode.value, **self.np_kwargs) # type: ignore return zoomed[tuple(slice_vec)] @@ -868,6 +873,9 @@ class RandZoom(RandomizableTransform): 'linear', 'bilinear', 'bicubic' or 'trilinear'. Default: None. See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate keep_size: Should keep original size (pad if needed), default is True. + np_kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension. + more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + """ def __init__( @@ -879,6 +887,7 @@ def __init__( padding_mode: Union[NumpyPadMode, str] = NumpyPadMode.EDGE, align_corners: Optional[bool] = None, keep_size: bool = True, + **np_kwargs, ) -> None: RandomizableTransform.__init__(self, prob) self.min_zoom = ensure_tuple(min_zoom) @@ -889,6 +898,7 @@ def __init__( self.padding_mode: NumpyPadMode = look_up_option(padding_mode, NumpyPadMode) self.align_corners = align_corners self.keep_size = keep_size + self.np_kwargs = np_kwargs self._zoom: Sequence[float] = [1.0] @@ -928,7 +938,7 @@ def __call__( elif len(self._zoom) == 2 and img.ndim > 3: # if 2 zoom factors provided for 3D data, use the first factor for H and W dims, second factor for D dim self._zoom = ensure_tuple_rep(self._zoom[0], img.ndim - 2) + ensure_tuple(self._zoom[-1]) - zoomer = Zoom(self._zoom, keep_size=self.keep_size) + zoomer = Zoom(self._zoom, keep_size=self.keep_size, **self.np_kwargs) return np.asarray( zoomer( img, diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index a7eeceacf9..d953fd63ea 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -1534,6 +1534,9 @@ class Zoomd(MapTransform, InvertibleTransform): It also can be a sequence of bool or None, each element corresponds to a key in ``keys``. keep_size: Should keep original size (pad if needed), default is True. allow_missing_keys: don't raise exception if key is missing. + np_kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension. + more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + """ def __init__( @@ -1545,12 +1548,13 @@ def __init__( align_corners: Union[Sequence[Optional[bool]], Optional[bool]] = None, keep_size: bool = True, allow_missing_keys: bool = False, + **np_kwargs, ) -> None: super().__init__(keys, allow_missing_keys) self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) - self.zoomer = Zoom(zoom=zoom, keep_size=keep_size) + self.zoomer = Zoom(zoom=zoom, keep_size=keep_size, **np_kwargs) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) @@ -1630,6 +1634,9 @@ class RandZoomd(RandomizableTransform, MapTransform, InvertibleTransform): It also can be a sequence of bool or None, each element corresponds to a key in ``keys``. keep_size: Should keep original size (pad if needed), default is True. allow_missing_keys: don't raise exception if key is missing. + np_kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension. + more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + """ def __init__( @@ -1643,6 +1650,7 @@ def __init__( align_corners: Union[Sequence[Optional[bool]], Optional[bool]] = None, keep_size: bool = True, allow_missing_keys: bool = False, + **np_kwargs, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) @@ -1655,6 +1663,7 @@ def __init__( self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) self.keep_size = keep_size + self.np_kwargs = np_kwargs self._zoom: Sequence[float] = [1.0] @@ -1674,7 +1683,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda elif len(self._zoom) == 2 and img_dims > 3: # if 2 zoom factors provided for 3D data, use the first factor for H and W dims, second factor for D dim self._zoom = ensure_tuple_rep(self._zoom[0], img_dims - 2) + ensure_tuple(self._zoom[-1]) - zoomer = Zoom(self._zoom, keep_size=self.keep_size) + zoomer = Zoom(self._zoom, keep_size=self.keep_size, **self.np_kwargs) for key, mode, padding_mode, align_corners in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners ): diff --git a/tests/test_crop_foreground.py b/tests/test_crop_foreground.py index 8eae8f484e..71e488cac8 100644 --- a/tests/test_crop_foreground.py +++ b/tests/test_crop_foreground.py @@ -53,7 +53,7 @@ ] TEST_CASE_7 = [ - {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": 0, "k_divisible": 10}, + {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": 0, "k_divisible": 10, "constant_values": 2}, np.array([[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), np.zeros((1, 0, 0)), ] diff --git a/tests/test_crop_foregroundd.py b/tests/test_crop_foregroundd.py index 37abfb8c55..f51ca7e2df 100644 --- a/tests/test_crop_foregroundd.py +++ b/tests/test_crop_foregroundd.py @@ -23,6 +23,8 @@ "select_fn": lambda x: x > 0, "channel_indices": None, "margin": 0, + "mode": "constant", + "constant_values": 2, }, { "img": np.array([[[1, 0, 2, 0, 1], [0, 1, 2, 1, 0], [2, 2, 3, 2, 2], [0, 1, 2, 1, 0], [1, 0, 2, 0, 1]]]), diff --git a/tests/test_rand_zoom.py b/tests/test_rand_zoom.py index 35cf30bcb1..c21bc8b9e9 100644 --- a/tests/test_rand_zoom.py +++ b/tests/test_rand_zoom.py @@ -41,7 +41,14 @@ def test_correct_results(self, min_zoom, max_zoom, mode, keep_size): np.testing.assert_allclose(zoomed, expected, atol=1.0) def test_keep_size(self): - random_zoom = RandZoom(prob=1.0, min_zoom=0.6, max_zoom=0.7, keep_size=True) + random_zoom = RandZoom( + prob=1.0, + min_zoom=0.6, + max_zoom=0.7, + keep_size=True, + padding_mode="constant", + constant_values=2, + ) zoomed = random_zoom(self.imt[0]) self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape[1:])) zoomed = random_zoom(self.imt[0]) diff --git a/tests/test_rand_zoomd.py b/tests/test_rand_zoomd.py index fd50c490d5..4ccb1aad64 100644 --- a/tests/test_rand_zoomd.py +++ b/tests/test_rand_zoomd.py @@ -45,7 +45,15 @@ def test_correct_results(self, min_zoom, max_zoom, mode, align_corners, keep_siz def test_keep_size(self): key = "img" - random_zoom = RandZoomd(key, prob=1.0, min_zoom=0.6, max_zoom=0.7, keep_size=True) + random_zoom = RandZoomd( + keys=key, + prob=1.0, + min_zoom=0.6, + max_zoom=0.7, + keep_size=True, + padding_mode="constant", + constant_values=2, + ) zoomed = random_zoom({key: self.imt[0]}) self.assertTrue(np.array_equal(zoomed[key].shape, self.imt.shape[1:])) diff --git a/tests/test_zoom.py b/tests/test_zoom.py index dcc401f16c..e6710ede29 100644 --- a/tests/test_zoom.py +++ b/tests/test_zoom.py @@ -38,7 +38,7 @@ def test_correct_results(self, zoom, mode): np.testing.assert_allclose(zoomed, expected, atol=1.0) def test_keep_size(self): - zoom_fn = Zoom(zoom=[0.6, 0.6], keep_size=True, align_corners=True) + zoom_fn = Zoom(zoom=[0.6, 0.6], keep_size=True, align_corners=True, padding_mode="constant", constant_values=2) zoomed = zoom_fn(self.imt[0], mode="bilinear") np.testing.assert_allclose(zoomed.shape, self.imt.shape[1:]) diff --git a/tests/test_zoomd.py b/tests/test_zoomd.py index b17ecd1bf0..1a1a905d80 100644 --- a/tests/test_zoomd.py +++ b/tests/test_zoomd.py @@ -45,7 +45,7 @@ def test_correct_results(self, zoom, mode, keep_size): def test_keep_size(self): key = "img" - zoom_fn = Zoomd(key, zoom=0.6, keep_size=True) + zoom_fn = Zoomd(key, zoom=0.6, keep_size=True, padding_mode="constant", constant_values=2) zoomed = zoom_fn({key: self.imt[0]}) self.assertTrue(np.array_equal(zoomed[key].shape, self.imt.shape[1:]))