Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 46 additions & 9 deletions monai/metrics/occlusion_sensitivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,17 @@
# limitations under the License.

from collections.abc import Sequence
from typing import Union
from functools import partial
from typing import Optional, Union

import numpy as np
import torch
import torch.nn as nn

try:
from tqdm import trange

trange = partial(trange, desc="Computing occlusion sensitivity")
except (ImportError, AttributeError):
trange = range

Expand Down Expand Up @@ -84,7 +87,9 @@ def compute_occlusion_sensitivity(
pad_val: float = 0.0,
margin: Union[int, Sequence] = 2,
n_batch: int = 128,
b_box: Union[Sequence, None] = None,
b_box: Optional[Sequence] = None,
stride: Union[int, Sequence] = 1,
upsample_mode: str = "nearest",
) -> np.ndarray:
"""
This function computes the occlusion sensitivity for a model's prediction
Expand Down Expand Up @@ -123,6 +128,13 @@ def compute_occlusion_sensitivity(
speed the analysis up, which might be useful for larger images.
* Min and max are inclusive, so [0, 63, ...] will have size (64, ...).
* Use -ve to use 0 for min values and im.shape[x]-1 for xth dimension.
stride: Stride for performing occlusions. Can be single value or sequence
(for varying stride in the different directions). Should be >= 1.
upsample_mode: If stride != 1 is used, we'll upsample such that the size
of the voxels in the output image match the input. Upsampling is done with
``torch.nn.Upsample``, and mode can be set to:
* ``nearest``, ``linear``, ``bilinear``, ``bicubic`` and ``trilinear``
* default is ``nearest``.
Returns:
Numpy array. If no bounding box is supplied, this will be the same size
as the input image. If a bounding box is used, the output image will be
Expand All @@ -147,12 +159,28 @@ def compute_occlusion_sensitivity(
# If no bounding box supplied, output shape is same as input shape.
# If bounding box is present, shape is max - min + 1
output_im_shape = im_shape if b_box is None else b_box_max - b_box_min + 1
num_required_predictions = np.prod(output_im_shape)

# Calculate the downsampled shape
if not isinstance(stride, Sequence):
stride_np = np.full_like(im_shape, stride, dtype=np.int32)
stride_np[0] = 1 # always do stride 1 in channel dimension
else:
# Convert to numpy array and check dimensions match
stride_np = np.array(stride, dtype=np.int32)
if stride_np.size != im_shape.size:
raise ValueError("Sizes of image shape and stride should match.")

# Obviously if stride = 1, downsampled_im_shape == output_im_shape
downsampled_im_shape = np.floor(output_im_shape / stride_np).astype(np.int32)
downsampled_im_shape[downsampled_im_shape == 0] = 1 # make sure dimension sizes are >= 1
num_required_predictions = np.prod(downsampled_im_shape)

# Loop 1D over image
for i in trange(num_required_predictions):
# Get corresponding ND index
idx = np.unravel_index(i, output_im_shape)
idx = np.unravel_index(i, downsampled_im_shape)
# Multiply by stride
idx *= stride_np
# If a bounding box is being used, we need to add on
# the min to shift to start of region of interest
if b_box_min is not None:
Expand All @@ -178,11 +206,20 @@ def compute_occlusion_sensitivity(
batch_images = []
batch_ids = []

# Subtract from baseline
sensitivity_im = baseline - sensitivity_im

# Reshape to match downsampled image
sensitivity_im = sensitivity_im.reshape(tuple(downsampled_im_shape))

# If necessary, upsample
if np.any(stride_np != 1):
output_im_shape = tuple(output_im_shape[1:]) # needs to be given as 3D tuple
upsampler = nn.Upsample(size=output_im_shape, mode=upsample_mode)
sensitivity_im = upsampler(sensitivity_im.unsqueeze(0))

# Convert tensor to numpy
sensitivity_im = sensitivity_im.cpu().numpy()

# Reshape to size of output image
sensitivity_im = sensitivity_im.reshape(output_im_shape)

# Squeeze, subtract from baseline and return
return baseline - np.squeeze(sensitivity_im)
# Squeeze and return
return np.squeeze(sensitivity_im)
1 change: 1 addition & 0 deletions tests/test_compute_occlusion_sensitivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
"label": 0,
"b_box": [-1, -1, 2, 3, -1, -1, -1, -1],
"n_batch": 10,
"stride": 2,
},
(2, 6, 6),
]
Expand Down