From d01eb1de21ceea3e4774d61e1817cfe23568fbe6 Mon Sep 17 00:00:00 2001 From: Mingxin Zheng <18563433+mingxin-zheng@users.noreply.github.com> Date: Fri, 5 May 2023 13:41:28 +0000 Subject: [PATCH] Make RetinaNet throw errors for NaN only when training Signed-off-by: Mingxin Zheng <18563433+mingxin-zheng@users.noreply.github.com> --- .../apps/detection/networks/retinanet_detector.py | 15 ++++++++++++--- .../apps/detection/networks/retinanet_network.py | 11 +++++++++-- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/monai/apps/detection/networks/retinanet_detector.py b/monai/apps/detection/networks/retinanet_detector.py index 4dbec014dc..a0573d6cd1 100644 --- a/monai/apps/detection/networks/retinanet_detector.py +++ b/monai/apps/detection/networks/retinanet_detector.py @@ -617,7 +617,10 @@ def _reshape_maps(self, result_maps: list[Tensor]) -> Tensor: reshaped_result_map = reshaped_result_map.reshape(batch_size, -1, num_channel) if torch.isnan(reshaped_result_map).any() or torch.isinf(reshaped_result_map).any(): - raise ValueError("Concatenated result is NaN or Inf.") + if torch.is_grad_enabled(): + raise ValueError("Concatenated result is NaN or Inf.") + else: + warnings.warn("Concatenated result is NaN or Inf.") all_reshaped_result_map.append(reshaped_result_map) @@ -893,7 +896,10 @@ def get_cls_train_sample_per_image( """ if torch.isnan(cls_logits_per_image).any() or torch.isinf(cls_logits_per_image).any(): - raise ValueError("NaN or Inf in predicted classification logits.") + if torch.is_grad_enabled(): + raise ValueError("NaN or Inf in predicted classification logits.") + else: + warnings.warn("NaN or Inf in predicted classification logits.") foreground_idxs_per_image = matched_idxs_per_image >= 0 @@ -973,7 +979,10 @@ def get_box_train_sample_per_image( """ if torch.isnan(box_regression_per_image).any() or torch.isinf(box_regression_per_image).any(): - raise ValueError("NaN or Inf in predicted box regression.") + if torch.is_grad_enabled(): + raise ValueError("NaN or Inf in predicted box regression.") + else: + warnings.warn("NaN or Inf in predicted box regression.") foreground_idxs_per_image = torch.where(matched_idxs_per_image >= 0)[0] num_gt_box = targets_per_image[self.target_box_key].shape[0] diff --git a/monai/apps/detection/networks/retinanet_network.py b/monai/apps/detection/networks/retinanet_network.py index 7d6e341833..af8e73ba2d 100644 --- a/monai/apps/detection/networks/retinanet_network.py +++ b/monai/apps/detection/networks/retinanet_network.py @@ -40,6 +40,7 @@ from __future__ import annotations import math +import warnings from collections.abc import Callable, Sequence from typing import Any, Dict @@ -125,7 +126,10 @@ def forward(self, x: list[Tensor]) -> list[Tensor]: cls_logits_maps.append(cls_logits) if torch.isnan(cls_logits).any() or torch.isinf(cls_logits).any(): - raise ValueError("cls_logits is NaN or Inf.") + if torch.is_grad_enabled(): + raise ValueError("cls_logits is NaN or Inf.") + else: + warnings.warn("cls_logits is NaN or Inf.") return cls_logits_maps @@ -194,7 +198,10 @@ def forward(self, x: list[Tensor]) -> list[Tensor]: box_regression_maps.append(box_regression) if torch.isnan(box_regression).any() or torch.isinf(box_regression).any(): - raise ValueError("box_regression is NaN or Inf.") + if torch.is_grad_enabled(): + raise ValueError("box_regression is NaN or Inf.") + else: + warnings.warn("box_regression is NaN or Inf.") return box_regression_maps