diff --git a/monai/visualize/__init__.py b/monai/visualize/__init__.py index c821bf42f5..9ad61fa3f2 100644 --- a/monai/visualize/__init__.py +++ b/monai/visualize/__init__.py @@ -9,8 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .visualizer import default_normalizer, default_upsampler # isort:skip -from .class_activation_maps import CAM, GradCAM, GradCAMpp, ModelWithHooks +from .class_activation_maps import CAM, GradCAM, GradCAMpp, ModelWithHooks, default_normalizer from .img2tensorboard import ( add_animated_gif, add_animated_gif_no_channels, @@ -18,3 +17,4 @@ plot_2d_or_3d_image, ) from .occlusion_sensitivity import OcclusionSensitivity +from .visualizer import default_upsampler diff --git a/monai/visualize/class_activation_maps.py b/monai/visualize/class_activation_maps.py index f0453f051c..a917bcf800 100644 --- a/monai/visualize/class_activation_maps.py +++ b/monai/visualize/class_activation_maps.py @@ -12,15 +12,30 @@ import warnings from typing import Callable, Dict, Sequence, Union +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from monai.networks.utils import eval_mode, train_mode +from monai.transforms import ScaleIntensity from monai.utils import ensure_tuple -from monai.visualize import default_normalizer, default_upsampler +from monai.visualize.visualizer import default_upsampler -__all__ = ["CAM", "GradCAM", "GradCAMpp", "ModelWithHooks"] +__all__ = ["CAM", "GradCAM", "GradCAMpp", "ModelWithHooks", "default_normalizer"] + + +def default_normalizer(x) -> np.ndarray: + """ + A linear intensity scaling by mapping the (min, max) to (1, 0). + + N.B.: This will flip magnitudes (i.e., smallest will become biggest and vice versa). + """ + if isinstance(x, torch.Tensor): + x = x.detach().cpu().numpy() + scaler = ScaleIntensity(minv=1.0, maxv=0.0) + x = [scaler(x) for x in x] + return np.stack(x, axis=0) class ModelWithHooks: @@ -221,7 +236,8 @@ def __init__( N dimensional linear (bilinear, trilinear, etc.) depending on num spatial dimensions of input. postprocessing: a callable that applies on the upsampled output image. - default is normalising between 0 and 1. + Default is normalizing between min=1 and max=0 (i.e., largest input will become 0 and + smallest input will become 1). """ super().__init__( nn_module=nn_module, diff --git a/monai/visualize/occlusion_sensitivity.py b/monai/visualize/occlusion_sensitivity.py index 0d7ac63467..6dc4de154f 100644 --- a/monai/visualize/occlusion_sensitivity.py +++ b/monai/visualize/occlusion_sensitivity.py @@ -11,14 +11,14 @@ from collections.abc import Sequence from functools import partial -from typing import Callable, Optional, Union +from typing import Callable, Optional, Tuple, Union import numpy as np import torch import torch.nn as nn from monai.networks.utils import eval_mode -from monai.visualize import default_normalizer, default_upsampler +from monai.visualize.visualizer import default_upsampler try: from tqdm import trange @@ -27,6 +27,12 @@ except (ImportError, AttributeError): trange = range +# For stride two (for example), +# if input array is: |0|1|2|3|4|5|6|7| +# downsampled output is: | 0 | 1 | 2 | 3 | +# So the upsampling should do it by the corners of the image, not their centres +default_upsampler = partial(default_upsampler, align_corners=True) + def _check_input_image(image): """Check that the input image is as expected.""" @@ -35,19 +41,6 @@ def _check_input_image(image): raise RuntimeError("Expected batch size of 1.") -def _check_input_label(model, label, image): - """Check that the input label is as expected.""" - if label is None: - label = model(image).argmax(1) - # If necessary turn the label into a 1-element tensor - elif not isinstance(label, torch.Tensor): - label = torch.tensor([[label]], dtype=torch.int64).to(image.device) - # make sure there's only 1 element - if label.numel() != image.shape[0]: - raise RuntimeError("Expected as many labels as batches.") - return label - - def _check_input_bounding_box(b_box, im_shape): """Check that the bounding box (if supplied) is as expected.""" # If no bounding box has been supplied, set min and max to None @@ -75,34 +68,56 @@ def _check_input_bounding_box(b_box, im_shape): return b_box_min, b_box_max -def _append_to_sensitivity_im(model, batch_images, batch_ids, sensitivity_im): - """For given number of images, get probability of predicting - a given label. Append to previous evaluations.""" +def _append_to_sensitivity_ims(model, batch_images, sensitivity_ims): + """Infer given images. Append to previous evaluations. Store each class separately.""" batch_images = torch.cat(batch_images, dim=0) - batch_ids = torch.LongTensor(batch_ids).unsqueeze(1).to(sensitivity_im.device) - scores = model(batch_images).detach().gather(1, batch_ids) - return torch.cat((sensitivity_im, scores)) + scores = model(batch_images).detach() + for i in range(scores.shape[1]): + sensitivity_ims[i] = torch.cat((sensitivity_ims[i], scores[:, i])) + return sensitivity_ims + + +def _get_as_np_array(val, numel): + # If not a sequence, then convert scalar to numpy array + if not isinstance(val, Sequence): + out = np.full(numel, val, dtype=np.int32) + out[0] = 1 # mask_size and stride always 1 in channel dimension + else: + # Convert to numpy array and check dimensions match + out = np.array(val, dtype=np.int32) + # Add stride of 1 to the channel direction (since user input was only for spatial dimensions) + out = np.insert(out, 0, 1) + if out.size != numel: + raise ValueError( + "If supplying stride/mask_size as sequence, number of elements should match number of spatial dimensions." + ) + return out class OcclusionSensitivity: """ - This class computes the occlusion sensitivity for a model's prediction - of a given image. By occlusion sensitivity, we mean how the probability of a given - prediction changes as the occluded section of an image changes. This can - be useful to understand why a network is making certain decisions. + This class computes the occlusion sensitivity for a model's prediction of a given image. By occlusion sensitivity, + we mean how the probability of a given prediction changes as the occluded section of an image changes. This can be + useful to understand why a network is making certain decisions. + + As important parts of the image are occluded, the probability of classifying the image correctly will decrease. + Hence, more negative values imply the corresponding occluded volume was more important in the decision process. + + Two ``torch.Tensor`` will be returned by the ``__call__`` method: an occlusion map and an image of the most probable + class. Both images will be cropped if a bounding box used, but voxel sizes will always match the input. - The result is given as ``baseline`` (the probability of - a certain output) minus the probability of the output with the occluded - area. + The occlusion map shows the inference probabilities when the corresponding part of the image is occluded. Hence, + more -ve values imply that region was important in the decision process. The map will have shape ``BCHW(D)N``, + where ``N`` is the number of classes to be inferred by the network. Hence, the occlusion for class ``i`` can + be seen with ``map[...,i]``. - Therefore, higher values in the output image mean there was a - greater the drop in certainty, indicating the occluded region was more - important in the decision process. + The most probable class is an image of the probable class when the corresponding part of the image is occluded + (equivalent to ``occ_map.argmax(dim=-1)``). See: R. R. Selvaraju et al. Grad-CAM: Visual Explanations from Deep Networks via - Gradient-based Localization. https://doi.org/10.1109/ICCV.2017.74 + Gradient-based Localization. https://doi.org/10.1109/ICCV.2017.74. - Examples + Examples: .. code-block:: python @@ -112,7 +127,7 @@ class OcclusionSensitivity: model_2d = densenet121(spatial_dims=2, in_channels=1, out_channels=3) occ_sens = OcclusionSensitivity(nn_module=model_2d) - result = occ_sens(x=torch.rand((1, 1, 48, 64)), class_idx=None, b_box=[-1, -1, 2, 40, 1, 62]) + occ_map, most_probable_class = occ_sens(x=torch.rand((1, 1, 48, 64)), class_idx=None, b_box=[-1, -1, 2, 40, 1, 62]) # densenet 3d from monai.networks.nets import DenseNet @@ -120,7 +135,7 @@ class OcclusionSensitivity: model_3d = DenseNet(spatial_dims=3, in_channels=1, out_channels=3, init_features=2, growth_rate=2, block_config=(6,)) occ_sens = OcclusionSensitivity(nn_module=model_3d, n_batch=10, stride=2) - result = occ_sens(torch.rand(1, 1, 6, 6, 6), class_idx=1, b_box=[-1, -1, 2, 3, -1, -1, -1, -1]) + occ_map, most_probable_class = occ_sens(torch.rand(1, 1, 6, 6, 6), class_idx=1, b_box=[-1, -1, 2, 3, -1, -1, -1, -1]) See Also: @@ -130,168 +145,172 @@ class OcclusionSensitivity: def __init__( self, nn_module: nn.Module, - pad_val: float = 0.0, - margin: Union[int, Sequence] = 2, + pad_val: Optional[float] = None, + mask_size: Union[int, Sequence] = 15, n_batch: int = 128, stride: Union[int, Sequence] = 1, - upsampler: Callable = default_upsampler, - postprocessing: Callable = default_normalizer, + upsampler: Optional[Callable] = default_upsampler, verbose: bool = True, ) -> None: """Occlusion sensitivitiy constructor. - :param nn_module: classification model to use for inference - :param pad_val: when occluding part of the image, which values should we put - in the image? - :param margin: we'll create a cuboid/cube around the voxel to be occluded. if - ``margin==2``, then we'll create a cube that is +/- 2 voxels in - all directions (i.e., a cube of 5 x 5 x 5 voxels). A ``Sequence`` - can be supplied to have a margin of different sizes (i.e., create - a cuboid). - :param n_batch: number of images in a batch before inference. - :param b_box: Bounding box on which to perform the analysis. The output image - will also match in size. There should be a minimum and maximum for - all dimensions except batch: ``[min1, max1, min2, max2,...]``. - * By default, the whole image will be used. Decreasing the size will - speed the analysis up, which might be useful for larger images. - * Min and max are inclusive, so [0, 63, ...] will have size (64, ...). - * Use -ve to use 0 for min values and im.shape[x]-1 for xth dimension. - :param stride: Stride in spatial directions for performing occlusions. Can be single - value or sequence (for varying stride in the different directions). - Should be >= 1. Striding in the channel direction will always be 1. - :param upsampler: An upsampling method to upsample the output image. Default is - N dimensional linear (bilinear, trilinear, etc.) depending on num spatial - dimensions of input. - :param postprocessing: a callable that applies on the upsampled output image. - default is normalising between 0 and 1. - :param verbose: use ``tdqm.trange`` output (if available). + Args: + nn_module: Classification model to use for inference + pad_val: When occluding part of the image, which values should we put + in the image? If ``None`` is used, then the average of the image will be used. + mask_size: Size of box to be occluded, centred on the central voxel. To ensure that the occluded area + is correctly centred, ``mask_size`` and ``stride`` should both be odd or even. + n_batch: Number of images in a batch for inference. + stride: Stride in spatial directions for performing occlusions. Can be single + value or sequence (for varying stride in the different directions). + Should be >= 1. Striding in the channel direction will always be 1. + upsampler: An upsampling method to upsample the output image. Default is + N-dimensional linear (bilinear, trilinear, etc.) depending on num spatial + dimensions of input. + verbose: Use ``tdqm.trange`` output (if available). """ self.nn_module = nn_module self.upsampler = upsampler - self.postprocessing = postprocessing self.pad_val = pad_val - self.margin = margin + self.mask_size = mask_size self.n_batch = n_batch self.stride = stride self.verbose = verbose - def _compute_occlusion_sensitivity(self, x, class_idx, b_box): + def _compute_occlusion_sensitivity(self, x, b_box): # Get bounding box im_shape = np.array(x.shape[1:]) b_box_min, b_box_max = _check_input_bounding_box(b_box, im_shape) - # Get baseline probability - baseline = self.nn_module(x).detach()[0, class_idx].item() + # Get the number of prediction classes + num_classes = self.nn_module(x).numel() + + #  If pad val not supplied, get the mean of the image + pad_val = x.mean() if self.pad_val is None else self.pad_val - # Create some lists + # List containing a batch of images to be inferred batch_images = [] - batch_ids = [] - sensitivity_im = torch.empty(0, dtype=torch.float32, device=x.device) + # List of sensitivity images, one for each inferred class + sensitivity_ims = num_classes * [torch.empty(0, dtype=torch.float32, device=x.device)] # If no bounding box supplied, output shape is same as input shape. # If bounding box is present, shape is max - min + 1 output_im_shape = im_shape if b_box is None else b_box_max - b_box_min + 1 - # Calculate the downsampled shape - if not isinstance(self.stride, Sequence): - stride_np = np.full_like(im_shape, self.stride, dtype=np.int32) - stride_np[0] = 1 # always do stride 1 in channel dimension - else: - # Convert to numpy array and check dimensions match - stride_np = np.array(self.stride, dtype=np.int32) - if stride_np.size != im_shape - 1: # should be 1 less to get spatial dimensions + # Get the stride and mask_size as numpy arrays + self.stride = _get_as_np_array(self.stride, len(im_shape)) + self.mask_size = _get_as_np_array(self.mask_size, len(im_shape)) + + # For each dimension, ... + for o, s in zip(output_im_shape, self.stride): + # if the size is > 1, then check that the stride is a factor of the output image shape + if o > 1 and o % s != 0: raise ValueError( - "If supplying stride as sequence, number of elements of stride should match number of spatial dimensions." + "Stride should be a factor of the image shape. Im shape " + + f"(taking bounding box into account): {output_im_shape}, stride: {self.stride}" ) - # Obviously if stride = 1, downsampled_im_shape == output_im_shape - downsampled_im_shape = np.floor(output_im_shape / stride_np).astype(np.int32) + # to ensure the occluded area is nicely centred if stride is even, ensure that so is the mask_size + if np.any(self.mask_size % 2 != self.stride % 2): + raise ValueError( + "Stride and mask size should both be odd or even (element-wise). " + + f"``stride={self.stride}``, ``mask_size={self.mask_size}``" + ) + + downsampled_im_shape = (output_im_shape / self.stride).astype(np.int32) downsampled_im_shape[downsampled_im_shape == 0] = 1 # make sure dimension sizes are >= 1 num_required_predictions = np.prod(downsampled_im_shape) + # Get bottom left and top right corners of occluded region + lower_corner = (self.stride - self.mask_size) // 2 + upper_corner = (self.stride + self.mask_size) // 2 + # Loop 1D over image verbose_range = trange if self.verbose else range for i in verbose_range(num_required_predictions): # Get corresponding ND index idx = np.unravel_index(i, downsampled_im_shape) # Multiply by stride - idx *= stride_np + idx *= self.stride # If a bounding box is being used, we need to add on # the min to shift to start of region of interest if b_box_min is not None: idx += b_box_min - # Get min and max index of box to occlude - min_idx = [max(0, i - self.margin) for i in idx] - max_idx = [min(j, i + self.margin) for i, j in zip(idx, im_shape)] + # Get min and max index of box to occlude (and make sure it's in bounds) + min_idx = np.maximum(idx + lower_corner, 0) + max_idx = np.minimum(idx + upper_corner, im_shape) # Clone and replace target area with `pad_val` occlu_im = x.detach().clone() - occlu_im[(...,) + tuple(slice(i, j) for i, j in zip(min_idx, max_idx))] = self.pad_val + occlu_im[(...,) + tuple(slice(i, j) for i, j in zip(min_idx, max_idx))] = pad_val # Add to list batch_images.append(occlu_im) - batch_ids.append(class_idx) # Once the batch is complete (or on last iteration) if len(batch_images) == self.n_batch or i == num_required_predictions - 1: - # Do the predictions and append to sensitivity map - sensitivity_im = _append_to_sensitivity_im(self.nn_module, batch_images, batch_ids, sensitivity_im) + # Do the predictions and append to sensitivity maps + sensitivity_ims = _append_to_sensitivity_ims(self.nn_module, batch_images, sensitivity_ims) # Clear lists batch_images = [] - batch_ids = [] - - # Subtract baseline from sensitivity so that +ve values mean more important in decision process - sensitivity_im = baseline - sensitivity_im # Reshape to match downsampled image, and unsqueeze to add batch dimension back in - sensitivity_im = sensitivity_im.reshape(tuple(downsampled_im_shape)).unsqueeze(0) + for i in range(num_classes): + sensitivity_ims[i] = sensitivity_ims[i].reshape(tuple(downsampled_im_shape)).unsqueeze(0) - return sensitivity_im, output_im_shape + return sensitivity_ims, output_im_shape def __call__( # type: ignore - self, x: torch.Tensor, class_idx: Optional[Union[int, torch.Tensor]] = None, b_box: Optional[Sequence] = None - ): + self, + x: torch.Tensor, + b_box: Optional[Sequence] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: - x: image to test. Should be tensor consisting of 1 batch, can be 2- or 3D. - class_idx: classification label to check for changes. This could be the true - label, or it could be the predicted label, etc. Use ``None`` to use generate - the predicted model. - b_box: Bounding box on which to perform the analysis. The output image - will also match in size. There should be a minimum and maximum for - all dimensions except batch: ``[min1, max1, min2, max2,...]``. - * By default, the whole image will be used. Decreasing the size will - speed the analysis up, which might be useful for larger images. - * Min and max are inclusive, so [0, 63, ...] will have size (64, ...). - * Use -ve to use 0 for min values and im.shape[x]-1 for xth dimension. + x: Image to use for inference. Should be a tensor consisting of 1 batch. + b_box: Bounding box on which to perform the analysis. The output image will be limited to this size. + There should be a minimum and maximum for all dimensions except batch: ``[min1, max1, min2, max2,...]``. + * By default, the whole image will be used. Decreasing the size will speed the analysis up, which might + be useful for larger images. + * Min and max are inclusive, so ``[0, 63, ...]`` will have size ``(64, ...)``. + * Use -ve to use ``min=0`` and ``max=im.shape[x]-1`` for xth dimension. + Returns: - Depends on the postprocessing, but the default return type is a Numpy array. - The returned image will occupy the same space as the input image, unless a - bounding box is supplied, in which case it will occupy that space. Unless - upsampling is disabled, the output image will have voxels of the same size - as the input image. + * Occlusion map: + * Shows the inference probabilities when the corresponding part of the image is occluded. + Hence, more -ve values imply that region was important in the decision process. + * The map will have shape ``BCHW(D)N``, where N is the number of classes to be inferred by the + network. Hence, the occlusion for class ``i`` can be seen with ``map[...,i]``. + * Most probable class: + * The most probable class when the corresponding part of the image is occluded (``argmax(dim=-1)``). + Both images will be cropped if a bounding box used, but voxel sizes will always match the input. """ with eval_mode(self.nn_module): # Check input arguments _check_input_image(x) - class_idx = _check_input_label(self.nn_module, class_idx, x) - # Generate sensitivity image - sensitivity_im, output_im_shape = self._compute_occlusion_sensitivity(x, class_idx, b_box) + # Generate sensitivity images + sensitivity_ims_list, output_im_shape = self._compute_occlusion_sensitivity(x, b_box) + + # Loop over image for each classification + for i in range(len(sensitivity_ims_list)): + + # upsample + if self.upsampler is not None: + if np.any(output_im_shape != x.shape[1:]): + img_spatial = tuple(output_im_shape[1:]) + sensitivity_ims_list[i] = self.upsampler(img_spatial)(sensitivity_ims_list[i]) + + # Convert list of tensors to tensor + sensitivity_ims = torch.stack(sensitivity_ims_list, dim=-1) - # upsampling and postprocessing - if self.upsampler is not None: - if np.any(output_im_shape != x.shape[1:]): - img_spatial = tuple(output_im_shape[1:]) - sensitivity_im = self.upsampler(img_spatial)(sensitivity_im) - if self.postprocessing: - sensitivity_im = self.postprocessing(sensitivity_im) + # The most probable class is the max in the classification dimension (last) + most_probable_class = sensitivity_ims.argmax(dim=-1) - # Squeeze and return - return sensitivity_im + return sensitivity_ims, most_probable_class diff --git a/monai/visualize/visualizer.py b/monai/visualize/visualizer.py index 2803f826f2..bbb01f5c5e 100644 --- a/monai/visualize/visualizer.py +++ b/monai/visualize/visualizer.py @@ -12,17 +12,15 @@ from typing import Callable -import numpy as np import torch import torch.nn.functional as F -from monai.transforms import ScaleIntensity from monai.utils import InterpolateMode -__all__ = ["default_upsampler", "default_normalizer"] +__all__ = ["default_upsampler"] -def default_upsampler(spatial_size) -> Callable[[torch.Tensor], torch.Tensor]: +def default_upsampler(spatial_size, align_corners=False) -> Callable[[torch.Tensor], torch.Tensor]: """ A linear interpolation method for upsampling the feature map. The output of this function is a callable `func`, @@ -30,19 +28,9 @@ def default_upsampler(spatial_size) -> Callable[[torch.Tensor], torch.Tensor]: """ def up(x): + linear_mode = [InterpolateMode.LINEAR, InterpolateMode.BILINEAR, InterpolateMode.TRILINEAR] interp_mode = linear_mode[len(spatial_size) - 1] - return F.interpolate(x, size=spatial_size, mode=str(interp_mode.value), align_corners=False) + return F.interpolate(x, size=spatial_size, mode=str(interp_mode.value), align_corners=align_corners) return up - - -def default_normalizer(x) -> np.ndarray: - """ - A linear intensity scaling by mapping the (min, max) to (1, 0). - """ - if isinstance(x, torch.Tensor): - x = x.detach().cpu().numpy() - scaler = ScaleIntensity(minv=1.0, maxv=0.0) - x = [scaler(x) for x in x] - return np.stack(x, axis=0) diff --git a/tests/test_occlusion_sensitivity.py b/tests/test_occlusion_sensitivity.py index 5ad68eabd8..ea21cd4fa8 100644 --- a/tests/test_occlusion_sensitivity.py +++ b/tests/test_occlusion_sensitivity.py @@ -18,9 +18,11 @@ from monai.visualize import OcclusionSensitivity device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") -model_2d = densenet121(spatial_dims=2, in_channels=1, out_channels=3).to(device) +out_channels_2d = 4 +out_channels_3d = 3 +model_2d = densenet121(spatial_dims=2, in_channels=1, out_channels=out_channels_2d).to(device) model_3d = DenseNet( - spatial_dims=3, in_channels=1, out_channels=3, init_features=2, growth_rate=2, block_config=(6,) + spatial_dims=3, in_channels=1, out_channels=out_channels_3d, init_features=2, growth_rate=2, block_config=(6,) ).to(device) model_2d.eval() model_3d.eval() @@ -32,27 +34,23 @@ }, { "x": torch.rand(1, 1, 48, 64).to(device), - "class_idx": torch.tensor([[0]], dtype=torch.int64).to(device), "b_box": [-1, -1, 2, 40, 1, 62], }, + (1, 1, 39, 62, out_channels_2d), (1, 1, 39, 62), ] # 3D w/ bounding box and stride TEST_CASE_1 = [ - { - "nn_module": model_3d, - "n_batch": 10, - "stride": 2, - }, + {"nn_module": model_3d, "n_batch": 10, "stride": (2, 1, 2), "mask_size": (16, 15, 14)}, { "x": torch.rand(1, 1, 6, 6, 6).to(device), - "class_idx": None, "b_box": [-1, -1, 2, 3, -1, -1, -1, -1], }, + (1, 1, 2, 6, 6, out_channels_3d), (1, 1, 2, 6, 6), ] -TEST_CASE_FAIL = [ # 2D should fail, since 3 stride values given +TEST_CASE_FAIL_0 = [ # 2D should fail, since 3 stride values given { "nn_module": model_2d, "n_batch": 10, @@ -60,23 +58,38 @@ }, { "x": torch.rand(1, 1, 48, 64).to(device), - "class_idx": None, "b_box": [-1, -1, 2, 3, -1, -1], }, ] +TEST_CASE_FAIL_1 = [ # 2D should fail, since stride is not a factor of image size + { + "nn_module": model_2d, + "stride": 3, + }, + { + "x": torch.rand(1, 1, 48, 64).to(device), + }, +] + class TestComputeOcclusionSensitivity(unittest.TestCase): @parameterized.expand([TEST_CASE_0, TEST_CASE_1]) - def test_shape(self, init_data, call_data, expected_shape): + def test_shape(self, init_data, call_data, map_expected_shape, most_prob_expected_shape): occ_sens = OcclusionSensitivity(**init_data) - result = occ_sens(**call_data) - self.assertTupleEqual(result.shape, expected_shape) + map, most_prob = occ_sens(**call_data) + self.assertTupleEqual(map.shape, map_expected_shape) + self.assertTupleEqual(most_prob.shape, most_prob_expected_shape) + # most probable class should be of type int, and should have min>=0, max