diff --git a/monai/metrics/occlusion_sensitivity.py b/monai/metrics/occlusion_sensitivity.py index 900cfe4645..9879f472a9 100644 --- a/monai/metrics/occlusion_sensitivity.py +++ b/monai/metrics/occlusion_sensitivity.py @@ -10,7 +10,8 @@ # limitations under the License. from collections.abc import Sequence -from typing import Union +from functools import partial +from typing import Optional, Union import numpy as np import torch @@ -18,6 +19,8 @@ try: from tqdm import trange + + trange = partial(trange, desc="Computing occlusion sensitivity") except (ImportError, AttributeError): trange = range @@ -84,7 +87,9 @@ def compute_occlusion_sensitivity( pad_val: float = 0.0, margin: Union[int, Sequence] = 2, n_batch: int = 128, - b_box: Union[Sequence, None] = None, + b_box: Optional[Sequence] = None, + stride: Union[int, Sequence] = 1, + upsample_mode: str = "nearest", ) -> np.ndarray: """ This function computes the occlusion sensitivity for a model's prediction @@ -123,6 +128,13 @@ def compute_occlusion_sensitivity( speed the analysis up, which might be useful for larger images. * Min and max are inclusive, so [0, 63, ...] will have size (64, ...). * Use -ve to use 0 for min values and im.shape[x]-1 for xth dimension. + stride: Stride for performing occlusions. Can be single value or sequence + (for varying stride in the different directions). Should be >= 1. + upsample_mode: If stride != 1 is used, we'll upsample such that the size + of the voxels in the output image match the input. Upsampling is done with + ``torch.nn.Upsample``, and mode can be set to: + * ``nearest``, ``linear``, ``bilinear``, ``bicubic`` and ``trilinear`` + * default is ``nearest``. Returns: Numpy array. If no bounding box is supplied, this will be the same size as the input image. If a bounding box is used, the output image will be @@ -147,12 +159,28 @@ def compute_occlusion_sensitivity( # If no bounding box supplied, output shape is same as input shape. # If bounding box is present, shape is max - min + 1 output_im_shape = im_shape if b_box is None else b_box_max - b_box_min + 1 - num_required_predictions = np.prod(output_im_shape) + + # Calculate the downsampled shape + if not isinstance(stride, Sequence): + stride_np = np.full_like(im_shape, stride, dtype=np.int32) + stride_np[0] = 1 # always do stride 1 in channel dimension + else: + # Convert to numpy array and check dimensions match + stride_np = np.array(stride, dtype=np.int32) + if stride_np.size != im_shape.size: + raise ValueError("Sizes of image shape and stride should match.") + + # Obviously if stride = 1, downsampled_im_shape == output_im_shape + downsampled_im_shape = np.floor(output_im_shape / stride_np).astype(np.int32) + downsampled_im_shape[downsampled_im_shape == 0] = 1 # make sure dimension sizes are >= 1 + num_required_predictions = np.prod(downsampled_im_shape) # Loop 1D over image for i in trange(num_required_predictions): # Get corresponding ND index - idx = np.unravel_index(i, output_im_shape) + idx = np.unravel_index(i, downsampled_im_shape) + # Multiply by stride + idx *= stride_np # If a bounding box is being used, we need to add on # the min to shift to start of region of interest if b_box_min is not None: @@ -178,11 +206,20 @@ def compute_occlusion_sensitivity( batch_images = [] batch_ids = [] + # Subtract from baseline + sensitivity_im = baseline - sensitivity_im + + # Reshape to match downsampled image + sensitivity_im = sensitivity_im.reshape(tuple(downsampled_im_shape)) + + # If necessary, upsample + if np.any(stride_np != 1): + output_im_shape = tuple(output_im_shape[1:]) # needs to be given as 3D tuple + upsampler = nn.Upsample(size=output_im_shape, mode=upsample_mode) + sensitivity_im = upsampler(sensitivity_im.unsqueeze(0)) + # Convert tensor to numpy sensitivity_im = sensitivity_im.cpu().numpy() - # Reshape to size of output image - sensitivity_im = sensitivity_im.reshape(output_im_shape) - - # Squeeze, subtract from baseline and return - return baseline - np.squeeze(sensitivity_im) + # Squeeze and return + return np.squeeze(sensitivity_im) diff --git a/tests/test_compute_occlusion_sensitivity.py b/tests/test_compute_occlusion_sensitivity.py index 897177c6ed..9f30162c47 100644 --- a/tests/test_compute_occlusion_sensitivity.py +++ b/tests/test_compute_occlusion_sensitivity.py @@ -43,6 +43,7 @@ "label": 0, "b_box": [-1, -1, 2, 3, -1, -1, -1, -1], "n_batch": 10, + "stride": 2, }, (2, 6, 6), ]