From 4c11161d42f54c6bc622a4e3d79db3f698337d9c Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Fri, 10 Sep 2021 02:58:41 -0400 Subject: [PATCH 01/19] CuPy to Tensor (#2919) * Add cupy to tensor Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Add unittest for cupy>tensor Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> Signed-off-by: ahatamizadeh --- monai/utils/type_conversion.py | 2 ++ tests/test_to_tensor.py | 15 ++++++++++++++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/monai/utils/type_conversion.py b/monai/utils/type_conversion.py index b0ce187e38..3688b02d26 100644 --- a/monai/utils/type_conversion.py +++ b/monai/utils/type_conversion.py @@ -105,6 +105,8 @@ def convert_to_tensor(data, wrap_sequence: bool = False): # numpy array with 0 dims is also sequence iterable, # `ascontiguousarray` will add 1 dim if img has no dim, so we only apply on data with dims return torch.as_tensor(data if data.ndim == 0 else np.ascontiguousarray(data)) + elif has_cp and isinstance(data, cp_ndarray): + return torch.as_tensor(data) elif isinstance(data, (float, int, bool)): return torch.as_tensor(data) elif isinstance(data, Sequence) and wrap_sequence: diff --git a/tests/test_to_tensor.py b/tests/test_to_tensor.py index 6ac06983f6..3d187a1dba 100644 --- a/tests/test_to_tensor.py +++ b/tests/test_to_tensor.py @@ -12,9 +12,12 @@ import unittest from parameterized import parameterized +from torch import Tensor from monai.transforms import ToTensor -from tests.utils import TEST_NDARRAYS, assert_allclose +from tests.utils import TEST_NDARRAYS, assert_allclose, optional_import + +cp, has_cp = optional_import("cupy") im = [[1, 2], [3, 4]] @@ -33,15 +36,25 @@ class TestToTensor(unittest.TestCase): @parameterized.expand(TESTS) def test_array_input(self, test_data, expected_shape): result = ToTensor()(test_data) + self.assertTrue(isinstance(result, Tensor)) assert_allclose(result, test_data) self.assertTupleEqual(result.shape, expected_shape) @parameterized.expand(TESTS_SINGLE) def test_single_input(self, test_data): result = ToTensor()(test_data) + self.assertTrue(isinstance(result, Tensor)) assert_allclose(result, test_data) self.assertEqual(result.ndim, 0) + @unittest.skipUnless(has_cp, "CuPy is required.") + def test_cupy(self): + test_data = [[1, 2], [3, 4]] + cupy_array = cp.ascontiguousarray(cp.asarray(test_data)) + result = ToTensor()(cupy_array) + self.assertTrue(isinstance(result, Tensor)) + assert_allclose(result, test_data) + if __name__ == "__main__": unittest.main() From 75c6402a682d0b96131550509fee1a785dc6ad57 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 10 Sep 2021 19:58:34 +0800 Subject: [PATCH 02/19] [DLMED] add dict version shuffle (#2918) Signed-off-by: Nic Ma Co-authored-by: Wenqi Li Signed-off-by: ahatamizadeh --- docs/source/transforms.rst | 6 ++ monai/transforms/__init__.py | 3 + monai/transforms/intensity/array.py | 15 +++++ monai/transforms/intensity/dictionary.py | 78 ++++++++++++++++++++++++ tests/test_rand_coarse_shuffle.py | 2 +- tests/test_rand_coarse_shuffled.py | 56 +++++++++++++++++ 6 files changed, 159 insertions(+), 1 deletion(-) create mode 100644 tests/test_rand_coarse_shuffled.py diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index d1083a641b..b61da87551 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -977,6 +977,12 @@ Intensity (Dict) :members: :special-members: __call__ +`RandCoarseShuffled` +"""""""""""""""""""" +.. autoclass:: RandCoarseShuffled + :members: + :special-members: __call__ + `HistogramNormalized` """"""""""""""""""""" .. autoclass:: HistogramNormalized diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index b9ba303ed7..e4ec38f82d 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -144,6 +144,9 @@ RandCoarseDropoutd, RandCoarseDropoutD, RandCoarseDropoutDict, + RandCoarseShuffled, + RandCoarseShuffleD, + RandCoarseShuffleDict, RandGaussianNoised, RandGaussianNoiseD, RandGaussianNoiseDict, diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index d10c3017a3..f6d4dfff5a 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -1786,6 +1786,21 @@ class RandCoarseShuffle(RandCoarseTransform): Kang, Guoliang, et al. "Patchshuffle regularization." arXiv preprint arXiv:1707.07103 (2017). https://arxiv.org/abs/1707.07103 + Args: + holes: number of regions to dropout, if `max_holes` is not None, use this arg as the minimum number to + randomly select the expected number of regions. + spatial_size: spatial size of the regions to dropout, if `max_spatial_size` is not None, use this arg + as the minimum spatial size to randomly select size for every region. + if some components of the `spatial_size` are non-positive values, the transform will use the + corresponding components of input img size. For example, `spatial_size=(32, -1)` will be adapted + to `(32, 64)` if the second spatial dimension size of img is `64`. + max_holes: if not None, define the maximum number to randomly select the expected number of regions. + max_spatial_size: if not None, define the maximum spatial size to randomly select size for every region. + if some components of the `max_spatial_size` are non-positive values, the transform will use the + corresponding components of input img size. For example, `max_spatial_size=(32, -1)` will be adapted + to `(32, 64)` if the second spatial dimension size of img is `64`. + prob: probability of applying the transform. + """ def _transform_holes(self, img: np.ndarray): diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index bc53fb6b7b..ca24980359 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -34,6 +34,7 @@ NormalizeIntensity, RandBiasField, RandCoarseDropout, + RandCoarseShuffle, RandGaussianNoise, RandKSpaceSpikeNoise, RandRicianNoise, @@ -75,6 +76,7 @@ "RandKSpaceSpikeNoised", "RandHistogramShiftd", "RandCoarseDropoutd", + "RandCoarseShuffled", "HistogramNormalized", "RandGaussianNoiseD", "RandGaussianNoiseDict", @@ -126,6 +128,8 @@ "RandRicianNoiseDict", "RandCoarseDropoutD", "RandCoarseDropoutDict", + "RandCoarseShuffleD", + "RandCoarseShuffleDict", "HistogramNormalizeD", "HistogramNormalizeDict", ] @@ -1478,6 +1482,13 @@ def __init__( prob=prob, ) + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandCoarseDropoutd": + self.dropper.set_random_state(seed, state) + super().set_random_state(seed, state) + return self + def randomize(self, img_size: Sequence[int]) -> None: self.dropper.randomize(img_size=img_size) @@ -1492,6 +1503,72 @@ def __call__(self, data): return d +class RandCoarseShuffled(Randomizable, MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.RandCoarseShuffle`. + Expect all the data specified by `keys` have same spatial shape and will randomly dropout the same regions + for every key, if want to shuffle different regions for every key, please use this transform separately. + + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + holes: number of regions to dropout, if `max_holes` is not None, use this arg as the minimum number to + randomly select the expected number of regions. + spatial_size: spatial size of the regions to dropout, if `max_spatial_size` is not None, use this arg + as the minimum spatial size to randomly select size for every region. + if some components of the `spatial_size` are non-positive values, the transform will use the + corresponding components of input img size. For example, `spatial_size=(32, -1)` will be adapted + to `(32, 64)` if the second spatial dimension size of img is `64`. + max_holes: if not None, define the maximum number to randomly select the expected number of regions. + max_spatial_size: if not None, define the maximum spatial size to randomly select size for every region. + if some components of the `max_spatial_size` are non-positive values, the transform will use the + corresponding components of input img size. For example, `max_spatial_size=(32, -1)` will be adapted + to `(32, 64)` if the second spatial dimension size of img is `64`. + prob: probability of applying the transform. + allow_missing_keys: don't raise exception if key is missing. + + """ + + def __init__( + self, + keys: KeysCollection, + holes: int, + spatial_size: Union[Sequence[int], int], + max_holes: Optional[int] = None, + max_spatial_size: Optional[Union[Sequence[int], int]] = None, + prob: float = 0.1, + allow_missing_keys: bool = False, + ): + MapTransform.__init__(self, keys, allow_missing_keys) + self.shuffle = RandCoarseShuffle( + holes=holes, + spatial_size=spatial_size, + max_holes=max_holes, + max_spatial_size=max_spatial_size, + prob=prob, + ) + + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandCoarseShuffled": + self.shuffle.set_random_state(seed, state) + super().set_random_state(seed, state) + return self + + def randomize(self, img_size: Sequence[int]) -> None: + self.shuffle.randomize(img_size=img_size) + + def __call__(self, data): + d = dict(data) + # expect all the specified keys have same spatial shape + self.randomize(d[self.keys[0]].shape[1:]) + if self.shuffle._do_transform: + for key in self.key_iterator(d): + d[key] = self.shuffle(img=d[key]) + + return d + + class HistogramNormalized(MapTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.HistogramNormalize`. @@ -1562,3 +1639,4 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda RandKSpaceSpikeNoiseD = RandKSpaceSpikeNoiseDict = RandKSpaceSpikeNoised RandCoarseDropoutD = RandCoarseDropoutDict = RandCoarseDropoutd HistogramNormalizeD = HistogramNormalizeDict = HistogramNormalized +RandCoarseShuffleD = RandCoarseShuffleDict = RandCoarseShuffled diff --git a/tests/test_rand_coarse_shuffle.py b/tests/test_rand_coarse_shuffle.py index 97d492fd24..0b8cdc6cf8 100644 --- a/tests/test_rand_coarse_shuffle.py +++ b/tests/test_rand_coarse_shuffle.py @@ -45,7 +45,7 @@ class TestRandCoarseShuffle(unittest.TestCase): @parameterized.expand(TEST_CASES) - def test_local_patch_shuffle(self, input_param, input_data, expected_val): + def test_shuffle(self, input_param, input_data, expected_val): g = RandCoarseShuffle(**input_param) g.set_random_state(seed=12) result = g(**input_data) diff --git a/tests/test_rand_coarse_shuffled.py b/tests/test_rand_coarse_shuffled.py new file mode 100644 index 0000000000..d2845fdaae --- /dev/null +++ b/tests/test_rand_coarse_shuffled.py @@ -0,0 +1,56 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import RandCoarseShuffled + +TEST_CASES = [ + [ + {"keys": "img", "holes": 5, "spatial_size": 1, "max_spatial_size": -1, "prob": 0.0}, + {"img": np.arange(8).reshape((1, 2, 2, 2))}, + np.arange(8).reshape((1, 2, 2, 2)), + ], + [ + {"keys": "img", "holes": 10, "spatial_size": 1, "max_spatial_size": -1, "prob": 1.0}, + {"img": np.arange(27).reshape((1, 3, 3, 3))}, + np.asarray( + [ + [ + [[13, 17, 5], [6, 16, 25], [12, 15, 22]], + [[24, 7, 3], [9, 2, 23], [0, 4, 26]], + [[19, 11, 14], [1, 20, 8], [18, 10, 21]], + ] + ] + ), + ], + [ + {"keys": "img", "holes": 2, "spatial_size": 1, "max_spatial_size": -1, "prob": 1.0}, + {"img": np.arange(16).reshape((2, 2, 2, 2))}, + np.asarray([[[[7, 2], [1, 4]], [[5, 0], [3, 6]]], [[[8, 13], [10, 15]], [[14, 12], [11, 9]]]]), + ], +] + + +class TestRandCoarseShuffled(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_shuffle(self, input_param, input_data, expected_val): + g = RandCoarseShuffled(**input_param) + g.set_random_state(seed=12) + result = g(input_data) + np.testing.assert_allclose(result["img"], expected_val, rtol=1e-4, atol=1e-4) + + +if __name__ == "__main__": + unittest.main() From 176a0fecfeedeb768ccf853826e753c31818c267 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Sat, 11 Sep 2021 06:58:22 -0400 Subject: [PATCH 03/19] Add device to ToTensor (#2926) Signed-off-by: ahatamizadeh --- monai/transforms/utility/array.py | 8 +++++--- monai/utils/type_conversion.py | 18 +++++++++--------- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 2eb6c447c6..add47e27ca 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -334,11 +334,15 @@ class ToTensor(Transform): backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__(self, device: Optional[torch.device] = None) -> None: + super().__init__() + self.device = device + def __call__(self, img: NdarrayOrTensor) -> torch.Tensor: """ Apply the transform to `img` and make it contiguous. """ - return convert_to_tensor(img, wrap_sequence=True) # type: ignore + return convert_to_tensor(img, wrap_sequence=True, device=self.device) # type: ignore class EnsureType(Transform): @@ -399,8 +403,6 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img` and make it contiguous. """ - if isinstance(img, torch.Tensor): - img = img.detach().cpu().numpy() return cp.ascontiguousarray(cp.asarray(img)) # type: ignore diff --git a/monai/utils/type_conversion.py b/monai/utils/type_conversion.py index 3688b02d26..47b48aa2b8 100644 --- a/monai/utils/type_conversion.py +++ b/monai/utils/type_conversion.py @@ -83,7 +83,7 @@ def get_dtype(data: Any): return type(data) -def convert_to_tensor(data, wrap_sequence: bool = False): +def convert_to_tensor(data, wrap_sequence: bool = False, device: Optional[torch.device] = None): """ Utility to convert the input data to a PyTorch Tensor. If passing a dictionary, list or tuple, recursively check every item and convert it to PyTorch Tensor. @@ -97,26 +97,26 @@ def convert_to_tensor(data, wrap_sequence: bool = False): """ if isinstance(data, torch.Tensor): - return data.contiguous() + return data.contiguous().to(device) if isinstance(data, np.ndarray): # skip array of string classes and object, refer to: # https://github.com/pytorch/pytorch/blob/v1.9.0/torch/utils/data/_utils/collate.py#L13 if re.search(r"[SaUO]", data.dtype.str) is None: # numpy array with 0 dims is also sequence iterable, # `ascontiguousarray` will add 1 dim if img has no dim, so we only apply on data with dims - return torch.as_tensor(data if data.ndim == 0 else np.ascontiguousarray(data)) + return torch.as_tensor(data if data.ndim == 0 else np.ascontiguousarray(data), device=device) elif has_cp and isinstance(data, cp_ndarray): - return torch.as_tensor(data) + return torch.as_tensor(data, device=device) elif isinstance(data, (float, int, bool)): - return torch.as_tensor(data) + return torch.as_tensor(data, device=device) elif isinstance(data, Sequence) and wrap_sequence: - return torch.as_tensor(data) + return torch.as_tensor(data, device=device) elif isinstance(data, list): - return [convert_to_tensor(i) for i in data] + return [convert_to_tensor(i, device=device) for i in data] elif isinstance(data, tuple): - return tuple(convert_to_tensor(i) for i in data) + return tuple(convert_to_tensor(i, device=device) for i in data) elif isinstance(data, dict): - return {k: convert_to_tensor(v) for k, v in data.items()} + return {k: convert_to_tensor(v, device=device) for k, v in data.items()} return data From 63fbc5fba29cd1682f093397118dcfc6af4e9d0d Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Sun, 12 Sep 2021 11:02:20 +0800 Subject: [PATCH 04/19] 2920 Enhance padding mode for Tensor data (#2921) * [DLMED] enhance the pad mode Signed-off-by: Nic Ma * [DLMED] update all the tensor pad related Signed-off-by: Nic Ma * [DLMED] fix error tests Signed-off-by: Nic Ma * [DLMED] fix GPU tests Signed-off-by: Nic Ma * [DLMED] update according to comments Signed-off-by: Nic Ma Signed-off-by: ahatamizadeh --- docs/source/transforms.rst | 6 + monai/transforms/__init__.py | 4 +- monai/transforms/croppad/array.py | 173 +++++++++++++++---------- monai/transforms/croppad/dictionary.py | 53 ++++---- monai/transforms/spatial/array.py | 87 +++++++------ monai/transforms/spatial/dictionary.py | 41 +++--- monai/transforms/utils.py | 31 ++++- monai/utils/type_conversion.py | 14 +- tests/test_convert_data_type.py | 13 +- tests/test_divisible_pad.py | 10 +- tests/test_pad_collation.py | 4 +- tests/test_spatial_pad.py | 56 ++++---- 12 files changed, 315 insertions(+), 177 deletions(-) diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index b61da87551..add6b9b40c 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -59,6 +59,12 @@ Vanilla Transforms Crop and Pad ^^^^^^^^^^^^ +`Pad` +""""" +.. autoclass:: Pad + :members: + :special-members: __call__ + `SpatialPad` """""""""""" .. autoclass:: SpatialPad diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index e4ec38f82d..a07dee867b 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -18,6 +18,7 @@ CenterSpatialCrop, CropForeground, DivisiblePad, + Pad, RandCropByLabelClasses, RandCropByPosNegLabel, RandScaleCrop, @@ -48,7 +49,7 @@ DivisiblePadd, DivisiblePadD, DivisiblePadDict, - NumpyPadModeSequence, + PadModeSequence, RandCropByLabelClassesd, RandCropByLabelClassesD, RandCropByLabelClassesDict, @@ -490,6 +491,7 @@ allow_missing_keys_mode, compute_divisible_spatial_size, convert_inverse_interp_mode, + convert_pad_mode, copypaste_arrays, create_control_grid, create_grid, diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 74f556cc1a..7e3bc835dd 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -27,6 +27,7 @@ from monai.transforms.transform import Randomizable, Transform from monai.transforms.utils import ( compute_divisible_spatial_size, + convert_pad_mode, generate_label_classes_crop_centers, generate_pos_neg_label_crop_centers, generate_spatial_bounding_box, @@ -35,7 +36,15 @@ map_classes_to_indices, weighted_patch_samples, ) -from monai.utils import Method, NumpyPadMode, ensure_tuple, ensure_tuple_rep, fall_back_tuple, look_up_option +from monai.utils import ( + Method, + NumpyPadMode, + PytorchPadMode, + ensure_tuple, + ensure_tuple_rep, + fall_back_tuple, + look_up_option, +) from monai.utils.enums import TransformBackends from monai.utils.type_conversion import convert_data_type @@ -61,16 +70,18 @@ class Pad(Transform): """ Perform padding for a given an amount of padding in each dimension. - If input is `torch.Tensor` and mode is `constant`, `torch.nn.functional.pad` will be used. - Otherwise, `np.pad` will be used (input converted to `np.ndarray` if necessary). - Uses np.pad so in practice, a mode needs to be provided. See numpy.lib.arraypad.pad - for additional details. + If input is `torch.Tensor`, `torch.nn.functional.pad` will be used, otherwise, `np.pad` will be used. + Args: to_pad: the amount to be padded in each dimension [(low_H, high_H), (low_W, high_W), ...]. - mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, - ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. One of the listed string values or a user supplied function. Defaults to ``"constant"``. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + kwargs: other arguments for the `np.pad` or `torch.pad` function. + note that `np.pad` treats channel dimension as the first dimension. """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] @@ -78,43 +89,50 @@ class Pad(Transform): def __init__( self, to_pad: List[Tuple[int, int]], - mode: Union[NumpyPadMode, str, None] = NumpyPadMode.CONSTANT, - **np_kwargs, + mode: Union[NumpyPadMode, PytorchPadMode, str] = NumpyPadMode.CONSTANT, + **kwargs, ) -> None: self.to_pad = to_pad - self.mode = mode or NumpyPadMode.CONSTANT - self.np_kwargs = np_kwargs + self.mode = mode + self.kwargs = kwargs @staticmethod - def _np_pad(img: np.ndarray, all_pad_width, mode, **np_kwargs) -> np.ndarray: + def _np_pad(img: np.ndarray, all_pad_width, mode, **kwargs) -> np.ndarray: img_np, *_ = convert_data_type(img, np.ndarray) - return np.pad(img_np, all_pad_width, mode=mode, **np_kwargs) # type: ignore + return np.pad(img_np, all_pad_width, mode=mode, **kwargs) # type: ignore @staticmethod - def _pt_pad(img: torch.Tensor, all_pad_width, mode, **np_kwargs) -> torch.Tensor: - pt_pad_width = [val for sublist in all_pad_width for val in sublist[::-1]][::-1] - return pad_pt(img, pt_pad_width, mode=mode, **np_kwargs) + def _pt_pad(img: torch.Tensor, all_pad_width, mode, **kwargs) -> torch.Tensor: + pt_pad_width = [val for sublist in all_pad_width[1:] for val in sublist[::-1]][::-1] + # torch.pad expects `[B, C, H, W, [D]]` shape + return pad_pt(img.unsqueeze(0), pt_pad_width, mode=mode, **kwargs).squeeze(0) - def __call__(self, img: NdarrayTensor, mode: Optional[Union[NumpyPadMode, str]] = None) -> NdarrayTensor: + def __call__( + self, + img: NdarrayTensor, + mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None, + ) -> NdarrayTensor: """ Args: img: data to be transformed, assuming `img` is channel-first and padding doesn't apply to the channel dim. - mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, - ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} - One of the listed string values or a user supplied function. Defaults to ``self.mode``. - See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"`` or ``"circular"``}. + One of the listed string values or a user supplied function. Defaults to `self.mode`. + See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + """ if not np.asarray(self.to_pad).any(): # all zeros, skip padding return img - mode = mode or self.mode - mode = mode.value if isinstance(mode, NumpyPadMode) else mode - if isinstance(img, torch.Tensor) and mode == "constant" and not self.np_kwargs: + mode = convert_pad_mode(dst=img, mode=mode or self.mode).value + if isinstance(img, torch.Tensor): pad = self._pt_pad else: pad = self._np_pad # type: ignore - return pad(img, self.to_pad, mode, **self.np_kwargs) + return pad(img, self.to_pad, mode, **self.kwargs) class SpatialPad(Transform): @@ -135,12 +153,14 @@ class SpatialPad(Transform): `spatial_size=[32, 25, -1]`, the spatial size of output data will be [32, 30, 30]. method: {``"symmetric"``, ``"end"``} Pad image symmetrically on every side or only pad at the end sides. Defaults to ``"symmetric"``. - mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, - ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. One of the listed string values or a user supplied function. Defaults to ``"constant"``. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html - np_kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension. - more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + kwargs: other arguments for the `np.pad` or `torch.pad` function. + note that `np.pad` treats channel dimension as the first dimension. """ @@ -150,13 +170,13 @@ def __init__( self, spatial_size: Union[Sequence[int], int], method: Union[Method, str] = Method.SYMMETRIC, - mode: Union[NumpyPadMode, str] = NumpyPadMode.CONSTANT, - **np_kwargs, + mode: Union[NumpyPadMode, PytorchPadMode, str] = NumpyPadMode.CONSTANT, + **kwargs, ) -> None: self.spatial_size = spatial_size self.method: Method = look_up_option(method, Method) - self.mode: NumpyPadMode = look_up_option(mode, NumpyPadMode) - self.np_kwargs = np_kwargs + self.mode = mode + self.kwargs = kwargs def _determine_data_pad_width(self, data_shape: Sequence[int]) -> List[Tuple[int, int]]: spatial_size = fall_back_tuple(self.spatial_size, data_shape) @@ -168,15 +188,22 @@ def _determine_data_pad_width(self, data_shape: Sequence[int]) -> List[Tuple[int return pad_width return [(0, max(sp_i - data_shape[i], 0)) for i, sp_i in enumerate(spatial_size)] - def __call__(self, img: NdarrayTensor, mode: Optional[Union[NumpyPadMode, str]] = None) -> NdarrayTensor: + def __call__( + self, + img: NdarrayTensor, + mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None, + ) -> NdarrayTensor: """ Args: img: data to be transformed, assuming `img` is channel-first and padding doesn't apply to the channel dim. - mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, - ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} - One of the listed string values or a user supplied function. Defaults to ``self.mode``. + mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + One of the listed string values or a user supplied function. Defaults to `self.mode`. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + """ data_pad_width = self._determine_data_pad_width(img.shape[1:]) all_pad_width = [(0, 0)] + data_pad_width @@ -184,8 +211,7 @@ def __call__(self, img: NdarrayTensor, mode: Optional[Union[NumpyPadMode, str]] # all zeros, skip padding return img - mode = look_up_option(mode or self.mode, NumpyPadMode) - padder = Pad(all_pad_width, mode, **self.np_kwargs) + padder = Pad(all_pad_width, mode or self.mode, **self.kwargs) return padder(img) @@ -204,13 +230,14 @@ class BorderPad(Transform): for example, image shape(CHW) is [1, 4, 4], spatial_border is [1, 2, 3, 4], pad top of H dim with 1, pad bottom of H dim with 2, pad left of W dim with 3, pad right of W dim with 4. the result shape is [1, 7, 11]. - - mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, - ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. One of the listed string values or a user supplied function. Defaults to ``"constant"``. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html - np_kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension. - more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + kwargs: other arguments for the `np.pad` or `torch.pad` function. + note that `np.pad` treats channel dimension as the first dimension. """ @@ -219,22 +246,28 @@ class BorderPad(Transform): def __init__( self, spatial_border: Union[Sequence[int], int], - mode: Union[NumpyPadMode, str] = NumpyPadMode.CONSTANT, - **np_kwargs, + mode: Union[NumpyPadMode, PytorchPadMode, str] = NumpyPadMode.CONSTANT, + **kwargs, ) -> None: self.spatial_border = spatial_border - self.mode: NumpyPadMode = look_up_option(mode, NumpyPadMode) - self.np_kwargs = np_kwargs + self.mode = mode + self.kwargs = kwargs - def __call__(self, img: NdarrayTensor, mode: Optional[Union[NumpyPadMode, str]] = None) -> NdarrayTensor: + def __call__( + self, + img: NdarrayTensor, + mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None, + ) -> NdarrayTensor: """ Args: img: data to be transformed, assuming `img` is channel-first and padding doesn't apply to the channel dim. - mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, - ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} - One of the listed string values or a user supplied function. Defaults to ``self.mode``. + mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + One of the listed string values or a user supplied function. Defaults to `self.mode`. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html Raises: ValueError: When ``self.spatial_border`` does not contain ints. @@ -261,8 +294,7 @@ def __call__(self, img: NdarrayTensor, mode: Optional[Union[NumpyPadMode, str]] ) all_pad_width = [(0, 0)] + data_pad_width - mode = look_up_option(mode or self.mode, NumpyPadMode) - padder = Pad(all_pad_width, mode, **self.np_kwargs) + padder = Pad(all_pad_width, mode or self.mode, **self.kwargs) return padder(img) @@ -276,47 +308,56 @@ class DivisiblePad(Transform): def __init__( self, k: Union[Sequence[int], int], - mode: Union[NumpyPadMode, str] = NumpyPadMode.CONSTANT, + mode: Union[NumpyPadMode, PytorchPadMode, str] = NumpyPadMode.CONSTANT, method: Union[Method, str] = Method.SYMMETRIC, - **np_kwargs, + **kwargs, ) -> None: """ Args: k: the target k for each spatial dimension. if `k` is negative or 0, the original size is preserved. if `k` is an int, the same `k` be applied to all the input spatial dimensions. - mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, - ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. One of the listed string values or a user supplied function. Defaults to ``"constant"``. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html method: {``"symmetric"``, ``"end"``} Pad image symmetrically on every side or only pad at the end sides. Defaults to ``"symmetric"``. - np_kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension. - more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + kwargs: other arguments for the `np.pad` or `torch.pad` function. + note that `np.pad` treats channel dimension as the first dimension. See also :py:class:`monai.transforms.SpatialPad` """ self.k = k self.mode: NumpyPadMode = NumpyPadMode(mode) self.method: Method = Method(method) - self.np_kwargs = np_kwargs + self.kwargs = kwargs - def __call__(self, img: NdarrayTensor, mode: Optional[Union[NumpyPadMode, str]] = None) -> NdarrayTensor: + def __call__( + self, + img: NdarrayTensor, + mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None, + ) -> NdarrayTensor: """ Args: img: data to be transformed, assuming `img` is channel-first and padding doesn't apply to the channel dim. - mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, - ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} - One of the listed string values or a user supplied function. Defaults to ``self.mode``. + mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + One of the listed string values or a user supplied function. Defaults to `self.mode`. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + """ new_size = compute_divisible_spatial_size(spatial_shape=img.shape[1:], k=self.k) spatial_pad = SpatialPad( spatial_size=new_size, method=self.method, mode=mode or self.mode, - **self.np_kwargs, + **self.kwargs, ) return spatial_pad(img) diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 9e33ab2db1..5c846b8d04 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -49,11 +49,11 @@ weighted_patch_samples, ) from monai.utils import ImageMetaKey as Key -from monai.utils import Method, NumpyPadMode, ensure_tuple, ensure_tuple_rep, fall_back_tuple +from monai.utils import Method, NumpyPadMode, PytorchPadMode, ensure_tuple, ensure_tuple_rep, fall_back_tuple from monai.utils.enums import InverseKeys __all__ = [ - "NumpyPadModeSequence", + "PadModeSequence", "SpatialPadd", "BorderPadd", "DivisiblePadd", @@ -99,6 +99,7 @@ ] NumpyPadModeSequence = Union[Sequence[Union[NumpyPadMode, str]], NumpyPadMode, str] +PadModeSequence = Union[Sequence[Union[NumpyPadMode, PytorchPadMode, str]], NumpyPadMode, PytorchPadMode, str] class SpatialPadd(MapTransform, InvertibleTransform): @@ -114,9 +115,9 @@ def __init__( keys: KeysCollection, spatial_size: Union[Sequence[int], int], method: Union[Method, str] = Method.SYMMETRIC, - mode: NumpyPadModeSequence = NumpyPadMode.CONSTANT, + mode: PadModeSequence = NumpyPadMode.CONSTANT, allow_missing_keys: bool = False, - **np_kwargs, + **kwargs, ) -> None: """ Args: @@ -129,19 +130,21 @@ def __init__( the spatial size of output data will be [32, 30, 30]. method: {``"symmetric"``, ``"end"``} Pad image symmetrically on every side or only pad at the end sides. Defaults to ``"symmetric"``. - mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, - ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. One of the listed string values or a user supplied function. Defaults to ``"constant"``. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html It also can be a sequence of string, each element corresponds to a key in ``keys``. allow_missing_keys: don't raise exception if key is missing. - np_kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension. - more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + kwargs: other arguments for the `np.pad` or `torch.pad` function. + note that `np.pad` treats channel dimension as the first dimension. """ super().__init__(keys, allow_missing_keys) self.mode = ensure_tuple_rep(mode, len(self.keys)) - self.padder = SpatialPad(spatial_size, method, **np_kwargs) + self.padder = SpatialPad(spatial_size, method, **kwargs) def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]: d = dict(data) @@ -183,9 +186,9 @@ def __init__( self, keys: KeysCollection, spatial_border: Union[Sequence[int], int], - mode: NumpyPadModeSequence = NumpyPadMode.CONSTANT, + mode: PadModeSequence = NumpyPadMode.CONSTANT, allow_missing_keys: bool = False, - **np_kwargs, + **kwargs, ) -> None: """ Args: @@ -202,19 +205,21 @@ def __init__( pad bottom of H dim with 2, pad left of W dim with 3, pad right of W dim with 4. the result shape is [1, 7, 11]. - mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, - ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. One of the listed string values or a user supplied function. Defaults to ``"constant"``. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html It also can be a sequence of string, each element corresponds to a key in ``keys``. allow_missing_keys: don't raise exception if key is missing. - np_kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension. - more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + kwargs: other arguments for the `np.pad` or `torch.pad` function. + note that `np.pad` treats channel dimension as the first dimension. """ super().__init__(keys, allow_missing_keys) self.mode = ensure_tuple_rep(mode, len(self.keys)) - self.padder = BorderPad(spatial_border=spatial_border, **np_kwargs) + self.padder = BorderPad(spatial_border=spatial_border, **kwargs) def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]: d = dict(data) @@ -260,10 +265,10 @@ def __init__( self, keys: KeysCollection, k: Union[Sequence[int], int], - mode: NumpyPadModeSequence = NumpyPadMode.CONSTANT, + mode: PadModeSequence = NumpyPadMode.CONSTANT, method: Union[Method, str] = Method.SYMMETRIC, allow_missing_keys: bool = False, - **np_kwargs, + **kwargs, ) -> None: """ Args: @@ -272,23 +277,25 @@ def __init__( k: the target k for each spatial dimension. if `k` is negative or 0, the original size is preserved. if `k` is an int, the same `k` be applied to all the input spatial dimensions. - mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, - ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. One of the listed string values or a user supplied function. Defaults to ``"constant"``. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html It also can be a sequence of string, each element corresponds to a key in ``keys``. method: {``"symmetric"``, ``"end"``} Pad image symmetrically on every side or only pad at the end sides. Defaults to ``"symmetric"``. allow_missing_keys: don't raise exception if key is missing. - np_kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension. - more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + kwargs: other arguments for the `np.pad` or `torch.pad` function. + note that `np.pad` treats channel dimension as the first dimension. See also :py:class:`monai.transforms.SpatialPad` """ super().__init__(keys, allow_missing_keys) self.mode = ensure_tuple_rep(mode, len(self.keys)) - self.padder = DivisiblePad(k=k, method=method, **np_kwargs) + self.padder = DivisiblePad(k=k, method=method, **kwargs) def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]: d = dict(data) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 816e9d58f2..a9cb847b93 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -38,6 +38,7 @@ GridSamplePadMode, InterpolateMode, NumpyPadMode, + PytorchPadMode, ensure_tuple, ensure_tuple_rep, ensure_tuple_size, @@ -47,7 +48,7 @@ ) from monai.utils.enums import TransformBackends from monai.utils.module import look_up_option -from monai.utils.type_conversion import convert_data_type +from monai.utils.type_conversion import convert_data_type, convert_to_dst_type nib, _ = optional_import("nibabel") @@ -543,54 +544,60 @@ class Zoom(Transform): mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} The interpolation mode. Defaults to ``"area"``. See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate - padding_mode: {``"constant"``, ``"edge``", ``"linear_ramp``", ``"maximum``", ``"mean``", `"median``", - ``"minimum``", `"reflect``", ``"symmetric``", ``"wrap``", ``"empty``", ``"``"} + padding_mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + One of the listed string values or a user supplied function. Defaults to ``"constant"``. The mode to pad data after zooming. - See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html + See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html align_corners: This only has an effect when mode is 'linear', 'bilinear', 'bicubic' or 'trilinear'. Default: None. See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate keep_size: Should keep original size (padding/slicing if needed), default is True. - np_kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension. - more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + kwargs: other arguments for the `np.pad` or `torch.pad` function. + note that `np.pad` treats channel dimension as the first dimension. """ - backend = [TransformBackends.TORCH] + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] def __init__( self, zoom: Union[Sequence[float], float], mode: Union[InterpolateMode, str] = InterpolateMode.AREA, - padding_mode: Union[NumpyPadMode, str] = NumpyPadMode.EDGE, + padding_mode: Union[NumpyPadMode, PytorchPadMode, str] = NumpyPadMode.EDGE, align_corners: Optional[bool] = None, keep_size: bool = True, - **np_kwargs, + **kwargs, ) -> None: self.zoom = zoom self.mode: InterpolateMode = InterpolateMode(mode) - self.padding_mode: NumpyPadMode = NumpyPadMode(padding_mode) + self.padding_mode = padding_mode self.align_corners = align_corners self.keep_size = keep_size - self.np_kwargs = np_kwargs + self.kwargs = kwargs def __call__( self, img: NdarrayOrTensor, mode: Optional[Union[InterpolateMode, str]] = None, - padding_mode: Optional[Union[NumpyPadMode, str]] = None, + padding_mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None, align_corners: Optional[bool] = None, - ) -> torch.Tensor: + ) -> NdarrayOrTensor: """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]). mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} The interpolation mode. Defaults to ``self.mode``. See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate - padding_mode: {``"constant"``, ``"edge``", ``"linear_ramp``", ``"maximum``", ``"mean``", `"median``", - ``"minimum``", `"reflect``", ``"symmetric``", ``"wrap``", ``"empty``", ``"``"} - The mode to pad data after zooming, default to ``self.padding_mode``. - See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html + padding_mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + One of the listed string values or a user supplied function. Defaults to ``"constant"``. + The mode to pad data after zooming. + See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html align_corners: This only has an effect when mode is 'linear', 'bilinear', 'bicubic' or 'trilinear'. Defaults to ``self.align_corners``. See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate @@ -621,12 +628,12 @@ def __call__( elif diff < 0: # need slicing slice_vec[idx] = slice(half, half + od) - padding_mode = look_up_option(padding_mode or self.padding_mode, NumpyPadMode) - padder = Pad(pad_vec, padding_mode) + padder = Pad(pad_vec, padding_mode or self.padding_mode) zoomed = padder(zoomed) zoomed = zoomed[tuple(slice_vec)] - return zoomed + out, *_ = convert_to_dst_type(zoomed, dst=img) + return out class Rotate90(Transform): @@ -887,16 +894,19 @@ class RandZoom(RandomizableTransform): mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} The interpolation mode. Defaults to ``"area"``. See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate - padding_mode: {``"constant"``, ``"edge``", ``"linear_ramp``", ``"maximum``", ``"mean``", `"median``", - ``"minimum``", `"reflect``", ``"symmetric``", ``"wrap``", ``"empty``", ``"``"} + padding_mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + One of the listed string values or a user supplied function. Defaults to ``"constant"``. The mode to pad data after zooming. - See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html + See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html align_corners: This only has an effect when mode is 'linear', 'bilinear', 'bicubic' or 'trilinear'. Default: None. See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate keep_size: Should keep original size (pad if needed), default is True. - np_kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension. - more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + kwargs: other arguments for the `np.pad` or `torch.pad` function. + note that `np.pad` treats channel dimension as the first dimension. """ @@ -908,10 +918,10 @@ def __init__( min_zoom: Union[Sequence[float], float] = 0.9, max_zoom: Union[Sequence[float], float] = 1.1, mode: Union[InterpolateMode, str] = InterpolateMode.AREA, - padding_mode: Union[NumpyPadMode, str] = NumpyPadMode.EDGE, + padding_mode: Union[NumpyPadMode, PytorchPadMode, str] = NumpyPadMode.EDGE, align_corners: Optional[bool] = None, keep_size: bool = True, - **np_kwargs, + **kwargs, ) -> None: RandomizableTransform.__init__(self, prob) self.min_zoom = ensure_tuple(min_zoom) @@ -919,10 +929,10 @@ def __init__( if len(self.min_zoom) != len(self.max_zoom): raise AssertionError("min_zoom and max_zoom must have same length.") self.mode: InterpolateMode = look_up_option(mode, InterpolateMode) - self.padding_mode: NumpyPadMode = look_up_option(padding_mode, NumpyPadMode) + self.padding_mode = padding_mode self.align_corners = align_corners self.keep_size = keep_size - self.np_kwargs = np_kwargs + self.kwargs = kwargs self._zoom: Sequence[float] = [1.0] @@ -934,19 +944,22 @@ def __call__( self, img: NdarrayOrTensor, mode: Optional[Union[InterpolateMode, str]] = None, - padding_mode: Optional[Union[NumpyPadMode, str]] = None, + padding_mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None, align_corners: Optional[bool] = None, - ) -> torch.Tensor: + ) -> NdarrayOrTensor: """ Args: img: channel first array, must have shape 2D: (nchannels, H, W), or 3D: (nchannels, H, W, D). mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} The interpolation mode. Defaults to ``self.mode``. See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate - padding_mode: {``"constant"``, ``"edge``", ``"linear_ramp``", ``"maximum``", ``"mean``", `"median``", - ``"minimum``", `"reflect``", ``"symmetric``", ``"wrap``", ``"empty``", ``"``"} - The mode to pad data after zooming, default to ``self.padding_mode``. - See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html + padding_mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + One of the listed string values or a user supplied function. Defaults to ``"constant"``. + The mode to pad data after zooming. + See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html align_corners: This only has an effect when mode is 'linear', 'bilinear', 'bicubic' or 'trilinear'. Defaults to ``self.align_corners``. See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate @@ -967,9 +980,9 @@ def __call__( self._zoom, keep_size=self.keep_size, mode=look_up_option(mode or self.mode, InterpolateMode), - padding_mode=look_up_option(padding_mode or self.padding_mode, NumpyPadMode), + padding_mode=padding_mode or self.padding_mode, align_corners=align_corners or self.align_corners, - **self.np_kwargs, + **self.kwargs, ) return zoomer(img) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index c09d8e8011..96fe21db12 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -50,6 +50,7 @@ GridSamplePadMode, InterpolateMode, NumpyPadMode, + PytorchPadMode, ensure_tuple, ensure_tuple_rep, fall_back_tuple, @@ -115,7 +116,7 @@ GridSampleModeSequence = Union[Sequence[Union[GridSampleMode, str]], GridSampleMode, str] GridSamplePadModeSequence = Union[Sequence[Union[GridSamplePadMode, str]], GridSamplePadMode, str] InterpolateModeSequence = Union[Sequence[Union[InterpolateMode, str]], InterpolateMode, str] -NumpyPadModeSequence = Union[Sequence[Union[NumpyPadMode, str]], NumpyPadMode, str] +PadModeSequence = Union[Sequence[Union[NumpyPadMode, PytorchPadMode, str]], NumpyPadMode, PytorchPadMode, str] class Spacingd(MapTransform, InvertibleTransform): @@ -1520,18 +1521,21 @@ class Zoomd(MapTransform, InvertibleTransform): The interpolation mode. Defaults to ``"area"``. See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate It also can be a sequence of string, each element corresponds to a key in ``keys``. - padding_mode: {``"constant"``, ``"edge``", ``"linear_ramp``", ``"maximum``", ``"mean``", `"median``", - ``"minimum``", `"reflect``", ``"symmetric``", ``"wrap``", ``"empty``", ``"``"} + padding_mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + One of the listed string values or a user supplied function. Defaults to ``"constant"``. The mode to pad data after zooming. - See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html + See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html align_corners: This only has an effect when mode is 'linear', 'bilinear', 'bicubic' or 'trilinear'. Default: None. See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate It also can be a sequence of bool or None, each element corresponds to a key in ``keys``. keep_size: Should keep original size (pad if needed), default is True. allow_missing_keys: don't raise exception if key is missing. - np_kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension. - more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + kwargs: other arguments for the `np.pad` or `torch.pad` function. + note that `np.pad` treats channel dimension as the first dimension. """ @@ -1542,17 +1546,17 @@ def __init__( keys: KeysCollection, zoom: Union[Sequence[float], float], mode: InterpolateModeSequence = InterpolateMode.AREA, - padding_mode: NumpyPadModeSequence = NumpyPadMode.EDGE, + padding_mode: PadModeSequence = NumpyPadMode.EDGE, align_corners: Union[Sequence[Optional[bool]], Optional[bool]] = None, keep_size: bool = True, allow_missing_keys: bool = False, - **np_kwargs, + **kwargs, ) -> None: super().__init__(keys, allow_missing_keys) self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) - self.zoomer = Zoom(zoom=zoom, keep_size=keep_size, **np_kwargs) + self.zoomer = Zoom(zoom=zoom, keep_size=keep_size, **kwargs) def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) @@ -1622,17 +1626,20 @@ class RandZoomd(RandomizableTransform, MapTransform, InvertibleTransform): The interpolation mode. Defaults to ``"area"``. See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate It also can be a sequence of string, each element corresponds to a key in ``keys``. - padding_mode: {``"constant"``, ``"edge``", ``"linear_ramp``", ``"maximum``", ``"mean``", `"median``", - ``"minimum``", `"reflect``", ``"symmetric``", ``"wrap``", ``"empty``", ``"``"} + padding_mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + One of the listed string values or a user supplied function. Defaults to ``"constant"``. The mode to pad data after zooming. - See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html + See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html align_corners: This only has an effect when mode is 'linear', 'bilinear', 'bicubic' or 'trilinear'. Default: None. See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate It also can be a sequence of bool or None, each element corresponds to a key in ``keys``. keep_size: Should keep original size (pad if needed), default is True. allow_missing_keys: don't raise exception if key is missing. - np_kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension. + kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension. more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html """ @@ -1646,11 +1653,11 @@ def __init__( min_zoom: Union[Sequence[float], float] = 0.9, max_zoom: Union[Sequence[float], float] = 1.1, mode: InterpolateModeSequence = InterpolateMode.AREA, - padding_mode: NumpyPadModeSequence = NumpyPadMode.EDGE, + padding_mode: PadModeSequence = NumpyPadMode.EDGE, align_corners: Union[Sequence[Optional[bool]], Optional[bool]] = None, keep_size: bool = True, allow_missing_keys: bool = False, - **np_kwargs, + **kwargs, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) @@ -1663,7 +1670,7 @@ def __init__( self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) self.keep_size = keep_size - self.np_kwargs = np_kwargs + self.kwargs = kwargs self._zoom: Sequence[float] = [1.0] @@ -1683,7 +1690,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N elif len(self._zoom) == 2 and img_dims > 3: # if 2 zoom factors provided for 3D data, use the first factor for H and W dims, second factor for D dim self._zoom = ensure_tuple_rep(self._zoom[0], img_dims - 2) + ensure_tuple(self._zoom[-1]) - zoomer = Zoom(self._zoom, keep_size=self.keep_size, **self.np_kwargs) + zoomer = Zoom(self._zoom, keep_size=self.keep_size, **self.kwargs) for key, mode, padding_mode, align_corners in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners ): diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 9618723430..f0be87de0b 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -22,7 +22,7 @@ import monai import monai.transforms.transform from monai.config import DtypeLike, IndexSelection -from monai.config.type_definitions import NdarrayOrTensor +from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor from monai.networks.layers import GaussianFilter from monai.transforms.compose import Compose, OneOf from monai.transforms.transform import MapTransform, Transform @@ -30,12 +30,15 @@ GridSampleMode, InterpolateMode, InverseKeys, + NumpyPadMode, + PytorchPadMode, deprecated_arg, ensure_tuple, ensure_tuple_rep, ensure_tuple_size, fall_back_tuple, issequenceiterable, + look_up_option, min_version, optional_import, ) @@ -85,6 +88,7 @@ "get_number_image_type_conversions", "get_transform_backends", "print_transform_backends", + "convert_pad_mode", ] @@ -1259,5 +1263,30 @@ def print_table_column(name, torch, numpy, color=Colors.none): print_color(f"Number of uncategorised: {n_uncategorized}", Colors.red) +def convert_pad_mode(dst: NdarrayTensor, mode: Union[NumpyPadMode, PytorchPadMode, str]): + """ + Utility to convert padding mode between numpy array and PyTorch Tensor. + + Args: + dst: target data to convert padding mode for, should be numpy array or PyTorch Tensor. + mode: current padding mode. + + """ + mode = mode.value if isinstance(mode, (NumpyPadMode, PytorchPadMode)) else mode + if isinstance(dst, torch.Tensor): + if mode == "wrap": + mode = "circular" + if mode == "edge": + mode = "replicate" + return look_up_option(mode, PytorchPadMode) + if isinstance(dst, np.ndarray): + if mode == "circular": + mode = "wrap" + if mode == "replicate": + mode = "edge" + return look_up_option(mode, NumpyPadMode) + raise ValueError(f"unsupported data type: {type(dst)}.") + + if __name__ == "__main__": print_transform_backends() diff --git a/monai/utils/type_conversion.py b/monai/utils/type_conversion.py index 47b48aa2b8..b51ff6a9c8 100644 --- a/monai/utils/type_conversion.py +++ b/monai/utils/type_conversion.py @@ -206,7 +206,9 @@ def convert_data_type( def convert_to_dst_type(src: Any, dst: NdarrayOrTensor) -> Tuple[NdarrayOrTensor, type, Optional[torch.device]]: """ - Convert `src` to the same `torch.Tensor`/`np.ndarray` and data type as `dst`. + If `dst` is `torch.Tensor` or its subclass, convert `src` to `torch.Tensor` with the same data type as `dst`, + if `dst` is `numpy.ndarray` or its subclass, convert to `numpy.ndarray` with the same data type as `dst`, + otherwise, convert to the type of `dst` directly. See Also: :func:`convert_data_type` @@ -214,4 +216,12 @@ def convert_to_dst_type(src: Any, dst: NdarrayOrTensor) -> Tuple[NdarrayOrTensor device = None if isinstance(dst, torch.Tensor): device = dst.device - return convert_data_type(data=src, output_type=type(dst), device=device, dtype=dst.dtype) + + output_type: Any + if isinstance(dst, torch.Tensor): + output_type = torch.Tensor + elif isinstance(dst, np.ndarray): + output_type = np.ndarray + else: + output_type = type(dst) + return convert_data_type(data=src, output_type=output_type, device=device, dtype=dst.dtype) diff --git a/tests/test_convert_data_type.py b/tests/test_convert_data_type.py index a7fc64f950..e48f6e8854 100644 --- a/tests/test_convert_data_type.py +++ b/tests/test_convert_data_type.py @@ -25,6 +25,10 @@ TESTS.append((in_type(np.array(1.0)), out_type(np.array(1.0)))) # type: ignore +class TestTensor(torch.Tensor): + pass + + class TestConvertDataType(unittest.TestCase): @parameterized.expand(TESTS) def test_convert_data_type(self, in_image, im_out): @@ -49,7 +53,8 @@ def test_ill_arg(self): class TestConvertDataSame(unittest.TestCase): - @parameterized.expand(TESTS) + # add test for subclass of Tensor + @parameterized.expand(TESTS + [(np.array(1.0), TestTensor(np.array(1.0)))]) def test_convert_data_type(self, in_image, im_out): converted_im, orig_type, orig_device = convert_to_dst_type(in_image, im_out) # check input is unchanged @@ -57,7 +62,11 @@ def test_convert_data_type(self, in_image, im_out): if isinstance(in_image, torch.Tensor): self.assertEqual(in_image.device, orig_device) # check output is desired type - self.assertEqual(type(converted_im), type(im_out)) + if isinstance(im_out, torch.Tensor): + output_type = torch.Tensor + else: + output_type = np.ndarray + self.assertEqual(type(converted_im), output_type) # check dtype is unchanged if isinstance(in_type, (np.ndarray, torch.Tensor)): self.assertEqual(converted_im.dtype, im_out.dtype) diff --git a/tests/test_divisible_pad.py b/tests/test_divisible_pad.py index ca15b4b347..810d08252c 100644 --- a/tests/test_divisible_pad.py +++ b/tests/test_divisible_pad.py @@ -50,11 +50,13 @@ def test_pad_shape(self, input_param, input_data, expected_val): self.assertAlmostEqual(result.shape, expected_val.shape) def test_pad_kwargs(self): - padder = DivisiblePad(k=5, mode="constant", constant_values=((0, 0), (1, 1), (2, 2))) for p in TEST_NDARRAYS: - result = padder(p(np.zeros((3, 8, 4)))) - result = result.cpu() if isinstance(result, torch.Tensor) else result - torch.testing.assert_allclose(result[:, :1, :4], np.ones((3, 1, 4)), rtol=1e-7, atol=0) + input_data = p(np.zeros((3, 8, 4))) + if isinstance(input_data, np.ndarray): + result = DivisiblePad(k=5, mode="constant", constant_values=((0, 0), (1, 1), (2, 2)))(input_data) + np.testing.assert_allclose(result[:, :1, :4], np.ones((3, 1, 4)), rtol=1e-7, atol=0) + else: + result = DivisiblePad(k=5, mode="constant", value=2)(input_data).cpu() torch.testing.assert_allclose(result[:, :, 4:5], np.ones((3, 10, 1)) + 1, rtol=1e-7, atol=0) diff --git a/tests/test_pad_collation.py b/tests/test_pad_collation.py index 47bfa69582..eda36f4761 100644 --- a/tests/test_pad_collation.py +++ b/tests/test_pad_collation.py @@ -38,8 +38,8 @@ TESTS: List[Tuple] = [] for pad_collate in [ - lambda x: pad_list_data_collate(batch=x, method="end", mode="constant", constant_values=1), - PadListDataCollate(method="end", mode="constant", constant_values=1), + lambda x: pad_list_data_collate(batch=x, method="end", mode="constant"), + PadListDataCollate(method="end", mode="constant"), ]: TESTS.append((dict, pad_collate, RandSpatialCropd("image", roi_size=[8, 7], random_size=True))) TESTS.append((dict, pad_collate, RandRotated("image", prob=1, range_x=np.pi, keep_size=False))) diff --git a/tests/test_spatial_pad.py b/tests/test_spatial_pad.py index 86d010bbad..3d237c6681 100644 --- a/tests/test_spatial_pad.py +++ b/tests/test_spatial_pad.py @@ -17,34 +17,41 @@ from parameterized import parameterized from monai.transforms import SpatialPad -from monai.utils.enums import NumpyPadMode +from monai.utils.enums import NumpyPadMode, PytorchPadMode from monai.utils.misc import set_determinism from tests.utils import TEST_NDARRAYS TESTS = [] -# Numpy modes -MODES: List = [ +MODES = [] + +# Test modes +NP_MODES: List = [ "constant", "edge", - "linear_ramp", - "maximum", - "mean", - "median", - "minimum", - "reflect", - "symmetric", + # `reflect` mode is not supported in some PyTorch versions, skip the test + # "reflect", "wrap", - "empty", ] -MODES += [NumpyPadMode(i) for i in MODES] +MODES += NP_MODES +MODES += [NumpyPadMode(i) for i in NP_MODES] + +PT_MODES: list = [ + "constant", + "replicate", + "circular", + # `reflect` mode is not supported in some PyTorch versions, skip the test + # "reflect", +] +MODES += PT_MODES +MODES += [PytorchPadMode(i) for i in PT_MODES] for mode in MODES: TESTS.append( [ - {"spatial_size": [50, 50], "method": "end", "mode": mode}, - (1, 2, 2), - (1, 50, 50), + {"spatial_size": [3, 4], "method": "end", "mode": mode}, + (1, 2, 3), + (1, 3, 4), ] ) @@ -86,14 +93,19 @@ def test_pad_shape(self, input_param, input_shape, expected_shape): torch.testing.assert_allclose(results[0], results[-1], atol=0, rtol=1e-5) def test_pad_kwargs(self): - padder = SpatialPad( - spatial_size=[15, 8], method="end", mode="constant", constant_values=((0, 0), (1, 1), (2, 2)) - ) for p in TEST_NDARRAYS: - result = padder(p(np.zeros((3, 8, 4)))) - if isinstance(result, torch.Tensor): - result = result.cpu().numpy() - torch.testing.assert_allclose(result[:, 8:, :4], np.ones((3, 7, 4)), rtol=1e-7, atol=0) + input_data = p(np.zeros((3, 8, 4))) + if isinstance(input_data, torch.Tensor): + result = ( + SpatialPad(spatial_size=[15, 8], method="end", mode="constant", value=2)(img=input_data) + .cpu() + .numpy() + ) + else: + result = SpatialPad( + spatial_size=[15, 8], method="end", mode="constant", constant_values=((0, 0), (1, 1), (2, 2)) + )(img=input_data) + torch.testing.assert_allclose(result[:, 8:, :4], np.ones((3, 7, 4)), rtol=1e-7, atol=0) torch.testing.assert_allclose(result[:, :, 4:], np.ones((3, 15, 4)) + 1, rtol=1e-7, atol=0) From 5627f230fb8402ef6b26088ba25aefe0bb2b6670 Mon Sep 17 00:00:00 2001 From: ahatamizadeh Date: Mon, 13 Sep 2021 03:08:24 -0700 Subject: [PATCH 05/19] add multimodal transformers Signed-off-by: ahatamizadeh --- docs/requirements.txt | 1 + docs/source/installation.md | 4 +- monai/config/deviceconfig.py | 1 + monai/networks/nets/vltransformer.py | 343 +++++++++++++++++++++++++++ requirements-dev.txt | 1 + setup.cfg | 3 + tests/test_vltransformer.py | 79 ++++++ 7 files changed, 430 insertions(+), 2 deletions(-) create mode 100644 monai/networks/nets/vltransformer.py create mode 100644 tests/test_vltransformer.py diff --git a/docs/requirements.txt b/docs/requirements.txt index 00dd4d2c1e..3530d63c49 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -20,3 +20,4 @@ sphinxcontrib-serializinghtml sphinx-autodoc-typehints==1.11.1 pandas einops +transformers==4.10.2 diff --git a/docs/source/installation.md b/docs/source/installation.md index 08ab109142..902f596dfc 100644 --- a/docs/source/installation.md +++ b/docs/source/installation.md @@ -174,9 +174,9 @@ Since MONAI v0.2.0, the extras syntax such as `pip install 'monai[nibabel]'` is - The options are ``` -[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops] +[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers] ``` which correspond to `nibabel`, `scikit-image`, `pillow`, `tensorboard`, -`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas` and `einops`, respectively. +`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas` , `einops` and `transformers`, respectively. - `pip install 'monai[all]'` installs all the optional dependencies. diff --git a/monai/config/deviceconfig.py b/monai/config/deviceconfig.py index 273431fc72..ff45b29531 100644 --- a/monai/config/deviceconfig.py +++ b/monai/config/deviceconfig.py @@ -73,6 +73,7 @@ def get_optional_config_values(): output["psutil"] = psutil_version output["pandas"] = get_package_version("pandas") output["einops"] = get_package_version("einops") + output["transformers"] = get_package_version("transformers") return output diff --git a/monai/networks/nets/vltransformer.py b/monai/networks/nets/vltransformer.py new file mode 100644 index 0000000000..9638c649f6 --- /dev/null +++ b/monai/networks/nets/vltransformer.py @@ -0,0 +1,343 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import os +import shutil +import tarfile +import tempfile +from typing import Sequence, Union + +import torch +from torch import nn + +from monai.utils import optional_import + +load_tf_weights_in_bert = optional_import("transformers", name="load_tf_weights_in_bert") +cached_path, _ = optional_import("transformers.file_utils", name="cached_path") +BertEmbeddings, _ = optional_import("transformers.models.bert.modeling_bert", name="BertEmbeddings") +BertLayer, _ = optional_import("transformers.models.bert.modeling_bert", name="BertLayer") + + +class BertPreTrainedModel(nn.Module): + """Module to load BERT pre-trained weights. + Based on: + LXMERT + https://github.com/airsplay/lxmert + + BERT (pytorch-transformer) + https://github.com/huggingface/transformers + """ + + def __init__(self, *inputs, **kwargs) -> None: + super(BertPreTrainedModel, self).__init__() + + def init_bert_weights(self, module): + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, torch.nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + @classmethod + def from_pretrained(cls, state_dict=None, cache_dir=None, from_tf=False, *inputs, **kwargs): + archive_file = "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz" + resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) + tempdir = None + if os.path.isdir(resolved_archive_file) or from_tf: + serialization_dir = resolved_archive_file + else: + tempdir = tempfile.mkdtemp() + with tarfile.open(resolved_archive_file, "r:gz") as archive: + archive.extractall(tempdir) + serialization_dir = tempdir + model = cls(*inputs, **kwargs) + if state_dict is None and not from_tf: + weights_path = os.path.join(serialization_dir, "pytorch_model.bin") + state_dict = torch.load(weights_path, map_location="cpu" if not torch.cuda.is_available() else None) + if tempdir: + shutil.rmtree(tempdir) + if from_tf: + weights_path = os.path.join(serialization_dir, "model.ckpt") + return load_tf_weights_in_bert(model, weights_path) + old_keys = [] + new_keys = [] + for key in state_dict.keys(): + new_key = None + if "gamma" in key: + new_key = key.replace("gamma", "weight") + if "beta" in key: + new_key = key.replace("beta", "bias") + if new_key: + old_keys.append(key) + new_keys.append(new_key) + for old_key, new_key in zip(old_keys, new_keys): + state_dict[new_key] = state_dict.pop(old_key) + missing_keys = [] + unexpected_keys = [] + error_msgs = [] + metadata = getattr(state_dict, "_metadata", None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + def load(module, prefix=""): + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + module._load_from_state_dict( + state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs + ) + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + ".") + + start_prefix = "" + if not hasattr(model, "bert") and any(s.startswith("bert.") for s in state_dict.keys()): + start_prefix = "bert." + load(model, prefix=start_prefix) + return model + + +class BertAttention(nn.Module): + """BERT attention layer. + Based on: BERT (pytorch-transformer) + https://github.com/huggingface/transformers + """ + + def __init__( + self, + config, + ) -> None: + super().__init__() + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states, context): + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(context) + mixed_value_layer = self.value(context) + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + attention_probs = self.dropout(nn.Softmax(dim=-1)(attention_scores)) + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + return context_layer + + +class BertOutput(nn.Module): + """BERT output layer. + Based on: BERT (pytorch-transformer) + https://github.com/huggingface/transformers + """ + + def __init__(self, config) -> None: + super(BertOutput, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = torch.nn.LayerNorm(config.hidden_size, eps=1e-12) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertMixedLayer(nn.Module): + """BERT cross attention layer. + Based on: BERT (pytorch-transformer) + https://github.com/huggingface/transformers + """ + + def __init__( + self, + config, + ) -> None: + super().__init__() + self.att = BertAttention(config) + self.output = BertOutput(config) + + def forward(self, x, y): + output = self.att(x, y) + return self.output(output, x) + + +class Pooler(nn.Module): + """BERT pooler layer. + Based on: BERT (pytorch-transformer) + https://github.com/huggingface/transformers + """ + + def __init__( + self, + hidden_size, + ) -> None: + super(Pooler, self).__init__() + self.dense = nn.Linear(hidden_size, hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class MultiModal(BertPreTrainedModel): + """ + Multimodal Transformers From Pretrained BERT Weights" + """ + + def __init__( + self, + num_language_layers: int = 2, + num_vision_layers: int = 2, + num_mixed_layers: int = 2, + ) -> None: + """ + Args: + num_language_layers: number of language transformer layers. + num_vision_layers: number of vision transformer layers. + num_mixed_layers: number of mixed transformer layers. + """ + super().__init__() + bert_config = type( + "obj", + (object,), + { + "attention_probs_dropout_prob": 0.1, + "classifier_dropout": None, + "gradient_checkpointing": False, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 3072, + "layer_norm_eps": 1e-12, + "max_position_embeddings": 512, + "model_type": "bert", + "num_attention_heads": 12, + "num_hidden_layers": 12, + "pad_token_id": 0, + "position_embedding_type": "absolute", + "transformers_version": "4.10.2", + "type_vocab_size": 2, + "use_cache": True, + "vocab_size": 30522, + "chunk_size_feed_forward": 0, + "is_decoder": False, + "add_cross_attention": False, + }, + ) + + self.config = bert_config + self.embeddings = BertEmbeddings(bert_config) + self.language_encoder = nn.ModuleList([BertLayer(bert_config) for _ in range(num_language_layers)]) + self.vision_encoder = nn.ModuleList([BertLayer(bert_config) for _ in range(num_vision_layers)]) + self.mixed_encoder = nn.ModuleList([BertMixedLayer(bert_config) for _ in range(num_mixed_layers)]) + self.apply(self.init_bert_weights) + + def forward(self, input_ids, token_type_ids=None, vision_feats=None, attention_mask=None): + language_features = self.embeddings(input_ids, token_type_ids) + for layer in self.vision_encoder: + hidden_state_vision = layer(vision_feats, None)[0] + for layer in self.language_encoder: + hidden_state_language = layer(language_features, attention_mask)[0] + for layer in self.mixed_encoder: + hidden_state_mixed = layer(hidden_state_language, hidden_state_vision) + return hidden_state_mixed + + +class VLTransformers(torch.nn.Module): + """ + Vision Language Multimodal Transformers" + """ + + def __init__( + self, + in_channels: int, + img_size: Union[Sequence[int], int], + patch_size: Union[Sequence[int], int], + num_classes: int = 2, + num_language_layers: int = 2, + num_vision_layers: int = 2, + num_mixed_layers: int = 2, + drop_out: float = 0.1, + ) -> None: + """ + Args: + in_channels: dimension of input channels. + img_size: dimension of input image. + patch_size: dimension of patch size. + num_classes: number of classes if classification is used. + num_language_layers: number of language transformer layers. + num_vision_layers: number of vision transformer layers. + num_mixed_layers: number of mixed transformer layers. + drop_out: faction of the input units to drop. + + Examples:: + + # for 3-channel with image size of (224,224), patch size of (32,32), 3 classes, 2 language layers, 2 vision layers and 2 mixed modality layers + >>> net = VLTransformers(in_channels=3, img_size=(224, 224), num_classes=3, num_language_layers=2, num_vision_layers=2, num_mixed_layers=2) + + """ + super(VLTransformers, self).__init__() + + if not (0 <= drop_out <= 1): + raise ValueError("dropout_rate should be between 0 and 1.") + + if (img_size[0] % patch_size[0] != 0) or (img_size[1] % patch_size[1] != 0): + raise ValueError("img_size should be divisible by patch_size.") + + self.multimodal = MultiModal( + num_language_layers=num_language_layers, + num_vision_layers=num_vision_layers, + num_mixed_layers=num_mixed_layers, + ).from_pretrained() + self.embed_dim = 768 + self.patch_size = patch_size + self.num_patches = (img_size[0] // self.patch_size[0]) * (img_size[1] // self.patch_size[1]) + self.vision_proj = nn.Conv2d(in_channels, self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size) + self.norm_vision_pos = nn.LayerNorm(self.embed_dim) + self.pos_embed_vis = nn.Parameter(torch.zeros(1, self.num_patches, self.embed_dim)) + self.pooler = Pooler(hidden_size=self.embed_dim) + self.drop = torch.nn.Dropout(drop_out) + self.cls_head = torch.nn.Linear(self.embed_dim, num_classes) + + def forward(self, input_ids, token_type_ids=None, vision_feats=None): + attention_mask = torch.ones_like(input_ids).unsqueeze(1).unsqueeze(2) + attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) + attention_mask = (1.0 - attention_mask) * -10000.0 + vision_feats = self.vision_proj(vision_feats).flatten(2).transpose(1, 2) + vision_feats = self.norm_vision_pos(vision_feats) + vision_feats = vision_feats + self.pos_embed_vis + hidden_state_mixed = self.multimodal( + input_ids=input_ids, token_type_ids=token_type_ids, vision_feats=vision_feats, attention_mask=attention_mask + ) + pooled_features = self.pooler(hidden_state_mixed) + logits = self.cls_head(self.drop(pooled_features)) + return logits diff --git a/requirements-dev.txt b/requirements-dev.txt index 785454ad5d..dffaf0779c 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -36,3 +36,4 @@ openslide-python==1.1.2 pandas requests einops +transformers==4.10.2 diff --git a/setup.cfg b/setup.cfg index 6efe768a6f..e5fa282973 100644 --- a/setup.cfg +++ b/setup.cfg @@ -44,6 +44,7 @@ all = openslide-python==1.1.2 pandas einops + transformers==4.10.2 nibabel = nibabel skimage = @@ -74,6 +75,8 @@ pandas = pandas einops = einops +transformers = + transformers [flake8] select = B,C,E,F,N,P,T4,W,B9 max_line_length = 120 diff --git a/tests/test_vltransformer.py b/tests/test_vltransformer.py new file mode 100644 index 0000000000..841cf0b764 --- /dev/null +++ b/tests/test_vltransformer.py @@ -0,0 +1,79 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets.vltransformer import VLTransformers + +TEST_CASE_VLTransformers = [] +TEST_CASE_Vit = [] +for drop_out in [0.4]: + for in_channels in [3]: + for img_size in [224]: + for patch_size in [16, 32]: + for num_language_layers in [2]: + for num_vision_layers in [4]: + for num_mixed_layers in [3]: + for num_classes in [8]: + test_case = [ + { + "in_channels": in_channels, + "img_size": (img_size,) * 2, + "patch_size": (patch_size,) * 2, + "num_vision_layers": num_vision_layers, + "num_mixed_layers": num_mixed_layers, + "num_language_layers": num_language_layers, + "num_classes": num_classes, + "drop_out": drop_out, + }, + (2, num_classes), + ] + TEST_CASE_VLTransformers.append(test_case) + + +class TestPatchEmbeddingBlock(unittest.TestCase): + @parameterized.expand(TEST_CASE_VLTransformers) + def test_shape(self, input_param, expected_shape): + net = VLTransformers(**input_param) + with eval_mode(net): + result = net(torch.randint(2, (2, 512)), torch.randint(2, (2, 512)), torch.randn((2, 3, 224, 224))) + self.assertEqual(result.shape, expected_shape) + + def test_ill_arg(self): + with self.assertRaises(ValueError): + VLTransformers( + in_channels=3, + img_size=(128, 128), + patch_size=(16, 16), + num_language_layers=2, + num_mixed_layers=4, + num_vision_layers=2, + drop_out=5.0, + ) + + with self.assertRaises(ValueError): + VLTransformers( + in_channels=1, + img_size=(97, 97), + patch_size=(16, 16), + num_language_layers=6, + num_mixed_layers=6, + num_vision_layers=8, + drop_out=0.4, + ) + + +if __name__ == "__main__": + unittest.main() From c9c632aba70a375087d945ba40bd495075efa591 Mon Sep 17 00:00:00 2001 From: ahatamizadeh Date: Mon, 13 Sep 2021 03:17:48 -0700 Subject: [PATCH 06/19] add multimodal transformers Signed-off-by: ahatamizadeh --- monai/networks/nets/vltransformer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/monai/networks/nets/vltransformer.py b/monai/networks/nets/vltransformer.py index 9638c649f6..0b853dae9f 100644 --- a/monai/networks/nets/vltransformer.py +++ b/monai/networks/nets/vltransformer.py @@ -301,8 +301,10 @@ def __init__( Examples:: - # for 3-channel with image size of (224,224), patch size of (32,32), 3 classes, 2 language layers, 2 vision layers and 2 mixed modality layers - >>> net = VLTransformers(in_channels=3, img_size=(224, 224), num_classes=3, num_language_layers=2, num_vision_layers=2, num_mixed_layers=2) + # for 3-channel with image size of (224,224), patch size of (32,32), 3 classes, 2 language layers, + 2 vision layers and 2 mixed modality layers + >>> net = VLTransformers(in_channels=3, img_size=(224, 224), num_classes=3, num_language_layers=2, + num_vision_layers=2, num_mixed_layers=2) """ super(VLTransformers, self).__init__() From 2cd138e79fd205236b5ed9775fb1e1d10a00af6d Mon Sep 17 00:00:00 2001 From: ahatamizadeh Date: Mon, 13 Sep 2021 03:29:52 -0700 Subject: [PATCH 07/19] add multimodal transformers Signed-off-by: ahatamizadeh --- monai/networks/nets/vltransformer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/monai/networks/nets/vltransformer.py b/monai/networks/nets/vltransformer.py index 0b853dae9f..72b43a2d92 100644 --- a/monai/networks/nets/vltransformer.py +++ b/monai/networks/nets/vltransformer.py @@ -21,6 +21,7 @@ from monai.utils import optional_import +transformers = optional_import("transformers") load_tf_weights_in_bert = optional_import("transformers", name="load_tf_weights_in_bert") cached_path, _ = optional_import("transformers.file_utils", name="cached_path") BertEmbeddings, _ = optional_import("transformers.models.bert.modeling_bert", name="BertEmbeddings") From 4b5bdd0bfe6018783a71a26d669d32bb5d25f111 Mon Sep 17 00:00:00 2001 From: ahatamizadeh Date: Mon, 13 Sep 2021 03:38:37 -0700 Subject: [PATCH 08/19] add multimodal transformers Signed-off-by: ahatamizadeh --- monai/networks/nets/vltransformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/nets/vltransformer.py b/monai/networks/nets/vltransformer.py index 72b43a2d92..ad9d9e9d28 100644 --- a/monai/networks/nets/vltransformer.py +++ b/monai/networks/nets/vltransformer.py @@ -17,11 +17,11 @@ from typing import Sequence, Union import torch +import transformers from torch import nn from monai.utils import optional_import -transformers = optional_import("transformers") load_tf_weights_in_bert = optional_import("transformers", name="load_tf_weights_in_bert") cached_path, _ = optional_import("transformers.file_utils", name="cached_path") BertEmbeddings, _ = optional_import("transformers.models.bert.modeling_bert", name="BertEmbeddings") From 9662e81df759098c320560272ee34b8eba408dd2 Mon Sep 17 00:00:00 2001 From: ahatamizadeh Date: Mon, 13 Sep 2021 03:41:00 -0700 Subject: [PATCH 09/19] add multimodal transformers Signed-off-by: ahatamizadeh --- monai/networks/nets/vltransformer.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/monai/networks/nets/vltransformer.py b/monai/networks/nets/vltransformer.py index ad9d9e9d28..6ea5ffc888 100644 --- a/monai/networks/nets/vltransformer.py +++ b/monai/networks/nets/vltransformer.py @@ -19,13 +19,9 @@ import torch import transformers from torch import nn - -from monai.utils import optional_import - -load_tf_weights_in_bert = optional_import("transformers", name="load_tf_weights_in_bert") -cached_path, _ = optional_import("transformers.file_utils", name="cached_path") -BertEmbeddings, _ = optional_import("transformers.models.bert.modeling_bert", name="BertEmbeddings") -BertLayer, _ = optional_import("transformers.models.bert.modeling_bert", name="BertLayer") +from transformers import load_tf_weights_in_bert +from transformers.file_utils import cached_path +from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer class BertPreTrainedModel(nn.Module): From f2f4f91b224ef44f29167b79616f945bda9062e7 Mon Sep 17 00:00:00 2001 From: ahatamizadeh Date: Mon, 13 Sep 2021 07:29:15 -0700 Subject: [PATCH 10/19] add multimodal transformers Signed-off-by: ahatamizadeh --- monai/networks/nets/vltransformer.py | 13 ++++++++----- requirements-dev.txt | 2 +- setup.cfg | 2 +- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/monai/networks/nets/vltransformer.py b/monai/networks/nets/vltransformer.py index 6ea5ffc888..83a72def54 100644 --- a/monai/networks/nets/vltransformer.py +++ b/monai/networks/nets/vltransformer.py @@ -17,11 +17,15 @@ from typing import Sequence, Union import torch -import transformers from torch import nn -from transformers import load_tf_weights_in_bert -from transformers.file_utils import cached_path -from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer + +from monai.utils import optional_import + +transformers = optional_import("transformers") +load_tf_weights_in_bert = optional_import("transformers", name="load_tf_weights_in_bert") +cached_path, _ = optional_import("transformers.file_utils", name="cached_path") +BertEmbeddings, _ = optional_import("transformers.models.bert.modeling_bert", name="BertEmbeddings") +BertLayer, _ = optional_import("transformers.models.bert.modeling_bert", name="BertLayer") class BertPreTrainedModel(nn.Module): @@ -250,7 +254,6 @@ def __init__( "add_cross_attention": False, }, ) - self.config = bert_config self.embeddings = BertEmbeddings(bert_config) self.language_encoder = nn.ModuleList([BertLayer(bert_config) for _ in range(num_language_layers)]) diff --git a/requirements-dev.txt b/requirements-dev.txt index dffaf0779c..ed8739ded8 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -36,4 +36,4 @@ openslide-python==1.1.2 pandas requests einops -transformers==4.10.2 +transformers diff --git a/setup.cfg b/setup.cfg index e5fa282973..f7ed90a14a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -44,7 +44,7 @@ all = openslide-python==1.1.2 pandas einops - transformers==4.10.2 + transformers nibabel = nibabel skimage = From 97c518c69dad3d1c00dd6f3119239f04eca5d8a3 Mon Sep 17 00:00:00 2001 From: ahatamizadeh Date: Mon, 13 Sep 2021 08:23:26 -0700 Subject: [PATCH 11/19] add multimodal transformers Signed-off-by: ahatamizadeh --- monai/networks/nets/vltransformer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/networks/nets/vltransformer.py b/monai/networks/nets/vltransformer.py index 83a72def54..ea44007d7a 100644 --- a/monai/networks/nets/vltransformer.py +++ b/monai/networks/nets/vltransformer.py @@ -23,9 +23,9 @@ transformers = optional_import("transformers") load_tf_weights_in_bert = optional_import("transformers", name="load_tf_weights_in_bert") -cached_path, _ = optional_import("transformers.file_utils", name="cached_path") -BertEmbeddings, _ = optional_import("transformers.models.bert.modeling_bert", name="BertEmbeddings") -BertLayer, _ = optional_import("transformers.models.bert.modeling_bert", name="BertLayer") +cached_path = optional_import("transformers.file_utils", name="cached_path")[0] +BertEmbeddings = optional_import("transformers.models.bert.modeling_bert", name="BertEmbeddings")[0] +BertLayer = optional_import("transformers.models.bert.modeling_bert", name="BertLayer")[0] class BertPreTrainedModel(nn.Module): From 6fb363a8b1f338bf4b8f006285e4a88bac63a30e Mon Sep 17 00:00:00 2001 From: ahatamizadeh Date: Mon, 13 Sep 2021 08:59:12 -0700 Subject: [PATCH 12/19] add multimodal transformers Signed-off-by: ahatamizadeh --- tests/min_tests.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/min_tests.py b/tests/min_tests.py index 5b376d7b57..bac6521889 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -140,6 +140,7 @@ def run_testsuit(): "test_zoom", "test_zoom_affine", "test_zoomd", + "test_vltransformer", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" From ea89c1f41978f501c0209315cc30dda04a404600 Mon Sep 17 00:00:00 2001 From: ahatamizadeh Date: Mon, 13 Sep 2021 09:27:11 -0700 Subject: [PATCH 13/19] add multimodal transformers Signed-off-by: ahatamizadeh --- monai/networks/nets/vltransformer.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/monai/networks/nets/vltransformer.py b/monai/networks/nets/vltransformer.py index ea44007d7a..1e0e368c2a 100644 --- a/monai/networks/nets/vltransformer.py +++ b/monai/networks/nets/vltransformer.py @@ -280,8 +280,8 @@ class VLTransformers(torch.nn.Module): def __init__( self, in_channels: int, - img_size: Union[Sequence[int], int], - patch_size: Union[Sequence[int], int], + img_size: Union[Sequence[int], int], # type: ignore + patch_size: Union[Sequence[int], int], # type: ignore num_classes: int = 2, num_language_layers: int = 2, num_vision_layers: int = 2, @@ -323,7 +323,9 @@ def __init__( self.embed_dim = 768 self.patch_size = patch_size self.num_patches = (img_size[0] // self.patch_size[0]) * (img_size[1] // self.patch_size[1]) - self.vision_proj = nn.Conv2d(in_channels, self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size) + self.vision_proj = nn.Conv2d( + in_channels, self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size + ) # type: ignore self.norm_vision_pos = nn.LayerNorm(self.embed_dim) self.pos_embed_vis = nn.Parameter(torch.zeros(1, self.num_patches, self.embed_dim)) self.pooler = Pooler(hidden_size=self.embed_dim) From 51abbf1d1bc19f38c23ab8e6ac377b3385b559c6 Mon Sep 17 00:00:00 2001 From: ahatamizadeh Date: Mon, 13 Sep 2021 09:48:07 -0700 Subject: [PATCH 14/19] add multimodal transformers Signed-off-by: ahatamizadeh --- monai/networks/nets/vltransformer.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/monai/networks/nets/vltransformer.py b/monai/networks/nets/vltransformer.py index 1e0e368c2a..6767359c24 100644 --- a/monai/networks/nets/vltransformer.py +++ b/monai/networks/nets/vltransformer.py @@ -14,7 +14,7 @@ import shutil import tarfile import tempfile -from typing import Sequence, Union +from typing import Tuple, Union import torch from torch import nn @@ -280,8 +280,8 @@ class VLTransformers(torch.nn.Module): def __init__( self, in_channels: int, - img_size: Union[Sequence[int], int], # type: ignore - patch_size: Union[Sequence[int], int], # type: ignore + img_size: Union[int, Tuple[int, int]], + patch_size: Union[int, Tuple[int, int]], num_classes: int = 2, num_language_layers: int = 2, num_vision_layers: int = 2, @@ -323,9 +323,7 @@ def __init__( self.embed_dim = 768 self.patch_size = patch_size self.num_patches = (img_size[0] // self.patch_size[0]) * (img_size[1] // self.patch_size[1]) - self.vision_proj = nn.Conv2d( - in_channels, self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size - ) # type: ignore + self.vision_proj = nn.Conv2d(in_channels, self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size) self.norm_vision_pos = nn.LayerNorm(self.embed_dim) self.pos_embed_vis = nn.Parameter(torch.zeros(1, self.num_patches, self.embed_dim)) self.pooler = Pooler(hidden_size=self.embed_dim) From f3b7c5b19c71cc2683491c6b8e6d2cb1922fa9bf Mon Sep 17 00:00:00 2001 From: ahatamizadeh Date: Mon, 13 Sep 2021 10:14:44 -0700 Subject: [PATCH 15/19] add multimodal transformers Signed-off-by: ahatamizadeh --- monai/networks/nets/vltransformer.py | 8 +++++--- tests/test_vltransformer.py | 3 +-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/monai/networks/nets/vltransformer.py b/monai/networks/nets/vltransformer.py index 6767359c24..f3ce3107e4 100644 --- a/monai/networks/nets/vltransformer.py +++ b/monai/networks/nets/vltransformer.py @@ -280,8 +280,8 @@ class VLTransformers(torch.nn.Module): def __init__( self, in_channels: int, - img_size: Union[int, Tuple[int, int]], - patch_size: Union[int, Tuple[int, int]], + img_size: Union[int, Tuple[int, int]], # type: ignore + patch_size: Union[int, Tuple[int, int]], # type: ignore num_classes: int = 2, num_language_layers: int = 2, num_vision_layers: int = 2, @@ -323,7 +323,9 @@ def __init__( self.embed_dim = 768 self.patch_size = patch_size self.num_patches = (img_size[0] // self.patch_size[0]) * (img_size[1] // self.patch_size[1]) - self.vision_proj = nn.Conv2d(in_channels, self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size) + self.vision_proj = nn.Conv2d( + in_channels, self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size # type: ignore + ) # type: ignore self.norm_vision_pos = nn.LayerNorm(self.embed_dim) self.pos_embed_vis = nn.Parameter(torch.zeros(1, self.num_patches, self.embed_dim)) self.pooler = Pooler(hidden_size=self.embed_dim) diff --git a/tests/test_vltransformer.py b/tests/test_vltransformer.py index 841cf0b764..bd403d3b0d 100644 --- a/tests/test_vltransformer.py +++ b/tests/test_vltransformer.py @@ -18,7 +18,6 @@ from monai.networks.nets.vltransformer import VLTransformers TEST_CASE_VLTransformers = [] -TEST_CASE_Vit = [] for drop_out in [0.4]: for in_channels in [3]: for img_size in [224]: @@ -38,7 +37,7 @@ "num_classes": num_classes, "drop_out": drop_out, }, - (2, num_classes), + (2, num_classes), # type: ignore ] TEST_CASE_VLTransformers.append(test_case) From 19f0e9858dd8a649b1b3c0049d7ed3e17f5033de Mon Sep 17 00:00:00 2001 From: ahatamizadeh Date: Mon, 13 Sep 2021 10:35:37 -0700 Subject: [PATCH 16/19] add multimodal transformers Signed-off-by: ahatamizadeh --- monai/networks/nets/vltransformer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/networks/nets/vltransformer.py b/monai/networks/nets/vltransformer.py index f3ce3107e4..248823cd17 100644 --- a/monai/networks/nets/vltransformer.py +++ b/monai/networks/nets/vltransformer.py @@ -14,7 +14,7 @@ import shutil import tarfile import tempfile -from typing import Tuple, Union +from typing import Sequence, Union import torch from torch import nn @@ -280,8 +280,8 @@ class VLTransformers(torch.nn.Module): def __init__( self, in_channels: int, - img_size: Union[int, Tuple[int, int]], # type: ignore - patch_size: Union[int, Tuple[int, int]], # type: ignore + img_size: Union[Sequence[int], int], # type: ignore + patch_size: Union[Sequence[int], int], # type: ignore num_classes: int = 2, num_language_layers: int = 2, num_vision_layers: int = 2, From 8829c95688ba2efa5f7d78846c783d74955a8c3e Mon Sep 17 00:00:00 2001 From: ahatamizadeh Date: Mon, 13 Sep 2021 10:37:23 -0700 Subject: [PATCH 17/19] add multimodal transformers Signed-off-by: ahatamizadeh --- monai/networks/nets/vltransformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/networks/nets/vltransformer.py b/monai/networks/nets/vltransformer.py index 248823cd17..2eba56e8a8 100644 --- a/monai/networks/nets/vltransformer.py +++ b/monai/networks/nets/vltransformer.py @@ -312,7 +312,7 @@ def __init__( if not (0 <= drop_out <= 1): raise ValueError("dropout_rate should be between 0 and 1.") - if (img_size[0] % patch_size[0] != 0) or (img_size[1] % patch_size[1] != 0): + if (img_size[0] % patch_size[0] != 0) or (img_size[1] % patch_size[1] != 0): # type: ignore raise ValueError("img_size should be divisible by patch_size.") self.multimodal = MultiModal( @@ -322,7 +322,7 @@ def __init__( ).from_pretrained() self.embed_dim = 768 self.patch_size = patch_size - self.num_patches = (img_size[0] // self.patch_size[0]) * (img_size[1] // self.patch_size[1]) + self.num_patches = (img_size[0] // self.patch_size[0]) * (img_size[1] // self.patch_size[1]) # type: ignore self.vision_proj = nn.Conv2d( in_channels, self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size # type: ignore ) # type: ignore From 7be790dac0381cc7a3ed393d351f2a860570cbdd Mon Sep 17 00:00:00 2001 From: ahatamizadeh Date: Wed, 15 Sep 2021 13:55:30 -0700 Subject: [PATCH 18/19] add multimodal transformers Signed-off-by: ahatamizadeh --- monai/networks/nets/vltransformer.py | 116 ++++++++++++++------------- tests/test_vltransformer.py | 2 + 2 files changed, 64 insertions(+), 54 deletions(-) diff --git a/monai/networks/nets/vltransformer.py b/monai/networks/nets/vltransformer.py index 2eba56e8a8..9ac0475043 100644 --- a/monai/networks/nets/vltransformer.py +++ b/monai/networks/nets/vltransformer.py @@ -33,7 +33,6 @@ class BertPreTrainedModel(nn.Module): Based on: LXMERT https://github.com/airsplay/lxmert - BERT (pytorch-transformer) https://github.com/huggingface/transformers """ @@ -51,7 +50,18 @@ def init_bert_weights(self, module): module.bias.data.zero_() @classmethod - def from_pretrained(cls, state_dict=None, cache_dir=None, from_tf=False, *inputs, **kwargs): + def from_pretrained( + cls, + num_language_layers, + num_vision_layers, + num_mixed_layers, + bert_config, + state_dict=None, + cache_dir=None, + from_tf=False, + *inputs, + **kwargs, + ): archive_file = "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz" resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) tempdir = None @@ -62,7 +72,7 @@ def from_pretrained(cls, state_dict=None, cache_dir=None, from_tf=False, *inputs with tarfile.open(resolved_archive_file, "r:gz") as archive: archive.extractall(tempdir) serialization_dir = tempdir - model = cls(*inputs, **kwargs) + model = cls(num_language_layers, num_vision_layers, num_mixed_layers, bert_config, *inputs, **kwargs) if state_dict is None and not from_tf: weights_path = os.path.join(serialization_dir, "pytorch_model.bin") state_dict = torch.load(weights_path, map_location="cpu" if not torch.cuda.is_available() else None) @@ -215,50 +225,24 @@ class MultiModal(BertPreTrainedModel): def __init__( self, - num_language_layers: int = 2, - num_vision_layers: int = 2, - num_mixed_layers: int = 2, + num_language_layers: int, + num_vision_layers: int, + num_mixed_layers: int, + bert_config: dict, # type: ignore ) -> None: """ Args: num_language_layers: number of language transformer layers. num_vision_layers: number of vision transformer layers. - num_mixed_layers: number of mixed transformer layers. + bert_config: configuration for bert language transformer encoder. + """ super().__init__() - bert_config = type( - "obj", - (object,), - { - "attention_probs_dropout_prob": 0.1, - "classifier_dropout": None, - "gradient_checkpointing": False, - "hidden_act": "gelu", - "hidden_dropout_prob": 0.1, - "hidden_size": 768, - "initializer_range": 0.02, - "intermediate_size": 3072, - "layer_norm_eps": 1e-12, - "max_position_embeddings": 512, - "model_type": "bert", - "num_attention_heads": 12, - "num_hidden_layers": 12, - "pad_token_id": 0, - "position_embedding_type": "absolute", - "transformers_version": "4.10.2", - "type_vocab_size": 2, - "use_cache": True, - "vocab_size": 30522, - "chunk_size_feed_forward": 0, - "is_decoder": False, - "add_cross_attention": False, - }, - ) - self.config = bert_config - self.embeddings = BertEmbeddings(bert_config) - self.language_encoder = nn.ModuleList([BertLayer(bert_config) for _ in range(num_language_layers)]) - self.vision_encoder = nn.ModuleList([BertLayer(bert_config) for _ in range(num_vision_layers)]) - self.mixed_encoder = nn.ModuleList([BertMixedLayer(bert_config) for _ in range(num_mixed_layers)]) + self.config = type("obj", (object,), bert_config) + self.embeddings = BertEmbeddings(self.config) + self.language_encoder = nn.ModuleList([BertLayer(self.config) for _ in range(num_language_layers)]) + self.vision_encoder = nn.ModuleList([BertLayer(self.config) for _ in range(num_vision_layers)]) + self.mixed_encoder = nn.ModuleList([BertMixedLayer(self.config) for _ in range(num_mixed_layers)]) self.apply(self.init_bert_weights) def forward(self, input_ids, token_type_ids=None, vision_feats=None, attention_mask=None): @@ -282,11 +266,35 @@ def __init__( in_channels: int, img_size: Union[Sequence[int], int], # type: ignore patch_size: Union[Sequence[int], int], # type: ignore - num_classes: int = 2, - num_language_layers: int = 2, - num_vision_layers: int = 2, - num_mixed_layers: int = 2, - drop_out: float = 0.1, + num_classes: int, + num_language_layers: int, + num_vision_layers: int, + num_mixed_layers: int, + drop_out: float = 0.0, + bert_config: dict = { + "attention_probs_dropout_prob": 0.1, + "classifier_dropout": None, + "gradient_checkpointing": False, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 3072, + "layer_norm_eps": 1e-12, + "max_position_embeddings": 512, + "model_type": "bert", + "num_attention_heads": 12, + "num_hidden_layers": 12, + "pad_token_id": 0, + "position_embedding_type": "absolute", + "transformers_version": "4.10.2", + "type_vocab_size": 2, + "use_cache": True, + "vocab_size": 30522, + "chunk_size_feed_forward": 0, + "is_decoder": False, + "add_cross_attention": False, + }, ) -> None: """ Args: @@ -298,14 +306,12 @@ def __init__( num_vision_layers: number of vision transformer layers. num_mixed_layers: number of mixed transformer layers. drop_out: faction of the input units to drop. - + bert_config: configuration for bert language transformer encoder. Examples:: - # for 3-channel with image size of (224,224), patch size of (32,32), 3 classes, 2 language layers, - 2 vision layers and 2 mixed modality layers + 2 vision layers, 2 mixed modality layers and dropout of 0.2 in the classification head >>> net = VLTransformers(in_channels=3, img_size=(224, 224), num_classes=3, num_language_layers=2, - num_vision_layers=2, num_mixed_layers=2) - + num_vision_layers=2, num_mixed_layers=2, drop_out=0.2) """ super(VLTransformers, self).__init__() @@ -315,17 +321,19 @@ def __init__( if (img_size[0] % patch_size[0] != 0) or (img_size[1] % patch_size[1] != 0): # type: ignore raise ValueError("img_size should be divisible by patch_size.") - self.multimodal = MultiModal( + self.multimodal = MultiModal.from_pretrained( num_language_layers=num_language_layers, num_vision_layers=num_vision_layers, num_mixed_layers=num_mixed_layers, - ).from_pretrained() + bert_config=bert_config, + ) + self.embed_dim = 768 self.patch_size = patch_size self.num_patches = (img_size[0] // self.patch_size[0]) * (img_size[1] // self.patch_size[1]) # type: ignore self.vision_proj = nn.Conv2d( - in_channels, self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size # type: ignore - ) # type: ignore + in_channels=in_channels, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size + ) self.norm_vision_pos = nn.LayerNorm(self.embed_dim) self.pos_embed_vis = nn.Parameter(torch.zeros(1, self.num_patches, self.embed_dim)) self.pooler = Pooler(hidden_size=self.embed_dim) diff --git a/tests/test_vltransformer.py b/tests/test_vltransformer.py index bd403d3b0d..a92a9bf79a 100644 --- a/tests/test_vltransformer.py +++ b/tests/test_vltransformer.py @@ -59,6 +59,7 @@ def test_ill_arg(self): num_language_layers=2, num_mixed_layers=4, num_vision_layers=2, + num_classes=2, drop_out=5.0, ) @@ -70,6 +71,7 @@ def test_ill_arg(self): num_language_layers=6, num_mixed_layers=6, num_vision_layers=8, + num_classes=8, drop_out=0.4, ) From 7187c69223af734cd64c98d2f55490ca9bd469ac Mon Sep 17 00:00:00 2001 From: ahatamizadeh Date: Wed, 15 Sep 2021 14:15:20 -0700 Subject: [PATCH 19/19] add multimodal transformers Signed-off-by: ahatamizadeh --- docs/requirements.txt | 1 + docs/source/installation.md | 4 +- monai/config/deviceconfig.py | 1 + monai/networks/nets/vltransformer.py | 355 +++++++++++++++++++++++++++ monai/transforms/utility/array.py | 77 ++++-- monai/utils/type_conversion.py | 128 +++++++--- requirements-dev.txt | 1 + setup.cfg | 3 + tests/min_tests.py | 1 + tests/test_to_tensor.py | 16 +- tests/test_vltransformer.py | 80 ++++++ 11 files changed, 601 insertions(+), 66 deletions(-) create mode 100644 monai/networks/nets/vltransformer.py create mode 100644 tests/test_vltransformer.py diff --git a/docs/requirements.txt b/docs/requirements.txt index 00dd4d2c1e..3530d63c49 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -20,3 +20,4 @@ sphinxcontrib-serializinghtml sphinx-autodoc-typehints==1.11.1 pandas einops +transformers==4.10.2 diff --git a/docs/source/installation.md b/docs/source/installation.md index 08ab109142..902f596dfc 100644 --- a/docs/source/installation.md +++ b/docs/source/installation.md @@ -174,9 +174,9 @@ Since MONAI v0.2.0, the extras syntax such as `pip install 'monai[nibabel]'` is - The options are ``` -[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops] +[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers] ``` which correspond to `nibabel`, `scikit-image`, `pillow`, `tensorboard`, -`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas` and `einops`, respectively. +`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas` , `einops` and `transformers`, respectively. - `pip install 'monai[all]'` installs all the optional dependencies. diff --git a/monai/config/deviceconfig.py b/monai/config/deviceconfig.py index 273431fc72..ff45b29531 100644 --- a/monai/config/deviceconfig.py +++ b/monai/config/deviceconfig.py @@ -73,6 +73,7 @@ def get_optional_config_values(): output["psutil"] = psutil_version output["pandas"] = get_package_version("pandas") output["einops"] = get_package_version("einops") + output["transformers"] = get_package_version("transformers") return output diff --git a/monai/networks/nets/vltransformer.py b/monai/networks/nets/vltransformer.py new file mode 100644 index 0000000000..af095a181c --- /dev/null +++ b/monai/networks/nets/vltransformer.py @@ -0,0 +1,355 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import os +import shutil +import tarfile +import tempfile +from typing import Sequence, Union + +import torch +from torch import nn + +from monai.utils import optional_import + +transformers = optional_import("transformers") +load_tf_weights_in_bert = optional_import("transformers", name="load_tf_weights_in_bert") +cached_path = optional_import("transformers.file_utils", name="cached_path")[0] +BertEmbeddings = optional_import("transformers.models.bert.modeling_bert", name="BertEmbeddings")[0] +BertLayer = optional_import("transformers.models.bert.modeling_bert", name="BertLayer")[0] + + +class BertPreTrainedModel(nn.Module): + """Module to load BERT pre-trained weights. + Based on: + LXMERT + https://github.com/airsplay/lxmert + BERT (pytorch-transformer) + https://github.com/huggingface/transformers + """ + + def __init__(self, *inputs, **kwargs) -> None: + super(BertPreTrainedModel, self).__init__() + + def init_bert_weights(self, module): + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, torch.nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + @classmethod + def from_pretrained( + cls, + num_language_layers, + num_vision_layers, + num_mixed_layers, + bert_config, + state_dict=None, + cache_dir=None, + from_tf=False, + *inputs, + **kwargs, + ): + archive_file = "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz" + resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) + tempdir = None + if os.path.isdir(resolved_archive_file) or from_tf: + serialization_dir = resolved_archive_file + else: + tempdir = tempfile.mkdtemp() + with tarfile.open(resolved_archive_file, "r:gz") as archive: + archive.extractall(tempdir) + serialization_dir = tempdir + model = cls(num_language_layers, num_vision_layers, num_mixed_layers, bert_config, *inputs, **kwargs) + if state_dict is None and not from_tf: + weights_path = os.path.join(serialization_dir, "pytorch_model.bin") + state_dict = torch.load(weights_path, map_location="cpu" if not torch.cuda.is_available() else None) + if tempdir: + shutil.rmtree(tempdir) + if from_tf: + weights_path = os.path.join(serialization_dir, "model.ckpt") + return load_tf_weights_in_bert(model, weights_path) + old_keys = [] + new_keys = [] + for key in state_dict.keys(): + new_key = None + if "gamma" in key: + new_key = key.replace("gamma", "weight") + if "beta" in key: + new_key = key.replace("beta", "bias") + if new_key: + old_keys.append(key) + new_keys.append(new_key) + for old_key, new_key in zip(old_keys, new_keys): + state_dict[new_key] = state_dict.pop(old_key) + missing_keys = [] + unexpected_keys = [] + error_msgs = [] + metadata = getattr(state_dict, "_metadata", None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + def load(module, prefix=""): + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + module._load_from_state_dict( + state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs + ) + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + ".") + + start_prefix = "" + if not hasattr(model, "bert") and any(s.startswith("bert.") for s in state_dict.keys()): + start_prefix = "bert." + load(model, prefix=start_prefix) + return model + + +class BertAttention(nn.Module): + """BERT attention layer. + Based on: BERT (pytorch-transformer) + https://github.com/huggingface/transformers + """ + + def __init__( + self, + config, + ) -> None: + super().__init__() + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states, context): + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(context) + mixed_value_layer = self.value(context) + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + attention_probs = self.dropout(nn.Softmax(dim=-1)(attention_scores)) + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + return context_layer + + +class BertOutput(nn.Module): + """BERT output layer. + Based on: BERT (pytorch-transformer) + https://github.com/huggingface/transformers + """ + + def __init__(self, config) -> None: + super(BertOutput, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = torch.nn.LayerNorm(config.hidden_size, eps=1e-12) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertMixedLayer(nn.Module): + """BERT cross attention layer. + Based on: BERT (pytorch-transformer) + https://github.com/huggingface/transformers + """ + + def __init__( + self, + config, + ) -> None: + super().__init__() + self.att = BertAttention(config) + self.output = BertOutput(config) + + def forward(self, x, y): + output = self.att(x, y) + return self.output(output, x) + + +class Pooler(nn.Module): + """BERT pooler layer. + Based on: BERT (pytorch-transformer) + https://github.com/huggingface/transformers + """ + + def __init__( + self, + hidden_size, + ) -> None: + super(Pooler, self).__init__() + self.dense = nn.Linear(hidden_size, hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class MultiModal(BertPreTrainedModel): + """ + Multimodal Transformers From Pretrained BERT Weights" + """ + + def __init__( + self, + num_language_layers: int, + num_vision_layers: int, + num_mixed_layers: int, + bert_config: dict, # type: ignore + ) -> None: + """ + Args: + num_language_layers: number of language transformer layers. + num_vision_layers: number of vision transformer layers. + bert_config: configuration for bert language transformer encoder. + + """ + super().__init__() + self.config = type("obj", (object,), bert_config) + self.embeddings = BertEmbeddings(self.config) + self.language_encoder = nn.ModuleList([BertLayer(self.config) for _ in range(num_language_layers)]) + self.vision_encoder = nn.ModuleList([BertLayer(self.config) for _ in range(num_vision_layers)]) + self.mixed_encoder = nn.ModuleList([BertMixedLayer(self.config) for _ in range(num_mixed_layers)]) + self.apply(self.init_bert_weights) + + def forward(self, input_ids, token_type_ids=None, vision_feats=None, attention_mask=None): + language_features = self.embeddings(input_ids, token_type_ids) + for layer in self.vision_encoder: + hidden_state_vision = layer(vision_feats, None)[0] + for layer in self.language_encoder: + hidden_state_language = layer(language_features, attention_mask)[0] + for layer in self.mixed_encoder: + hidden_state_mixed = layer(hidden_state_language, hidden_state_vision) + return hidden_state_mixed + + +class VLTransformers(torch.nn.Module): + """ + Vision Language Multimodal Transformers" + """ + + def __init__( + self, + in_channels: int, + img_size: Union[Sequence[int], int], # type: ignore + patch_size: Union[Sequence[int], int], # type: ignore + num_classes: int, + num_language_layers: int, + num_vision_layers: int, + num_mixed_layers: int, + drop_out: float = 0.0, + bert_config: dict = { + "attention_probs_dropout_prob": 0.1, + "classifier_dropout": None, + "gradient_checkpointing": False, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 3072, + "layer_norm_eps": 1e-12, + "max_position_embeddings": 512, + "model_type": "bert", + "num_attention_heads": 12, + "num_hidden_layers": 12, + "pad_token_id": 0, + "position_embedding_type": "absolute", + "transformers_version": "4.10.2", + "type_vocab_size": 2, + "use_cache": True, + "vocab_size": 30522, + "chunk_size_feed_forward": 0, + "is_decoder": False, + "add_cross_attention": False, + }, + ) -> None: + """ + Args: + in_channels: dimension of input channels. + img_size: dimension of input image. + patch_size: dimension of patch size. + num_classes: number of classes if classification is used. + num_language_layers: number of language transformer layers. + num_vision_layers: number of vision transformer layers. + num_mixed_layers: number of mixed transformer layers. + drop_out: faction of the input units to drop. + bert_config: configuration for bert language transformer encoder. + Examples:: + # for 3-channel with image size of (224,224), patch size of (32,32), 3 classes, 2 language layers, + 2 vision layers, 2 mixed modality layers and dropout of 0.2 in the classification head + >>> net = VLTransformers(in_channels=3, img_size=(224, 224), num_classes=3, num_language_layers=2, + num_vision_layers=2, num_mixed_layers=2, drop_out=0.2) + """ + super(VLTransformers, self).__init__() + + if not (0 <= drop_out <= 1): + raise ValueError("dropout_rate should be in the range of 0 and 1.") + + if (img_size[0] % patch_size[0] != 0) or (img_size[1] % patch_size[1] != 0): # type: ignore + raise ValueError("img_size should be divisible by patch_size.") + + self.multimodal = MultiModal.from_pretrained( + num_language_layers=num_language_layers, + num_vision_layers=num_vision_layers, + num_mixed_layers=num_mixed_layers, + bert_config=bert_config, + ) + + self.embed_dim = 768 + self.patch_size = patch_size + self.num_patches = (img_size[0] // self.patch_size[0]) * (img_size[1] // self.patch_size[1]) # type: ignore + self.vision_proj = nn.Conv2d( + in_channels=in_channels, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size + ) + self.norm_vision_pos = nn.LayerNorm(self.embed_dim) + self.pos_embed_vis = nn.Parameter(torch.zeros(1, self.num_patches, self.embed_dim)) + self.pooler = Pooler(hidden_size=self.embed_dim) + self.drop = torch.nn.Dropout(drop_out) + self.cls_head = torch.nn.Linear(self.embed_dim, num_classes) + + def forward(self, input_ids, token_type_ids=None, vision_feats=None): + attention_mask = torch.ones_like(input_ids).unsqueeze(1).unsqueeze(2) + attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) + attention_mask = (1.0 - attention_mask) * -10000.0 + vision_feats = self.vision_proj(vision_feats).flatten(2).transpose(1, 2) + vision_feats = self.norm_vision_pos(vision_feats) + vision_feats = vision_feats + self.pos_embed_vis + hidden_state_mixed = self.multimodal( + input_ids=input_ids, token_type_ids=token_type_ids, vision_feats=vision_feats, attention_mask=attention_mask + ) + pooled_features = self.pooler(hidden_state_mixed) + logits = self.cls_head(self.drop(pooled_features)) + return logits diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index add47e27ca..824b5b33d3 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -32,10 +32,18 @@ map_classes_to_indices, ) from monai.transforms.utils_pytorch_numpy_unification import in1d, moveaxis -from monai.utils import convert_to_numpy, convert_to_tensor, ensure_tuple, look_up_option, min_version, optional_import +from monai.utils import ( + convert_data_type, + convert_to_numpy, + convert_to_tensor, + ensure_tuple, + get_equivalent_dtype, + look_up_option, + min_version, + optional_import, +) from monai.utils.enums import TransformBackends from monai.utils.misc import is_module_ver_at_least -from monai.utils.type_conversion import convert_data_type PILImageImage, has_pil = optional_import("PIL.Image", name="Image") pil_image_fromarray, _ = optional_import("PIL.Image", name="fromarray") @@ -334,15 +342,16 @@ class ToTensor(Transform): backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - def __init__(self, device: Optional[torch.device] = None) -> None: + def __init__(self, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None) -> None: super().__init__() + self.dtype = dtype self.device = device def __call__(self, img: NdarrayOrTensor) -> torch.Tensor: """ Apply the transform to `img` and make it contiguous. """ - return convert_to_tensor(img, wrap_sequence=True, device=self.device) # type: ignore + return convert_to_tensor(img, dtype=self.dtype, device=self.device, wrap_sequence=True) # type: ignore class EnsureType(Transform): @@ -354,19 +363,24 @@ class EnsureType(Transform): Args: data_type: target data type to convert, should be "tensor" or "numpy". + dtype: target data content type to convert, for example: np.float32, torch.float, etc. + device: for Tensor data type, specify the target device. """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - def __init__(self, data_type: str = "tensor") -> None: - data_type = data_type.lower() - if data_type not in ("tensor", "numpy"): - raise ValueError("`data type` must be 'tensor' or 'numpy'.") - - self.data_type = data_type + def __init__( + self, + data_type: str = "tensor", + dtype: Optional[Union[DtypeLike, torch.dtype]] = None, + device: Optional[torch.device] = None, + ) -> None: + self.data_type = look_up_option(data_type.lower(), {"tensor", "numpy"}) + self.dtype = dtype + self.device = device - def __call__(self, data: NdarrayOrTensor) -> NdarrayOrTensor: + def __call__(self, data: NdarrayOrTensor): """ Args: data: input data can be PyTorch Tensor, numpy array, list, dictionary, int, float, bool, str, etc. @@ -375,7 +389,12 @@ def __call__(self, data: NdarrayOrTensor) -> NdarrayOrTensor: if applicable. """ - return convert_to_tensor(data) if self.data_type == "tensor" else convert_to_numpy(data) # type: ignore + if self.data_type == "tensor": + dtype_ = get_equivalent_dtype(self.dtype, torch.Tensor) + return convert_to_tensor(data, dtype=dtype_, device=self.device) + else: + dtype_ = get_equivalent_dtype(self.dtype, np.ndarray) + return convert_to_numpy(data, dtype=dtype_) class ToNumpy(Transform): @@ -385,25 +404,36 @@ class ToNumpy(Transform): backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__(self, dtype: Optional[DtypeLike] = None) -> None: + super().__init__() + self.dtype = dtype + def __call__(self, img: NdarrayOrTensor) -> np.ndarray: """ Apply the transform to `img` and make it contiguous. """ - return convert_to_numpy(img) # type: ignore + return convert_to_numpy(img, dtype=self.dtype) # type: ignore class ToCupy(Transform): """ Converts the input data to CuPy array, can support list or tuple of numbers, NumPy and PyTorch Tensor. + + Args: + dtype: data type specifier. It is inferred from the input by default. """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: + def __init__(self, dtype=None) -> None: + super().__init__() + self.dtype = dtype + + def __call__(self, data: NdarrayOrTensor): """ - Apply the transform to `img` and make it contiguous. + Create a CuPy array from `data` and make it contiguous """ - return cp.ascontiguousarray(cp.asarray(img)) # type: ignore + return convert_to_cupy(data, self.dtype) class ToPIL(Transform): @@ -779,6 +809,9 @@ def __call__( output_shape: expected shape of output indices. if None, use `self.output_shape` instead. """ + label, *_ = convert_data_type(label, np.ndarray) # type: ignore + if image is not None: + image, *_ = convert_data_type(image, np.ndarray) # type: ignore if output_shape is None: output_shape = self.output_shape fg_indices, bg_indices = map_binary_to_indices(label, image, self.image_threshold) @@ -828,6 +861,10 @@ def __call__( output_shape: expected shape of output indices. if None, use `self.output_shape` instead. """ + label, *_ = convert_data_type(label, np.ndarray) # type: ignore + if image is not None: + image, *_ = convert_data_type(image, np.ndarray) # type: ignore + if output_shape is None: output_shape = self.output_shape indices = map_classes_to_indices(label, self.num_classes, image, self.image_threshold) @@ -848,6 +885,7 @@ class ConvertToMultiChannelBasedOnBratsClasses(Transform): """ def __call__(self, img: np.ndarray) -> np.ndarray: + img, *_ = convert_data_type(img, np.ndarray) # type: ignore # if img has channel dim, squeeze it if img.ndim == 4 and img.shape[0] == 1: img = np.squeeze(img, axis=0) @@ -914,6 +952,9 @@ def __call__( if label.shape[0] != 1: raise ValueError("Only supports single channel labels!") + img, *_ = convert_data_type(img, np.ndarray) # type: ignore + label, *_ = convert_data_type(label, np.ndarray) # type: ignore + # Generate extreme points self.randomize(label[0, :]) @@ -950,6 +991,7 @@ def __call__(self, img: torch.Tensor): img: PyTorch Tensor data for the TorchVision transform. """ + img, *_ = convert_data_type(img, torch.Tensor) # type: ignore return self.trans(img) @@ -980,7 +1022,7 @@ def __init__(self, orig_labels: Sequence, target_labels: Sequence, dtype: DtypeL self.dtype = dtype def __call__(self, img: np.ndarray): - img = np.asarray(img) + img, *_ = convert_data_type(img, np.ndarray) # type: ignore img_flat = img.flatten() try: out_flat = np.copy(img_flat).astype(self.dtype) @@ -1036,6 +1078,7 @@ def __call__( mask must have the same shape as input `img`. """ + img, *_ = convert_data_type(img, np.ndarray) # type: ignore if meta_data is None: meta_data = {} diff --git a/monai/utils/type_conversion.py b/monai/utils/type_conversion.py index b51ff6a9c8..3636dbc6c0 100644 --- a/monai/utils/type_conversion.py +++ b/monai/utils/type_conversion.py @@ -16,6 +16,7 @@ "get_equivalent_dtype", "convert_data_type", "get_dtype", + "convert_to_cupy", "convert_to_numpy", "convert_to_tensor", "convert_to_dst_type", @@ -60,6 +61,8 @@ def get_equivalent_dtype(dtype, data_type): im = torch.tensor(1) dtype = get_equivalent_dtype(np.float32, type(im)) """ + if dtype is None: + return None if data_type is torch.Tensor: if type(dtype) is torch.dtype: return dtype @@ -83,7 +86,12 @@ def get_dtype(data: Any): return type(data) -def convert_to_tensor(data, wrap_sequence: bool = False, device: Optional[torch.device] = None): +def convert_to_tensor( + data, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + wrap_sequence: bool = False, +): """ Utility to convert the input data to a PyTorch Tensor. If passing a dictionary, list or tuple, recursively check every item and convert it to PyTorch Tensor. @@ -92,36 +100,41 @@ def convert_to_tensor(data, wrap_sequence: bool = False, device: Optional[torch. data: input data can be PyTorch Tensor, numpy array, list, dictionary, int, float, bool, str, etc. will convert Tensor, Numpy array, float, int, bool to Tensors, strings and objects keep the original. for dictionary, list or tuple, convert every item to a Tensor if applicable. - wrap_sequence: if `False`, then lists will recursively call this function. E.g., `[1, 2]` -> `[tensor(1), tensor(2)]`. - If `True`, then `[1, 2]` -> `tensor([1, 2])`. + dtype: target data type to when converting to Tensor. + device: target device to put the converted Tensor data. + wrap_sequence: if `False`, then lists will recursively call this function. + E.g., `[1, 2]` -> `[tensor(1), tensor(2)]`. If `True`, then `[1, 2]` -> `tensor([1, 2])`. """ if isinstance(data, torch.Tensor): - return data.contiguous().to(device) + return data.to(dtype=dtype, device=device, memory_format=torch.contiguous_format) # type: ignore if isinstance(data, np.ndarray): # skip array of string classes and object, refer to: # https://github.com/pytorch/pytorch/blob/v1.9.0/torch/utils/data/_utils/collate.py#L13 if re.search(r"[SaUO]", data.dtype.str) is None: # numpy array with 0 dims is also sequence iterable, # `ascontiguousarray` will add 1 dim if img has no dim, so we only apply on data with dims - return torch.as_tensor(data if data.ndim == 0 else np.ascontiguousarray(data), device=device) - elif has_cp and isinstance(data, cp_ndarray): - return torch.as_tensor(data, device=device) - elif isinstance(data, (float, int, bool)): - return torch.as_tensor(data, device=device) - elif isinstance(data, Sequence) and wrap_sequence: - return torch.as_tensor(data, device=device) + if data.ndim > 0: + data = np.ascontiguousarray(data) + return torch.as_tensor(data, dtype=dtype, device=device) # type: ignore + elif ( + has_cp + and isinstance(data, cp_ndarray) + or isinstance(data, (float, int, bool)) + or (isinstance(data, Sequence) and wrap_sequence) + ): + return torch.as_tensor(data, dtype=dtype, device=device) # type: ignore elif isinstance(data, list): - return [convert_to_tensor(i, device=device) for i in data] + return [convert_to_tensor(i, dtype=dtype, device=device) for i in data] elif isinstance(data, tuple): - return tuple(convert_to_tensor(i, device=device) for i in data) + return tuple(convert_to_tensor(i, dtype=dtype, device=device) for i in data) elif isinstance(data, dict): - return {k: convert_to_tensor(v, device=device) for k, v in data.items()} + return {k: convert_to_tensor(v, dtype=dtype, device=device) for k, v in data.items()} return data -def convert_to_numpy(data, wrap_sequence: bool = False): +def convert_to_numpy(data, dtype: Optional[DtypeLike] = None, wrap_sequence: bool = False): """ Utility to convert the input data to a numpy array. If passing a dictionary, list or tuple, recursively check every item and convert it to numpy array. @@ -130,23 +143,22 @@ def convert_to_numpy(data, wrap_sequence: bool = False): data: input data can be PyTorch Tensor, numpy array, list, dictionary, int, float, bool, str, etc. will convert Tensor, Numpy array, float, int, bool to numpy arrays, strings and objects keep the original. for dictionary, list or tuple, convert every item to a numpy array if applicable. + dtype: target data type when converting to numpy array. wrap_sequence: if `False`, then lists will recursively call this function. E.g., `[1, 2]` -> `[array(1), array(2)]`. If `True`, then `[1, 2]` -> `array([1, 2])`. """ if isinstance(data, torch.Tensor): - data = data.detach().cpu().numpy() + data = data.detach().to(dtype=get_equivalent_dtype(dtype, torch.Tensor), device="cpu").numpy() elif has_cp and isinstance(data, cp_ndarray): - data = cp.asnumpy(data) - elif isinstance(data, (float, int, bool)): - data = np.asarray(data) - elif isinstance(data, Sequence) and wrap_sequence: - return np.asarray(data) + data = cp.asnumpy(data).astype(dtype) + elif isinstance(data, (np.ndarray, float, int, bool)) or (isinstance(data, Sequence) and wrap_sequence): + data = np.asarray(data, dtype=dtype) elif isinstance(data, list): - return [convert_to_numpy(i) for i in data] + return [convert_to_numpy(i, dtype=dtype) for i in data] elif isinstance(data, tuple): - return tuple(convert_to_numpy(i) for i in data) + return tuple(convert_to_numpy(i, dtype=dtype) for i in data) elif isinstance(data, dict): - return {k: convert_to_numpy(v) for k, v in data.items()} + return {k: convert_to_numpy(v, dtype=dtype) for k, v in data.items()} if isinstance(data, np.ndarray) and data.ndim > 0: data = np.ascontiguousarray(data) @@ -154,6 +166,42 @@ def convert_to_numpy(data, wrap_sequence: bool = False): return data +def convert_to_cupy(data, dtype, wrap_sequence: bool = True): + """ + Utility to convert the input data to a cupy array. If passing a dictionary, list or tuple, + recursively check every item and convert it to cupy array. + + Args: + data: input data can be PyTorch Tensor, numpy array, cupy array, list, dictionary, int, float, bool, str, etc. + Tensor, numpy array, cupy array, float, int, bool are converted to cupy arrays + + for dictionary, list or tuple, convert every item to a numpy array if applicable. + dtype: target data type when converting to Cupy array. + wrap_sequence: if `False`, then lists will recursively call this function. E.g., `[1, 2]` -> `[array(1), array(2)]`. + If `True`, then `[1, 2]` -> `array([1, 2])`. + """ + + # direct calls + if isinstance(data, (cp_ndarray, np.ndarray, torch.Tensor, float, int, bool)) or ( + isinstance(data, Sequence) and wrap_sequence + ): + data = cp.asarray(data, dtype) + elif isinstance(data, list): + return [convert_to_cupy(i, dtype) for i in data] + elif isinstance(data, tuple): + return tuple(convert_to_cupy(i, dtype) for i in data) + elif isinstance(data, dict): + return {k: convert_to_cupy(v, dtype) for k, v in data.items()} + # make it contiguous + if isinstance(data, cp.ndarray): + if data.ndim > 0: + data = cp.ascontiguousarray(data) + else: + raise ValueError(f"The input data type [{type(data)}] cannot be converted into cupy arrays!") + + return data + + def convert_data_type( data: Any, output_type: Optional[type] = None, @@ -178,6 +226,8 @@ def convert_data_type( orig_type = torch.Tensor elif isinstance(data, np.ndarray): orig_type = np.ndarray + elif has_cp and isinstance(data, cp.ndarray): + orig_type = cp.ndarray else: orig_type = type(data) @@ -185,30 +235,27 @@ def convert_data_type( output_type = output_type or orig_type - dtype = get_equivalent_dtype(dtype or get_dtype(data), output_type) + dtype_ = get_equivalent_dtype(dtype or get_dtype(data), output_type) if output_type is torch.Tensor: - if orig_type is not torch.Tensor: - data = convert_to_tensor(data) - if dtype != data.dtype: - data = data.to(dtype) - if device is not None: - data = data.to(device) + data = convert_to_tensor(data, dtype=dtype_, device=device) elif output_type is np.ndarray: - if orig_type is not np.ndarray: - data = convert_to_numpy(data) - if data is not None and dtype != data.dtype: - data = data.astype(dtype) + data = convert_to_numpy(data, dtype=dtype_) + elif has_cp and output_type is cp.ndarray: + data = convert_to_cupy(data, dtype=dtype_) else: raise ValueError(f"Unsupported output type: {output_type}") return data, orig_type, orig_device -def convert_to_dst_type(src: Any, dst: NdarrayOrTensor) -> Tuple[NdarrayOrTensor, type, Optional[torch.device]]: +def convert_to_dst_type( + src: Any, dst: NdarrayOrTensor, dtype: Optional[Union[DtypeLike, torch.dtype]] = None +) -> Tuple[NdarrayOrTensor, type, Optional[torch.device]]: """ - If `dst` is `torch.Tensor` or its subclass, convert `src` to `torch.Tensor` with the same data type as `dst`, - if `dst` is `numpy.ndarray` or its subclass, convert to `numpy.ndarray` with the same data type as `dst`, + If `dst` is an instance of `torch.Tensor` or its subclass, convert `src` to `torch.Tensor` with the same data type as `dst`, + if `dst` is an instance of `numpy.ndarray` or its subclass, convert to `numpy.ndarray` with the same data type as `dst`, otherwise, convert to the type of `dst` directly. + `dtype` is an optional argument if the target `dtype` is different from the original `dst`'s data type. See Also: :func:`convert_data_type` @@ -217,6 +264,9 @@ def convert_to_dst_type(src: Any, dst: NdarrayOrTensor) -> Tuple[NdarrayOrTensor if isinstance(dst, torch.Tensor): device = dst.device + if dtype is None: + dtype = dst.dtype + output_type: Any if isinstance(dst, torch.Tensor): output_type = torch.Tensor @@ -224,4 +274,4 @@ def convert_to_dst_type(src: Any, dst: NdarrayOrTensor) -> Tuple[NdarrayOrTensor output_type = np.ndarray else: output_type = type(dst) - return convert_data_type(data=src, output_type=output_type, device=device, dtype=dst.dtype) + return convert_data_type(data=src, output_type=output_type, device=device, dtype=dtype) diff --git a/requirements-dev.txt b/requirements-dev.txt index 785454ad5d..ed8739ded8 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -36,3 +36,4 @@ openslide-python==1.1.2 pandas requests einops +transformers diff --git a/setup.cfg b/setup.cfg index 6efe768a6f..f7ed90a14a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -44,6 +44,7 @@ all = openslide-python==1.1.2 pandas einops + transformers nibabel = nibabel skimage = @@ -74,6 +75,8 @@ pandas = pandas einops = einops +transformers = + transformers [flake8] select = B,C,E,F,N,P,T4,W,B9 max_line_length = 120 diff --git a/tests/min_tests.py b/tests/min_tests.py index 5b376d7b57..bac6521889 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -140,6 +140,7 @@ def run_testsuit(): "test_zoom", "test_zoom_affine", "test_zoomd", + "test_vltransformer", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" diff --git a/tests/test_to_tensor.py b/tests/test_to_tensor.py index 3d187a1dba..b065595e89 100644 --- a/tests/test_to_tensor.py +++ b/tests/test_to_tensor.py @@ -11,8 +11,8 @@ import unittest +import torch from parameterized import parameterized -from torch import Tensor from monai.transforms import ToTensor from tests.utils import TEST_NDARRAYS, assert_allclose, optional_import @@ -35,16 +35,16 @@ class TestToTensor(unittest.TestCase): @parameterized.expand(TESTS) def test_array_input(self, test_data, expected_shape): - result = ToTensor()(test_data) - self.assertTrue(isinstance(result, Tensor)) - assert_allclose(result, test_data) + result = ToTensor(dtype=torch.float32, device="cpu")(test_data) + self.assertTrue(isinstance(result, torch.Tensor)) + assert_allclose(result, test_data, type_test=False) self.assertTupleEqual(result.shape, expected_shape) @parameterized.expand(TESTS_SINGLE) def test_single_input(self, test_data): result = ToTensor()(test_data) - self.assertTrue(isinstance(result, Tensor)) - assert_allclose(result, test_data) + self.assertTrue(isinstance(result, torch.Tensor)) + assert_allclose(result, test_data, type_test=False) self.assertEqual(result.ndim, 0) @unittest.skipUnless(has_cp, "CuPy is required.") @@ -52,8 +52,8 @@ def test_cupy(self): test_data = [[1, 2], [3, 4]] cupy_array = cp.ascontiguousarray(cp.asarray(test_data)) result = ToTensor()(cupy_array) - self.assertTrue(isinstance(result, Tensor)) - assert_allclose(result, test_data) + self.assertTrue(isinstance(result, torch.Tensor)) + assert_allclose(result, test_data, type_test=False) if __name__ == "__main__": diff --git a/tests/test_vltransformer.py b/tests/test_vltransformer.py new file mode 100644 index 0000000000..a92a9bf79a --- /dev/null +++ b/tests/test_vltransformer.py @@ -0,0 +1,80 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets.vltransformer import VLTransformers + +TEST_CASE_VLTransformers = [] +for drop_out in [0.4]: + for in_channels in [3]: + for img_size in [224]: + for patch_size in [16, 32]: + for num_language_layers in [2]: + for num_vision_layers in [4]: + for num_mixed_layers in [3]: + for num_classes in [8]: + test_case = [ + { + "in_channels": in_channels, + "img_size": (img_size,) * 2, + "patch_size": (patch_size,) * 2, + "num_vision_layers": num_vision_layers, + "num_mixed_layers": num_mixed_layers, + "num_language_layers": num_language_layers, + "num_classes": num_classes, + "drop_out": drop_out, + }, + (2, num_classes), # type: ignore + ] + TEST_CASE_VLTransformers.append(test_case) + + +class TestPatchEmbeddingBlock(unittest.TestCase): + @parameterized.expand(TEST_CASE_VLTransformers) + def test_shape(self, input_param, expected_shape): + net = VLTransformers(**input_param) + with eval_mode(net): + result = net(torch.randint(2, (2, 512)), torch.randint(2, (2, 512)), torch.randn((2, 3, 224, 224))) + self.assertEqual(result.shape, expected_shape) + + def test_ill_arg(self): + with self.assertRaises(ValueError): + VLTransformers( + in_channels=3, + img_size=(128, 128), + patch_size=(16, 16), + num_language_layers=2, + num_mixed_layers=4, + num_vision_layers=2, + num_classes=2, + drop_out=5.0, + ) + + with self.assertRaises(ValueError): + VLTransformers( + in_channels=1, + img_size=(97, 97), + patch_size=(16, 16), + num_language_layers=6, + num_mixed_layers=6, + num_vision_layers=8, + num_classes=8, + drop_out=0.4, + ) + + +if __name__ == "__main__": + unittest.main()