Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions monai/apps/detection/networks/retinanet_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,10 @@ def _reshape_maps(self, result_maps: list[Tensor]) -> Tensor:
reshaped_result_map = reshaped_result_map.reshape(batch_size, -1, num_channel)

if torch.isnan(reshaped_result_map).any() or torch.isinf(reshaped_result_map).any():
raise ValueError("Concatenated result is NaN or Inf.")
if torch.is_grad_enabled():
raise ValueError("Concatenated result is NaN or Inf.")
else:
warnings.warn("Concatenated result is NaN or Inf.")

all_reshaped_result_map.append(reshaped_result_map)

Expand Down Expand Up @@ -893,7 +896,10 @@ def get_cls_train_sample_per_image(
"""

if torch.isnan(cls_logits_per_image).any() or torch.isinf(cls_logits_per_image).any():
raise ValueError("NaN or Inf in predicted classification logits.")
if torch.is_grad_enabled():
raise ValueError("NaN or Inf in predicted classification logits.")
else:
warnings.warn("NaN or Inf in predicted classification logits.")

foreground_idxs_per_image = matched_idxs_per_image >= 0

Expand Down Expand Up @@ -973,7 +979,10 @@ def get_box_train_sample_per_image(
"""

if torch.isnan(box_regression_per_image).any() or torch.isinf(box_regression_per_image).any():
raise ValueError("NaN or Inf in predicted box regression.")
if torch.is_grad_enabled():
raise ValueError("NaN or Inf in predicted box regression.")
else:
warnings.warn("NaN or Inf in predicted box regression.")

foreground_idxs_per_image = torch.where(matched_idxs_per_image >= 0)[0]
num_gt_box = targets_per_image[self.target_box_key].shape[0]
Expand Down
11 changes: 9 additions & 2 deletions monai/apps/detection/networks/retinanet_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from __future__ import annotations

import math
import warnings
from collections.abc import Callable, Sequence
from typing import Any, Dict

Expand Down Expand Up @@ -125,7 +126,10 @@ def forward(self, x: list[Tensor]) -> list[Tensor]:
cls_logits_maps.append(cls_logits)

if torch.isnan(cls_logits).any() or torch.isinf(cls_logits).any():
raise ValueError("cls_logits is NaN or Inf.")
if torch.is_grad_enabled():
raise ValueError("cls_logits is NaN or Inf.")
else:
warnings.warn("cls_logits is NaN or Inf.")

return cls_logits_maps

Expand Down Expand Up @@ -194,7 +198,10 @@ def forward(self, x: list[Tensor]) -> list[Tensor]:
box_regression_maps.append(box_regression)

if torch.isnan(box_regression).any() or torch.isinf(box_regression).any():
raise ValueError("box_regression is NaN or Inf.")
if torch.is_grad_enabled():
raise ValueError("box_regression is NaN or Inf.")
else:
warnings.warn("box_regression is NaN or Inf.")

return box_regression_maps

Expand Down