diff --git a/monai/visualize/occlusion_sensitivity.py b/monai/visualize/occlusion_sensitivity.py index 6dc4de154f..bb9ef59e5c 100644 --- a/monai/visualize/occlusion_sensitivity.py +++ b/monai/visualize/occlusion_sensitivity.py @@ -303,7 +303,8 @@ def __call__( # type: ignore # upsample if self.upsampler is not None: - if np.any(output_im_shape != x.shape[1:]): + assert len(sensitivity_ims_list[i].shape) == len(x.shape) + if np.any(sensitivity_ims_list[i].shape != x.shape): img_spatial = tuple(output_im_shape[1:]) sensitivity_ims_list[i] = self.upsampler(img_spatial)(sensitivity_ims_list[i])