Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions monai/handlers/transform_inverter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from monai.data import BatchInverseTransform
from monai.data.utils import no_collation
from monai.engines.utils import CommonKeys
from monai.transforms import InvertibleTransform, allow_missing_keys_mode
from monai.transforms import InvertibleTransform, allow_missing_keys_mode, convert_inverse_interp_mode
from monai.utils import InverseKeys, exact_version, optional_import

Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events")
Expand All @@ -32,12 +32,6 @@ class TransformInverter:
Ignite handler to automatically invert all the pre-transforms that support `inverse`.
It takes `engine.state.output` as the input data and uses the transforms infomation from `engine.state.batch`.

Note:
This handler is experimental API in v0.5, the interpolation mode in the transforms
and inverse transforms are the same, so maybe it's not correct as we may want to use `bilinear`
for input image but use `nearest` when inverting transforms for model outout.
For this case, a solution is to set `batch_key` to the label field if we have labels.

"""

def __init__(
Expand All @@ -48,6 +42,7 @@ def __init__(
batch_key: str = CommonKeys.IMAGE,
output_key: str = CommonKeys.PRED,
postfix: str = "inverted",
nearest_interp: bool = True,
) -> None:
"""
Args:
Expand All @@ -59,13 +54,16 @@ def __init__(
for this input data, then invert them for the model output, default to "image".
output_key: the key of model output in `ignite.engine.output`, invert transforms on it.
postfix: will save the inverted result into `ignite.engine.output` with key `{ouput_key}_{postfix}`.
nearest_interp: whether to use `nearest` interpolation mode when inverting spatial transforms,
default to `True`. if `False`, use the same interpolation mode as the original transform.

"""
self.transform = transform
self.inverter = BatchInverseTransform(transform=transform, loader=loader, collate_fn=collate_fn)
self.batch_key = batch_key
self.output_key = output_key
self.postfix = postfix
self.nearest_interp = nearest_interp

def attach(self, engine: Engine) -> None:
"""
Expand All @@ -84,9 +82,13 @@ def __call__(self, engine: Engine) -> None:
warnings.warn("all the pre-transforms are not InvertibleTransform or no need to invert.")
return

transform_info = engine.state.batch[transform_key]
if self.nearest_interp:
convert_inverse_interp_mode(trans_info=transform_info, mode="nearest", align_corners=None)

segs_dict = {
self.batch_key: engine.state.output[self.output_key].detach().cpu(),
transform_key: engine.state.batch[transform_key],
transform_key: transform_info,
}

with allow_missing_keys_mode(self.transform): # type: ignore
Expand Down
1 change: 1 addition & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@
)
from .utils import (
allow_missing_keys_mode,
convert_inverse_interp_mode,
copypaste_arrays,
create_control_grid,
create_grid,
Expand Down
26 changes: 18 additions & 8 deletions monai/transforms/croppad/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""

from copy import deepcopy
from enum import Enum
from itertools import chain
from math import floor
from typing import Any, Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union
Expand Down Expand Up @@ -125,7 +126,7 @@ def __init__(
def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
d = dict(data)
for key, m in self.key_iterator(d, self.mode):
self.push_transform(d, key)
self.push_transform(d, key, extra_info={"mode": m.value if isinstance(m, Enum) else m})
d[key] = self.padder(d[key], mode=m)
return d

Expand Down Expand Up @@ -193,7 +194,7 @@ def __init__(
def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
d = dict(data)
for key, m in self.key_iterator(d, self.mode):
self.push_transform(d, key)
self.push_transform(d, key, extra_info={"mode": m.value if isinstance(m, Enum) else m})
d[key] = self.padder(d[key], mode=m)
return d

Expand Down Expand Up @@ -259,7 +260,7 @@ def __init__(
def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
d = dict(data)
for key, m in self.key_iterator(d, self.mode):
self.push_transform(d, key)
self.push_transform(d, key, extra_info={"mode": m.value if isinstance(m, Enum) else m})
d[key] = self.padder(d[key], mode=m)
return d

Expand Down Expand Up @@ -826,6 +827,7 @@ class ResizeWithPadOrCropd(MapTransform, InvertibleTransform):
``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
One of the listed string values or a user supplied function for padding. Defaults to ``"constant"``.
See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
It also can be a sequence of string, each element corresponds to a key in ``keys``.
allow_missing_keys: don't raise exception if key is missing.

"""
Expand All @@ -834,18 +836,26 @@ def __init__(
self,
keys: KeysCollection,
spatial_size: Union[Sequence[int], int],
mode: Union[NumpyPadMode, str] = NumpyPadMode.CONSTANT,
mode: NumpyPadModeSequence = NumpyPadMode.CONSTANT,
allow_missing_keys: bool = False,
) -> None:
super().__init__(keys, allow_missing_keys)
self.padcropper = ResizeWithPadOrCrop(spatial_size=spatial_size, mode=mode)
self.mode = ensure_tuple_rep(mode, len(self.keys))
self.padcropper = ResizeWithPadOrCrop(spatial_size=spatial_size)

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
d = dict(data)
for key in self.key_iterator(d):
for key, m in self.key_iterator(d, self.mode):
orig_size = d[key].shape[1:]
d[key] = self.padcropper(d[key])
self.push_transform(d, key, orig_size=orig_size)
d[key] = self.padcropper(d[key], mode=m)
self.push_transform(
d,
key,
orig_size=orig_size,
extra_info={
"mode": m.value if isinstance(m, Enum) else m,
},
)
return d

def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
Expand Down
2 changes: 1 addition & 1 deletion monai/transforms/inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def push_transform(
info = {
InverseKeys.CLASS_NAME: self.__class__.__name__,
InverseKeys.ID: id(self),
InverseKeys.ORIG_SIZE: orig_size or data[key].shape[1:],
InverseKeys.ORIG_SIZE: orig_size or (data[key].shape[1:] if hasattr(data[key], "shape") else None),
}
if extra_info is not None:
info[InverseKeys.EXTRA_INFO] = extra_info
Expand Down
Loading