Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions monai/visualize/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .visualizer import default_normalizer, default_upsampler # isort:skip
from .class_activation_maps import CAM, GradCAM, GradCAMpp, ModelWithHooks
from .class_activation_maps import CAM, GradCAM, GradCAMpp, ModelWithHooks, default_normalizer
from .img2tensorboard import (
add_animated_gif,
add_animated_gif_no_channels,
make_animated_gif_summary,
plot_2d_or_3d_image,
)
from .occlusion_sensitivity import OcclusionSensitivity
from .visualizer import default_upsampler
22 changes: 19 additions & 3 deletions monai/visualize/class_activation_maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,30 @@
import warnings
from typing import Callable, Dict, Sequence, Union

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from monai.networks.utils import eval_mode, train_mode
from monai.transforms import ScaleIntensity
from monai.utils import ensure_tuple
from monai.visualize import default_normalizer, default_upsampler
from monai.visualize.visualizer import default_upsampler

__all__ = ["CAM", "GradCAM", "GradCAMpp", "ModelWithHooks"]
__all__ = ["CAM", "GradCAM", "GradCAMpp", "ModelWithHooks", "default_normalizer"]


def default_normalizer(x) -> np.ndarray:
"""
A linear intensity scaling by mapping the (min, max) to (1, 0).

N.B.: This will flip magnitudes (i.e., smallest will become biggest and vice versa).
"""
if isinstance(x, torch.Tensor):
x = x.detach().cpu().numpy()
scaler = ScaleIntensity(minv=1.0, maxv=0.0)
x = [scaler(x) for x in x]
return np.stack(x, axis=0)


class ModelWithHooks:
Expand Down Expand Up @@ -221,7 +236,8 @@ def __init__(
N dimensional linear (bilinear, trilinear, etc.) depending on num spatial
dimensions of input.
postprocessing: a callable that applies on the upsampled output image.
default is normalising between 0 and 1.
Default is normalizing between min=1 and max=0 (i.e., largest input will become 0 and
smallest input will become 1).
"""
super().__init__(
nn_module=nn_module,
Expand Down
Loading