From 6c9e081a50f35f401cc757130ea79e74983fe270 Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Thu, 28 Jul 2022 10:02:49 +0100 Subject: [PATCH 01/30] Pulling across more from standalone prototype --- monai/utils/mapping_stack.py | 105 +++++++++++++++++++++++++++++++++++ tests/test_mapping_stack.py | 30 ++++++++++ 2 files changed, 135 insertions(+) create mode 100644 monai/utils/mapping_stack.py create mode 100644 tests/test_mapping_stack.py diff --git a/monai/utils/mapping_stack.py b/monai/utils/mapping_stack.py new file mode 100644 index 0000000000..94e952f3ab --- /dev/null +++ b/monai/utils/mapping_stack.py @@ -0,0 +1,105 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Sequence, Union + +import numpy as np + +import torch + +from monai.utils.enums import TransformBackends +from monai.transforms.utils import (_create_rotate, _create_scale, _create_shear, + _create_translate) + +class MatrixFactory: + + def __init__(self, + dims: int, + backend: TransformBackends, + device: Optional[torch.device] = None): + + if backend == TransformBackends.NUMPY: + if device is not None: + raise ValueError("'device' must be None with TransformBackends.NUMPY") + self._device = None + self._sin = np.sin + self._cos = np.cos + self._eye = np.eye + self._diag = np.diag + else: + if device is None: + raise ValueError("'device' must be set with TransformBackends.TORCH") + self._device = device + self._sin = lambda th: torch.sin(torch.as_tensor(th, + dtype=torch.float32, + device=self._device)) + self._cos = lambda th: torch.cos(torch.as_tensor(th, + dtype=torch.float32, + device=self._device)) + self._eye = lambda rank: torch.eye(rank, device=self._device); + self._diag = lambda size: torch.diag(torch.as_tensor(size, device=self._device)) + + self._backend = backend + self._dims = dims + + def identity(self): + return self._eye(self._dims + 1) + + def rotate_euler(self, radians: Union[Sequence[float], float]): + return _create_rotate(self._dims, radians, self._sin, self._cos, self._eye) + + def shear(self, coefs: Union[Sequence[float], float]): + return _create_shear(self._dims, coefs, self._eye) + + def scale(self, factors: Union[Sequence[float], float]): + return _create_scale(self._dims, factors, self._diag) + + def translate(self, offsets: Union[Sequence[float], float]): + return _create_translate(self._dims, offsets, self._eye) + + +class Mapping: + + def __init__(self, matrix): + self._matrix = matrix + + def apply(self, other): + return Mapping(other @ self._matrix) + + +class MappingStack: + """ + This class keeps track of a series of mappings and apply them / calculate their inverse (if + mappings are invertible). Mapping stacks are used to generate a mapping that gets applied during a `Resample` / + `Resampled` transform. + + A mapping is one of: + - a description of a change to a numpy array that only requires index manipulation instead of an actual resample. + - a homogeneous matrix representing a geometric transform to be applied during a resample + - a field representing a deformation to be applied during a resample + """ + + def __init__(self, factory: MatrixFactory): + self.factory = factory + self.stack = [] + self.applied_stack = [] + + def push(self, mapping): + self.stack.append(mapping) + + def pop(self): + raise NotImplementedError() + + def transform(self): + m = Mapping(self.factory.identity()) + for t in self.stack: + m = m.apply(t) + return m diff --git a/tests/test_mapping_stack.py b/tests/test_mapping_stack.py new file mode 100644 index 0000000000..cd0c72417f --- /dev/null +++ b/tests/test_mapping_stack.py @@ -0,0 +1,30 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from monai.transforms.utils import TransformBackends + +from monai.utils.mapping_stack import MappingStack, Mapping, MatrixFactory + + +class MappingStackTest(unittest.TestCase): + + def test_scale_then_translate(self): + + f = MatrixFactory(3, TransformBackends.NUMPY) + m_scale = f.scale((2, 2, 2)) + m_trans = f.translate((20, 20, 0)) + ms = MappingStack(f) + ms.push(m_scale) + ms.push(m_trans) + + print(ms.transform()._matrix) From 0b0670dd6a57d25e0f0f8ab78002cb6204f38b0b Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Mon, 1 Aug 2022 10:30:14 +0100 Subject: [PATCH 02/30] Applyd function --- monai/transforms/atmostonce/__init__.py | 0 monai/transforms/atmostonce/apply.py | 79 +++++++++++++++++++++++++ 2 files changed, 79 insertions(+) create mode 100644 monai/transforms/atmostonce/__init__.py create mode 100644 monai/transforms/atmostonce/apply.py diff --git a/monai/transforms/atmostonce/__init__.py b/monai/transforms/atmostonce/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/monai/transforms/atmostonce/apply.py b/monai/transforms/atmostonce/apply.py new file mode 100644 index 0000000000..59bbe61ff0 --- /dev/null +++ b/monai/transforms/atmostonce/apply.py @@ -0,0 +1,79 @@ +from typing import Dict, Hashable, Mapping, Optional, Sequence, Union + +import numpy as np + +import torch + +from monai.config import USE_COMPILED, DtypeLike, KeysCollection +from monai.config.type_definitions import NdarrayOrTensor +from monai.transforms.inverse import InvertibleTransform +from monai.transforms.spatial.array import Resample +from monai.transforms.transform import MapTransform +from monai.transforms.utils import create_grid +from monai.utils import GridSampleMode, GridSamplePadMode +from monai.utils.enums import GridPatchSort, TransformBackends +from monai.utils.type_conversion import convert_data_type, convert_to_dst_type +from monai.utils.mapping_stack import MappingStack + +# TODO: This should move to a common place to be shared with dictionary +GridSampleModeSequence = Union[Sequence[Union[GridSampleMode, str]], GridSampleMode, str] +GridSamplePadModeSequence = Union[Sequence[Union[GridSamplePadMode, str]], GridSamplePadMode, str] +DtypeSequence = Union[Sequence[DtypeLike], DtypeLike] + +class Applyd(MapTransform, InvertibleTransform): + + def __init__(self, + keys: KeysCollection, + modes: GridSampleModeSequence, + padding_modes: GridSamplePadModeSequence, + normalized: bool = False, + device: Optional[torch.device] = None, + dtypes: Optional[DtypeSequence] = np.float32): + self.keys = keys + self.modes = modes + self.padding_modes = padding_modes + self.device = device + self.dtypes = dtypes + self.resamplers = dict() + + if isinstance(dtypes, (list, tuple)): + if len(keys) != len(dtypes): + raise ValueError("'keys' and 'dtypes' must be the same length if 'dtypes' is a sequence") + + # create a resampler for each output data type + unique_resamplers = dict() + for d in dtypes: + if d not in unique_resamplers: + unique_resamplers[d] = Resample(norm_coords=not normalized, device=device, dtype=d) + + # assign each named data input the appropriate resampler for that data type + for k, d in zip(keys, dtypes): + if k not in self.resamplers: + self.resamplers[k] = unique_resamplers[d] + + else: + # share the resampler across all named data inputs + resampler = Resample(norm_coords=not normalized, device=device, dtype=dtypes) + for k in keys: + self.resamplers[k] = resampler + + def __call__(self, + data: Mapping[Hashable, NdarrayOrTensor], + allow_missing_keys: bool = False) -> Dict[Hashable, NdarrayOrTensor]: + d = dict(data) + mapping_stack = d["mapping_stack"] + affine = mapping_stack.transform() + for key_tuple in self.key_iterator(d, self.modes, self.padding_modes, self.dtypes): + key, mode, padding_mode, dtype = key_tuple + data = d[key] + spatial_size = data.shape[1:] + grid = create_grid(spatial_size, device=self.device, backend="torch", dtype=dtype) + _device = grid.device + + _b = TransformBackends.TORCH if isinstance(grid, torch.Tensor) else TransformBackends.NUMPY + + grid, *_ = convert_data_type(grid, torch.Tensor, device=_device, dtype=grid.dtype) + affine, *_ = convert_to_dst_type(affine, grid) + d[key] = self.resamplers[key](data, grid=grid, mode=mode, padding_mode=padding_mode) + + return d From b1e6f3c22802adb6ecb04807ef8845cb279c0097 Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Thu, 11 Aug 2022 10:42:21 +0100 Subject: [PATCH 03/30] more atmostonce functionality; baseline atmostonce (non-dictionary) transforms --- monai/transforms/atmostonce/apply.py | 19 ++- monai/transforms/atmostonce/array.py | 140 ++++++++++++++++ monai/transforms/atmostonce/dictionary.py | 192 ++++++++++++++++++++++ monai/utils/type_conversion.py | 28 +++- tests/test_atmostonce.py | 80 +++++++++ 5 files changed, 453 insertions(+), 6 deletions(-) create mode 100644 monai/transforms/atmostonce/array.py create mode 100644 monai/transforms/atmostonce/dictionary.py create mode 100644 tests/test_atmostonce.py diff --git a/monai/transforms/atmostonce/apply.py b/monai/transforms/atmostonce/apply.py index 59bbe61ff0..4ccdf402b1 100644 --- a/monai/transforms/atmostonce/apply.py +++ b/monai/transforms/atmostonce/apply.py @@ -11,8 +11,9 @@ from monai.transforms.transform import MapTransform from monai.transforms.utils import create_grid from monai.utils import GridSampleMode, GridSamplePadMode -from monai.utils.enums import GridPatchSort, TransformBackends -from monai.utils.type_conversion import convert_data_type, convert_to_dst_type +from monai.utils.enums import TransformBackends +from monai.utils.type_conversion import (convert_data_type, convert_to_dst_type, + expand_scalar_to_tuple) from monai.utils.mapping_stack import MappingStack # TODO: This should move to a common place to be shared with dictionary @@ -20,6 +21,10 @@ GridSamplePadModeSequence = Union[Sequence[Union[GridSamplePadMode, str]], GridSamplePadMode, str] DtypeSequence = Union[Sequence[DtypeLike], DtypeLike] + + + + class Applyd(MapTransform, InvertibleTransform): def __init__(self, @@ -61,10 +66,14 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor], allow_missing_keys: bool = False) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - mapping_stack = d["mapping_stack"] - affine = mapping_stack.transform() - for key_tuple in self.key_iterator(d, self.modes, self.padding_modes, self.dtypes): + mapping_stack = d["mappings"] + keys = d.keys() + for key_tuple in self.key_iterator(d, + expand_scalar_to_tuple(self.modes, len(keys)), + expand_scalar_to_tuple(self.padding_modes, len(keys)), + expand_scalar_to_tuple(self.dtypes, len(keys))): key, mode, padding_mode, dtype = key_tuple + affine = mapping_stack[key].transform() data = d[key] spatial_size = data.shape[1:] grid = create_grid(spatial_size, device=self.device, backend="torch", dtype=dtype) diff --git a/monai/transforms/atmostonce/array.py b/monai/transforms/atmostonce/array.py new file mode 100644 index 0000000000..50b26a5bcf --- /dev/null +++ b/monai/transforms/atmostonce/array.py @@ -0,0 +1,140 @@ +from typing import Optional, Sequence, Tuple, Union + +import numpy as np + +import torch + +from monai.config import DtypeLike, NdarrayOrTensor + +from monai.transforms import Transform + +from monai.utils import (GridSampleMode, GridSamplePadMode, + InterpolateMode, NumpyPadMode, PytorchPadMode) +from monai.utils.mapping_stack import MappingStack + + +class Rotate(Transform): + + def __init__( + self, + angle: Union[Sequence[float], float], + keep_size: bool = True, + mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, + padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, + align_corners: bool = False, + dtype: Union[DtypeLike, torch.dtype] = np.float32 + ): + self.angle = angle + self.keep_size = keep_size + self.mode = mode + self.padding_mode = padding_mode + self.align_corners = align_corners + self.dtype = dtype + + def __call__( + self, + img: NdarrayOrTensor, + mapping_stack: Optional[MappingStack] = None, + mode: Optional[Union[InterpolateMode, str]] = None, + padding_mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None, + align_corners: Optional[bool] = None, + ) -> NdarrayOrTensor: + pass + + +class Zoom(Transform): + """ + Zoom into / out of the image applying the `zoom` factor as a scalar, or if `zoom` is a tuple of + values, apply each zoom factor to the appropriate dimension. + """ + + def __init__( + self, + zoom: Union[Sequence[float], float], + mode: Union[InterpolateMode, str] = InterpolateMode.AREA, + padding_mode: Union[NumpyPadMode, PytorchPadMode, str] = NumpyPadMode.EDGE, + align_corners: Optional[bool] = None, + keep_size: bool = True, + **kwargs + ): + self.zoom = zoom + self.mode: InterpolateMode = InterpolateMode(mode) + self.padding_mode = padding_mode + self.align_corners = align_corners + self.keep_size = keep_size + self.kwargs = kwargs + + def __call__( + self, + img: NdarrayOrTensor, + mapping_stack: Optional[MappingStack] = None, + mode: Optional[Union[InterpolateMode, str]] = None, + padding_mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None, + align_corners: Optional[bool] = None + ) -> NdarrayOrTensor: + pass + + +class Resize(Transform): + + def __init__( + self, + spatial_size: Union[Sequence[int], int], + size_mode: str = "all", + mode: Union[InterpolateMode, str] = InterpolateMode.AREA, + align_corners: Optional[bool] = None, + anti_aliasing: bool = False, + anti_aliasing_sigma: Union[Sequence[float], float, None] = None + ): + self.spatial_size = spatial_size + self.size_mode = size_mode + self.mode = mode, + self.align_corners = align_corners + self.anti_aliasing = anti_aliasing + self.anti_aliasing_sigma = anti_aliasing_sigma + + def __call__( + self, + img: NdarrayOrTensor, + mode: Optional[Union[InterpolateMode, str]] = None, + align_corners: Optional[bool] = None, + anti_aliasing: Optional[bool] = None, + anti_aliasing_sigma: Union[Sequence[float], float, None] = None + ) -> NdarrayOrTensor: + pass + + +class Spacing(Transform): + + def __init__( + self, + pixdim: Union[Sequence[float], float, np.ndarray], + diagonal: bool = False, + mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, + padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, + align_corners: bool = False, + dtype: DtypeLike = np.float64, + image_only: bool = False + ): + self.pixdim = pixdim + self.diagonal = diagonal + self.mode = mode + self.padding_mode = padding_mode + self.align_corners = align_corners + self.dtype = dtype + self.image_only = image_only + + def __call__( + self, + img: NdarrayOrTensor, + mapping_stack: Optional[MappingStack] = None, + affine: Optional[NdarrayOrTensor] = None, + mode: Optional[Union[GridSampleMode, str]] = None, + padding_mode: Optional[Union[GridSamplePadMode, str]] = None, + align_corners: Optional[bool] = None, + dtype: DtypeLike = None, + output_spatial_shape: Optional[Union[Sequence[int], np.ndarray, int]] = None + ) -> Union[NdarrayOrTensor, Tuple[NdarrayOrTensor, NdarrayOrTensor, NdarrayOrTensor]]: + pass + + diff --git a/monai/transforms/atmostonce/dictionary.py b/monai/transforms/atmostonce/dictionary.py new file mode 100644 index 0000000000..f89f8756d7 --- /dev/null +++ b/monai/transforms/atmostonce/dictionary.py @@ -0,0 +1,192 @@ +from typing import Mapping, Optional, Sequence, Union + +import numpy as np + +import torch + +from monai.config import KeysCollection +from monai.transforms.inverse import InvertibleTransform +from monai.transforms.transform import MapTransform +from monai.utils.enums import TransformBackends +from monai.utils.mapping_stack import MappingStack, MatrixFactory +from monai.utils.type_conversion import expand_scalar_to_tuple + + +def get_device_from_data(data): + if isinstance(data, np.ndarray): + return None + elif isinstance(data, torch.Tensor): + return data.device + else: + msg = "'data' must be one of numpy ndarray or torch Tensor but is {}" + raise ValueError(msg.format(type(data))) + + +def get_backend_from_data(data): + if isinstance(data, np.ndarray): + return TransformBackends.NUMPY + elif isinstance(data, torch.Tensor): + return TransformBackends.TORCH + else: + msg = "'data' must be one of numpy ndarray or torch Tensor but is {}" + raise ValueError(msg.format(type(data))) + +# TODO: reconcile multiple definitions to one in utils +def expand_potential_tuple(keys, value): + if not isinstance(value, (tuple, list)): + return tuple(value for _ in keys) + return value + + +class MappingStackTransformd(MapTransform, InvertibleTransform): + + def __init__(self, + keys: KeysCollection): + super().__init__(self) + self.keys = keys + + def __call__(self, + d: Mapping, + *args, + **kwargs): + mappings = d.get("mappings", dict()) + rd = dict() + for k in self.keys: + data = d[k] + dims = len(data.shape)-1 + device = get_device_from_data(data) + backend = get_backend_from_data(data) + v = mappings.get(k, MappingStack(MatrixFactory(dims, backend, device))) + v.push(self.get_matrix(dims, backend, device, *args, **kwargs)) + mappings[k] = v + rd[k] = data + + rd["mappings"] = mappings + + return rd + + def get_matrix(self, dims, backend, device, *args, **kwargs): + msg = "get_matrix must be implemented in a subclass of MappingStackTransform" + raise NotImplementedError(msg) + + +class RotateEulerd(MapTransform, InvertibleTransform): + + def __init__(self, + keys: KeysCollection, + euler_radians: Union[Sequence[float], float]): + super().__init__(self) + self.keys = keys + self.euler_radians = expand_scalar_to_tuple(euler_radians, len(keys)) + + def __call__(self, d: Mapping): + mappings = d.get("mappings", dict()) + rd = dict() + for k in self.keys: + data = d[k] + dims = len(data.shape)-1 + device = get_device_from_data(data) + backend = get_backend_from_data(data) + matrix_factory = MatrixFactory(dims, backend, device) + v = mappings.get(k, MappingStack(matrix_factory)) + v.push(matrix_factory.rotate_euler(self.euler_radians)) + mappings[k] = v + rd[k] = data + + rd["mappings"] = mappings + + return rd + + +class Translated(MapTransform, InvertibleTransform): + + def __init__(self, + keys: KeysCollection, + translate: Union[Sequence[float], float]): + super().__init__(self) + self.keys = keys + self.translate = expand_scalar_to_tuple(translate, len(keys)) + + def __call__(self, d: Mapping): + mappings = d.get("mappings", dict()) + rd = dict() + for k in self.keys: + data = d[k] + dims = len(data.shape)-1 + device = get_device_from_data(data) + backend = get_backend_from_data(data) + matrix_factory = MatrixFactory(dims, backend, device) + v = mappings.get(k, MappingStack(matrix_factory)) + v.push(matrix_factory.translate(self.translate)) + mappings[k] = v + rd[k] = data + + rd["mappings"] = mappings + + return rd + + +class Zoomd(MapTransform, InvertibleTransform): + + def __init__(self, + keys: KeysCollection, + scale: Union[Sequence[float], float]): + super().__init__(self) + self.keys = keys + self.scale = expand_scalar_to_tuple(scale, len(keys)) + + def __call__(self, d: Mapping): + mappings = d.get("mappings", dict()) + rd = dict() + for k in self.keys: + data = d[k] + dims = len(data.shape)-1 + device = get_device_from_data(data) + backend = get_backend_from_data(data) + matrix_factory = MatrixFactory(dims, backend, device) + v = mappings.get(k, MappingStack(matrix_factory)) + v.push(matrix_factory.scale(self.scale)) + mappings[k] = v + rd[k] = data + + rd["mappings"] = mappings + + return rd + + +# class RotateEulerd(MappingStackTransformd): +# +# def __init__(self, +# keys: KeysCollection, +# euler_radians: Union[Sequence[float], float]): +# super().__init__(keys) +# self.euler_radians = euler_radians +# +# def get_matrix(self, dims, backend, device, *args, **kwargs): +# euler_radians = args[0] if len(args) > 0 else None +# euler_radians = kwargs.get("euler_radians", None) if euler_radians is None +# euler_radians = self.euler_radians if euler_radians is None else self.euler_radians +# if euler_radians is None: +# raise ValueError("'euler_radians' must be set during initialisation or passed in" +# "during __call__") +# arg = euler_radians if self.euler_radians is None else euler_radians +# return MatrixFactory(dims, backend, device).rotate_euler(arg) +# +# +# class ScaleEulerd(MappingStackTransformd): +# +# def __init__(self, +# keys: KeysCollection, +# scale: Union[Sequence[float], float]): +# super().__init__(keys) +# self.scale = scale +# +# def get_matrix(self, dims, backend, device, *args, **kwargs): +# scale = args[0] if len(args) > 0 else None +# scale = kwargs.get("scale", None) if scale is None +# scale = self.scale if scale is None else self.euler_radians +# if scale is None: +# raise ValueError("'scale' must be set during initialisation or passed in" +# "during __call__") +# arg = scale if self.scale is None else scale +# return MatrixFactory(dims, backend, device).scale(arg) \ No newline at end of file diff --git a/monai/utils/type_conversion.py b/monai/utils/type_conversion.py index a6cd2522d7..0452df5bf4 100644 --- a/monai/utils/type_conversion.py +++ b/monai/utils/type_conversion.py @@ -10,7 +10,7 @@ # limitations under the License. import re -from typing import Any, Optional, Sequence, Tuple, Type, Union +from typing import Any, SupportsInt, Optional, Sequence, Tuple, Type, TypeVar, Union import numpy as np import torch @@ -33,6 +33,7 @@ "convert_to_numpy", "convert_to_tensor", "convert_to_dst_type", + "expand_scalar_to_tuple" ] @@ -309,3 +310,28 @@ def convert_to_list(data: Union[Sequence, torch.Tensor, np.ndarray]) -> list: """ return data.tolist() if isinstance(data, (torch.Tensor, np.ndarray)) else list(data) + + +TValue = TypeVar('TValue') +def expand_scalar_to_tuple(value: Union[Tuple[TValue], float], + length: int): + """ + If `value` is not a tuple, it will be converted to a tuple of the given `length`. + Otherwise, it is returned as is. Note that if `value` is a tuple, its length must be + the same as the `length` parameter or the conversion will fail. + Args: + value: the value to be converted to a tuple if necessary + length: the length of the resulting tuple + Returns: + If `value` is already a tuple, then `value` is returned. Otherwise, return a tuple + of length `length`, each element of which is `value`. + """ + if not isinstance(length, int): + raise ValueError("'length' must be an integer value") + + if not isinstance(value, tuple): + return tuple(value for _ in range(length)) + else: + if length != len(value): + raise ValueError("if 'value' is a tuple it must be the same length as 'length'") + return value diff --git a/tests/test_atmostonce.py b/tests/test_atmostonce.py new file mode 100644 index 0000000000..44a3f008bd --- /dev/null +++ b/tests/test_atmostonce.py @@ -0,0 +1,80 @@ +import unittest + +import numpy as np + +import matplotlib.pyplot as plt + +import torch + +from monai.transforms import Affined +from monai.transforms.atmostonce.apply import Applyd +from monai.transforms.atmostonce.dictionary import RotateEulerd +from monai.transforms.compose import Compose +from monai.utils.enums import GridSampleMode, GridSamplePadMode + +class TestRotateEulerd(unittest.TestCase): + + def test_rotate_numpy(self): + r = RotateEulerd(('image', 'label'), [0.0, 1.0, 0.0]) + + d = { + 'image': np.zeros((1, 64, 64, 32), dtype=np.float32), + 'label': np.ones((1, 64, 64, 32), dtype=np.int8) + } + d = r(d) + + for k, v in d.items(): + if isinstance(v, np.ndarray): + print(k, v.shape) + else: + print(k, v) + + def test_rotate_tensor(self): + r = RotateEulerd(('image', 'label'), [0.0, 1.0, 0.0]) + + d = { + 'image': torch.zeros((1, 64, 64, 32), device="cpu", dtype=torch.float32), + 'label': torch.ones((1, 64, 64, 32), device="cpu", dtype=torch.int8) + } + d = r(d) + + for k, v in d.items(): + if isinstance(v, (np.ndarray, torch.Tensor)): + print(k, v.shape) + else: + print(k, v) + + def test_rotate_apply(self): + c = Compose([ + RotateEulerd(('image', 'label'), (0.0, 3.14159265 / 2, 0.0)), + Applyd(('image', 'label'), + modes=(GridSampleMode.BILINEAR, GridSampleMode.NEAREST), + padding_modes=(GridSamplePadMode.BORDER, GridSamplePadMode.BORDER)) + ]) + + image = torch.zeros((1, 16, 16, 4), device="cpu", dtype=torch.float32) + for y in range(image.shape[-2]): + for z in range(image.shape[-1]): + image[0, :, y, z] = y + z * 16 + label = torch.ones((1, 16, 16, 4), device="cpu", dtype=torch.int8) + d = { + 'image': image, + 'label': label + } + # plt.imshow(d['image'][0, ..., d['image'].shape[-1]//2]) + d = c(d) + # plt.imshow(d['image'][0, ..., d['image'].shape[-1]//2]) + print(d['image'].shape) + + def test_old_affine(self): + c = Compose([ + Affined(('image', 'label'), + rotate_params=(0.0, 0.0, 3.14159265 / 2)) + ]) + + d = { + 'image': torch.zeros((1, 64, 64, 32), device="cpu", dtype=torch.float32), + 'label': torch.ones((1, 64, 64, 32), device="cpu", dtype=torch.int8) + } + d = c(d) + print(d['image'].shape) From 02a95a4692084affff954dc3de573f61719988d0 Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Thu, 28 Jul 2022 10:02:49 +0100 Subject: [PATCH 04/30] Pulling across more from standalone prototype --- monai/utils/mapping_stack.py | 105 +++++++++++++++++++++++++++++++++++ tests/test_mapping_stack.py | 30 ++++++++++ 2 files changed, 135 insertions(+) create mode 100644 monai/utils/mapping_stack.py create mode 100644 tests/test_mapping_stack.py diff --git a/monai/utils/mapping_stack.py b/monai/utils/mapping_stack.py new file mode 100644 index 0000000000..94e952f3ab --- /dev/null +++ b/monai/utils/mapping_stack.py @@ -0,0 +1,105 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Sequence, Union + +import numpy as np + +import torch + +from monai.utils.enums import TransformBackends +from monai.transforms.utils import (_create_rotate, _create_scale, _create_shear, + _create_translate) + +class MatrixFactory: + + def __init__(self, + dims: int, + backend: TransformBackends, + device: Optional[torch.device] = None): + + if backend == TransformBackends.NUMPY: + if device is not None: + raise ValueError("'device' must be None with TransformBackends.NUMPY") + self._device = None + self._sin = np.sin + self._cos = np.cos + self._eye = np.eye + self._diag = np.diag + else: + if device is None: + raise ValueError("'device' must be set with TransformBackends.TORCH") + self._device = device + self._sin = lambda th: torch.sin(torch.as_tensor(th, + dtype=torch.float32, + device=self._device)) + self._cos = lambda th: torch.cos(torch.as_tensor(th, + dtype=torch.float32, + device=self._device)) + self._eye = lambda rank: torch.eye(rank, device=self._device); + self._diag = lambda size: torch.diag(torch.as_tensor(size, device=self._device)) + + self._backend = backend + self._dims = dims + + def identity(self): + return self._eye(self._dims + 1) + + def rotate_euler(self, radians: Union[Sequence[float], float]): + return _create_rotate(self._dims, radians, self._sin, self._cos, self._eye) + + def shear(self, coefs: Union[Sequence[float], float]): + return _create_shear(self._dims, coefs, self._eye) + + def scale(self, factors: Union[Sequence[float], float]): + return _create_scale(self._dims, factors, self._diag) + + def translate(self, offsets: Union[Sequence[float], float]): + return _create_translate(self._dims, offsets, self._eye) + + +class Mapping: + + def __init__(self, matrix): + self._matrix = matrix + + def apply(self, other): + return Mapping(other @ self._matrix) + + +class MappingStack: + """ + This class keeps track of a series of mappings and apply them / calculate their inverse (if + mappings are invertible). Mapping stacks are used to generate a mapping that gets applied during a `Resample` / + `Resampled` transform. + + A mapping is one of: + - a description of a change to a numpy array that only requires index manipulation instead of an actual resample. + - a homogeneous matrix representing a geometric transform to be applied during a resample + - a field representing a deformation to be applied during a resample + """ + + def __init__(self, factory: MatrixFactory): + self.factory = factory + self.stack = [] + self.applied_stack = [] + + def push(self, mapping): + self.stack.append(mapping) + + def pop(self): + raise NotImplementedError() + + def transform(self): + m = Mapping(self.factory.identity()) + for t in self.stack: + m = m.apply(t) + return m diff --git a/tests/test_mapping_stack.py b/tests/test_mapping_stack.py new file mode 100644 index 0000000000..cd0c72417f --- /dev/null +++ b/tests/test_mapping_stack.py @@ -0,0 +1,30 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from monai.transforms.utils import TransformBackends + +from monai.utils.mapping_stack import MappingStack, Mapping, MatrixFactory + + +class MappingStackTest(unittest.TestCase): + + def test_scale_then_translate(self): + + f = MatrixFactory(3, TransformBackends.NUMPY) + m_scale = f.scale((2, 2, 2)) + m_trans = f.translate((20, 20, 0)) + ms = MappingStack(f) + ms.push(m_scale) + ms.push(m_trans) + + print(ms.transform()._matrix) From 4fa26c5767ff3bfe53e40a58523c4008fd5a2de7 Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Mon, 1 Aug 2022 10:30:14 +0100 Subject: [PATCH 05/30] Applyd function --- monai/transforms/atmostonce/__init__.py | 0 monai/transforms/atmostonce/apply.py | 79 +++++++++++++++++++++++++ 2 files changed, 79 insertions(+) create mode 100644 monai/transforms/atmostonce/__init__.py create mode 100644 monai/transforms/atmostonce/apply.py diff --git a/monai/transforms/atmostonce/__init__.py b/monai/transforms/atmostonce/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/monai/transforms/atmostonce/apply.py b/monai/transforms/atmostonce/apply.py new file mode 100644 index 0000000000..59bbe61ff0 --- /dev/null +++ b/monai/transforms/atmostonce/apply.py @@ -0,0 +1,79 @@ +from typing import Dict, Hashable, Mapping, Optional, Sequence, Union + +import numpy as np + +import torch + +from monai.config import USE_COMPILED, DtypeLike, KeysCollection +from monai.config.type_definitions import NdarrayOrTensor +from monai.transforms.inverse import InvertibleTransform +from monai.transforms.spatial.array import Resample +from monai.transforms.transform import MapTransform +from monai.transforms.utils import create_grid +from monai.utils import GridSampleMode, GridSamplePadMode +from monai.utils.enums import GridPatchSort, TransformBackends +from monai.utils.type_conversion import convert_data_type, convert_to_dst_type +from monai.utils.mapping_stack import MappingStack + +# TODO: This should move to a common place to be shared with dictionary +GridSampleModeSequence = Union[Sequence[Union[GridSampleMode, str]], GridSampleMode, str] +GridSamplePadModeSequence = Union[Sequence[Union[GridSamplePadMode, str]], GridSamplePadMode, str] +DtypeSequence = Union[Sequence[DtypeLike], DtypeLike] + +class Applyd(MapTransform, InvertibleTransform): + + def __init__(self, + keys: KeysCollection, + modes: GridSampleModeSequence, + padding_modes: GridSamplePadModeSequence, + normalized: bool = False, + device: Optional[torch.device] = None, + dtypes: Optional[DtypeSequence] = np.float32): + self.keys = keys + self.modes = modes + self.padding_modes = padding_modes + self.device = device + self.dtypes = dtypes + self.resamplers = dict() + + if isinstance(dtypes, (list, tuple)): + if len(keys) != len(dtypes): + raise ValueError("'keys' and 'dtypes' must be the same length if 'dtypes' is a sequence") + + # create a resampler for each output data type + unique_resamplers = dict() + for d in dtypes: + if d not in unique_resamplers: + unique_resamplers[d] = Resample(norm_coords=not normalized, device=device, dtype=d) + + # assign each named data input the appropriate resampler for that data type + for k, d in zip(keys, dtypes): + if k not in self.resamplers: + self.resamplers[k] = unique_resamplers[d] + + else: + # share the resampler across all named data inputs + resampler = Resample(norm_coords=not normalized, device=device, dtype=dtypes) + for k in keys: + self.resamplers[k] = resampler + + def __call__(self, + data: Mapping[Hashable, NdarrayOrTensor], + allow_missing_keys: bool = False) -> Dict[Hashable, NdarrayOrTensor]: + d = dict(data) + mapping_stack = d["mapping_stack"] + affine = mapping_stack.transform() + for key_tuple in self.key_iterator(d, self.modes, self.padding_modes, self.dtypes): + key, mode, padding_mode, dtype = key_tuple + data = d[key] + spatial_size = data.shape[1:] + grid = create_grid(spatial_size, device=self.device, backend="torch", dtype=dtype) + _device = grid.device + + _b = TransformBackends.TORCH if isinstance(grid, torch.Tensor) else TransformBackends.NUMPY + + grid, *_ = convert_data_type(grid, torch.Tensor, device=_device, dtype=grid.dtype) + affine, *_ = convert_to_dst_type(affine, grid) + d[key] = self.resamplers[key](data, grid=grid, mode=mode, padding_mode=padding_mode) + + return d From 65fafa9867d9a612c3b17a210db79a7b1cb57133 Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Thu, 11 Aug 2022 10:42:21 +0100 Subject: [PATCH 06/30] more atmostonce functionality; baseline atmostonce (non-dictionary) transforms --- monai/transforms/atmostonce/apply.py | 19 ++- monai/transforms/atmostonce/array.py | 140 ++++++++++++++++ monai/transforms/atmostonce/dictionary.py | 192 ++++++++++++++++++++++ monai/utils/type_conversion.py | 28 +++- tests/test_atmostonce.py | 80 +++++++++ 5 files changed, 453 insertions(+), 6 deletions(-) create mode 100644 monai/transforms/atmostonce/array.py create mode 100644 monai/transforms/atmostonce/dictionary.py create mode 100644 tests/test_atmostonce.py diff --git a/monai/transforms/atmostonce/apply.py b/monai/transforms/atmostonce/apply.py index 59bbe61ff0..4ccdf402b1 100644 --- a/monai/transforms/atmostonce/apply.py +++ b/monai/transforms/atmostonce/apply.py @@ -11,8 +11,9 @@ from monai.transforms.transform import MapTransform from monai.transforms.utils import create_grid from monai.utils import GridSampleMode, GridSamplePadMode -from monai.utils.enums import GridPatchSort, TransformBackends -from monai.utils.type_conversion import convert_data_type, convert_to_dst_type +from monai.utils.enums import TransformBackends +from monai.utils.type_conversion import (convert_data_type, convert_to_dst_type, + expand_scalar_to_tuple) from monai.utils.mapping_stack import MappingStack # TODO: This should move to a common place to be shared with dictionary @@ -20,6 +21,10 @@ GridSamplePadModeSequence = Union[Sequence[Union[GridSamplePadMode, str]], GridSamplePadMode, str] DtypeSequence = Union[Sequence[DtypeLike], DtypeLike] + + + + class Applyd(MapTransform, InvertibleTransform): def __init__(self, @@ -61,10 +66,14 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor], allow_missing_keys: bool = False) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - mapping_stack = d["mapping_stack"] - affine = mapping_stack.transform() - for key_tuple in self.key_iterator(d, self.modes, self.padding_modes, self.dtypes): + mapping_stack = d["mappings"] + keys = d.keys() + for key_tuple in self.key_iterator(d, + expand_scalar_to_tuple(self.modes, len(keys)), + expand_scalar_to_tuple(self.padding_modes, len(keys)), + expand_scalar_to_tuple(self.dtypes, len(keys))): key, mode, padding_mode, dtype = key_tuple + affine = mapping_stack[key].transform() data = d[key] spatial_size = data.shape[1:] grid = create_grid(spatial_size, device=self.device, backend="torch", dtype=dtype) diff --git a/monai/transforms/atmostonce/array.py b/monai/transforms/atmostonce/array.py new file mode 100644 index 0000000000..50b26a5bcf --- /dev/null +++ b/monai/transforms/atmostonce/array.py @@ -0,0 +1,140 @@ +from typing import Optional, Sequence, Tuple, Union + +import numpy as np + +import torch + +from monai.config import DtypeLike, NdarrayOrTensor + +from monai.transforms import Transform + +from monai.utils import (GridSampleMode, GridSamplePadMode, + InterpolateMode, NumpyPadMode, PytorchPadMode) +from monai.utils.mapping_stack import MappingStack + + +class Rotate(Transform): + + def __init__( + self, + angle: Union[Sequence[float], float], + keep_size: bool = True, + mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, + padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, + align_corners: bool = False, + dtype: Union[DtypeLike, torch.dtype] = np.float32 + ): + self.angle = angle + self.keep_size = keep_size + self.mode = mode + self.padding_mode = padding_mode + self.align_corners = align_corners + self.dtype = dtype + + def __call__( + self, + img: NdarrayOrTensor, + mapping_stack: Optional[MappingStack] = None, + mode: Optional[Union[InterpolateMode, str]] = None, + padding_mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None, + align_corners: Optional[bool] = None, + ) -> NdarrayOrTensor: + pass + + +class Zoom(Transform): + """ + Zoom into / out of the image applying the `zoom` factor as a scalar, or if `zoom` is a tuple of + values, apply each zoom factor to the appropriate dimension. + """ + + def __init__( + self, + zoom: Union[Sequence[float], float], + mode: Union[InterpolateMode, str] = InterpolateMode.AREA, + padding_mode: Union[NumpyPadMode, PytorchPadMode, str] = NumpyPadMode.EDGE, + align_corners: Optional[bool] = None, + keep_size: bool = True, + **kwargs + ): + self.zoom = zoom + self.mode: InterpolateMode = InterpolateMode(mode) + self.padding_mode = padding_mode + self.align_corners = align_corners + self.keep_size = keep_size + self.kwargs = kwargs + + def __call__( + self, + img: NdarrayOrTensor, + mapping_stack: Optional[MappingStack] = None, + mode: Optional[Union[InterpolateMode, str]] = None, + padding_mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None, + align_corners: Optional[bool] = None + ) -> NdarrayOrTensor: + pass + + +class Resize(Transform): + + def __init__( + self, + spatial_size: Union[Sequence[int], int], + size_mode: str = "all", + mode: Union[InterpolateMode, str] = InterpolateMode.AREA, + align_corners: Optional[bool] = None, + anti_aliasing: bool = False, + anti_aliasing_sigma: Union[Sequence[float], float, None] = None + ): + self.spatial_size = spatial_size + self.size_mode = size_mode + self.mode = mode, + self.align_corners = align_corners + self.anti_aliasing = anti_aliasing + self.anti_aliasing_sigma = anti_aliasing_sigma + + def __call__( + self, + img: NdarrayOrTensor, + mode: Optional[Union[InterpolateMode, str]] = None, + align_corners: Optional[bool] = None, + anti_aliasing: Optional[bool] = None, + anti_aliasing_sigma: Union[Sequence[float], float, None] = None + ) -> NdarrayOrTensor: + pass + + +class Spacing(Transform): + + def __init__( + self, + pixdim: Union[Sequence[float], float, np.ndarray], + diagonal: bool = False, + mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, + padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, + align_corners: bool = False, + dtype: DtypeLike = np.float64, + image_only: bool = False + ): + self.pixdim = pixdim + self.diagonal = diagonal + self.mode = mode + self.padding_mode = padding_mode + self.align_corners = align_corners + self.dtype = dtype + self.image_only = image_only + + def __call__( + self, + img: NdarrayOrTensor, + mapping_stack: Optional[MappingStack] = None, + affine: Optional[NdarrayOrTensor] = None, + mode: Optional[Union[GridSampleMode, str]] = None, + padding_mode: Optional[Union[GridSamplePadMode, str]] = None, + align_corners: Optional[bool] = None, + dtype: DtypeLike = None, + output_spatial_shape: Optional[Union[Sequence[int], np.ndarray, int]] = None + ) -> Union[NdarrayOrTensor, Tuple[NdarrayOrTensor, NdarrayOrTensor, NdarrayOrTensor]]: + pass + + diff --git a/monai/transforms/atmostonce/dictionary.py b/monai/transforms/atmostonce/dictionary.py new file mode 100644 index 0000000000..f89f8756d7 --- /dev/null +++ b/monai/transforms/atmostonce/dictionary.py @@ -0,0 +1,192 @@ +from typing import Mapping, Optional, Sequence, Union + +import numpy as np + +import torch + +from monai.config import KeysCollection +from monai.transforms.inverse import InvertibleTransform +from monai.transforms.transform import MapTransform +from monai.utils.enums import TransformBackends +from monai.utils.mapping_stack import MappingStack, MatrixFactory +from monai.utils.type_conversion import expand_scalar_to_tuple + + +def get_device_from_data(data): + if isinstance(data, np.ndarray): + return None + elif isinstance(data, torch.Tensor): + return data.device + else: + msg = "'data' must be one of numpy ndarray or torch Tensor but is {}" + raise ValueError(msg.format(type(data))) + + +def get_backend_from_data(data): + if isinstance(data, np.ndarray): + return TransformBackends.NUMPY + elif isinstance(data, torch.Tensor): + return TransformBackends.TORCH + else: + msg = "'data' must be one of numpy ndarray or torch Tensor but is {}" + raise ValueError(msg.format(type(data))) + +# TODO: reconcile multiple definitions to one in utils +def expand_potential_tuple(keys, value): + if not isinstance(value, (tuple, list)): + return tuple(value for _ in keys) + return value + + +class MappingStackTransformd(MapTransform, InvertibleTransform): + + def __init__(self, + keys: KeysCollection): + super().__init__(self) + self.keys = keys + + def __call__(self, + d: Mapping, + *args, + **kwargs): + mappings = d.get("mappings", dict()) + rd = dict() + for k in self.keys: + data = d[k] + dims = len(data.shape)-1 + device = get_device_from_data(data) + backend = get_backend_from_data(data) + v = mappings.get(k, MappingStack(MatrixFactory(dims, backend, device))) + v.push(self.get_matrix(dims, backend, device, *args, **kwargs)) + mappings[k] = v + rd[k] = data + + rd["mappings"] = mappings + + return rd + + def get_matrix(self, dims, backend, device, *args, **kwargs): + msg = "get_matrix must be implemented in a subclass of MappingStackTransform" + raise NotImplementedError(msg) + + +class RotateEulerd(MapTransform, InvertibleTransform): + + def __init__(self, + keys: KeysCollection, + euler_radians: Union[Sequence[float], float]): + super().__init__(self) + self.keys = keys + self.euler_radians = expand_scalar_to_tuple(euler_radians, len(keys)) + + def __call__(self, d: Mapping): + mappings = d.get("mappings", dict()) + rd = dict() + for k in self.keys: + data = d[k] + dims = len(data.shape)-1 + device = get_device_from_data(data) + backend = get_backend_from_data(data) + matrix_factory = MatrixFactory(dims, backend, device) + v = mappings.get(k, MappingStack(matrix_factory)) + v.push(matrix_factory.rotate_euler(self.euler_radians)) + mappings[k] = v + rd[k] = data + + rd["mappings"] = mappings + + return rd + + +class Translated(MapTransform, InvertibleTransform): + + def __init__(self, + keys: KeysCollection, + translate: Union[Sequence[float], float]): + super().__init__(self) + self.keys = keys + self.translate = expand_scalar_to_tuple(translate, len(keys)) + + def __call__(self, d: Mapping): + mappings = d.get("mappings", dict()) + rd = dict() + for k in self.keys: + data = d[k] + dims = len(data.shape)-1 + device = get_device_from_data(data) + backend = get_backend_from_data(data) + matrix_factory = MatrixFactory(dims, backend, device) + v = mappings.get(k, MappingStack(matrix_factory)) + v.push(matrix_factory.translate(self.translate)) + mappings[k] = v + rd[k] = data + + rd["mappings"] = mappings + + return rd + + +class Zoomd(MapTransform, InvertibleTransform): + + def __init__(self, + keys: KeysCollection, + scale: Union[Sequence[float], float]): + super().__init__(self) + self.keys = keys + self.scale = expand_scalar_to_tuple(scale, len(keys)) + + def __call__(self, d: Mapping): + mappings = d.get("mappings", dict()) + rd = dict() + for k in self.keys: + data = d[k] + dims = len(data.shape)-1 + device = get_device_from_data(data) + backend = get_backend_from_data(data) + matrix_factory = MatrixFactory(dims, backend, device) + v = mappings.get(k, MappingStack(matrix_factory)) + v.push(matrix_factory.scale(self.scale)) + mappings[k] = v + rd[k] = data + + rd["mappings"] = mappings + + return rd + + +# class RotateEulerd(MappingStackTransformd): +# +# def __init__(self, +# keys: KeysCollection, +# euler_radians: Union[Sequence[float], float]): +# super().__init__(keys) +# self.euler_radians = euler_radians +# +# def get_matrix(self, dims, backend, device, *args, **kwargs): +# euler_radians = args[0] if len(args) > 0 else None +# euler_radians = kwargs.get("euler_radians", None) if euler_radians is None +# euler_radians = self.euler_radians if euler_radians is None else self.euler_radians +# if euler_radians is None: +# raise ValueError("'euler_radians' must be set during initialisation or passed in" +# "during __call__") +# arg = euler_radians if self.euler_radians is None else euler_radians +# return MatrixFactory(dims, backend, device).rotate_euler(arg) +# +# +# class ScaleEulerd(MappingStackTransformd): +# +# def __init__(self, +# keys: KeysCollection, +# scale: Union[Sequence[float], float]): +# super().__init__(keys) +# self.scale = scale +# +# def get_matrix(self, dims, backend, device, *args, **kwargs): +# scale = args[0] if len(args) > 0 else None +# scale = kwargs.get("scale", None) if scale is None +# scale = self.scale if scale is None else self.euler_radians +# if scale is None: +# raise ValueError("'scale' must be set during initialisation or passed in" +# "during __call__") +# arg = scale if self.scale is None else scale +# return MatrixFactory(dims, backend, device).scale(arg) \ No newline at end of file diff --git a/monai/utils/type_conversion.py b/monai/utils/type_conversion.py index 9c9fb1a4b2..9e08e56fad 100644 --- a/monai/utils/type_conversion.py +++ b/monai/utils/type_conversion.py @@ -10,7 +10,7 @@ # limitations under the License. import re -from typing import Any, Optional, Sequence, Tuple, Type, Union +from typing import Any, SupportsInt, Optional, Sequence, Tuple, Type, TypeVar, Union import numpy as np import torch @@ -34,6 +34,7 @@ "convert_to_numpy", "convert_to_tensor", "convert_to_dst_type", + "expand_scalar_to_tuple" ] @@ -347,3 +348,28 @@ def convert_to_list(data: Union[Sequence, torch.Tensor, np.ndarray]) -> list: """ return data.tolist() if isinstance(data, (torch.Tensor, np.ndarray)) else list(data) + + +TValue = TypeVar('TValue') +def expand_scalar_to_tuple(value: Union[Tuple[TValue], float], + length: int): + """ + If `value` is not a tuple, it will be converted to a tuple of the given `length`. + Otherwise, it is returned as is. Note that if `value` is a tuple, its length must be + the same as the `length` parameter or the conversion will fail. + Args: + value: the value to be converted to a tuple if necessary + length: the length of the resulting tuple + Returns: + If `value` is already a tuple, then `value` is returned. Otherwise, return a tuple + of length `length`, each element of which is `value`. + """ + if not isinstance(length, int): + raise ValueError("'length' must be an integer value") + + if not isinstance(value, tuple): + return tuple(value for _ in range(length)) + else: + if length != len(value): + raise ValueError("if 'value' is a tuple it must be the same length as 'length'") + return value diff --git a/tests/test_atmostonce.py b/tests/test_atmostonce.py new file mode 100644 index 0000000000..44a3f008bd --- /dev/null +++ b/tests/test_atmostonce.py @@ -0,0 +1,80 @@ +import unittest + +import numpy as np + +import matplotlib.pyplot as plt + +import torch + +from monai.transforms import Affined +from monai.transforms.atmostonce.apply import Applyd +from monai.transforms.atmostonce.dictionary import RotateEulerd +from monai.transforms.compose import Compose +from monai.utils.enums import GridSampleMode, GridSamplePadMode + +class TestRotateEulerd(unittest.TestCase): + + def test_rotate_numpy(self): + r = RotateEulerd(('image', 'label'), [0.0, 1.0, 0.0]) + + d = { + 'image': np.zeros((1, 64, 64, 32), dtype=np.float32), + 'label': np.ones((1, 64, 64, 32), dtype=np.int8) + } + d = r(d) + + for k, v in d.items(): + if isinstance(v, np.ndarray): + print(k, v.shape) + else: + print(k, v) + + def test_rotate_tensor(self): + r = RotateEulerd(('image', 'label'), [0.0, 1.0, 0.0]) + + d = { + 'image': torch.zeros((1, 64, 64, 32), device="cpu", dtype=torch.float32), + 'label': torch.ones((1, 64, 64, 32), device="cpu", dtype=torch.int8) + } + d = r(d) + + for k, v in d.items(): + if isinstance(v, (np.ndarray, torch.Tensor)): + print(k, v.shape) + else: + print(k, v) + + def test_rotate_apply(self): + c = Compose([ + RotateEulerd(('image', 'label'), (0.0, 3.14159265 / 2, 0.0)), + Applyd(('image', 'label'), + modes=(GridSampleMode.BILINEAR, GridSampleMode.NEAREST), + padding_modes=(GridSamplePadMode.BORDER, GridSamplePadMode.BORDER)) + ]) + + image = torch.zeros((1, 16, 16, 4), device="cpu", dtype=torch.float32) + for y in range(image.shape[-2]): + for z in range(image.shape[-1]): + image[0, :, y, z] = y + z * 16 + label = torch.ones((1, 16, 16, 4), device="cpu", dtype=torch.int8) + d = { + 'image': image, + 'label': label + } + # plt.imshow(d['image'][0, ..., d['image'].shape[-1]//2]) + d = c(d) + # plt.imshow(d['image'][0, ..., d['image'].shape[-1]//2]) + print(d['image'].shape) + + def test_old_affine(self): + c = Compose([ + Affined(('image', 'label'), + rotate_params=(0.0, 0.0, 3.14159265 / 2)) + ]) + + d = { + 'image': torch.zeros((1, 64, 64, 32), device="cpu", dtype=torch.float32), + 'label': torch.ones((1, 64, 64, 32), device="cpu", dtype=torch.int8) + } + d = c(d) + print(d['image'].shape) From 3530432f194b674b83f69be36d66f411b0e7c3dc Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Mon, 15 Aug 2022 13:21:55 +0100 Subject: [PATCH 07/30] Working on ground-up refactor of array transforms --- monai/transforms/atmostonce/array.py | 63 ++++++++++++++++++++++++++-- monai/utils/mapping_stack.py | 10 +++-- monai/utils/misc.py | 23 ++++++++++ 3 files changed, 88 insertions(+), 8 deletions(-) diff --git a/monai/transforms/atmostonce/array.py b/monai/transforms/atmostonce/array.py index 50b26a5bcf..42cf5bbdaa 100644 --- a/monai/transforms/atmostonce/array.py +++ b/monai/transforms/atmostonce/array.py @@ -10,7 +10,8 @@ from monai.utils import (GridSampleMode, GridSamplePadMode, InterpolateMode, NumpyPadMode, PytorchPadMode) -from monai.utils.mapping_stack import MappingStack +from monai.utils.mapping_stack import MappingStack, MatrixFactory +from monai.utils.misc import get_backend_from_data, get_device_from_data class Rotate(Transform): @@ -39,7 +40,24 @@ def __call__( padding_mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None, align_corners: Optional[bool] = None, ) -> NdarrayOrTensor: - pass + mode = self.mode if mode is None else mode + padding_mode = self.padding_mode if padding_mode is None else padding_mode + align_corners = self.align_corners if align_corners is None else align_corners + keep_size = self.keep_size + dtype = self.dtype + matrix_factory = MatrixFactory(len(img.shape)-1, + get_backend_from_data(img), + get_device_from_data(img)) + if mapping_stack is None: + mapping_stack = MappingStack(matrix_factory) + mapping_stack.push(matrix_factory.rotate_euler(self.angle, + **{ + "padding_mode": padding_mode, + "mode": mode, + "align_corners": align_corners, + "keep_size": keep_size, + "dtype": dtype + })) class Zoom(Transform): @@ -55,6 +73,7 @@ def __init__( padding_mode: Union[NumpyPadMode, PytorchPadMode, str] = NumpyPadMode.EDGE, align_corners: Optional[bool] = None, keep_size: bool = True, + dtype: Union[DtypeLike, torch.dtype] = np.float32, **kwargs ): self.zoom = zoom @@ -62,6 +81,7 @@ def __init__( self.padding_mode = padding_mode self.align_corners = align_corners self.keep_size = keep_size + self.dtype = dtype self.kwargs = kwargs def __call__( @@ -72,7 +92,26 @@ def __call__( padding_mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None, align_corners: Optional[bool] = None ) -> NdarrayOrTensor: - pass + + mode = self.mode if mode is None else mode + padding_mode = self.padding_mode if padding_mode is None else padding_mode + align_corners = self.align_corners if align_corners is None else align_corners + keep_size = self.keep_size + dtype = self.dtype + matrix_factory = MatrixFactory(len(img.shape)-1, + get_backend_from_data(img), + get_device_from_data(img)) + if mapping_stack is None: + mapping_stack = MappingStack(matrix_factory) + mapping_stack.push(matrix_factory.scale(self.zoom, + **{ + "padding_mode": padding_mode, + "mode": mode, + "align_corners": align_corners, + "keep_size": keep_size, + "dtype": dtype + })) + img.add class Resize(Transform): @@ -96,12 +135,28 @@ def __init__( def __call__( self, img: NdarrayOrTensor, + mapping_stack: Optional[MappingStack] = None, mode: Optional[Union[InterpolateMode, str]] = None, align_corners: Optional[bool] = None, anti_aliasing: Optional[bool] = None, anti_aliasing_sigma: Union[Sequence[float], float, None] = None ) -> NdarrayOrTensor: - pass + mode = self.mode if mode is None else mode + align_corners = self.align_corners if align_corners is None else align_corners + keep_size = self.keep_size + dtype = self.dtype + matrix_factory = MatrixFactory(len(img.shape)-1, + get_backend_from_data(img), + get_device_from_data(img)) + if mapping_stack is None: + mapping_stack = MappingStack(matrix_factory) + mapping_stack.push(matrix_factory.scale(self.zoom, + **{ + "mode": mode, + "align_corners": align_corners, + "keep_size": keep_size, + "dtype": dtype + })) class Spacing(Transform): diff --git a/monai/utils/mapping_stack.py b/monai/utils/mapping_stack.py index 94e952f3ab..0c214df4ef 100644 --- a/monai/utils/mapping_stack.py +++ b/monai/utils/mapping_stack.py @@ -53,14 +53,16 @@ def __init__(self, def identity(self): return self._eye(self._dims + 1) - def rotate_euler(self, radians: Union[Sequence[float], float]): - return _create_rotate(self._dims, radians, self._sin, self._cos, self._eye) + def rotate_euler(self, radians: Union[Sequence[float], float], **extra_args): + matrix = _create_rotate(self._dims, radians, self._sin, self._cos, self._eye) + return {"matrix": matrix, "args": extra_args} def shear(self, coefs: Union[Sequence[float], float]): return _create_shear(self._dims, coefs, self._eye) - def scale(self, factors: Union[Sequence[float], float]): - return _create_scale(self._dims, factors, self._diag) + def scale(self, factors: Union[Sequence[float], float], **extra_args): + matrix = _create_scale(self._dims, factors, self._diag) + return {"matrix": matrix, "args": extra_args} def translate(self, offsets: Union[Sequence[float], float]): return _create_translate(self._dims, offsets, self._eye) diff --git a/monai/utils/misc.py b/monai/utils/misc.py index fc38dc5056..c4ad86cc68 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -25,6 +25,7 @@ import numpy as np import torch +from monai.utils import TransformBackends from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor, PathLike from monai.utils.module import version_leq @@ -52,6 +53,8 @@ "sample_slices", "check_parent_dir", "save_obj", + "get_backend_from_data", + "get_device_from_data", ] _seed = None @@ -471,3 +474,23 @@ def save_obj( shutil.move(str(temp_path), path) except PermissionError: # project-monai/monai issue #3613 pass + + +def get_device_from_data(data): + if isinstance(data, np.ndarray): + return None + elif isinstance(data, torch.Tensor): + return data.device + else: + msg = "'data' must be one of numpy ndarray or torch Tensor but is {}" + raise ValueError(msg.format(type(data))) + + +def get_backend_from_data(data): + if isinstance(data, np.ndarray): + return TransformBackends.NUMPY + elif isinstance(data, torch.Tensor): + return TransformBackends.TORCH + else: + msg = "'data' must be one of numpy ndarray or torch Tensor but is {}" + raise ValueError(msg.format(type(data))) \ No newline at end of file From d74fb564de0af4cec6f46fc4a03bf977f8bb6d4d Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Wed, 17 Aug 2022 13:46:27 +0100 Subject: [PATCH 08/30] Re-re-re-factored function / array / dict based rotate and others; refactored apply; initial compose implementation + helpers --- monai/data/meta_tensor.py | 22 ++ monai/transforms/atmostonce/apply.py | 183 ++++++---- monai/transforms/atmostonce/array.py | 339 ++++++++++++------ monai/transforms/atmostonce/compose.py | 193 ++++++++++ monai/transforms/atmostonce/dictionary.py | 112 +++--- monai/transforms/atmostonce/functional.py | 278 ++++++++++++++ monai/transforms/atmostonce/lazy_transform.py | 75 ++++ monai/utils/mapping_stack.py | 142 ++++++-- tests/test_atmostonce.py | 183 +++++++++- 9 files changed, 1260 insertions(+), 267 deletions(-) create mode 100644 monai/transforms/atmostonce/compose.py create mode 100644 monai/transforms/atmostonce/functional.py create mode 100644 monai/transforms/atmostonce/lazy_transform.py diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index 8897371903..8c6c7278d7 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -11,6 +11,8 @@ from __future__ import annotations +import copy + import warnings from copy import deepcopy from typing import Any, Sequence @@ -151,6 +153,26 @@ def __init__( if MetaKeys.SPACE not in self.meta: self.meta[MetaKeys.SPACE] = SpaceKeys.RAS # defaulting to the right-anterior-superior space + self._pending_transforms = list() + + def push_pending_transform(self, meta_matrix): + self._pending_transforms.append(meta_matrix) + + def peek_pending_transform(self): + return copy.deepcopy(self._pending_transforms[0]) + + def pop_pending_transform(self): + transform = self._pending_transforms[0] + self._pending_transforms.pop(0) + return transform + + @property + def pending_transforms(self): + return copy.deepcopy(self._pending_transforms) + + def clear_pending_transforms(self): + self._pending_transforms = list() + @staticmethod def update_meta(rets: Sequence, func, args, kwargs) -> Sequence: """ diff --git a/monai/transforms/atmostonce/apply.py b/monai/transforms/atmostonce/apply.py index 4ccdf402b1..d39557dac5 100644 --- a/monai/transforms/atmostonce/apply.py +++ b/monai/transforms/atmostonce/apply.py @@ -1,20 +1,24 @@ from typing import Dict, Hashable, Mapping, Optional, Sequence, Union +import itertools as it + import numpy as np import torch from monai.config import USE_COMPILED, DtypeLike, KeysCollection from monai.config.type_definitions import NdarrayOrTensor +from monai.data import MetaTensor from monai.transforms.inverse import InvertibleTransform from monai.transforms.spatial.array import Resample from monai.transforms.transform import MapTransform from monai.transforms.utils import create_grid from monai.utils import GridSampleMode, GridSamplePadMode from monai.utils.enums import TransformBackends +from monai.utils.misc import get_backend_from_data, get_device_from_data from monai.utils.type_conversion import (convert_data_type, convert_to_dst_type, expand_scalar_to_tuple) -from monai.utils.mapping_stack import MappingStack +from monai.utils.mapping_stack import MatrixFactory # TODO: This should move to a common place to be shared with dictionary GridSampleModeSequence = Union[Sequence[Union[GridSampleMode, str]], GridSampleMode, str] @@ -22,67 +26,126 @@ DtypeSequence = Union[Sequence[DtypeLike], DtypeLike] +# TODO: move to mapping_stack.py +def extents_from_shape(shape): + extents = [[0, shape[i]] for i in range(1, len(shape))] + + extents = it.product(*extents) + return list(np.asarray(e + (1,)) for e in extents) + + +# TODO: move to mapping_stack.py +def shape_from_extents(extents): + aextents = np.asarray(extents) + mins = aextents.min(axis=0) + maxes = aextents.max(axis=0) + return np.ceil(maxes - mins)[:-1].astype(int) + + +def apply(data: MetaTensor): + pending = data.pending_transforms + + if len(pending) == 0: + return data + + dim_count = len(data) - 1 + matrix_factory = MatrixFactory(dim_count, + get_backend_from_data(data), + get_device_from_data(data)) + + # set up the identity matrix and metadata + cumulative_matrix = matrix_factory.identity(dim_count) + cumulative_extents = extents_from_shape(data.shape) + + # pre-translate origin to centre of image + translate_to_centre = matrix_factory.translate(dim_count) + cumulative_matrix = translate_to_centre @ cumulative_matrix + cumulative_extents = [e @ translate_to_centre for e in cumulative_extents] + + for meta_matrix in pending: + next_matrix = meta_matrix.matrix + cumulative_matrix = next_matrix @ cumulative_matrix + cumulative_extents = [e @ translate_to_centre for e in cumulative_extents] + + # TODO: figure out how to propagate extents properly + # TODO: resampling strategy: augment resample or perform multiple stages if necessary + # TODO: resampling strategy - antialiasing: can resample just be augmented? + + + data.clear_pending_transforms() + + +class Apply(InvertibleTransform): + def __init__(self): + super().__init__() + pass class Applyd(MapTransform, InvertibleTransform): - def __init__(self, - keys: KeysCollection, - modes: GridSampleModeSequence, - padding_modes: GridSamplePadModeSequence, - normalized: bool = False, - device: Optional[torch.device] = None, - dtypes: Optional[DtypeSequence] = np.float32): - self.keys = keys - self.modes = modes - self.padding_modes = padding_modes - self.device = device - self.dtypes = dtypes - self.resamplers = dict() - - if isinstance(dtypes, (list, tuple)): - if len(keys) != len(dtypes): - raise ValueError("'keys' and 'dtypes' must be the same length if 'dtypes' is a sequence") - - # create a resampler for each output data type - unique_resamplers = dict() - for d in dtypes: - if d not in unique_resamplers: - unique_resamplers[d] = Resample(norm_coords=not normalized, device=device, dtype=d) - - # assign each named data input the appropriate resampler for that data type - for k, d in zip(keys, dtypes): - if k not in self.resamplers: - self.resamplers[k] = unique_resamplers[d] - - else: - # share the resampler across all named data inputs - resampler = Resample(norm_coords=not normalized, device=device, dtype=dtypes) - for k in keys: - self.resamplers[k] = resampler - - def __call__(self, - data: Mapping[Hashable, NdarrayOrTensor], - allow_missing_keys: bool = False) -> Dict[Hashable, NdarrayOrTensor]: - d = dict(data) - mapping_stack = d["mappings"] - keys = d.keys() - for key_tuple in self.key_iterator(d, - expand_scalar_to_tuple(self.modes, len(keys)), - expand_scalar_to_tuple(self.padding_modes, len(keys)), - expand_scalar_to_tuple(self.dtypes, len(keys))): - key, mode, padding_mode, dtype = key_tuple - affine = mapping_stack[key].transform() - data = d[key] - spatial_size = data.shape[1:] - grid = create_grid(spatial_size, device=self.device, backend="torch", dtype=dtype) - _device = grid.device - - _b = TransformBackends.TORCH if isinstance(grid, torch.Tensor) else TransformBackends.NUMPY - - grid, *_ = convert_data_type(grid, torch.Tensor, device=_device, dtype=grid.dtype) - affine, *_ = convert_to_dst_type(affine, grid) - d[key] = self.resamplers[key](data, grid=grid, mode=mode, padding_mode=padding_mode) - - return d + def __init__(self): + super().__init__() + pass + +# class Applyd(MapTransform, InvertibleTransform): +# +# def __init__(self, +# keys: KeysCollection, +# modes: GridSampleModeSequence, +# padding_modes: GridSamplePadModeSequence, +# normalized: bool = False, +# device: Optional[torch.device] = None, +# dtypes: Optional[DtypeSequence] = np.float32): +# self.keys = keys +# self.modes = modes +# self.padding_modes = padding_modes +# self.device = device +# self.dtypes = dtypes +# self.resamplers = dict() +# +# if isinstance(dtypes, (list, tuple)): +# if len(keys) != len(dtypes): +# raise ValueError("'keys' and 'dtypes' must be the same length if 'dtypes' is a sequence") +# +# # create a resampler for each output data type +# unique_resamplers = dict() +# for d in dtypes: +# if d not in unique_resamplers: +# unique_resamplers[d] = Resample(norm_coords=not normalized, device=device, dtype=d) +# +# # assign each named data input the appropriate resampler for that data type +# for k, d in zip(keys, dtypes): +# if k not in self.resamplers: +# self.resamplers[k] = unique_resamplers[d] +# +# else: +# # share the resampler across all named data inputs +# resampler = Resample(norm_coords=not normalized, device=device, dtype=dtypes) +# for k in keys: +# self.resamplers[k] = resampler +# +# def __call__(self, +# data: Mapping[Hashable, NdarrayOrTensor], +# allow_missing_keys: bool = False) -> Dict[Hashable, NdarrayOrTensor]: +# d = dict(data) +# mapping_stack = d["mappings"] +# keys = d.keys() +# for key_tuple in self.key_iterator(d, +# expand_scalar_to_tuple(self.modes, len(keys)), +# expand_scalar_to_tuple(self.padding_modes, len(keys)), +# expand_scalar_to_tuple(self.dtypes, len(keys))): +# key, mode, padding_mode, dtype = key_tuple +# affine = mapping_stack[key].transform() +# data = d[key] +# spatial_size = data.shape[1:] +# grid = create_grid(spatial_size, device=self.device, backend="torch", dtype=dtype) +# _device = grid.device +# +# _b = TransformBackends.TORCH if isinstance(grid, torch.Tensor) else TransformBackends.NUMPY +# +# grid, *_ = convert_data_type(grid, torch.Tensor, device=_device, dtype=grid.dtype) +# affine, *_ = convert_to_dst_type(affine, grid) +# d[key] = self.resamplers[key](data, grid=grid, mode=mode, padding_mode=padding_mode) +# +# return d diff --git a/monai/transforms/atmostonce/array.py b/monai/transforms/atmostonce/array.py index 42cf5bbdaa..df998e622b 100644 --- a/monai/transforms/atmostonce/array.py +++ b/monai/transforms/atmostonce/array.py @@ -1,20 +1,126 @@ -from typing import Optional, Sequence, Tuple, Union +from typing import Any, Optional, Sequence, Tuple, Union import numpy as np import torch from monai.config import DtypeLike, NdarrayOrTensor +from monai.data import get_track_meta -from monai.transforms import Transform +from monai.transforms import Transform, InvertibleTransform, RandomizableTransform + +from monai.transforms.atmostonce.apply import apply +from monai.transforms.atmostonce.functional import resize, rotate, zoom, spacing +from monai.transforms.atmostonce.lazy_transform import LazyTransform, push_transform from monai.utils import (GridSampleMode, GridSamplePadMode, - InterpolateMode, NumpyPadMode, PytorchPadMode) -from monai.utils.mapping_stack import MappingStack, MatrixFactory -from monai.utils.misc import get_backend_from_data, get_device_from_data + InterpolateMode, NumpyPadMode, PytorchPadMode, convert_to_tensor) +from monai.utils.mapping_stack import MatrixFactory, MetaMatrix +from monai.utils.misc import get_backend_from_data, get_device_from_data, ensure_tuple + + +# TODO: these transforms are intended to replace array transforms once development is done + +# TODO: why doesn't Spacing have antialiasing options? +class Spacing(LazyTransform, InvertibleTransform): + + def __init__( + self, + pixdim: Union[Sequence[float], float, np.ndarray], + src_pixdim: Optional[Union[Sequence[float], float, np.ndarray]], + diagonal: Optional[bool] = False, + mode: Optional[Union[GridSampleMode, str]] = GridSampleMode.BILINEAR, + padding_mode: Optional[Union[GridSamplePadMode, str]] = GridSamplePadMode.BORDER, + align_corners: Optional[bool] = False, + dtype: Optional[DtypeLike] = np.float64, + image_only: Optional[bool] = False, + lazy_evaluation: Optional[bool] = False + ): + LazyTransform.__init__(self, lazy_evaluation) + self.pixdim = pixdim + self.diagonal = diagonal + self.mode = mode + self.padding_mode = padding_mode + self.align_corners = align_corners + self.dtype = dtype + self.image_only = image_only + def __call__( + self, + img: NdarrayOrTensor, + mode: Optional[Union[GridSampleMode, str]] = None, + padding_mode: Optional[Union[GridSamplePadMode, str]] = None, + align_corners: Optional[bool] = None, + dtype: DtypeLike = None + ): -class Rotate(Transform): + mode_ = mode or self.mode + padding_mode_ = padding_mode or self.padding_mode + align_corners_ = align_corners or self.align_corners + dtype_ = dtype or self.dtype + + img_t, transform, metadata = spacing(img, self.pixdim, self.src_pixdim, self.diagonal, + mode_, padding_mode_, align_corners_, dtype_) + + # TODO: candidate for refactoring into a LazyTransform method + img_t.push_pending_transform(MetaMatrix(transform, metadata)) + if not self.lazy_evaluation: + img_t = apply(img_t) + + return img_t + + def inverse(self, data): + raise NotImplementedError() + + +class Resize(LazyTransform, InvertibleTransform): + + def __init__( + self, + spatial_size: Union[Sequence[int], int], + size_mode: Optional[str] = "all", + mode: Union[InterpolateMode, str] = InterpolateMode.AREA, + align_corners: Optional[bool] = False, + anti_aliasing: Optional[bool] = False, + anti_aliasing_sigma: Optional[Union[Sequence[float], float, None]] = None, + dtype: Optional[Union[DtypeLike, torch.dtype]] = np.float32, + lazy_evaluation: Optional[bool] = False + ): + LazyTransform.__init__(self, lazy_evaluation) + self.spatial_size = spatial_size + self.size_mode = size_mode + self.mode = mode, + self.align_corners = align_corners + self.anti_aliasing = anti_aliasing + self.anti_aliasing_sigma = anti_aliasing_sigma + self.dtype = dtype + + def __call__( + self, + img: NdarrayOrTensor, + mode: Optional[Union[InterpolateMode, str]] = None, + align_corners: Optional[bool] = None, + anti_aliasing: Optional[bool] = None, + anti_aliasing_sigma: Union[Sequence[float], float, None] = None + ) -> NdarrayOrTensor: + mode_ = mode or self.mode + align_corners_ = align_corners or self.align_corners + anti_aliasing_ = anti_aliasing or self.anti_aliasing + anti_aliasing_sigma_ = anti_aliasing_sigma or self.anti_aliasing_sigma + + img_t, transform, metadata = resize(img, self.spatial_size, self.size_mode, mode_, + align_corners_, anti_aliasing_, anti_aliasing_sigma_, + self.dtype) + + # TODO: candidate for refactoring into a LazyTransform method + img_t.push_pending_transform(MetaMatrix(transform, metadata)) + if not self.lazy_evaluation: + img_t = apply(img_t) + + return img_t + + +class Rotate(LazyTransform, InvertibleTransform): def __init__( self, @@ -23,8 +129,10 @@ def __init__( mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, align_corners: bool = False, - dtype: Union[DtypeLike, torch.dtype] = np.float32 + dtype: Union[DtypeLike, torch.dtype] = np.float32, + lazy_evaluation: Optional[bool] = False ): + LazyTransform.__init__(self, lazy_evaluation) self.angle = angle self.keep_size = keep_size self.mode = mode @@ -35,32 +143,32 @@ def __init__( def __call__( self, img: NdarrayOrTensor, - mapping_stack: Optional[MappingStack] = None, mode: Optional[Union[InterpolateMode, str]] = None, padding_mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None, align_corners: Optional[bool] = None, ) -> NdarrayOrTensor: - mode = self.mode if mode is None else mode - padding_mode = self.padding_mode if padding_mode is None else padding_mode - align_corners = self.align_corners if align_corners is None else align_corners + angle = self.angle + mode = mode or self.mode + padding_mode = padding_mode or self.padding_mode + align_corners = align_corners or self.align_corners keep_size = self.keep_size dtype = self.dtype - matrix_factory = MatrixFactory(len(img.shape)-1, - get_backend_from_data(img), - get_device_from_data(img)) - if mapping_stack is None: - mapping_stack = MappingStack(matrix_factory) - mapping_stack.push(matrix_factory.rotate_euler(self.angle, - **{ - "padding_mode": padding_mode, - "mode": mode, - "align_corners": align_corners, - "keep_size": keep_size, - "dtype": dtype - })) - - -class Zoom(Transform): + + img_t, transform, metadata = rotate(img, angle, keep_size, mode, padding_mode, + align_corners, dtype) + + # TODO: candidate for refactoring into a LazyTransform method + img_t.push_pending_transform(MetaMatrix(transform, metadata)) + if not self.lazy_evaluation: + img_t = apply(img_t) + + return img_t + + def inverse(self, data): + raise NotImplementedError() + + +class Zoom(LazyTransform, InvertibleTransform): """ Zoom into / out of the image applying the `zoom` factor as a scalar, or if `zoom` is a tuple of values, apply each zoom factor to the appropriate dimension. @@ -87,109 +195,128 @@ def __init__( def __call__( self, img: NdarrayOrTensor, - mapping_stack: Optional[MappingStack] = None, mode: Optional[Union[InterpolateMode, str]] = None, padding_mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None, align_corners: Optional[bool] = None ) -> NdarrayOrTensor: - mode = self.mode if mode is None else mode - padding_mode = self.padding_mode if padding_mode is None else padding_mode - align_corners = self.align_corners if align_corners is None else align_corners + mode = self.mode or mode + padding_mode = self.padding_mode or padding_mode + align_corners = self.align_corners or align_corners keep_size = self.keep_size dtype = self.dtype - matrix_factory = MatrixFactory(len(img.shape)-1, - get_backend_from_data(img), - get_device_from_data(img)) - if mapping_stack is None: - mapping_stack = MappingStack(matrix_factory) - mapping_stack.push(matrix_factory.scale(self.zoom, - **{ - "padding_mode": padding_mode, - "mode": mode, - "align_corners": align_corners, - "keep_size": keep_size, - "dtype": dtype - })) - img.add - - -class Resize(Transform): - def __init__( - self, - spatial_size: Union[Sequence[int], int], - size_mode: str = "all", - mode: Union[InterpolateMode, str] = InterpolateMode.AREA, - align_corners: Optional[bool] = None, - anti_aliasing: bool = False, - anti_aliasing_sigma: Union[Sequence[float], float, None] = None - ): - self.spatial_size = spatial_size - self.size_mode = size_mode - self.mode = mode, - self.align_corners = align_corners - self.anti_aliasing = anti_aliasing - self.anti_aliasing_sigma = anti_aliasing_sigma + img_t, transform, metadata = zoom(img, self.zoom, mode, padding_mode, align_corners, + keep_size, dtype) - def __call__( - self, - img: NdarrayOrTensor, - mapping_stack: Optional[MappingStack] = None, - mode: Optional[Union[InterpolateMode, str]] = None, - align_corners: Optional[bool] = None, - anti_aliasing: Optional[bool] = None, - anti_aliasing_sigma: Union[Sequence[float], float, None] = None - ) -> NdarrayOrTensor: - mode = self.mode if mode is None else mode - align_corners = self.align_corners if align_corners is None else align_corners - keep_size = self.keep_size - dtype = self.dtype - matrix_factory = MatrixFactory(len(img.shape)-1, - get_backend_from_data(img), - get_device_from_data(img)) - if mapping_stack is None: - mapping_stack = MappingStack(matrix_factory) - mapping_stack.push(matrix_factory.scale(self.zoom, - **{ - "mode": mode, - "align_corners": align_corners, - "keep_size": keep_size, - "dtype": dtype - })) - - -class Spacing(Transform): + # TODO: candidate for refactoring into a LazyTransform method + img_t.push_pending_transform(MetaMatrix(transform, metadata)) + if not self.lazy_evaluation: + img_t = apply(img_t) + + return img_t + + def inverse(self, data): + raise NotImplementedError() + + +class RandRotate(RandomizableTransform, InvertibleTransform, LazyTransform): def __init__( self, - pixdim: Union[Sequence[float], float, np.ndarray], - diagonal: bool = False, + range_x: Optional[Union[Tuple[float, float], float]] = 0.0, + range_y: Optional[Union[Tuple[float, float], float]] = 0.0, + range_z: Optional[Union[Tuple[float, float], float]] = 0.0, + prob: Optional[float] = 0.1, + keep_size: bool = True, mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, align_corners: bool = False, - dtype: DtypeLike = np.float64, - image_only: bool = False + dtype: Union[DtypeLike, torch.dtype] = np.float32 ): - self.pixdim = pixdim - self.diagonal = diagonal + RandomizableTransform.__init__(self, prob) + self.range_x = ensure_tuple(range_x) + if len(self.range_x) == 1: + self.range_x = tuple(sorted([-self.range_x[0], self.range_x[0]])) + self.range_y = ensure_tuple(range_y) + if len(self.range_y) == 1: + self.range_y = tuple(sorted([-self.range_y[0], self.range_y[0]])) + self.range_z = ensure_tuple(range_z) + if len(self.range_z) == 1: + self.range_z = tuple(sorted([-self.range_z[0], self.range_z[0]])) + + self.keep_size = keep_size self.mode = mode self.padding_mode = padding_mode self.align_corners = align_corners self.dtype = dtype - self.image_only = image_only + + self.x = 0.0 + self.y = 0.0 + self.z = 0.0 + + def randomize(self, data: Optional[Any] = None) -> None: + super().randomize(None) + if not self._do_transform: + return None + self.x = self.R.uniform(low=self.range_x[0], high=self.range_x[1]) + self.y = self.R.uniform(low=self.range_y[0], high=self.range_y[1]) + self.z = self.R.uniform(low=self.range_z[0], high=self.range_z[1]) def __call__( self, img: NdarrayOrTensor, - mapping_stack: Optional[MappingStack] = None, - affine: Optional[NdarrayOrTensor] = None, - mode: Optional[Union[GridSampleMode, str]] = None, - padding_mode: Optional[Union[GridSamplePadMode, str]] = None, + mode: Optional[Union[InterpolateMode, str]] = None, + padding_mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None, align_corners: Optional[bool] = None, - dtype: DtypeLike = None, - output_spatial_shape: Optional[Union[Sequence[int], np.ndarray, int]] = None - ) -> Union[NdarrayOrTensor, Tuple[NdarrayOrTensor, NdarrayOrTensor, NdarrayOrTensor]]: - pass + dtype: Optional[Union[DtypeLike, torch.dtype]] = None, + randomize: Optional[bool] = True, + get_matrix: Optional[bool] = False + + ) -> NdarrayOrTensor: + + if randomize: + self.randomize() + + img_dims = len(img.shape) - 1 + if self._do_transform: + angle = self.x if img_dims == 2 else (self.x, self.y, self.z) + else: + angle = 0 if img_dims == 2 else (0, 0, 0) + + mode = self.mode or mode + padding_mode = self.padding_mode or padding_mode + align_corners = self.align_corners or align_corners + keep_size = self.keep_size + dtype = self.dtype + + img_t, transform, metadata = rotate(img, angle, keep_size, mode, padding_mode, + align_corners, dtype) + + # TODO: candidate for refactoring into a LazyTransform method + img_t.push_pending_transform(MetaMatrix(transform, metadata)) + if not self.lazy_evaluation: + img_t = apply(img_t) + + return img_t + + def inverse( + self, + data: NdarrayOrTensor, + ): + raise NotImplementedError() +# Snippet of code for pushing transform to metadata - pulled from Rotate + # img = self._post_process(img, img.spatial_shape, sp_size, *args) + # img.spatial_shape = sp_size # type: ignore + # self.update_meta(img, orig_size, sp_size) + # self.push_transform( + # img, + # orig_size=orig_size, + # extra_info={ + # "mode": mode, + # "align_corners": align_corners if align_corners is not None else TraceKeys.NONE, + # "new_dim": len(orig_size) - ndim, # additional dims appended + # }, + # ) \ No newline at end of file diff --git a/monai/transforms/atmostonce/compose.py b/monai/transforms/atmostonce/compose.py new file mode 100644 index 0000000000..8ac6d3beba --- /dev/null +++ b/monai/transforms/atmostonce/compose.py @@ -0,0 +1,193 @@ +import warnings +from typing import Any, Callable, Optional, Sequence, Union + +import numpy as np + +import torch + +from monai.transforms.atmostonce.lazy_transform import LazyTransform, compile_transforms, flatten_sequences +from monai.utils import GridSampleMode, GridSamplePadMode, ensure_tuple, get_seed, MAX_SEED + +from monai.transforms import Randomizable, InvertibleTransform, OneOf, apply_transform + + +# TODO: this is intended to replace Compose once development is done + + +class Compose(Randomizable, InvertibleTransform): + """ + ``Compose`` provides the ability to chain a series of callables together in + a sequential manner. Each transform in the sequence must take a single + argument and return a single value. + + ``Compose`` can be used in two ways: + + #. With a series of transforms that accept and return a single + ndarray / tensor / tensor-like parameter. + #. With a series of transforms that accept and return a dictionary that + contains one or more parameters. Such transforms must have pass-through + semantics that unused values in the dictionary must be copied to the return + dictionary. It is required that the dictionary is copied between input + and output of each transform. + + If some transform takes a data item dictionary as input, and returns a + sequence of data items in the transform chain, all following transforms + will be applied to each item of this list if `map_items` is `True` (the + default). If `map_items` is `False`, the returned sequence is passed whole + to the next callable in the chain. + + For example: + + A `Compose([transformA, transformB, transformC], + map_items=True)(data_dict)` could achieve the following patch-based + transformation on the `data_dict` input: + + #. transformA normalizes the intensity of 'img' field in the `data_dict`. + #. transformB crops out image patches from the 'img' and 'seg' of + `data_dict`, and return a list of three patch samples:: + + {'img': 3x100x100 data, 'seg': 1x100x100 data, 'shape': (100, 100)} + applying transformB + ----------> + [{'img': 3x20x20 data, 'seg': 1x20x20 data, 'shape': (20, 20)}, + {'img': 3x20x20 data, 'seg': 1x20x20 data, 'shape': (20, 20)}, + {'img': 3x20x20 data, 'seg': 1x20x20 data, 'shape': (20, 20)},] + + #. transformC then randomly rotates or flips 'img' and 'seg' of + each dictionary item in the list returned by transformB. + + The composed transforms will be set the same global random seed if user called + `set_determinism()`. + + When using the pass-through dictionary operation, you can make use of + :class:`monai.transforms.adaptors.adaptor` to wrap transforms that don't conform + to the requirements. This approach allows you to use transforms from + otherwise incompatible libraries with minimal additional work. + + Note: + + In many cases, Compose is not the best way to create pre-processing + pipelines. Pre-processing is often not a strictly sequential series of + operations, and much of the complexity arises when a not-sequential + set of functions must be called as if it were a sequence. + + Example: images and labels + Images typically require some kind of normalization that labels do not. + Both are then typically augmented through the use of random rotations, + flips, and deformations. + Compose can be used with a series of transforms that take a dictionary + that contains 'image' and 'label' entries. This might require wrapping + `torchvision` transforms before passing them to compose. + Alternatively, one can create a class with a `__call__` function that + calls your pre-processing functions taking into account that not all of + them are called on the labels. + + Args: + transforms: sequence of callables. + map_items: whether to apply transform to each item in the input `data` if `data` is a list or tuple. + defaults to `True`. + unpack_items: whether to unpack input `data` with `*` as parameters for the callable function of transform. + defaults to `False`. + log_stats: whether to log the detailed information of data and applied transform when error happened, + for NumPy array and PyTorch Tensor, log the data shape and value range, + for other metadata, log the values directly. default to `False`. + lazy_resample: whether to compute consecutive spatial transforms resampling lazily. Default to False. + mode: {``"bilinear"``, ``"nearest"``} + Interpolation mode when ``lazy_resample=True``. Defaults to ``"bilinear"``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + When `USE_COMPILED` is `True`, this argument uses + ``"nearest"``, ``"bilinear"``, ``"bicubic"`` to indicate 0, 1, 3 order interpolations. + See also: https://docs.monai.io/en/stable/networks.html#grid-pull + padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} + Padding mode for outside grid values when ``lazy_resample=True``. Defaults to ``"border"``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + + """ + + def __init__( + self, + transforms: Optional[Union[Sequence[Callable], Callable]] = None, + map_items: bool = True, + unpack_items: bool = False, + log_stats: bool = False, + mode=GridSampleMode.BILINEAR, + padding_mode=GridSamplePadMode.BORDER, + lazy_evaluation: bool = False + ) -> None: + if transforms is None: + transforms = [] + self.transforms = ensure_tuple(transforms) + + if lazy_evaluation is True: + self.dst_transforms = compile_transforms(self.transforms) + else: + self.dst_transforms = flatten_sequences(self.transforms) + + self.map_items = map_items + self.unpack_items = unpack_items + self.log_stats = log_stats + self.mode = mode + self.padding_mode = padding_mode + self.lazy_evaluation = lazy_evaluation + self.set_random_state(seed=get_seed()) + + if self.lazy_evaluation: + for t in self.dst_transforms: + if isinstance(t, LazyTransform): + t.lazy_evaluation = True + + def set_random_state(self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None) -> "Compose": + super().set_random_state(seed=seed, state=state) + for _transform in self.transforms: + if not isinstance(_transform, Randomizable): + continue + _transform.set_random_state(seed=self.R.randint(MAX_SEED, dtype="uint32")) + return self + + def randomize(self, data: Optional[Any] = None) -> None: + for _transform in self.transforms: + if not isinstance(_transform, Randomizable): + continue + try: + _transform.randomize(data) + except TypeError as type_error: + tfm_name: str = type(_transform).__name__ + warnings.warn( + f'Transform "{tfm_name}" in Compose not randomized\n{tfm_name}.{type_error}.', RuntimeWarning + ) + + # TODO: this is a more general function that could be implemented elsewhere + def flatten(self): + """Return a Composition with a simple list of transforms, as opposed to any nested Compositions. + + e.g., `t1 = Compose([x, x, x, x, Compose([Compose([x, x]), x, x])]).flatten()` + will result in the equivalent of `t1 = Compose([x, x, x, x, x, x, x, x])`. + + """ + new_transforms = [] + for t in self.transforms: + if isinstance(t, Compose) and not isinstance(t, OneOf): + new_transforms += t.flatten().transforms + else: + new_transforms.append(t) + + return Compose(new_transforms) + + def __len__(self): + """Return number of transformations.""" + return len(self.flatten().transforms) + + def __call__(self, input_): + for _transform in self.dst_transforms: + input_ = apply_transform(_transform, input_, self.map_items, self.unpack_items, self.log_stats) + return input_ + + def inverse(self, data): + invertible_transforms = [t for t in self.flatten().transforms if isinstance(t, InvertibleTransform)] + if not invertible_transforms: + warnings.warn("inverse has been called but no invertible transforms have been supplied") + + # loop backwards over transforms + for t in reversed(invertible_transforms): + data = apply_transform(t.inverse, data, self.map_items, self.unpack_items, self.log_stats) + return data \ No newline at end of file diff --git a/monai/transforms/atmostonce/dictionary.py b/monai/transforms/atmostonce/dictionary.py index f89f8756d7..dbdf331298 100644 --- a/monai/transforms/atmostonce/dictionary.py +++ b/monai/transforms/atmostonce/dictionary.py @@ -1,14 +1,19 @@ -from typing import Mapping, Optional, Sequence, Union +from typing import Any, Mapping, Optional, Sequence, Union import numpy as np import torch -from monai.config import KeysCollection +from monai.transforms.atmostonce.functional import rotate +from monai.utils import ensure_tuple, ensure_tuple_rep + +from monai.config import KeysCollection, DtypeLike, SequenceStr +from monai.transforms.atmostonce.apply import apply +from monai.transforms.atmostonce.lazy_transform import LazyTransform from monai.transforms.inverse import InvertibleTransform from monai.transforms.transform import MapTransform -from monai.utils.enums import TransformBackends -from monai.utils.mapping_stack import MappingStack, MatrixFactory +from monai.utils.enums import TransformBackends, GridSampleMode, GridSamplePadMode +from monai.utils.mapping_stack import MatrixFactory, MetaMatrix from monai.utils.type_conversion import expand_scalar_to_tuple @@ -56,7 +61,7 @@ def __call__(self, dims = len(data.shape)-1 device = get_device_from_data(data) backend = get_backend_from_data(data) - v = mappings.get(k, MappingStack(MatrixFactory(dims, backend, device))) + v = None # mappings.get(k, MappingStack(MatrixFactory(dims, backend, device))) v.push(self.get_matrix(dims, backend, device, *args, **kwargs)) mappings[k] = v rd[k] = data @@ -70,33 +75,55 @@ def get_matrix(self, dims, backend, device, *args, **kwargs): raise NotImplementedError(msg) -class RotateEulerd(MapTransform, InvertibleTransform): +class Rotated(LazyTransform, MapTransform, InvertibleTransform): def __init__(self, keys: KeysCollection, - euler_radians: Union[Sequence[float], float]): + angle: Union[Sequence[float], float], + keep_size: bool = True, + mode: Optional[SequenceStr] = GridSampleMode.BILINEAR, + padding_mode: Optional[SequenceStr] = GridSamplePadMode.BORDER, + align_corners: Optional[Union[Sequence[bool], bool]] = False, + dtype: Optional[Union[Sequence[Union[DtypeLike, torch.dtype]], DtypeLike, torch.dtype]] = np.float32, + allow_missing_keys: Optional[bool] = False, + lazy_evaluation: Optional[bool] = False + ): super().__init__(self) self.keys = keys - self.euler_radians = expand_scalar_to_tuple(euler_radians, len(keys)) + self.angle = angle + self.keep_size = keep_size + self.modes = ensure_tuple_rep(mode, len(keys)) + self.padding_modes = ensure_tuple_rep(padding_mode, len(keys)) + self.align_corners = align_corners + self.dtypes = ensure_tuple_rep(dtype, len(keys)) + self.allow_missing_keys = allow_missing_keys def __call__(self, d: Mapping): - mappings = d.get("mappings", dict()) - rd = dict() - for k in self.keys: - data = d[k] - dims = len(data.shape)-1 - device = get_device_from_data(data) - backend = get_backend_from_data(data) - matrix_factory = MatrixFactory(dims, backend, device) - v = mappings.get(k, MappingStack(matrix_factory)) - v.push(matrix_factory.rotate_euler(self.euler_radians)) - mappings[k] = v - rd[k] = data + rd = dict(d) + if self.allow_missing_keys is True: + keys_present = {k for k in self.keys if k in d} + else: + keys_present = self.keys - rd["mappings"] = mappings + for ik, k in enumerate(keys_present): + img = d[k] + + img_t, transform, metadata = rotate(img, self.angle, self.keep_size, + self.modes[ik], self.padding_modes[ik], + self.align_corners, self.dtypes[ik]) + + # TODO: candidate for refactoring into a LazyTransform method + img_t.push_pending_transform(MetaMatrix(transform, metadata)) + if not self.lazy_evaluation: + img_t = apply(img_t) + + rd[k] = img_t return rd + def inverse(self, data: Any): + raise NotImplementedError() + class Translated(MapTransform, InvertibleTransform): @@ -116,13 +143,11 @@ def __call__(self, d: Mapping): device = get_device_from_data(data) backend = get_backend_from_data(data) matrix_factory = MatrixFactory(dims, backend, device) - v = mappings.get(k, MappingStack(matrix_factory)) + v = None # mappings.get(k, MappingStack(matrix_factory)) v.push(matrix_factory.translate(self.translate)) mappings[k] = v rd[k] = data - rd["mappings"] = mappings - return rd @@ -144,7 +169,7 @@ def __call__(self, d: Mapping): device = get_device_from_data(data) backend = get_backend_from_data(data) matrix_factory = MatrixFactory(dims, backend, device) - v = mappings.get(k, MappingStack(matrix_factory)) + v = None # mappings.get(k, MappingStack(matrix_factory)) v.push(matrix_factory.scale(self.scale)) mappings[k] = v rd[k] = data @@ -153,40 +178,3 @@ def __call__(self, d: Mapping): return rd - -# class RotateEulerd(MappingStackTransformd): -# -# def __init__(self, -# keys: KeysCollection, -# euler_radians: Union[Sequence[float], float]): -# super().__init__(keys) -# self.euler_radians = euler_radians -# -# def get_matrix(self, dims, backend, device, *args, **kwargs): -# euler_radians = args[0] if len(args) > 0 else None -# euler_radians = kwargs.get("euler_radians", None) if euler_radians is None -# euler_radians = self.euler_radians if euler_radians is None else self.euler_radians -# if euler_radians is None: -# raise ValueError("'euler_radians' must be set during initialisation or passed in" -# "during __call__") -# arg = euler_radians if self.euler_radians is None else euler_radians -# return MatrixFactory(dims, backend, device).rotate_euler(arg) -# -# -# class ScaleEulerd(MappingStackTransformd): -# -# def __init__(self, -# keys: KeysCollection, -# scale: Union[Sequence[float], float]): -# super().__init__(keys) -# self.scale = scale -# -# def get_matrix(self, dims, backend, device, *args, **kwargs): -# scale = args[0] if len(args) > 0 else None -# scale = kwargs.get("scale", None) if scale is None -# scale = self.scale if scale is None else self.euler_radians -# if scale is None: -# raise ValueError("'scale' must be set during initialisation or passed in" -# "during __call__") -# arg = scale if self.scale is None else scale -# return MatrixFactory(dims, backend, device).scale(arg) \ No newline at end of file diff --git a/monai/transforms/atmostonce/functional.py b/monai/transforms/atmostonce/functional.py new file mode 100644 index 0000000000..b054b3a454 --- /dev/null +++ b/monai/transforms/atmostonce/functional.py @@ -0,0 +1,278 @@ +from typing import Optional, Sequence, Union + +import numpy as np + +import torch +from monai.transforms import create_rotate, create_translate, GaussianSmooth + +from monai.data import get_track_meta +from monai.transforms.atmostonce.apply import extents_from_shape, shape_from_extents +from monai.utils import convert_to_tensor, get_equivalent_dtype, ensure_tuple_rep, convert_to_dst_type, look_up_option, \ + GridSampleMode, GridSamplePadMode, fall_back_tuple, ensure_tuple_size, ensure_tuple, InterpolateMode, NumpyPadMode + +from monai.config import DtypeLike +from monai.utils.mapping_stack import MetaMatrix, MatrixFactory + + +def spacing( + img: torch.Tensor, + pixdim: Union[Sequence[float], float], + src_pixdim: Union[Sequence[float], float], + diagonal: Optional[bool] = False, + mode: Optional[Union[InterpolateMode, str]] = InterpolateMode.AREA, + padding_mode: Optional[Union[NumpyPadMode, GridSamplePadMode, str]] = NumpyPadMode.EDGE, + align_corners: Optional[bool] = False, + dtype: Optional[Union[DtypeLike, torch.dtype]] = None +): + """ + Args: + img: channel first array, must have shape: (num_channels, H[, W, ..., ]). + mode: {``"nearest"``, ``"nearest-exact"``, ``"linear"``, + ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} + The interpolation mode. Defaults to ``self.mode``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html + align_corners: This only has an effect when mode is + 'linear', 'bilinear', 'bicubic' or 'trilinear'. Defaults to ``self.align_corners``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html + anti_aliasing: bool, optional + Whether to apply a Gaussian filter to smooth the image prior + to downsampling. It is crucial to filter when downsampling + the image to avoid aliasing artifacts. See also ``skimage.transform.resize`` + anti_aliasing_sigma: {float, tuple of floats}, optional + Standard deviation for Gaussian filtering used when anti-aliasing. + By default, this value is chosen as (s - 1) / 2 where s is the + downsampling factor, where s > 1. For the up-size case, s < 1, no + anti-aliasing is performed prior to rescaling. + + Raises: + ValueError: When ``self.spatial_size`` length is less than ``img`` spatial dimensions. + + """ + + img_ = convert_to_tensor(img, track_meta=get_track_meta()) + input_ndim = len(img.shape) - 1 + + pixdim_ = ensure_tuple_rep(pixdim, input_ndim) + src_pixdim_ = ensure_tuple_rep(src_pixdim, input_ndim) + + if diagonal is True: + raise ValueError("'diagonal' value of True is not currently supported") + + mode_ = look_up_option(mode, GridSampleMode) + padding_mode_ = look_up_option(padding_mode, GridSamplePadMode) + dtype_ = get_equivalent_dtype(dtype or img.dtype, torch.Tensor) + zoom_factors = [i / j for i, j in zip(src_pixdim_, pixdim_)] + + transform = MatrixFactory.from_tensor(img).scale(zoom_factors) + im_extents = extents_from_shape(img.shape) + im_extents = [transform.matrix.matrix @ e for e in im_extents] + spatial_shape_ = shape_from_extents(im_extents) + + metadata = { + "pixdim": pixdim_, + "src_pixdim": src_pixdim_, + "diagonal": diagonal, + "mode": mode_, + "padding_mode": padding_mode_, + "align_corners": align_corners, + "dtype": dtype_, + "im_extents": im_extents, + "spatial_shape": spatial_shape_ + } + return img_, transform, metadata + + +def resize( + img: torch.Tensor, + spatial_size: Union[Sequence[int], int], + size_mode: str = "all", + mode: Optional[Union[InterpolateMode, str]] = InterpolateMode.AREA, + align_corners: Optional[bool] = False, + anti_aliasing: Optional[bool] = None, + anti_aliasing_sigma: Optional[Union[Sequence[float], float]] = None, + dtype: Optional[Union[DtypeLike, torch.dtype]] = None +): + """ + Args: + img: channel first array, must have shape: (num_channels, H[, W, ..., ]). + mode: {``"nearest"``, ``"nearest-exact"``, ``"linear"``, + ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} + The interpolation mode. Defaults to ``self.mode``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html + align_corners: This only has an effect when mode is + 'linear', 'bilinear', 'bicubic' or 'trilinear'. Defaults to ``self.align_corners``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html + anti_aliasing: bool, optional + Whether to apply a Gaussian filter to smooth the image prior + to downsampling. It is crucial to filter when downsampling + the image to avoid aliasing artifacts. See also ``skimage.transform.resize`` + anti_aliasing_sigma: {float, tuple of floats}, optional + Standard deviation for Gaussian filtering used when anti-aliasing. + By default, this value is chosen as (s - 1) / 2 where s is the + downsampling factor, where s > 1. For the up-size case, s < 1, no + anti-aliasing is performed prior to rescaling. + + Raises: + ValueError: When ``self.spatial_size`` length is less than ``img`` spatial dimensions. + + """ + + img_ = convert_to_tensor(img, track_meta=get_track_meta()) + input_ndim = len(img.shape) - 1 + + if size_mode == "all": + output_ndim = len(ensure_tuple(spatial_size)) + if output_ndim > input_ndim: + input_shape = ensure_tuple_size(img.shape, output_ndim + 1, 1) + img = img.reshape(input_shape) + elif output_ndim < input_ndim: + raise ValueError( + "len(spatial_size) must be greater or equal to img spatial dimensions, " + f"got spatial_size={output_ndim} img={input_ndim}." + ) + spatial_size_ = fall_back_tuple(spatial_size, img.shape[1:]) + else: # for the "longest" mode + img_size = img.shape[1:] + if not isinstance(spatial_size, int): + raise ValueError("spatial_size must be an int number if size_mode is 'longest'.") + scale = spatial_size / max(img_size) + spatial_size_ = tuple(int(round(s * scale)) for s in img_size) + + mode_ = look_up_option(mode, GridSampleMode) + dtype_ = get_equivalent_dtype(dtype or img.dtype, torch.Tensor) + zoom_factors = [i / j for i, j in zip(spatial_size, img.shape[1:])] + transform = MatrixFactory.from_tensor(img).scale(zoom_factors) + im_extents = extents_from_shape(img.shape) + im_extents = [transform.matrix.matrix @ e for e in im_extents] + spatial_shape_ = shape_from_extents(im_extents) + + metadata = { + "spatial_size": spatial_size_, + "size_mode": size_mode, + "mode": mode_, + "align_corners": align_corners, + "anti_aliasing": anti_aliasing, + "anti_aliasing_sigma": anti_aliasing_sigma, + "dtype": dtype_, + "im_extents": im_extents, + "spatial_shape": spatial_shape_ + } + return img_, transform, metadata + + +def rotate( + img: torch.Tensor, + angle: Union[Sequence[float], float], + keep_size: Optional[bool] = True, + mode: Optional[Union[InterpolateMode, str]] = InterpolateMode.AREA, + padding_mode: Optional[Union[NumpyPadMode, GridSamplePadMode, str]] = NumpyPadMode.EDGE, + align_corners: Optional[bool] = False, + dtype: Optional[Union[DtypeLike, torch.dtype]] = None +): + """ + Args: + img: channel first array, must have shape: [chns, H, W] or [chns, H, W, D]. + angle: Rotation angle(s) in radians. should a float for 2D, three floats for 3D. + keep_size: If it is True, the output shape is kept the same as the input. + If it is False, the output shape is adapted so that the + input array is contained completely in the output. Default is True. + mode: {``"bilinear"``, ``"nearest"``} + Interpolation mode to calculate output values. Defaults to ``self.mode``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} + Padding mode for outside grid values. Defaults to ``self.padding_mode``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + align_corners: Defaults to ``self.align_corners``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + align_corners: Defaults to ``self.align_corners``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + dtype: data type for resampling computation. Defaults to ``self.dtype``. + If None, use the data type of input data. To be compatible with other modules, + the output data type is always ``np.float32``. + + Raises: + ValueError: When ``img`` spatially is not one of [2D, 3D]. + + """ + + img_ = convert_to_tensor(img, track_meta=get_track_meta()) + mode_ = look_up_option(mode, GridSampleMode) + padding_mode_ = look_up_option(padding_mode, GridSamplePadMode) + dtype_ = get_equivalent_dtype(dtype or img.dtype, torch.Tensor) + input_ndim = len(img_.shape) - 1 + if input_ndim not in (2, 3): + raise ValueError(f"Unsupported image dimension: {input_ndim}, available options are [2, 3].") + angle_ = ensure_tuple_rep(angle, 1 if input_ndim == 2 else 3) + transform = create_rotate(input_ndim, angle_) + im_extents = extents_from_shape(img.shape) + if not keep_size: + im_extents = [transform @ e for e in im_extents] + spatial_shape = shape_from_extents(im_extents) + else: + spatial_shape = img_.shape + + metadata = { + "angle": angle_, + "keep_size": keep_size, + "mode": mode_, + "padding_mode": padding_mode_, + "align_corners": align_corners, + "dtype": dtype_, + "im_extents": im_extents, + "spatial_shape": spatial_shape + } + return img_, transform, metadata + + +def zoom( + img: torch.Tensor, + zoom: Union[Sequence[float], float], + mode: Optional[Union[InterpolateMode, str]] = InterpolateMode.AREA, + padding_mode: Optional[Union[NumpyPadMode, GridSamplePadMode, str]] = NumpyPadMode.EDGE, + align_corners: Optional[bool] = False, + keep_size: Optional[bool] = True, + dtype: Optional[Union[DtypeLike, torch.dtype]] = None +): + """ + Args: + img: channel first array, must have shape: (num_channels, H[, W, ..., ]). + mode: {``"nearest"``, ``"nearest-exact"``, ``"linear"``, + ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} + The interpolation mode. Defaults to ``self.mode``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html + align_corners: This only has an effect when mode is + 'linear', 'bilinear', 'bicubic' or 'trilinear'. Defaults to ``self.align_corners``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html + + Raises: + ValueError: When ``self.spatial_size`` length is less than ``img`` spatial dimensions. + + """ + + img_ = convert_to_tensor(img, track_meta=get_track_meta()) + input_ndim = len(img.shape) - 1 + + zoom_factors = ensure_tuple_rep(zoom, input_ndim) + + + mode_ = look_up_option(mode, GridSampleMode) + padding_mode_ = look_up_option(padding_mode, GridSamplePadMode) + dtype_ = get_equivalent_dtype(dtype or img.dtype, torch.Tensor) + + transform = MatrixFactory.from_tensor(img).scale(zoom_factors) + im_extents = extents_from_shape(img.shape) + if keep_size is False: + im_extents = [transform.matrix.matrix @ e for e in im_extents] + spatial_shape_ = shape_from_extents(im_extents) + else: + spatial_shape_ = img_.shape + + metadata = { + "zoom": zoom_factors, + "mode": mode_, + "padding_mode": padding_mode_, + "align_corners": align_corners, + "keep_size": keep_size, + "dtype": dtype_ + } + return img_, transform, metadata diff --git a/monai/transforms/atmostonce/lazy_transform.py b/monai/transforms/atmostonce/lazy_transform.py new file mode 100644 index 0000000000..2731d58f47 --- /dev/null +++ b/monai/transforms/atmostonce/lazy_transform.py @@ -0,0 +1,75 @@ +import itertools as it + +from monai.config import NdarrayOrTensor + +from monai.data import MetaTensor +from monai.transforms import Transform +from monai.transforms.atmostonce.apply import Applyd +from monai.utils.mapping_stack import MetaMatrix + + +# TODO: move to mapping_stack.py +def push_transform( + data: MetaTensor, + meta_matrix: MetaMatrix +): + data.push_pending_transform(meta_matrix) + + +# TODO: move to mapping_stack.py +def update_metadata( + data: MetaTensor, + transform: NdarrayOrTensor, + extra_info +): + pass + + +# TODO: move to utils +def flatten_sequences(seq): + + def flatten_impl(s, accum): + if isinstance(s, (list, tuple)): + for inner_t in s: + accum = flatten_impl(inner_t, accum) + else: + accum.append(s) + return accum + + dest = [] + for s in seq: + dest = flatten_impl(s, dest) + + return dest + + +def transforms_compatible(current, next): + raise NotImplementedError() + + +def compile_transforms(transforms): + flat = flatten_sequences(transforms) + for i in range(len(flat)-1): + cur_t, next_t = flat[i], flat[i + 1] + if not transforms_compatible(cur_t, next_t): + flat.insert(i + 1, Applyd()) + return flat + + + + + +class LazyTransform: + + def __init__(self, lazy_evaluation): + self.lazy_evaluation = lazy_evaluation + + # TODO: determine whether to have an 'eval' defined here that implements laziness + # def __call__(self, *args, **kwargs): + # """Call this method after calculating your meta data""" + # if self.lazily_evaluate: + # # forward the transform to metatensor + # pass + # else: + # # apply the transform and reset the stack on metatensor + # pass \ No newline at end of file diff --git a/monai/utils/mapping_stack.py b/monai/utils/mapping_stack.py index 0c214df4ef..39ffac01ff 100644 --- a/monai/utils/mapping_stack.py +++ b/monai/utils/mapping_stack.py @@ -18,6 +18,8 @@ from monai.utils.enums import TransformBackends from monai.transforms.utils import (_create_rotate, _create_scale, _create_shear, _create_translate) +from monai.utils.misc import get_backend_from_data, get_device_from_data + class MatrixFactory: @@ -50,58 +52,130 @@ def __init__(self, self._backend = backend self._dims = dims + @staticmethod + def from_tensor(data): + return MatrixFactory(len(data.shape)-1, + get_backend_from_data(data), + get_device_from_data(data)) + def identity(self): - return self._eye(self._dims + 1) + matrix = self._eye(self._dims + 1) + return MetaMatrix(matrix, {}) def rotate_euler(self, radians: Union[Sequence[float], float], **extra_args): matrix = _create_rotate(self._dims, radians, self._sin, self._cos, self._eye) - return {"matrix": matrix, "args": extra_args} + return MetaMatrix(matrix, extra_args) - def shear(self, coefs: Union[Sequence[float], float]): - return _create_shear(self._dims, coefs, self._eye) + def shear(self, coefs: Union[Sequence[float], float], **extra_args): + matrix = _create_shear(self._dims, coefs, self._eye) + return MetaMatrix(matrix, extra_args) def scale(self, factors: Union[Sequence[float], float], **extra_args): matrix = _create_scale(self._dims, factors, self._diag) - return {"matrix": matrix, "args": extra_args} + return MetaMatrix(matrix, extra_args) - def translate(self, offsets: Union[Sequence[float], float]): - return _create_translate(self._dims, offsets, self._eye) + def translate(self, offsets: Union[Sequence[float], float], **extra_args): + matrix = _create_translate(self._dims, offsets, self._eye) + return MetaMatrix(matrix, extra_args) -class Mapping: +# class Mapping: +# +# def __init__(self, matrix): +# self._matrix = matrix +# +# def apply(self, other): +# return Mapping(other @ self._matrix) + + +class Dimensions: + + def __init__(self, flips, permutes): + raise NotImplementedError() + + def __matmul__(self, other): + raise NotImplementedError() + + def __rmatmul__(self, other): + raise NotImplementedError() - def __init__(self, matrix): - self._matrix = matrix - def apply(self, other): - return Mapping(other @ self._matrix) +class Matrix: + def __init__(self, matrix): + self.matrix = matrix -class MappingStack: - """ - This class keeps track of a series of mappings and apply them / calculate their inverse (if - mappings are invertible). Mapping stacks are used to generate a mapping that gets applied during a `Resample` / - `Resampled` transform. + def __matmul__(self, other): + if isinstance(other, Matrix): + other_matrix = other.matrix + else: + other_matrix = other + return self.matrix @ other_matrix - A mapping is one of: - - a description of a change to a numpy array that only requires index manipulation instead of an actual resample. - - a homogeneous matrix representing a geometric transform to be applied during a resample - - a field representing a deformation to be applied during a resample - """ + def __rmatmul__(self, other): + return other.__matmul__(self.matrix) - def __init__(self, factory: MatrixFactory): - self.factory = factory - self.stack = [] - self.applied_stack = [] - def push(self, mapping): - self.stack.append(mapping) +# TODO: remove if the existing Grid is fine for our purposes +class Grid: + def __init__(self, grid): + raise NotImplementedError() - def pop(self): + def __matmul__(self, other): raise NotImplementedError() - def transform(self): - m = Mapping(self.factory.identity()) - for t in self.stack: - m = m.apply(t) - return m + +class MetaMatrix: + + def __init__(self, matrix, metadata=None): + if not isinstance(matrix, (Matrix, Grid)): + matrix_ = Matrix(matrix) + else: + matrix_ = matrix + self.matrix = matrix_ + + self.metadata = metadata or {} + + def __matmul__(self, other): + if isinstance(other, MetaMatrix): + other_ = other.matrix + else: + other_ = other + return MetaMatrix(self.matrix @ other_) + + def __rmatmul__(self, other): + if isinstance(other, MetaMatrix): + other_ = other.matrix + else: + other_ = other + return MetaMatrix(other_ @ self.matrix) + + +# class MappingStack: +# """ +# This class keeps track of a series of mappings and apply them / calculate their inverse (if +# mappings are invertible). Mapping stacks are used to generate a mapping that gets applied during a `Resample` / +# `Resampled` transform. +# +# A mapping is one of: +# - a description of a change to a numpy array that only requires index manipulation instead of an actual resample. +# - a homogeneous matrix representing a geometric transform to be applied during a resample +# - a field representing a deformation to be applied during a resample +# """ +# +# def __init__(self, factory: MatrixFactory): +# self.factory = factory +# self.stack = [] +# self.applied_stack = [] +# +# def push(self, mapping): +# self.stack.append(mapping) +# +# def pop(self): +# raise NotImplementedError() +# +# def transform(self): +# m = Mapping(self.factory.identity()) +# for t in self.stack: +# m = m.apply(t) +# return m diff --git a/tests/test_atmostonce.py b/tests/test_atmostonce.py index 44a3f008bd..d0f3ab79aa 100644 --- a/tests/test_atmostonce.py +++ b/tests/test_atmostonce.py @@ -1,21 +1,194 @@ import unittest +import math + import numpy as np import matplotlib.pyplot as plt import torch +from monai.transforms.atmostonce import array as amoa +from monai.transforms.atmostonce.lazy_transform import compile_transforms +from monai.utils import TransformBackends + from monai.transforms import Affined -from monai.transforms.atmostonce.apply import Applyd -from monai.transforms.atmostonce.dictionary import RotateEulerd +from monai.transforms.atmostonce.functional import resize, rotate, spacing +from monai.transforms.atmostonce.apply import Applyd, extents_from_shape, shape_from_extents +from monai.transforms.atmostonce.dictionary import Rotated from monai.transforms.compose import Compose from monai.utils.enums import GridSampleMode, GridSamplePadMode +from monai.utils.mapping_stack import MatrixFactory, MetaMatrix + + +def enumerate_results_of_op(results): + if isinstance(results, dict): + for k, v in results.items(): + if isinstance(v, (np.ndarray, torch.Tensor)): + print(k, v.shape, v[tuple(slice(0, 8) for _ in r.shape)]) + else: + print(k, v) + else: + for ir, r in enumerate(results): + if isinstance(r, (np.ndarray, torch.Tensor)): + print(ir, r.shape, r[tuple(slice(0, 8) for _ in r.shape)]) + else: + print(ir, r) + + +class TestLowLevel(unittest.TestCase): + + def test_extents_2(self): + actual = extents_from_shape([1, 24, 32]) + expected = [np.asarray(v) for v in ((0, 0, 1), (0, 32, 1), (24, 0, 1), (24, 32, 1))] + self.assertTrue(np.all([np.array_equal(a, e) for a, e in zip(actual, expected)])) + + def test_extents_3(self): + actual = extents_from_shape([1, 12, 16, 8]) + expected = [np.asarray(v) for v in ((0, 0, 0, 1), (0, 0, 8, 1), (0, 16, 0, 1), (0, 16, 8, 1), + (12, 0, 0, 1), (12, 0, 8, 1), (12, 16, 0, 1), (12, 16, 8, 1))] + self.assertTrue(np.all([np.array_equal(a, e) for a, e in zip(actual, expected)])) + + def test_shape_from_extents(self): + actual = shape_from_extents([np.asarray([-16, -20, 1]), + np.asarray([-16, 20, 1]), + np.asarray([16, -20, 1]), + np.asarray([16, 20, 1])]) + print(actual) + + + def test_compile_transforms(self): + values = ["a", "b", ["c", ["d"], "e"], "f", ["g", "h"], "i"] + result = compile_transforms(values) + print(result) + + +class TestMappingStack(unittest.TestCase): + + def test_rotation_pi_by_2(self): + + fac = MatrixFactory(2, TransformBackends.NUMPY) + mat = fac.rotate_euler(torch.pi / 2) + expected = np.asarray([[0, -1, 0], + [1, 0, 0], + [0, 0, 1]]) + self.assertTrue(np.allclose(mat.matrix.matrix, expected)) + + def test_rotation_pi_by_4(self): + + fac = MatrixFactory(2, TransformBackends.NUMPY) + mat = fac.rotate_euler(torch.pi / 4) + piby4 = math.cos(torch.pi / 4) + expected = np.asarray([[piby4, -piby4, 0], + [piby4, piby4, 0], + [0, 0, 1]]) + self.assertTrue(np.allclose(mat.matrix.matrix, expected)) + + def test_rotation_pi_by_8(self): + fac = MatrixFactory(2, TransformBackends.NUMPY) + mat = fac.rotate_euler(torch.pi / 8) + cospi = math.cos(torch.pi / 8) + sinpi = math.sin(torch.pi / 8) + expected = np.asarray([[cospi, -sinpi, 0], + [sinpi, cospi, 0], + [0, 0, 1]]) + self.assertTrue(np.allclose(mat.matrix.matrix, expected)) + + def scale_by_2(self): + fac = MatrixFactory(2, TransformBackends.NUMPY) + mat = fac.scale(2) + expected = np.asarray([[2, 0, 0], + [0, 2, 0], + [0, 0, 1]]) + self.assertTrue(np.allclose(mat.matrix.matrix, expected)) + + # TODO: turn into proper test + def test_mult_matrices(self): + + fac = MatrixFactory(2, TransformBackends.NUMPY) + matrix1 = fac.translate((-16, -16)) + matrix2 = fac.rotate_euler(torch.pi / 4) + + matrix12 = matrix1 @ matrix2 + matrix21 = matrix2 @ matrix1 + + print("matrix12\n", matrix12.matrix.matrix) + print("matrix21\n", matrix21.matrix.matrix) + + extents = extents_from_shape([1, 32, 32]) + + print("matrix1") + for e in extents: + print(" ", e, matrix1.matrix.matrix @ e) + print("matrix2") + for e in extents: + print(" ", e, matrix2.matrix.matrix @ e) + print("matrix12") + for e in extents: + print(" ", e, matrix12.matrix.matrix @ e) + print("matrix21") + for e in extents: + print(" ", e, matrix21.matrix.matrix @ e) + + +class TestFunctional(unittest.TestCase): + + # TODO: turn into proper test + def test_spacing(self): + results = spacing(np.zeros((1, 24, 32), dtype=np.float32), + (0.5, 0.6), + (1.0, 1.0), + False, + "bilinear", + "border", + False) + + + # TODO: turn into proper test + def test_resize(self): + results = resize(np.zeros((1, 24, 32), dtype=np.float32), + (40, 40), + "all", + "bilinear", + False) + enumerate_results_of_op(results) + + # TODO: turn into proper test + def test_rotate(self): + results = rotate(np.zeros((1, 64, 64), dtype=np.float32), + torch.pi / 4, + True, + "bilinear", + "border") + enumerate_results_of_op(results) + + results = rotate(np.zeros((1, 64, 64), dtype=np.float32), + torch.pi / 4, + False, + "bilinear", + "border") + enumerate_results_of_op(results) + + +class TestArrayTransforms(unittest.TestCase): + + def test_rand_rotate(self): + r = amoa.RandRotate((-torch.pi / 4, torch.pi / 4), + prob=0.0, + keep_size=True, + mode="bilinear", + padding_mode="border", + align_corners=False) + img = np.zeros((1, 32, 32), dtype=np.float32) + results = r(img) + enumerate_results_of_op(results) + enumerate_results_of_op(results.pending_transforms[-1].metadata) + class TestRotateEulerd(unittest.TestCase): def test_rotate_numpy(self): - r = RotateEulerd(('image', 'label'), [0.0, 1.0, 0.0]) + r = Rotated(('image', 'label'), [0.0, 1.0, 0.0]) d = { 'image': np.zeros((1, 64, 64, 32), dtype=np.float32), @@ -30,7 +203,7 @@ def test_rotate_numpy(self): print(k, v) def test_rotate_tensor(self): - r = RotateEulerd(('image', 'label'), [0.0, 1.0, 0.0]) + r = Rotated(('image', 'label'), [0.0, 1.0, 0.0]) d = { 'image': torch.zeros((1, 64, 64, 32), device="cpu", dtype=torch.float32), @@ -46,7 +219,7 @@ def test_rotate_tensor(self): def test_rotate_apply(self): c = Compose([ - RotateEulerd(('image', 'label'), (0.0, 3.14159265 / 2, 0.0)), + Rotated(('image', 'label'), (0.0, 3.14159265 / 2, 0.0)), Applyd(('image', 'label'), modes=(GridSampleMode.BILINEAR, GridSampleMode.NEAREST), padding_modes=(GridSamplePadMode.BORDER, GridSamplePadMode.BORDER)) From 2dab7e57101bb3f474d406571c9dbd12935fc29a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 17 Aug 2022 13:00:51 +0000 Subject: [PATCH 09/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/transforms/atmostonce/apply.py | 3 +-- monai/transforms/atmostonce/array.py | 2 -- monai/transforms/atmostonce/dictionary.py | 4 ++-- monai/utils/type_conversion.py | 2 +- tests/test_atmostonce.py | 1 - tests/test_mapping_stack.py | 2 +- 6 files changed, 5 insertions(+), 9 deletions(-) diff --git a/monai/transforms/atmostonce/apply.py b/monai/transforms/atmostonce/apply.py index 4ccdf402b1..38ef0a2517 100644 --- a/monai/transforms/atmostonce/apply.py +++ b/monai/transforms/atmostonce/apply.py @@ -4,7 +4,7 @@ import torch -from monai.config import USE_COMPILED, DtypeLike, KeysCollection +from monai.config import DtypeLike, KeysCollection from monai.config.type_definitions import NdarrayOrTensor from monai.transforms.inverse import InvertibleTransform from monai.transforms.spatial.array import Resample @@ -14,7 +14,6 @@ from monai.utils.enums import TransformBackends from monai.utils.type_conversion import (convert_data_type, convert_to_dst_type, expand_scalar_to_tuple) -from monai.utils.mapping_stack import MappingStack # TODO: This should move to a common place to be shared with dictionary GridSampleModeSequence = Union[Sequence[Union[GridSampleMode, str]], GridSampleMode, str] diff --git a/monai/transforms/atmostonce/array.py b/monai/transforms/atmostonce/array.py index 50b26a5bcf..cff038bef5 100644 --- a/monai/transforms/atmostonce/array.py +++ b/monai/transforms/atmostonce/array.py @@ -136,5 +136,3 @@ def __call__( output_spatial_shape: Optional[Union[Sequence[int], np.ndarray, int]] = None ) -> Union[NdarrayOrTensor, Tuple[NdarrayOrTensor, NdarrayOrTensor, NdarrayOrTensor]]: pass - - diff --git a/monai/transforms/atmostonce/dictionary.py b/monai/transforms/atmostonce/dictionary.py index f89f8756d7..ddc4dea133 100644 --- a/monai/transforms/atmostonce/dictionary.py +++ b/monai/transforms/atmostonce/dictionary.py @@ -1,4 +1,4 @@ -from typing import Mapping, Optional, Sequence, Union +from typing import Mapping, Sequence, Union import numpy as np @@ -189,4 +189,4 @@ def __call__(self, d: Mapping): # raise ValueError("'scale' must be set during initialisation or passed in" # "during __call__") # arg = scale if self.scale is None else scale -# return MatrixFactory(dims, backend, device).scale(arg) \ No newline at end of file +# return MatrixFactory(dims, backend, device).scale(arg) diff --git a/monai/utils/type_conversion.py b/monai/utils/type_conversion.py index 0452df5bf4..24f2202e7e 100644 --- a/monai/utils/type_conversion.py +++ b/monai/utils/type_conversion.py @@ -10,7 +10,7 @@ # limitations under the License. import re -from typing import Any, SupportsInt, Optional, Sequence, Tuple, Type, TypeVar, Union +from typing import Any, Optional, Sequence, Tuple, Type, TypeVar, Union import numpy as np import torch diff --git a/tests/test_atmostonce.py b/tests/test_atmostonce.py index 44a3f008bd..b65ec8cf3f 100644 --- a/tests/test_atmostonce.py +++ b/tests/test_atmostonce.py @@ -2,7 +2,6 @@ import numpy as np -import matplotlib.pyplot as plt import torch diff --git a/tests/test_mapping_stack.py b/tests/test_mapping_stack.py index cd0c72417f..6c6fb1ec11 100644 --- a/tests/test_mapping_stack.py +++ b/tests/test_mapping_stack.py @@ -13,7 +13,7 @@ from monai.transforms.utils import TransformBackends -from monai.utils.mapping_stack import MappingStack, Mapping, MatrixFactory +from monai.utils.mapping_stack import MappingStack, MatrixFactory class MappingStackTest(unittest.TestCase): From 999698e0d44ab5767576f495fcbd694199e16889 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 18 Aug 2022 10:23:42 +0000 Subject: [PATCH 10/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/transforms/atmostonce/apply.py | 13 ++----------- monai/transforms/atmostonce/array.py | 11 +++++------ monai/transforms/atmostonce/compose.py | 3 +-- monai/transforms/atmostonce/dictionary.py | 2 +- monai/transforms/atmostonce/functional.py | 7 +++---- monai/transforms/atmostonce/lazy_transform.py | 4 +--- monai/utils/misc.py | 2 +- monai/utils/type_conversion.py | 2 +- tests/test_atmostonce.py | 3 +-- tests/test_mapping_stack.py | 2 +- 10 files changed, 17 insertions(+), 32 deletions(-) diff --git a/monai/transforms/atmostonce/apply.py b/monai/transforms/atmostonce/apply.py index d39557dac5..6c9fedc91a 100644 --- a/monai/transforms/atmostonce/apply.py +++ b/monai/transforms/atmostonce/apply.py @@ -1,23 +1,16 @@ -from typing import Dict, Hashable, Mapping, Optional, Sequence, Union +from typing import Sequence, Union import itertools as it import numpy as np -import torch -from monai.config import USE_COMPILED, DtypeLike, KeysCollection -from monai.config.type_definitions import NdarrayOrTensor +from monai.config import DtypeLike from monai.data import MetaTensor from monai.transforms.inverse import InvertibleTransform -from monai.transforms.spatial.array import Resample from monai.transforms.transform import MapTransform -from monai.transforms.utils import create_grid from monai.utils import GridSampleMode, GridSamplePadMode -from monai.utils.enums import TransformBackends from monai.utils.misc import get_backend_from_data, get_device_from_data -from monai.utils.type_conversion import (convert_data_type, convert_to_dst_type, - expand_scalar_to_tuple) from monai.utils.mapping_stack import MatrixFactory # TODO: This should move to a common place to be shared with dictionary @@ -79,14 +72,12 @@ class Apply(InvertibleTransform): def __init__(self): super().__init__() - pass class Applyd(MapTransform, InvertibleTransform): def __init__(self): super().__init__() - pass # class Applyd(MapTransform, InvertibleTransform): # diff --git a/monai/transforms/atmostonce/array.py b/monai/transforms/atmostonce/array.py index 420a344a6f..f3b112231c 100644 --- a/monai/transforms/atmostonce/array.py +++ b/monai/transforms/atmostonce/array.py @@ -5,18 +5,17 @@ import torch from monai.config import DtypeLike, NdarrayOrTensor -from monai.data import get_track_meta -from monai.transforms import Transform, InvertibleTransform, RandomizableTransform +from monai.transforms import InvertibleTransform, RandomizableTransform from monai.transforms.atmostonce.apply import apply from monai.transforms.atmostonce.functional import resize, rotate, zoom, spacing -from monai.transforms.atmostonce.lazy_transform import LazyTransform, push_transform +from monai.transforms.atmostonce.lazy_transform import LazyTransform from monai.utils import (GridSampleMode, GridSamplePadMode, - InterpolateMode, NumpyPadMode, PytorchPadMode, convert_to_tensor) -from monai.utils.mapping_stack import MatrixFactory, MetaMatrix -from monai.utils.misc import get_backend_from_data, get_device_from_data, ensure_tuple + InterpolateMode, NumpyPadMode, PytorchPadMode) +from monai.utils.mapping_stack import MetaMatrix +from monai.utils.misc import ensure_tuple # TODO: these transforms are intended to replace array transforms once development is done diff --git a/monai/transforms/atmostonce/compose.py b/monai/transforms/atmostonce/compose.py index 8ac6d3beba..6d7464f547 100644 --- a/monai/transforms/atmostonce/compose.py +++ b/monai/transforms/atmostonce/compose.py @@ -3,7 +3,6 @@ import numpy as np -import torch from monai.transforms.atmostonce.lazy_transform import LazyTransform, compile_transforms, flatten_sequences from monai.utils import GridSampleMode, GridSamplePadMode, ensure_tuple, get_seed, MAX_SEED @@ -190,4 +189,4 @@ def inverse(self, data): # loop backwards over transforms for t in reversed(invertible_transforms): data = apply_transform(t.inverse, data, self.map_items, self.unpack_items, self.log_stats) - return data \ No newline at end of file + return data diff --git a/monai/transforms/atmostonce/dictionary.py b/monai/transforms/atmostonce/dictionary.py index 4a77a2f867..c57023a82b 100644 --- a/monai/transforms/atmostonce/dictionary.py +++ b/monai/transforms/atmostonce/dictionary.py @@ -5,7 +5,7 @@ import torch from monai.transforms.atmostonce.functional import rotate -from monai.utils import ensure_tuple, ensure_tuple_rep +from monai.utils import ensure_tuple_rep from monai.config import KeysCollection, DtypeLike, SequenceStr from monai.transforms.atmostonce.apply import apply diff --git a/monai/transforms/atmostonce/functional.py b/monai/transforms/atmostonce/functional.py index a5c975a8ad..129e9edefe 100644 --- a/monai/transforms/atmostonce/functional.py +++ b/monai/transforms/atmostonce/functional.py @@ -1,17 +1,16 @@ from typing import Optional, Sequence, Union -import numpy as np import torch -from monai.transforms import create_rotate, create_translate, GaussianSmooth +from monai.transforms import create_rotate from monai.data import get_track_meta from monai.transforms.atmostonce.apply import extents_from_shape, shape_from_extents -from monai.utils import convert_to_tensor, get_equivalent_dtype, ensure_tuple_rep, convert_to_dst_type, look_up_option, \ +from monai.utils import convert_to_tensor, get_equivalent_dtype, ensure_tuple_rep, look_up_option, \ GridSampleMode, GridSamplePadMode, fall_back_tuple, ensure_tuple_size, ensure_tuple, InterpolateMode, NumpyPadMode from monai.config import DtypeLike -from monai.utils.mapping_stack import MetaMatrix, MatrixFactory +from monai.utils.mapping_stack import MatrixFactory def spacing( diff --git a/monai/transforms/atmostonce/lazy_transform.py b/monai/transforms/atmostonce/lazy_transform.py index 2731d58f47..fa60a5f0d4 100644 --- a/monai/transforms/atmostonce/lazy_transform.py +++ b/monai/transforms/atmostonce/lazy_transform.py @@ -1,9 +1,7 @@ -import itertools as it from monai.config import NdarrayOrTensor from monai.data import MetaTensor -from monai.transforms import Transform from monai.transforms.atmostonce.apply import Applyd from monai.utils.mapping_stack import MetaMatrix @@ -72,4 +70,4 @@ def __init__(self, lazy_evaluation): # pass # else: # # apply the transform and reset the stack on metatensor - # pass \ No newline at end of file + # pass diff --git a/monai/utils/misc.py b/monai/utils/misc.py index c4ad86cc68..b2d5dd6e32 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -493,4 +493,4 @@ def get_backend_from_data(data): return TransformBackends.TORCH else: msg = "'data' must be one of numpy ndarray or torch Tensor but is {}" - raise ValueError(msg.format(type(data))) \ No newline at end of file + raise ValueError(msg.format(type(data))) diff --git a/monai/utils/type_conversion.py b/monai/utils/type_conversion.py index 9a6b488c61..97eca5a7a6 100644 --- a/monai/utils/type_conversion.py +++ b/monai/utils/type_conversion.py @@ -10,7 +10,7 @@ # limitations under the License. import re -from typing import Any, Optional, Sequence, SupportsInt, Tuple, Type, TypeVar, Union +from typing import Any, Optional, Sequence, Tuple, Type, TypeVar, Union import numpy as np import torch diff --git a/tests/test_atmostonce.py b/tests/test_atmostonce.py index d0f3ab79aa..f899562f72 100644 --- a/tests/test_atmostonce.py +++ b/tests/test_atmostonce.py @@ -4,7 +4,6 @@ import numpy as np -import matplotlib.pyplot as plt import torch @@ -18,7 +17,7 @@ from monai.transforms.atmostonce.dictionary import Rotated from monai.transforms.compose import Compose from monai.utils.enums import GridSampleMode, GridSamplePadMode -from monai.utils.mapping_stack import MatrixFactory, MetaMatrix +from monai.utils.mapping_stack import MatrixFactory def enumerate_results_of_op(results): diff --git a/tests/test_mapping_stack.py b/tests/test_mapping_stack.py index cd0c72417f..6c6fb1ec11 100644 --- a/tests/test_mapping_stack.py +++ b/tests/test_mapping_stack.py @@ -13,7 +13,7 @@ from monai.transforms.utils import TransformBackends -from monai.utils.mapping_stack import MappingStack, Mapping, MatrixFactory +from monai.utils.mapping_stack import MappingStack, MatrixFactory class MappingStackTest(unittest.TestCase): From 7af9558605988c0df4b589e9eb7f76533b070060 Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Fri, 19 Aug 2022 08:52:35 +0100 Subject: [PATCH 11/30] Simplified array and dictionary transforms; working on generic translate --- monai/transforms/atmostonce/array.py | 17 +- monai/transforms/atmostonce/dictionary.py | 201 ++++++++++++++++------ monai/transforms/atmostonce/functional.py | 34 +++- tests/test_atmostonce.py | 31 +++- 4 files changed, 209 insertions(+), 74 deletions(-) diff --git a/monai/transforms/atmostonce/array.py b/monai/transforms/atmostonce/array.py index f3b112231c..8b568df404 100644 --- a/monai/transforms/atmostonce/array.py +++ b/monai/transforms/atmostonce/array.py @@ -32,17 +32,16 @@ def __init__( padding_mode: Optional[Union[GridSamplePadMode, str]] = GridSamplePadMode.BORDER, align_corners: Optional[bool] = False, dtype: Optional[DtypeLike] = np.float64, - image_only: Optional[bool] = False, lazy_evaluation: Optional[bool] = False ): LazyTransform.__init__(self, lazy_evaluation) self.pixdim = pixdim + self.src_pixdim = src_pixdim self.diagonal = diagonal self.mode = mode self.padding_mode = padding_mode self.align_corners = align_corners self.dtype = dtype - self.image_only = image_only def __call__( self, @@ -306,16 +305,4 @@ def inverse( raise NotImplementedError() -# Snippet of code for pushing transform to metadata - pulled from Rotate - # img = self._post_process(img, img.spatial_shape, sp_size, *args) - # img.spatial_shape = sp_size # type: ignore - # self.update_meta(img, orig_size, sp_size) - # self.push_transform( - # img, - # orig_size=orig_size, - # extra_info={ - # "mode": mode, - # "align_corners": align_corners if align_corners is not None else TraceKeys.NONE, - # "new_dim": len(orig_size) - ndim, # additional dims appended - # }, - # ) + diff --git a/monai/transforms/atmostonce/dictionary.py b/monai/transforms/atmostonce/dictionary.py index c57023a82b..3ca1668008 100644 --- a/monai/transforms/atmostonce/dictionary.py +++ b/monai/transforms/atmostonce/dictionary.py @@ -4,7 +4,7 @@ import torch -from monai.transforms.atmostonce.functional import rotate +from monai.transforms.atmostonce.array import Rotate, Resize, Spacing, Zoom from monai.utils import ensure_tuple_rep from monai.config import KeysCollection, DtypeLike, SequenceStr @@ -12,7 +12,8 @@ from monai.transforms.atmostonce.lazy_transform import LazyTransform from monai.transforms.inverse import InvertibleTransform from monai.transforms.transform import MapTransform -from monai.utils.enums import TransformBackends, GridSampleMode, GridSamplePadMode +from monai.utils.enums import TransformBackends, GridSampleMode, GridSamplePadMode, InterpolateMode, NumpyPadMode, \ + PytorchPadMode from monai.utils.mapping_stack import MatrixFactory, MetaMatrix from monai.utils.type_conversion import expand_scalar_to_tuple @@ -43,36 +44,80 @@ def expand_potential_tuple(keys, value): return value -class MappingStackTransformd(MapTransform, InvertibleTransform): +# class MappingStackTransformd(MapTransform, InvertibleTransform): +# +# def __init__(self, +# keys: KeysCollection): +# super().__init__(self) +# self.keys = keys +# +# def __call__(self, +# d: Mapping, +# *args, +# **kwargs): +# mappings = d.get("mappings", dict()) +# rd = dict() +# for k in self.keys: +# data = d[k] +# dims = len(data.shape)-1 +# device = get_device_from_data(data) +# backend = get_backend_from_data(data) +# v = None # mappings.get(k, MappingStack(MatrixFactory(dims, backend, device))) +# v.push(self.get_matrix(dims, backend, device, *args, **kwargs)) +# mappings[k] = v +# rd[k] = data +# +# rd["mappings"] = mappings +# +# return rd +# +# def get_matrix(self, dims, backend, device, *args, **kwargs): +# msg = "get_matrix must be implemented in a subclass of MappingStackTransform" +# raise NotImplementedError(msg) + + +class Spaced(LazyTransform, MapTransform, InvertibleTransform): def __init__(self, - keys: KeysCollection): - super().__init__(self) + keys: KeysCollection, + pixdim: Union[Sequence[float], float, np.ndarray], + src_pixdim: Optional[Union[Sequence[float], float, np.ndarray]], + diagonal: Optional[bool] = False, + mode: Optional[Union[GridSampleMode, str]] = GridSampleMode.BILINEAR, + padding_mode: Optional[Union[GridSamplePadMode, str]] = GridSamplePadMode.BORDER, + align_corners: Optional[bool] = False, + dtype: Optional[DtypeLike] = np.float64, + allow_missing_keys: Optional[bool] = False, + lazy_evaluation: Optional[bool] = False + ): + LazyTransform.__init__(self, lazy_evaluation) + MapTransform.__init__(self) + InvertibleTransform.__init__(self) self.keys = keys + self.pixdim = pixdim + self.src_pixdim = src_pixdim + self.diagonal = diagonal + self.modes = ensure_tuple_rep(mode) + self.padding_modes = ensure_tuple_rep(padding_mode) + self.align_corners = align_corners + self.dtypes = ensure_tuple_rep(dtype) + self.allow_missing_keys = allow_missing_keys - def __call__(self, - d: Mapping, - *args, - **kwargs): - mappings = d.get("mappings", dict()) - rd = dict() - for k in self.keys: - data = d[k] - dims = len(data.shape)-1 - device = get_device_from_data(data) - backend = get_backend_from_data(data) - v = None # mappings.get(k, MappingStack(MatrixFactory(dims, backend, device))) - v.push(self.get_matrix(dims, backend, device, *args, **kwargs)) - mappings[k] = v - rd[k] = data + def __call__(self, d: Mapping): + rd = dict(d) + if self.allow_missing_keys is True: + keys_present = {k for k in self.keys if k in d} + else: + keys_present = self.keys - rd["mappings"] = mappings + for ik, k in enumerate(keys_present): + tx = Spacing(self.pixdim, self.src_pixdim, self.diagonal, + self.modes[ik], self.padding_modes[ik], + self.align_corners, self.dtypes[ik]) - return rd + rd[k] = tx(d[k]) - def get_matrix(self, dims, backend, device, *args, **kwargs): - msg = "get_matrix must be implemented in a subclass of MappingStackTransform" - raise NotImplementedError(msg) + return rd class Rotated(LazyTransform, MapTransform, InvertibleTransform): @@ -106,18 +151,51 @@ def __call__(self, d: Mapping): keys_present = self.keys for ik, k in enumerate(keys_present): - img = d[k] + tx = Rotate(self.angle, self.keep_size, + self.modes[ik], self.padding_modes[ik], + self.align_corners, self.dtypes[ik]) + rd[k] = tx(d[k]) + + return rd + + def inverse(self, data: Any): + raise NotImplementedError() - img_t, transform, metadata = rotate(img, self.angle, self.keep_size, - self.modes[ik], self.padding_modes[ik], - self.align_corners, self.dtypes[ik]) - # TODO: candidate for refactoring into a LazyTransform method - img_t.push_pending_transform(MetaMatrix(transform, metadata)) - if not self.lazy_evaluation: - img_t = apply(img_t) +class Resized(LazyTransform, MapTransform, InvertibleTransform): + + def __init__(self, + keys: KeysCollection, + spatial_size: Union[Sequence[int], int], + size_mode: Optional[str] = "all", + mode: Union[InterpolateMode, str] = InterpolateMode.AREA, + align_corners: Optional[bool] = False, + anti_aliasing: Optional[bool] = False, + anti_aliasing_sigma: Optional[Union[Sequence[float], float, None]] = None, + dtype: Optional[Union[DtypeLike, torch.dtype]] = np.float32, + lazy_evaluation: Optional[bool] = False + ): + LazyTransform.__init__(self, lazy_evaluation) + self.keys = keys + self.spatial_size = spatial_size + self.size_mode = size_mode + self.modes = ensure_tuple_rep(mode), + self.align_corners = align_corners + self.anti_aliasing = anti_aliasing + self.anti_aliasing_sigma = anti_aliasing_sigma + self.dtype = dtype + + def __call__(self, d: Mapping): + rd = dict(d) + if self.allow_missing_keys is True: + keys_present = {k for k in self.keys if k in d} + else: + keys_present = self.keys - rd[k] = img_t + for ik, k in enumerate(keys_present): + tx = Resize(spatial_size, size_mode, self.modes[ik], self.align_corners, + self.anti_aliasing, self.anti_aliasing_sigma, self.dtype) + rd[k] = tx(d[k]) return rd @@ -125,40 +203,53 @@ def inverse(self, data: Any): raise NotImplementedError() -class Translated(MapTransform, InvertibleTransform): +class Zoomd(LazyTransform, MapTransform, InvertibleTransform): def __init__(self, keys: KeysCollection, - translate: Union[Sequence[float], float]): - super().__init__(self) + zoom: Union[Sequence[float], float], + mode: Optional[Union[InterpolateMode, str]] = InterpolateMode.AREA, + padding_mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = NumpyPadMode.EDGE, + align_corners: Optional[bool] = None, + keep_size: Optional[bool] = True, + dtype: Optional[Union[DtypeLike, torch.dtype]] = np.float32, + lazy_evaluation: Optional[bool] = False, + **kwargs + ): + LazyTransform.__init__(self, lazy_evaluation) self.keys = keys - self.translate = expand_scalar_to_tuple(translate, len(keys)) + self.zoom = zoom + self.modes = ensure_tuple_rep(mode) + self.padding_modes = ensure_tuple_rep(padding_mode) + self.align_corners = align_corners + self.keep_size = keep_size + self.dtype = dtype def __call__(self, d: Mapping): - mappings = d.get("mappings", dict()) - rd = dict() - for k in self.keys: - data = d[k] - dims = len(data.shape)-1 - device = get_device_from_data(data) - backend = get_backend_from_data(data) - matrix_factory = MatrixFactory(dims, backend, device) - v = None # mappings.get(k, MappingStack(matrix_factory)) - v.push(matrix_factory.translate(self.translate)) - mappings[k] = v - rd[k] = data + rd = dict(d) + if self.allow_missing_keys is True: + keys_present = {k for k in self.keys if k in d} + else: + keys_present = self.keys + + for ik, k in enumerate(keys_present): + tx = Zoom(self.zoom, self.modes[ik], self.padding_modes[k], self.align_corners, + self.keep_size, self.dtype) + rd[k] = tx(d[k]) return rd + def inverse(self, data: Any): + raise NotImplementedError() -class Zoomd(MapTransform, InvertibleTransform): +class Translated(MapTransform, InvertibleTransform): def __init__(self, keys: KeysCollection, - scale: Union[Sequence[float], float]): + translate: Union[Sequence[float], float]): super().__init__(self) self.keys = keys - self.scale = expand_scalar_to_tuple(scale, len(keys)) + self.translate = expand_scalar_to_tuple(translate, len(keys)) def __call__(self, d: Mapping): mappings = d.get("mappings", dict()) @@ -170,10 +261,8 @@ def __call__(self, d: Mapping): backend = get_backend_from_data(data) matrix_factory = MatrixFactory(dims, backend, device) v = None # mappings.get(k, MappingStack(matrix_factory)) - v.push(matrix_factory.scale(self.scale)) + v.push(matrix_factory.translate(self.translate)) mappings[k] = v rd[k] = data - rd["mappings"] = mappings - return rd diff --git a/monai/transforms/atmostonce/functional.py b/monai/transforms/atmostonce/functional.py index 129e9edefe..692a372e94 100644 --- a/monai/transforms/atmostonce/functional.py +++ b/monai/transforms/atmostonce/functional.py @@ -284,7 +284,9 @@ def zoom( "padding_mode": padding_mode_, "align_corners": align_corners, "keep_size": keep_size, - "dtype": dtype_ + "dtype": dtype_, + "im_extents": im_extents, + "spatial_shape": spatial_shape_ } return img_, transform, metadata @@ -293,3 +295,33 @@ def rotate90( img: torch.Tensor ): pass + + +def croppad( + img: torch.Tensor, + slices: Union[Sequence[slice], slice], + pad_mode: Optional[Union[GridSamplePadMode, str]] = NumpyPadMode.EDGE +): + img_ = convert_to_tensor(img, track_meta=get_track_meta()) + input_ndim = len(img.shape) - 1 + if len(slices) != input_ndim: + raise ValueError(f"'slices' length {len(slices)} must be equal to 'img' " + f"spatial dimensions of {input_ndim}") + + img_centers = [i // 2 for i in img.shape[1:]] + slice_centers = [s.stop - s.start for s in slices] + # img_centers = [0 for _ in img.shape[1:]] + # slice_centers = [s.end - s.start for s in slices] + deltas = [s - i for i, s in zip(img_centers, slice_centers)] + transform = MatrixFactory.from_tensor(img).translate(deltas) + im_extents = extents_from_shape([img.shape[0]] + [s.stop - s.start for s in slices]) + im_extents = [transform.matrix.matrix @ e for e in im_extents] + spatial_shape_ = shape_from_extents(im_extents) + + metadata = { + "slices": slices, + "dtype": img.dtype, + "im_extents": im_extents, + "spatial_shape": spatial_shape_ + } + return img_, transform, metadata diff --git a/tests/test_atmostonce.py b/tests/test_atmostonce.py index f899562f72..9e0841b83f 100644 --- a/tests/test_atmostonce.py +++ b/tests/test_atmostonce.py @@ -11,8 +11,8 @@ from monai.transforms.atmostonce.lazy_transform import compile_transforms from monai.utils import TransformBackends -from monai.transforms import Affined -from monai.transforms.atmostonce.functional import resize, rotate, spacing +from monai.transforms import Affined, Affine +from monai.transforms.atmostonce.functional import croppad, resize, rotate, spacing from monai.transforms.atmostonce.apply import Applyd, extents_from_shape, shape_from_extents from monai.transforms.atmostonce.dictionary import Rotated from monai.transforms.compose import Compose @@ -20,6 +20,19 @@ from monai.utils.mapping_stack import MatrixFactory +def get_img(size): + img = torch.zeros(size, dtype=torch.float32) + if len(size) == 2: + for j in range(size[0]): + for i in range(size[1]): + img[j, i] = i + j * size[1] + else: + for k in range(size[-1]): + for j in range(size[-2]): + img[..., j, k] = j + k * size[0] + return np.expand_dims(img, 0) + + def enumerate_results_of_op(results): if isinstance(results, dict): for k, v in results.items(): @@ -168,6 +181,20 @@ def test_rotate(self): "border") enumerate_results_of_op(results) + def test_croppad(self): + img = get_img((16, 16)).astype(int) + results = croppad(img, + (slice(3, 8), slice(3, 9))) + enumerate_results_of_op(results) + m = results[1].matrix.matrix + print(m) + result_size = results[2]['spatial_shape'] + a = Affine(affine=m, + padding_mode=GridSamplePadMode.ZEROS, + spatial_size=[1] + result_size) + img_, _ = a(img) + print(img_.numpy().astype(int)) + class TestArrayTransforms(unittest.TestCase): From e47b22108844f5da7821133c83e8761e27d29231 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 19 Aug 2022 08:09:59 +0000 Subject: [PATCH 12/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/transforms/atmostonce/array.py | 3 --- monai/transforms/atmostonce/dictionary.py | 3 +-- monai/transforms/atmostonce/lazy_transform.py | 1 - 3 files changed, 1 insertion(+), 6 deletions(-) diff --git a/monai/transforms/atmostonce/array.py b/monai/transforms/atmostonce/array.py index 8b568df404..d066b1adbc 100644 --- a/monai/transforms/atmostonce/array.py +++ b/monai/transforms/atmostonce/array.py @@ -303,6 +303,3 @@ def inverse( data: NdarrayOrTensor, ): raise NotImplementedError() - - - diff --git a/monai/transforms/atmostonce/dictionary.py b/monai/transforms/atmostonce/dictionary.py index 3ca1668008..1160250ab5 100644 --- a/monai/transforms/atmostonce/dictionary.py +++ b/monai/transforms/atmostonce/dictionary.py @@ -8,13 +8,12 @@ from monai.utils import ensure_tuple_rep from monai.config import KeysCollection, DtypeLike, SequenceStr -from monai.transforms.atmostonce.apply import apply from monai.transforms.atmostonce.lazy_transform import LazyTransform from monai.transforms.inverse import InvertibleTransform from monai.transforms.transform import MapTransform from monai.utils.enums import TransformBackends, GridSampleMode, GridSamplePadMode, InterpolateMode, NumpyPadMode, \ PytorchPadMode -from monai.utils.mapping_stack import MatrixFactory, MetaMatrix +from monai.utils.mapping_stack import MatrixFactory from monai.utils.type_conversion import expand_scalar_to_tuple diff --git a/monai/transforms/atmostonce/lazy_transform.py b/monai/transforms/atmostonce/lazy_transform.py index fa60a5f0d4..ddbd0cc1f3 100644 --- a/monai/transforms/atmostonce/lazy_transform.py +++ b/monai/transforms/atmostonce/lazy_transform.py @@ -1,4 +1,3 @@ - from monai.config import NdarrayOrTensor from monai.data import MetaTensor From d53d95d93640a8c4a234869e2f1b92e3541398ed Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Fri, 19 Aug 2022 12:23:18 +0100 Subject: [PATCH 13/30] Base croppad debugged and tests added --- monai/transforms/atmostonce/functional.py | 7 +-- tests/test_atmostonce.py | 67 +++++++++++++++++++++-- 2 files changed, 66 insertions(+), 8 deletions(-) diff --git a/monai/transforms/atmostonce/functional.py b/monai/transforms/atmostonce/functional.py index 692a372e94..b0264d3609 100644 --- a/monai/transforms/atmostonce/functional.py +++ b/monai/transforms/atmostonce/functional.py @@ -308,10 +308,8 @@ def croppad( raise ValueError(f"'slices' length {len(slices)} must be equal to 'img' " f"spatial dimensions of {input_ndim}") - img_centers = [i // 2 for i in img.shape[1:]] - slice_centers = [s.stop - s.start for s in slices] - # img_centers = [0 for _ in img.shape[1:]] - # slice_centers = [s.end - s.start for s in slices] + img_centers = [i / 2 for i in img.shape[1:]] + slice_centers = [(s.stop + s.start) / 2 for s in slices] deltas = [s - i for i, s in zip(img_centers, slice_centers)] transform = MatrixFactory.from_tensor(img).translate(deltas) im_extents = extents_from_shape([img.shape[0]] + [s.stop - s.start for s in slices]) @@ -320,6 +318,7 @@ def croppad( metadata = { "slices": slices, + "pad_mode": pad_mode, "dtype": img.dtype, "im_extents": im_extents, "spatial_shape": spatial_shape_ diff --git a/tests/test_atmostonce.py b/tests/test_atmostonce.py index 9e0841b83f..cc347e2427 100644 --- a/tests/test_atmostonce.py +++ b/tests/test_atmostonce.py @@ -181,19 +181,78 @@ def test_rotate(self): "border") enumerate_results_of_op(results) - def test_croppad(self): + def test_croppad_identity(self): img = get_img((16, 16)).astype(int) results = croppad(img, - (slice(3, 8), slice(3, 9))) + (slice(0, 16), slice(0, 16))) + enumerate_results_of_op(results) + m = results[1].matrix.matrix + print(m) + result_size = results[2]['spatial_shape'] + a = Affine(affine=m, + padding_mode=GridSamplePadMode.ZEROS, + spatial_size=result_size) + img_, _ = a(img) + print(img_) + + def _croppad_impl(self, img_ext, slices, expected): + img = get_img(img_ext).astype(int) + results = croppad(img, slices) + enumerate_results_of_op(results) + m = results[1].matrix.matrix + print(m) + result_size = results[2]['spatial_shape'] + a = Affine(affine=m, + padding_mode=GridSamplePadMode.ZEROS, + spatial_size=result_size) + img_, _ = a(img) + if expected is None: + print(img_.numpy()) + else: + self.assertTrue(torch.allclose(img_, expected)) + + def test_croppad_img_odd_crop_odd(self): + expected = torch.as_tensor([[63., 64., 65., 66., 67., 68., 69.], + [78., 79., 80., 81., 82., 83., 84.], + [93., 94., 95., 96., 97., 98., 99.], + [108., 109., 110., 111., 112., 113., 114.], + [123., 124., 125., 126., 127., 128., 129.]]) + self._croppad_impl((15, 15), (slice(4, 9), slice(3, 10)), expected) + + def test_croppad_img_odd_crop_even(self): + expected = torch.as_tensor([[63., 64., 65., 66., 67., 68.], + [78., 79., 80., 81., 82., 83.], + [93., 94., 95., 96., 97., 98.], + [108., 109., 110., 111., 112., 113.]]) + self._croppad_impl((15, 15), (slice(4, 8), slice(3, 9)), expected) + + def test_croppad_img_even_crop_odd(self): + expected = torch.as_tensor([[67., 68., 69., 70., 71., 72., 73.], + [83., 84., 85., 86., 87., 88., 89.], + [99., 100., 101., 102., 103., 104., 105.], + [115., 116., 117., 118., 119., 120., 121.], + [131., 132., 133., 134., 135., 136., 137.]]) + self._croppad_impl((16, 16), (slice(4, 9), slice(3, 10)), expected) + + def test_croppad_img_even_crop_even(self): + expected = torch.as_tensor([[67., 68., 69., 70., 71., 72.], + [83., 84., 85., 86., 87., 88.], + [99., 100., 101., 102., 103., 104.], + [115., 116., 117., 118., 119., 120.]]) + self._croppad_impl((16, 16), (slice(4, 8), slice(3, 9)), expected) + + def test_croppad(self): + img = get_img((15, 15)).astype(int) + results = croppad(img, (slice(4, 8), slice(3, 9))) enumerate_results_of_op(results) m = results[1].matrix.matrix print(m) result_size = results[2]['spatial_shape'] a = Affine(affine=m, padding_mode=GridSamplePadMode.ZEROS, - spatial_size=[1] + result_size) + spatial_size=result_size) img_, _ = a(img) - print(img_.numpy().astype(int)) + print(img_.numpy()) class TestArrayTransforms(unittest.TestCase): From b6a0b05d43de4a067caf8c499ba35a0060381955 Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Fri, 19 Aug 2022 15:59:08 +0100 Subject: [PATCH 14/30] Partial implementation of mon-metatensor based lazy resampling --- monai/data/meta_tensor.py | 2 +- monai/transforms/atmostonce/apply.py | 35 +++++++-- monai/transforms/atmostonce/array.py | 69 ++++++++++++++--- monai/transforms/atmostonce/dictionary.py | 1 + monai/transforms/atmostonce/functional.py | 92 +++++++++++++---------- monai/utils/mapping_stack.py | 12 ++- tests/test_atmostonce.py | 48 ++++++++---- 7 files changed, 183 insertions(+), 76 deletions(-) diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index 8c6c7278d7..bde42ff279 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -159,7 +159,7 @@ def push_pending_transform(self, meta_matrix): self._pending_transforms.append(meta_matrix) def peek_pending_transform(self): - return copy.deepcopy(self._pending_transforms[0]) + return copy.deepcopy(self._pending_transforms[-1]) def pop_pending_transform(self): transform = self._pending_transforms[0] diff --git a/monai/transforms/atmostonce/apply.py b/monai/transforms/atmostonce/apply.py index 6c9fedc91a..d981eadea4 100644 --- a/monai/transforms/atmostonce/apply.py +++ b/monai/transforms/atmostonce/apply.py @@ -4,6 +4,7 @@ import numpy as np +import torch from monai.config import DtypeLike from monai.data import MetaTensor @@ -28,11 +29,29 @@ def extents_from_shape(shape): # TODO: move to mapping_stack.py -def shape_from_extents(extents): - aextents = np.asarray(extents) - mins = aextents.min(axis=0) - maxes = aextents.max(axis=0) - return np.ceil(maxes - mins)[:-1].astype(int) +def shape_from_extents( + src_shape: Sequence, + extents: Union[Sequence[np.ndarray], Sequence[torch.Tensor], np.ndarray, torch.Tensor] +): + if isinstance(extents, (list, tuple)): + if isinstance(extents[0], np.ndarray): + aextents = np.asarray(extents) + aextents = torch.from_numpy(aextents) + else: + aextents = torch.stack(extents) + else: + if isinstance(extents, np.ndarray): + aextents = torch.from_numpy(extents) + else: + aextents = extents + + mins = aextents.min(axis=0)[0] + maxes = aextents.max(axis=0)[0] + values = torch.ceil(maxes - mins).type(torch.IntTensor)[:-1] + return torch.cat((torch.as_tensor([src_shape[0]]), values)) + + # return [src_shape[0]] + np.ceil(maxes - mins)[:-1].astype(int).tolist() + def apply(data: MetaTensor): @@ -41,17 +60,17 @@ def apply(data: MetaTensor): if len(pending) == 0: return data - dim_count = len(data) - 1 + dim_count = len(data.shape) - 1 matrix_factory = MatrixFactory(dim_count, get_backend_from_data(data), get_device_from_data(data)) # set up the identity matrix and metadata - cumulative_matrix = matrix_factory.identity(dim_count) + cumulative_matrix = matrix_factory.identity() cumulative_extents = extents_from_shape(data.shape) # pre-translate origin to centre of image - translate_to_centre = matrix_factory.translate(dim_count) + translate_to_centre = matrix_factory.translate([d / 2 for d in data.shape[1:]]) cumulative_matrix = translate_to_centre @ cumulative_matrix cumulative_extents = [e @ translate_to_centre for e in cumulative_extents] diff --git a/monai/transforms/atmostonce/array.py b/monai/transforms/atmostonce/array.py index d066b1adbc..7a88f72789 100644 --- a/monai/transforms/atmostonce/array.py +++ b/monai/transforms/atmostonce/array.py @@ -9,7 +9,7 @@ from monai.transforms import InvertibleTransform, RandomizableTransform from monai.transforms.atmostonce.apply import apply -from monai.transforms.atmostonce.functional import resize, rotate, zoom, spacing +from monai.transforms.atmostonce.functional import resize, rotate, zoom, spacing, croppad from monai.transforms.atmostonce.lazy_transform import LazyTransform from monai.utils import (GridSampleMode, GridSamplePadMode, @@ -20,6 +20,9 @@ # TODO: these transforms are intended to replace array transforms once development is done +# spatial +# ======= + # TODO: why doesn't Spacing have antialiasing options? class Spacing(LazyTransform, InvertibleTransform): @@ -32,7 +35,8 @@ def __init__( padding_mode: Optional[Union[GridSamplePadMode, str]] = GridSamplePadMode.BORDER, align_corners: Optional[bool] = False, dtype: Optional[DtypeLike] = np.float64, - lazy_evaluation: Optional[bool] = False + lazy_evaluation: Optional[bool] = False, + shape_override: Optional[Sequence] = None ): LazyTransform.__init__(self, lazy_evaluation) self.pixdim = pixdim @@ -49,7 +53,8 @@ def __call__( mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, align_corners: Optional[bool] = None, - dtype: DtypeLike = None + dtype: DtypeLike = None, + shape_override: Optional[Sequence] = None ): mode_ = mode or self.mode @@ -58,7 +63,8 @@ def __call__( dtype_ = dtype or self.dtype img_t, transform, metadata = spacing(img, self.pixdim, self.src_pixdim, self.diagonal, - mode_, padding_mode_, align_corners_, dtype_) + mode_, padding_mode_, align_corners_, dtype_, + shape_override) # TODO: candidate for refactoring into a LazyTransform method img_t.push_pending_transform(MetaMatrix(transform, metadata)) @@ -99,7 +105,8 @@ def __call__( mode: Optional[Union[InterpolateMode, str]] = None, align_corners: Optional[bool] = None, anti_aliasing: Optional[bool] = None, - anti_aliasing_sigma: Union[Sequence[float], float, None] = None + anti_aliasing_sigma: Union[Sequence[float], float, None] = None, + shape_override: Optional[Sequence] = None ) -> NdarrayOrTensor: mode_ = mode or self.mode align_corners_ = align_corners or self.align_corners @@ -108,7 +115,7 @@ def __call__( img_t, transform, metadata = resize(img, self.spatial_size, self.size_mode, mode_, align_corners_, anti_aliasing_, anti_aliasing_sigma_, - self.dtype) + self.dtype, shape_override) # TODO: candidate for refactoring into a LazyTransform method img_t.push_pending_transform(MetaMatrix(transform, metadata)) @@ -144,6 +151,7 @@ def __call__( mode: Optional[Union[InterpolateMode, str]] = None, padding_mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None, align_corners: Optional[bool] = None, + shape_override: Optional[Sequence] = None ) -> NdarrayOrTensor: angle = self.angle mode = mode or self.mode @@ -153,7 +161,7 @@ def __call__( dtype = self.dtype img_t, transform, metadata = rotate(img, angle, keep_size, mode, padding_mode, - align_corners, dtype) + align_corners, dtype, shape_override) # TODO: candidate for refactoring into a LazyTransform method img_t.push_pending_transform(MetaMatrix(transform, metadata)) @@ -195,7 +203,8 @@ def __call__( img: NdarrayOrTensor, mode: Optional[Union[InterpolateMode, str]] = None, padding_mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None, - align_corners: Optional[bool] = None + align_corners: Optional[bool] = None, + shape_override: Optional[Sequence] = None ) -> NdarrayOrTensor: mode = self.mode or mode @@ -205,7 +214,7 @@ def __call__( dtype = self.dtype img_t, transform, metadata = zoom(img, self.zoom, mode, padding_mode, align_corners, - keep_size, dtype) + keep_size, dtype, shape_override) # TODO: candidate for refactoring into a LazyTransform method img_t.push_pending_transform(MetaMatrix(transform, metadata)) @@ -269,7 +278,8 @@ def __call__( align_corners: Optional[bool] = None, dtype: Optional[Union[DtypeLike, torch.dtype]] = None, randomize: Optional[bool] = True, - get_matrix: Optional[bool] = False + get_matrix: Optional[bool] = False, + shape_override: Optional[Sequence] = None ) -> NdarrayOrTensor: @@ -289,7 +299,7 @@ def __call__( dtype = self.dtype img_t, transform, metadata = rotate(img, angle, keep_size, mode, padding_mode, - align_corners, dtype) + align_corners, dtype, shape_override) # TODO: candidate for refactoring into a LazyTransform method img_t.push_pending_transform(MetaMatrix(transform, metadata)) @@ -303,3 +313,40 @@ def inverse( data: NdarrayOrTensor, ): raise NotImplementedError() + +# croppad +# ======= + + +class CropPad(LazyTransform, InvertibleTransform): + + def __init__( + self, + slices: Sequence[slice], + padmode: Optional[Union[GridSamplePadMode, str]] = GridSamplePadMode.BORDER, + lazy_evaluation: Optional[bool] = True, + ): + LazyTransform.__init__(self, lazy_evaluation) + self.slices = slices + self.padmode = padmode + + def __call__( + self, + img: NdarrayOrTensor, + shape_override: Optional[Sequence] = None + ): + + img_t, transform, metadata = croppad(img, self.slices, self.padmode, shape_override) + + # TODO: candidate for refactoring into a LazyTransform method + img_t.push_pending_transform(MetaMatrix(transform, metadata)) + if not self.lazy_evaluation: + img_t = apply(img_t) + + return img_t + + def inverse( + self, + data: NdarrayOrTensor + ): + raise NotImplementedError() diff --git a/monai/transforms/atmostonce/dictionary.py b/monai/transforms/atmostonce/dictionary.py index 1160250ab5..453f21cf06 100644 --- a/monai/transforms/atmostonce/dictionary.py +++ b/monai/transforms/atmostonce/dictionary.py @@ -241,6 +241,7 @@ def __call__(self, d: Mapping): def inverse(self, data: Any): raise NotImplementedError() + class Translated(MapTransform, InvertibleTransform): def __init__(self, diff --git a/monai/transforms/atmostonce/functional.py b/monai/transforms/atmostonce/functional.py index b0264d3609..4dfb323245 100644 --- a/monai/transforms/atmostonce/functional.py +++ b/monai/transforms/atmostonce/functional.py @@ -21,7 +21,8 @@ def spacing( mode: Optional[Union[InterpolateMode, str]] = InterpolateMode.AREA, padding_mode: Optional[Union[NumpyPadMode, GridSamplePadMode, str]] = NumpyPadMode.EDGE, align_corners: Optional[bool] = False, - dtype: Optional[Union[DtypeLike, torch.dtype]] = None + dtype: Optional[Union[DtypeLike, torch.dtype]] = None, + shape_override: Optional[Sequence] = None ): """ Args: @@ -49,7 +50,8 @@ def spacing( """ img_ = convert_to_tensor(img, track_meta=get_track_meta()) - input_ndim = len(img.shape) - 1 + input_shape = img_.shape if shape_override is None else shape_override + input_ndim = len(input_shape) - 1 pixdim_ = ensure_tuple_rep(pixdim, input_ndim) src_pixdim_ = ensure_tuple_rep(src_pixdim, input_ndim) @@ -62,10 +64,11 @@ def spacing( dtype_ = get_equivalent_dtype(dtype or img.dtype, torch.Tensor) zoom_factors = [i / j for i, j in zip(src_pixdim_, pixdim_)] - transform = MatrixFactory.from_tensor(img).scale(zoom_factors) - im_extents = extents_from_shape(img.shape) - im_extents = [transform.matrix.matrix @ e for e in im_extents] - spatial_shape_ = shape_from_extents(im_extents) + # TODO: decide whether we are consistently returning MetaMatrix or concrete transforms + transform = MatrixFactory.from_tensor(img).scale(zoom_factors).matrix.matrix + im_extents = extents_from_shape(input_shape) + im_extents = [transform @ e for e in im_extents] + shape_override_ = shape_from_extents(input_shape, im_extents) metadata = { "pixdim": pixdim_, @@ -76,7 +79,7 @@ def spacing( "align_corners": align_corners, "dtype": dtype_, "im_extents": im_extents, - "spatial_shape": spatial_shape_ + "shape_override": shape_override_ } return img_, transform, metadata @@ -101,7 +104,8 @@ def resize( align_corners: Optional[bool] = False, anti_aliasing: Optional[bool] = None, anti_aliasing_sigma: Optional[Union[Sequence[float], float]] = None, - dtype: Optional[Union[DtypeLike, torch.dtype]] = None + dtype: Optional[Union[DtypeLike, torch.dtype]] = None, + shape_override: Optional[Sequence] = None ): """ Args: @@ -129,21 +133,22 @@ def resize( """ img_ = convert_to_tensor(img, track_meta=get_track_meta()) - input_ndim = len(img.shape) - 1 + input_shape = img_.shape if shape_override is None else shape_override + input_ndim = len(input_shape) - 1 if size_mode == "all": output_ndim = len(ensure_tuple(spatial_size)) if output_ndim > input_ndim: - input_shape = ensure_tuple_size(img.shape, output_ndim + 1, 1) + input_shape = ensure_tuple_size(input_shape, output_ndim + 1, 1) img = img.reshape(input_shape) elif output_ndim < input_ndim: raise ValueError( "len(spatial_size) must be greater or equal to img spatial dimensions, " f"got spatial_size={output_ndim} img={input_ndim}." ) - spatial_size_ = fall_back_tuple(spatial_size, img.shape[1:]) + spatial_size_ = fall_back_tuple(spatial_size, input_shape[1:]) else: # for the "longest" mode - img_size = img.shape[1:] + img_size = input_shape[1:] if not isinstance(spatial_size, int): raise ValueError("spatial_size must be an int number if size_mode is 'longest'.") scale = spatial_size / max(img_size) @@ -151,11 +156,11 @@ def resize( mode_ = look_up_option(mode, GridSampleMode) dtype_ = get_equivalent_dtype(dtype or img.dtype, torch.Tensor) - zoom_factors = [i / j for i, j in zip(spatial_size, img.shape[1:])] - transform = MatrixFactory.from_tensor(img).scale(zoom_factors) - im_extents = extents_from_shape(img.shape) - im_extents = [transform.matrix.matrix @ e for e in im_extents] - spatial_shape_ = shape_from_extents(im_extents) + zoom_factors = [i / j for i, j in zip(spatial_size, input_shape[1:])] + transform = MatrixFactory.from_tensor(img).scale(zoom_factors).matrix.matrix + im_extents = extents_from_shape(input_shape) + im_extents = [transform @ e for e in im_extents] + shape_override_ = shape_from_extents(input_shape, im_extents) metadata = { "spatial_size": spatial_size_, @@ -166,7 +171,7 @@ def resize( "anti_aliasing_sigma": anti_aliasing_sigma, "dtype": dtype_, "im_extents": im_extents, - "spatial_shape": spatial_shape_ + "shape_override": shape_override_ } return img_, transform, metadata @@ -178,7 +183,8 @@ def rotate( mode: Optional[Union[InterpolateMode, str]] = InterpolateMode.AREA, padding_mode: Optional[Union[NumpyPadMode, GridSamplePadMode, str]] = NumpyPadMode.EDGE, align_corners: Optional[bool] = False, - dtype: Optional[Union[DtypeLike, torch.dtype]] = None + dtype: Optional[Union[DtypeLike, torch.dtype]] = None, + shape_override: Optional[Sequence] = None ): """ Args: @@ -210,17 +216,19 @@ def rotate( mode_ = look_up_option(mode, GridSampleMode) padding_mode_ = look_up_option(padding_mode, GridSamplePadMode) dtype_ = get_equivalent_dtype(dtype or img.dtype, torch.Tensor) - input_ndim = len(img_.shape) - 1 + input_shape = img_.shape if shape_override is None else shape_override + input_ndim = len(input_shape) - 1 if input_ndim not in (2, 3): raise ValueError(f"Unsupported image dimension: {input_ndim}, available options are [2, 3].") + angle_ = ensure_tuple_rep(angle, 1 if input_ndim == 2 else 3) transform = create_rotate(input_ndim, angle_) - im_extents = extents_from_shape(img.shape) + im_extents = extents_from_shape(input_shape) if not keep_size: im_extents = [transform @ e for e in im_extents] - spatial_shape = shape_from_extents(im_extents) + spatial_shape = shape_from_extents(input_shape, im_extents) else: - spatial_shape = img_.shape + spatial_shape = input_shape metadata = { "angle": angle_, @@ -230,7 +238,7 @@ def rotate( "align_corners": align_corners, "dtype": dtype_, "im_extents": im_extents, - "spatial_shape": spatial_shape + "shape_override": spatial_shape } return img_, transform, metadata @@ -242,7 +250,8 @@ def zoom( padding_mode: Optional[Union[NumpyPadMode, GridSamplePadMode, str]] = NumpyPadMode.EDGE, align_corners: Optional[bool] = False, keep_size: Optional[bool] = True, - dtype: Optional[Union[DtypeLike, torch.dtype]] = None + dtype: Optional[Union[DtypeLike, torch.dtype]] = None, + shape_override: Optional[Sequence] = None ): """ Args: @@ -261,7 +270,8 @@ def zoom( """ img_ = convert_to_tensor(img, track_meta=get_track_meta()) - input_ndim = len(img.shape) - 1 + input_shape = img_.shape if shape_override is None else shape_override + input_ndim = len(input_shape) - 1 zoom_factors = ensure_tuple_rep(zoom, input_ndim) @@ -270,13 +280,13 @@ def zoom( padding_mode_ = look_up_option(padding_mode, GridSamplePadMode) dtype_ = get_equivalent_dtype(dtype or img.dtype, torch.Tensor) - transform = MatrixFactory.from_tensor(img).scale(zoom_factors) - im_extents = extents_from_shape(img.shape) + transform = MatrixFactory.from_tensor(img).scale(zoom_factors).matrix.matrix + im_extents = extents_from_shape(input_shape) if keep_size is False: - im_extents = [transform.matrix.matrix @ e for e in im_extents] - spatial_shape_ = shape_from_extents(im_extents) + im_extents = [transform @ e for e in im_extents] + shape_override_ = shape_from_extents(input_shape, im_extents) else: - spatial_shape_ = img_.shape + shape_override_ = input_shape metadata = { "zoom": zoom_factors, @@ -286,7 +296,7 @@ def zoom( "keep_size": keep_size, "dtype": dtype_, "im_extents": im_extents, - "spatial_shape": spatial_shape_ + "shape_override": shape_override_ } return img_, transform, metadata @@ -300,27 +310,29 @@ def rotate90( def croppad( img: torch.Tensor, slices: Union[Sequence[slice], slice], - pad_mode: Optional[Union[GridSamplePadMode, str]] = NumpyPadMode.EDGE + pad_mode: Optional[Union[GridSamplePadMode, str]] = NumpyPadMode.EDGE, + shape_override: Optional[Sequence] = None ): img_ = convert_to_tensor(img, track_meta=get_track_meta()) - input_ndim = len(img.shape) - 1 + input_shape = img_.shape if shape_override is None else shape_override + input_ndim = len(input_shape) - 1 if len(slices) != input_ndim: raise ValueError(f"'slices' length {len(slices)} must be equal to 'img' " f"spatial dimensions of {input_ndim}") - img_centers = [i / 2 for i in img.shape[1:]] + img_centers = [i / 2 for i in input_shape[1:]] slice_centers = [(s.stop + s.start) / 2 for s in slices] deltas = [s - i for i, s in zip(img_centers, slice_centers)] - transform = MatrixFactory.from_tensor(img).translate(deltas) - im_extents = extents_from_shape([img.shape[0]] + [s.stop - s.start for s in slices]) - im_extents = [transform.matrix.matrix @ e for e in im_extents] - spatial_shape_ = shape_from_extents(im_extents) + transform = MatrixFactory.from_tensor(img).translate(deltas).matrix.matrix + im_extents = extents_from_shape([input_shape[0]] + [s.stop - s.start for s in slices]) + im_extents = [transform @ e for e in im_extents] + shape_override_ = shape_from_extents(input_shape, im_extents) metadata = { "slices": slices, "pad_mode": pad_mode, "dtype": img.dtype, "im_extents": im_extents, - "spatial_shape": spatial_shape_ + "shape_override": shape_override_ } return img_, transform, metadata diff --git a/monai/utils/mapping_stack.py b/monai/utils/mapping_stack.py index 86b56b9e92..2759b35b86 100644 --- a/monai/utils/mapping_stack.py +++ b/monai/utils/mapping_stack.py @@ -14,6 +14,7 @@ import numpy as np import torch +from monai.config import NdarrayOrTensor from monai.utils.enums import TransformBackends from monai.transforms.utils import (_create_rotate, _create_scale, _create_shear, @@ -21,6 +22,13 @@ from monai.utils.misc import get_backend_from_data, get_device_from_data +def ensure_tensor(data: NdarrayOrTensor): + if isinstance(data, torch.Tensor): + return data + + return torch.as_tensor(data) + + class MatrixFactory: def __init__(self, @@ -102,8 +110,8 @@ def __rmatmul__(self, other): class Matrix: - def __init__(self, matrix): - self.matrix = matrix + def __init__(self, matrix: NdarrayOrTensor): + self.matrix = ensure_tensor(matrix) def __matmul__(self, other): if isinstance(other, Matrix): diff --git a/tests/test_atmostonce.py b/tests/test_atmostonce.py index cc347e2427..bbcbb442ff 100644 --- a/tests/test_atmostonce.py +++ b/tests/test_atmostonce.py @@ -8,29 +8,31 @@ import torch from monai.transforms.atmostonce import array as amoa +from monai.transforms.atmostonce.array import Rotate, CropPad from monai.transforms.atmostonce.lazy_transform import compile_transforms from monai.utils import TransformBackends from monai.transforms import Affined, Affine from monai.transforms.atmostonce.functional import croppad, resize, rotate, spacing -from monai.transforms.atmostonce.apply import Applyd, extents_from_shape, shape_from_extents +from monai.transforms.atmostonce.apply import Applyd, extents_from_shape, shape_from_extents, apply from monai.transforms.atmostonce.dictionary import Rotated from monai.transforms.compose import Compose from monai.utils.enums import GridSampleMode, GridSamplePadMode from monai.utils.mapping_stack import MatrixFactory -def get_img(size): - img = torch.zeros(size, dtype=torch.float32) - if len(size) == 2: - for j in range(size[0]): - for i in range(size[1]): - img[j, i] = i + j * size[1] - else: - for k in range(size[-1]): - for j in range(size[-2]): - img[..., j, k] = j + k * size[0] - return np.expand_dims(img, 0) +def get_img(size, offset = 0): + img = torch.zeros(size, dtype=torch.float32) + if len(size) == 2: + for j in range(size[0]): + for i in range(size[1]): + img[j, i] = i + j * size[1] + offset + else: + for k in range(size[0]): + for j in range(size[1]): + for i in range(size[2]): + img[..., j, k] = j * size[0] + k * size[0] * size[1] + offset + return np.expand_dims(img, 0) def enumerate_results_of_op(results): @@ -241,18 +243,36 @@ def test_croppad_img_even_crop_even(self): [115., 116., 117., 118., 119., 120.]]) self._croppad_impl((16, 16), (slice(4, 8), slice(3, 9)), expected) + # TODO: amo: add tests for matrix and result size def test_croppad(self): img = get_img((15, 15)).astype(int) results = croppad(img, (slice(4, 8), slice(3, 9))) enumerate_results_of_op(results) m = results[1].matrix.matrix - print(m) + # print(m) result_size = results[2]['spatial_shape'] a = Affine(affine=m, padding_mode=GridSamplePadMode.ZEROS, spatial_size=result_size) img_, _ = a(img) - print(img_.numpy()) + # print(img_.numpy()) + + def test_apply(self): + img = get_img((16, 16)) + r = Rotate(torch.pi / 4, + keep_size=False, + mode="bilinear", + padding_mode="zeros", + lazy_evaluation=True) + c = CropPad((slice(4, 12), slice(6, 14)), + lazy_evaluation=True) + + img_r = r(img) + cur_op = img_r.peek_pending_transform() + img_rc = c(img_r, + shape_override=cur_op.metadata.get("shape_override", None)) + + img_rca = apply(img_rc) class TestArrayTransforms(unittest.TestCase): From 77756c89d0d1302a60a53cb37a0cf7ee3e09ee46 Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Mon, 29 Aug 2022 17:51:34 +0100 Subject: [PATCH 15/30] Local changes on 034 --- monai/data/meta_tensor.py | 2 +- monai/transforms/atmostonce/apply.py | 41 ++++-- monai/transforms/atmostonce/array.py | 66 ++++++++-- monai/transforms/atmostonce/dictionary.py | 4 +- monai/transforms/atmostonce/functional.py | 97 +++++++------- monai/transforms/atmostonce/lazy_transform.py | 11 +- monai/utils/mapping_stack.py | 12 +- tests/test_atmostonce.py | 121 +++++++++++++++--- 8 files changed, 262 insertions(+), 92 deletions(-) diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index 8c6c7278d7..bde42ff279 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -159,7 +159,7 @@ def push_pending_transform(self, meta_matrix): self._pending_transforms.append(meta_matrix) def peek_pending_transform(self): - return copy.deepcopy(self._pending_transforms[0]) + return copy.deepcopy(self._pending_transforms[-1]) def pop_pending_transform(self): transform = self._pending_transforms[0] diff --git a/monai/transforms/atmostonce/apply.py b/monai/transforms/atmostonce/apply.py index 6c9fedc91a..278fac1930 100644 --- a/monai/transforms/atmostonce/apply.py +++ b/monai/transforms/atmostonce/apply.py @@ -4,6 +4,7 @@ import numpy as np +import torch from monai.config import DtypeLike from monai.data import MetaTensor @@ -28,11 +29,29 @@ def extents_from_shape(shape): # TODO: move to mapping_stack.py -def shape_from_extents(extents): - aextents = np.asarray(extents) - mins = aextents.min(axis=0) - maxes = aextents.max(axis=0) - return np.ceil(maxes - mins)[:-1].astype(int) +def shape_from_extents( + src_shape: Sequence, + extents: Union[Sequence[np.ndarray], Sequence[torch.Tensor], np.ndarray, torch.Tensor] +): + if isinstance(extents, (list, tuple)): + if isinstance(extents[0], np.ndarray): + aextents = np.asarray(extents) + aextents = torch.from_numpy(aextents) + else: + aextents = torch.stack(extents) + else: + if isinstance(extents, np.ndarray): + aextents = torch.from_numpy(extents) + else: + aextents = extents + + mins = aextents.min(axis=0)[0] + maxes = aextents.max(axis=0)[0] + values = torch.ceil(maxes - mins).type(torch.IntTensor)[:-1] + return torch.cat((torch.as_tensor([src_shape[0]]), values)) + + # return [src_shape[0]] + np.ceil(maxes - mins)[:-1].astype(int).tolist() + def apply(data: MetaTensor): @@ -41,29 +60,29 @@ def apply(data: MetaTensor): if len(pending) == 0: return data - dim_count = len(data) - 1 + dim_count = len(data.shape) - 1 matrix_factory = MatrixFactory(dim_count, get_backend_from_data(data), get_device_from_data(data)) # set up the identity matrix and metadata - cumulative_matrix = matrix_factory.identity(dim_count) + cumulative_matrix = matrix_factory.identity() cumulative_extents = extents_from_shape(data.shape) # pre-translate origin to centre of image - translate_to_centre = matrix_factory.translate(dim_count) + translate_to_centre = matrix_factory.translate([d / 2 for d in data.shape[1:]]) cumulative_matrix = translate_to_centre @ cumulative_matrix - cumulative_extents = [e @ translate_to_centre for e in cumulative_extents] + cumulative_extents = [e @ translate_to_centre.matrix.matrix for e in cumulative_extents] for meta_matrix in pending: next_matrix = meta_matrix.matrix cumulative_matrix = next_matrix @ cumulative_matrix - cumulative_extents = [e @ translate_to_centre for e in cumulative_extents] + cumulative_extents = [e @ translate_to_centre.matrix.matrix for e in cumulative_extents] # TODO: figure out how to propagate extents properly # TODO: resampling strategy: augment resample or perform multiple stages if necessary # TODO: resampling strategy - antialiasing: can resample just be augmented? - + r = Resample() data.clear_pending_transforms() diff --git a/monai/transforms/atmostonce/array.py b/monai/transforms/atmostonce/array.py index 8b568df404..516d97a45a 100644 --- a/monai/transforms/atmostonce/array.py +++ b/monai/transforms/atmostonce/array.py @@ -9,7 +9,7 @@ from monai.transforms import InvertibleTransform, RandomizableTransform from monai.transforms.atmostonce.apply import apply -from monai.transforms.atmostonce.functional import resize, rotate, zoom, spacing +from monai.transforms.atmostonce.functional import resize, rotate, zoom, spacing, croppad from monai.transforms.atmostonce.lazy_transform import LazyTransform from monai.utils import (GridSampleMode, GridSamplePadMode, @@ -20,6 +20,9 @@ # TODO: these transforms are intended to replace array transforms once development is done +# spatial +# ======= + # TODO: why doesn't Spacing have antialiasing options? class Spacing(LazyTransform, InvertibleTransform): @@ -32,7 +35,8 @@ def __init__( padding_mode: Optional[Union[GridSamplePadMode, str]] = GridSamplePadMode.BORDER, align_corners: Optional[bool] = False, dtype: Optional[DtypeLike] = np.float64, - lazy_evaluation: Optional[bool] = False + lazy_evaluation: Optional[bool] = False, + shape_override: Optional[Sequence] = None ): LazyTransform.__init__(self, lazy_evaluation) self.pixdim = pixdim @@ -49,7 +53,8 @@ def __call__( mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, align_corners: Optional[bool] = None, - dtype: DtypeLike = None + dtype: DtypeLike = None, + shape_override: Optional[Sequence] = None ): mode_ = mode or self.mode @@ -58,7 +63,8 @@ def __call__( dtype_ = dtype or self.dtype img_t, transform, metadata = spacing(img, self.pixdim, self.src_pixdim, self.diagonal, - mode_, padding_mode_, align_corners_, dtype_) + mode_, padding_mode_, align_corners_, dtype_, + shape_override) # TODO: candidate for refactoring into a LazyTransform method img_t.push_pending_transform(MetaMatrix(transform, metadata)) @@ -99,7 +105,8 @@ def __call__( mode: Optional[Union[InterpolateMode, str]] = None, align_corners: Optional[bool] = None, anti_aliasing: Optional[bool] = None, - anti_aliasing_sigma: Union[Sequence[float], float, None] = None + anti_aliasing_sigma: Union[Sequence[float], float, None] = None, + shape_override: Optional[Sequence] = None ) -> NdarrayOrTensor: mode_ = mode or self.mode align_corners_ = align_corners or self.align_corners @@ -108,7 +115,7 @@ def __call__( img_t, transform, metadata = resize(img, self.spatial_size, self.size_mode, mode_, align_corners_, anti_aliasing_, anti_aliasing_sigma_, - self.dtype) + self.dtype, shape_override) # TODO: candidate for refactoring into a LazyTransform method img_t.push_pending_transform(MetaMatrix(transform, metadata)) @@ -144,6 +151,7 @@ def __call__( mode: Optional[Union[InterpolateMode, str]] = None, padding_mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None, align_corners: Optional[bool] = None, + shape_override: Optional[Sequence] = None ) -> NdarrayOrTensor: angle = self.angle mode = mode or self.mode @@ -153,7 +161,7 @@ def __call__( dtype = self.dtype img_t, transform, metadata = rotate(img, angle, keep_size, mode, padding_mode, - align_corners, dtype) + align_corners, dtype, shape_override) # TODO: candidate for refactoring into a LazyTransform method img_t.push_pending_transform(MetaMatrix(transform, metadata)) @@ -195,7 +203,8 @@ def __call__( img: NdarrayOrTensor, mode: Optional[Union[InterpolateMode, str]] = None, padding_mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None, - align_corners: Optional[bool] = None + align_corners: Optional[bool] = None, + shape_override: Optional[Sequence] = None ) -> NdarrayOrTensor: mode = self.mode or mode @@ -205,7 +214,7 @@ def __call__( dtype = self.dtype img_t, transform, metadata = zoom(img, self.zoom, mode, padding_mode, align_corners, - keep_size, dtype) + keep_size, dtype, shape_override) # TODO: candidate for refactoring into a LazyTransform method img_t.push_pending_transform(MetaMatrix(transform, metadata)) @@ -269,7 +278,8 @@ def __call__( align_corners: Optional[bool] = None, dtype: Optional[Union[DtypeLike, torch.dtype]] = None, randomize: Optional[bool] = True, - get_matrix: Optional[bool] = False + get_matrix: Optional[bool] = False, + shape_override: Optional[Sequence] = None ) -> NdarrayOrTensor: @@ -289,7 +299,7 @@ def __call__( dtype = self.dtype img_t, transform, metadata = rotate(img, angle, keep_size, mode, padding_mode, - align_corners, dtype) + align_corners, dtype, shape_override) # TODO: candidate for refactoring into a LazyTransform method img_t.push_pending_transform(MetaMatrix(transform, metadata)) @@ -304,5 +314,39 @@ def inverse( ): raise NotImplementedError() +# croppad +# ======= + + +class CropPad(LazyTransform, InvertibleTransform): + + def __init__( + self, + slices: Sequence[slice], + padding_mode: Optional[Union[GridSamplePadMode, str]] = GridSamplePadMode.BORDER, + lazy_evaluation: Optional[bool] = True, + ): + LazyTransform.__init__(self, lazy_evaluation) + self.slices = slices + self.padding_mode = padding_mode + + def __call__( + self, + img: NdarrayOrTensor, + shape_override: Optional[Sequence] = None + ): + + img_t, transform, metadata = croppad(img, self.slices, self.padding_mode, shape_override) + # TODO: candidate for refactoring into a LazyTransform method + img_t.push_pending_transform(MetaMatrix(transform, metadata)) + if not self.lazy_evaluation: + img_t = apply(img_t) + return img_t + + def inverse( + self, + data: NdarrayOrTensor + ): + raise NotImplementedError() diff --git a/monai/transforms/atmostonce/dictionary.py b/monai/transforms/atmostonce/dictionary.py index 3ca1668008..453f21cf06 100644 --- a/monai/transforms/atmostonce/dictionary.py +++ b/monai/transforms/atmostonce/dictionary.py @@ -8,13 +8,12 @@ from monai.utils import ensure_tuple_rep from monai.config import KeysCollection, DtypeLike, SequenceStr -from monai.transforms.atmostonce.apply import apply from monai.transforms.atmostonce.lazy_transform import LazyTransform from monai.transforms.inverse import InvertibleTransform from monai.transforms.transform import MapTransform from monai.utils.enums import TransformBackends, GridSampleMode, GridSamplePadMode, InterpolateMode, NumpyPadMode, \ PytorchPadMode -from monai.utils.mapping_stack import MatrixFactory, MetaMatrix +from monai.utils.mapping_stack import MatrixFactory from monai.utils.type_conversion import expand_scalar_to_tuple @@ -242,6 +241,7 @@ def __call__(self, d: Mapping): def inverse(self, data: Any): raise NotImplementedError() + class Translated(MapTransform, InvertibleTransform): def __init__(self, diff --git a/monai/transforms/atmostonce/functional.py b/monai/transforms/atmostonce/functional.py index 692a372e94..4dfb323245 100644 --- a/monai/transforms/atmostonce/functional.py +++ b/monai/transforms/atmostonce/functional.py @@ -21,7 +21,8 @@ def spacing( mode: Optional[Union[InterpolateMode, str]] = InterpolateMode.AREA, padding_mode: Optional[Union[NumpyPadMode, GridSamplePadMode, str]] = NumpyPadMode.EDGE, align_corners: Optional[bool] = False, - dtype: Optional[Union[DtypeLike, torch.dtype]] = None + dtype: Optional[Union[DtypeLike, torch.dtype]] = None, + shape_override: Optional[Sequence] = None ): """ Args: @@ -49,7 +50,8 @@ def spacing( """ img_ = convert_to_tensor(img, track_meta=get_track_meta()) - input_ndim = len(img.shape) - 1 + input_shape = img_.shape if shape_override is None else shape_override + input_ndim = len(input_shape) - 1 pixdim_ = ensure_tuple_rep(pixdim, input_ndim) src_pixdim_ = ensure_tuple_rep(src_pixdim, input_ndim) @@ -62,10 +64,11 @@ def spacing( dtype_ = get_equivalent_dtype(dtype or img.dtype, torch.Tensor) zoom_factors = [i / j for i, j in zip(src_pixdim_, pixdim_)] - transform = MatrixFactory.from_tensor(img).scale(zoom_factors) - im_extents = extents_from_shape(img.shape) - im_extents = [transform.matrix.matrix @ e for e in im_extents] - spatial_shape_ = shape_from_extents(im_extents) + # TODO: decide whether we are consistently returning MetaMatrix or concrete transforms + transform = MatrixFactory.from_tensor(img).scale(zoom_factors).matrix.matrix + im_extents = extents_from_shape(input_shape) + im_extents = [transform @ e for e in im_extents] + shape_override_ = shape_from_extents(input_shape, im_extents) metadata = { "pixdim": pixdim_, @@ -76,7 +79,7 @@ def spacing( "align_corners": align_corners, "dtype": dtype_, "im_extents": im_extents, - "spatial_shape": spatial_shape_ + "shape_override": shape_override_ } return img_, transform, metadata @@ -101,7 +104,8 @@ def resize( align_corners: Optional[bool] = False, anti_aliasing: Optional[bool] = None, anti_aliasing_sigma: Optional[Union[Sequence[float], float]] = None, - dtype: Optional[Union[DtypeLike, torch.dtype]] = None + dtype: Optional[Union[DtypeLike, torch.dtype]] = None, + shape_override: Optional[Sequence] = None ): """ Args: @@ -129,21 +133,22 @@ def resize( """ img_ = convert_to_tensor(img, track_meta=get_track_meta()) - input_ndim = len(img.shape) - 1 + input_shape = img_.shape if shape_override is None else shape_override + input_ndim = len(input_shape) - 1 if size_mode == "all": output_ndim = len(ensure_tuple(spatial_size)) if output_ndim > input_ndim: - input_shape = ensure_tuple_size(img.shape, output_ndim + 1, 1) + input_shape = ensure_tuple_size(input_shape, output_ndim + 1, 1) img = img.reshape(input_shape) elif output_ndim < input_ndim: raise ValueError( "len(spatial_size) must be greater or equal to img spatial dimensions, " f"got spatial_size={output_ndim} img={input_ndim}." ) - spatial_size_ = fall_back_tuple(spatial_size, img.shape[1:]) + spatial_size_ = fall_back_tuple(spatial_size, input_shape[1:]) else: # for the "longest" mode - img_size = img.shape[1:] + img_size = input_shape[1:] if not isinstance(spatial_size, int): raise ValueError("spatial_size must be an int number if size_mode is 'longest'.") scale = spatial_size / max(img_size) @@ -151,11 +156,11 @@ def resize( mode_ = look_up_option(mode, GridSampleMode) dtype_ = get_equivalent_dtype(dtype or img.dtype, torch.Tensor) - zoom_factors = [i / j for i, j in zip(spatial_size, img.shape[1:])] - transform = MatrixFactory.from_tensor(img).scale(zoom_factors) - im_extents = extents_from_shape(img.shape) - im_extents = [transform.matrix.matrix @ e for e in im_extents] - spatial_shape_ = shape_from_extents(im_extents) + zoom_factors = [i / j for i, j in zip(spatial_size, input_shape[1:])] + transform = MatrixFactory.from_tensor(img).scale(zoom_factors).matrix.matrix + im_extents = extents_from_shape(input_shape) + im_extents = [transform @ e for e in im_extents] + shape_override_ = shape_from_extents(input_shape, im_extents) metadata = { "spatial_size": spatial_size_, @@ -166,7 +171,7 @@ def resize( "anti_aliasing_sigma": anti_aliasing_sigma, "dtype": dtype_, "im_extents": im_extents, - "spatial_shape": spatial_shape_ + "shape_override": shape_override_ } return img_, transform, metadata @@ -178,7 +183,8 @@ def rotate( mode: Optional[Union[InterpolateMode, str]] = InterpolateMode.AREA, padding_mode: Optional[Union[NumpyPadMode, GridSamplePadMode, str]] = NumpyPadMode.EDGE, align_corners: Optional[bool] = False, - dtype: Optional[Union[DtypeLike, torch.dtype]] = None + dtype: Optional[Union[DtypeLike, torch.dtype]] = None, + shape_override: Optional[Sequence] = None ): """ Args: @@ -210,17 +216,19 @@ def rotate( mode_ = look_up_option(mode, GridSampleMode) padding_mode_ = look_up_option(padding_mode, GridSamplePadMode) dtype_ = get_equivalent_dtype(dtype or img.dtype, torch.Tensor) - input_ndim = len(img_.shape) - 1 + input_shape = img_.shape if shape_override is None else shape_override + input_ndim = len(input_shape) - 1 if input_ndim not in (2, 3): raise ValueError(f"Unsupported image dimension: {input_ndim}, available options are [2, 3].") + angle_ = ensure_tuple_rep(angle, 1 if input_ndim == 2 else 3) transform = create_rotate(input_ndim, angle_) - im_extents = extents_from_shape(img.shape) + im_extents = extents_from_shape(input_shape) if not keep_size: im_extents = [transform @ e for e in im_extents] - spatial_shape = shape_from_extents(im_extents) + spatial_shape = shape_from_extents(input_shape, im_extents) else: - spatial_shape = img_.shape + spatial_shape = input_shape metadata = { "angle": angle_, @@ -230,7 +238,7 @@ def rotate( "align_corners": align_corners, "dtype": dtype_, "im_extents": im_extents, - "spatial_shape": spatial_shape + "shape_override": spatial_shape } return img_, transform, metadata @@ -242,7 +250,8 @@ def zoom( padding_mode: Optional[Union[NumpyPadMode, GridSamplePadMode, str]] = NumpyPadMode.EDGE, align_corners: Optional[bool] = False, keep_size: Optional[bool] = True, - dtype: Optional[Union[DtypeLike, torch.dtype]] = None + dtype: Optional[Union[DtypeLike, torch.dtype]] = None, + shape_override: Optional[Sequence] = None ): """ Args: @@ -261,7 +270,8 @@ def zoom( """ img_ = convert_to_tensor(img, track_meta=get_track_meta()) - input_ndim = len(img.shape) - 1 + input_shape = img_.shape if shape_override is None else shape_override + input_ndim = len(input_shape) - 1 zoom_factors = ensure_tuple_rep(zoom, input_ndim) @@ -270,13 +280,13 @@ def zoom( padding_mode_ = look_up_option(padding_mode, GridSamplePadMode) dtype_ = get_equivalent_dtype(dtype or img.dtype, torch.Tensor) - transform = MatrixFactory.from_tensor(img).scale(zoom_factors) - im_extents = extents_from_shape(img.shape) + transform = MatrixFactory.from_tensor(img).scale(zoom_factors).matrix.matrix + im_extents = extents_from_shape(input_shape) if keep_size is False: - im_extents = [transform.matrix.matrix @ e for e in im_extents] - spatial_shape_ = shape_from_extents(im_extents) + im_extents = [transform @ e for e in im_extents] + shape_override_ = shape_from_extents(input_shape, im_extents) else: - spatial_shape_ = img_.shape + shape_override_ = input_shape metadata = { "zoom": zoom_factors, @@ -286,7 +296,7 @@ def zoom( "keep_size": keep_size, "dtype": dtype_, "im_extents": im_extents, - "spatial_shape": spatial_shape_ + "shape_override": shape_override_ } return img_, transform, metadata @@ -300,28 +310,29 @@ def rotate90( def croppad( img: torch.Tensor, slices: Union[Sequence[slice], slice], - pad_mode: Optional[Union[GridSamplePadMode, str]] = NumpyPadMode.EDGE + pad_mode: Optional[Union[GridSamplePadMode, str]] = NumpyPadMode.EDGE, + shape_override: Optional[Sequence] = None ): img_ = convert_to_tensor(img, track_meta=get_track_meta()) - input_ndim = len(img.shape) - 1 + input_shape = img_.shape if shape_override is None else shape_override + input_ndim = len(input_shape) - 1 if len(slices) != input_ndim: raise ValueError(f"'slices' length {len(slices)} must be equal to 'img' " f"spatial dimensions of {input_ndim}") - img_centers = [i // 2 for i in img.shape[1:]] - slice_centers = [s.stop - s.start for s in slices] - # img_centers = [0 for _ in img.shape[1:]] - # slice_centers = [s.end - s.start for s in slices] + img_centers = [i / 2 for i in input_shape[1:]] + slice_centers = [(s.stop + s.start) / 2 for s in slices] deltas = [s - i for i, s in zip(img_centers, slice_centers)] - transform = MatrixFactory.from_tensor(img).translate(deltas) - im_extents = extents_from_shape([img.shape[0]] + [s.stop - s.start for s in slices]) - im_extents = [transform.matrix.matrix @ e for e in im_extents] - spatial_shape_ = shape_from_extents(im_extents) + transform = MatrixFactory.from_tensor(img).translate(deltas).matrix.matrix + im_extents = extents_from_shape([input_shape[0]] + [s.stop - s.start for s in slices]) + im_extents = [transform @ e for e in im_extents] + shape_override_ = shape_from_extents(input_shape, im_extents) metadata = { "slices": slices, + "pad_mode": pad_mode, "dtype": img.dtype, "im_extents": im_extents, - "spatial_shape": spatial_shape_ + "shape_override": shape_override_ } return img_, transform, metadata diff --git a/monai/transforms/atmostonce/lazy_transform.py b/monai/transforms/atmostonce/lazy_transform.py index fa60a5f0d4..94734cd2f0 100644 --- a/monai/transforms/atmostonce/lazy_transform.py +++ b/monai/transforms/atmostonce/lazy_transform.py @@ -1,3 +1,4 @@ +from monai.transforms import Randomizable from monai.config import NdarrayOrTensor @@ -45,14 +46,22 @@ def transforms_compatible(current, next): raise NotImplementedError() -def compile_transforms(transforms): +def compile_lazy_transforms(transforms): flat = flatten_sequences(transforms) for i in range(len(flat)-1): cur_t, next_t = flat[i], flat[i + 1] if not transforms_compatible(cur_t, next_t): flat.insert(i + 1, Applyd()) + if not isinstance(flat[-1], Applyd): + flat.append(Applyd) return flat +def compile_cached_dataloading_transforms(transforms): + flat = flatten_sequences(transforms) + for i in range(len(flat)): + cur_t = flat[i] + if isinstance(cur_t, Randomizable): + flat.insert diff --git a/monai/utils/mapping_stack.py b/monai/utils/mapping_stack.py index 86b56b9e92..2759b35b86 100644 --- a/monai/utils/mapping_stack.py +++ b/monai/utils/mapping_stack.py @@ -14,6 +14,7 @@ import numpy as np import torch +from monai.config import NdarrayOrTensor from monai.utils.enums import TransformBackends from monai.transforms.utils import (_create_rotate, _create_scale, _create_shear, @@ -21,6 +22,13 @@ from monai.utils.misc import get_backend_from_data, get_device_from_data +def ensure_tensor(data: NdarrayOrTensor): + if isinstance(data, torch.Tensor): + return data + + return torch.as_tensor(data) + + class MatrixFactory: def __init__(self, @@ -102,8 +110,8 @@ def __rmatmul__(self, other): class Matrix: - def __init__(self, matrix): - self.matrix = matrix + def __init__(self, matrix: NdarrayOrTensor): + self.matrix = ensure_tensor(matrix) def __matmul__(self, other): if isinstance(other, Matrix): diff --git a/tests/test_atmostonce.py b/tests/test_atmostonce.py index 9e0841b83f..2da1d7a6f6 100644 --- a/tests/test_atmostonce.py +++ b/tests/test_atmostonce.py @@ -8,44 +8,46 @@ import torch from monai.transforms.atmostonce import array as amoa +from monai.transforms.atmostonce.array import Rotate, CropPad from monai.transforms.atmostonce.lazy_transform import compile_transforms from monai.utils import TransformBackends from monai.transforms import Affined, Affine from monai.transforms.atmostonce.functional import croppad, resize, rotate, spacing -from monai.transforms.atmostonce.apply import Applyd, extents_from_shape, shape_from_extents +from monai.transforms.atmostonce.apply import Applyd, extents_from_shape, shape_from_extents, apply from monai.transforms.atmostonce.dictionary import Rotated from monai.transforms.compose import Compose from monai.utils.enums import GridSampleMode, GridSamplePadMode from monai.utils.mapping_stack import MatrixFactory -def get_img(size): - img = torch.zeros(size, dtype=torch.float32) - if len(size) == 2: - for j in range(size[0]): - for i in range(size[1]): - img[j, i] = i + j * size[1] - else: - for k in range(size[-1]): - for j in range(size[-2]): - img[..., j, k] = j + k * size[0] - return np.expand_dims(img, 0) +def get_img(size, offset = 0): + img = torch.zeros(size, dtype=torch.float32) + if len(size) == 2: + for j in range(size[0]): + for i in range(size[1]): + img[j, i] = i + j * size[1] + offset + else: + for k in range(size[0]): + for j in range(size[1]): + for i in range(size[2]): + img[..., j, k] = j * size[0] + k * size[0] * size[1] + offset + return np.expand_dims(img, 0) def enumerate_results_of_op(results): if isinstance(results, dict): for k, v in results.items(): if isinstance(v, (np.ndarray, torch.Tensor)): - print(k, v.shape, v[tuple(slice(0, 8) for _ in r.shape)]) + print(k, v.shape, v[tuple(slice(0, 8) for _ in v.shape)]) else: print(k, v) else: - for ir, r in enumerate(results): - if isinstance(r, (np.ndarray, torch.Tensor)): - print(ir, r.shape, r[tuple(slice(0, 8) for _ in r.shape)]) + for ir, v in enumerate(results): + if isinstance(v, (np.ndarray, torch.Tensor)): + print(ir, v.shape, v[tuple(slice(0, 8) for _ in v.shape)]) else: - print(ir, r) + print(ir, v) class TestLowLevel(unittest.TestCase): @@ -181,19 +183,96 @@ def test_rotate(self): "border") enumerate_results_of_op(results) - def test_croppad(self): + def test_croppad_identity(self): img = get_img((16, 16)).astype(int) results = croppad(img, - (slice(3, 8), slice(3, 9))) + (slice(0, 16), slice(0, 16))) + enumerate_results_of_op(results) + m = results[1].matrix.matrix + print(m) + result_size = results[2]['spatial_shape'] + a = Affine(affine=m, + padding_mode=GridSamplePadMode.ZEROS, + spatial_size=result_size) + img_, _ = a(img) + print(img_) + + def _croppad_impl(self, img_ext, slices, expected): + img = get_img(img_ext).astype(int) + results = croppad(img, slices) enumerate_results_of_op(results) m = results[1].matrix.matrix print(m) result_size = results[2]['spatial_shape'] a = Affine(affine=m, padding_mode=GridSamplePadMode.ZEROS, - spatial_size=[1] + result_size) + spatial_size=result_size) + img_, _ = a(img) + if expected is None: + print(img_.numpy()) + else: + self.assertTrue(torch.allclose(img_, expected)) + + def test_croppad_img_odd_crop_odd(self): + expected = torch.as_tensor([[63., 64., 65., 66., 67., 68., 69.], + [78., 79., 80., 81., 82., 83., 84.], + [93., 94., 95., 96., 97., 98., 99.], + [108., 109., 110., 111., 112., 113., 114.], + [123., 124., 125., 126., 127., 128., 129.]]) + self._croppad_impl((15, 15), (slice(4, 9), slice(3, 10)), expected) + + def test_croppad_img_odd_crop_even(self): + expected = torch.as_tensor([[63., 64., 65., 66., 67., 68.], + [78., 79., 80., 81., 82., 83.], + [93., 94., 95., 96., 97., 98.], + [108., 109., 110., 111., 112., 113.]]) + self._croppad_impl((15, 15), (slice(4, 8), slice(3, 9)), expected) + + def test_croppad_img_even_crop_odd(self): + expected = torch.as_tensor([[67., 68., 69., 70., 71., 72., 73.], + [83., 84., 85., 86., 87., 88., 89.], + [99., 100., 101., 102., 103., 104., 105.], + [115., 116., 117., 118., 119., 120., 121.], + [131., 132., 133., 134., 135., 136., 137.]]) + self._croppad_impl((16, 16), (slice(4, 9), slice(3, 10)), expected) + + def test_croppad_img_even_crop_even(self): + expected = torch.as_tensor([[67., 68., 69., 70., 71., 72.], + [83., 84., 85., 86., 87., 88.], + [99., 100., 101., 102., 103., 104.], + [115., 116., 117., 118., 119., 120.]]) + self._croppad_impl((16, 16), (slice(4, 8), slice(3, 9)), expected) + + # TODO: amo: add tests for matrix and result size + def test_croppad(self): + img = get_img((15, 15)).astype(int) + results = croppad(img, (slice(4, 8), slice(3, 9))) + enumerate_results_of_op(results) + m = results[1].matrix.matrix + # print(m) + result_size = results[2]['spatial_shape'] + a = Affine(affine=m, + padding_mode=GridSamplePadMode.ZEROS, + spatial_size=result_size) img_, _ = a(img) - print(img_.numpy().astype(int)) + # print(img_.numpy()) + + def test_apply(self): + img = get_img((16, 16)) + r = Rotate(torch.pi / 4, + keep_size=False, + mode="bilinear", + padding_mode="zeros", + lazy_evaluation=True) + c = CropPad((slice(4, 12), slice(6, 14)), + lazy_evaluation=True) + + img_r = r(img) + cur_op = img_r.peek_pending_transform() + img_rc = c(img_r, + shape_override=cur_op.metadata.get("shape_override", None)) + + img_rca = apply(img_rc) class TestArrayTransforms(unittest.TestCase): From 886057e1d6a6f96085793115a7e73a4e00b0e545 Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Mon, 29 Aug 2022 18:36:53 +0100 Subject: [PATCH 16/30] Minor fix to Apply; minor fix to enumerate_results_of_op --- monai/transforms/atmostonce/apply.py | 4 ++-- tests/test_atmostonce.py | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/monai/transforms/atmostonce/apply.py b/monai/transforms/atmostonce/apply.py index d981eadea4..9a17dfe4b9 100644 --- a/monai/transforms/atmostonce/apply.py +++ b/monai/transforms/atmostonce/apply.py @@ -72,12 +72,12 @@ def apply(data: MetaTensor): # pre-translate origin to centre of image translate_to_centre = matrix_factory.translate([d / 2 for d in data.shape[1:]]) cumulative_matrix = translate_to_centre @ cumulative_matrix - cumulative_extents = [e @ translate_to_centre for e in cumulative_extents] + cumulative_extents = [e @ translate_to_centre.matrix.matrix for e in cumulative_extents] for meta_matrix in pending: next_matrix = meta_matrix.matrix cumulative_matrix = next_matrix @ cumulative_matrix - cumulative_extents = [e @ translate_to_centre for e in cumulative_extents] + cumulative_extents = [e @ translate_to_centre.matrix.matrix for e in cumulative_extents] # TODO: figure out how to propagate extents properly # TODO: resampling strategy: augment resample or perform multiple stages if necessary diff --git a/tests/test_atmostonce.py b/tests/test_atmostonce.py index bbcbb442ff..2da1d7a6f6 100644 --- a/tests/test_atmostonce.py +++ b/tests/test_atmostonce.py @@ -39,15 +39,15 @@ def enumerate_results_of_op(results): if isinstance(results, dict): for k, v in results.items(): if isinstance(v, (np.ndarray, torch.Tensor)): - print(k, v.shape, v[tuple(slice(0, 8) for _ in r.shape)]) + print(k, v.shape, v[tuple(slice(0, 8) for _ in v.shape)]) else: print(k, v) else: - for ir, r in enumerate(results): - if isinstance(r, (np.ndarray, torch.Tensor)): - print(ir, r.shape, r[tuple(slice(0, 8) for _ in r.shape)]) + for ir, v in enumerate(results): + if isinstance(v, (np.ndarray, torch.Tensor)): + print(ir, v.shape, v[tuple(slice(0, 8) for _ in v.shape)]) else: - print(ir, r) + print(ir, v) class TestLowLevel(unittest.TestCase): From 142177325b22ebee6b34b40d155a35f20a068386 Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Tue, 30 Aug 2022 16:07:29 +0100 Subject: [PATCH 17/30] Working on transform based compose compilers --- monai/transforms/atmostonce/apply.py | 7 +- monai/transforms/atmostonce/compose.py | 67 +++++++++++++++++++ monai/transforms/atmostonce/lazy_transform.py | 1 + tests/test_atmostonce.py | 2 +- 4 files changed, 74 insertions(+), 3 deletions(-) diff --git a/monai/transforms/atmostonce/apply.py b/monai/transforms/atmostonce/apply.py index 278fac1930..670ab8ebd8 100644 --- a/monai/transforms/atmostonce/apply.py +++ b/monai/transforms/atmostonce/apply.py @@ -5,6 +5,7 @@ import numpy as np import torch +from monai.transforms import Affine from monai.config import DtypeLike from monai.data import MetaTensor @@ -82,8 +83,10 @@ def apply(data: MetaTensor): # TODO: figure out how to propagate extents properly # TODO: resampling strategy: augment resample or perform multiple stages if necessary # TODO: resampling strategy - antialiasing: can resample just be augmented? - r = Resample() - + # r = Resample() + a = Affine(affine=cumulative_matrix.matrix.matrix, + padding_mode=cur_padding_mode, + spatial_size=cur_spatial_size) data.clear_pending_transforms() diff --git a/monai/transforms/atmostonce/compose.py b/monai/transforms/atmostonce/compose.py index 6d7464f547..9b1851a59f 100644 --- a/monai/transforms/atmostonce/compose.py +++ b/monai/transforms/atmostonce/compose.py @@ -12,6 +12,73 @@ # TODO: this is intended to replace Compose once development is done +class ComposeCompiler: + """ + Args: + transforms: A sequence of callable transforms + lazy_resampling: Whether to resample the data after each transform or accumulate + changes and then resample according to the accumulated changes as few times as + possible. Defaults to True as this nearly always improves speed and quality + caching_policy: Whether to cache deterministic transforms before the first + randomised transforms. This can be one of "off", "drive", "memory" + caching_favor: Whether to cache primarily for "speed" or for "quality". "speed" will + favor doing more work before caching, whereas "quality" will favour delaying + resampling until after caching + """ + def __init__( + self, + transforms: Union[Sequence[Callable], Callable], + lazy_resampling: Optional[bool] = True, + caching_policy: Optional[str] = "off", + caching_favor: Optional[str] = "quality" + ): + valid_caching_policies = ("off", "drive", "memory") + if caching_policy not in valid_caching_policies: + raise ValueError("parameter 'caching_policy' must be one of " + f"{valid_caching_policies} but is '{caching_policy}'") + + dest_transforms = None + + if caching_policy == "off": + if lazy_resampling is False: + dest_transforms = [t for t in transforms] + else: + dest_transforms = ComposeCompiler.lazy_no_cache() + else: + if caching_policy == "drive": + raise NotImplementedError() + elif caching_policy == "memory": + raise NotImplementedError() + + self.src_transforms = [t for t in transforms] + self.dest_transforms = dest_transforms + + def __getitem__( + self, + index + ): + return self.dest_transforms[index] + + def __len__(self): + return len(self.dest_transforms) + + @staticmethod + def lazy_no_cache(transforms): + dest_transforms = [] + # TODO: replace with lazy transform + cur_lazy = [] + for i_t in range(1, len(transforms)): + if isinstance(transforms[i_t], LazyTransform): + # add this to the stack of transforms to be handled lazily + cur_lazy.append(transforms[i_t]) + else: + if len(cur_lazy) > 0: + dest_transforms.append(cur_lazy) + # TODO: replace with lazy transform + cur_lazy = [] + dest_transforms.append(transforms[i_t]) + return dest_transforms + class Compose(Randomizable, InvertibleTransform): """ diff --git a/monai/transforms/atmostonce/lazy_transform.py b/monai/transforms/atmostonce/lazy_transform.py index 9d97d2bff3..a027fc978e 100644 --- a/monai/transforms/atmostonce/lazy_transform.py +++ b/monai/transforms/atmostonce/lazy_transform.py @@ -54,6 +54,7 @@ def compile_lazy_transforms(transforms): flat.append(Applyd) return flat + def compile_cached_dataloading_transforms(transforms): flat = flatten_sequences(transforms) for i in range(len(flat)): diff --git a/tests/test_atmostonce.py b/tests/test_atmostonce.py index 2da1d7a6f6..f89b17da34 100644 --- a/tests/test_atmostonce.py +++ b/tests/test_atmostonce.py @@ -9,7 +9,7 @@ from monai.transforms.atmostonce import array as amoa from monai.transforms.atmostonce.array import Rotate, CropPad -from monai.transforms.atmostonce.lazy_transform import compile_transforms +from monai.transforms.atmostonce.lazy_transform import compile_lazy_transforms from monai.utils import TransformBackends from monai.transforms import Affined, Affine From 57ae027f1a7fe032e98b9fa8a13959ffc0d86e0d Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Wed, 31 Aug 2022 15:21:48 +0100 Subject: [PATCH 18/30] Work on rotate_90 functional, and associated tests --- monai/transforms/atmostonce/array.py | 117 ++++++++++++---------- monai/transforms/atmostonce/functional.py | 26 +++-- tests/test_atmostonce.py | 53 ++++++++++ 3 files changed, 134 insertions(+), 62 deletions(-) diff --git a/monai/transforms/atmostonce/array.py b/monai/transforms/atmostonce/array.py index b90f5f2dca..91d4b4a226 100644 --- a/monai/transforms/atmostonce/array.py +++ b/monai/transforms/atmostonce/array.py @@ -27,16 +27,16 @@ class Spacing(LazyTransform, InvertibleTransform): def __init__( - self, - pixdim: Union[Sequence[float], float, np.ndarray], - src_pixdim: Optional[Union[Sequence[float], float, np.ndarray]], - diagonal: Optional[bool] = False, - mode: Optional[Union[GridSampleMode, str]] = GridSampleMode.BILINEAR, - padding_mode: Optional[Union[GridSamplePadMode, str]] = GridSamplePadMode.BORDER, - align_corners: Optional[bool] = False, - dtype: Optional[DtypeLike] = np.float64, - lazy_evaluation: Optional[bool] = False, - shape_override: Optional[Sequence] = None + self, + pixdim: Union[Sequence[float], float, np.ndarray], + src_pixdim: Optional[Union[Sequence[float], float, np.ndarray]], + diagonal: Optional[bool] = False, + mode: Optional[Union[GridSampleMode, str]] = GridSampleMode.BILINEAR, + padding_mode: Optional[Union[GridSamplePadMode, str]] = GridSamplePadMode.BORDER, + align_corners: Optional[bool] = False, + dtype: Optional[DtypeLike] = np.float64, + lazy_evaluation: Optional[bool] = False, + shape_override: Optional[Sequence] = None ): LazyTransform.__init__(self, lazy_evaluation) self.pixdim = pixdim @@ -80,15 +80,15 @@ def inverse(self, data): class Resize(LazyTransform, InvertibleTransform): def __init__( - self, - spatial_size: Union[Sequence[int], int], - size_mode: Optional[str] = "all", - mode: Union[InterpolateMode, str] = InterpolateMode.AREA, - align_corners: Optional[bool] = False, - anti_aliasing: Optional[bool] = False, - anti_aliasing_sigma: Optional[Union[Sequence[float], float, None]] = None, - dtype: Optional[Union[DtypeLike, torch.dtype]] = np.float32, - lazy_evaluation: Optional[bool] = False + self, + spatial_size: Union[Sequence[int], int], + size_mode: Optional[str] = "all", + mode: Union[InterpolateMode, str] = InterpolateMode.AREA, + align_corners: Optional[bool] = False, + anti_aliasing: Optional[bool] = False, + anti_aliasing_sigma: Optional[Union[Sequence[float], float, None]] = None, + dtype: Optional[Union[DtypeLike, torch.dtype]] = np.float32, + lazy_evaluation: Optional[bool] = False ): LazyTransform.__init__(self, lazy_evaluation) self.spatial_size = spatial_size @@ -128,14 +128,14 @@ def __call__( class Rotate(LazyTransform, InvertibleTransform): def __init__( - self, - angle: Union[Sequence[float], float], - keep_size: bool = True, - mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, - padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, - align_corners: bool = False, - dtype: Union[DtypeLike, torch.dtype] = np.float32, - lazy_evaluation: Optional[bool] = False + self, + angle: Union[Sequence[float], float], + keep_size: bool = True, + mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, + padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, + align_corners: bool = False, + dtype: Union[DtypeLike, torch.dtype] = np.float32, + lazy_evaluation: Optional[bool] = False ): LazyTransform.__init__(self, lazy_evaluation) self.angle = angle @@ -181,14 +181,14 @@ class Zoom(LazyTransform, InvertibleTransform): """ def __init__( - self, - zoom: Union[Sequence[float], float], - mode: Union[InterpolateMode, str] = InterpolateMode.AREA, - padding_mode: Union[NumpyPadMode, PytorchPadMode, str] = NumpyPadMode.EDGE, - align_corners: Optional[bool] = None, - keep_size: bool = True, - dtype: Union[DtypeLike, torch.dtype] = np.float32, - **kwargs + self, + zoom: Union[Sequence[float], float], + mode: Union[InterpolateMode, str] = InterpolateMode.AREA, + padding_mode: Union[NumpyPadMode, PytorchPadMode, str] = NumpyPadMode.EDGE, + align_corners: Optional[bool] = None, + keep_size: bool = True, + dtype: Union[DtypeLike, torch.dtype] = np.float32, + **kwargs ): self.zoom = zoom self.mode: InterpolateMode = InterpolateMode(mode) @@ -227,19 +227,27 @@ def inverse(self, data): raise NotImplementedError() +# class Rotate90(InvertibleTransform, LazyTransform): +# +# def __init__( +# self, +# +# ): +# pass + class RandRotate(RandomizableTransform, InvertibleTransform, LazyTransform): def __init__( - self, - range_x: Optional[Union[Tuple[float, float], float]] = 0.0, - range_y: Optional[Union[Tuple[float, float], float]] = 0.0, - range_z: Optional[Union[Tuple[float, float], float]] = 0.0, - prob: Optional[float] = 0.1, - keep_size: bool = True, - mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, - padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, - align_corners: bool = False, - dtype: Union[DtypeLike, torch.dtype] = np.float32 + self, + range_x: Optional[Union[Tuple[float, float], float]] = 0.0, + range_y: Optional[Union[Tuple[float, float], float]] = 0.0, + range_z: Optional[Union[Tuple[float, float], float]] = 0.0, + prob: Optional[float] = 0.1, + keep_size: bool = True, + mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, + padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, + align_corners: bool = False, + dtype: Union[DtypeLike, torch.dtype] = np.float32 ): RandomizableTransform.__init__(self, prob) self.range_x = ensure_tuple(range_x) @@ -271,16 +279,15 @@ def randomize(self, data: Optional[Any] = None) -> None: self.z = self.R.uniform(low=self.range_z[0], high=self.range_z[1]) def __call__( - self, - img: NdarrayOrTensor, - mode: Optional[Union[InterpolateMode, str]] = None, - padding_mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None, - align_corners: Optional[bool] = None, - dtype: Optional[Union[DtypeLike, torch.dtype]] = None, - randomize: Optional[bool] = True, - get_matrix: Optional[bool] = False, - shape_override: Optional[Sequence] = None - + self, + img: NdarrayOrTensor, + mode: Optional[Union[InterpolateMode, str]] = None, + padding_mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None, + align_corners: Optional[bool] = None, + dtype: Optional[Union[DtypeLike, torch.dtype]] = None, + randomize: Optional[bool] = True, + get_matrix: Optional[bool] = False, + shape_override: Optional[Sequence] = None ) -> NdarrayOrTensor: if randomize: diff --git a/monai/transforms/atmostonce/functional.py b/monai/transforms/atmostonce/functional.py index 4dfb323245..bf74e50d20 100644 --- a/monai/transforms/atmostonce/functional.py +++ b/monai/transforms/atmostonce/functional.py @@ -1,8 +1,8 @@ -from typing import Optional, Sequence, Union +from typing import Optional, Sequence, Tuple, Union import torch -from monai.transforms import create_rotate +from monai.transforms import create_rotate, map_spatial_axes from monai.data import get_track_meta from monai.transforms.atmostonce.apply import extents_from_shape, shape_from_extents @@ -275,7 +275,6 @@ def zoom( zoom_factors = ensure_tuple_rep(zoom, input_ndim) - mode_ = look_up_option(mode, GridSampleMode) padding_mode_ = look_up_option(padding_mode, GridSamplePadMode) dtype_ = get_equivalent_dtype(dtype or img.dtype, torch.Tensor) @@ -301,10 +300,23 @@ def zoom( return img_, transform, metadata -def rotate90( - img: torch.Tensor -): - pass +# def rotate90( +# img: torch.Tensor, +# k: Optional[int] = 1, +# spatial_axes: Optional[Tuple[int, int]] = (0, 1), +# ): +# if len(spatial_axes) != 2: +# raise ValueError("'spatial_axes' must be a tuple of two integers indicating") +# +# img = convert_to_tensor(img, track_meta=get_track_meta()) +# axes = map_spatial_axes(img.ndim, spatial_axes) +# ori_shape = img.shape[1:] +# +# metadata = { +# "k": k, +# "spatial_axes": spatial_axes, +# "shape_override": shape_override +# } def croppad( diff --git a/tests/test_atmostonce.py b/tests/test_atmostonce.py index f89b17da34..c83cffd823 100644 --- a/tests/test_atmostonce.py +++ b/tests/test_atmostonce.py @@ -183,6 +183,59 @@ def test_rotate(self): "border") enumerate_results_of_op(results) + def _check_matrix(self, actual, expected): + np.allclose(actual, expected) + def _test_rotate_90_impl(self, values, keep_dims, expected): + results = rotate(np.zeros((1, 64, 64, 32), dtype=np.float32), + values, + keep_dims, + "bilinear", + "border") + # enumerate_results_of_op(results) + self._check_matrix(results[1], expected) + + def test_rotate_d0_r1(self): + expected = np.asarray([[1, 0, 0, 0], + [0, 0, -1, 0], + [0, 1, 0, 0], + [0, 0, 0, 1]]) + self._test_rotate_90_impl((torch.pi / 2, 0, 0), True, expected) + + def test_rotate_d0_r2(self): + expected = np.asarray([[1, 0, 0, 0], + [0, -1, 0, 0], + [0, 0, -1, 0], + [0, 0, 0, 1]]) + self._test_rotate_90_impl((torch.pi, 0, 0), True, expected) + + def test_rotate_d0_r3(self): + expected = np.asarray([[1, 0, 0, 0], + [0, 0, 1, 0], + [0, -1, 0, 0], + [0, 0, 0, 1]]) + self._test_rotate_90_impl((3 * torch.pi / 2, 0, 0), True, expected) + + def test_rotate_d2_r1(self): + expected = np.asarray([[0, -1, 0, 0], + [1, 0, 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 1]]) + self._test_rotate_90_impl((0, 0, torch.pi / 2), True, expected) + + def test_rotate_d2_r2(self): + expected = np.asarray([[-1, 0, 0, 0], + [0, -1, 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 1]]) + self._test_rotate_90_impl((0, 0, torch.pi), True, expected) + + def test_rotate_d2_r3(self): + expected = np.asarray([[0, 1, 0, 0], + [-1, 0, 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 1]]) + self._test_rotate_90_impl((0, 0, 3 * torch.pi / 2), True, expected) + def test_croppad_identity(self): img = get_img((16, 16)).astype(int) results = croppad(img, From 2f1bf91c8ec1b356030a0d44b417b7bb9db159f4 Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Thu, 1 Sep 2022 15:41:09 +0100 Subject: [PATCH 19/30] Work on apply --- monai/transforms/atmostonce/apply.py | 112 +++++++++++++++++++++++---- tests/test_atmostonce.py | 21 +++-- 2 files changed, 114 insertions(+), 19 deletions(-) diff --git a/monai/transforms/atmostonce/apply.py b/monai/transforms/atmostonce/apply.py index 278fac1930..f1e46db2e4 100644 --- a/monai/transforms/atmostonce/apply.py +++ b/monai/transforms/atmostonce/apply.py @@ -1,10 +1,11 @@ -from typing import Sequence, Union +from typing import Optional, Sequence, Union import itertools as it import numpy as np import torch +from monai.transforms import Resample, Affine from monai.config import DtypeLike from monai.data import MetaTensor @@ -53,18 +54,16 @@ def shape_from_extents( # return [src_shape[0]] + np.ceil(maxes - mins)[:-1].astype(int).tolist() +def metadata_is_compatible(value_1, value_2): + if value_1 is None: + return True + else: + if value_2 is None: + return True + return value_1 == value_2 -def apply(data: MetaTensor): - pending = data.pending_transforms - - if len(pending) == 0: - return data - - dim_count = len(data.shape) - 1 - matrix_factory = MatrixFactory(dim_count, - get_backend_from_data(data), - get_device_from_data(data)) +def starting_matrix_and_extents(matrix_factory, data): # set up the identity matrix and metadata cumulative_matrix = matrix_factory.identity() cumulative_extents = extents_from_shape(data.shape) @@ -73,16 +72,101 @@ def apply(data: MetaTensor): translate_to_centre = matrix_factory.translate([d / 2 for d in data.shape[1:]]) cumulative_matrix = translate_to_centre @ cumulative_matrix cumulative_extents = [e @ translate_to_centre.matrix.matrix for e in cumulative_extents] + return cumulative_matrix, cumulative_extents + + +def prepare_args_dict_for_apply(cur_mode, cur_padding_mode, cur_device, cur_dtype): + kwargs = {} + if cur_mode is not None: + kwargs['mode'] = cur_mode + if cur_padding_mode is not None: + kwargs['padding_mode'] = cur_padding_mode + if cur_device is not None: + kwargs['device'] = cur_device + if cur_dtype is not None: + kwargs['dtype'] = cur_dtype + + return kwargs - for meta_matrix in pending: + +def apply(data: Union[torch.Tensor, MetaTensor], + pending: Optional[dict] = None): + pending_ = pending + pending_ = data.pending_transforms + + if len(pending) == 0: + return data + + dim_count = len(data.shape) - 1 + matrix_factory = MatrixFactory(dim_count, + get_backend_from_data(data), + get_device_from_data(data)) + + # # set up the identity matrix and metadata + # cumulative_matrix = matrix_factory.identity() + # cumulative_extents = extents_from_shape(data.shape) + # + # # pre-translate origin to centre of image + # translate_to_centre = matrix_factory.translate([d / 2 for d in data.shape[1:]]) + # cumulative_matrix = translate_to_centre @ cumulative_matrix + # cumulative_extents = [e @ translate_to_centre.matrix.matrix for e in cumulative_extents] + cumulative_matrix, cumulative_extents = starting_matrix_and_extents(matrix_factory, data) + + # set the various resampling parameters to an initial state + cur_mode = None + cur_padding_mode = None + cur_device = None + cur_dtype = None + + for meta_matrix in pending_: next_matrix = meta_matrix.matrix cumulative_matrix = next_matrix @ cumulative_matrix - cumulative_extents = [e @ translate_to_centre.matrix.matrix for e in cumulative_extents] + # cumulative_extents = [e @ translate_to_centre.matrix.matrix for e in cumulative_extents] + cumulative_extents = [e @ cumulative_matrix.matrix.matrix for e in cumulative_extents] + + new_mode = meta_matrix.metadata.get('mode', None) + new_padding_mode = meta_matrix.metadata.get('padding_mode', None) + new_device = meta_matrix.metadata.get('device', None) + new_dtype = meta_matrix.metadata.get('dtype', None) + + mode_compat = metadata_is_compatible(cur_mode, new_mode) + padding_mode_compat = metadata_is_compatible(cur_padding_mode, new_padding_mode) + device_compat = metadata_is_compatible(cur_device, new_device) + dtype_compat = metadata_is_compatible(cur_dtype, new_dtype) + + if (mode_compat is False or padding_mode_compat is False or + device_compat is False or dtype_compat is False): + print("intermediate apply required") + # carry out an intermediate resample here due to incompatibility between arguments + # kwargs = {} + # if cur_mode is not None: + # kwargs['mode'] = cur_mode + # if cur_padding_mode is not None: + # kwargs['padding_mode'] = cur_padding_mode + # if cur_device is not None: + # kwargs['device'] = cur_device + # if cur_dtype is not None: + # kwargs['dtype'] = cur_dtype + kwargs = prepare_args_dict_for_apply(cur_mode, cur_padding_mode, cur_device, cur_dtype) + + a = Affine(norm_coords=False, + affine=cumulative_matrix.matrix.matrix, + **kwargs) + data = a(img=data) + + cur_mode = new_mode + cur_padding_mode = new_padding_mode + cur_device = new_device + cur_dtype = new_dtype # TODO: figure out how to propagate extents properly # TODO: resampling strategy: augment resample or perform multiple stages if necessary # TODO: resampling strategy - antialiasing: can resample just be augmented? - r = Resample() + + a = Affine(norm_coords=False, + affine=cumulative_matrix.matrix.matrix, + **kwargs) + data = a(img=data) data.clear_pending_transforms() diff --git a/tests/test_atmostonce.py b/tests/test_atmostonce.py index 2da1d7a6f6..35ceb6e221 100644 --- a/tests/test_atmostonce.py +++ b/tests/test_atmostonce.py @@ -9,7 +9,7 @@ from monai.transforms.atmostonce import array as amoa from monai.transforms.atmostonce.array import Rotate, CropPad -from monai.transforms.atmostonce.lazy_transform import compile_transforms +from monai.transforms.atmostonce.lazy_transform import compile_lazy_transforms from monai.utils import TransformBackends from monai.transforms import Affined, Affine @@ -243,6 +243,9 @@ def test_croppad_img_even_crop_even(self): [115., 116., 117., 118., 119., 120.]]) self._croppad_impl((16, 16), (slice(4, 8), slice(3, 9)), expected) + +class TestArrayTransforms(unittest.TestCase): + # TODO: amo: add tests for matrix and result size def test_croppad(self): img = get_img((15, 15)).astype(int) @@ -274,9 +277,6 @@ def test_apply(self): img_rca = apply(img_rc) - -class TestArrayTransforms(unittest.TestCase): - def test_rand_rotate(self): r = amoa.RandRotate((-torch.pi / 4, torch.pi / 4), prob=0.0, @@ -289,8 +289,19 @@ def test_rand_rotate(self): enumerate_results_of_op(results) enumerate_results_of_op(results.pending_transforms[-1].metadata) + def test_rotate_apply(self): + r = amoa.Rotate(-torch.pi / 4, + mode="bilinear", + padding_mode="border", + keep_size=False) + data = get_img((32, 32)) + data = r(data) + data = apply(data) + print(data) + + -class TestRotateEulerd(unittest.TestCase): +class TestDictionaryTransforms(unittest.TestCase): def test_rotate_numpy(self): r = Rotated(('image', 'label'), [0.0, 1.0, 0.0]) From d0b490bbdd0da3a4b774ca17654392d37a10ce37 Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Tue, 6 Sep 2022 10:56:55 +0100 Subject: [PATCH 20/30] bug fixes --- monai/data/meta_tensor.py | 3 + monai/transforms/atmostonce/apply.py | 93 +++++++++++++++-------- monai/transforms/atmostonce/array.py | 66 ++++++++++++++-- monai/transforms/atmostonce/compose.py | 4 +- monai/transforms/atmostonce/functional.py | 47 ++++++++++-- monai/transforms/atmostonce/utils.py | 44 +++++++++++ monai/utils/mapping_stack.py | 12 ++- monai/utils/type_conversion.py | 31 ++++++++ tests/test_atmostonce.py | 36 ++++++++- 9 files changed, 284 insertions(+), 52 deletions(-) create mode 100644 monai/transforms/atmostonce/utils.py diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index bde42ff279..deda678ef9 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -158,6 +158,9 @@ def __init__( def push_pending_transform(self, meta_matrix): self._pending_transforms.append(meta_matrix) + def has_pending_transforms(self): + return len(self._pending_transforms) + def peek_pending_transform(self): return copy.deepcopy(self._pending_transforms[-1]) diff --git a/monai/transforms/atmostonce/apply.py b/monai/transforms/atmostonce/apply.py index a857e5952b..e6e954a5a4 100644 --- a/monai/transforms/atmostonce/apply.py +++ b/monai/transforms/atmostonce/apply.py @@ -11,22 +11,25 @@ from monai.transforms import Affine from monai.transforms.inverse import InvertibleTransform from monai.transforms.transform import MapTransform +from monai.transforms.atmostonce.utils import matmul from monai.utils import GridSampleMode, GridSamplePadMode from monai.utils.misc import get_backend_from_data, get_device_from_data -from monai.utils.mapping_stack import MatrixFactory +from monai.utils.mapping_stack import MatrixFactory, MetaMatrix, Matrix # TODO: This should move to a common place to be shared with dictionary +from monai.utils.type_conversion import dtypes_to_str_or_identity + GridSampleModeSequence = Union[Sequence[Union[GridSampleMode, str]], GridSampleMode, str] GridSamplePadModeSequence = Union[Sequence[Union[GridSamplePadMode, str]], GridSamplePadMode, str] DtypeSequence = Union[Sequence[DtypeLike], DtypeLike] # TODO: move to mapping_stack.py -def extents_from_shape(shape): +def extents_from_shape(shape, dtype=np.float64): extents = [[0, shape[i]] for i in range(1, len(shape))] extents = it.product(*extents) - return list(np.asarray(e + (1,)) for e in extents) + return list(np.asarray(e + (1,), dtype=dtype) for e in extents) # TODO: move to mapping_stack.py @@ -48,7 +51,8 @@ def shape_from_extents( mins = aextents.min(axis=0)[0] maxes = aextents.max(axis=0)[0] - values = torch.ceil(maxes - mins).type(torch.IntTensor)[:-1] + values = torch.round(maxes - mins).type(torch.IntTensor)[:-1] + #values = torch.ceil(maxes - mins).type(torch.IntTensor)[:-1] return torch.cat((torch.as_tensor([src_shape[0]]), values)) # return [src_shape[0]] + np.ceil(maxes - mins)[:-1].astype(int).tolist() @@ -62,6 +66,19 @@ def metadata_is_compatible(value_1, value_2): return True return value_1 == value_2 +def metadata_dtype_is_compatible(value_1, value_2): + if value_1 is None: + return True + else: + if value_2 is None: + return True + + # if we are here, value_1 and value_2 are both set + # TODO: this is not a good enough solution + value_1_ = dtypes_to_str_or_identity(value_1) + value_2_ = dtypes_to_str_or_identity(value_2) + return value_1_ == value_2_ + def starting_matrix_and_extents(matrix_factory, data): # set up the identity matrix and metadata @@ -69,9 +86,9 @@ def starting_matrix_and_extents(matrix_factory, data): cumulative_extents = extents_from_shape(data.shape) # pre-translate origin to centre of image - translate_to_centre = matrix_factory.translate([d / 2 for d in data.shape[1:]]) - cumulative_matrix = translate_to_centre @ cumulative_matrix - cumulative_extents = [e @ translate_to_centre.matrix.matrix for e in cumulative_extents] + # translate_to_centre = matrix_factory.translate([d / 2 for d in data.shape[1:]]) + # cumulative_matrix = translate_to_centre @ cumulative_matrix + # cumulative_extents = [matmul(e, translate_to_centre.matrix.matrix) for e in cumulative_extents] return cumulative_matrix, cumulative_extents @@ -89,12 +106,21 @@ def prepare_args_dict_for_apply(cur_mode, cur_padding_mode, cur_device, cur_dtyp return kwargs +def matrix_from_matrix_container(matrix): + if isinstance(matrix, MetaMatrix): + return matrix.matrix.matrix + elif isinstance(matrix, Matrix): + return matrix.matrix + else: + return matrix + + def apply(data: Union[torch.Tensor, MetaTensor], pending: Optional[dict] = None): pending_ = pending pending_ = data.pending_transforms - if len(pending) == 0: + if len(pending_) == 0: return data dim_count = len(data.shape) - 1 @@ -117,60 +143,67 @@ def apply(data: Union[torch.Tensor, MetaTensor], cur_padding_mode = None cur_device = None cur_dtype = None + cur_shape = data.shape for meta_matrix in pending_: next_matrix = meta_matrix.matrix - cumulative_matrix = next_matrix @ cumulative_matrix + print("intermediate matrix\n", matrix_from_matrix_container(cumulative_matrix)) + # cumulative_matrix = matmul(next_matrix, cumulative_matrix) + cumulative_matrix = matmul(cumulative_matrix, next_matrix) # cumulative_extents = [e @ translate_to_centre.matrix.matrix for e in cumulative_extents] - cumulative_extents = [e @ cumulative_matrix.matrix.matrix for e in cumulative_extents] + cumulative_extents = [matmul(e, cumulative_matrix) for e in cumulative_extents] new_mode = meta_matrix.metadata.get('mode', None) new_padding_mode = meta_matrix.metadata.get('padding_mode', None) new_device = meta_matrix.metadata.get('device', None) new_dtype = meta_matrix.metadata.get('dtype', None) + new_shape = meta_matrix.metadata.get('shape_override', None) mode_compat = metadata_is_compatible(cur_mode, new_mode) padding_mode_compat = metadata_is_compatible(cur_padding_mode, new_padding_mode) device_compat = metadata_is_compatible(cur_device, new_device) - dtype_compat = metadata_is_compatible(cur_dtype, new_dtype) + dtype_compat = metadata_dtype_is_compatible(cur_dtype, new_dtype) if (mode_compat is False or padding_mode_compat is False or device_compat is False or dtype_compat is False): print("intermediate apply required") # carry out an intermediate resample here due to incompatibility between arguments - # kwargs = {} - # if cur_mode is not None: - # kwargs['mode'] = cur_mode - # if cur_padding_mode is not None: - # kwargs['padding_mode'] = cur_padding_mode - # if cur_device is not None: - # kwargs['device'] = cur_device - # if cur_dtype is not None: - # kwargs['dtype'] = cur_dtype kwargs = prepare_args_dict_for_apply(cur_mode, cur_padding_mode, cur_device, cur_dtype) + cumulative_matrix_ = matrix_from_matrix_container(cumulative_matrix) + print(f"intermediate applying with cumulative matrix\n {cumulative_matrix_}") a = Affine(norm_coords=False, - affine=cumulative_matrix.matrix.matrix, + affine=cumulative_matrix_, **kwargs) - data = a(img=data) + data, _ = a(img=data) - cur_mode = new_mode - cur_padding_mode = new_padding_mode - cur_device = new_device - cur_dtype = new_dtype + cumulative_matrix, cumulative_extents =\ + starting_matrix_and_extents(matrix_factory, data) + cur_mode = cur_mode if new_mode is None else new_mode + cur_padding_mode = cur_padding_mode if new_padding_mode is None else new_padding_mode + cur_device = cur_device if new_device is None else new_device + cur_dtype = cur_dtype if new_dtype is None else new_dtype + cur_shape = cur_shape if new_shape is None else new_shape # TODO: figure out how to propagate extents properly # TODO: resampling strategy: augment resample or perform multiple stages if necessary # TODO: resampling strategy - antialiasing: can resample just be augmented? + kwargs = prepare_args_dict_for_apply(cur_mode, cur_padding_mode, cur_device, cur_dtype) + + cumulative_matrix_ = matrix_from_matrix_container(cumulative_matrix) + + print(f"applying with cumulative matrix\n {cumulative_matrix_}") a = Affine(norm_coords=False, - affine=cumulative_matrix.matrix.matrix, - # spatial_size=cur_spatial_size, + affine=cumulative_matrix_, + spatial_size=cur_shape[1:], + normalized=False, **kwargs) - data = a(img=data) - + data, tx = a(img=data) data.clear_pending_transforms() + return data + class Apply(InvertibleTransform): diff --git a/monai/transforms/atmostonce/array.py b/monai/transforms/atmostonce/array.py index 91d4b4a226..2fe891a019 100644 --- a/monai/transforms/atmostonce/array.py +++ b/monai/transforms/atmostonce/array.py @@ -5,11 +5,12 @@ import torch from monai.config import DtypeLike, NdarrayOrTensor +from monai.data import MetaTensor from monai.transforms import InvertibleTransform, RandomizableTransform from monai.transforms.atmostonce.apply import apply -from monai.transforms.atmostonce.functional import resize, rotate, zoom, spacing, croppad +from monai.transforms.atmostonce.functional import resize, rotate, zoom, spacing, croppad, translate from monai.transforms.atmostonce.lazy_transform import LazyTransform from monai.utils import (GridSampleMode, GridSamplePadMode, @@ -160,8 +161,12 @@ def __call__( keep_size = self.keep_size dtype = self.dtype + shape_override_ = shape_override + if shape_override_ is None and isinstance(img, MetaTensor) and img.has_pending_transforms(): + shape_override_ = img.peek_pending_transform().metadata.get("shape_override", None) + img_t, transform, metadata = rotate(img, angle, keep_size, mode, padding_mode, - align_corners, dtype, shape_override) + align_corners, dtype, shape_override_) # TODO: candidate for refactoring into a LazyTransform method img_t.push_pending_transform(MetaMatrix(transform, metadata)) @@ -186,10 +191,12 @@ def __init__( mode: Union[InterpolateMode, str] = InterpolateMode.AREA, padding_mode: Union[NumpyPadMode, PytorchPadMode, str] = NumpyPadMode.EDGE, align_corners: Optional[bool] = None, - keep_size: bool = True, + keep_size: Optional[bool] = True, dtype: Union[DtypeLike, torch.dtype] = np.float32, + lazy_evaluation: Optional[bool] = True, **kwargs ): + LazyTransform.__init__(self, lazy_evaluation) self.zoom = zoom self.mode: InterpolateMode = InterpolateMode(mode) self.padding_mode = padding_mode @@ -321,6 +328,52 @@ def inverse( ): raise NotImplementedError() + +class Translate(LazyTransform, InvertibleTransform): + def __init__( + self, + translation: Union[Sequence[float], float], + mode: Union[InterpolateMode, str] = InterpolateMode.AREA, + padding_mode: Union[NumpyPadMode, PytorchPadMode, str] = NumpyPadMode.EDGE, + dtype: Union[DtypeLike, torch.dtype] = np.float32, + lazy_evaluation: Optional[bool] = True, + **kwargs + ): + LazyTransform.__init__(self, lazy_evaluation) + self.translation = translation + self.mode: InterpolateMode = InterpolateMode(mode) + self.padding_mode = padding_mode + self.dtype = dtype + self.kwargs = kwargs + + def __call__( + self, + img: NdarrayOrTensor, + mode: Optional[Union[InterpolateMode, str]] = None, + padding_mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None, + ) -> NdarrayOrTensor: + mode = self.mode or mode + padding_mode = self.padding_mode or padding_mode + dtype = self.dtype + + shape_override_ = None + if shape_override_ is None and isinstance(img, MetaTensor) and img.has_pending_transforms(): + shape_override_ = img.peek_pending_transform().metadata.get("shape_override", None) + + img_t, transform, metadata = translate(img, self.translation, + mode, padding_mode, dtype, shape_override_) + + # TODO: candidate for refactoring into a LazyTransform method + img_t.push_pending_transform(MetaMatrix(transform, metadata)) + if not self.lazy_evaluation: + img_t = apply(img_t) + + return img_t + + def inverse(self, data): + raise NotImplementedError() + + # croppad # ======= @@ -328,7 +381,7 @@ class CropPad(LazyTransform, InvertibleTransform): def __init__( self, - slices: Sequence[slice], + slices: Optional[Sequence[slice]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = GridSamplePadMode.BORDER, lazy_evaluation: Optional[bool] = True, ): @@ -339,10 +392,11 @@ def __init__( def __call__( self, img: NdarrayOrTensor, + slices: Optional[Sequence[slice]] = None, shape_override: Optional[Sequence] = None ): - - img_t, transform, metadata = croppad(img, self.slices, self.padding_mode, shape_override) + slices_ = slices if self.slices is None else self.slices + img_t, transform, metadata = croppad(img, slices_, self.padding_mode, shape_override) # TODO: candidate for refactoring into a LazyTransform method img_t.push_pending_transform(MetaMatrix(transform, metadata)) diff --git a/monai/transforms/atmostonce/compose.py b/monai/transforms/atmostonce/compose.py index 9b1851a59f..0ec0367f8a 100644 --- a/monai/transforms/atmostonce/compose.py +++ b/monai/transforms/atmostonce/compose.py @@ -4,7 +4,7 @@ import numpy as np -from monai.transforms.atmostonce.lazy_transform import LazyTransform, compile_transforms, flatten_sequences +from monai.transforms.atmostonce.lazy_transform import LazyTransform, compile_lazy_transforms, flatten_sequences from monai.utils import GridSampleMode, GridSamplePadMode, ensure_tuple, get_seed, MAX_SEED from monai.transforms import Randomizable, InvertibleTransform, OneOf, apply_transform @@ -185,7 +185,7 @@ def __init__( self.transforms = ensure_tuple(transforms) if lazy_evaluation is True: - self.dst_transforms = compile_transforms(self.transforms) + self.dst_transforms = compile_lazy_transforms(self.transforms) else: self.dst_transforms = flatten_sequences(self.transforms) diff --git a/monai/transforms/atmostonce/functional.py b/monai/transforms/atmostonce/functional.py index bf74e50d20..f2b0bea5d9 100644 --- a/monai/transforms/atmostonce/functional.py +++ b/monai/transforms/atmostonce/functional.py @@ -1,8 +1,10 @@ from typing import Optional, Sequence, Tuple, Union +import numpy as np import torch -from monai.transforms import create_rotate, map_spatial_axes + +from monai.transforms import create_rotate, create_translate, map_spatial_axes from monai.data import get_track_meta from monai.transforms.atmostonce.apply import extents_from_shape, shape_from_extents @@ -222,14 +224,17 @@ def rotate( raise ValueError(f"Unsupported image dimension: {input_ndim}, available options are [2, 3].") angle_ = ensure_tuple_rep(angle, 1 if input_ndim == 2 else 3) - transform = create_rotate(input_ndim, angle_) + to_center_tx = create_translate(input_ndim, [d / 2 for d in input_shape[1:]]) + rotate_tx = create_rotate(input_ndim, angle_) im_extents = extents_from_shape(input_shape) if not keep_size: - im_extents = [transform @ e for e in im_extents] + im_extents = [rotate_tx @ e for e in im_extents] spatial_shape = shape_from_extents(input_shape, im_extents) else: spatial_shape = input_shape - + from_center_tx = create_translate(input_ndim, [-d / 2 for d in input_shape[1:]]) + # transform = from_center_tx @ rotate_tx @ to_center_tx + transform = rotate_tx metadata = { "angle": angle_, "keep_size": keep_size, @@ -319,10 +324,40 @@ def zoom( # } +def translate( + img: torch.Tensor, + translation: Sequence[float], + mode: Optional[Union[GridSampleMode, str]] = GridSampleMode.BILINEAR, + padding_mode: Optional[Union[GridSamplePadMode, str]] = NumpyPadMode.EDGE, + dtype: Union[DtypeLike, torch.dtype] = np.float32, + shape_override: Optional[Sequence] = None +): + img_ = convert_to_tensor(img, track_meta=get_track_meta()) + input_shape = img_.shape if shape_override is None else shape_override + input_ndim = len(input_shape) - 1 + if len(translation) != input_ndim: + raise ValueError(f"'translate' length {len(translation)} must be equal to 'img' " + f"spatial dimensions of {input_ndim}") + + transform = MatrixFactory.from_tensor(img).translate(translation).matrix.matrix + im_extents = extents_from_shape(input_shape) + im_extents = [transform @ e for e in im_extents] + # shape_override_ = shape_from_extents(input_shape, im_extents) + + metadata = { + "translation": translation, + "padding_mode": padding_mode, + "dtype": img.dtype, + "im_extents": im_extents, + # "shape_override": shape_override_ + } + return img_, transform, metadata + + def croppad( img: torch.Tensor, slices: Union[Sequence[slice], slice], - pad_mode: Optional[Union[GridSamplePadMode, str]] = NumpyPadMode.EDGE, + padding_mode: Optional[Union[GridSamplePadMode, str]] = NumpyPadMode.EDGE, shape_override: Optional[Sequence] = None ): img_ = convert_to_tensor(img, track_meta=get_track_meta()) @@ -342,7 +377,7 @@ def croppad( metadata = { "slices": slices, - "pad_mode": pad_mode, + "padding_mode": padding_mode, "dtype": img.dtype, "im_extents": im_extents, "shape_override": shape_override_ diff --git a/monai/transforms/atmostonce/utils.py b/monai/transforms/atmostonce/utils.py new file mode 100644 index 0000000000..09b122460b --- /dev/null +++ b/monai/transforms/atmostonce/utils.py @@ -0,0 +1,44 @@ +from typing import Union + +import numpy as np + +import torch + +from monai.config import NdarrayOrTensor +from monai.utils.mapping_stack import Matrix, MetaMatrix + + +def matmul( + first: Union[MetaMatrix, Matrix, NdarrayOrTensor], + second: Union[MetaMatrix, Matrix, NdarrayOrTensor] +): + matrix_types = (MetaMatrix, Matrix, torch.Tensor, np.ndarray) + + if not isinstance(first, matrix_types): + raise TypeError(f"'first' must be one of {matrix_types} but is {type(first)}") + if not isinstance(second, matrix_types): + raise TypeError(f"'second' must be one of {matrix_types} but is {type(second)}") + + first_ = first + if isinstance(first_, MetaMatrix): + first_ = first_.matrix.matrix + elif isinstance(first_, Matrix): + first_ = first_.matrix + + second_ = second + if isinstance(second_, MetaMatrix): + second_ = second_.matrix.matrix + elif isinstance(second_, Matrix): + second_ = second_.matrix + + + if isinstance(first_, np.ndarray): + if isinstance(second_, np.ndarray): + return first_ @ second_ + else: + return torch.from_numpy(first_) @ second_ + else: + if isinstance(second_, np.ndarray): + return first_ @ torch.from_numpy(second_) + else: + return first_ @ second_ diff --git a/monai/utils/mapping_stack.py b/monai/utils/mapping_stack.py index 2759b35b86..b3639bb838 100644 --- a/monai/utils/mapping_stack.py +++ b/monai/utils/mapping_stack.py @@ -49,13 +49,17 @@ def __init__(self, raise ValueError("'device' must be set with TransformBackends.TORCH") self._device = device self._sin = lambda th: torch.sin(torch.as_tensor(th, - dtype=torch.float32, + dtype=torch.float64, device=self._device)) self._cos = lambda th: torch.cos(torch.as_tensor(th, - dtype=torch.float32, + dtype=torch.float64, device=self._device)) - self._eye = lambda rank: torch.eye(rank, device=self._device); - self._diag = lambda size: torch.diag(torch.as_tensor(size, device=self._device)) + self._eye = lambda rank: torch.eye(rank, + device=self._device, + dtype=torch.float64); + self._diag = lambda size: torch.diag(torch.as_tensor(size, + device=self._device, + dtype=torch.float64)) self._backend = backend self._dims = dims diff --git a/monai/utils/type_conversion.py b/monai/utils/type_conversion.py index 97eca5a7a6..7e12639b76 100644 --- a/monai/utils/type_conversion.py +++ b/monai/utils/type_conversion.py @@ -37,6 +37,37 @@ "expand_scalar_to_tuple" ] +__dtype_dict = { + np.int8: 'int8', + torch.int8: 'int8', + np.int16: 'int16', + torch.int16: 'int16', + int: 'int32', + np.int32: 'int32', + torch.int32: 'int32', + np.int64: 'int64', + torch.int64: 'int64', + np.uint8: 'uint8', + torch.uint8: 'uint8', + np.uint16: 'uint16', + np.uint32: 'uint32', + np.uint64: 'uint64', + float: 'float32', + np.float16: 'float16', + np.float: 'float32', + np.float32: 'float32', + np.float64: 'float64', + torch.float16: 'float16', + torch.float: 'float32', + torch.float32: 'float32', + torch.double: 'float64', + torch.float64: 'float64' +} + +def dtypes_to_str_or_identity(dtype: Any) -> Any: + return __dtype_dict.get(dtype, dtype) + + def get_numpy_dtype_from_string(dtype: str) -> np.dtype: """Get a numpy dtype (e.g., `np.float32`) from its string (e.g., `"float32"`).""" diff --git a/tests/test_atmostonce.py b/tests/test_atmostonce.py index 65280eec17..b1ea560bfb 100644 --- a/tests/test_atmostonce.py +++ b/tests/test_atmostonce.py @@ -21,8 +21,8 @@ from monai.utils.mapping_stack import MatrixFactory -def get_img(size, offset = 0): - img = torch.zeros(size, dtype=torch.float32) +def get_img(size, dtype=torch.float32, offset=0): + img = torch.zeros(size, dtype=dtype) if len(size) == 2: for j in range(size[0]): for i in range(size[1]): @@ -342,17 +342,45 @@ def test_rand_rotate(self): enumerate_results_of_op(results) enumerate_results_of_op(results.pending_transforms[-1].metadata) - def test_rotate_apply(self): + def test_rotate_apply_not_lazy(self): r = amoa.Rotate(-torch.pi / 4, mode="bilinear", padding_mode="border", keep_size=False) data = get_img((32, 32)) data = r(data) - data = apply(data) + # data = apply(data) + print(data.shape) print(data) + def test_rotate_apply_lazy(self): + r = amoa.Rotate(-torch.pi / 4, + mode="bilinear", + padding_mode="border", + keep_size=False) + r.lazy_evaluation = True + data = get_img((32, 32)) + data = r(data) + data = apply(data) + print(data.shape) + print(data) + def test_crop_then_rotate_apply_lazy(self): + data = get_img((32, 32)) + print(data.shape) + + lc1 = amoa.CropPad(lazy_evaluation=True, + padding_mode="zeros") + lr1 = amoa.Rotate(torch.pi / 4, + keep_size=False, + padding_mode="zeros", + lazy_evaluation=False) + datas = [] + datas.append(data) + data1 = lc1(data, slices=(slice(0, 16), slice(0, 16))) + datas.append(data1) + data2 = lr1(data1) + datas.append(data2) class TestDictionaryTransforms(unittest.TestCase): From 97216af9103d0104c73d3c18b9199654068dda8e Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Wed, 7 Sep 2022 09:47:12 +0100 Subject: [PATCH 21/30] Removing dead code from apply --- monai/transforms/atmostonce/apply.py | 18 ++---------------- 1 file changed, 2 insertions(+), 16 deletions(-) diff --git a/monai/transforms/atmostonce/apply.py b/monai/transforms/atmostonce/apply.py index e6e954a5a4..aa1993f155 100644 --- a/monai/transforms/atmostonce/apply.py +++ b/monai/transforms/atmostonce/apply.py @@ -52,11 +52,8 @@ def shape_from_extents( mins = aextents.min(axis=0)[0] maxes = aextents.max(axis=0)[0] values = torch.round(maxes - mins).type(torch.IntTensor)[:-1] - #values = torch.ceil(maxes - mins).type(torch.IntTensor)[:-1] return torch.cat((torch.as_tensor([src_shape[0]]), values)) - # return [src_shape[0]] + np.ceil(maxes - mins)[:-1].astype(int).tolist() - def metadata_is_compatible(value_1, value_2): if value_1 is None: @@ -66,6 +63,7 @@ def metadata_is_compatible(value_1, value_2): return True return value_1 == value_2 + def metadata_dtype_is_compatible(value_1, value_2): if value_1 is None: return True @@ -84,11 +82,6 @@ def starting_matrix_and_extents(matrix_factory, data): # set up the identity matrix and metadata cumulative_matrix = matrix_factory.identity() cumulative_extents = extents_from_shape(data.shape) - - # pre-translate origin to centre of image - # translate_to_centre = matrix_factory.translate([d / 2 for d in data.shape[1:]]) - # cumulative_matrix = translate_to_centre @ cumulative_matrix - # cumulative_extents = [matmul(e, translate_to_centre.matrix.matrix) for e in cumulative_extents] return cumulative_matrix, cumulative_extents @@ -128,14 +121,7 @@ def apply(data: Union[torch.Tensor, MetaTensor], get_backend_from_data(data), get_device_from_data(data)) - # # set up the identity matrix and metadata - # cumulative_matrix = matrix_factory.identity() - # cumulative_extents = extents_from_shape(data.shape) - # - # # pre-translate origin to centre of image - # translate_to_centre = matrix_factory.translate([d / 2 for d in data.shape[1:]]) - # cumulative_matrix = translate_to_centre @ cumulative_matrix - # cumulative_extents = [e @ translate_to_centre.matrix.matrix for e in cumulative_extents] + # set up the identity matrix and metadata cumulative_matrix, cumulative_extents = starting_matrix_and_extents(matrix_factory, data) # set the various resampling parameters to an initial state From bb3a60e83bb96ba680b03a011d2ea1be068bd94b Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Wed, 7 Sep 2022 09:49:10 +0100 Subject: [PATCH 22/30] Addition work on array --- monai/transforms/atmostonce/apply.py | 18 +-- monai/transforms/atmostonce/array.py | 207 ++++++++++++++++++++------- monai/transforms/atmostonce/utils.py | 13 +- monai/transforms/utils.py | 74 ++++++++++ monai/utils/mapping_stack.py | 4 + tests/test_atmostonce.py | 14 ++ tests/test_create_grid_and_affine.py | 16 +++ 7 files changed, 274 insertions(+), 72 deletions(-) diff --git a/monai/transforms/atmostonce/apply.py b/monai/transforms/atmostonce/apply.py index e6e954a5a4..aa1993f155 100644 --- a/monai/transforms/atmostonce/apply.py +++ b/monai/transforms/atmostonce/apply.py @@ -52,11 +52,8 @@ def shape_from_extents( mins = aextents.min(axis=0)[0] maxes = aextents.max(axis=0)[0] values = torch.round(maxes - mins).type(torch.IntTensor)[:-1] - #values = torch.ceil(maxes - mins).type(torch.IntTensor)[:-1] return torch.cat((torch.as_tensor([src_shape[0]]), values)) - # return [src_shape[0]] + np.ceil(maxes - mins)[:-1].astype(int).tolist() - def metadata_is_compatible(value_1, value_2): if value_1 is None: @@ -66,6 +63,7 @@ def metadata_is_compatible(value_1, value_2): return True return value_1 == value_2 + def metadata_dtype_is_compatible(value_1, value_2): if value_1 is None: return True @@ -84,11 +82,6 @@ def starting_matrix_and_extents(matrix_factory, data): # set up the identity matrix and metadata cumulative_matrix = matrix_factory.identity() cumulative_extents = extents_from_shape(data.shape) - - # pre-translate origin to centre of image - # translate_to_centre = matrix_factory.translate([d / 2 for d in data.shape[1:]]) - # cumulative_matrix = translate_to_centre @ cumulative_matrix - # cumulative_extents = [matmul(e, translate_to_centre.matrix.matrix) for e in cumulative_extents] return cumulative_matrix, cumulative_extents @@ -128,14 +121,7 @@ def apply(data: Union[torch.Tensor, MetaTensor], get_backend_from_data(data), get_device_from_data(data)) - # # set up the identity matrix and metadata - # cumulative_matrix = matrix_factory.identity() - # cumulative_extents = extents_from_shape(data.shape) - # - # # pre-translate origin to centre of image - # translate_to_centre = matrix_factory.translate([d / 2 for d in data.shape[1:]]) - # cumulative_matrix = translate_to_centre @ cumulative_matrix - # cumulative_extents = [e @ translate_to_centre.matrix.matrix for e in cumulative_extents] + # set up the identity matrix and metadata cumulative_matrix, cumulative_extents = starting_matrix_and_extents(matrix_factory, data) # set the various resampling parameters to an initial state diff --git a/monai/transforms/atmostonce/array.py b/monai/transforms/atmostonce/array.py index 2fe891a019..11894b0333 100644 --- a/monai/transforms/atmostonce/array.py +++ b/monai/transforms/atmostonce/array.py @@ -12,6 +12,7 @@ from monai.transforms.atmostonce.apply import apply from monai.transforms.atmostonce.functional import resize, rotate, zoom, spacing, croppad, translate from monai.transforms.atmostonce.lazy_transform import LazyTransform +from monai.transforms.atmostonce.utils import value_to_tuple_range from monai.utils import (GridSampleMode, GridSamplePadMode, InterpolateMode, NumpyPadMode, PytorchPadMode) @@ -63,9 +64,13 @@ def __call__( align_corners_ = align_corners or self.align_corners dtype_ = dtype or self.dtype + shape_override_ = shape_override + if shape_override_ is None and isinstance(img, MetaTensor) and img.has_pending_transforms(): + shape_override_ = img.peek_pending_transform().metadata.get("shape_override", None) + img_t, transform, metadata = spacing(img, self.pixdim, self.src_pixdim, self.diagonal, mode_, padding_mode_, align_corners_, dtype_, - shape_override) + shape_override_) # TODO: candidate for refactoring into a LazyTransform method img_t.push_pending_transform(MetaMatrix(transform, metadata)) @@ -114,9 +119,13 @@ def __call__( anti_aliasing_ = anti_aliasing or self.anti_aliasing anti_aliasing_sigma_ = anti_aliasing_sigma or self.anti_aliasing_sigma + shape_override_ = shape_override + if shape_override_ is None and isinstance(img, MetaTensor) and img.has_pending_transforms(): + shape_override_ = img.peek_pending_transform().metadata.get("shape_override", None) + img_t, transform, metadata = resize(img, self.spatial_size, self.size_mode, mode_, align_corners_, anti_aliasing_, anti_aliasing_sigma_, - self.dtype, shape_override) + self.dtype, shape_override_) # TODO: candidate for refactoring into a LazyTransform method img_t.push_pending_transform(MetaMatrix(transform, metadata)) @@ -149,6 +158,7 @@ def __init__( def __call__( self, img: NdarrayOrTensor, + angle: Optional[Union[Sequence[float], float]] = None, mode: Optional[Union[InterpolateMode, str]] = None, padding_mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None, align_corners: Optional[bool] = None, @@ -208,6 +218,7 @@ def __init__( def __call__( self, img: NdarrayOrTensor, + zoom: Optional[Union[Sequence[float], float]] = None, mode: Optional[Union[InterpolateMode, str]] = None, padding_mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None, align_corners: Optional[bool] = None, @@ -220,8 +231,12 @@ def __call__( keep_size = self.keep_size dtype = self.dtype + shape_override_ = shape_override + if shape_override_ is None and isinstance(img, MetaTensor) and img.has_pending_transforms(): + shape_override_ = img.peek_pending_transform().metadata.get("shape_override", None) + img_t, transform, metadata = zoom(img, self.zoom, mode, padding_mode, align_corners, - keep_size, dtype, shape_override) + keep_size, dtype, shape_override_) # TODO: candidate for refactoring into a LazyTransform method img_t.push_pending_transform(MetaMatrix(transform, metadata)) @@ -234,13 +249,22 @@ def inverse(self, data): raise NotImplementedError() -# class Rotate90(InvertibleTransform, LazyTransform): -# -# def __init__( -# self, -# -# ): -# pass +class Rotate90(InvertibleTransform, LazyTransform): + + def __init__( + self, + k: Optional[int] = 1, + spatial_axes: Optional[Tuple[int, int]] = (0, 1) + ) -> None: + self.k = k + self.spatial_axes = spatial_axes + + def __call__( + self, + img: torch.Tensor + ) -> torch.Tensor: + + class RandRotate(RandomizableTransform, InvertibleTransform, LazyTransform): @@ -250,40 +274,31 @@ def __init__( range_y: Optional[Union[Tuple[float, float], float]] = 0.0, range_z: Optional[Union[Tuple[float, float], float]] = 0.0, prob: Optional[float] = 0.1, - keep_size: bool = True, - mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, - padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, - align_corners: bool = False, - dtype: Union[DtypeLike, torch.dtype] = np.float32 + keep_size: Optional[bool] = True, + mode: Optional[Union[GridSampleMode, str]] = GridSampleMode.BILINEAR, + padding_mode: Optional[Union[GridSamplePadMode, str]] = GridSamplePadMode.BORDER, + align_corners: Optional[bool] = False, + dtype: Optional[Union[DtypeLike, torch.dtype]] = np.float32, + lazy_evaluation: Optional[bool] = True ): RandomizableTransform.__init__(self, prob) - self.range_x = ensure_tuple(range_x) - if len(self.range_x) == 1: - self.range_x = tuple(sorted([-self.range_x[0], self.range_x[0]])) - self.range_y = ensure_tuple(range_y) - if len(self.range_y) == 1: - self.range_y = tuple(sorted([-self.range_y[0], self.range_y[0]])) - self.range_z = ensure_tuple(range_z) - if len(self.range_z) == 1: - self.range_z = tuple(sorted([-self.range_z[0], self.range_z[0]])) + self.range_x = value_to_tuple_range(range_x) + self.range_y = value_to_tuple_range(range_y) + self.range_z = value_to_tuple_range(range_z) - self.keep_size = keep_size - self.mode = mode - self.padding_mode = padding_mode - self.align_corners = align_corners - self.dtype = dtype + self.x, self.y, self.z = 0.0, 0.0, 0.0 - self.x = 0.0 - self.y = 0.0 - self.z = 0.0 + self.op = Rotate(0, keep_size, mode, padding_mode, align_corners, dtype, lazy_evaluation) def randomize(self, data: Optional[Any] = None) -> None: super().randomize(None) - if not self._do_transform: - return None - self.x = self.R.uniform(low=self.range_x[0], high=self.range_x[1]) - self.y = self.R.uniform(low=self.range_y[0], high=self.range_y[1]) - self.z = self.R.uniform(low=self.range_z[0], high=self.range_z[1]) + + self.x, self.y, self.z = 0.0, 0.0, 0.0 + + if self._do_transform: + self.x = self.R.uniform(low=self.range_x[0], high=self.range_x[1]) + self.y = self.R.uniform(low=self.range_y[0], high=self.range_y[1]) + self.z = self.R.uniform(low=self.range_z[0], high=self.range_z[1]) def __call__( self, @@ -293,7 +308,6 @@ def __call__( align_corners: Optional[bool] = None, dtype: Optional[Union[DtypeLike, torch.dtype]] = None, randomize: Optional[bool] = True, - get_matrix: Optional[bool] = False, shape_override: Optional[Sequence] = None ) -> NdarrayOrTensor: @@ -306,21 +320,7 @@ def __call__( else: angle = 0 if img_dims == 2 else (0, 0, 0) - mode = self.mode or mode - padding_mode = self.padding_mode or padding_mode - align_corners = self.align_corners or align_corners - keep_size = self.keep_size - dtype = self.dtype - - img_t, transform, metadata = rotate(img, angle, keep_size, mode, padding_mode, - align_corners, dtype, shape_override) - - # TODO: candidate for refactoring into a LazyTransform method - img_t.push_pending_transform(MetaMatrix(transform, metadata)) - if not self.lazy_evaluation: - img_t = apply(img_t) - - return img_t + return self.op(img, angle, mode, padding_mode, align_corners, shape_override) def inverse( self, @@ -329,6 +329,97 @@ def inverse( raise NotImplementedError() +# class RandRotateOld(RandomizableTransform, InvertibleTransform, LazyTransform): +# +# def __init__( +# self, +# range_x: Optional[Union[Tuple[float, float], float]] = 0.0, +# range_y: Optional[Union[Tuple[float, float], float]] = 0.0, +# range_z: Optional[Union[Tuple[float, float], float]] = 0.0, +# prob: Optional[float] = 0.1, +# keep_size: bool = True, +# mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, +# padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, +# align_corners: bool = False, +# dtype: Union[DtypeLike, torch.dtype] = np.float32 +# ): +# RandomizableTransform.__init__(self, prob) +# self.range_x = ensure_tuple(range_x) +# if len(self.range_x) == 1: +# self.range_x = tuple(sorted([-self.range_x[0], self.range_x[0]])) +# self.range_y = ensure_tuple(range_y) +# if len(self.range_y) == 1: +# self.range_y = tuple(sorted([-self.range_y[0], self.range_y[0]])) +# self.range_z = ensure_tuple(range_z) +# if len(self.range_z) == 1: +# self.range_z = tuple(sorted([-self.range_z[0], self.range_z[0]])) +# +# self.keep_size = keep_size +# self.mode = mode +# self.padding_mode = padding_mode +# self.align_corners = align_corners +# self.dtype = dtype +# +# self.x = 0.0 +# self.y = 0.0 +# self.z = 0.0 +# +# def randomize(self, data: Optional[Any] = None) -> None: +# super().randomize(None) +# if not self._do_transform: +# return None +# self.x = self.R.uniform(low=self.range_x[0], high=self.range_x[1]) +# self.y = self.R.uniform(low=self.range_y[0], high=self.range_y[1]) +# self.z = self.R.uniform(low=self.range_z[0], high=self.range_z[1]) +# +# def __call__( +# self, +# img: NdarrayOrTensor, +# mode: Optional[Union[InterpolateMode, str]] = None, +# padding_mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None, +# align_corners: Optional[bool] = None, +# dtype: Optional[Union[DtypeLike, torch.dtype]] = None, +# randomize: Optional[bool] = True, +# get_matrix: Optional[bool] = False, +# shape_override: Optional[Sequence] = None +# ) -> NdarrayOrTensor: +# +# if randomize: +# self.randomize() +# +# img_dims = len(img.shape) - 1 +# if self._do_transform: +# angle = self.x if img_dims == 2 else (self.x, self.y, self.z) +# else: +# angle = 0 if img_dims == 2 else (0, 0, 0) +# +# mode = self.mode or mode +# padding_mode = self.padding_mode or padding_mode +# align_corners = self.align_corners or align_corners +# keep_size = self.keep_size +# dtype = self.dtype +# +# shape_override_ = shape_override +# if shape_override_ is None and isinstance(img, MetaTensor) and img.has_pending_transforms(): +# shape_override_ = img.peek_pending_transform().metadata.get("shape_override", None) +# +# img_t, transform, metadata = rotate(img, angle, keep_size, mode, padding_mode, +# align_corners, dtype, shape_override_) +# +# # TODO: candidate for refactoring into a LazyTransform method +# img_t.push_pending_transform(MetaMatrix(transform, metadata)) +# if not self.lazy_evaluation: +# img_t = apply(img_t) +# +# return img_t +# +# def inverse( +# self, +# data: NdarrayOrTensor, +# ): +# raise NotImplementedError() + + class Translate(LazyTransform, InvertibleTransform): def __init__( self, @@ -351,12 +442,13 @@ def __call__( img: NdarrayOrTensor, mode: Optional[Union[InterpolateMode, str]] = None, padding_mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None, + shape_override: Optional[Sequence] = None ) -> NdarrayOrTensor: mode = self.mode or mode padding_mode = self.padding_mode or padding_mode dtype = self.dtype - shape_override_ = None + shape_override_ = shape_override if shape_override_ is None and isinstance(img, MetaTensor) and img.has_pending_transforms(): shape_override_ = img.peek_pending_transform().metadata.get("shape_override", None) @@ -396,7 +488,12 @@ def __call__( shape_override: Optional[Sequence] = None ): slices_ = slices if self.slices is None else self.slices - img_t, transform, metadata = croppad(img, slices_, self.padding_mode, shape_override) + + shape_override_ = shape_override + if shape_override_ is None and isinstance(img, MetaTensor) and img.has_pending_transforms(): + shape_override_ = img.peek_pending_transform().metadata.get("shape_override", None) + + img_t, transform, metadata = croppad(img, slices_, self.padding_mode, shape_override_) # TODO: candidate for refactoring into a LazyTransform method img_t.push_pending_transform(MetaMatrix(transform, metadata)) diff --git a/monai/transforms/atmostonce/utils.py b/monai/transforms/atmostonce/utils.py index 09b122460b..ac8a1f1f68 100644 --- a/monai/transforms/atmostonce/utils.py +++ b/monai/transforms/atmostonce/utils.py @@ -31,7 +31,6 @@ def matmul( elif isinstance(second_, Matrix): second_ = second_.matrix - if isinstance(first_, np.ndarray): if isinstance(second_, np.ndarray): return first_ @ second_ @@ -42,3 +41,15 @@ def matmul( return first_ @ torch.from_numpy(second_) else: return first_ @ second_ + + +def value_to_tuple_range(value): + if isinstance(value, (tuple, list)): + if len(value) == 2: + return (value[0], value[1]) if value[0] <= value[1] else (value[1], value[0]) + elif len(value) == 1: + return -value[0], value[0] + else: + raise ValueError(f"parameter 'value' must be of length 1 or 2 but is {len(value)}") + else: + return -value, value diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index ae550e7ce6..52a620312c 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -758,6 +758,80 @@ def _create_rotate( raise ValueError(f"Unsupported spatial_dims: {spatial_dims}, available options are [2, 3].") +def create_rotate_90( + spatial_dims: int, + axis: int, + steps: Optional[int] = 1, + device: Optional[torch.device] = None, + backend: str = TransformBackends.NUMPY, +) -> NdarrayOrTensor: + """ + create a 2D or 3D rotation matrix + + Args: + spatial_dims: {``2``, ``3``} spatial rank + radians: rotation radians + when spatial_dims == 3, the `radians` sequence corresponds to + rotation in the 1st, 2nd, and 3rd dim respectively. + device: device to compute and store the output (when the backend is "torch"). + backend: APIs to use, ``numpy`` or ``torch``. + + Raises: + ValueError: When ``radians`` is empty. + ValueError: When ``spatial_dims`` is not one of [2, 3]. + + """ + _backend = look_up_option(backend, TransformBackends) + if _backend == TransformBackends.NUMPY: + return _create_rotate_90( + spatial_dims=spatial_dims, + axis=axis, + steps=steps, + eye_func=np.eye) + if _backend == TransformBackends.TORCH: + return _create_rotate_90( + spatial_dims=spatial_dims, + axis=axis, + steps=steps, + eye_func=lambda rank: torch.eye(rank, device=device), + ) + raise ValueError(f"backend {backend} is not supported") + +def _create_rotate_90( + spatial_dims: int, + axis: int, + steps: Optional[int] = 1, + eye_func: Callable = np.eye +) -> NdarrayOrTensor: + + values = [(1, 0, 0, 1), + (0, -1, 1, 0), + (-1, 0, 0, -1), + (0, 1, -1, 0)] + + if spatial_dims == 2: + if axis != 0: + raise ValueError(f"if 'spatial_dims' is 2, 'axis' must be 0 but is {axis}") + elif spatial_dims == 3: + if axis < 0 or axis > 2: + raise ValueError("if 'spatial_dims' is 3, 'axis' must be between 0 and 2 inclusive ", + f"but is {axis}") + else: + raise ValueError(f"'spatial_dims' must be 2 or 3 but is {spatial_dims}") + + steps_ = steps % 4 + + affine = eye_func(spatial_dims + 1) + + if spatial_dims == 2: + a, b = 0, 1 + else: + a, b = 0 if axis > 0 else 1, 2 if axis < 2 else 1 + + affine[a, a], affine[a, b], affine[b, a], affine[b, b] = values[steps] + return affine + + def create_shear( spatial_dims: int, coefs: Union[Sequence[float], float], diff --git a/monai/utils/mapping_stack.py b/monai/utils/mapping_stack.py index b3639bb838..6bf2708915 100644 --- a/monai/utils/mapping_stack.py +++ b/monai/utils/mapping_stack.py @@ -78,6 +78,10 @@ def rotate_euler(self, radians: Union[Sequence[float], float], **extra_args): matrix = _create_rotate(self._dims, radians, self._sin, self._cos, self._eye) return MetaMatrix(matrix, extra_args) + def rotate_90(self, rotations, axis, **extra_args): + matrix = _create_rotate_90(self._dims, rotations, axis) + return MetaMatrix(matrix, extra_args) + def shear(self, coefs: Union[Sequence[float], float], **extra_args): matrix = _create_shear(self._dims, coefs, self._eye) return MetaMatrix(matrix, extra_args) diff --git a/tests/test_atmostonce.py b/tests/test_atmostonce.py index b1ea560bfb..661b17073b 100644 --- a/tests/test_atmostonce.py +++ b/tests/test_atmostonce.py @@ -10,6 +10,7 @@ from monai.transforms.atmostonce import array as amoa from monai.transforms.atmostonce.array import Rotate, CropPad from monai.transforms.atmostonce.lazy_transform import compile_lazy_transforms +from monai.transforms.atmostonce.utils import value_to_tuple_range from monai.utils import TransformBackends from monai.transforms import Affined, Affine @@ -448,3 +449,16 @@ def test_old_affine(self): } d = c(d) print(d['image'].shape) + + +class TestUtils(unittest.TestCase): + + def test_value_to_tuple_range(self): + self.assertTupleEqual(value_to_tuple_range(5), (-5, 5)) + self.assertTupleEqual(value_to_tuple_range([5]), (-5, 5)) + self.assertTupleEqual(value_to_tuple_range((5,)), (-5, 5)) + self.assertTupleEqual(value_to_tuple_range([-2.1, 4.3]), (-2.1, 4.3)) + self.assertTupleEqual(value_to_tuple_range((-2.1, 4.3)), (-2.1, 4.3)) + self.assertTupleEqual(value_to_tuple_range([4.3, -2.1]), (-2.1, 4.3)) + self.assertTupleEqual(value_to_tuple_range((4.3, -2.1)), (-2.1, 4.3)) + diff --git a/tests/test_create_grid_and_affine.py b/tests/test_create_grid_and_affine.py index d70db45468..bda9865658 100644 --- a/tests/test_create_grid_and_affine.py +++ b/tests/test_create_grid_and_affine.py @@ -22,6 +22,7 @@ create_shear, create_translate, ) +from monai.transforms.utils import create_rotate_90 from tests.utils import assert_allclose, is_tf32_env @@ -219,6 +220,21 @@ def test_create_rotate(self): (3, (0, 0, np.pi / 2)), np.array([[0.0, -1.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]), ) + def test_create_rotate_90(self): + expected = np.eye(3) + test_assert(create_rotate_90, (2, 0, 0), expected) + + expected = np.eye(3) + expected[0:2, 0:2] = [[0, -1], [1, 0]] + test_assert(create_rotate_90, (2, 0, 1), expected) + + expected = np.eye(3) + expected[0:2, 0:2] = [[-1, 0], [0, -1]] + test_assert(create_rotate_90, (2, 0, 2), expected) + + expected = np.eye(3) + expected[0:2, 0:2] = [[0, 1], [-1, 0]] + test_assert(create_rotate_90, (2, 0, 3), expected) def test_create_shear(self): test_assert(create_shear, (2, 1.0), np.array([[1.0, 1.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])) From f53a56e15f625d5e8c723756be6bdea2beddddfe Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Fri, 9 Sep 2022 10:18:50 +0100 Subject: [PATCH 23/30] More lazy transforms --- monai/transforms/atmostonce/array.py | 115 +++++++++++++++++++++- monai/transforms/atmostonce/functional.py | 66 +++++++++---- monai/transforms/utils.py | 59 +++++++++-- monai/utils/mapping_stack.py | 6 +- tests/test_atmostonce.py | 55 ++++++++++- tests/test_create_grid_and_affine.py | 9 +- 6 files changed, 271 insertions(+), 39 deletions(-) diff --git a/monai/transforms/atmostonce/array.py b/monai/transforms/atmostonce/array.py index 11894b0333..9b164c6f6c 100644 --- a/monai/transforms/atmostonce/array.py +++ b/monai/transforms/atmostonce/array.py @@ -10,7 +10,7 @@ from monai.transforms import InvertibleTransform, RandomizableTransform from monai.transforms.atmostonce.apply import apply -from monai.transforms.atmostonce.functional import resize, rotate, zoom, spacing, croppad, translate +from monai.transforms.atmostonce.functional import resize, rotate, zoom, spacing, croppad, translate, rotate90, flip from monai.transforms.atmostonce.lazy_transform import LazyTransform from monai.transforms.atmostonce.utils import value_to_tuple_range @@ -83,6 +83,38 @@ def inverse(self, data): raise NotImplementedError() +class Flip(LazyTransform, InvertibleTransform): + + def __init__( + self, + spatial_axis: Optional[Union[Sequence[int], int]] = None, + lazy_evaluation: Optional[bool] = True + ) -> None: + LazyTransform.__init__(self, lazy_evaluation) + self.spatial_axis = spatial_axis + + def __call__( + self, + img: NdarrayOrTensor, + spatial_axis: Optional[Union[Sequence[int], int]] = None, + shape_override: Optional[Sequence] = None + ): + spatial_axis_ = self.spatial_axis = spatial_axis + shape_override_ = shape_override + if (shape_override_ is None and + isinstance(img, MetaTensor) and img.has_pending_transforms()): + shape_override_ = img.peek_pending_transform().metadata.get("shape_override", None) + + img_t, transform, metadata = flip(img, spatial_axis_, shape_override_) + + # TODO: candidate for refactoring into a LazyTransform method + img_t.push_pending_transform(MetaMatrix(transform, metadata)) + if not self.lazy_evaluation: + img_t = apply(img_t) + + return img_t + + class Resize(LazyTransform, InvertibleTransform): def __init__( @@ -254,16 +286,80 @@ class Rotate90(InvertibleTransform, LazyTransform): def __init__( self, k: Optional[int] = 1, - spatial_axes: Optional[Tuple[int, int]] = (0, 1) + spatial_axes: Optional[Tuple[int, int]] = (0, 1), + lazy_evaluation: Optional[bool] = True, ) -> None: + LazyTransform.__init__(self, lazy_evaluation) self.k = k self.spatial_axes = spatial_axes def __call__( self, - img: torch.Tensor + img: torch.Tensor, + k: Optional[int] = None, + spatial_axes: Optional[Tuple[int, int]] = None, + shape_override: Optional[Sequence[int]] = None + ) -> torch.Tensor: + k_ = k or self.k + spatial_axes_ = spatial_axes or self.spatial_axes + + shape_override_ = shape_override + if (shape_override_ is None and + isinstance(img, MetaTensor) and img.has_pending_transforms()): + shape_override_ = img.peek_pending_transform().metadata.get("shape_override", None) + + img_t, transform, metadata = rotate90(img, k_, spatial_axes_, shape_override_) + + # TODO: candidate for refactoring into a LazyTransform method + img_t.push_pending_transform(MetaMatrix(transform, metadata)) + if not self.lazy_evaluation: + img_t = apply(img_t) + + return img_t + + +class RandRotate90(RandomizableTransform, InvertibleTransform, LazyTransform): + + def __init__( + self, + prob: float = 0.1, + max_k: int = 3, + spatial_axes: Tuple[int, int] = (0, 1), + lazy_evaluation: Optional[bool] = True + ) -> None: + RandomizableTransform.__init__(self, prob) + self.max_k = max_k + self.spatial_axes = spatial_axes + + self.k = 0 + + self.op = Rotate90(0, spatial_axes, lazy_evaluation) + + def randomize(self, data: Optional[Any] = None) -> None: + super().randomize(None) + if not self._do_transform: + return None + self.k = self.R.randint(self.max_k) + 1 + + def __call__( + self, + img: torch.Tensor, + randomize: bool = True, + shape_override: Optional[Sequence] = None ) -> torch.Tensor: + if randomize: + self.randomize() + + k = self.k if self._do_transform else 0 + + return self.op(img, k, shape_override=shape_override) + + def inverse( + self, + data: NdarrayOrTensor, + ): + raise NotImplementedError() class RandRotate(RandomizableTransform, InvertibleTransform, LazyTransform): @@ -329,6 +425,19 @@ def inverse( raise NotImplementedError() +class RandFlip(RandomizableTransform, InvertibleTransform, LazyTransform): + + def __init__( + self, + prob: float = 0.1, + spatial_axis: Optional[Union[Sequence[int], int]] = None + ) -> None: + RandomizableTransform.__init__(self, prob) + self.spatial_axis = spatial_axis + + self.op = Flip(0, spatial_axis) + + # class RandRotateOld(RandomizableTransform, InvertibleTransform, LazyTransform): # # def __init__( diff --git a/monai/transforms/atmostonce/functional.py b/monai/transforms/atmostonce/functional.py index f2b0bea5d9..8bd9e0cf3d 100644 --- a/monai/transforms/atmostonce/functional.py +++ b/monai/transforms/atmostonce/functional.py @@ -93,9 +93,28 @@ def orientation( def flip( - img: torch.Tensor + img: torch.Tensor, + spatial_axis: Union[Sequence[int], int], + shape_override: Optional[Sequence] = None ): - pass + img_ = convert_to_tensor(img, track_meta=get_track_meta()) + input_shape = img_.shape if shape_override is None else shape_override + + spatial_axis_ = spatial_axis + if spatial_axis_ is None: + spatial_axis_ = tuple(i for i in range(len(input_shape[1:]))) + transform = MatrixFactory.from_tensor(img).flip(spatial_axis_).matrix.matrix + im_extents = extents_from_shape(input_shape) + im_extents = [transform @ e for e in im_extents] + + shape_override_ = shape_from_extents(input_shape, im_extents) + + metadata = { + "spatial_axes": spatial_axis, + "im_extents": im_extents, + "shape_override": shape_override_ + } + return img_, transform, metadata def resize( @@ -282,9 +301,9 @@ def zoom( mode_ = look_up_option(mode, GridSampleMode) padding_mode_ = look_up_option(padding_mode, GridSamplePadMode) - dtype_ = get_equivalent_dtype(dtype or img.dtype, torch.Tensor) + dtype_ = get_equivalent_dtype(dtype or img_.dtype, torch.Tensor) - transform = MatrixFactory.from_tensor(img).scale(zoom_factors).matrix.matrix + transform = MatrixFactory.from_tensor(img_).scale(zoom_factors).matrix.matrix im_extents = extents_from_shape(input_shape) if keep_size is False: im_extents = [transform @ e for e in im_extents] @@ -305,23 +324,28 @@ def zoom( return img_, transform, metadata -# def rotate90( -# img: torch.Tensor, -# k: Optional[int] = 1, -# spatial_axes: Optional[Tuple[int, int]] = (0, 1), -# ): -# if len(spatial_axes) != 2: -# raise ValueError("'spatial_axes' must be a tuple of two integers indicating") -# -# img = convert_to_tensor(img, track_meta=get_track_meta()) -# axes = map_spatial_axes(img.ndim, spatial_axes) -# ori_shape = img.shape[1:] -# -# metadata = { -# "k": k, -# "spatial_axes": spatial_axes, -# "shape_override": shape_override -# } +def rotate90( + img: torch.Tensor, + k: Optional[int] = 1, + spatial_axes: Optional[Tuple[int, int]] = (0, 1), + shape_override: Optional[bool] = None +): + if len(spatial_axes) != 2: + raise ValueError("'spatial_axes' must be a tuple of two integers indicating") + + img_ = convert_to_tensor(img, track_meta=get_track_meta()) + # axes = map_spatial_axes(img.ndim, spatial_axes) + # ori_shape = img.shape[1:] + input_shape = img_.shape if shape_override is None else shape_override + input_ndim = len(input_shape) - 1 + + transform = MatrixFactory.from_tensor(img_).rotate_90(k, ) + + metadata = { + "k": k, + "spatial_axes": spatial_axes, + "shape_override": shape_override + } def translate( diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 52a620312c..93939a98c4 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -797,9 +797,10 @@ def create_rotate_90( ) raise ValueError(f"backend {backend} is not supported") + def _create_rotate_90( spatial_dims: int, - axis: int, + axis: Tuple[int, int], steps: Optional[int] = 1, eye_func: Callable = np.eye ) -> NdarrayOrTensor: @@ -810,11 +811,11 @@ def _create_rotate_90( (0, 1, -1, 0)] if spatial_dims == 2: - if axis != 0: - raise ValueError(f"if 'spatial_dims' is 2, 'axis' must be 0 but is {axis}") + if axis != (0, 1): + raise ValueError(f"if 'spatial_dims' is 2, 'axis' must be (0, 1) but is {axis}") elif spatial_dims == 3: - if axis < 0 or axis > 2: - raise ValueError("if 'spatial_dims' is 3, 'axis' must be between 0 and 2 inclusive ", + if axis not in ((0, 1), (0, 2), (1, 2)): + raise ValueError("if 'spatial_dims' is 3, 'axis' must be (0,1), (0, 2), or (1, 2) " f"but is {axis}") else: raise ValueError(f"'spatial_dims' must be 2 or 3 but is {spatial_dims}") @@ -826,12 +827,58 @@ def _create_rotate_90( if spatial_dims == 2: a, b = 0, 1 else: - a, b = 0 if axis > 0 else 1, 2 if axis < 2 else 1 + a, b = axis affine[a, a], affine[a, b], affine[b, a], affine[b, b] = values[steps] return affine +def create_flip( + spatial_dims: int, + spatial_axis: Union[Sequence[int], int], + device: Optional[torch.device] = None, + backend: str = TransformBackends.NUMPY, +) -> NdarrayOrTensor: + _backend = look_up_option(backend, TransformBackends) + if _backend == TransformBackends.NUMPY: + return _create_flip( + spatial_dims=spatial_dims, + spatial_axis=spatial_axis, + eye_func=np.eye) + if _backend == TransformBackends.TORCH: + return _create_flip( + spatial_dims=spatial_dims, + spatial_axis=spatial_axis, + eye_func=lambda rank: torch.eye(rank, device=device), + ) + raise ValueError(f"backend {backend} is not supported") + + +def _create_flip( + spatial_dims: int, + spatial_axis: Union[Sequence[int], int], + eye_func: Callable = np.eye +): + affine = eye_func(spatial_dims + 1) + if isinstance(spatial_axis, int): + if spatial_axis < -spatial_dims or spatial_axis >= spatial_dims: + raise ValueError("'spatial_axis' values must be between " + f"{-spatial_dims} and {spatial_dims-1} inclusive " + f"('spatial_axis' is {spatial_axis})") + affine[spatial_axis, spatial_axis] = -1 + else: + if any((s < -spatial_dims or s >= spatial_dims) for s in spatial_axis): + raise ValueError("'spatial_axis' values must be between " + f"{-spatial_dims} and {spatial_dims-1} inclusive " + f"('spatial_axis' is {spatial_axis})") + + for i in range(spatial_dims): + if i in spatial_axis: + affine[i, i] = -1 + + return affine + + def create_shear( spatial_dims: int, coefs: Union[Sequence[float], float], diff --git a/monai/utils/mapping_stack.py b/monai/utils/mapping_stack.py index 6bf2708915..78f6bb3e8a 100644 --- a/monai/utils/mapping_stack.py +++ b/monai/utils/mapping_stack.py @@ -18,7 +18,7 @@ from monai.utils.enums import TransformBackends from monai.transforms.utils import (_create_rotate, _create_scale, _create_shear, - _create_translate) + _create_translate, _create_rotate_90, _create_flip) from monai.utils.misc import get_backend_from_data, get_device_from_data @@ -82,6 +82,10 @@ def rotate_90(self, rotations, axis, **extra_args): matrix = _create_rotate_90(self._dims, rotations, axis) return MetaMatrix(matrix, extra_args) + def flip(self, axis, **extra_args): + matrix = _create_flip(self._dims, axis, self._eye) + return MetaMatrix(matrix, extra_args) + def shear(self, coefs: Union[Sequence[float], float], **extra_args): matrix = _create_shear(self._dims, coefs, self._eye) return MetaMatrix(matrix, extra_args) diff --git a/tests/test_atmostonce.py b/tests/test_atmostonce.py index 661b17073b..e0de31c96c 100644 --- a/tests/test_atmostonce.py +++ b/tests/test_atmostonce.py @@ -13,8 +13,8 @@ from monai.transforms.atmostonce.utils import value_to_tuple_range from monai.utils import TransformBackends -from monai.transforms import Affined, Affine -from monai.transforms.atmostonce.functional import croppad, resize, rotate, spacing +from monai.transforms import Affined, Affine, Flip +from monai.transforms.atmostonce.functional import croppad, resize, rotate, spacing, flip from monai.transforms.atmostonce.apply import Applyd, extents_from_shape, shape_from_extents, apply from monai.transforms.atmostonce.dictionary import Rotated from monai.transforms.compose import Compose @@ -27,12 +27,12 @@ def get_img(size, dtype=torch.float32, offset=0): if len(size) == 2: for j in range(size[0]): for i in range(size[1]): - img[j, i] = i + j * size[1] + offset + img[j, i] = i + j * size[0] + offset else: for k in range(size[0]): for j in range(size[1]): for i in range(size[2]): - img[..., j, k] = j * size[0] + k * size[0] * size[1] + offset + img[k, j, i] = i + j * size[0] + k * size[0] * size[1] return np.expand_dims(img, 0) @@ -51,6 +51,12 @@ def enumerate_results_of_op(results): print(ir, v) +def matrices_nearly_equal(actual, expected): + if actual.shape != expected.shape: + raise ValueError("actual matrix does not match expected matrix size; " + f"{actual} vs {expected} respectively") + + class TestLowLevel(unittest.TestCase): def test_extents_2(self): @@ -297,6 +303,47 @@ def test_croppad_img_even_crop_even(self): [115., 116., 117., 118., 119., 120.]]) self._croppad_impl((16, 16), (slice(4, 8), slice(3, 9)), expected) + def _test_flip_impl(self, dims, spatial_axis, expected, verbose=False): + if dims == 2: + img = get_img((32, 32)) + else: + img = get_img((32, 32, 8)) + + actual = flip(img, spatial_axis=spatial_axis) + if verbose: + print("expected\n", expected) + print("actual\n", actual[1]) + self.assertTrue(np.allclose(expected, actual[1])) + + def test_flip(self): + + tests = [ + (2, None, {(0, 0): -1, (1, 1): -1}), + (2, 0, {(0, 0): -1}), + (2, 1, {(1, 1): -1}), + (2, (0,), {(0, 0): -1}), + (2, (1,), {(1, 1): -1}), + (2, (0, 1), {(0, 0): -1, (1, 1): -1}), + (3, None, {(0, 0): -1, (1, 1): -1, (2, 2): -1}), + (3, 0, {(0, 0): -1}), + (3, 1, {(1, 1): -1}), + (3, 2, {(2, 2): -1}), + (3, (0,), {(0, 0): -1}), + (3, (1,), {(1, 1): -1}), + (3, (2,), {(2, 2): -1}), + (3, (0, 1), {(0, 0): -1, (1, 1): -1}), + (3, (0, 2), {(0, 0): -1, (2, 2): -1}), + (3, (1, 2), {(1, 1): -1, (2, 2): -1}), + (3, (0, 1, 2), {(0, 0): -1, (1, 1): -1, (2, 2): -1}), + ] + + for t in tests: + with self.subTest(f"{t}"): + expected = np.eye(t[0] + 1) + for ke, kv in t[2].items(): + expected[ke] = kv + self._test_flip_impl(t[0], t[1], expected) + class TestArrayTransforms(unittest.TestCase): diff --git a/tests/test_create_grid_and_affine.py b/tests/test_create_grid_and_affine.py index bda9865658..87901494cc 100644 --- a/tests/test_create_grid_and_affine.py +++ b/tests/test_create_grid_and_affine.py @@ -220,21 +220,22 @@ def test_create_rotate(self): (3, (0, 0, np.pi / 2)), np.array([[0.0, -1.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]), ) + def test_create_rotate_90(self): expected = np.eye(3) - test_assert(create_rotate_90, (2, 0, 0), expected) + test_assert(create_rotate_90, (2, (0, 1), 0), expected) expected = np.eye(3) expected[0:2, 0:2] = [[0, -1], [1, 0]] - test_assert(create_rotate_90, (2, 0, 1), expected) + test_assert(create_rotate_90, (2, (0, 1), 1), expected) expected = np.eye(3) expected[0:2, 0:2] = [[-1, 0], [0, -1]] - test_assert(create_rotate_90, (2, 0, 2), expected) + test_assert(create_rotate_90, (2, (0, 1), 2), expected) expected = np.eye(3) expected[0:2, 0:2] = [[0, 1], [-1, 0]] - test_assert(create_rotate_90, (2, 0, 3), expected) + test_assert(create_rotate_90, (2, (0, 1), 3), expected) def test_create_shear(self): test_assert(create_shear, (2, 1.0), np.array([[1.0, 1.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])) From eb7692d74e9226d3c420ab3eb3799a69a3746152 Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Wed, 14 Sep 2022 11:36:20 +0100 Subject: [PATCH 24/30] Further work on transforms --- monai/transforms/spatial/array.py | 169 +++++++++++++++++++----------- 1 file changed, 105 insertions(+), 64 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 14da37300a..e72deaea38 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -15,10 +15,11 @@ import warnings from copy import deepcopy from enum import Enum -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import numpy as np import torch +from numpy.lib.stride_tricks import as_strided from monai.config import USE_COMPILED, DtypeLike from monai.config.type_definitions import NdarrayOrTensor @@ -389,7 +390,7 @@ def __call__( RuntimeError: When ``dst_meta`` is missing. ValueError: When the affine matrix of the source image is not invertible. Returns: - Resampled input tensor or MetaTensor. + Resampled input image, Metadata """ if img_dst is None: raise RuntimeError("`img_dst` is missing.") @@ -481,6 +482,8 @@ def __call__( align_corners: Optional[bool] = None, dtype: DtypeLike = None, output_spatial_shape: Optional[Union[Sequence[int], np.ndarray, int]] = None, + pixdim: Optional[Union[Sequence[float], float, np.ndarray]] = None, + diagonal: Optional[bool] = None, ) -> torch.Tensor: """ Args: @@ -508,23 +511,22 @@ def __call__( ValueError: When ``pixdim`` is nonpositive. Returns: - data tensor or MetaTensor (resampled into `self.pixdim`). + data_array (resampled into `self.pixdim`), original affine, current affine. """ original_spatial_shape = data_array.shape[1:] sr = len(original_spatial_shape) if sr <= 0: raise ValueError("data_array must have at least one spatial dimension.") - input_affine: Optional[NdarrayOrTensor] = None affine_: np.ndarray - if affine is not None: - warnings.warn("arg `affine` is deprecated, the affine of MetaTensor in data_array has higher priority.") - input_affine = data_array.affine if isinstance(data_array, MetaTensor) else affine - if input_affine is None: + affine_np: np.ndarray + if isinstance(data_array, MetaTensor): + affine_np, *_ = convert_data_type(data_array.affine, np.ndarray) + affine_ = to_affine_nd(sr, affine_np) + else: warnings.warn("`data_array` is not of type MetaTensor, assuming affine to be identity.") # default to identity - input_affine = np.eye(sr + 1, dtype=np.float64) - affine_ = to_affine_nd(sr, convert_data_type(input_affine, np.ndarray)[0]) + affine_ = np.eye(sr + 1, dtype=np.float64) out_d = self.pixdim[:sr] if out_d.size < sr: @@ -596,7 +598,13 @@ def __init__( self.as_closest_canonical = as_closest_canonical self.labels = labels - def __call__(self, data_array: torch.Tensor) -> torch.Tensor: + def __call__( + self, + data_array: torch.Tensor, + axcodes: Optional[str] = None, + as_closest_canonical: Optional[bool] = None, + labels: Optional[Sequence[Tuple[str, str]]] = None + ) -> torch.Tensor: """ If input type is `MetaTensor`, original affine is extracted with `data_array.affine`. If input type is `torch.Tensor`, original affine is assumed to be identity. @@ -630,21 +638,27 @@ def __call__(self, data_array: torch.Tensor) -> torch.Tensor: affine_ = np.eye(sr + 1, dtype=np.float64) src = nib.io_orientation(affine_) - if self.as_closest_canonical: + + _axcodes = self.axcodes if axcodes is None else axcodes + _as_closest_canonical =\ + self.as_closest_canonical if as_closest_canonical is None else as_closest_canonical + _labels = self.labels if labels is None else labels + + if _as_closest_canonical: spatial_ornt = src else: - if self.axcodes is None: + if _axcodes is None: raise ValueError("Incompatible values: axcodes=None and as_closest_canonical=True.") - if sr < len(self.axcodes): + if sr < len(_axcodes): warnings.warn( - f"axcodes ('{self.axcodes}') length is smaller than the number of input spatial dimensions D={sr}.\n" + f"axcodes ('{_axcodes}') length is smaller than the number of input spatial dimensions D={sr}.\n" f"{self.__class__.__name__}: input spatial shape is {spatial_shape}, num. channels is {data_array.shape[0]}," "please make sure the input is in the channel-first format." ) - dst = nib.orientations.axcodes2ornt(self.axcodes[:sr], labels=self.labels) + dst = nib.orientations.axcodes2ornt(_axcodes[:sr], labels=_labels) if len(dst) < sr: raise ValueError( - f"axcodes must match data_array spatially, got axcodes={len(self.axcodes)}D data_array={sr}D" + f"axcodes must match data_array spatially, got axcodes={len(_axcodes)}D data_array={sr}D" ) spatial_ornt = nib.orientations.ornt_transform(src, dst) new_affine = affine_ @ nib.orientations.inv_ornt_aff(spatial_ornt, spatial_shape) @@ -678,7 +692,7 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor: # Create inverse transform orig_affine = transform[TraceKeys.EXTRA_INFO]["original_affine"] orig_axcodes = nib.orientations.aff2axcodes(orig_affine) - inverse_transform = Orientation(axcodes=orig_axcodes, as_closest_canonical=False, labels=self.labels) + inverse_transform = Orientation(axcodes=orig_axcodes, as_closest_canonical=False, labels=_labels) # Apply inverse with inverse_transform.trace_transform(False): data = inverse_transform(data) @@ -703,8 +717,18 @@ class Flip(InvertibleTransform): backend = [TransformBackends.TORCH] - def __init__(self, spatial_axis: Optional[Union[Sequence[int], int]] = None) -> None: - self.spatial_axis = spatial_axis + @deprecated_arg(name="spatial_axis", since="1.0", msg_suffix="please use `spatial_axes` instead.") + def __init__( + self, + spatial_axis: Optional[Union[Sequence[int], int]] = None, + spatial_axes: Optional[Union[Sequence[int], int]] = None + ) -> None: + if spatial_axis is not None and spatial_axes is not None: + raise ValueError("Only one of 'spatial_axis' and 'spatial_axes may be set; " + f"'spatial_axis' is {spatial_axis} " + f"and 'spatial_axes' is {spatial_axes}") + + self.spatial_axes = spatial_axis if spatial_axes is None else spatial_axes def update_meta(self, img, shape, axes): # shape and axes include the channel dim @@ -718,13 +742,18 @@ def update_meta(self, img, shape, axes): def forward_image(self, img, axes) -> torch.Tensor: return torch.flip(img, axes) - def __call__(self, img: torch.Tensor) -> torch.Tensor: + def __call__( + self, + img: torch.Tensor, + spatial_axes: Optional[Union[Sequence[int], int]] = None + ) -> torch.Tensor: """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]) """ img = convert_to_tensor(img, track_meta=get_track_meta()) - axes = map_spatial_axes(img.ndim, self.spatial_axis) + spatial_axes_ = self.spatial_axes if spatial_axes is None else self.spatial_axes + axes = map_spatial_axes(img.ndim, spatial_axes_) out = self.forward_image(img, axes) if get_track_meta(): self.update_meta(out, out.shape, axes) @@ -733,7 +762,7 @@ def __call__(self, img: torch.Tensor) -> torch.Tensor: def inverse(self, data: torch.Tensor) -> torch.Tensor: self.pop_transform(data) - flipper = Flip(spatial_axis=self.spatial_axis) + flipper = Flip(spatial_axes=self.spatial_axes) with flipper.trace_transform(False): return flipper(data) @@ -842,12 +871,10 @@ def __call__( scale = self.spatial_size / max(img_size) spatial_size_ = tuple(int(round(s * scale)) for s in img_size) - original_sp_size = img.shape[1:] - _mode = look_up_option(self.mode if mode is None else mode, InterpolateMode) - _align_corners = self.align_corners if align_corners is None else align_corners if tuple(img.shape[1:]) == spatial_size_: # spatial shape is already the desired - img = convert_to_tensor(img, track_meta=get_track_meta()) # type: ignore - return self._post_process(img, original_sp_size, spatial_size_, _mode, _align_corners, input_ndim) + return convert_to_tensor(img, track_meta=get_track_meta()) # type: ignore + + original_sp_size = img.shape[1:] img_ = convert_to_tensor(img, dtype=torch.float, track_meta=False) if anti_aliasing and any(x < y for x, y in zip(spatial_size_, img_.shape[1:])): @@ -864,25 +891,25 @@ def __call__( img_ = convert_to_tensor(anti_aliasing_filter(img_), track_meta=False) img = convert_to_tensor(img, track_meta=get_track_meta()) + _mode = look_up_option(self.mode if mode is None else mode, InterpolateMode) + _align_corners = self.align_corners if align_corners is None else align_corners + resized = torch.nn.functional.interpolate( input=img_.unsqueeze(0), size=spatial_size_, mode=_mode, align_corners=_align_corners ) out, *_ = convert_to_dst_type(resized.squeeze(0), img) - return self._post_process(out, original_sp_size, spatial_size_, _mode, _align_corners, input_ndim) - - def _post_process(self, img: torch.Tensor, orig_size, sp_size, mode, align_corners, ndim) -> torch.Tensor: if get_track_meta(): - self.update_meta(img, orig_size, sp_size) + self.update_meta(out, original_sp_size, spatial_size_) self.push_transform( - img, - orig_size=orig_size, + out, + orig_size=original_sp_size, extra_info={ - "mode": mode, - "align_corners": align_corners if align_corners is not None else TraceKeys.NONE, - "new_dim": len(orig_size) - ndim, # additional dims appended + "mode": _mode, + "align_corners": _align_corners if _align_corners is not None else TraceKeys.NONE, + "new_dim": len(original_sp_size) - input_ndim, # additional dims appended }, ) - return img + return out def update_meta(self, img, spatial_size, new_spatial_size): affine = convert_to_tensor(img.affine, track_meta=False) @@ -953,6 +980,8 @@ def __call__( padding_mode: Optional[str] = None, align_corners: Optional[bool] = None, dtype: Union[DtypeLike, torch.dtype] = None, + angle: Optional[Union[Sequence[float], float]] = None, + keep_size: Optional[bool] = None, ) -> torch.Tensor: """ Args: @@ -982,10 +1011,13 @@ def __call__( input_ndim = len(im_shape) if input_ndim not in (2, 3): raise ValueError(f"Unsupported image dimension: {input_ndim}, available options are [2, 3].") - _angle = ensure_tuple_rep(self.angle, 1 if input_ndim == 2 else 3) + + _keep_size = self.keep_size if keep_size is None else keep_size + _angle = self.angle if angle is None else angle + _angle = ensure_tuple_rep(_angle, 1 if input_ndim == 2 else 3) transform = create_rotate(input_ndim, _angle) shift = create_translate(input_ndim, ((im_shape - 1) / 2).tolist()) - if self.keep_size: + if _keep_size: output_shape = im_shape else: corners = np.asarray(np.meshgrid(*[(0, dim) for dim in im_shape], indexing="ij")).reshape( @@ -1114,6 +1146,8 @@ def __call__( mode: Optional[str] = None, padding_mode: Optional[str] = None, align_corners: Optional[bool] = None, + zoom: Optional[Union[Sequence[float], float]] = None, + keep_size: Optional[bool] = None ) -> torch.Tensor: """ Args: @@ -1137,6 +1171,9 @@ def __call__( img = convert_to_tensor(img, track_meta=get_track_meta()) img_t = img.to(torch.float32) + _keep_size = self.keep_size if keep_size is None else keep_size + + _zoom = self.zoom if zoom is None else zoom _zoom = ensure_tuple_rep(self.zoom, img.ndim - 1) # match the spatial image dim _mode = look_up_option(self.mode if mode is None else mode, InterpolateMode).value _align_corners = self.align_corners if align_corners is None else align_corners @@ -1155,7 +1192,7 @@ def __call__( out, *_ = convert_to_dst_type(zoomed, dst=img) if get_track_meta(): self.update_meta(out, orig_size[1:], z_size[1:]) - do_pad_crop = self.keep_size and not np.allclose(orig_size, z_size) + do_pad_crop = _keep_size and not np.allclose(orig_size, z_size) if do_pad_crop: _pad_crop = ResizeWithPadOrCrop(spatial_size=img_t.shape[1:], mode=_padding_mode) out = _pad_crop(out) @@ -1505,10 +1542,14 @@ class RandAxisFlip(RandomizableTransform, InvertibleTransform): backend = Flip.backend - def __init__(self, prob: float = 0.1) -> None: + def __init__( + self, + prob: Optional[float]=0.1, + spatial_axes: Optional[Union[Sequence[int], int]]=None + ) -> None: RandomizableTransform.__init__(self, prob) self._axis: Optional[int] = None - self.flipper = Flip(spatial_axis=self._axis) + self.flipper = Flip(spatial_axes=spatial_axes) def randomize(self, data: NdarrayOrTensor) -> None: super().randomize(None) @@ -1526,7 +1567,7 @@ def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor: self.randomize(data=img) if self._do_transform: - self.flipper.spatial_axis = self._axis + self.flipper.spatial_axes = self._axis out = self.flipper(img) else: out = convert_to_tensor(img, track_meta=get_track_meta()) @@ -1540,7 +1581,7 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor: transform = self.pop_transform(data) if not transform[TraceKeys.DO_TRANSFORM]: return data - flipper = Flip(spatial_axis=transform[TraceKeys.EXTRA_INFO]["axes"]) + flipper = Flip(spatial_axes=transform[TraceKeys.EXTRA_INFO]["axes"]) with flipper.trace_transform(False): return flipper(data) @@ -2972,31 +3013,31 @@ def __call__( split_size, steps = self._get_params(image.shape[1:], input_size) patches: List[NdarrayOrTensor] - as_strided_func: Callable if isinstance(image, torch.Tensor): - as_strided_func = torch.as_strided - c_stride, x_stride, y_stride = image.stride() # type: ignore + unfolded_image = ( + image.unfold(1, split_size[0], steps[0]) + .unfold(2, split_size[1], steps[1]) + .flatten(1, 2) + .transpose(0, 1) + ) + # Make a list of contiguous patches + patches = [p.contiguous() for p in unfolded_image] elif isinstance(image, np.ndarray): - as_strided_func = np.lib.stride_tricks.as_strided + x_step, y_step = steps c_stride, x_stride, y_stride = image.strides + n_channels = image.shape[0] + strided_image = as_strided( + image, + shape=(*self.grid, n_channels, split_size[0], split_size[1]), + strides=(x_stride * x_step, y_stride * y_step, c_stride, x_stride, y_stride), + ) + # Flatten the first two dimensions + strided_image = strided_image.reshape(-1, *strided_image.shape[2:]) + # Make a list of contiguous patches + patches = [np.ascontiguousarray(p) for p in strided_image] else: raise ValueError(f"Input type [{type(image)}] is not supported.") - x_step, y_step = steps - n_channels = image.shape[0] - strided_image = as_strided_func( - image, - (*self.grid, n_channels, split_size[0], split_size[1]), - (x_stride * x_step, y_stride * y_step, c_stride, x_stride, y_stride), - ) - # Flatten the first two dimensions - strided_image = strided_image.reshape(-1, *strided_image.shape[2:]) - # Make a list of contiguous patches - if isinstance(image, torch.Tensor): - patches = [p.contiguous() for p in strided_image] - elif isinstance(image, np.ndarray): - patches = [np.ascontiguousarray(p) for p in strided_image] - return patches def _get_params( From b353e3e343940ade021f1fc4b3c76ca5b6c52fdb Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Fri, 7 Oct 2022 14:43:36 +0100 Subject: [PATCH 25/30] Fixes for zoom and rotate; rename of spaced to spacingd; introduction of CachedTransform --- monai/transforms/atmostonce/array.py | 20 +-- monai/transforms/atmostonce/dictionary.py | 4 +- monai/transforms/atmostonce/functional.py | 15 +- monai/transforms/atmostonce/utility.py | 64 +++++++++ tests/test_atmostonce.py | 159 ++++++++++++++++++---- 5 files changed, 215 insertions(+), 47 deletions(-) create mode 100644 monai/transforms/atmostonce/utility.py diff --git a/monai/transforms/atmostonce/array.py b/monai/transforms/atmostonce/array.py index 9b164c6f6c..cd760ef25d 100644 --- a/monai/transforms/atmostonce/array.py +++ b/monai/transforms/atmostonce/array.py @@ -229,8 +229,8 @@ class Zoom(LazyTransform, InvertibleTransform): def __init__( self, - zoom: Union[Sequence[float], float], - mode: Union[InterpolateMode, str] = InterpolateMode.AREA, + factor: Union[Sequence[float], float], + mode: Union[InterpolateMode, str] = InterpolateMode.BILINEAR, padding_mode: Union[NumpyPadMode, PytorchPadMode, str] = NumpyPadMode.EDGE, align_corners: Optional[bool] = None, keep_size: Optional[bool] = True, @@ -239,35 +239,37 @@ def __init__( **kwargs ): LazyTransform.__init__(self, lazy_evaluation) - self.zoom = zoom + self.factor = factor self.mode: InterpolateMode = InterpolateMode(mode) self.padding_mode = padding_mode self.align_corners = align_corners self.keep_size = keep_size self.dtype = dtype self.kwargs = kwargs + print("mode =", self.mode) def __call__( self, img: NdarrayOrTensor, - zoom: Optional[Union[Sequence[float], float]] = None, + factor: Optional[Union[Sequence[float], float]] = None, mode: Optional[Union[InterpolateMode, str]] = None, padding_mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None, align_corners: Optional[bool] = None, shape_override: Optional[Sequence] = None ) -> NdarrayOrTensor: - mode = self.mode or mode - padding_mode = self.padding_mode or padding_mode - align_corners = self.align_corners or align_corners + factor = self.factor if factor is None else factor + mode = self.mode if mode is None else mode + padding_mode = self.padding_mode if padding_mode is None else padding_mode + align_corners = self.align_corners if align_corners is None else align_corners keep_size = self.keep_size dtype = self.dtype shape_override_ = shape_override if shape_override_ is None and isinstance(img, MetaTensor) and img.has_pending_transforms(): shape_override_ = img.peek_pending_transform().metadata.get("shape_override", None) - - img_t, transform, metadata = zoom(img, self.zoom, mode, padding_mode, align_corners, + print("mode =", mode) + img_t, transform, metadata = zoom(img, factor, mode, padding_mode, align_corners, keep_size, dtype, shape_override_) # TODO: candidate for refactoring into a LazyTransform method diff --git a/monai/transforms/atmostonce/dictionary.py b/monai/transforms/atmostonce/dictionary.py index 453f21cf06..408e7b8c64 100644 --- a/monai/transforms/atmostonce/dictionary.py +++ b/monai/transforms/atmostonce/dictionary.py @@ -75,7 +75,7 @@ def expand_potential_tuple(keys, value): # raise NotImplementedError(msg) -class Spaced(LazyTransform, MapTransform, InvertibleTransform): +class Spacingd(LazyTransform, MapTransform, InvertibleTransform): def __init__(self, keys: KeysCollection, @@ -192,7 +192,7 @@ def __call__(self, d: Mapping): keys_present = self.keys for ik, k in enumerate(keys_present): - tx = Resize(spatial_size, size_mode, self.modes[ik], self.align_corners, + tx = Resize(self.spatial_size, self.size_mode, self.modes[ik], self.align_corners, self.anti_aliasing, self.anti_aliasing_sigma, self.dtype) rd[k] = tx(d[k]) diff --git a/monai/transforms/atmostonce/functional.py b/monai/transforms/atmostonce/functional.py index 8bd9e0cf3d..6dbcef3c0b 100644 --- a/monai/transforms/atmostonce/functional.py +++ b/monai/transforms/atmostonce/functional.py @@ -243,16 +243,13 @@ def rotate( raise ValueError(f"Unsupported image dimension: {input_ndim}, available options are [2, 3].") angle_ = ensure_tuple_rep(angle, 1 if input_ndim == 2 else 3) - to_center_tx = create_translate(input_ndim, [d / 2 for d in input_shape[1:]]) - rotate_tx = create_rotate(input_ndim, angle_) + rotate_tx = torch.from_numpy(create_rotate(input_ndim, angle_)) im_extents = extents_from_shape(input_shape) if not keep_size: im_extents = [rotate_tx @ e for e in im_extents] spatial_shape = shape_from_extents(input_shape, im_extents) else: spatial_shape = input_shape - from_center_tx = create_translate(input_ndim, [-d / 2 for d in input_shape[1:]]) - # transform = from_center_tx @ rotate_tx @ to_center_tx transform = rotate_tx metadata = { "angle": angle_, @@ -269,8 +266,8 @@ def rotate( def zoom( img: torch.Tensor, - zoom: Union[Sequence[float], float], - mode: Optional[Union[InterpolateMode, str]] = InterpolateMode.AREA, + factor: Union[Sequence[float], float], + mode: Optional[Union[InterpolateMode, str]] = InterpolateMode.BILINEAR, padding_mode: Optional[Union[NumpyPadMode, GridSamplePadMode, str]] = NumpyPadMode.EDGE, align_corners: Optional[bool] = False, keep_size: Optional[bool] = True, @@ -297,7 +294,8 @@ def zoom( input_shape = img_.shape if shape_override is None else shape_override input_ndim = len(input_shape) - 1 - zoom_factors = ensure_tuple_rep(zoom, input_ndim) + zoom_factors = ensure_tuple_rep(factor, input_ndim) + zoom_factors = [1 / f for f in zoom_factors] mode_ = look_up_option(mode, GridSampleMode) padding_mode_ = look_up_option(padding_mode, GridSamplePadMode) @@ -312,7 +310,7 @@ def zoom( shape_override_ = input_shape metadata = { - "zoom": zoom_factors, + "factor": zoom_factors, "mode": mode_, "padding_mode": padding_mode_, "align_corners": align_corners, @@ -346,6 +344,7 @@ def rotate90( "spatial_axes": spatial_axes, "shape_override": shape_override } + return img_, transform, metadata def translate( diff --git a/monai/transforms/atmostonce/utility.py b/monai/transforms/atmostonce/utility.py new file mode 100644 index 0000000000..cd0a85ed7c --- /dev/null +++ b/monai/transforms/atmostonce/utility.py @@ -0,0 +1,64 @@ +from typing import Callable, Sequence + +from abc import ABC + + +class CacheMechanism(ABC): + """ + The interface for caching mechanisms to be used with CachedTransform. This interface provides + the ability to check whether cached objects are present, test and fetch simultaneously, and + store items. It makes no other assumptions about the caching mechanism, capacity, cache eviction + strategies or any other aspect of cache implementation + """ + + def try_fetch( + self, + key + ): + raise NotImplementedError() + + def store( + self, + key, + value + ): + raise NotImplementedError() + + +class CachedTransform: + """ + CachedTransform provides the functionality to cache the output of one or more transforms such + that they only need to be run once. Each time that CachedTransform is run, it checks whether + a cached entity is present, and if that entity is present, it loads it and returns the + resulting tensor / tensors as output. If that entity is not present in the cache, it executes + the transforms in its internal pipeline and caches the result before returning it. + """ + + def __init__( + self, + transforms: Callable, + cache: CacheMechanism + ): + """ + Args: + transforms: A sequence of callable objects + cache: A caching mechanism that implements the `CacheMechanism` interface + """ + self.transforms = transforms + self.cache = cache + + def __call__( + self, + key, + *args, + **kwargs + ): + is_present, value = self.cache.try_fetch(key) + + if is_present: + return value + + result = self.transforms(*args, **kwargs) + self.cache.store(key, result) + + return result diff --git a/tests/test_atmostonce.py b/tests/test_atmostonce.py index e0de31c96c..884152abef 100644 --- a/tests/test_atmostonce.py +++ b/tests/test_atmostonce.py @@ -14,13 +14,15 @@ from monai.utils import TransformBackends from monai.transforms import Affined, Affine, Flip -from monai.transforms.atmostonce.functional import croppad, resize, rotate, spacing, flip +from monai.transforms.atmostonce.functional import croppad, resize, rotate, zoom, spacing, flip from monai.transforms.atmostonce.apply import Applyd, extents_from_shape, shape_from_extents, apply from monai.transforms.atmostonce.dictionary import Rotated from monai.transforms.compose import Compose from monai.utils.enums import GridSampleMode, GridSamplePadMode from monai.utils.mapping_stack import MatrixFactory +from monai.transforms.atmostonce.utility import CachedTransform, CacheMechanism + def get_img(size, dtype=torch.float32, offset=0): img = torch.zeros(size, dtype=dtype) @@ -154,44 +156,69 @@ def test_mult_matrices(self): class TestFunctional(unittest.TestCase): + def _test_functional_impl(self, + op, + image, + params, + expected_matrix): + r_image, r_transform, r_metadata = op(image, **params) + enumerate_results_of_op((r_image, r_transform, r_metadata)) + self.assertTrue(torch.allclose(r_transform, expected_matrix)) + # TODO: turn into proper test def test_spacing(self): - results = spacing(np.zeros((1, 24, 32), dtype=np.float32), - (0.5, 0.6), - (1.0, 1.0), - False, - "bilinear", - "border", - False) + kwargs = { + "pixdim": (0.5, 0.6), "src_pixdim": (1.0, 1.0), "diagonal": False, + "mode": "bilinear", "padding_mode": "border", "align_corners": None + } + expected_tx = torch.DoubleTensor([[2.0, 0.0, 0.0], + [0.0, 1.66666667, 0.0], + [0.0, 0.0, 1.0]]) + self._test_functional_impl(spacing, get_img((24, 32)), kwargs, expected_tx) # TODO: turn into proper test def test_resize(self): - results = resize(np.zeros((1, 24, 32), dtype=np.float32), - (40, 40), - "all", - "bilinear", - False) - enumerate_results_of_op(results) + kwargs = { + "spatial_size": (40, 40), "size_mode": "all", + "mode": "bilinear", "align_corners": None + } + expected_tx = torch.DoubleTensor([[1.66666667, 0.0, 0.0], + [0.0, 1.25, 0.0], + [0.0, 0.0, 1.0]]) + self._test_functional_impl(resize, get_img((24, 32)), kwargs, expected_tx) + # TODO: turn into proper test def test_rotate(self): - results = rotate(np.zeros((1, 64, 64), dtype=np.float32), - torch.pi / 4, - True, - "bilinear", - "border") - enumerate_results_of_op(results) + kwargs = { + "angle": torch.pi / 4, "keep_size": True, + "mode": "bilinear", "padding_mode": "border" + } + expected_tx = torch.DoubleTensor([[0.70710678, -0.70710678, 0.0], + [0.70710678, 0.70710678, 0.0], + [0.0, 0.0, 1.0]]) + self._test_functional_impl(rotate, get_img((24, 32)), kwargs, expected_tx) + + + def test_zoom(self): + # results = zoom(np.zeros((1, 64, 64), dtype=np.float32), + # 2, + # "bilinear", + # "zeros") + # enumerate_results_of_op(results) + kwargs = { + "factor": 2, "mode": "nearest", "padding_mode": "border", "keep_size": True + } + expected_tx = torch.DoubleTensor([[0.5, 0.0, 0.0], + [0.0, 0.5, 0.0], + [0.0, 0.0, 1.0]]) + self._test_functional_impl(zoom, get_img((24, 32)), kwargs, expected_tx) - results = rotate(np.zeros((1, 64, 64), dtype=np.float32), - torch.pi / 4, - False, - "bilinear", - "border") - enumerate_results_of_op(results) def _check_matrix(self, actual, expected): np.allclose(actual, expected) + def _test_rotate_90_impl(self, values, keep_dims, expected): results = rotate(np.zeros((1, 64, 64, 32), dtype=np.float32), values, @@ -410,8 +437,26 @@ def test_rotate_apply_lazy(self): data = get_img((32, 32)) data = r(data) data = apply(data) - print(data.shape) - print(data) + expected = torch.DoubleTensor([[0.70710677, 0.70710677, 0.0, -15.61269784], + [-0.70710677, 0.70710677, 0.0, 15.5], + [0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 1.0]]) + self.assertTrue(torch.allclose(expected, data.affine)) + + def test_zoom_apply_lazy(self): + r = amoa.Zoom(2, + mode="bilinear", + padding_mode="border", + keep_size=False) + r.lazy_evaluation = True + data = get_img((32, 32)) + data = r(data) + data = apply(data) + expected = torch.DoubleTensor([[0.5, 0.0, 0.0, 11.75], + [0.0, 0.5, 0.0, 11.75], + [0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 1.0]]) + self.assertTrue(torch.allclose(expected, data.affine)) def test_crop_then_rotate_apply_lazy(self): data = get_img((32, 32)) @@ -430,6 +475,7 @@ def test_crop_then_rotate_apply_lazy(self): data2 = lr1(data1) datas.append(data2) + class TestDictionaryTransforms(unittest.TestCase): def test_rotate_numpy(self): @@ -509,3 +555,60 @@ def test_value_to_tuple_range(self): self.assertTupleEqual(value_to_tuple_range([4.3, -2.1]), (-2.1, 4.3)) self.assertTupleEqual(value_to_tuple_range((4.3, -2.1)), (-2.1, 4.3)) + +# Utility transforms for compose compiler +# ================================================================================================= + +class TestMemoryCacheMechanism(CacheMechanism): + + def __init__( + self, + max_count: int + ): + self.max_count = max_count + self.contents = dict() + self.order = list() + + def try_fetch( + self, + key + ): + if key in self.contents: + return True, self.contents[key] + + return False, None + + def store( + self, + key, + value + ): + if key in self.contents: + self.contents[key] = value + else: + if len(self.contents) >= self.max_count: + last = self.order.pop() + del self.contents[last] + + self.contents[key] = value + self.order.append(key) + + +class TestUtilityTransforms(unittest.TestCase): + + def test_cached_transform(self): + + def generate_noise(shape): + def _inner(*args, **kwargs): + return np.random.normal(size=shape) + return _inner + + ct = CachedTransform(transforms=generate_noise((1, 16, 16)), + cache=TestMemoryCacheMechanism(4)) + + first = ct("foo") + second = ct("foo") + third = ct("bar") + + self.assertIs(first, second) + self.assertIsNot(first, third) From bcbfb682f86c8c1415b0292e1f5d41a7533a6db3 Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Fri, 7 Oct 2022 15:19:56 +0100 Subject: [PATCH 26/30] Removing unnecessary comments in apply --- monai/transforms/atmostonce/apply.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/monai/transforms/atmostonce/apply.py b/monai/transforms/atmostonce/apply.py index aa1993f155..453d6750d2 100644 --- a/monai/transforms/atmostonce/apply.py +++ b/monai/transforms/atmostonce/apply.py @@ -171,9 +171,6 @@ def apply(data: Union[torch.Tensor, MetaTensor], cur_device = cur_device if new_device is None else new_device cur_dtype = cur_dtype if new_dtype is None else new_dtype cur_shape = cur_shape if new_shape is None else new_shape - # TODO: figure out how to propagate extents properly - # TODO: resampling strategy: augment resample or perform multiple stages if necessary - # TODO: resampling strategy - antialiasing: can resample just be augmented? kwargs = prepare_args_dict_for_apply(cur_mode, cur_padding_mode, cur_device, cur_dtype) From af043e9b2bec4fd51df76a91ecad39537fff1342 Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Wed, 12 Oct 2022 08:34:36 +0100 Subject: [PATCH 27/30] Adding utility transforms for compose compiler --- monai/transforms/atmostonce/apply.py | 3 - monai/transforms/atmostonce/array.py | 20 +-- monai/transforms/atmostonce/dictionary.py | 4 +- monai/transforms/atmostonce/functional.py | 15 +- monai/transforms/atmostonce/utility.py | 99 ++++++++++++ monai/transforms/spatial/array.py | 169 ++++++++------------ tests/test_atmostonce.py | 181 ++++++++++++++++++---- 7 files changed, 335 insertions(+), 156 deletions(-) create mode 100644 monai/transforms/atmostonce/utility.py diff --git a/monai/transforms/atmostonce/apply.py b/monai/transforms/atmostonce/apply.py index aa1993f155..453d6750d2 100644 --- a/monai/transforms/atmostonce/apply.py +++ b/monai/transforms/atmostonce/apply.py @@ -171,9 +171,6 @@ def apply(data: Union[torch.Tensor, MetaTensor], cur_device = cur_device if new_device is None else new_device cur_dtype = cur_dtype if new_dtype is None else new_dtype cur_shape = cur_shape if new_shape is None else new_shape - # TODO: figure out how to propagate extents properly - # TODO: resampling strategy: augment resample or perform multiple stages if necessary - # TODO: resampling strategy - antialiasing: can resample just be augmented? kwargs = prepare_args_dict_for_apply(cur_mode, cur_padding_mode, cur_device, cur_dtype) diff --git a/monai/transforms/atmostonce/array.py b/monai/transforms/atmostonce/array.py index 9b164c6f6c..cd760ef25d 100644 --- a/monai/transforms/atmostonce/array.py +++ b/monai/transforms/atmostonce/array.py @@ -229,8 +229,8 @@ class Zoom(LazyTransform, InvertibleTransform): def __init__( self, - zoom: Union[Sequence[float], float], - mode: Union[InterpolateMode, str] = InterpolateMode.AREA, + factor: Union[Sequence[float], float], + mode: Union[InterpolateMode, str] = InterpolateMode.BILINEAR, padding_mode: Union[NumpyPadMode, PytorchPadMode, str] = NumpyPadMode.EDGE, align_corners: Optional[bool] = None, keep_size: Optional[bool] = True, @@ -239,35 +239,37 @@ def __init__( **kwargs ): LazyTransform.__init__(self, lazy_evaluation) - self.zoom = zoom + self.factor = factor self.mode: InterpolateMode = InterpolateMode(mode) self.padding_mode = padding_mode self.align_corners = align_corners self.keep_size = keep_size self.dtype = dtype self.kwargs = kwargs + print("mode =", self.mode) def __call__( self, img: NdarrayOrTensor, - zoom: Optional[Union[Sequence[float], float]] = None, + factor: Optional[Union[Sequence[float], float]] = None, mode: Optional[Union[InterpolateMode, str]] = None, padding_mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None, align_corners: Optional[bool] = None, shape_override: Optional[Sequence] = None ) -> NdarrayOrTensor: - mode = self.mode or mode - padding_mode = self.padding_mode or padding_mode - align_corners = self.align_corners or align_corners + factor = self.factor if factor is None else factor + mode = self.mode if mode is None else mode + padding_mode = self.padding_mode if padding_mode is None else padding_mode + align_corners = self.align_corners if align_corners is None else align_corners keep_size = self.keep_size dtype = self.dtype shape_override_ = shape_override if shape_override_ is None and isinstance(img, MetaTensor) and img.has_pending_transforms(): shape_override_ = img.peek_pending_transform().metadata.get("shape_override", None) - - img_t, transform, metadata = zoom(img, self.zoom, mode, padding_mode, align_corners, + print("mode =", mode) + img_t, transform, metadata = zoom(img, factor, mode, padding_mode, align_corners, keep_size, dtype, shape_override_) # TODO: candidate for refactoring into a LazyTransform method diff --git a/monai/transforms/atmostonce/dictionary.py b/monai/transforms/atmostonce/dictionary.py index 453f21cf06..408e7b8c64 100644 --- a/monai/transforms/atmostonce/dictionary.py +++ b/monai/transforms/atmostonce/dictionary.py @@ -75,7 +75,7 @@ def expand_potential_tuple(keys, value): # raise NotImplementedError(msg) -class Spaced(LazyTransform, MapTransform, InvertibleTransform): +class Spacingd(LazyTransform, MapTransform, InvertibleTransform): def __init__(self, keys: KeysCollection, @@ -192,7 +192,7 @@ def __call__(self, d: Mapping): keys_present = self.keys for ik, k in enumerate(keys_present): - tx = Resize(spatial_size, size_mode, self.modes[ik], self.align_corners, + tx = Resize(self.spatial_size, self.size_mode, self.modes[ik], self.align_corners, self.anti_aliasing, self.anti_aliasing_sigma, self.dtype) rd[k] = tx(d[k]) diff --git a/monai/transforms/atmostonce/functional.py b/monai/transforms/atmostonce/functional.py index 8bd9e0cf3d..6dbcef3c0b 100644 --- a/monai/transforms/atmostonce/functional.py +++ b/monai/transforms/atmostonce/functional.py @@ -243,16 +243,13 @@ def rotate( raise ValueError(f"Unsupported image dimension: {input_ndim}, available options are [2, 3].") angle_ = ensure_tuple_rep(angle, 1 if input_ndim == 2 else 3) - to_center_tx = create_translate(input_ndim, [d / 2 for d in input_shape[1:]]) - rotate_tx = create_rotate(input_ndim, angle_) + rotate_tx = torch.from_numpy(create_rotate(input_ndim, angle_)) im_extents = extents_from_shape(input_shape) if not keep_size: im_extents = [rotate_tx @ e for e in im_extents] spatial_shape = shape_from_extents(input_shape, im_extents) else: spatial_shape = input_shape - from_center_tx = create_translate(input_ndim, [-d / 2 for d in input_shape[1:]]) - # transform = from_center_tx @ rotate_tx @ to_center_tx transform = rotate_tx metadata = { "angle": angle_, @@ -269,8 +266,8 @@ def rotate( def zoom( img: torch.Tensor, - zoom: Union[Sequence[float], float], - mode: Optional[Union[InterpolateMode, str]] = InterpolateMode.AREA, + factor: Union[Sequence[float], float], + mode: Optional[Union[InterpolateMode, str]] = InterpolateMode.BILINEAR, padding_mode: Optional[Union[NumpyPadMode, GridSamplePadMode, str]] = NumpyPadMode.EDGE, align_corners: Optional[bool] = False, keep_size: Optional[bool] = True, @@ -297,7 +294,8 @@ def zoom( input_shape = img_.shape if shape_override is None else shape_override input_ndim = len(input_shape) - 1 - zoom_factors = ensure_tuple_rep(zoom, input_ndim) + zoom_factors = ensure_tuple_rep(factor, input_ndim) + zoom_factors = [1 / f for f in zoom_factors] mode_ = look_up_option(mode, GridSampleMode) padding_mode_ = look_up_option(padding_mode, GridSamplePadMode) @@ -312,7 +310,7 @@ def zoom( shape_override_ = input_shape metadata = { - "zoom": zoom_factors, + "factor": zoom_factors, "mode": mode_, "padding_mode": padding_mode_, "align_corners": align_corners, @@ -346,6 +344,7 @@ def rotate90( "spatial_axes": spatial_axes, "shape_override": shape_override } + return img_, transform, metadata def translate( diff --git a/monai/transforms/atmostonce/utility.py b/monai/transforms/atmostonce/utility.py new file mode 100644 index 0000000000..a8fde06104 --- /dev/null +++ b/monai/transforms/atmostonce/utility.py @@ -0,0 +1,99 @@ +from typing import Callable, Sequence + +from abc import ABC + +import torch + + +class CacheMechanism(ABC): + """ + The interface for caching mechanisms to be used with CachedTransform. This interface provides + the ability to check whether cached objects are present, test and fetch simultaneously, and + store items. It makes no other assumptions about the caching mechanism, capacity, cache eviction + strategies or any other aspect of cache implementation + """ + + def try_fetch( + self, + key + ): + raise NotImplementedError() + + def store( + self, + key, + value + ): + raise NotImplementedError() + + +class CachedTransform: + """ + CachedTransform provides the functionality to cache the output of one or more transforms such + that they only need to be run once. Each time that CachedTransform is run, it checks whether + a cached entity is present, and if that entity is present, it loads it and returns the + resulting tensor / tensors as output. If that entity is not present in the cache, it executes + the transforms in its internal pipeline and caches the result before returning it. + """ + + def __init__( + self, + transforms: Callable, + cache: CacheMechanism + ): + """ + Args: + transforms: A sequence of callable objects + cache: A caching mechanism that implements the `CacheMechanism` interface + """ + self.transforms = transforms + self.cache = cache + + def __call__( + self, + key, + *args, + **kwargs + ): + is_present, value = self.cache.try_fetch(key) + + if is_present: + return value + + result = self.transforms(*args, **kwargs) + self.cache.store(key, result) + + return result + + +class MultiSampleTransform: + """ + Multi-sample takes the output of a transform that generates multiple samples and executes + each sample separately in a depth first fashion, gathering the results into an array that + is finally returned after all samples are processed + """ + def __init__( + self, + multi_sample: Callable, + transforms: Callable, + ): + self.multi_sample = multi_sample + self.transforms = transforms + + def __call__( + self, + t, + *args, + **kwargs + ): + output = list() + for mt in self.multi_sample(t): + mt_out = self.multi_sample(mt) + if isinstance(mt_out, torch.Tensor): + output.append(mt_out) + elif isinstance(mt_out, list): + output += mt_out + else: + raise ValueError(f"self.transform must return a Tensor or list of Tensors, but returned {mt_out}") + + return output diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index e72deaea38..14da37300a 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -15,11 +15,10 @@ import warnings from copy import deepcopy from enum import Enum -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import numpy as np import torch -from numpy.lib.stride_tricks import as_strided from monai.config import USE_COMPILED, DtypeLike from monai.config.type_definitions import NdarrayOrTensor @@ -390,7 +389,7 @@ def __call__( RuntimeError: When ``dst_meta`` is missing. ValueError: When the affine matrix of the source image is not invertible. Returns: - Resampled input image, Metadata + Resampled input tensor or MetaTensor. """ if img_dst is None: raise RuntimeError("`img_dst` is missing.") @@ -482,8 +481,6 @@ def __call__( align_corners: Optional[bool] = None, dtype: DtypeLike = None, output_spatial_shape: Optional[Union[Sequence[int], np.ndarray, int]] = None, - pixdim: Optional[Union[Sequence[float], float, np.ndarray]] = None, - diagonal: Optional[bool] = None, ) -> torch.Tensor: """ Args: @@ -511,22 +508,23 @@ def __call__( ValueError: When ``pixdim`` is nonpositive. Returns: - data_array (resampled into `self.pixdim`), original affine, current affine. + data tensor or MetaTensor (resampled into `self.pixdim`). """ original_spatial_shape = data_array.shape[1:] sr = len(original_spatial_shape) if sr <= 0: raise ValueError("data_array must have at least one spatial dimension.") + input_affine: Optional[NdarrayOrTensor] = None affine_: np.ndarray - affine_np: np.ndarray - if isinstance(data_array, MetaTensor): - affine_np, *_ = convert_data_type(data_array.affine, np.ndarray) - affine_ = to_affine_nd(sr, affine_np) - else: + if affine is not None: + warnings.warn("arg `affine` is deprecated, the affine of MetaTensor in data_array has higher priority.") + input_affine = data_array.affine if isinstance(data_array, MetaTensor) else affine + if input_affine is None: warnings.warn("`data_array` is not of type MetaTensor, assuming affine to be identity.") # default to identity - affine_ = np.eye(sr + 1, dtype=np.float64) + input_affine = np.eye(sr + 1, dtype=np.float64) + affine_ = to_affine_nd(sr, convert_data_type(input_affine, np.ndarray)[0]) out_d = self.pixdim[:sr] if out_d.size < sr: @@ -598,13 +596,7 @@ def __init__( self.as_closest_canonical = as_closest_canonical self.labels = labels - def __call__( - self, - data_array: torch.Tensor, - axcodes: Optional[str] = None, - as_closest_canonical: Optional[bool] = None, - labels: Optional[Sequence[Tuple[str, str]]] = None - ) -> torch.Tensor: + def __call__(self, data_array: torch.Tensor) -> torch.Tensor: """ If input type is `MetaTensor`, original affine is extracted with `data_array.affine`. If input type is `torch.Tensor`, original affine is assumed to be identity. @@ -638,27 +630,21 @@ def __call__( affine_ = np.eye(sr + 1, dtype=np.float64) src = nib.io_orientation(affine_) - - _axcodes = self.axcodes if axcodes is None else axcodes - _as_closest_canonical =\ - self.as_closest_canonical if as_closest_canonical is None else as_closest_canonical - _labels = self.labels if labels is None else labels - - if _as_closest_canonical: + if self.as_closest_canonical: spatial_ornt = src else: - if _axcodes is None: + if self.axcodes is None: raise ValueError("Incompatible values: axcodes=None and as_closest_canonical=True.") - if sr < len(_axcodes): + if sr < len(self.axcodes): warnings.warn( - f"axcodes ('{_axcodes}') length is smaller than the number of input spatial dimensions D={sr}.\n" + f"axcodes ('{self.axcodes}') length is smaller than the number of input spatial dimensions D={sr}.\n" f"{self.__class__.__name__}: input spatial shape is {spatial_shape}, num. channels is {data_array.shape[0]}," "please make sure the input is in the channel-first format." ) - dst = nib.orientations.axcodes2ornt(_axcodes[:sr], labels=_labels) + dst = nib.orientations.axcodes2ornt(self.axcodes[:sr], labels=self.labels) if len(dst) < sr: raise ValueError( - f"axcodes must match data_array spatially, got axcodes={len(_axcodes)}D data_array={sr}D" + f"axcodes must match data_array spatially, got axcodes={len(self.axcodes)}D data_array={sr}D" ) spatial_ornt = nib.orientations.ornt_transform(src, dst) new_affine = affine_ @ nib.orientations.inv_ornt_aff(spatial_ornt, spatial_shape) @@ -692,7 +678,7 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor: # Create inverse transform orig_affine = transform[TraceKeys.EXTRA_INFO]["original_affine"] orig_axcodes = nib.orientations.aff2axcodes(orig_affine) - inverse_transform = Orientation(axcodes=orig_axcodes, as_closest_canonical=False, labels=_labels) + inverse_transform = Orientation(axcodes=orig_axcodes, as_closest_canonical=False, labels=self.labels) # Apply inverse with inverse_transform.trace_transform(False): data = inverse_transform(data) @@ -717,18 +703,8 @@ class Flip(InvertibleTransform): backend = [TransformBackends.TORCH] - @deprecated_arg(name="spatial_axis", since="1.0", msg_suffix="please use `spatial_axes` instead.") - def __init__( - self, - spatial_axis: Optional[Union[Sequence[int], int]] = None, - spatial_axes: Optional[Union[Sequence[int], int]] = None - ) -> None: - if spatial_axis is not None and spatial_axes is not None: - raise ValueError("Only one of 'spatial_axis' and 'spatial_axes may be set; " - f"'spatial_axis' is {spatial_axis} " - f"and 'spatial_axes' is {spatial_axes}") - - self.spatial_axes = spatial_axis if spatial_axes is None else spatial_axes + def __init__(self, spatial_axis: Optional[Union[Sequence[int], int]] = None) -> None: + self.spatial_axis = spatial_axis def update_meta(self, img, shape, axes): # shape and axes include the channel dim @@ -742,18 +718,13 @@ def update_meta(self, img, shape, axes): def forward_image(self, img, axes) -> torch.Tensor: return torch.flip(img, axes) - def __call__( - self, - img: torch.Tensor, - spatial_axes: Optional[Union[Sequence[int], int]] = None - ) -> torch.Tensor: + def __call__(self, img: torch.Tensor) -> torch.Tensor: """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]) """ img = convert_to_tensor(img, track_meta=get_track_meta()) - spatial_axes_ = self.spatial_axes if spatial_axes is None else self.spatial_axes - axes = map_spatial_axes(img.ndim, spatial_axes_) + axes = map_spatial_axes(img.ndim, self.spatial_axis) out = self.forward_image(img, axes) if get_track_meta(): self.update_meta(out, out.shape, axes) @@ -762,7 +733,7 @@ def __call__( def inverse(self, data: torch.Tensor) -> torch.Tensor: self.pop_transform(data) - flipper = Flip(spatial_axes=self.spatial_axes) + flipper = Flip(spatial_axis=self.spatial_axis) with flipper.trace_transform(False): return flipper(data) @@ -871,10 +842,12 @@ def __call__( scale = self.spatial_size / max(img_size) spatial_size_ = tuple(int(round(s * scale)) for s in img_size) - if tuple(img.shape[1:]) == spatial_size_: # spatial shape is already the desired - return convert_to_tensor(img, track_meta=get_track_meta()) # type: ignore - original_sp_size = img.shape[1:] + _mode = look_up_option(self.mode if mode is None else mode, InterpolateMode) + _align_corners = self.align_corners if align_corners is None else align_corners + if tuple(img.shape[1:]) == spatial_size_: # spatial shape is already the desired + img = convert_to_tensor(img, track_meta=get_track_meta()) # type: ignore + return self._post_process(img, original_sp_size, spatial_size_, _mode, _align_corners, input_ndim) img_ = convert_to_tensor(img, dtype=torch.float, track_meta=False) if anti_aliasing and any(x < y for x, y in zip(spatial_size_, img_.shape[1:])): @@ -891,25 +864,25 @@ def __call__( img_ = convert_to_tensor(anti_aliasing_filter(img_), track_meta=False) img = convert_to_tensor(img, track_meta=get_track_meta()) - _mode = look_up_option(self.mode if mode is None else mode, InterpolateMode) - _align_corners = self.align_corners if align_corners is None else align_corners - resized = torch.nn.functional.interpolate( input=img_.unsqueeze(0), size=spatial_size_, mode=_mode, align_corners=_align_corners ) out, *_ = convert_to_dst_type(resized.squeeze(0), img) + return self._post_process(out, original_sp_size, spatial_size_, _mode, _align_corners, input_ndim) + + def _post_process(self, img: torch.Tensor, orig_size, sp_size, mode, align_corners, ndim) -> torch.Tensor: if get_track_meta(): - self.update_meta(out, original_sp_size, spatial_size_) + self.update_meta(img, orig_size, sp_size) self.push_transform( - out, - orig_size=original_sp_size, + img, + orig_size=orig_size, extra_info={ - "mode": _mode, - "align_corners": _align_corners if _align_corners is not None else TraceKeys.NONE, - "new_dim": len(original_sp_size) - input_ndim, # additional dims appended + "mode": mode, + "align_corners": align_corners if align_corners is not None else TraceKeys.NONE, + "new_dim": len(orig_size) - ndim, # additional dims appended }, ) - return out + return img def update_meta(self, img, spatial_size, new_spatial_size): affine = convert_to_tensor(img.affine, track_meta=False) @@ -980,8 +953,6 @@ def __call__( padding_mode: Optional[str] = None, align_corners: Optional[bool] = None, dtype: Union[DtypeLike, torch.dtype] = None, - angle: Optional[Union[Sequence[float], float]] = None, - keep_size: Optional[bool] = None, ) -> torch.Tensor: """ Args: @@ -1011,13 +982,10 @@ def __call__( input_ndim = len(im_shape) if input_ndim not in (2, 3): raise ValueError(f"Unsupported image dimension: {input_ndim}, available options are [2, 3].") - - _keep_size = self.keep_size if keep_size is None else keep_size - _angle = self.angle if angle is None else angle - _angle = ensure_tuple_rep(_angle, 1 if input_ndim == 2 else 3) + _angle = ensure_tuple_rep(self.angle, 1 if input_ndim == 2 else 3) transform = create_rotate(input_ndim, _angle) shift = create_translate(input_ndim, ((im_shape - 1) / 2).tolist()) - if _keep_size: + if self.keep_size: output_shape = im_shape else: corners = np.asarray(np.meshgrid(*[(0, dim) for dim in im_shape], indexing="ij")).reshape( @@ -1146,8 +1114,6 @@ def __call__( mode: Optional[str] = None, padding_mode: Optional[str] = None, align_corners: Optional[bool] = None, - zoom: Optional[Union[Sequence[float], float]] = None, - keep_size: Optional[bool] = None ) -> torch.Tensor: """ Args: @@ -1171,9 +1137,6 @@ def __call__( img = convert_to_tensor(img, track_meta=get_track_meta()) img_t = img.to(torch.float32) - _keep_size = self.keep_size if keep_size is None else keep_size - - _zoom = self.zoom if zoom is None else zoom _zoom = ensure_tuple_rep(self.zoom, img.ndim - 1) # match the spatial image dim _mode = look_up_option(self.mode if mode is None else mode, InterpolateMode).value _align_corners = self.align_corners if align_corners is None else align_corners @@ -1192,7 +1155,7 @@ def __call__( out, *_ = convert_to_dst_type(zoomed, dst=img) if get_track_meta(): self.update_meta(out, orig_size[1:], z_size[1:]) - do_pad_crop = _keep_size and not np.allclose(orig_size, z_size) + do_pad_crop = self.keep_size and not np.allclose(orig_size, z_size) if do_pad_crop: _pad_crop = ResizeWithPadOrCrop(spatial_size=img_t.shape[1:], mode=_padding_mode) out = _pad_crop(out) @@ -1542,14 +1505,10 @@ class RandAxisFlip(RandomizableTransform, InvertibleTransform): backend = Flip.backend - def __init__( - self, - prob: Optional[float]=0.1, - spatial_axes: Optional[Union[Sequence[int], int]]=None - ) -> None: + def __init__(self, prob: float = 0.1) -> None: RandomizableTransform.__init__(self, prob) self._axis: Optional[int] = None - self.flipper = Flip(spatial_axes=spatial_axes) + self.flipper = Flip(spatial_axis=self._axis) def randomize(self, data: NdarrayOrTensor) -> None: super().randomize(None) @@ -1567,7 +1526,7 @@ def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor: self.randomize(data=img) if self._do_transform: - self.flipper.spatial_axes = self._axis + self.flipper.spatial_axis = self._axis out = self.flipper(img) else: out = convert_to_tensor(img, track_meta=get_track_meta()) @@ -1581,7 +1540,7 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor: transform = self.pop_transform(data) if not transform[TraceKeys.DO_TRANSFORM]: return data - flipper = Flip(spatial_axes=transform[TraceKeys.EXTRA_INFO]["axes"]) + flipper = Flip(spatial_axis=transform[TraceKeys.EXTRA_INFO]["axes"]) with flipper.trace_transform(False): return flipper(data) @@ -3013,31 +2972,31 @@ def __call__( split_size, steps = self._get_params(image.shape[1:], input_size) patches: List[NdarrayOrTensor] + as_strided_func: Callable if isinstance(image, torch.Tensor): - unfolded_image = ( - image.unfold(1, split_size[0], steps[0]) - .unfold(2, split_size[1], steps[1]) - .flatten(1, 2) - .transpose(0, 1) - ) - # Make a list of contiguous patches - patches = [p.contiguous() for p in unfolded_image] + as_strided_func = torch.as_strided + c_stride, x_stride, y_stride = image.stride() # type: ignore elif isinstance(image, np.ndarray): - x_step, y_step = steps + as_strided_func = np.lib.stride_tricks.as_strided c_stride, x_stride, y_stride = image.strides - n_channels = image.shape[0] - strided_image = as_strided( - image, - shape=(*self.grid, n_channels, split_size[0], split_size[1]), - strides=(x_stride * x_step, y_stride * y_step, c_stride, x_stride, y_stride), - ) - # Flatten the first two dimensions - strided_image = strided_image.reshape(-1, *strided_image.shape[2:]) - # Make a list of contiguous patches - patches = [np.ascontiguousarray(p) for p in strided_image] else: raise ValueError(f"Input type [{type(image)}] is not supported.") + x_step, y_step = steps + n_channels = image.shape[0] + strided_image = as_strided_func( + image, + (*self.grid, n_channels, split_size[0], split_size[1]), + (x_stride * x_step, y_stride * y_step, c_stride, x_stride, y_stride), + ) + # Flatten the first two dimensions + strided_image = strided_image.reshape(-1, *strided_image.shape[2:]) + # Make a list of contiguous patches + if isinstance(image, torch.Tensor): + patches = [p.contiguous() for p in strided_image] + elif isinstance(image, np.ndarray): + patches = [np.ascontiguousarray(p) for p in strided_image] + return patches def _get_params( diff --git a/tests/test_atmostonce.py b/tests/test_atmostonce.py index e0de31c96c..6801aa52e3 100644 --- a/tests/test_atmostonce.py +++ b/tests/test_atmostonce.py @@ -13,14 +13,16 @@ from monai.transforms.atmostonce.utils import value_to_tuple_range from monai.utils import TransformBackends -from monai.transforms import Affined, Affine, Flip -from monai.transforms.atmostonce.functional import croppad, resize, rotate, spacing, flip +from monai.transforms import Affined, Affine, Flip, RandSpatialCropSamplesd, RandRotated +from monai.transforms.atmostonce.functional import croppad, resize, rotate, zoom, spacing, flip from monai.transforms.atmostonce.apply import Applyd, extents_from_shape, shape_from_extents, apply from monai.transforms.atmostonce.dictionary import Rotated from monai.transforms.compose import Compose from monai.utils.enums import GridSampleMode, GridSamplePadMode from monai.utils.mapping_stack import MatrixFactory +from monai.transforms.atmostonce.utility import CachedTransform, CacheMechanism + def get_img(size, dtype=torch.float32, offset=0): img = torch.zeros(size, dtype=dtype) @@ -154,44 +156,69 @@ def test_mult_matrices(self): class TestFunctional(unittest.TestCase): + def _test_functional_impl(self, + op, + image, + params, + expected_matrix): + r_image, r_transform, r_metadata = op(image, **params) + enumerate_results_of_op((r_image, r_transform, r_metadata)) + self.assertTrue(torch.allclose(r_transform, expected_matrix)) + # TODO: turn into proper test def test_spacing(self): - results = spacing(np.zeros((1, 24, 32), dtype=np.float32), - (0.5, 0.6), - (1.0, 1.0), - False, - "bilinear", - "border", - False) + kwargs = { + "pixdim": (0.5, 0.6), "src_pixdim": (1.0, 1.0), "diagonal": False, + "mode": "bilinear", "padding_mode": "border", "align_corners": None + } + expected_tx = torch.DoubleTensor([[2.0, 0.0, 0.0], + [0.0, 1.66666667, 0.0], + [0.0, 0.0, 1.0]]) + self._test_functional_impl(spacing, get_img((24, 32)), kwargs, expected_tx) # TODO: turn into proper test def test_resize(self): - results = resize(np.zeros((1, 24, 32), dtype=np.float32), - (40, 40), - "all", - "bilinear", - False) - enumerate_results_of_op(results) + kwargs = { + "spatial_size": (40, 40), "size_mode": "all", + "mode": "bilinear", "align_corners": None + } + expected_tx = torch.DoubleTensor([[1.66666667, 0.0, 0.0], + [0.0, 1.25, 0.0], + [0.0, 0.0, 1.0]]) + self._test_functional_impl(resize, get_img((24, 32)), kwargs, expected_tx) + # TODO: turn into proper test def test_rotate(self): - results = rotate(np.zeros((1, 64, 64), dtype=np.float32), - torch.pi / 4, - True, - "bilinear", - "border") - enumerate_results_of_op(results) + kwargs = { + "angle": torch.pi / 4, "keep_size": True, + "mode": "bilinear", "padding_mode": "border" + } + expected_tx = torch.DoubleTensor([[0.70710678, -0.70710678, 0.0], + [0.70710678, 0.70710678, 0.0], + [0.0, 0.0, 1.0]]) + self._test_functional_impl(rotate, get_img((24, 32)), kwargs, expected_tx) + + + def test_zoom(self): + # results = zoom(np.zeros((1, 64, 64), dtype=np.float32), + # 2, + # "bilinear", + # "zeros") + # enumerate_results_of_op(results) + kwargs = { + "factor": 2, "mode": "nearest", "padding_mode": "border", "keep_size": True + } + expected_tx = torch.DoubleTensor([[0.5, 0.0, 0.0], + [0.0, 0.5, 0.0], + [0.0, 0.0, 1.0]]) + self._test_functional_impl(zoom, get_img((24, 32)), kwargs, expected_tx) - results = rotate(np.zeros((1, 64, 64), dtype=np.float32), - torch.pi / 4, - False, - "bilinear", - "border") - enumerate_results_of_op(results) def _check_matrix(self, actual, expected): np.allclose(actual, expected) + def _test_rotate_90_impl(self, values, keep_dims, expected): results = rotate(np.zeros((1, 64, 64, 32), dtype=np.float32), values, @@ -410,8 +437,26 @@ def test_rotate_apply_lazy(self): data = get_img((32, 32)) data = r(data) data = apply(data) - print(data.shape) - print(data) + expected = torch.DoubleTensor([[0.70710677, 0.70710677, 0.0, -15.61269784], + [-0.70710677, 0.70710677, 0.0, 15.5], + [0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 1.0]]) + self.assertTrue(torch.allclose(expected, data.affine)) + + def test_zoom_apply_lazy(self): + r = amoa.Zoom(2, + mode="bilinear", + padding_mode="border", + keep_size=False) + r.lazy_evaluation = True + data = get_img((32, 32)) + data = r(data) + data = apply(data) + expected = torch.DoubleTensor([[0.5, 0.0, 0.0, 11.75], + [0.0, 0.5, 0.0, 11.75], + [0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 1.0]]) + self.assertTrue(torch.allclose(expected, data.affine)) def test_crop_then_rotate_apply_lazy(self): data = get_img((32, 32)) @@ -430,6 +475,7 @@ def test_crop_then_rotate_apply_lazy(self): data2 = lr1(data1) datas.append(data2) + class TestDictionaryTransforms(unittest.TestCase): def test_rotate_numpy(self): @@ -509,3 +555,80 @@ def test_value_to_tuple_range(self): self.assertTupleEqual(value_to_tuple_range([4.3, -2.1]), (-2.1, 4.3)) self.assertTupleEqual(value_to_tuple_range((4.3, -2.1)), (-2.1, 4.3)) + +# Utility transforms for compose compiler +# ================================================================================================= + +class TestMemoryCacheMechanism(CacheMechanism): + + def __init__( + self, + max_count: int + ): + self.max_count = max_count + self.contents = dict() + self.order = list() + + def try_fetch( + self, + key + ): + if key in self.contents: + return True, self.contents[key] + + return False, None + + def store( + self, + key, + value + ): + if key in self.contents: + self.contents[key] = value + else: + if len(self.contents) >= self.max_count: + last = self.order.pop() + del self.contents[last] + + self.contents[key] = value + self.order.append(key) + + +class TestUtilityTransforms(unittest.TestCase): + + def test_cached_transform(self): + + def generate_noise(shape): + def _inner(*args, **kwargs): + return np.random.normal(size=shape) + return _inner + + ct = CachedTransform(transforms=generate_noise((1, 16, 16)), + cache=TestMemoryCacheMechanism(4)) + + first = ct("foo") + second = ct("foo") + third = ct("bar") + + self.assertIs(first, second) + self.assertIsNot(first, third) + + def test_multi_transform(self): + + def fake_multi_sample(keys, num_samples, roi_size): + def _inner(t): + for i in range(num_samples): + yield {'image': t[i:i+roi_size[0], i:i+roi_size[1]]} + return _inner + +# t1 = RandSpatialCropSamplesd(keys=('image',), num_samples=4, roi_size=(32, 32)) + t1 = fake_multi_sample(keys=('image',), num_samples=4, roi_size=(32, 32)) + t2 = RandRotated(keys=('image',), range_z=(-torch.pi/2, torch.pi/2)) + c = Compose([t1, t2]) + + d = torch.rand((1, 64, 64)) + + _d = d.data + _dd = d.data.clone() + d.data = _dd + r = c({'image': d}) From 6295e0252ae02d58ddbf1e1a89b30caae7eb000f Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Fri, 14 Oct 2022 11:05:47 +0100 Subject: [PATCH 28/30] Adding RandCropPad and RandCropPadMultiSample array implementations --- monai/transforms/atmostonce/array.py | 379 +++++++++++++----- monai/transforms/atmostonce/compose.py | 35 ++ monai/transforms/atmostonce/functional.py | 23 ++ monai/transforms/atmostonce/lazy_transform.py | 5 +- monai/transforms/atmostonce/utility.py | 31 +- monai/transforms/transform.py | 3 +- tests/test_atmostonce.py | 220 +++++++++- 7 files changed, 561 insertions(+), 135 deletions(-) diff --git a/monai/transforms/atmostonce/array.py b/monai/transforms/atmostonce/array.py index cd760ef25d..93dcbf3246 100644 --- a/monai/transforms/atmostonce/array.py +++ b/monai/transforms/atmostonce/array.py @@ -10,18 +10,59 @@ from monai.transforms import InvertibleTransform, RandomizableTransform from monai.transforms.atmostonce.apply import apply -from monai.transforms.atmostonce.functional import resize, rotate, zoom, spacing, croppad, translate, rotate90, flip +from monai.transforms.atmostonce.functional import resize, rotate, zoom, spacing, croppad, translate, rotate90, flip, \ + identity from monai.transforms.atmostonce.lazy_transform import LazyTransform +from monai.transforms.atmostonce.utility import IMultiSampleTransform, ILazyTransform, IRandomizableTransform from monai.transforms.atmostonce.utils import value_to_tuple_range from monai.utils import (GridSampleMode, GridSamplePadMode, - InterpolateMode, NumpyPadMode, PytorchPadMode) + InterpolateMode, NumpyPadMode, PytorchPadMode, look_up_option) from monai.utils.mapping_stack import MetaMatrix -from monai.utils.misc import ensure_tuple +from monai.utils.misc import ensure_tuple, ensure_tuple_rep # TODO: these transforms are intended to replace array transforms once development is done + +class Identity(LazyTransform, InvertibleTransform): + + def __init__( + self, + mode: Optional[Union[GridSampleMode, str]] = None, + padding_mode: Optional[Union[GridSamplePadMode, str]] = None, + dtype: Optional[Union[DtypeLike, torch.dtype]] = None, + lazy_evaluation: Optional[bool] = False + ): + LazyTransform.__init__(self, lazy_evaluation) + self.mode = mode + self.padding_mode = padding_mode + self.dtype = dtype + + def __call__( + self, + img: torch.Tensor, + mode: Optional[Union[GridSampleMode, str]] = None, + padding_mode: Optional[Union[GridSamplePadMode, str]] = None, + dtype: Optional[Union[DtypeLike, torch.dtype]] = None + ): + mode_ = mode or self.mode + padding_mode_ = padding_mode or self.mode + dtype_ = dtype or self.dtype + + img_t, transform, metadata = identity(img, mode_, padding_mode_, dtype_) + + # TODO: candidate for refactoring into a LazyTransform method + img_t.push_pending_transform(MetaMatrix(transform, metadata)) + if not self.lazy_evaluation: + img_t = apply(img_t) + + return img_t + + def inverse(self, data): + return NotImplementedError() + + # spatial # ======= @@ -55,7 +96,7 @@ def __call__( mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, align_corners: Optional[bool] = None, - dtype: DtypeLike = None, + dtype: Optional[Union[DtypeLike, torch.dtype]] = None, shape_override: Optional[Sequence] = None ): @@ -251,10 +292,10 @@ def __init__( def __call__( self, img: NdarrayOrTensor, - factor: Optional[Union[Sequence[float], float]] = None, mode: Optional[Union[InterpolateMode, str]] = None, padding_mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None, align_corners: Optional[bool] = None, + factor: Optional[Union[Sequence[float], float]] = None, shape_override: Optional[Sequence] = None ) -> NdarrayOrTensor: @@ -339,9 +380,8 @@ def __init__( def randomize(self, data: Optional[Any] = None) -> None: super().randomize(None) - if not self._do_transform: - return None - self.k = self.R.randint(self.max_k) + 1 + if self._do_transform: + self.k = self.R.randint(self.max_k) + 1 def __call__( self, @@ -390,10 +430,9 @@ def __init__( def randomize(self, data: Optional[Any] = None) -> None: super().randomize(None) + if self._do_transform is True: + self.x, self.y, self.z = 0.0, 0.0, 0.0 - self.x, self.y, self.z = 0.0, 0.0, 0.0 - - if self._do_transform: self.x = self.R.uniform(low=self.range_x[0], high=self.range_x[1]) self.y = self.R.uniform(low=self.range_y[0], high=self.range_y[1]) self.z = self.R.uniform(low=self.range_z[0], high=self.range_z[1]) @@ -435,100 +474,142 @@ def __init__( spatial_axis: Optional[Union[Sequence[int], int]] = None ) -> None: RandomizableTransform.__init__(self, prob) + self.prob = prob self.spatial_axis = spatial_axis - + self.do_flip = False self.op = Flip(0, spatial_axis) + self.nop = Identity() + def randomize(self, data: Optional[Any] = None) -> None: + super().randomize(None) + if not self._do_transform: + self.do_flip = self._do_transform + + def __call__( + self, + img: NdarrayOrTensor, + randomize: Optional[bool] = True + ): + if randomize: + self.randomize() + if self.do_flip is True: + return self.op(img, self.spatial_axis) + else: + return self.nop(img) + + return self.op(img, self.spatial_axis) + + def inverse( + self, + data: NdarrayOrTensor, + ): + raise NotImplementedError() + + +class RandAxisFlip(RandomizableTransform, InvertibleTransform, LazyTransform): + + def __init__( + self, + prob: float = 0.1 + ) -> None: + RandomizableTransform.__init__(self, prob) + self.prob = prob + self.spatial_axis = None + self.op = Flip(self.spatial_axis) + + def randomize( + self, + data: Optional[Any] = None + ) -> None: + super().randomize(None) + if self._do_transform: + self.spatial_axis = self.R.randint(0, data.ndim - 1) + + def __call__( + self, + img: NdarrayOrTensor, + randomize: Optional[bool] = True + ) -> NdarrayOrTensor: + if randomize: + self.randomize() + + if self._do_transform: + spatial_axis = self.spatial_axis + else: + spatial_axis = None + + return self.op(img, spatial_axis) + + def inverse( + self, + data: NdarrayOrTensor, + ): + raise NotImplementedError() -# class RandRotateOld(RandomizableTransform, InvertibleTransform, LazyTransform): -# -# def __init__( -# self, -# range_x: Optional[Union[Tuple[float, float], float]] = 0.0, -# range_y: Optional[Union[Tuple[float, float], float]] = 0.0, -# range_z: Optional[Union[Tuple[float, float], float]] = 0.0, -# prob: Optional[float] = 0.1, -# keep_size: bool = True, -# mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, -# padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, -# align_corners: bool = False, -# dtype: Union[DtypeLike, torch.dtype] = np.float32 -# ): -# RandomizableTransform.__init__(self, prob) -# self.range_x = ensure_tuple(range_x) -# if len(self.range_x) == 1: -# self.range_x = tuple(sorted([-self.range_x[0], self.range_x[0]])) -# self.range_y = ensure_tuple(range_y) -# if len(self.range_y) == 1: -# self.range_y = tuple(sorted([-self.range_y[0], self.range_y[0]])) -# self.range_z = ensure_tuple(range_z) -# if len(self.range_z) == 1: -# self.range_z = tuple(sorted([-self.range_z[0], self.range_z[0]])) -# -# self.keep_size = keep_size -# self.mode = mode -# self.padding_mode = padding_mode -# self.align_corners = align_corners -# self.dtype = dtype -# -# self.x = 0.0 -# self.y = 0.0 -# self.z = 0.0 -# -# def randomize(self, data: Optional[Any] = None) -> None: -# super().randomize(None) -# if not self._do_transform: -# return None -# self.x = self.R.uniform(low=self.range_x[0], high=self.range_x[1]) -# self.y = self.R.uniform(low=self.range_y[0], high=self.range_y[1]) -# self.z = self.R.uniform(low=self.range_z[0], high=self.range_z[1]) -# -# def __call__( -# self, -# img: NdarrayOrTensor, -# mode: Optional[Union[InterpolateMode, str]] = None, -# padding_mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None, -# align_corners: Optional[bool] = None, -# dtype: Optional[Union[DtypeLike, torch.dtype]] = None, -# randomize: Optional[bool] = True, -# get_matrix: Optional[bool] = False, -# shape_override: Optional[Sequence] = None -# ) -> NdarrayOrTensor: -# -# if randomize: -# self.randomize() -# -# img_dims = len(img.shape) - 1 -# if self._do_transform: -# angle = self.x if img_dims == 2 else (self.x, self.y, self.z) -# else: -# angle = 0 if img_dims == 2 else (0, 0, 0) -# -# mode = self.mode or mode -# padding_mode = self.padding_mode or padding_mode -# align_corners = self.align_corners or align_corners -# keep_size = self.keep_size -# dtype = self.dtype -# -# shape_override_ = shape_override -# if shape_override_ is None and isinstance(img, MetaTensor) and img.has_pending_transforms(): -# shape_override_ = img.peek_pending_transform().metadata.get("shape_override", None) -# -# img_t, transform, metadata = rotate(img, angle, keep_size, mode, padding_mode, -# align_corners, dtype, shape_override_) -# -# # TODO: candidate for refactoring into a LazyTransform method -# img_t.push_pending_transform(MetaMatrix(transform, metadata)) -# if not self.lazy_evaluation: -# img_t = apply(img_t) -# -# return img_t -# -# def inverse( -# self, -# data: NdarrayOrTensor, -# ): -# raise NotImplementedError() + +class RandZoom(RandomizableTransform, InvertibleTransform, LazyTransform): + + def __init__( + self, + prob: float = 0.1, + min_zoom: Optional[Union[Sequence[float], float]] = 0.9, + max_zoom: Optional[Union[Sequence[float], float]] = 1.1, + mode: Optional[Union[GridSampleMode, str]] = InterpolateMode.AREA, + padding_mode: Optional[Union[GridSamplePadMode, NumpyPadMode, str]] = NumpyPadMode.EDGE, + align_corners: Optional[bool] = None, + keep_size: bool = True, + **kwargs + ) -> None: + RandomizableTransform.__init__(self, prob) + self.prob = prob + self.min_zoom = ensure_tuple(min_zoom) + self.max_zoom = ensure_tuple(max_zoom) + if len(self.min_zoom) != len(self.max_zoom): + raise AssertionError("min_zoom and max_zoom must have the same length ", + f"but are {min_zoom} and {max_zoom} respectively") + self.mode = look_up_option(mode, InterpolateMode) + self.padding_mode = padding_mode + self.align_corners = align_corners + self.keep_size = keep_size + self.factors = None + + self.op = Zoom(1.0, self.mode, self.padding_mode, self.align_corners, self.keep_size) + + def randomize( + self, + data: Optional[Any] = None + ) -> None: + super().randomize(None) + if not self._do_transform: + self.factors = [self.R.uniform(l, h) for l, h in zip(self.min_zoom, self.max_zoom)] + if len(self.factors) == 1: + # to keep the spatial shape ratio, use same random zoom factor for all dims + self.factors = ensure_tuple_rep(self.factors[0], data.ndim - 1) + elif len(self.factors) == 2 and data.ndim > 3: + # if 2 zoom factors provided for 3D data, use the first factor for H and W dims, second factor for D dim + self.factors =\ + ensure_tuple_rep(self.factors[0], data.ndim - 2) + ensure_tuple(self.factors[-1]) + + def __call__( + self, + img: NdarrayOrTensor, + randomize: Optional[bool] = True + ) -> NdarrayOrTensor: + if randomize: + self.randomize(img) + + if self._do_transform: + factors_ = self.factors + else: + factors_ = 1.0 + + return self.op(img, factor=factors_) + + def inverse( + self, + data: NdarrayOrTensor, + ): + raise NotImplementedError() class Translate(LazyTransform, InvertibleTransform): @@ -586,7 +667,7 @@ def __init__( self, slices: Optional[Sequence[slice]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = GridSamplePadMode.BORDER, - lazy_evaluation: Optional[bool] = True, + lazy_evaluation: Optional[bool] = True ): LazyTransform.__init__(self, lazy_evaluation) self.slices = slices @@ -618,3 +699,97 @@ def inverse( data: NdarrayOrTensor ): raise NotImplementedError() + + +class RandomCropPad(InvertibleTransform, RandomizableTransform, ILazyTransform): + + def __init__( + self, + sizes: Union[Sequence[int], int], + prob: Optional[float] = 0.1, + padding_mode: Optional[Union[GridSamplePadMode, str]] = GridSamplePadMode.BORDER, + lazy_evaluation: Optional[bool] = True + ): + RandomizableTransform.__init__(self, prob) + self.sizes = sizes + self.padding_mode = padding_mode + self.offsets = None + + self.op = CropPad(padding_mode=padding_mode, lazy_evaluation=lazy_evaluation) + + def randomize( + self, + img: torch.Tensor + ): + super().randomize(None) + if self._do_transform: + img_shape = img.shape[1:] + if isinstance(self.sizes, int): + crop_shape = tuple(self.sizes for _ in range(len(img_shape))) + else: + crop_shape = self.sizes + + valid_ranges = tuple(i - c for i, c in zip(img_shape, crop_shape)) + self.offsets = tuple(self.R.randint(0, r+1) if r > 0 else r for r in valid_ranges) + + def __call__( + self, + img: torch.Tensor, + randomize: Optional[bool] = True + ): + if randomize: + self.randomize(img) + + if self._do_transform: + offsets_ = self.offsets + slices = tuple(slice(o, o + s) for o, s in zip(offsets_, self.sizes)) + return self.op(img, slices=slices) + else: + return self.op(img) + + def inverse( + self, + data: NdarrayOrTensor + ): + raise NotImplementedError() + + @property + def lazy_evaluation(self): + return self.op.lazy_evaluation + + +class RandomCropPadMultiSample( + InvertibleTransform, ILazyTransform, IRandomizableTransform, IMultiSampleTransform +): + + def __init__( + self, + sizes: Union[Sequence[int], int], + sample_count: int, + padding_mode: Optional[Union[GridSamplePadMode, str]] = GridSamplePadMode.BORDER, + lazy_evaluation: Optional[bool] = True + ): + self.sample_count = sample_count + self.op = RandomCropPad(sizes, 1.0, padding_mode, lazy_evaluation) + + def __call__( + self, + img: torch.Tensor, + randomize: Optional[bool] = True + ): + for i in range(self.sample_count): + yield self.op(img, randomize) + + def inverse( + self, + data: NdarrayOrTensor + ): + raise NotImplementedError() + + def set_random_state(self, seed=None, state=None): + self.op.set_random_state(seed, state) + + @property + def lazy_evaluation(self): + return self.op.lazy_evaluation + diff --git a/monai/transforms/atmostonce/compose.py b/monai/transforms/atmostonce/compose.py index 0ec0367f8a..e46aa9d466 100644 --- a/monai/transforms/atmostonce/compose.py +++ b/monai/transforms/atmostonce/compose.py @@ -5,6 +5,7 @@ from monai.transforms.atmostonce.lazy_transform import LazyTransform, compile_lazy_transforms, flatten_sequences +from monai.transforms.atmostonce.utility import CachedTransformCompose from monai.utils import GridSampleMode, GridSamplePadMode, ensure_tuple, get_seed, MAX_SEED from monai.transforms import Randomizable, InvertibleTransform, OneOf, apply_transform @@ -80,6 +81,40 @@ def lazy_no_cache(transforms): return dest_transforms +class ComposeCompiler2: + + def compile(self, transforms, cache_mechanism): + + transforms_ = self.compile_caching(transforms, cache_mechanism) + + transforms__ = self.compile_multisampling(transforms_) + + transforms___ = self.compile_lazy_resampling(transforms__) + + return transforms___ + + def compile_caching(self, transforms, cache_stategy): + # given a list of transforms, determine where to add a cached transform object + # and what transforms to put in it + return transforms + + def compile_multisampling(self, transforms): + return transforms + + def compile_lazy_resampling(self, transforms): + return transforms + + def transform_is_container(self, t): + if isinstance(t, CachedTransform): + return True + return False + + def transform_is_multisampling(self, t): + # if isinstance(t, MultiSamplingTransform): + # return True + return False + + class Compose(Randomizable, InvertibleTransform): """ ``Compose`` provides the ability to chain a series of callables together in diff --git a/monai/transforms/atmostonce/functional.py b/monai/transforms/atmostonce/functional.py index 6dbcef3c0b..101fe58180 100644 --- a/monai/transforms/atmostonce/functional.py +++ b/monai/transforms/atmostonce/functional.py @@ -15,6 +15,29 @@ from monai.utils.mapping_stack import MatrixFactory +def identity( + img: torch.Tensor, + mode: Optional[Union[InterpolateMode, str]] = None, + padding_mode: Optional[Union[NumpyPadMode, GridSamplePadMode, str]] = None, + dtype: Optional[Union[DtypeLike, torch.dtype]] = None +): + img_ = convert_to_tensor(img, track_meta=get_track_meta()) + + mode_ = None if mode is None else look_up_option(mode, GridSampleMode) + padding_mode_ = None if padding_mode is None else look_up_option(padding_mode, GridSamplePadMode) + dtype_ = get_equivalent_dtype(dtype or img_.dtype, torch.Tensor) + + transform = MatrixFactory.from_tensor(img_).identity().matrix.matrix + + metadata = dict() + if mode_ is not None: + metadata["mode"] = mode_ + if padding_mode_ is not None: + metadata["padding_mode"] = padding_mode_ + metadata["dtype"] = dtype_ + return img_, transform, metadata + + def spacing( img: torch.Tensor, pixdim: Union[Sequence[float], float], diff --git a/monai/transforms/atmostonce/lazy_transform.py b/monai/transforms/atmostonce/lazy_transform.py index a027fc978e..4f2179dd2c 100644 --- a/monai/transforms/atmostonce/lazy_transform.py +++ b/monai/transforms/atmostonce/lazy_transform.py @@ -2,6 +2,7 @@ from monai.data import MetaTensor from monai.transforms import Randomizable from monai.transforms.atmostonce.apply import Applyd +from monai.transforms.atmostonce.utility import ILazyTransform from monai.utils.mapping_stack import MetaMatrix @@ -63,9 +64,7 @@ def compile_cached_dataloading_transforms(transforms): flat.insert - - -class LazyTransform: +class LazyTransform(ILazyTransform): def __init__(self, lazy_evaluation): self.lazy_evaluation = lazy_evaluation diff --git a/monai/transforms/atmostonce/utility.py b/monai/transforms/atmostonce/utility.py index a8fde06104..4b7036e9db 100644 --- a/monai/transforms/atmostonce/utility.py +++ b/monai/transforms/atmostonce/utility.py @@ -1,10 +1,23 @@ from typing import Callable, Sequence +import abc from abc import ABC import torch +class ILazyTransform(abc.ABC): + pass + + +class IMultiSampleTransform(abc.ABC): + pass + + +class IRandomizableTransform(abc.ABC): + pass + + class CacheMechanism(ABC): """ The interface for caching mechanisms to be used with CachedTransform. This interface provides @@ -27,10 +40,10 @@ def store( raise NotImplementedError() -class CachedTransform: +class CachedTransformCompose: """ - CachedTransform provides the functionality to cache the output of one or more transforms such - that they only need to be run once. Each time that CachedTransform is run, it checks whether + CachedTransformCompose provides the functionality to cache the output of one or more transforms + such that they only need to be run once. Each time that CachedTransform is run, it checks whether a cached entity is present, and if that entity is present, it loads it and returns the resulting tensor / tensors as output. If that entity is not present in the cache, it executes the transforms in its internal pipeline and caches the result before returning it. @@ -66,11 +79,11 @@ def __call__( return result -class MultiSampleTransform: +class MultiSampleTransformCompose: """ - Multi-sample takes the output of a transform that generates multiple samples and executes - each sample separately in a depth first fashion, gathering the results into an array that - is finally returned after all samples are processed + MultiSampleTransformCompose takes the output of a transform that generates multiple samples + and executes each sample separately in a depth first fashion, gathering the results into an + array that is finally returned after all samples are processed """ def __init__( self, @@ -88,8 +101,8 @@ def __call__( ): output = list() for mt in self.multi_sample(t): - mt_out = self.multi_sample(mt) - if isinstance(mt_out, torch.Tensor): + mt_out = self.transforms(mt) + if isinstance(mt_out, (torch.Tensor, dict)): output.append(mt_out) elif isinstance(mt_out, list): output += mt_out diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 730cb634c0..4c284619bd 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -22,6 +22,7 @@ from monai import config, transforms from monai.config import KeysCollection from monai.data.meta_tensor import MetaTensor +from monai.transforms.atmostonce.utility import IRandomizableTransform from monai.utils import MAX_SEED, ensure_tuple, first from monai.utils.enums import TransformBackends @@ -243,7 +244,7 @@ def __call__(self, data: Any): raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") -class RandomizableTransform(Randomizable, Transform): +class RandomizableTransform(Randomizable, Transform, IRandomizableTransform): """ An interface for handling random state locally, currently based on a class variable `R`, which is an instance of `np.random.RandomState`. diff --git a/tests/test_atmostonce.py b/tests/test_atmostonce.py index 6801aa52e3..b46657ffa3 100644 --- a/tests/test_atmostonce.py +++ b/tests/test_atmostonce.py @@ -13,15 +13,46 @@ from monai.transforms.atmostonce.utils import value_to_tuple_range from monai.utils import TransformBackends +from monai.transforms.spatial import array as spatialarray from monai.transforms import Affined, Affine, Flip, RandSpatialCropSamplesd, RandRotated from monai.transforms.atmostonce.functional import croppad, resize, rotate, zoom, spacing, flip from monai.transforms.atmostonce.apply import Applyd, extents_from_shape, shape_from_extents, apply from monai.transforms.atmostonce.dictionary import Rotated from monai.transforms.compose import Compose from monai.utils.enums import GridSampleMode, GridSamplePadMode -from monai.utils.mapping_stack import MatrixFactory +from monai.utils.mapping_stack import MatrixFactory, MetaMatrix -from monai.transforms.atmostonce.utility import CachedTransform, CacheMechanism +from monai.transforms.atmostonce.utility import CachedTransformCompose, CacheMechanism, MultiSampleTransformCompose + + +class FakeRand(np.random.RandomState): + + def __init__(self, + rands=tuple(), + randints=tuple(), + uniforms=tuple() + ): + self.rands = rands + self.randind = 0 + self.randints = randints + self.randintind = 0 + self.uniforms = uniforms + self.uniformind = 0 + + def rand(self, *_, **__): + value = self.rands[self.randind] + self.randind += 1 + return value + + def randint(self, *_, **__): + value = self.randints[self.randintind] + self.randintind += 1 + return value + + def uniform(self, *_, **__): + value = self.uniforms[self.uniformind] + self.uniformind += 1 + return value def get_img(size, dtype=torch.float32, offset=0): @@ -59,6 +90,35 @@ def matrices_nearly_equal(actual, expected): f"{actual} vs {expected} respectively") +def test_array_op_multi_sample(tester, op, img, expected): + + s = 0 + for actual in op(img): + e = expected[s] + s += 1 + if op.lazy_evaluation is True: + actual = apply(actual) + + if not torch.allclose(actual, e): + print("torch.allclose test returned False") + print(actual) + print(e) + tester.assertTrue(False) + + +def test_array_op(tester, op, img, expected): + actual = op(img) + + if op.lazy_evaluation is True: + actual = apply(actual) + + if not torch.allclose(actual, expected): + print("torch.allclose test returned False") + print(actual) + print(expected) + tester.assertTrue(False) + + class TestLowLevel(unittest.TestCase): def test_extents_2(self): @@ -374,20 +434,6 @@ def test_flip(self): class TestArrayTransforms(unittest.TestCase): - # TODO: amo: add tests for matrix and result size - def test_croppad(self): - img = get_img((15, 15)).astype(int) - results = croppad(img, (slice(4, 8), slice(3, 9))) - enumerate_results_of_op(results) - m = results[1].matrix.matrix - # print(m) - result_size = results[2]['spatial_shape'] - a = Affine(affine=m, - padding_mode=GridSamplePadMode.ZEROS, - spatial_size=result_size) - img_, _ = a(img) - # print(img_.numpy()) - def test_apply(self): img = get_img((16, 16)) r = Rotate(torch.pi / 4, @@ -417,6 +463,36 @@ def test_rand_rotate(self): enumerate_results_of_op(results) enumerate_results_of_op(results.pending_transforms[-1].metadata) + def test_rand_zoom(self): + r = amoa.RandZoom(prob=1.0, + min_zoom=0.9, + max_zoom=1.1, + mode="nearest", + padding_mode="zeros", + keep_size=True) + + r.set_random_state(state=FakeRand((0.5,), (1.05,))) + img = np.zeros((1, 32, 32)) + results = r(img) + enumerate_results_of_op(results) + enumerate_results_of_op(results.pending_transforms[-1].metadata) + + + # TODO: amo: add tests for matrix and result size + def test_croppad(self): + img = get_img((15, 15)).astype(int) + results = croppad(img, (slice(4, 8), slice(3, 9))) + enumerate_results_of_op(results) + m = results[1].matrix.matrix + # print(m) + result_size = results[2]['spatial_shape'] + a = Affine(affine=m, + padding_mode=GridSamplePadMode.ZEROS, + spatial_size=result_size) + img_, _ = a(img) + # print(img_.numpy()) + + def test_rotate_apply_not_lazy(self): r = amoa.Rotate(-torch.pi / 4, mode="bilinear", @@ -476,6 +552,20 @@ def test_crop_then_rotate_apply_lazy(self): datas.append(data2) +class TestOldTransforms(unittest.TestCase): + + def test_rand_zoom(self): + + r = spatialarray.RandZoom(1.0, 0.9, 1.1) + t = torch.rand((1, 32, 32)) + t_out = r(t) + print(t_out.shape) + + r = spatialarray.RandZoom(1.0, (0.9, 0.9, 0.9), (1.1, 1.1, 1.1)) + t_out = r(t) + print(t_out.shape) + + class TestDictionaryTransforms(unittest.TestCase): def test_rotate_numpy(self): @@ -556,6 +646,95 @@ def test_value_to_tuple_range(self): self.assertTupleEqual(value_to_tuple_range((4.3, -2.1)), (-2.1, 4.3)) +class TestCropPad(unittest.TestCase): + + def _test_functional(self, targs, img, expected): + result, tx, md = amoa.croppad(img, **targs) + result.push_pending_transform(MetaMatrix(tx, md)) + actual = apply(result) + if not torch.allclose(actual, expected): + print("torch.allclose test returned False") + print(actual) + print(expected) + self.assertTrue(False) + + def _test_rand(self, targs, rng_fac, img, expected): + targs['lazy_evaluation'] = False + r = amoa.RandomCropPad(**targs) + r.set_random_state(state=rng_fac()) + actual = r(img) + + if not torch.allclose(actual, expected): + print("torch.allclose test returned False") + print(actual) + print(expected) + self.assertTrue(False) + + targs['lazy_evaluation'] = True + r = amoa.RandomCropPad(**targs) + # a = amoa.apply() + r.set_random_state(state=rng_fac()) + actual = amoa.apply(r(img)) + + if not torch.allclose(actual, expected): + print("torch.allclose test returned False") + print(actual) + print(expected) + self.assertTrue(False) + + def test_croppad_all_valid(self): + targs = {'slices': None, 'padding_mode': 'zeros'} + img = get_img((16, 16)) + for j in range(8): + for i in range(8): + expected = torch.FloatTensor( + [[(i + j * 16) + ii + jj * 16 for ii in range(8)] for jj in range(8)]) + targs['slices'] = (slice(i, i+8), slice(j, j+8)) + self._test_functional(targs, img, expected) + + def test_randcroppad(self): + targs = {'sizes': (8, 8), 'prob': 1.0, 'padding_mode': 'zeros'} + rng_fac = lambda: FakeRand(rands=(0.5,), randints=(2, 6)) + img = get_img((16, 16)) + expected = torch.FloatTensor([[98 + i + j * 16 for i in range(8)] for j in range(8)]) + + self._test_rand(targs, rng_fac, img, expected) + + def test_randcroppad_ysmall(self): + targs = {'sizes': (8, 8), 'prob': 1.0, 'padding_mode': 'zeros'} + rng_fac = lambda: FakeRand(rands=(0.5,), randints=(6,)) + img = get_img((16, 6)) + expected = torch.FloatTensor([[102 + i + j * 16 for i in range(8)] for j in range(8)]) + + self._test_rand(targs, rng_fac, img, expected) + + def test_rand_croppad(self): + r = amoa.RandomCropPad((8, 8), 1.0, padding_mode="zeros", lazy_evaluation=False) + rng = FakeRand(rands=(0.5,), randints=(2, 6)) + r.set_random_state(state=rng) + + img = get_img((16, 16)) + + actual = r(img) + expected = torch.FloatTensor([[102 + i + j * 16 for i in range(8)] for j in range(8)]) + print(actual) + print(expected) + self.assertTrue(torch.allclose(actual, expected)) + + def test_randcroppadmulti(self): + op = amoa.RandomCropPadMultiSample((8, 8), 4, padding_mode="zeros", lazy_evaluation=False) + rng = FakeRand(rands=(0.5, 0.5, 0.5, 0.5), randints=(2, 6, 3, 5, 4, 4, 5, 3)) + op.set_random_state(state=rng) + img = get_img((16, 16)) + expected = [ + torch.FloatTensor([[38 + i + j * 16 for i in range(8)] for j in range(8)]), + torch.FloatTensor([[53 + i + j * 16 for i in range(8)] for j in range(8)]), + torch.FloatTensor([[68 + i + j * 16 for i in range(8)] for j in range(8)]), + torch.FloatTensor([[83 + i + j * 16 for i in range(8)] for j in range(8)]) + ] + test_array_op_multi_sample(self, op, img, expected) + + # Utility transforms for compose compiler # ================================================================================================= @@ -603,8 +782,8 @@ def _inner(*args, **kwargs): return np.random.normal(size=shape) return _inner - ct = CachedTransform(transforms=generate_noise((1, 16, 16)), - cache=TestMemoryCacheMechanism(4)) + ct = CachedTransformCompose(transforms=generate_noise((1, 16, 16)), + cache=TestMemoryCacheMechanism(4)) first = ct("foo") second = ct("foo") @@ -618,13 +797,14 @@ def test_multi_transform(self): def fake_multi_sample(keys, num_samples, roi_size): def _inner(t): for i in range(num_samples): - yield {'image': t[i:i+roi_size[0], i:i+roi_size[1]]} + yield {'image': t['image'][i:i+roi_size[0], i:i+roi_size[1]]} return _inner # t1 = RandSpatialCropSamplesd(keys=('image',), num_samples=4, roi_size=(32, 32)) t1 = fake_multi_sample(keys=('image',), num_samples=4, roi_size=(32, 32)) t2 = RandRotated(keys=('image',), range_z=(-torch.pi/2, torch.pi/2)) - c = Compose([t1, t2]) + mst = MultiSampleTransformCompose(t1, Compose([t2])) + c = Compose([mst]) d = torch.rand((1, 64, 64)) From b76e965c4429882425f029a993153691b3902cbf Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Wed, 19 Oct 2022 14:50:35 +0100 Subject: [PATCH 29/30] Compose compile; initial multisample generic croppad (array and dict) implementations --- monai/transforms/atmostonce/apply.py | 27 +++ monai/transforms/atmostonce/array.py | 11 +- monai/transforms/atmostonce/compose.py | 192 +++++++++++++--------- monai/transforms/atmostonce/dictionary.py | 152 ++++++++++++++++- tests/test_atmostonce.py | 167 ++++++++++++++++++- 5 files changed, 456 insertions(+), 93 deletions(-) diff --git a/monai/transforms/atmostonce/apply.py b/monai/transforms/atmostonce/apply.py index 453d6750d2..d5c6e5f0ec 100644 --- a/monai/transforms/atmostonce/apply.py +++ b/monai/transforms/atmostonce/apply.py @@ -110,6 +110,14 @@ def matrix_from_matrix_container(matrix): def apply(data: Union[torch.Tensor, MetaTensor], pending: Optional[dict] = None): + + if isinstance(data, dict): + rd = dict() + for k, v in data.items(): + result = apply(v) + rd[k] = result + return rd + pending_ = pending pending_ = data.pending_transforms @@ -188,17 +196,36 @@ def apply(data: Union[torch.Tensor, MetaTensor], return data +# make Apply universal for arrays and dictionaries; it just calls through to functional apply class Apply(InvertibleTransform): def __init__(self): super().__init__() + def __call__(self, *args, **kwargs): + return apply(*args, **kwargs) + + def inverse(self, data): + return NotImplementedError() + class Applyd(MapTransform, InvertibleTransform): def __init__(self): super().__init__() + def __call__( + self, + d: dict + ): + rd = dict() + for k, v in d.items(): + rd[k] = apply(v) + + def inverse(self, data): + return NotImplementedError() + + # class Applyd(MapTransform, InvertibleTransform): # # def __init__(self, diff --git a/monai/transforms/atmostonce/array.py b/monai/transforms/atmostonce/array.py index 93dcbf3246..e5361fe2dd 100644 --- a/monai/transforms/atmostonce/array.py +++ b/monai/transforms/atmostonce/array.py @@ -737,15 +737,20 @@ def __call__( img: torch.Tensor, randomize: Optional[bool] = True ): + + img_shape = img.shape[:1] + if randomize: self.randomize(img) if self._do_transform: offsets_ = self.offsets - slices = tuple(slice(o, o + s) for o, s in zip(offsets_, self.sizes)) - return self.op(img, slices=slices) else: - return self.op(img) + # center crop if this sample isn't random + offsets_ = tuple((i - s) // 2 for i, s in zip(img_shape, self.sizes)) + slices = tuple(slice(o, o + s) for o, s in zip(offsets_, self.sizes)) + return self.op(img, slices=slices) + def inverse( self, diff --git a/monai/transforms/atmostonce/compose.py b/monai/transforms/atmostonce/compose.py index e46aa9d466..0576656870 100644 --- a/monai/transforms/atmostonce/compose.py +++ b/monai/transforms/atmostonce/compose.py @@ -3,9 +3,10 @@ import numpy as np - +from monai.transforms.atmostonce.apply import Apply from monai.transforms.atmostonce.lazy_transform import LazyTransform, compile_lazy_transforms, flatten_sequences -from monai.transforms.atmostonce.utility import CachedTransformCompose +from monai.transforms.atmostonce.utility import CachedTransformCompose, MultiSampleTransformCompose, \ + IMultiSampleTransform, IRandomizableTransform, ILazyTransform from monai.utils import GridSampleMode, GridSamplePadMode, ensure_tuple, get_seed, MAX_SEED from monai.transforms import Randomizable, InvertibleTransform, OneOf, apply_transform @@ -13,75 +14,75 @@ # TODO: this is intended to replace Compose once development is done -class ComposeCompiler: - """ - Args: - transforms: A sequence of callable transforms - lazy_resampling: Whether to resample the data after each transform or accumulate - changes and then resample according to the accumulated changes as few times as - possible. Defaults to True as this nearly always improves speed and quality - caching_policy: Whether to cache deterministic transforms before the first - randomised transforms. This can be one of "off", "drive", "memory" - caching_favor: Whether to cache primarily for "speed" or for "quality". "speed" will - favor doing more work before caching, whereas "quality" will favour delaying - resampling until after caching - """ - def __init__( - self, - transforms: Union[Sequence[Callable], Callable], - lazy_resampling: Optional[bool] = True, - caching_policy: Optional[str] = "off", - caching_favor: Optional[str] = "quality" - ): - valid_caching_policies = ("off", "drive", "memory") - if caching_policy not in valid_caching_policies: - raise ValueError("parameter 'caching_policy' must be one of " - f"{valid_caching_policies} but is '{caching_policy}'") - - dest_transforms = None - - if caching_policy == "off": - if lazy_resampling is False: - dest_transforms = [t for t in transforms] - else: - dest_transforms = ComposeCompiler.lazy_no_cache() - else: - if caching_policy == "drive": - raise NotImplementedError() - elif caching_policy == "memory": - raise NotImplementedError() - - self.src_transforms = [t for t in transforms] - self.dest_transforms = dest_transforms - - def __getitem__( - self, - index - ): - return self.dest_transforms[index] - - def __len__(self): - return len(self.dest_transforms) - - @staticmethod - def lazy_no_cache(transforms): - dest_transforms = [] - # TODO: replace with lazy transform - cur_lazy = [] - for i_t in range(1, len(transforms)): - if isinstance(transforms[i_t], LazyTransform): - # add this to the stack of transforms to be handled lazily - cur_lazy.append(transforms[i_t]) - else: - if len(cur_lazy) > 0: - dest_transforms.append(cur_lazy) - # TODO: replace with lazy transform - cur_lazy = [] - dest_transforms.append(transforms[i_t]) - return dest_transforms +# class ComposeCompiler: +# """ +# Args: +# transforms: A sequence of callable transforms +# lazy_resampling: Whether to resample the data after each transform or accumulate +# changes and then resample according to the accumulated changes as few times as +# possible. Defaults to True as this nearly always improves speed and quality +# caching_policy: Whether to cache deterministic transforms before the first +# randomised transforms. This can be one of "off", "drive", "memory" +# caching_favor: Whether to cache primarily for "speed" or for "quality". "speed" will +# favor doing more work before caching, whereas "quality" will favour delaying +# resampling until after caching +# """ +# def __init__( +# self, +# transforms: Union[Sequence[Callable], Callable], +# lazy_resampling: Optional[bool] = True, +# caching_policy: Optional[str] = "off", +# caching_favor: Optional[str] = "quality" +# ): +# valid_caching_policies = ("off", "drive", "memory") +# if caching_policy not in valid_caching_policies: +# raise ValueError("parameter 'caching_policy' must be one of " +# f"{valid_caching_policies} but is '{caching_policy}'") +# +# dest_transforms = None +# +# if caching_policy == "off": +# if lazy_resampling is False: +# dest_transforms = [t for t in transforms] +# else: +# dest_transforms = ComposeCompiler.lazy_no_cache() +# else: +# if caching_policy == "drive": +# raise NotImplementedError() +# elif caching_policy == "memory": +# raise NotImplementedError() +# +# self.src_transforms = [t for t in transforms] +# self.dest_transforms = dest_transforms +# +# def __getitem__( +# self, +# index +# ): +# return self.dest_transforms[index] +# +# def __len__(self): +# return len(self.dest_transforms) +# +# @staticmethod +# def lazy_no_cache(transforms): +# dest_transforms = [] +# # TODO: replace with lazy transform +# cur_lazy = [] +# for i_t in range(1, len(transforms)): +# if isinstance(transforms[i_t], LazyTransform): +# # add this to the stack of transforms to be handled lazily +# cur_lazy.append(transforms[i_t]) +# else: +# if len(cur_lazy) > 0: +# dest_transforms.append(cur_lazy) +# # TODO: replace with lazy transform +# cur_lazy = [] +# dest_transforms.append(transforms[i_t]) +# return dest_transforms -class ComposeCompiler2: +class ComposeCompiler: def compile(self, transforms, cache_mechanism): @@ -93,26 +94,59 @@ def compile(self, transforms, cache_mechanism): return transforms___ - def compile_caching(self, transforms, cache_stategy): + def compile_caching(self, transforms, cache_mechanism): + # TODO: handle being passed a transform list with containers # given a list of transforms, determine where to add a cached transform object # and what transforms to put in it - return transforms + cacheable = list() + for t in transforms: + if self.transform_is_random(t) is False: + cacheable.append(t) + else: + break + + if len(cacheable) == 0: + return list(transforms) + else: + return [CachedTransformCompose(cacheable, cache_mechanism)] + transforms[len(cacheable):] def compile_multisampling(self, transforms): - return transforms + for i in reversed(range(len(transforms))): + if self.transform_is_multisampling(transforms[i]) is True: + transforms_ = transforms[:i] + [MultiSampleTransformCompose(transforms[i], + transforms[i+1:])] + return self.compile_multisampling(transforms_) + + return list(transforms) def compile_lazy_resampling(self, transforms): - return transforms + result = list() + lazy = list() + for i in range(len(transforms)): + if self.transform_is_lazy(transforms[i]): + lazy.append(transforms[i]) + else: + if len(lazy) > 0: + result.extend(lazy) + result.append(Apply()) + lazy = list() + result.append(transforms[i]) + if len(lazy) > 0: + result.extend(lazy) + result.append(Apply()) + return result + + def transform_is_random(self, t): + return isinstance(t, IRandomizableTransform) def transform_is_container(self, t): - if isinstance(t, CachedTransform): - return True - return False + return isinstance(t, CachedTransformCompose, MultiSampleTransformCompose) def transform_is_multisampling(self, t): - # if isinstance(t, MultiSamplingTransform): - # return True - return False + return isinstance(t, IMultiSampleTransform) + + def transform_is_lazy(self, t): + return isinstance(t, ILazyTransform) class Compose(Randomizable, InvertibleTransform): diff --git a/monai/transforms/atmostonce/dictionary.py b/monai/transforms/atmostonce/dictionary.py index 408e7b8c64..131f442fd0 100644 --- a/monai/transforms/atmostonce/dictionary.py +++ b/monai/transforms/atmostonce/dictionary.py @@ -4,13 +4,14 @@ import torch -from monai.transforms.atmostonce.array import Rotate, Resize, Spacing, Zoom +from monai.transforms.atmostonce.array import Rotate, Resize, Spacing, Zoom, CropPad +from monai.transforms.atmostonce.utility import ILazyTransform, IRandomizableTransform, IMultiSampleTransform from monai.utils import ensure_tuple_rep from monai.config import KeysCollection, DtypeLike, SequenceStr from monai.transforms.atmostonce.lazy_transform import LazyTransform from monai.transforms.inverse import InvertibleTransform -from monai.transforms.transform import MapTransform +from monai.transforms.transform import MapTransform, RandomizableTransform from monai.utils.enums import TransformBackends, GridSampleMode, GridSamplePadMode, InterpolateMode, NumpyPadMode, \ PytorchPadMode from monai.utils.mapping_stack import MatrixFactory @@ -36,6 +37,7 @@ def get_backend_from_data(data): msg = "'data' must be one of numpy ndarray or torch Tensor but is {}" raise ValueError(msg.format(type(data))) + # TODO: reconcile multiple definitions to one in utils def expand_potential_tuple(keys, value): if not isinstance(value, (tuple, list)): @@ -43,6 +45,16 @@ def expand_potential_tuple(keys, value): return value +def keys_to_process( + keys: Sequence[str], + dictionary: dict, + allow_missing_keys: bool, +): + if allow_missing_keys is True: + return {k for k in keys if k in dictionary} + return keys + + # class MappingStackTransformd(MapTransform, InvertibleTransform): # # def __init__(self, @@ -266,3 +278,139 @@ def __call__(self, d: Mapping): rd[k] = data return rd + + +class CropPadd(MapTransform, InvertibleTransform, ILazyTransform): + + def __init__( + self, + keys: KeysCollection, + slices: Optional[Sequence[slice]] = None, + padding_mode: Optional[Union[GridSamplePadMode, str]] = GridSamplePadMode.BORDER, + lazy_evaluation: Optional[bool] = True + ): + self.keys = keys + self.slices = slices + self.padding_modes = padding_mode + self.lazy_evaluation = lazy_evaluation + + + def __call__( + self, + d: dict + ): + keys = keys_to_process(self.keys, d, self.allow_missing_keys) + + rd = dict(d) + for ik, k in enumerate(keys): + tx = CropPad(slices=self.slices, + padding_mode=self.padding_modes, + lazy_evaluation=self.lazy_evaluation) + + rd[k] = tx(d[k]) + + return rd + + +class RandomCropPadd(MapTransform, InvertibleTransform, RandomizableTransform, ILazyTransform): + + def __init__( + self, + keys: KeysCollection, + sizes: Union[Sequence[int], int], + prob: Optional[float] = 0.1, + padding_mode: Optional[Union[GridSamplePadMode, str]] = GridSamplePadMode.BORDER, + allow_missing_keys: bool=False, + lazy_evaluation: Optional[bool] = True + ): + RandomizableTransform.__init__(self, prob) + self.keys = keys + self.sizes = sizes + self.padding_mode = padding_mode + self.offsets = None + self.allow_missing_keys = allow_missing_keys + + self.op = CropPad(None, padding_mode) + + def randomize( + self, + img: torch.Tensor, + ): + super().randomize(None) + if self._do_transform: + img_shape = img.shape[1:] + if isinstance(self.sizes, int): + crop_shape = tuple(self.sizes for _ in range(len(img_shape))) + else: + crop_shape = self.sizes + + valid_ranges = tuple(i - c for i, c in zip(img_shape, crop_shape)) + self.offsets = tuple(self.R.randint(0, r+1) if r > 0 else r for r in valid_ranges) + + def __call__( + self, + d: dict, + randomize: Optional[bool] = True + ): + keys = keys_to_process(self.keys, d, self.allow_missing_keys) + + img = d[keys[0]] + img_shape = img.shape[:1] + + if randomize: + self.randomize(img) + + if self._do_transform: + offsets_ = self.offsets + else: + # center crop if this sample isn't random + offsets_ = tuple((i - s) // 2 for i, s in zip(img_shape, self.sizes)) + + slices = tuple(slice(o, o + s) for o, s in zip(offsets_, self.sizes)) + + rd = dict(d) + for k in keys: + rd[k] = self.op(img, slices=slices) + + return rd + + @property + def lazy_evaluation(self): + return self.op.lazy_evaluation + + +class RandomCropPadMultiSampled( + InvertibleTransform, IRandomizableTransform, ILazyTransform, IMultiSampleTransform +): + + def __init__( + self, + keys: Sequence[str], + sizes: Union[Sequence[int], int], + sample_count: int, + padding_mode: Optional[Union[GridSamplePadMode, str]] = GridSamplePadMode.BORDER, + lazy_evaluation: Optional[bool] = True + ): + self.sample_count = sample_count + self.op = RandomCropPadd(keys, sizes, 1.0, padding_mode, lazy_evaluation) + + def __call__( + self, + d: dict, + randomize: Optional[bool] = True + ): + for i in range(self.sample_count): + yield self.op(d, randomize) + + def inverse( + self, + data: dict + ): + raise NotImplementedError() + + def set_random_state(self, seed=None, state=None): + self.op.set_random_state(seed, state) + + @property + def lazy_evaluation(self): + return self.op.lazy_evaluation \ No newline at end of file diff --git a/tests/test_atmostonce.py b/tests/test_atmostonce.py index b46657ffa3..035b235295 100644 --- a/tests/test_atmostonce.py +++ b/tests/test_atmostonce.py @@ -2,13 +2,16 @@ import math +import astropy.samp.tests.test_errors import numpy as np import torch from monai.transforms.atmostonce import array as amoa +from monai.transforms.atmostonce import dictionary as amod from monai.transforms.atmostonce.array import Rotate, CropPad +from monai.transforms.atmostonce.compose import ComposeCompiler from monai.transforms.atmostonce.lazy_transform import compile_lazy_transforms from monai.transforms.atmostonce.utils import value_to_tuple_range from monai.utils import TransformBackends @@ -16,13 +19,15 @@ from monai.transforms.spatial import array as spatialarray from monai.transforms import Affined, Affine, Flip, RandSpatialCropSamplesd, RandRotated from monai.transforms.atmostonce.functional import croppad, resize, rotate, zoom, spacing, flip -from monai.transforms.atmostonce.apply import Applyd, extents_from_shape, shape_from_extents, apply +from monai.transforms.atmostonce.apply import Applyd, extents_from_shape, shape_from_extents, apply, Apply from monai.transforms.atmostonce.dictionary import Rotated +import monai.transforms.croppad.array as ocpa from monai.transforms.compose import Compose from monai.utils.enums import GridSampleMode, GridSamplePadMode from monai.utils.mapping_stack import MatrixFactory, MetaMatrix -from monai.transforms.atmostonce.utility import CachedTransformCompose, CacheMechanism, MultiSampleTransformCompose +from monai.transforms.atmostonce.utility import CachedTransformCompose, CacheMechanism, MultiSampleTransformCompose, \ + IMultiSampleTransform, IRandomizableTransform, ILazyTransform class FakeRand(np.random.RandomState): @@ -99,11 +104,19 @@ def test_array_op_multi_sample(tester, op, img, expected): if op.lazy_evaluation is True: actual = apply(actual) - if not torch.allclose(actual, e): - print("torch.allclose test returned False") - print(actual) - print(e) - tester.assertTrue(False) + if isinstance(e, dict): + for k, v in e.items(): + if not torch.allclose(actual[k], v): + print("torch.allclose test returned False") + print(actual) + print(e) + tester.assertTrue(False) + else: + if not torch.allclose(actual, e): + print("torch.allclose test returned False") + print(actual) + print(e) + tester.assertTrue(False) def test_array_op(tester, op, img, expected): @@ -434,7 +447,7 @@ def test_flip(self): class TestArrayTransforms(unittest.TestCase): - def test_apply(self): + def test_apply_function(self): img = get_img((16, 16)) r = Rotate(torch.pi / 4, keep_size=False, @@ -565,6 +578,16 @@ def test_rand_zoom(self): t_out = r(t) print(t_out.shape) + def test_center_spatial_crop(self): + r = ocpa.CenterSpatialCrop(4) + img = get_img((8, 8)) + result = r(img) + print(result) + + img = get_img((9, 9)) + result = r(img) + print(result) + class TestDictionaryTransforms(unittest.TestCase): @@ -734,6 +757,26 @@ def test_randcroppadmulti(self): ] test_array_op_multi_sample(self, op, img, expected) + def test_randcropppadmultid(self): + op = amod.RandomCropPadMultiSampled(('img', 'lbl'), + (8, 8), + 4, + padding_mode="zeros", + lazy_evaluation=False) + rng = FakeRand(rands=(0.5, 0.5, 0.5, 0.5), randints=(2, 6, 3, 5, 4, 4, 5, 3)) + op.set_random_state(state=rng) + img = get_img((16, 16)) + lbl = get_img((16, 16)) + d = {'img': img, 'lbl': lbl} + expected_ts = [ + torch.FloatTensor([[38 + i + j * 16 for i in range(8)] for j in range(8)]), + torch.FloatTensor([[53 + i + j * 16 for i in range(8)] for j in range(8)]), + torch.FloatTensor([[68 + i + j * 16 for i in range(8)] for j in range(8)]), + torch.FloatTensor([[83 + i + j * 16 for i in range(8)] for j in range(8)]) + ] + expected = [{'img': e, 'lbl': e} for e in expected_ts] + test_array_op_multi_sample(self, op, d, expected) + # Utility transforms for compose compiler # ================================================================================================= @@ -797,7 +840,7 @@ def test_multi_transform(self): def fake_multi_sample(keys, num_samples, roi_size): def _inner(t): for i in range(num_samples): - yield {'image': t['image'][i:i+roi_size[0], i:i+roi_size[1]]} + yield {'image': t['image'][0:1, i:i+roi_size[0], i:i+roi_size[1]]} return _inner # t1 = RandSpatialCropSamplesd(keys=('image',), num_samples=4, roi_size=(32, 32)) @@ -812,3 +855,109 @@ def _inner(t): _dd = d.data.clone() d.data = _dd r = c({'image': d}) + print(r) + + def test_compile_caching(self): + class NotRandomizable: + def __init__(self, name): + self.name = name + + def __repr__(self): + return f"NR<{self.name}>" + + class Randomizable(IRandomizableTransform): + def __init__(self, name): + self.name = name + + def __repr__(self): + return f"R<{self.name}>" + + a = NotRandomizable("a") + b = NotRandomizable("b") + c = Randomizable("c") + d = Randomizable("d") + e = NotRandomizable("e") + + source_transforms = [a, b, c, d, e] + + cc = ComposeCompiler() + + actual = cc.compile_caching(source_transforms, CacheMechanism()) + + self.assertIsInstance(actual[0], CachedTransformCompose) + self.assertEqual(len(actual[0].transforms), 2) + self.assertTrue(actual[0].transforms[0], a) + self.assertTrue(actual[0].transforms[1], b) + self.assertTrue(actual[1], c) + self.assertTrue(actual[2], d) + self.assertTrue(actual[3], e) + + + def test_compile_multisampling(self): + class NotMultiSampling: + def __init__(self, name): + self.name = name + + def __repr__(self): + return f"NMS<{self.name}>" + + class MultiSampling(IMultiSampleTransform): + def __init__(self, name): + self.name = name + + def __repr__(self): + return f"MS<{self.name}>" + + a = NotMultiSampling("a") + b = NotMultiSampling("b") + c = MultiSampling("c") + d = NotMultiSampling("d") + e = MultiSampling("e") + f = NotMultiSampling("f") + + source_transforms = [a, b, c, d, e, f] + + cc = ComposeCompiler() + + actual = cc.compile_multisampling(source_transforms) + + self.assertEqual(actual[0], a) + self.assertEqual(actual[1], b) + self.assertIsInstance(actual[2], MultiSampleTransformCompose) + self.assertEqual(actual[2].multi_sample, c) + self.assertEqual(len(actual[2].transforms), 2) + self.assertEqual(actual[2].transforms[0], d) + self.assertIsInstance(actual[2].transforms[1], MultiSampleTransformCompose) + self.assertEqual(actual[2].transforms[1].multi_sample, e) + self.assertEqual(len(actual[2].transforms[1].transforms), 1) + self.assertEqual(actual[2].transforms[1].transforms[0], f) + + def test_compile_lazy_resampling(self): + class NotLazy: + def __init__(self, name): + self.name = name + + def __repr__(self): + return f"NL<{self.name}>" + + class Lazy(ILazyTransform): + def __init__(self, name): + self.name = name + + def __repr__(self): + return f"L<{self.name}>" + + a = NotLazy("a") + b = Lazy("b") + c = Lazy("c") + d = NotLazy("d") + e = Lazy("e") + f = Lazy("f") + + source_transforms = [a, b, c, d, e, f] + + cc = ComposeCompiler() + + actual = cc.compile_lazy_resampling(source_transforms) + + print(actual) \ No newline at end of file From 62f8172f020c3c8b6ac182b6b4ca5c3805f9877b Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Wed, 26 Oct 2022 17:25:28 +0100 Subject: [PATCH 30/30] More work towards lazy resampling --- monai/data/meta_tensor.py | 3 +- monai/transforms/atmostonce/apply.py | 71 +----- monai/transforms/atmostonce/array.py | 253 +++++++++++++++++---- monai/transforms/atmostonce/compose.py | 79 +------ monai/transforms/atmostonce/dictionary.py | 63 ++++- monai/transforms/atmostonce/functional.py | 80 ++++++- monai/transforms/atmostonce/randomizers.py | 105 +++++++++ monai/transforms/atmostonce/utility.py | 9 +- tests/test_atmostonce.py | 78 +++++++ 9 files changed, 543 insertions(+), 198 deletions(-) create mode 100644 monai/transforms/atmostonce/randomizers.py diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index deda678ef9..39fac39ca2 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -158,8 +158,9 @@ def __init__( def push_pending_transform(self, meta_matrix): self._pending_transforms.append(meta_matrix) + @property def has_pending_transforms(self): - return len(self._pending_transforms) + return len(self._pending_transforms) > 0 def peek_pending_transform(self): return copy.deepcopy(self._pending_transforms[-1]) diff --git a/monai/transforms/atmostonce/apply.py b/monai/transforms/atmostonce/apply.py index d5c6e5f0ec..56c5802dda 100644 --- a/monai/transforms/atmostonce/apply.py +++ b/monai/transforms/atmostonce/apply.py @@ -141,10 +141,8 @@ def apply(data: Union[torch.Tensor, MetaTensor], for meta_matrix in pending_: next_matrix = meta_matrix.matrix - print("intermediate matrix\n", matrix_from_matrix_container(cumulative_matrix)) - # cumulative_matrix = matmul(next_matrix, cumulative_matrix) + # print("intermediate matrix\n", matrix_from_matrix_container(cumulative_matrix)) cumulative_matrix = matmul(cumulative_matrix, next_matrix) - # cumulative_extents = [e @ translate_to_centre.matrix.matrix for e in cumulative_extents] cumulative_extents = [matmul(e, cumulative_matrix) for e in cumulative_extents] new_mode = meta_matrix.metadata.get('mode', None) @@ -160,12 +158,10 @@ def apply(data: Union[torch.Tensor, MetaTensor], if (mode_compat is False or padding_mode_compat is False or device_compat is False or dtype_compat is False): - print("intermediate apply required") # carry out an intermediate resample here due to incompatibility between arguments kwargs = prepare_args_dict_for_apply(cur_mode, cur_padding_mode, cur_device, cur_dtype) cumulative_matrix_ = matrix_from_matrix_container(cumulative_matrix) - print(f"intermediate applying with cumulative matrix\n {cumulative_matrix_}") a = Affine(norm_coords=False, affine=cumulative_matrix_, **kwargs) @@ -184,7 +180,7 @@ def apply(data: Union[torch.Tensor, MetaTensor], cumulative_matrix_ = matrix_from_matrix_container(cumulative_matrix) - print(f"applying with cumulative matrix\n {cumulative_matrix_}") + # print(f"applying with cumulative matrix\n {cumulative_matrix_}") a = Affine(norm_coords=False, affine=cumulative_matrix_, spatial_size=cur_shape[1:], @@ -224,66 +220,3 @@ def __call__( def inverse(self, data): return NotImplementedError() - - -# class Applyd(MapTransform, InvertibleTransform): -# -# def __init__(self, -# keys: KeysCollection, -# modes: GridSampleModeSequence, -# padding_modes: GridSamplePadModeSequence, -# normalized: bool = False, -# device: Optional[torch.device] = None, -# dtypes: Optional[DtypeSequence] = np.float32): -# self.keys = keys -# self.modes = modes -# self.padding_modes = padding_modes -# self.device = device -# self.dtypes = dtypes -# self.resamplers = dict() -# -# if isinstance(dtypes, (list, tuple)): -# if len(keys) != len(dtypes): -# raise ValueError("'keys' and 'dtypes' must be the same length if 'dtypes' is a sequence") -# -# # create a resampler for each output data type -# unique_resamplers = dict() -# for d in dtypes: -# if d not in unique_resamplers: -# unique_resamplers[d] = Resample(norm_coords=not normalized, device=device, dtype=d) -# -# # assign each named data input the appropriate resampler for that data type -# for k, d in zip(keys, dtypes): -# if k not in self.resamplers: -# self.resamplers[k] = unique_resamplers[d] -# -# else: -# # share the resampler across all named data inputs -# resampler = Resample(norm_coords=not normalized, device=device, dtype=dtypes) -# for k in keys: -# self.resamplers[k] = resampler -# -# def __call__(self, -# data: Mapping[Hashable, NdarrayOrTensor], -# allow_missing_keys: bool = False) -> Dict[Hashable, NdarrayOrTensor]: -# d = dict(data) -# mapping_stack = d["mappings"] -# keys = d.keys() -# for key_tuple in self.key_iterator(d, -# expand_scalar_to_tuple(self.modes, len(keys)), -# expand_scalar_to_tuple(self.padding_modes, len(keys)), -# expand_scalar_to_tuple(self.dtypes, len(keys))): -# key, mode, padding_mode, dtype = key_tuple -# affine = mapping_stack[key].transform() -# data = d[key] -# spatial_size = data.shape[1:] -# grid = create_grid(spatial_size, device=self.device, backend="torch", dtype=dtype) -# _device = grid.device -# -# _b = TransformBackends.TORCH if isinstance(grid, torch.Tensor) else TransformBackends.NUMPY -# -# grid, *_ = convert_data_type(grid, torch.Tensor, device=_device, dtype=grid.dtype) -# affine, *_ = convert_to_dst_type(affine, grid) -# d[key] = self.resamplers[key](data, grid=grid, mode=mode, padding_mode=padding_mode) -# -# return d diff --git a/monai/transforms/atmostonce/array.py b/monai/transforms/atmostonce/array.py index e5361fe2dd..aa6c4ed3ff 100644 --- a/monai/transforms/atmostonce/array.py +++ b/monai/transforms/atmostonce/array.py @@ -3,6 +3,9 @@ import numpy as np import torch +from monai.networks.utils import meshgrid_ij + +from monai.transforms.spatial.array import RandRange from monai.config import DtypeLike, NdarrayOrTensor from monai.data import MetaTensor @@ -11,8 +14,9 @@ from monai.transforms.atmostonce.apply import apply from monai.transforms.atmostonce.functional import resize, rotate, zoom, spacing, croppad, translate, rotate90, flip, \ - identity + identity, grid_distortion, elastic_3d from monai.transforms.atmostonce.lazy_transform import LazyTransform +from monai.transforms.atmostonce.randomizers import RotateRandomizer, Elastic3DRandomizer from monai.transforms.atmostonce.utility import IMultiSampleTransform, ILazyTransform, IRandomizableTransform from monai.transforms.atmostonce.utils import value_to_tuple_range @@ -106,7 +110,7 @@ def __call__( dtype_ = dtype or self.dtype shape_override_ = shape_override - if shape_override_ is None and isinstance(img, MetaTensor) and img.has_pending_transforms(): + if shape_override_ is None and isinstance(img, MetaTensor) and img.has_pending_transforms: shape_override_ = img.peek_pending_transform().metadata.get("shape_override", None) img_t, transform, metadata = spacing(img, self.pixdim, self.src_pixdim, self.diagonal, @@ -143,7 +147,7 @@ def __call__( spatial_axis_ = self.spatial_axis = spatial_axis shape_override_ = shape_override if (shape_override_ is None and - isinstance(img, MetaTensor) and img.has_pending_transforms()): + isinstance(img, MetaTensor) and img.has_pending_transforms): shape_override_ = img.peek_pending_transform().metadata.get("shape_override", None) img_t, transform, metadata = flip(img, spatial_axis_, shape_override_) @@ -193,7 +197,7 @@ def __call__( anti_aliasing_sigma_ = anti_aliasing_sigma or self.anti_aliasing_sigma shape_override_ = shape_override - if shape_override_ is None and isinstance(img, MetaTensor) and img.has_pending_transforms(): + if shape_override_ is None and isinstance(img, MetaTensor) and img.has_pending_transforms: shape_override_ = img.peek_pending_transform().metadata.get("shape_override", None) img_t, transform, metadata = resize(img, self.spatial_size, self.size_mode, mode_, @@ -237,19 +241,20 @@ def __call__( align_corners: Optional[bool] = None, shape_override: Optional[Sequence] = None ) -> NdarrayOrTensor: - angle = self.angle - mode = mode or self.mode - padding_mode = padding_mode or self.padding_mode - align_corners = align_corners or self.align_corners + angle_ = self.angle or angle + mode_ = mode or self.mode or mode + padding_mode_ = padding_mode or self.padding_mode + align_corners_ = align_corners or self.align_corners keep_size = self.keep_size - dtype = self.dtype + dtype_ = self.dtype shape_override_ = shape_override - if shape_override_ is None and isinstance(img, MetaTensor) and img.has_pending_transforms(): + if shape_override_ is None and isinstance(img, MetaTensor) and img.has_pending_transforms: shape_override_ = img.peek_pending_transform().metadata.get("shape_override", None) - img_t, transform, metadata = rotate(img, angle, keep_size, mode, padding_mode, - align_corners, dtype, shape_override_) + # TODO: We should be tracking random rotate rather than just rotate + img_t, transform, metadata = rotate(img, angle_, keep_size, mode_, padding_mode_, + align_corners_, dtype_, shape_override_) # TODO: candidate for refactoring into a LazyTransform method img_t.push_pending_transform(MetaMatrix(transform, metadata)) @@ -307,7 +312,7 @@ def __call__( dtype = self.dtype shape_override_ = shape_override - if shape_override_ is None and isinstance(img, MetaTensor) and img.has_pending_transforms(): + if shape_override_ is None and isinstance(img, MetaTensor) and img.has_pending_transforms: shape_override_ = img.peek_pending_transform().metadata.get("shape_override", None) print("mode =", mode) img_t, transform, metadata = zoom(img, factor, mode, padding_mode, align_corners, @@ -348,7 +353,7 @@ def __call__( shape_override_ = shape_override if (shape_override_ is None and - isinstance(img, MetaTensor) and img.has_pending_transforms()): + isinstance(img, MetaTensor) and img.has_pending_transforms): shape_override_ = img.peek_pending_transform().metadata.get("shape_override", None) img_t, transform, metadata = rotate90(img, k_, spatial_axes_, shape_override_) @@ -404,7 +409,69 @@ def inverse( raise NotImplementedError() -class RandRotate(RandomizableTransform, InvertibleTransform, LazyTransform): +# class RandRotate(RandomizableTransform, InvertibleTransform, LazyTransform): +# +# def __init__( +# self, +# range_x: Optional[Union[Tuple[float, float], float]] = 0.0, +# range_y: Optional[Union[Tuple[float, float], float]] = 0.0, +# range_z: Optional[Union[Tuple[float, float], float]] = 0.0, +# prob: Optional[float] = 0.1, +# keep_size: Optional[bool] = True, +# mode: Optional[Union[GridSampleMode, str]] = GridSampleMode.BILINEAR, +# padding_mode: Optional[Union[GridSamplePadMode, str]] = GridSamplePadMode.BORDER, +# align_corners: Optional[bool] = False, +# dtype: Optional[Union[DtypeLike, torch.dtype]] = np.float32, +# lazy_evaluation: Optional[bool] = True +# ): +# RandomizableTransform.__init__(self, prob) +# self.range_x = value_to_tuple_range(range_x) +# self.range_y = value_to_tuple_range(range_y) +# self.range_z = value_to_tuple_range(range_z) +# +# self.x, self.y, self.z = 0.0, 0.0, 0.0 +# +# self.op = Rotate(0, keep_size, mode, padding_mode, align_corners, dtype, lazy_evaluation) +# +# def randomize(self, data: Optional[Any] = None) -> None: +# super().randomize(None) +# if self._do_transform is True: +# self.x, self.y, self.z = 0.0, 0.0, 0.0 +# +# self.x = self.R.uniform(low=self.range_x[0], high=self.range_x[1]) +# self.y = self.R.uniform(low=self.range_y[0], high=self.range_y[1]) +# self.z = self.R.uniform(low=self.range_z[0], high=self.range_z[1]) +# +# def __call__( +# self, +# img: NdarrayOrTensor, +# mode: Optional[Union[InterpolateMode, str]] = None, +# padding_mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None, +# align_corners: Optional[bool] = None, +# dtype: Optional[Union[DtypeLike, torch.dtype]] = None, +# randomize: Optional[bool] = True, +# shape_override: Optional[Sequence] = None +# ) -> NdarrayOrTensor: +# +# if randomize: +# self.randomize() +# +# img_dims = len(img.shape) - 1 +# if self._do_transform: +# angle = self.x if img_dims == 2 else (self.x, self.y, self.z) +# else: +# angle = 0 if img_dims == 2 else (0, 0, 0) +# +# return self.op(img, angle, mode, padding_mode, align_corners, shape_override) +# +# def inverse( +# self, +# data: NdarrayOrTensor, +# ): +# raise NotImplementedError() + + +class RandRotate(InvertibleTransform, ILazyTransform, IRandomizableTransform): def __init__( self, @@ -419,24 +486,13 @@ def __init__( dtype: Optional[Union[DtypeLike, torch.dtype]] = np.float32, lazy_evaluation: Optional[bool] = True ): - RandomizableTransform.__init__(self, prob) - self.range_x = value_to_tuple_range(range_x) - self.range_y = value_to_tuple_range(range_y) - self.range_z = value_to_tuple_range(range_z) - - self.x, self.y, self.z = 0.0, 0.0, 0.0 + self.randomizer = RotateRandomizer(value_to_tuple_range(range_x), + value_to_tuple_range(range_y), + value_to_tuple_range(range_z), + prob) self.op = Rotate(0, keep_size, mode, padding_mode, align_corners, dtype, lazy_evaluation) - def randomize(self, data: Optional[Any] = None) -> None: - super().randomize(None) - if self._do_transform is True: - self.x, self.y, self.z = 0.0, 0.0, 0.0 - - self.x = self.R.uniform(low=self.range_x[0], high=self.range_x[1]) - self.y = self.R.uniform(low=self.range_y[0], high=self.range_y[1]) - self.z = self.R.uniform(low=self.range_z[0], high=self.range_z[1]) - def __call__( self, img: NdarrayOrTensor, @@ -448,16 +504,17 @@ def __call__( shape_override: Optional[Sequence] = None ) -> NdarrayOrTensor: - if randomize: - self.randomize() + angles = self.randomizer.sample(img) - img_dims = len(img.shape) - 1 - if self._do_transform: - angle = self.x if img_dims == 2 else (self.x, self.y, self.z) - else: - angle = 0 if img_dims == 2 else (0, 0, 0) + return self.op(img, angles, mode, padding_mode, align_corners, shape_override) + + @property + def lazy_evaluation(self): + return self.op.lazy_evaluation - return self.op(img, angle, mode, padding_mode, align_corners, shape_override) + @lazy_evaluation.setter + def lazy_evaluation(self, value): + self.op.lazy_evaluation = value def inverse( self, @@ -547,7 +604,7 @@ def inverse( raise NotImplementedError() -class RandZoom(RandomizableTransform, InvertibleTransform, LazyTransform): +class RandZoom(RandomizableTransform, InvertibleTransform, ILazyTransform): def __init__( self, @@ -558,6 +615,7 @@ def __init__( padding_mode: Optional[Union[GridSamplePadMode, NumpyPadMode, str]] = NumpyPadMode.EDGE, align_corners: Optional[bool] = None, keep_size: bool = True, + lazy_evaluation: Optional[bool] = True, **kwargs ) -> None: RandomizableTransform.__init__(self, prob) @@ -573,7 +631,8 @@ def __init__( self.keep_size = keep_size self.factors = None - self.op = Zoom(1.0, self.mode, self.padding_mode, self.align_corners, self.keep_size) + self.op = Zoom(1.0, self.mode, self.padding_mode, self.align_corners, self.keep_size, + lazy_evaluation=lazy_evaluation) def randomize( self, @@ -612,6 +671,116 @@ def inverse( raise NotImplementedError() +class GridDistortion(LazyTransform): + + def __init__( + self, + num_cells: Union[Tuple[int], int], + distort_steps: Sequence[Sequence[float]], + mode: Optional[Union[GridSampleMode, str]] = GridSampleMode.BILINEAR, + padding_mode: Optional[Union[GridSamplePadMode, NumpyPadMode, str]] = NumpyPadMode.EDGE, + ): + self.num_cells = num_cells + self.distort_steps = distort_steps + self.mode = mode + self.padding_mode = padding_mode + + def __call__( + self, + img: torch.Tensor, + distort_steps: Optional[Sequence[Sequence[float]]], + mode: Optional[Union[GridSampleMode, str]] = None, + padding_mode: Optional[Union[GridSamplePadMode, NumpyPadMode, str]] = None, + ): + distort_steps_ = self.distort_steps if distort_steps is None else distort_steps + num_cells_ = ensure_tuple_rep(self.num_cells, len(img.shape)-1) + mode_ = mode or self.mode + padding_mode_ = mode_ or self.padding_mode + + shape_override_ = None + if isinstance(img, MetaTensor) and img.has_pending_transforms: + shape_override_ = img.peek_pending_transform().metadata.get("shape_override", None) + + img_t, transform, metadata = grid_distortion(img, num_cells_, distort_steps_, + mode_, padding_mode_, shape_override_) + + # TODO: candidate for refactoring into a LazyTransform method + img_t.push_pending_transform(MetaMatrix(transform, metadata)) + if not self.lazy_evaluation: + img_t = apply(img_t) + + return img_t + + def inverse(self, data): + raise NotImplementedError() + + +class Rand3DElastic(LazyTransform, IRandomizableTransform): + + def __init__( + self, + sigma_range: Tuple[float, float], + magnitude_range: Tuple[float, float], + prob: float = 0.1, + rotate_range: RandRange = None, + shear_range: RandRange = None, + translate_range: RandRange = None, + scale_range: RandRange = None, + spatial_size: Optional[Union[Tuple[int, int, int], int]] = None, + mode: Optional[Union[GridSampleMode, str]] = GridSampleMode.BILINEAR, + padding_mode: Optional[Union[GridSamplePadMode, NumpyPadMode, str]] = GridSamplePadMode.REFLECTION, + as_tensor_output: bool = False, + device: Optional[torch.device] = None, + lazy_evaluation: Optional[bool] = True + ): + LazyTransform.__init__(self, lazy_evaluation=lazy_evaluation) + self.spatial_size = spatial_size + self.mode = mode + self.padding_mode = padding_mode + self.device = device + + self.randomizer = Elastic3DRandomizer(sigma_range, magnitude_range, prob) + + self.nop = Identity() + + def __call__( + self, + img: torch.Tensor, + spatial_size: Optional[Union[Tuple[int, int, int], int]] = None, + mode: Optional[Union[GridSampleMode, str]] = None, + padding_mode: Optional[Union[GridSamplePadMode, NumpyPadMode, str]] = None, + randomize: Optional[bool] = True, + shape_override: Tuple[int] = None + ): + mode_ = mode or self.mode + padding_mode_ = padding_mode or self.padding_mode + spatial_size_ = spatial_size or self.spatial_size or img.shape[1:] + + shape_override_ = shape_override + if shape_override is None and (isinstance(img, MetaTensor) and img.has_pending_transforms): + shape_override_ = img.peek_pending_transform().metadata.get("shape_override") + + rand_offsets, magnitude, sigma = self.randomizer.sample(spatial_size_, self.device) + if rand_offsets is None: + return self.nop(img) + else: + + img_t, transform, metadata = elastic_3d(img, + sigma, magnitude, rand_offsets, + spatial_size_, mode_, padding_mode_, + self.device, shape_override=shape_override_) + + # TODO: candidate for refactoring into a LazyTransform method + img_t.push_pending_transform(MetaMatrix(transform, metadata)) + if not self.lazy_evaluation: + img_t = apply(img_t) + + return img_t + + def inverse(self, data): + raise NotImplementedError() + + class Translate(LazyTransform, InvertibleTransform): def __init__( self, @@ -641,7 +810,7 @@ def __call__( dtype = self.dtype shape_override_ = shape_override - if shape_override_ is None and isinstance(img, MetaTensor) and img.has_pending_transforms(): + if shape_override_ is None and isinstance(img, MetaTensor) and img.has_pending_transforms: shape_override_ = img.peek_pending_transform().metadata.get("shape_override", None) img_t, transform, metadata = translate(img, self.translation, @@ -682,7 +851,7 @@ def __call__( slices_ = slices if self.slices is None else self.slices shape_override_ = shape_override - if shape_override_ is None and isinstance(img, MetaTensor) and img.has_pending_transforms(): + if shape_override_ is None and isinstance(img, MetaTensor) and img.has_pending_transforms: shape_override_ = img.peek_pending_transform().metadata.get("shape_override", None) img_t, transform, metadata = croppad(img, slices_, self.padding_mode, shape_override_) diff --git a/monai/transforms/atmostonce/compose.py b/monai/transforms/atmostonce/compose.py index 0576656870..cd6597769c 100644 --- a/monai/transforms/atmostonce/compose.py +++ b/monai/transforms/atmostonce/compose.py @@ -14,85 +14,18 @@ # TODO: this is intended to replace Compose once development is done -# class ComposeCompiler: -# """ -# Args: -# transforms: A sequence of callable transforms -# lazy_resampling: Whether to resample the data after each transform or accumulate -# changes and then resample according to the accumulated changes as few times as -# possible. Defaults to True as this nearly always improves speed and quality -# caching_policy: Whether to cache deterministic transforms before the first -# randomised transforms. This can be one of "off", "drive", "memory" -# caching_favor: Whether to cache primarily for "speed" or for "quality". "speed" will -# favor doing more work before caching, whereas "quality" will favour delaying -# resampling until after caching -# """ -# def __init__( -# self, -# transforms: Union[Sequence[Callable], Callable], -# lazy_resampling: Optional[bool] = True, -# caching_policy: Optional[str] = "off", -# caching_favor: Optional[str] = "quality" -# ): -# valid_caching_policies = ("off", "drive", "memory") -# if caching_policy not in valid_caching_policies: -# raise ValueError("parameter 'caching_policy' must be one of " -# f"{valid_caching_policies} but is '{caching_policy}'") -# -# dest_transforms = None -# -# if caching_policy == "off": -# if lazy_resampling is False: -# dest_transforms = [t for t in transforms] -# else: -# dest_transforms = ComposeCompiler.lazy_no_cache() -# else: -# if caching_policy == "drive": -# raise NotImplementedError() -# elif caching_policy == "memory": -# raise NotImplementedError() -# -# self.src_transforms = [t for t in transforms] -# self.dest_transforms = dest_transforms -# -# def __getitem__( -# self, -# index -# ): -# return self.dest_transforms[index] -# -# def __len__(self): -# return len(self.dest_transforms) -# -# @staticmethod -# def lazy_no_cache(transforms): -# dest_transforms = [] -# # TODO: replace with lazy transform -# cur_lazy = [] -# for i_t in range(1, len(transforms)): -# if isinstance(transforms[i_t], LazyTransform): -# # add this to the stack of transforms to be handled lazily -# cur_lazy.append(transforms[i_t]) -# else: -# if len(cur_lazy) > 0: -# dest_transforms.append(cur_lazy) -# # TODO: replace with lazy transform -# cur_lazy = [] -# dest_transforms.append(transforms[i_t]) -# return dest_transforms - class ComposeCompiler: def compile(self, transforms, cache_mechanism): - transforms_ = self.compile_caching(transforms, cache_mechanism) + transforms1 = self.compile_caching(transforms, cache_mechanism) - transforms__ = self.compile_multisampling(transforms_) + transforms2 = self.compile_multisampling(transforms1) - transforms___ = self.compile_lazy_resampling(transforms__) + transforms3 = self.compile_lazy_resampling(transforms2) - return transforms___ + return transforms3 def compile_caching(self, transforms, cache_mechanism): # TODO: handle being passed a transform list with containers @@ -114,7 +47,7 @@ def compile_multisampling(self, transforms): for i in reversed(range(len(transforms))): if self.transform_is_multisampling(transforms[i]) is True: transforms_ = transforms[:i] + [MultiSampleTransformCompose(transforms[i], - transforms[i+1:])] + transforms[i+1:])] return self.compile_multisampling(transforms_) return list(transforms) @@ -140,7 +73,7 @@ def transform_is_random(self, t): return isinstance(t, IRandomizableTransform) def transform_is_container(self, t): - return isinstance(t, CachedTransformCompose, MultiSampleTransformCompose) + return isinstance(t, (CachedTransformCompose, MultiSampleTransformCompose)) def transform_is_multisampling(self, t): return isinstance(t, IMultiSampleTransform) diff --git a/monai/transforms/atmostonce/dictionary.py b/monai/transforms/atmostonce/dictionary.py index 131f442fd0..59f6a5c102 100644 --- a/monai/transforms/atmostonce/dictionary.py +++ b/monai/transforms/atmostonce/dictionary.py @@ -1,11 +1,12 @@ -from typing import Any, Mapping, Optional, Sequence, Union +from typing import Any, Hashable, Mapping, Optional, Sequence, Tuple, Union import numpy as np import torch -from monai.transforms.atmostonce.array import Rotate, Resize, Spacing, Zoom, CropPad +from monai.transforms.atmostonce.array import Rotate, Resize, Spacing, Zoom, CropPad, RotateRandomizer from monai.transforms.atmostonce.utility import ILazyTransform, IRandomizableTransform, IMultiSampleTransform +from monai.transforms.atmostonce.utils import value_to_tuple_range from monai.utils import ensure_tuple_rep from monai.config import KeysCollection, DtypeLike, SequenceStr @@ -46,8 +47,8 @@ def expand_potential_tuple(keys, value): def keys_to_process( - keys: Sequence[str], - dictionary: dict, + keys: KeysCollection, + dictionary: Mapping[Hashable, torch.Tensor], allow_missing_keys: bool, ): if allow_missing_keys is True: @@ -155,13 +156,10 @@ def __init__(self, self.allow_missing_keys = allow_missing_keys def __call__(self, d: Mapping): + keys = keys_to_process(self.keys, d, self.allow_missing_keys) rd = dict(d) - if self.allow_missing_keys is True: - keys_present = {k for k in self.keys if k in d} - else: - keys_present = self.keys - for ik, k in enumerate(keys_present): + for ik, k in enumerate(keys): tx = Rotate(self.angle, self.keep_size, self.modes[ik], self.padding_modes[ik], self.align_corners, self.dtypes[ik]) @@ -173,6 +171,49 @@ def inverse(self, data: Any): raise NotImplementedError() +class RandRotated(MapTransform, InvertibleTransform, LazyTransform, IRandomizableTransform): + + def __init__( + self, + keys: KeysCollection, + range_x: Union[Tuple[float, float], float] = 0.0, + range_y: Union[Tuple[float, float], float] = 0.0, + range_z: Union[Tuple[float, float], float] = 0.0, + prob: float = 0.1, + keep_size: bool = True, + mode: SequenceStr = GridSampleMode.BILINEAR, + padding_mode: SequenceStr = GridSamplePadMode.BORDER, + align_corners: Union[Sequence[bool], bool] = False, + dtype: Union[Sequence[Union[DtypeLike, torch.dtype]], DtypeLike, torch.dtype] = np.float32, + lazy_evaluation: Optional[bool] = True, + allow_missing_keys: Optional[bool] = False, + ): + self.keys = keys + self.allow_missing_keys = allow_missing_keys + self.randomizer = RotateRandomizer(value_to_tuple_range(range_x), + value_to_tuple_range(range_y), + value_to_tuple_range(range_z), + prob) + self.op = Rotate(0, keep_size, mode, padding_mode, align_corners, dtype, lazy_evaluation) + + def __call__( + self, + data: Mapping[Hashable, torch.Tensor] + ): + keys = keys_to_process(self.keys, data, self.allow_missing_keys) + rd = dict(data) + + angles = self.randomizer.sample(data[keys[0]]) + + for ik, k in enumerate(keys): + rd[k] = self.op(data[k], angles) + + return rd + + def inverse(self, data): + raise NotImplementedError() + + class Resized(LazyTransform, MapTransform, InvertibleTransform): def __init__(self, @@ -399,7 +440,7 @@ def __call__( d: dict, randomize: Optional[bool] = True ): - for i in range(self.sample_count): + for _ in range(self.sample_count): yield self.op(d, randomize) def inverse( @@ -413,4 +454,4 @@ def set_random_state(self, seed=None, state=None): @property def lazy_evaluation(self): - return self.op.lazy_evaluation \ No newline at end of file + return self.op.lazy_evaluation diff --git a/monai/transforms/atmostonce/functional.py b/monai/transforms/atmostonce/functional.py index 101fe58180..22f5c81eb8 100644 --- a/monai/transforms/atmostonce/functional.py +++ b/monai/transforms/atmostonce/functional.py @@ -3,8 +3,11 @@ import numpy as np import torch +from monai.networks.layers import GaussianFilter -from monai.transforms import create_rotate, create_translate, map_spatial_axes +from monai.networks.utils import meshgrid_ij + +from monai.transforms import create_rotate, create_translate, map_spatial_axes, create_grid from monai.data import get_track_meta from monai.transforms.atmostonce.apply import extents_from_shape, shape_from_extents @@ -370,6 +373,81 @@ def rotate90( return img_, transform, metadata +def grid_distortion( + img: torch.Tensor, + num_cells: Union[Tuple[int], int], + distort_steps: Sequence[Sequence[float]], + mode: str = GridSampleMode.BILINEAR, + padding_mode: str = GridSamplePadMode.BORDER, + shape_override: Optional[Tuple[int]] = None +): + all_ranges = [] + num_cells = ensure_tuple_rep(num_cells, len(img.shape) - 1) + for dim_idx, dim_size in enumerate(img.shape[1:]): + dim_distort_steps = distort_steps[dim_idx] + ranges = torch.zeros(dim_size, dtype=torch.float32) + cell_size = dim_size // num_cells[dim_idx] + prev = 0 + for idx in range(num_cells[dim_idx] + 1): + start = int(idx * cell_size) + end = start + cell_size + if end > dim_size: + end = dim_size + cur = dim_size + else: + cur = prev + cell_size * dim_distort_steps[idx] + prev = cur + ranges = range - (dim_size - 1.0) / 2.0 + all_ranges.append() + coords = meshgrid_ij(*all_ranges) + grid = torch.stack([*coords, torch.ones_like(coords[0])]) + + metadata = { + "num_cells": num_cells, + "distort_steps": distort_steps, + "mode": mode, + "padding_mode": padding_mode + } + + return img, grid, metadata + + +def elastic_3d( + img: torch.Tensor, + sigma: float, + magnitude: float, + offsets: torch.Tensor, + spatial_size: Optional[Union[Tuple[int, int, int], int]] = None, + mode: str = GridSampleMode.BILINEAR, + padding_mode: str = GridSamplePadMode.REFLECTION, + device: Optional[torch.device] = None, + shape_override: Optional[Tuple[float]] = None +): + img_ = convert_to_tensor(img, track_meta=get_track_meta()) + + sp_size = fall_back_tuple(spatial_size, img.shape[1:]) + device_ = img.device if isinstance(img, torch.Tensor) else device + grid = create_grid(spatial_size=sp_size, device=device_, backend="torch") + gaussian = GaussianFilter(3, sigma, 3.0).to(device=device_) + grid[:3] += gaussian(offsets)[0] * magnitude + + metadata = { + "sigma": sigma, + "magnitude": magnitude, + "offsets": offsets, + } + if spatial_size is not None: + metadata["spatial_size"] = spatial_size + if mode is not None: + metadata["mode"] = mode + if padding_mode is not None: + metadata["padding_mode"] = padding_mode + if shape_override is not None: + metadata["shape_override"] = shape_override + + return img_, grid, metadata + + def translate( img: torch.Tensor, translation: Sequence[float], diff --git a/monai/transforms/atmostonce/randomizers.py b/monai/transforms/atmostonce/randomizers.py new file mode 100644 index 0000000000..fbcb6381aa --- /dev/null +++ b/monai/transforms/atmostonce/randomizers.py @@ -0,0 +1,105 @@ +import numpy as np + +import torch + + +class Randomizer: + + def __init__( + self, + prob: float = 1.0, + seed=None, + state=None + ): + self.R = None + self.set_random_state(seed, state) + + if not 0.0 <= prob <= 1.0: + raise ValueError(f"'prob' must be between 0.0 and 1.0 inclusive but is {prob}") + self.prob = prob + + def set_random_state(self, seed=None, state=None): + if seed is not None: + self.R = np.random.RandomState(seed) + elif state is not None: + self.R = state + else: + self.R = np.random.RandomState() + + def do_random(self): + return self.R.uniform() < self.prob + + def sample(self): + return self.R.uniform() + + +class RotateRandomizer(Randomizer): + + def __init__( + self, + range_x, + range_y, + range_z, + prob: float = 1.0, + seed=None, + state=None, + ): + super().__init__(prob, state, seed) + self.range_x = range_x + self.range_y = range_y + self.range_z = range_z + + def sample( + self, + data: torch.Tensor = None + ): + if not isinstance(data, (np.ndarray, torch.Tensor)): + raise ValueError("data must be a numpy ndarray or torch tensor but is of " + f"type {type(data)}") + + spatial_shape = len(data.shape[1:]) + if spatial_shape == 2: + if self.do_random(): + return self.R.uniform(self.range_x[0], self.range_x[1]) + return 0.0 + elif spatial_shape == 3: + if self.do_random(): + x = self.R.uniform(self.range_x[0], self.range_x[1]) + y = self.R.uniform(self.range_y[0], self.range_y[1]) + z = self.R.uniform(self.range_z[0], self.range_z[1]) + return x, y, z + return 0.0, 0.0, 0.0 + else: + raise ValueError("data should be a tensor with 2 or 3 spatial dimensions but it " + f"has {spatial_shape} spatial dimensions") + + +class Elastic3DRandomizer(Randomizer): + + def __init__( + self, + sigma_range, + magnitude_range, + prob=1.0, + grid_size=None, + seed=None, + state=None, + ): + super().__init__(prob, seed, state) + self.grid_size = grid_size + self.sigma_range = sigma_range + self.magnitude_range = magnitude_range + + def sample( + self, + grid_size, + device + ): + if self.do_random(): + rand_offsets = self.R.uniform(-1.0, 1.0, [3] + list(grid_size)).astype(np.float32, copy=False) + rand_offsets = torch.as_tensor(rand_offsets, device=device).unsqueeze(0) + sigma = self.R.uniform(self.sigma_range[0], self.sigma_range[1]) + magnitude = self.R.uniform(self.magnitude_range[0], self.magnitude_range[1]) + return rand_offsets, magnitude, sigma + + return None, None, None diff --git a/monai/transforms/atmostonce/utility.py b/monai/transforms/atmostonce/utility.py index 4b7036e9db..43e84ec4fa 100644 --- a/monai/transforms/atmostonce/utility.py +++ b/monai/transforms/atmostonce/utility.py @@ -7,7 +7,14 @@ class ILazyTransform(abc.ABC): - pass + + @property + def lazy_evaluation(self): + raise NotImplementedError() + + @lazy_evaluation.setter + def lazy_evaluation(self, lazy_evaluation): + raise NotImplementedError() class IMultiSampleTransform(abc.ABC): diff --git a/tests/test_atmostonce.py b/tests/test_atmostonce.py index 035b235295..beabfe3da6 100644 --- a/tests/test_atmostonce.py +++ b/tests/test_atmostonce.py @@ -578,6 +578,14 @@ def test_rand_zoom(self): t_out = r(t) print(t_out.shape) + def test_deform_grid(self): + r = spatialarray.Rand2DElastic((1, 1), + (0.1, 0.2), + 1.0) + img = get_img((16, 16)) + result = r(img) + print(result) + def test_center_spatial_crop(self): r = ocpa.CenterSpatialCrop(4) img = get_img((8, 8)) @@ -669,6 +677,75 @@ def test_value_to_tuple_range(self): self.assertTupleEqual(value_to_tuple_range((4.3, -2.1)), (-2.1, 4.3)) +class TestRotate(unittest.TestCase): + + def _test_rotate_array_nonlazy(self, r, t, expected): + t_out = r(t) + self.assertTrue(torch.allclose(t_out.affine, expected)) + self.assertFalse(t_out.has_pending_transforms) + + def _test_rotate_array_lazy(self, r, t, expected): + t_out = r(t) + self.assertTrue(torch.allclose(t_out.affine, torch.eye(4, 4, dtype=torch.double))) + self.assertTrue(t_out.has_pending_transforms) + self.assertTrue(torch.allclose(t_out.peek_pending_transform().matrix.matrix, expected)) + + def test_rotate(self): + r = amoa.Rotate(torch.pi, + keep_size=True, + mode="nearest", + padding_mode="zeros", + lazy_evaluation=False) + t = get_img((16, 16)) + + expected = torch.eye(4, 4, dtype=torch.double) + expected[0, :] = torch.DoubleTensor([-1, 0, 0, 15]) + expected[1, :] = torch.DoubleTensor([0, -1, 0, 15]) + self._test_rotate_array_nonlazy(r, t, expected) + + def test_rand_rotate(self): + r = amoa.RandRotate((0, torch.pi * 2), + (0, torch.pi * 2), + (0, torch.pi * 2), + prob=0.5, + keep_size=True, + mode="nearest", + padding_mode="zeros", + lazy_evaluation=False) + t = get_img((16, 16)) + + expected = torch.eye(4, 4, dtype=torch.double) + expected[0, :] = torch.DoubleTensor([-1, 0, 0, 15]) + expected[1, :] = torch.DoubleTensor([0, -1, 0, 15]) + r.randomizer.set_random_state(state=FakeRand(uniforms=(0.25, torch.pi))) + self._test_rotate_array_nonlazy(r, t, expected) + + expected = torch.eye(4, 4, dtype=torch.double) + r.randomizer.set_random_state(state=FakeRand(uniforms=(0.75, torch.pi))) + self._test_rotate_array_nonlazy(r, t, expected) + + r.lazy_evaluation = True + + expected = torch.eye(3, 3, dtype=torch.double) + expected[0, :] = torch.DoubleTensor([-1, 0, 0]) + expected[1, :] = torch.DoubleTensor([0, -1, 0]) + r.randomizer.set_random_state(state=FakeRand(uniforms=(0.25, torch.pi))) + self._test_rotate_array_lazy(r, t, expected) + + expected = torch.eye(3, 3, dtype=torch.double) + r.randomizer.set_random_state(state=FakeRand(uniforms=(0.75, torch.pi))) + self._test_rotate_array_lazy(r, t, expected) + + +class TestRand3DElastic(unittest.TestCase): + + def test_array(self): + img = get_img((16, 16, 8)) + r = amoa.Rand3DElastic((0.5, 1.5), (0, 1), 1.0, mode="nearest", padding_mode="zeros") + result = r(img) + print(result.shape) + + class TestCropPad(unittest.TestCase): def _test_functional(self, targs, img, expected): @@ -920,6 +997,7 @@ def __repr__(self): cc = ComposeCompiler() actual = cc.compile_multisampling(source_transforms) + print(actual) self.assertEqual(actual[0], a) self.assertEqual(actual[1], b)