Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
99 commits
Select commit Hold shift + click to select a range
d2a040b
inverse transformations
rijobro Mar 1, 2021
7dd3539
autofix
rijobro Mar 1, 2021
3236334
fix rand elastic 3d
rijobro Mar 1, 2021
faf771c
fixes
rijobro Mar 1, 2021
8d89a10
Merge branch 'master' into inverse_transformation
rijobro Mar 1, 2021
1ca5939
fix 2d elastic
rijobro Mar 1, 2021
7e7d32f
Merge remote-tracking branch 'rijobro/inverse_transformation' into in…
rijobro Mar 1, 2021
a5138bb
code format
rijobro Mar 1, 2021
3bedaa7
Merge remote-tracking branch 'MONAI/master' into inverse_transformation
rijobro Mar 2, 2021
0e89139
update inverse docstrings
rijobro Mar 2, 2021
c3fa403
inverse for randaxisflipd
rijobro Mar 2, 2021
b8cea25
remove shuffle
rijobro Mar 2, 2021
d4bfbb3
merge
rijobro Mar 2, 2021
088f626
debug message
rijobro Mar 2, 2021
d260e35
Merge remote-tracking branch 'MONAI/master' into inverse_transformation
rijobro Mar 2, 2021
4b7ba75
don't write to file if no nibabel
rijobro Mar 2, 2021
415efd9
Merge branch 'master' into inverse_transformation
rijobro Mar 3, 2021
e8bbdd7
skip if no meta data key
rijobro Mar 3, 2021
af791a5
more lenient thresholds for tests
rijobro Mar 3, 2021
ffc1b9b
isort
rijobro Mar 3, 2021
2d6f9a1
mypy
rijobro Mar 3, 2021
7a5c44c
Merge remote-tracking branch 'MONAI/master' into inverse_transformation
rijobro Mar 3, 2021
e5804e1
tests require nibabel
rijobro Mar 3, 2021
e62a4ce
undo skip if no metadata
rijobro Mar 3, 2021
df76f07
update test_decollate
rijobro Mar 3, 2021
cd0f305
isort
rijobro Mar 3, 2021
ed2f12c
Merge remote-tracking branch 'MONAI/master' into inverse_transformation
rijobro Mar 3, 2021
014f7df
changes
rijobro Mar 4, 2021
b61aaf5
Merge remote-tracking branch 'MONAI/master' into inverse_transformation
rijobro Mar 4, 2021
96c34c1
Merge branch 'master' into inverse_transformation
rijobro Mar 4, 2021
665237b
Merge remote-tracking branch 'MONAI/master' into inverse_transformation
rijobro Mar 5, 2021
d4e2c54
inverse_transform.py -> inverse.py
rijobro Mar 5, 2021
2a0eb62
with AllowMissingKeysMode
rijobro Mar 5, 2021
53c58b0
remove keys from inverse method
rijobro Mar 5, 2021
62f8a6a
Merge remote-tracking branch 'MONAI/master' into inverse_transformation
rijobro Mar 5, 2021
7036ac8
autofixes
rijobro Mar 5, 2021
3c41bdd
Merge remote-tracking branch 'MONAI/master' into inverse_transformation
rijobro Mar 8, 2021
dbc0770
update for allow_missing_keys_mode
rijobro Mar 8, 2021
2915346
inverse to use apply_transform
rijobro Mar 8, 2021
8efd75e
add enums
rijobro Mar 8, 2021
0ec2e98
Merge remote-tracking branch 'MONAI/master' into inverse_transformation
rijobro Mar 8, 2021
a0ad428
push_ and pop_transform
rijobro Mar 8, 2021
b7b17a0
update doc
rijobro Mar 8, 2021
9313cff
basic API
rijobro Mar 8, 2021
bc559af
code format
rijobro Mar 8, 2021
d8ad0ee
code format
rijobro Mar 8, 2021
8f258e5
Merge branch 'master' into inv_trans_API
rijobro Mar 8, 2021
c7d59d5
formatting docstring
wyli Mar 9, 2021
01adfcc
Merge remote-tracking branch 'MONAI/master' into inv_trans_API
rijobro Mar 10, 2021
6c913d8
enum changes
rijobro Mar 10, 2021
5a8a08b
put matplotlib functionality in docstrings
rijobro Mar 10, 2021
841f19c
update module list
wyli Mar 10, 2021
af24559
Merge remote-tracking branch 'MONAI/master' into inv_trans_API
rijobro Mar 10, 2021
595f1cd
skip decollate id check for windows and mac
rijobro Mar 10, 2021
2c6b48e
fix test for windows and mac
rijobro Mar 10, 2021
915d479
Merge branch 'inv_trans_API' into inverse_transformation
rijobro Mar 10, 2021
bc529c7
Merge remote-tracking branch 'MONAI/master' into inverse_transformation
rijobro Mar 10, 2021
358d653
update merge
rijobro Mar 10, 2021
917c47e
update merge 2
rijobro Mar 10, 2021
08bcc36
Merge remote-tracking branch 'MONAI/master' into inverse_transformation
rijobro Mar 15, 2021
97e0ad3
remove duplicate tests
rijobro Mar 15, 2021
ea31dd7
lossless inverse
rijobro Mar 15, 2021
2a82685
Merge remote-tracking branch 'MONAI/master' into inverse_transform_lo…
rijobro Mar 15, 2021
28f5a36
code format
rijobro Mar 15, 2021
fed515e
remove extra tests
rijobro Mar 15, 2021
4c3074b
Merge remote-tracking branch 'MONAI/master' into inverse_transform_lo…
rijobro Mar 15, 2021
463d034
Merge branch 'master' into inverse_transform_lossless
wyli Mar 15, 2021
c72407a
update tests
rijobro Mar 16, 2021
b7b53cf
Merge remote-tracking branch 'MONAI/master' into inverse_transform_lo…
rijobro Mar 16, 2021
0d26844
Merge remote-tracking branch 'rijobro/inverse_transform_lossless' int…
rijobro Mar 16, 2021
5f7f2c7
Merge branch 'master' into inverse_transform_lossless
rijobro Mar 16, 2021
14c978e
Merge branch 'inverse_transform_lossless' into inverse_transformation
rijobro Mar 16, 2021
070ae46
update after merge
rijobro Mar 16, 2021
e4b80e3
test fixes
rijobro Mar 16, 2021
07155e8
merge
rijobro Mar 16, 2021
9907eeb
Zoomd and RandZoomd
rijobro Mar 16, 2021
ae2e60f
add SpatialPad
rijobro Mar 16, 2021
1eae3a7
merge
rijobro Mar 16, 2021
797a94e
Inverse Spacingd
rijobro Mar 16, 2021
b26ead6
merge
rijobro Mar 16, 2021
08165fb
Inverse Resized
rijobro Mar 16, 2021
d4f65c5
merge
rijobro Mar 16, 2021
f22f7c0
inverse RandAffined
rijobro Mar 16, 2021
ff14329
undo unintentional change
rijobro Mar 16, 2021
3b985b3
inverse Affined
rijobro Mar 16, 2021
62bbb21
code format
rijobro Mar 16, 2021
1c43fb2
Merge branch 'inverse_Affined' into inverse_transformation
rijobro Mar 16, 2021
42e8104
move affine
rijobro Mar 16, 2021
384013a
remove duplicate tests
rijobro Mar 16, 2021
f33a731
Inverse Rotated and RandRotated
rijobro Mar 16, 2021
100bc7d
add random
rijobro Mar 16, 2021
067e161
add return rotation matrix
rijobro Mar 16, 2021
6403f93
remove typo
rijobro Mar 16, 2021
293c700
merge
rijobro Mar 16, 2021
86c63c1
Merge remote-tracking branch 'MONAI/master' into inverse_Affined
rijobro Mar 17, 2021
a3f50e2
remove return_affine
rijobro Mar 17, 2021
e2a1dac
code format
rijobro Mar 17, 2021
c4a8c48
Merge branch 'inverse_Affined' into inverse_transformation
rijobro Mar 17, 2021
4a8233b
Merge remote-tracking branch 'MONAI/master' into inverse_transformation
rijobro Mar 18, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@
ThresholdIntensityD,
ThresholdIntensityDict,
)
from .inverse import InvertibleTransform
from .inverse import InvertibleTransform, NonRigidTransform
from .io.array import LoadImage, SaveImage
from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict, SaveImaged, SaveImageD, SaveImageDict
from .post.array import (
Expand Down
130 changes: 130 additions & 0 deletions monai/transforms/inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from typing import Dict, Hashable, Optional, Tuple

import numpy as np
import torch

from monai.transforms.transform import RandomizableTransform, Transform
from monai.utils.enums import InverseKeys
from monai.utils.module import optional_import

sitk, has_sitk = optional_import("SimpleITK")
vtk, has_vtk = optional_import("vtk")
vtk_numpy_support, _ = optional_import("vtk.util.numpy_support")

__all__ = ["InvertibleTransform"]

Expand Down Expand Up @@ -119,3 +125,127 @@ def inverse(self, data: dict) -> Dict[Hashable, np.ndarray]:

"""
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")


class NonRigidTransform(Transform):
@staticmethod
def _get_disp_to_def_arr(shape, spacing):
def_to_disp = np.mgrid[[slice(0, i) for i in shape]].astype(np.float64)
for idx, i in enumerate(shape):
# shift for origin (in MONAI, center of image)
def_to_disp[idx] -= (i - 1) / 2
# if supplied, account for spacing (e.g., for control point grids)
if spacing is not None:
def_to_disp[idx] *= spacing[idx]
return def_to_disp

@staticmethod
def _inv_disp_w_sitk(fwd_disp, num_iters):
fwd_disp_sitk = sitk.GetImageFromArray(fwd_disp, isVector=True)
inv_disp_sitk = sitk.InvertDisplacementField(fwd_disp_sitk, num_iters)
inv_disp = sitk.GetArrayFromImage(inv_disp_sitk)
return inv_disp

@staticmethod
def _inv_disp_w_vtk(fwd_disp):
orig_shape = fwd_disp.shape
required_num_tensor_components = 3
# VTK requires 3 tensor components, so if shape was (H, W, 2), make it
# (H, W, 1, 3) (i.e., depth 1 with a 3rd tensor component of 0s)
while fwd_disp.shape[-1] < required_num_tensor_components:
fwd_disp = np.append(fwd_disp, np.zeros(fwd_disp.shape[:-1] + (1,)), axis=-1)
fwd_disp = fwd_disp[..., None, :]

# Create VTKDoubleArray. Shape needs to be (H*W*D, 3)
fwd_disp_flattened = fwd_disp.reshape(-1, required_num_tensor_components) # need to keep this in memory
vtk_data_array = vtk_numpy_support.numpy_to_vtk(fwd_disp_flattened)

# Generating the vtkImageData
fwd_disp_vtk = vtk.vtkImageData()
fwd_disp_vtk.SetOrigin(0, 0, 0)
fwd_disp_vtk.SetSpacing(1, 1, 1)
fwd_disp_vtk.SetDimensions(*fwd_disp.shape[:-1][::-1]) # VTK spacing opposite order to numpy
fwd_disp_vtk.GetPointData().SetScalars(vtk_data_array)

if __debug__:
fwd_disp_vtk_np = vtk_numpy_support.vtk_to_numpy(fwd_disp_vtk.GetPointData().GetArray(0))
assert fwd_disp_vtk_np.size == fwd_disp.size
assert fwd_disp_vtk_np.min() == fwd_disp.min()
assert fwd_disp_vtk_np.max() == fwd_disp.max()
assert fwd_disp_vtk.GetNumberOfScalarComponents() == required_num_tensor_components

# create b-spline coefficients for the displacement grid
bspline_filter = vtk.vtkImageBSplineCoefficients()
bspline_filter.SetInputData(fwd_disp_vtk)
bspline_filter.Update()

# use these b-spline coefficients to create a transform
bspline_transform = vtk.vtkBSplineTransform()
bspline_transform.SetCoefficientData(bspline_filter.GetOutput())
bspline_transform.Update()

# invert the b-spline transform onto a new grid
grid_maker = vtk.vtkTransformToGrid()
grid_maker.SetInput(bspline_transform.GetInverse())
grid_maker.SetGridOrigin(fwd_disp_vtk.GetOrigin())
grid_maker.SetGridSpacing(fwd_disp_vtk.GetSpacing())
grid_maker.SetGridExtent(fwd_disp_vtk.GetExtent())
grid_maker.SetGridScalarTypeToFloat()
grid_maker.Update()

# Get inverse displacement as an image
inv_disp_vtk = grid_maker.GetOutput()

# Convert back to numpy and reshape
inv_disp = vtk_numpy_support.vtk_to_numpy(inv_disp_vtk.GetPointData().GetArray(0))
# if there were originally < 3 tensor components, remove the zeros we added at the start
inv_disp = inv_disp[..., : orig_shape[-1]]
# reshape to original
inv_disp = inv_disp.reshape(orig_shape)

return inv_disp

@staticmethod
def compute_inverse_deformation(
num_spatial_dims, fwd_def_orig, spacing=None, num_iters: int = 100, use_package: str = "vtk"
):
"""Package can be vtk or sitk."""
if use_package.lower() == "vtk" and not has_vtk:
warnings.warn("Please install VTK to estimate inverse of non-rigid transforms. Data has not been modified")
return None
if use_package.lower() == "sitk" and not has_sitk:
warnings.warn(
"Please install SimpleITK to estimate inverse of non-rigid transforms. Data has not been modified"
)
return None

# Convert to numpy if necessary
if isinstance(fwd_def_orig, torch.Tensor):
fwd_def_orig = fwd_def_orig.cpu().numpy()
# Remove any extra dimensions (we'll add them back in at the end)
fwd_def = fwd_def_orig[:num_spatial_dims]
# Def -> disp
def_to_disp = NonRigidTransform._get_disp_to_def_arr(fwd_def.shape[1:], spacing)
fwd_disp = fwd_def - def_to_disp
# move tensor component to end (T,H,W,[D])->(H,W,[D],T)
fwd_disp = np.moveaxis(fwd_disp, 0, -1)

# If using vtk...
if use_package.lower() == "vtk":
inv_disp = NonRigidTransform._inv_disp_w_vtk(fwd_disp)
# If using sitk...
elif use_package.lower() == "sitk":
inv_disp = NonRigidTransform._inv_disp_w_sitk(fwd_disp, num_iters)
else:
raise RuntimeError("Enter vtk or sitk for inverse calculation")

# move tensor component back to beginning
inv_disp = np.moveaxis(inv_disp, -1, 0)
# Disp -> def
inv_def = inv_disp + def_to_disp
# Add back in any removed dimensions
ndim_in = fwd_def_orig.shape[0]
ndim_out = inv_def.shape[0]
inv_def = np.concatenate([inv_def, fwd_def_orig[ndim_out:ndim_in]])

return inv_def
4 changes: 3 additions & 1 deletion monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,8 @@ def __call__(
padding_mode: Optional[Union[GridSamplePadMode, str]] = None,
align_corners: Optional[bool] = None,
dtype: DtypeLike = None,
) -> np.ndarray:
return_rotation_matrix: bool = False,
) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
"""
Args:
img: channel first array, must have shape: [chns, H, W] or [chns, H, W, D].
Expand All @@ -447,6 +448,7 @@ def __call__(
dtype: data type for resampling computation. Defaults to ``self.dtype``.
If None, use the data type of input data. To be compatible with other modules,
the output data type is always ``np.float32``.
return_rotation_matrix: whether or not to return the applied rotation matrix.

Raises:
ValueError: When ``img`` spatially is not one of [2D, 3D].
Expand Down
126 changes: 107 additions & 19 deletions monai/transforms/spatial/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from monai.networks.layers import AffineTransform
from monai.networks.layers.simplelayers import GaussianFilter
from monai.transforms.croppad.array import CenterSpatialCrop, SpatialPad
from monai.transforms.inverse import InvertibleTransform
from monai.transforms.inverse import InvertibleTransform, NonRigidTransform
from monai.transforms.spatial.array import (
Affine,
AffineGrid,
Expand All @@ -50,9 +50,9 @@
ensure_tuple,
ensure_tuple_rep,
fall_back_tuple,
optional_import,
)
from monai.utils.enums import InverseKeys
from monai.utils.module import optional_import

nib, _ = optional_import("nibabel")

Expand Down Expand Up @@ -730,7 +730,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar
return d


class Rand2DElasticd(RandomizableTransform, MapTransform):
class Rand2DElasticd(RandomizableTransform, MapTransform, InvertibleTransform, NonRigidTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.Rand2DElastic`.
"""
Expand Down Expand Up @@ -822,6 +822,17 @@ def randomize(self, spatial_size: Sequence[int]) -> None:
super().randomize(None)
self.rand_2d_elastic.randomize(spatial_size)

@staticmethod
def cpg_to_dvf(cpg, spacing, output_shape):
grid = torch.nn.functional.interpolate(
recompute_scale_factor=True,
input=cpg.unsqueeze(0),
scale_factor=ensure_tuple_rep(spacing, 2),
mode=InterpolateMode.BICUBIC.value,
align_corners=False,
)
return CenterSpatialCrop(roi_size=output_shape)(grid[0])

def __call__(
self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]]
) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]:
Expand All @@ -831,25 +842,63 @@ def __call__(
self.randomize(spatial_size=sp_size)

if self._do_transform:
grid = self.rand_2d_elastic.deform_grid(spatial_size=sp_size)
grid = self.rand_2d_elastic.rand_affine_grid(grid=grid)
grid = torch.nn.functional.interpolate( # type: ignore
recompute_scale_factor=True,
input=grid.unsqueeze(0),
scale_factor=ensure_tuple_rep(self.rand_2d_elastic.deform_grid.spacing, 2),
mode=InterpolateMode.BICUBIC.value,
align_corners=False,
)
grid = CenterSpatialCrop(roi_size=sp_size)(grid[0])
cpg = self.rand_2d_elastic.deform_grid(spatial_size=sp_size)
cpg_w_affine, affine = self.rand_2d_elastic.rand_affine_grid(grid=cpg, return_affine=True)
grid = self.cpg_to_dvf(cpg_w_affine, self.rand_2d_elastic.deform_grid.spacing, sp_size)
extra_info: Optional[Dict] = {"cpg": deepcopy(cpg), "affine": deepcopy(affine)}
else:
grid = create_grid(spatial_size=sp_size)
extra_info = None

for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode):
self.push_transform(d, key, extra_info=extra_info)
d[key] = self.rand_2d_elastic.resampler(d[key], grid, mode=mode, padding_mode=padding_mode)
return d

def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
d = deepcopy(dict(data))
# This variable will be `not None` if vtk or sitk is present
inv_def_no_affine = None

for idx, (key, mode, padding_mode) in enumerate(self.key_iterator(d, self.mode, self.padding_mode)):
transform = self.get_most_recent_transform(d, key)
# Create inverse transform
if transform[InverseKeys.DO_TRANSFORM.value]:
orig_size = transform[InverseKeys.ORIG_SIZE.value]
# Only need to calculate inverse deformation once as it is the same for all keys
if idx == 0:
# If magnitude == 0, then non-rigid component is identity -- so just create blank
if self.rand_2d_elastic.deform_grid.magnitude == (0.0, 0.0):
inv_def_no_affine = create_grid(spatial_size=orig_size)
else:
fwd_cpg_no_affine = transform[InverseKeys.EXTRA_INFO.value]["cpg"]
fwd_def_no_affine = self.cpg_to_dvf(
fwd_cpg_no_affine, self.rand_2d_elastic.deform_grid.spacing, orig_size
)
inv_def_no_affine = self.compute_inverse_deformation(len(orig_size), fwd_def_no_affine)
# if inverse did not succeed (sitk or vtk present), data will not be changed.
if inv_def_no_affine is not None:
fwd_affine = transform[InverseKeys.EXTRA_INFO.value]["affine"]
inv_affine = np.linalg.inv(fwd_affine)
inv_def_w_affine_wrong_size = AffineGrid(affine=inv_affine, as_tensor_output=False)(
grid=inv_def_no_affine
)
# Back to original size
inv_def_w_affine = CenterSpatialCrop(roi_size=orig_size)(inv_def_w_affine_wrong_size) # type: ignore
# Apply inverse transform
if inv_def_no_affine is not None:
out = self.rand_2d_elastic.resampler(d[key], inv_def_w_affine, mode, padding_mode)
d[key] = out.cpu().numpy() if isinstance(out, torch.Tensor) else out

else:
d[key] = CenterSpatialCrop(roi_size=orig_size)(d[key])
# Remove the applied transform
self.pop_transform(d, key)

return d


class Rand3DElasticd(RandomizableTransform, MapTransform):
class Rand3DElasticd(RandomizableTransform, MapTransform, InvertibleTransform, NonRigidTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.Rand3DElastic`.
"""
Expand Down Expand Up @@ -949,17 +998,54 @@ def __call__(
sp_size = fall_back_tuple(self.rand_3d_elastic.spatial_size, data[self.keys[0]].shape[1:])

self.randomize(grid_size=sp_size)
grid = create_grid(spatial_size=sp_size)
grid_no_affine = create_grid(spatial_size=sp_size)
affine = np.eye(4)
if self._do_transform:
device = self.rand_3d_elastic.device
grid = torch.tensor(grid).to(device)
grid_no_affine = torch.tensor(grid_no_affine).to(device)
gaussian = GaussianFilter(spatial_dims=3, sigma=self.rand_3d_elastic.sigma, truncated=3.0).to(device)
offset = torch.tensor(self.rand_3d_elastic.rand_offset, device=device).unsqueeze(0)
grid[:3] += gaussian(offset)[0] * self.rand_3d_elastic.magnitude
grid = self.rand_3d_elastic.rand_affine_grid(grid=grid)
grid_no_affine[:3] += gaussian(offset)[0] * self.rand_3d_elastic.magnitude
grid_w_affine, affine = self.rand_3d_elastic.rand_affine_grid(grid=grid_no_affine, return_affine=True)
else:
grid_w_affine = grid_no_affine
affine = np.eye(len(sp_size) + 1)

for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode):
d[key] = self.rand_3d_elastic.resampler(d[key], grid, mode=mode, padding_mode=padding_mode)
self.push_transform(d, key, extra_info={"grid_no_affine": grid_no_affine, "affine": affine})
d[key] = self.rand_3d_elastic.resampler(d[key], grid_w_affine, mode=mode, padding_mode=padding_mode)
return d

def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
d = deepcopy(dict(data))

for idx, (key, mode, padding_mode) in enumerate(self.key_iterator(d, self.mode, self.padding_mode)):
transform = self.get_most_recent_transform(d, key)
# Create inverse transform
if transform[InverseKeys.DO_TRANSFORM.value]:
orig_size = transform[InverseKeys.ORIG_SIZE.value]
# Only need to calculate inverse deformation once as it is the same for all keys
if idx == 0:
fwd_def_no_affine = transform[InverseKeys.EXTRA_INFO.value]["grid_no_affine"]
inv_def_no_affine = self.compute_inverse_deformation(len(orig_size), fwd_def_no_affine)
# if inverse did not succeed (sitk or vtk present), data will not be changed.
if inv_def_no_affine is not None:
fwd_affine = transform[InverseKeys.EXTRA_INFO.value]["affine"]
inv_affine = np.linalg.inv(fwd_affine)
inv_def_w_affine_wrong_size = AffineGrid(affine=inv_affine, as_tensor_output=False)(
grid=inv_def_no_affine
)
# Back to original size
inv_def_w_affine = CenterSpatialCrop(roi_size=orig_size)(inv_def_w_affine_wrong_size) # type: ignore
# Apply inverse transform
if inv_def_w_affine is not None:
out = self.rand_3d_elastic.resampler(d[key], inv_def_w_affine, mode, padding_mode)
d[key] = out.cpu().numpy() if isinstance(out, torch.Tensor) else out
else:
d[key] = CenterSpatialCrop(roi_size=orig_size)(d[key])
# Remove the applied transform
self.pop_transform(d, key)

return d


Expand Down Expand Up @@ -1169,6 +1255,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda
padding_mode=padding_mode,
align_corners=align_corners,
dtype=dtype,
return_rotation_matrix=True,
)
rot_mat = self.rotator.get_rotation_matrix()
self.push_transform(d, key, orig_size=orig_size, extra_info={"rot_mat": rot_mat})
Expand Down Expand Up @@ -1302,6 +1389,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda
padding_mode=padding_mode,
align_corners=align_corners,
dtype=dtype,
return_rotation_matrix=True,
)
rot_mat = rotator.get_rotation_matrix()
self.push_transform(d, key, orig_size=orig_size, extra_info={"rot_mat": rot_mat})
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,4 @@ sphinx-autodoc-typehints==1.11.1
sphinx-rtd-theme==0.5.0
cucim==0.18.1
openslide-python==1.1.2
vtk
Loading