Skip to content

Conversation

@ytl0623
Copy link
Contributor

@ytl0623 ytl0623 commented Dec 22, 2025

Fixes #8603

Description

Refactors AsymmetricUnifiedFocalLoss and its sub-components (AsymmetricFocalLoss, AsymmetricFocalTverskyLoss) to extend support from Binary-only to Multi-class segmentation, while also fixing mathematical logic errors and parameter passing bugs.

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • Breaking change (fix or new feature that would cause existing functionality to change).
  • New tests added to cover the changes.
  • Integration tests passed locally by running ./runtests.sh -f -u --net --coverage.
  • Quick tests passed locally by running ./runtests.sh --quick --unittests --disttests.
  • In-line docstrings updated.
  • Documentation updated, tested make html command in the docs/ folder.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Dec 22, 2025

📝 Walkthrough

Walkthrough

Adds a use_softmax flag and multi-class support to AsymmetricFocalLoss, AsymmetricFocalTverskyLoss, and AsymmetricUnifiedFocalLoss. Prediction logits are auto-converted to probabilities (softmax for multi-class, sigmoid for binary), single-channel preds are expanded to two-channel probabilities, and to_onehot_y handling and shape validations are extended. Per-class losses are computed with explicit background/foreground paths and support NONE/MEAN/SUM reductions. AsymmetricUnifiedFocalLoss now wraps and exposes asy_focal_loss and asy_focal_tversky_loss, combining their per-class outputs. Tests added/rewritten for binary logits, binary 2‑ch, multiclass perfect/incorrect cases, CUDA, and shape errors.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 60.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed Title accurately summarizes the main change: adding sigmoid/softmax interface to AsymmetricUnifiedFocalLoss.
Description check ✅ Passed Description addresses the template with linked issue reference, clear objectives, and type selection; missing integration/documentation test checkboxes are secondary.
Linked Issues check ✅ Passed Changes fulfill #8603 objectives: sigmoid/softmax interface added, multi-class support enabled, mathematical logic and parameter bugs fixed, tests added.
Out of Scope Changes check ✅ Passed All changes directly address #8603: loss implementation refactoring for multi-class support, interface updates, and test coverage align with stated objectives.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (2)
monai/losses/unified_focal_loss.py (2)

76-82: Add stacklevel=2 to warning.

Per static analysis and Python conventions, set stacklevel to point to caller.

-                warnings.warn("single channel prediction, `include_background=False` ignored.")
+                warnings.warn("single channel prediction, `include_background=False` ignored.", stacklevel=2)

177-194: Logic is correct; consider documenting return value.

The forward method correctly passes logits to FocalLoss and probabilities to AsymmetricFocalTverskyLoss. Per coding guidelines, docstrings should document return values.

         """
         Args:
             y_pred: (BNH[WD]) Logits (raw scores).
             y_true: (BNH[WD]) Ground truth labels.
+
+        Returns:
+            torch.Tensor: Weighted combination of focal loss and asymmetric focal Tversky loss.
         """
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between 15fd428 and c27945a.

📒 Files selected for processing (1)
  • monai/losses/unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/losses/unified_focal_loss.py
🧬 Code graph analysis (1)
monai/losses/unified_focal_loss.py (3)
monai/losses/focal_loss.py (1)
  • FocalLoss (26-202)
monai/networks/utils.py (1)
  • one_hot (170-220)
monai/utils/enums.py (1)
  • LossReduction (253-264)
🪛 Ruff (0.14.8)
monai/losses/unified_focal_loss.py

78-78: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


85-85: Avoid specifying long messages outside the exception class

(TRY003)


127-127: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (1)
monai/losses/unified_focal_loss.py (1)

36-44: Tests already cover the unified focal loss implementation.

New tests were added to cover the changes. The PR indicates that test coverage has been implemented, so this concern can be closed.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (2)
monai/losses/unified_focal_loss.py (2)

157-157: Replace Chinese comment with English.

-        self.use_softmax = use_softmax  # 儲存參數
+        self.use_softmax = use_softmax

112-114: Numerical instability when dice approaches 1.0.

When dice_class[:, i] equals 1.0, torch.pow(0, -self.gamma) produces infinity, causing NaN gradients.

Proposed fix
-                # Foreground classes: apply focal modulation
-                # Original logic: (1 - dice) * (1 - dice)^(-gamma) -> (1-dice)^(1-gamma)
-                loss_list.append((1 - dice_class[:, i]) * torch.pow(1 - dice_class[:, i], -self.gamma))
+                # Foreground classes: apply focal modulation
+                back_dice = torch.clamp(1 - dice_class[:, i], min=self.epsilon)
+                loss_list.append(back_dice * torch.pow(back_dice, -self.gamma))
🧹 Nitpick comments (1)
monai/losses/unified_focal_loss.py (1)

74-80: Background exclusion logic is correct.

The implementation properly handles the single-channel edge case.

Optional: Add stacklevel to warning
-                warnings.warn("single channel prediction, `include_background=False` ignored.")
+                warnings.warn("single channel prediction, `include_background=False` ignored.", stacklevel=2)
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between c27945a and 1fba9d3.

📒 Files selected for processing (1)
  • monai/losses/unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/losses/unified_focal_loss.py
🧬 Code graph analysis (1)
monai/losses/unified_focal_loss.py (3)
monai/losses/focal_loss.py (1)
  • FocalLoss (26-202)
monai/networks/utils.py (1)
  • one_hot (170-220)
monai/utils/enums.py (1)
  • LossReduction (253-264)
🪛 Ruff (0.14.8)
monai/losses/unified_focal_loss.py

76-76: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


83-83: Avoid specifying long messages outside the exception class

(TRY003)


125-125: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (3)
monai/losses/unified_focal_loss.py (3)

34-58: Constructor changes look good.

The include_background parameter addition aligns with MONAI conventions and is properly documented.


159-173: Component initialization is appropriate.

Both loss components are properly configured with consistent parameters.


181-188: Remove redundant to_onehot_y conversion concern.

Both focal_loss and asy_focal_tversky_loss independently apply to_onehot_y conversion when enabled. Since PyTorch's one-hot encoding creates new tensors (not in-place modifications), passing the same y_true reference to both components is safe and functionally correct. No conversion conflicts occur.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (3)
monai/losses/unified_focal_loss.py (2)

76-76: Add stacklevel=2 to warning.

Missing explicit stacklevel for proper traceback.

-                warnings.warn("single channel prediction, `include_background=False` ignored.")
+                warnings.warn("single channel prediction, `include_background=False` ignored.", stacklevel=2)

Based on static analysis hints.


176-193: Document that y_pred must be logits.

The forward method expects raw logits since FocalLoss applies activation internally, and manual softmax/sigmoid is applied for the Tversky component. The docstring states this (line 179) but could be more explicit about the consequences of passing probabilities instead.

Consider adding a note in the docstring:

     def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
         """
         Args:
-            y_pred: (BNH[WD]) Logits (raw scores).
+            y_pred: (BNH[WD]) Logits (raw scores, not probabilities).
+                Do not pass pre-activated inputs; activation is applied internally.
             y_true: (BNH[WD]) Ground truth labels.
         """
tests/losses/test_unified_focal_loss.py (1)

26-61: Add test coverage for edge cases.

Current tests only cover perfect predictions with zero loss. Missing coverage for:

  • Imperfect predictions (non-zero loss)
  • include_background=False scenarios
  • to_onehot_y=True with integer labels
  • Multi-class softmax with imperfect predictions
Suggested additional test cases
# Case 2: Binary with include_background=False
[
    {
        "use_softmax": False,
        "include_background": False,
    },
    {
        "y_pred": torch.tensor([[[[logit_pos, logit_neg], [logit_neg, logit_pos]]]]),
        "y_true": torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]]),
    },
    0.0,  # Should still be zero for perfect prediction
],
# Case 3: Multi-class with to_onehot_y=True (integer labels)
[
    {
        "use_softmax": True,
        "include_background": True,
        "to_onehot_y": True,
    },
    {
        "y_pred": torch.tensor([[[[logit_pos, logit_neg], [logit_neg, logit_neg], [logit_neg, logit_pos]]]]),
        "y_true": torch.tensor([[[[0, 2]]]]),  # Integer labels: class 0, class 2
    },
    0.0,
],

Do you want me to generate a complete test case addition or open an issue to track this?

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between 1fba9d3 and 39664ea.

📒 Files selected for processing (2)
  • monai/losses/unified_focal_loss.py
  • tests/losses/test_unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • tests/losses/test_unified_focal_loss.py
  • monai/losses/unified_focal_loss.py
🧬 Code graph analysis (2)
tests/losses/test_unified_focal_loss.py (1)
monai/losses/unified_focal_loss.py (1)
  • AsymmetricUnifiedFocalLoss (129-193)
monai/losses/unified_focal_loss.py (3)
monai/losses/focal_loss.py (1)
  • FocalLoss (26-202)
monai/networks/utils.py (1)
  • one_hot (170-220)
monai/utils/enums.py (1)
  • LossReduction (253-264)
🪛 Ruff (0.14.8)
monai/losses/unified_focal_loss.py

76-76: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


83-83: Avoid specifying long messages outside the exception class

(TRY003)


126-126: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (5)
monai/losses/unified_focal_loss.py (3)

184-187: LGTM: Correct activation choice for Tversky loss.

FocalLoss handles its own activation internally, so this manual conversion to probabilities for AsymmetricFocalTverskyLoss is correct. The activation choice (softmax vs sigmoid) properly follows the use_softmax flag.


89-116: Implementation correctly handles include_background with standard MONAI slicing pattern.

When include_background=False, channel index 0 is excluded from the calculation—the code does this via tensor slicing at lines 79-80 before the asymmetry loop. Once sliced, all remaining channels receive focal modulation; none are treated as background. The loss only supports binary segmentation, so asymmetry designates the first present channel as background and all others as foreground, which is the intended behavior per the documented design comment (lines 101-104).


160-174: Both composed losses independently transform y_true with their respective settings. Each applies its own non-destructive transformations (one-hot encoding creates new tensors; slicing creates new views), so no actual collision occurs. This is correct by design—composed losses should handle their own input transformations.

tests/losses/test_unified_focal_loss.py (2)

22-24: LGTM: High-confidence logits ensure clear test expectations.

Using ±10.0 logits produces near-perfect probabilities (~0.9999 and ~0.0001), making zero-loss expectations reasonable for perfect predictions.


77-89: LGTM: CUDA test correctly instantiates loss.

The test properly moves the loss module to CUDA (line 85), ensuring both model parameters and inputs are on the same device.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
monai/losses/unified_focal_loss.py (1)

74-80: Background exclusion correctly implemented.

The logic properly removes the first channel when include_background=False, consistent with FocalLoss. The single-channel warning is appropriate.

Optional: Add stacklevel to warning for better traceability
-                warnings.warn("single channel prediction, `include_background=False` ignored.")
+                warnings.warn("single channel prediction, `include_background=False` ignored.", stacklevel=2)
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between 39664ea and 41dccad.

📒 Files selected for processing (2)
  • monai/losses/unified_focal_loss.py
  • tests/losses/test_unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/losses/unified_focal_loss.py
  • tests/losses/test_unified_focal_loss.py
🧬 Code graph analysis (2)
monai/losses/unified_focal_loss.py (3)
monai/losses/focal_loss.py (1)
  • FocalLoss (26-202)
monai/networks/utils.py (1)
  • one_hot (170-220)
monai/utils/enums.py (1)
  • LossReduction (253-264)
tests/losses/test_unified_focal_loss.py (1)
monai/losses/unified_focal_loss.py (1)
  • AsymmetricUnifiedFocalLoss (129-193)
🪛 Ruff (0.14.8)
monai/losses/unified_focal_loss.py

76-76: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


83-83: Avoid specifying long messages outside the exception class

(TRY003)


126-126: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (8)
tests/losses/test_unified_focal_loss.py (3)

22-42: Binary test case is correct.

The high-confidence logits (±10.0) correctly produce near-perfect probabilities after sigmoid. The alignment between predictions and targets should yield near-zero loss.


66-70: Test structure is correct.

Parameterized test properly unpacks configuration and data, with appropriate numerical tolerances for floating-point comparison.


77-89: CUDA test correctly adapted to logits interface.

The test properly uses logits with use_softmax=False for binary segmentation and correctly moves both tensors and the loss module to CUDA.

monai/losses/unified_focal_loss.py (5)

19-19: Import is correct.

FocalLoss is properly imported from monai.losses for reuse in the unified loss.


34-58: Constructor properly extended with include_background.

The parameter is correctly documented, defaulted, and stored for use in the forward method, consistent with MONAI's loss interface patterns.


97-117: Asymmetric focal modulation correctly implemented.

Background class uses standard Dice loss while foreground classes apply focal modulation (1-dice)^(1-gamma). Clamping prevents numerical instability when dice approaches 1.0.


135-174: Composition pattern correctly implemented.

The constructor properly instantiates and configures both FocalLoss and AsymmetricFocalTverskyLoss components with shared parameters, enabling modular loss computation.


176-193: Forward method correctly combines losses.

FocalLoss operates on logits (with internal activation), while AsymmetricFocalTverskyLoss requires probabilities. The explicit softmax/sigmoid conversion for the Tversky component is correct, and the weighted combination is straightforward.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
tests/losses/test_unified_focal_loss.py (1)

26-56: Add test case for include_background=False.

Test cases cover sigmoid/softmax modes correctly, but the include_background parameter (added per PR objectives) is only tested with True. Add a multi-class case with include_background=False to validate background channel exclusion.

Example test case
# Case 2: Multi-class with background excluded
[
    {
        "use_softmax": True,
        "include_background": False,
    },
    {
        "y_pred": torch.tensor([[[[logit_pos, logit_neg]], [[logit_neg, logit_pos]], [[logit_neg, logit_neg]]]]),
        "y_true": torch.tensor([[[[1.0, 0.0]], [[0.0, 1.0]], [[0.0, 0.0]]]]),
    },
    0.0,
],
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between 41dccad and ca81e4a.

📒 Files selected for processing (1)
  • tests/losses/test_unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • tests/losses/test_unified_focal_loss.py
🧬 Code graph analysis (1)
tests/losses/test_unified_focal_loss.py (1)
monai/losses/unified_focal_loss.py (1)
  • AsymmetricUnifiedFocalLoss (129-193)
🔇 Additional comments (3)
tests/losses/test_unified_focal_loss.py (3)

22-24: LGTM - Clear test constants.

Helper logits are well-defined for creating high-confidence predictions.


62-65: LGTM - Parameterized test structure correct.

Test method properly unpacks config and data dicts.


72-84: LGTM - CUDA test properly implemented.

Test correctly uses logits and moves both tensors and loss module to CUDA.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tests/losses/test_unified_focal_loss.py (1)

79-82: Add docstring.

Per coding guidelines, add a docstring describing that this test validates shape mismatch error handling.

🧹 Nitpick comments (1)
tests/losses/test_unified_focal_loss.py (1)

26-68: Add at least one test with non-zero loss.

All test cases expect 0.0 loss with perfect predictions. Add a case with imperfect predictions (e.g., logits slightly off from ground truth) to verify the loss is actually computed, not just validating tensor shape compatibility.

Optional: Expand parameter coverage

Consider adding test cases that vary:

  • to_onehot_y=True with class-index format ground truth
  • weight, delta, gamma to non-default values
  • reduction modes (SUM, NONE)

These are optional and can be deferred.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between ca81e4a and c9002e0.

📒 Files selected for processing (1)
  • tests/losses/test_unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • tests/losses/test_unified_focal_loss.py
🧬 Code graph analysis (1)
tests/losses/test_unified_focal_loss.py (1)
monai/losses/unified_focal_loss.py (1)
  • AsymmetricUnifiedFocalLoss (129-193)
🔇 Additional comments (1)
tests/losses/test_unified_focal_loss.py (1)

22-24: LGTM—High-confidence logits for perfect-prediction tests.

Values create predictions very close to 0 or 1, suitable for validating near-zero loss on ideal inputs.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
tests/losses/test_unified_focal_loss.py (1)

26-68: Test coverage is adequate for the new interface.

The three cases properly exercise binary sigmoid, multi-class softmax, and background exclusion. All use perfect predictions (loss=0.0), which validates the happy path.

Consider adding at least one test case with imperfect predictions (non-zero expected loss) to verify the actual loss computation, not just the zero-loss edge case. This would increase confidence in the refactored implementation.

monai/losses/unified_focal_loss.py (1)

134-173: Excellent refactoring using composition.

Replacing inline logic with composed FocalLoss and AsymmetricFocalTverskyLoss improves maintainability and reusability. Parameters are correctly forwarded to both components.

Per past review feedback, the use_softmax docstring (lines 151-152) could add brief guidance: softmax for mutually exclusive classes (standard multi-class), sigmoid for multi-label/overlapping classes.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between c9002e0 and 2f4657e.

📒 Files selected for processing (2)
  • monai/losses/unified_focal_loss.py
  • tests/losses/test_unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • tests/losses/test_unified_focal_loss.py
  • monai/losses/unified_focal_loss.py
🧬 Code graph analysis (2)
tests/losses/test_unified_focal_loss.py (1)
monai/losses/unified_focal_loss.py (1)
  • AsymmetricUnifiedFocalLoss (128-195)
monai/losses/unified_focal_loss.py (3)
monai/losses/focal_loss.py (1)
  • FocalLoss (26-202)
monai/networks/utils.py (1)
  • one_hot (170-220)
monai/utils/enums.py (1)
  • LossReduction (253-264)
🪛 Ruff (0.14.8)
monai/losses/unified_focal_loss.py

83-83: Avoid specifying long messages outside the exception class

(TRY003)


125-125: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (9)
tests/losses/test_unified_focal_loss.py (3)

22-24: LGTM.

Helper constants are well-commented and appropriate for generating high-confidence predictions in tests.


74-85: LGTM.

Docstring properly documents parameters. Parameterized test structure is clean and flexible.


92-104: LGTM.

CUDA test correctly uses the new API with sigmoid activation and validates GPU compatibility.

monai/losses/unified_focal_loss.py (6)

19-19: LGTM.

Import required for the new composition-based implementation.


34-58: LGTM.

Adding include_background parameter aligns with MONAI loss function conventions and enables proper multi-class segmentation support.


74-80: LGTM.

Background exclusion logic correctly follows the FocalLoss pattern, including the single-channel warning.


106-114: Asymmetry logic is correct.

Background channel (index 0 when include_background=True) uses standard Dice loss, while foreground channels use focal modulation. When include_background=False, all channels receive focal modulation since background was removed. Clamping prevents numerical instability.


118-125: LGTM.

Reduction logic correctly handles MEAN, SUM, and NONE cases with appropriate error for unsupported values.


175-195: LGTM.

Forward pass correctly handles different input requirements: logits for FocalLoss (which applies activation internally), probabilities for AsymmetricFocalTverskyLoss. The weighted combination is straightforward and matches the documented formula.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tests/losses/test_unified_focal_loss.py (1)

117-128: Add docstring for CUDA test.

Per coding guidelines and past review comment, add a docstring describing this test's purpose: validating CUDA compatibility with perfect predictions.

Suggested docstring
     def test_with_cuda(self):
+        """Test AsymmetricUnifiedFocalLoss CUDA compatibility with perfect predictions."""
         loss = AsymmetricUnifiedFocalLoss()
🧹 Nitpick comments (3)
tests/losses/test_unified_focal_loss.py (1)

25-93: Suggest adding imperfect prediction test cases.

All three cases test perfect predictions (loss=0.0). Add at least one case with misaligned logits/labels to verify the loss computes non-zero values correctly and gradients flow properly.

Example imperfect case
[  # Case 3: Imperfect prediction
    {"use_softmax": False, "include_background": True},
    {
        "y_pred": torch.tensor([[[[0.0, -2.0], [2.0, 0.0]]]]),  # Moderate confidence
        "y_true": torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]]),
    },
    # Expected: positive loss value (compute actual expected value)
],
monai/losses/unified_focal_loss.py (2)

60-86: LGTM: Background exclusion logic is correct.

The include_background handling properly slices channel 0 from both tensors and warns on single-channel edge cases. Shape validation and clipping are correctly placed.

Note: Static analysis flags line 83 for a long exception message (TRY003). Consider a custom exception class if this pattern recurs, but current usage is acceptable.


88-125: LGTM: Asymmetric focal Tversky logic is sound.

The per-class loss correctly applies standard Tversky to background (when included) and focal-modulated Tversky to foreground. Clamping prevents numerical instability. Reduction handling is complete.

Static analysis flags line 125 for a long exception message (TRY003). Consider extracting to a constant or custom exception if this pattern is reused.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between 2f4657e and e63e36e.

📒 Files selected for processing (2)
  • monai/losses/unified_focal_loss.py
  • tests/losses/test_unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • tests/losses/test_unified_focal_loss.py
  • monai/losses/unified_focal_loss.py
🧬 Code graph analysis (2)
tests/losses/test_unified_focal_loss.py (1)
monai/losses/unified_focal_loss.py (1)
  • AsymmetricUnifiedFocalLoss (128-197)
monai/losses/unified_focal_loss.py (3)
monai/losses/focal_loss.py (1)
  • FocalLoss (26-202)
monai/networks/utils.py (1)
  • one_hot (170-220)
monai/utils/enums.py (1)
  • LossReduction (253-264)
🪛 Ruff (0.14.8)
monai/losses/unified_focal_loss.py

83-83: Avoid specifying long messages outside the exception class

(TRY003)


125-125: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (6)
tests/losses/test_unified_focal_loss.py (2)

22-23: LGTM: Clear test constants.

Module-level logit constants are well-named and appropriate for testing high-confidence predictions.


99-110: LGTM: Well-documented parameterized test.

Docstring and implementation are clear. Tolerance levels are appropriate.

monai/losses/unified_focal_loss.py (4)

19-19: LGTM: Import supports composition pattern.

FocalLoss import enables the refactored AsymmetricUnifiedFocalLoss to reuse existing focal loss implementation.


34-58: LGTM: Consistent API with include_background parameter.

Addition of include_background aligns with MONAI loss conventions. Default True preserves backward compatibility.


128-175: LGTM: Clean composition of focal components.

Refactoring to compose FocalLoss and AsymmetricFocalTverskyLoss eliminates code duplication and ensures consistent parameter handling. Docstrings clearly distinguish sigmoid vs. softmax use cases.


177-197: LGTM: Forward pass correctly combines loss components.

The focal loss operates on logits while the Tversky component operates on probabilities after explicit activation. Each component independently handles to_onehot_y and include_background, ensuring correct behavior. Weighted combination is straightforward.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
monai/losses/unified_focal_loss.py (1)

157-157: Remove Chinese comment.

-        self.use_softmax = use_softmax  # 儲存參數
+        self.use_softmax = use_softmax

This was previously flagged but remains in the code.

🧹 Nitpick comments (1)
monai/losses/unified_focal_loss.py (1)

118-125: Reduction logic is correct.

Standard reduction pattern implemented properly.

For consistency with MONAI style, consider extracting the long error message to a constant or shortening it (static analysis hint TRY003):

-        raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')
+        raise ValueError(f"Unsupported reduction: {self.reduction}")
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between e63e36e and 1d196dc.

📒 Files selected for processing (2)
  • monai/losses/unified_focal_loss.py
  • tests/losses/test_unified_focal_loss.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/losses/test_unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/losses/unified_focal_loss.py
🪛 Ruff (0.14.8)
monai/losses/unified_focal_loss.py

76-76: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


83-83: Avoid specifying long messages outside the exception class

(TRY003)


125-125: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (5)
monai/losses/unified_focal_loss.py (5)

19-19: LGTM.

Import required for the new composition-based implementation.


34-59: LGTM.

The include_background parameter is properly documented and maintains backward compatibility with True as default.


74-81: Background exclusion logic is correct.

The warning for single-channel predictions and slicing logic are appropriate.

However, add stacklevel=2 to the warning at line 76 for proper caller identification:

-                warnings.warn("single channel prediction, `include_background=False` ignored.")
+                warnings.warn("single channel prediction, `include_background=False` ignored.", stacklevel=2)

Based on static analysis hints.


159-173: Asymmetric gamma application is intentional and correct.

The Unified Focal Loss design intentionally exploits gamma asymmetry to enable simultaneous suppression and enhancement effects in its component losses. In FocalLoss, gamma down-weights easy-to-classify pixels, while in Focal Tversky Loss, gamma enhances rather than suppresses easy examples. Gamma controls weights for difficult-to-predict samples; distribution-based corrections apply sample-by-sample while region-based corrections apply class-by-class during macro-averaging. This composition pattern correctly implements the unified focal loss framework.


175-192: Forward implementation is correct.

The loss properly:

  1. Computes focal loss on logits
  2. Converts logits to probabilities for Tversky component via softmax or sigmoid
  3. Combines losses with configurable weighting

Test coverage includes both sigmoid and softmax activation paths with appropriate input dimensions.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (1)
tests/losses/test_unified_focal_loss.py (1)

93-106: Add docstring.

Per coding guidelines, add a docstring describing the test purpose.

🔎 Proposed fix
     def test_with_cuda(self):
+        """Validate CUDA compatibility of AsymmetricUnifiedFocalLoss."""
         if not torch.cuda.is_available():

Based on coding guidelines.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between 1d196dc and f7cad77.

📒 Files selected for processing (2)
  • monai/losses/unified_focal_loss.py
  • tests/losses/test_unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/losses/unified_focal_loss.py
  • tests/losses/test_unified_focal_loss.py
🧬 Code graph analysis (2)
monai/losses/unified_focal_loss.py (1)
monai/utils/enums.py (1)
  • LossReduction (253-264)
tests/losses/test_unified_focal_loss.py (3)
monai/losses/unified_focal_loss.py (1)
  • AsymmetricUnifiedFocalLoss (229-298)
tests/test_utils.py (1)
  • assert_allclose (119-159)
monai/networks/nets/quicknat.py (1)
  • is_cuda (433-437)
🪛 Ruff (0.14.10)
monai/losses/unified_focal_loss.py

83-83: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


89-89: Avoid specifying long messages outside the exception class

(TRY003)


188-188: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


194-194: Avoid specifying long messages outside the exception class

(TRY003)


291-291: Avoid specifying long messages outside the exception class

(TRY003)

tests/losses/test_unified_focal_loss.py

83-83: Unused method argument: expected_val

(ARG002)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
  • GitHub Check: min-dep-pytorch (2.6.0)
  • GitHub Check: min-dep-pytorch (2.7.1)
  • GitHub Check: min-dep-pytorch (2.8.0)
  • GitHub Check: min-dep-os (macOS-latest)
  • GitHub Check: min-dep-py3 (3.10)
  • GitHub Check: min-dep-py3 (3.9)
  • GitHub Check: min-dep-pytorch (2.5.1)
  • GitHub Check: min-dep-os (windows-latest)
  • GitHub Check: min-dep-py3 (3.12)
  • GitHub Check: min-dep-py3 (3.11)
  • GitHub Check: min-dep-os (ubuntu-latest)
  • GitHub Check: flake8-py3 (pytype)
  • GitHub Check: quick-py3 (windows-latest)
  • GitHub Check: flake8-py3 (codeformat)
  • GitHub Check: flake8-py3 (mypy)
  • GitHub Check: quick-py3 (macOS-latest)
  • GitHub Check: packaging
  • GitHub Check: quick-py3 (ubuntu-latest)
  • GitHub Check: build-docs
🔇 Additional comments (7)
tests/losses/test_unified_focal_loss.py (2)

21-66: Test case definitions are well-structured.

The test cases cover binary logits, 2-channel binary, and multi-class scenarios with appropriate shapes and parameter combinations. The use of 10.0/-10.0 logits ensures near-perfect probabilities for validation.


71-80: LGTM.

Tolerance of 1e-3 is appropriate given that logits of ±10.0 don't yield exact probabilities of 0.0/1.0.

monai/losses/unified_focal_loss.py (5)

34-60: LGTM.

The use_softmax parameter is properly integrated with clear documentation.


91-129: Loss calculations are correct.

The background dice and foreground focal-tversky computations align with the paper's formulation. The use of 1/gamma exponent for foreground classes properly implements the focal modulation.


196-226: Focal loss implementation is correct.

The asymmetric weighting (background focal, foreground standard CE) with delta balancing correctly addresses class imbalance.


237-279: Composition pattern is well-executed.

Creating internal loss instances with shared parameters ensures consistency and avoids duplication in the forward pass.


281-298: Forward logic is sound.

The shape validation correctly handles edge cases (binary logits, to_onehot_y), and the weighted combination properly unifies focal and tversky losses.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
monai/losses/unified_focal_loss.py (1)

240-240: num_classes parameter is unused.

The num_classes parameter is stored at line 260 but never referenced. Either remove it or use it.

🔎 Proposed fix to remove unused parameter
     def __init__(
         self,
         to_onehot_y: bool = False,
-        num_classes: int = 2,
         weight: float = 0.5,
         gamma: float = 0.5,
         delta: float = 0.7,
         reduction: LossReduction | str = LossReduction.MEAN,
         use_softmax: bool = False,
     ):
         """
         Args:
             to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
-            num_classes: number of classes. Defaults to 2.
             weight: weight factor to balance between Focal Loss and Tversky Loss.

And remove self.num_classes = num_classes at line 260.

Also applies to: 260-260

♻️ Duplicate comments (2)
monai/losses/unified_focal_loss.py (2)

83-83: Add stacklevel=2 to warning.

Per static analysis, add stacklevel=2 so the warning points to the caller.

-                warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
+                warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)

188-188: Add stacklevel=2 to warning.

Per static analysis, add stacklevel=2.

-                warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
+                warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)
🧹 Nitpick comments (6)
monai/losses/unified_focal_loss.py (6)

122-129: Unreachable fallback return.

The final return torch.mean(all_losses) at line 129 is unreachable for valid LossReduction values. Consider raising an error for invalid reductions or removing the redundant return.

🔎 Proposed fix
         if self.reduction == LossReduction.MEAN.value:
             return torch.mean(all_losses)
         if self.reduction == LossReduction.SUM.value:
             return torch.sum(all_losses)
         if self.reduction == LossReduction.NONE.value:
             return all_losses
-
-        return torch.mean(all_losses)
+        raise ValueError(f"Unsupported reduction: {self.reduction}")

158-159: Incomplete docstring for reduction parameter.

The reduction parameter docstring is missing its description.

-            reduction: {``"none"``, ``"mean"``, ``"sum"``}
-            use_softmax: whether to use softmax to transform logits. Defaults to False.
+            reduction: {``"none"``, ``"mean"``, ``"sum"``}
+                Specifies the reduction to apply to the output. Defaults to ``"mean"``.
+            use_softmax: whether to use softmax to transform logits. Defaults to False.

175-202: Consider extracting shared preprocessing logic.

Lines 175-202 duplicate the preprocessing from AsymmetricFocalTverskyLoss (lines 70-99): single-channel handling, one-hot conversion, shape validation, and probability conversion. Extract to a shared helper to reduce duplication.


293-294: Sub-losses preprocess inputs independently, causing duplicate work.

Both self.asy_focal_loss and self.asy_focal_tversky_loss independently apply sigmoid/softmax, one-hot encoding, and clamping to the same inputs. For performance, consider preprocessing once in this forward method and passing processed tensors to sub-losses configured to skip preprocessing.


62-68: Docstring should document return value and exceptions.

Per coding guidelines, docstrings should describe return value and raised exceptions.

         """
         Args:
             y_pred: prediction logits or probabilities. Shape should be (B, C, spatial_dims).
             y_true: ground truth labels. Shape should match y_pred.
+
+        Returns:
+            torch.Tensor: Computed loss. Shape depends on reduction setting.
+
+        Raises:
+            ValueError: If y_true and y_pred shapes don't match after preprocessing.
         """

168-173: Docstring should document return value and exceptions.

Same as AsymmetricFocalTverskyLoss.forward - add Returns and Raises sections.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between f7cad77 and 4166faa.

📒 Files selected for processing (1)
  • monai/losses/unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/losses/unified_focal_loss.py
🧬 Code graph analysis (1)
monai/losses/unified_focal_loss.py (2)
monai/utils/enums.py (1)
  • LossReduction (253-264)
monai/networks/utils.py (1)
  • one_hot (170-220)
🪛 Ruff (0.14.10)
monai/losses/unified_focal_loss.py

83-83: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


89-89: Avoid specifying long messages outside the exception class

(TRY003)


188-188: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


194-194: Avoid specifying long messages outside the exception class

(TRY003)


291-291: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (2)
monai/losses/unified_focal_loss.py (2)

288-291: Shape validation is correct but duplicated.

The shape check here is a reasonable early guard, though sub-losses will validate again after their own preprocessing. Acceptable as-is for fail-fast behavior.


296-298: LGTM!

Clean delegation to sub-losses with proper weighted combination.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
monai/losses/unified_focal_loss.py (1)

237-279: Remove unused num_classes parameter or document for backward compatibility.

The num_classes parameter is stored at line 260 but never used. It's not passed to internal loss instances and not referenced in the forward method. Either remove it or add a comment explaining why it's retained.

♻️ Duplicate comments (4)
monai/losses/unified_focal_loss.py (2)

186-191: Add stacklevel=2 to warning.

The warnings.warn call should specify stacklevel=2 so the warning points to the caller's code.

🔎 Proposed fix
         if self.to_onehot_y:
             if n_pred_ch == 1:
-                warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
+                warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)
             else:

Based on static analysis hint.


81-86: Add stacklevel=2 to warning.

The warnings.warn call should specify stacklevel=2 so the warning points to the caller's code.

🔎 Proposed fix
         if self.to_onehot_y:
             if n_pred_ch == 1:
-                warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
+                warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)
             else:

Based on static analysis hint.

tests/losses/test_unified_focal_loss.py (2)

94-107: Add docstring.

Per coding guidelines, add a docstring describing that this test validates CUDA compatibility of the loss.

🔎 Proposed fix
     def test_with_cuda(self):
+        """Verify CUDA compatibility by running loss on GPU tensors when available."""
         if not torch.cuda.is_available():

83-87: Remove unused parameter and add docstring.

The expected_val parameter is unused. Remove it from the signature and update the test case accordingly. Also add a docstring per coding guidelines.

🔎 Proposed fix
     @parameterized.expand([TEST_CASE_MULTICLASS_WRONG])
-    def test_wrong_prediction(self, input_data, expected_val, args):
+    def test_wrong_prediction(self, input_data, args):
+        """Verify that wrong predictions yield high loss values."""
         loss_func = AsymmetricUnifiedFocalLoss(**args)
         result = loss_func(**input_data)
         self.assertGreater(result.item(), 1.0, "Loss should be high for wrong predictions")

Update TEST_CASE_MULTICLASS_WRONG at line 62 to remove the None value:

 TEST_CASE_MULTICLASS_WRONG = [
     {
         "y_pred": torch.tensor(
             [[[[-10.0, -10.0], [-10.0, -10.0]], [[10.0, 10.0], [10.0, 10.0]], [[-10.0, -10.0], [-10.0, -10.0]]]]
         ),
         "y_true": torch.tensor([[[[0, 0], [0, 0]]]]),  # GT is class 0, but Pred is class 1
     },
-    None,
     {"use_softmax": True, "to_onehot_y": True},
 ]

Based on static analysis hint.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between 4166faa and 45d9877.

📒 Files selected for processing (2)
  • monai/losses/unified_focal_loss.py
  • tests/losses/test_unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • tests/losses/test_unified_focal_loss.py
  • monai/losses/unified_focal_loss.py
🧬 Code graph analysis (2)
tests/losses/test_unified_focal_loss.py (1)
monai/losses/unified_focal_loss.py (1)
  • AsymmetricUnifiedFocalLoss (229-298)
monai/losses/unified_focal_loss.py (2)
monai/utils/enums.py (1)
  • LossReduction (253-264)
monai/networks/utils.py (1)
  • one_hot (170-220)
🪛 Ruff (0.14.10)
tests/losses/test_unified_focal_loss.py

84-84: Unused method argument: expected_val

(ARG002)

monai/losses/unified_focal_loss.py

83-83: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


89-89: Avoid specifying long messages outside the exception class

(TRY003)


188-188: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


194-194: Avoid specifying long messages outside the exception class

(TRY003)


291-291: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
  • GitHub Check: quick-py3 (macOS-latest)
  • GitHub Check: flake8-py3 (pytype)
  • GitHub Check: quick-py3 (windows-latest)
  • GitHub Check: quick-py3 (ubuntu-latest)
  • GitHub Check: flake8-py3 (codeformat)
  • GitHub Check: flake8-py3 (mypy)
  • GitHub Check: min-dep-pytorch (2.8.0)
  • GitHub Check: build-docs
  • GitHub Check: packaging
  • GitHub Check: min-dep-pytorch (2.7.1)
  • GitHub Check: min-dep-py3 (3.11)
  • GitHub Check: min-dep-pytorch (2.5.1)
  • GitHub Check: min-dep-os (windows-latest)
  • GitHub Check: min-dep-os (macOS-latest)
  • GitHub Check: min-dep-pytorch (2.6.0)
  • GitHub Check: min-dep-os (ubuntu-latest)
  • GitHub Check: min-dep-py3 (3.9)
  • GitHub Check: min-dep-py3 (3.10)
  • GitHub Check: min-dep-py3 (3.12)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (2)
monai/losses/unified_focal_loss.py (2)

81-86: Add stacklevel=2 to warning (still open from past reviews).

🔎 Proposed fix
         if self.to_onehot_y:
             if n_pred_ch == 1:
-                warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
+                warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)
             else:

Based on static analysis hint.


186-191: Add stacklevel=2 to warning (still open from past reviews).

🔎 Proposed fix
         if self.to_onehot_y:
             if n_pred_ch == 1:
-                warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
+                warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)
             else:

Based on static analysis hint.

🧹 Nitpick comments (3)
monai/losses/unified_focal_loss.py (3)

110-129: Loss computation is correct but consider documenting asymmetric treatment.

Background uses standard Dice loss while foreground uses focal modulation. This asymmetry is intentional to prioritize foreground classes, but could benefit from an inline comment for future maintainers.

🔎 Optional documentation enhancement
         # Calculate losses separately for each class
-        # Background: Standard Dice Loss
+        # Background: Standard Dice Loss (no focal modulation to preserve sensitivity)
         back_dice = 1 - dice_class[:, 0]
 
-        # Foreground: Focal Tversky Loss
+        # Foreground: Focal Tversky Loss (focal modulation to down-weight easy examples)
         fore_dice = torch.pow(1 - dice_class[:, 1:], 1 / self.gamma)

175-182: Optional: Extract duplicated single-channel handling to helper function.

The same single-channel auto-conversion logic appears in both AsymmetricFocalTverskyLoss (lines 69-78) and AsymmetricFocalLoss. Consider extracting to a shared helper if more losses adopt this pattern.


266-279: Consider exposing separate gamma parameters for the two loss components.

AsymmetricFocalLoss defaults to gamma=2.0 while AsymmetricFocalTverskyLoss defaults to gamma=0.75, but AsymmetricUnifiedFocalLoss forces both to use the same gamma value. This prevents users from independently tuning focal modulation for distribution-based (CE) vs region-based (Dice) objectives.

Not blocking, but consider adding gamma_focal and gamma_tversky parameters in a future revision if users request finer control.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between 45d9877 and 05dac9e.

📒 Files selected for processing (1)
  • monai/losses/unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/losses/unified_focal_loss.py
🧬 Code graph analysis (1)
monai/losses/unified_focal_loss.py (2)
monai/utils/enums.py (1)
  • LossReduction (253-264)
monai/networks/utils.py (1)
  • one_hot (170-220)
🪛 Ruff (0.14.10)
monai/losses/unified_focal_loss.py

83-83: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


89-89: Avoid specifying long messages outside the exception class

(TRY003)


188-188: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


194-194: Avoid specifying long messages outside the exception class

(TRY003)


291-291: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (3)
monai/losses/unified_focal_loss.py (3)

69-78: Verify single-channel sigmoid conversion produces valid probability distribution.

The auto-handling converts single-channel logits to two-channel probabilities via torch.cat([1 - y_pred, y_pred], dim=1) after sigmoid. This assumes y_pred after sigmoid is the foreground probability. Confirm this matches user expectations and aligns with the rest of MONAI's binary segmentation conventions.


91-99: LGTM - probability conversion logic is correct.

The is_already_prob flag prevents double conversion for single-channel inputs, and clamping protects against numerical instability.


288-291: Shape validation logic is sound.

The check for binary logits case (y_pred.shape[1] == 1 and not self.use_softmax) correctly allows shape mismatch when appropriate.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
monai/losses/unified_focal_loss.py (1)

240-240: Remove unused num_classes parameter.

The num_classes parameter is stored but never used in the implementation. It's not passed to internal loss instances and doesn't affect behavior. Either use it to validate inputs or remove it from the interface.

♻️ Duplicate comments (3)
monai/losses/unified_focal_loss.py (3)

83-83: Add stacklevel=2 to warning for proper caller attribution.

The warning call should specify stacklevel=2 to point to the caller's code rather than this line.

🔎 Proposed fix
-                warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
+                warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)

188-188: Add stacklevel=2 to warning for proper caller attribution.

The warning call should specify stacklevel=2 to point to the caller's code.

🔎 Proposed fix
-                warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
+                warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)

293-296: Handle or reject reduction=NONE to prevent shape mismatch.

When reduction=NONE, AsymmetricFocalLoss returns shape (B, H, W, [D]) (per-pixel) while AsymmetricFocalTverskyLoss returns shape (B, C) (per-class). Line 296's addition will fail. Either document and reject NONE reduction with a runtime check, or ensure both losses return compatible shapes.

🔎 Proposed fix - reject NONE reduction
     def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
         """
         Args:
             y_pred: Prediction logits. Shape: (B, C, H, W, [D]).
                     Supports binary (C=1 or C=2) and multi-class (C>2) segmentation.
             y_true: Ground truth labels. Shape should match y_pred (or be indices if to_onehot_y is True).
         """
+        if self.reduction == LossReduction.NONE.value:
+            raise ValueError("AsymmetricUnifiedFocalLoss does not support reduction='none' due to incompatible output shapes from component losses.")
+
         if y_pred.shape != y_true.shape:
🧹 Nitpick comments (2)
monai/losses/unified_focal_loss.py (2)

51-53: Clarify sigmoid behavior for multi-channel inputs.

The docstring states sigmoid is used "for binary/multi-label" when use_softmax=False, but doesn't clarify that sigmoid is applied independently to each channel in multi-channel cases. This differs from binary-only behavior where background channel is constructed. Consider adding: "For multi-channel inputs, sigmoid is applied per-channel independently (multi-label)."


159-159: Enhance use_softmax documentation with usage guidance.

Similar to AsymmetricFocalTverskyLoss, the docstring should clarify when to use softmax (mutually exclusive classes) vs. sigmoid (multi-label/overlapping classes). Consider adding: "Use True for mutually exclusive multi-class segmentation, False for binary or multi-label scenarios."

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between 05dac9e and cbed38d.

📒 Files selected for processing (1)
  • monai/losses/unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/losses/unified_focal_loss.py

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (1)
monai/losses/unified_focal_loss.py (1)

175-200: Same redundant probability conversion issue as AsymmetricFocalTverskyLoss.

Lines 175-182 duplicate the auto-handle logic from AsymmetricFocalTverskyLoss with the same flaw: single-channel always uses sigmoid regardless of use_softmax setting.

Apply the same fix as suggested for AsymmetricFocalTverskyLoss to respect use_softmax after expanding to 2-channel.

🧹 Nitpick comments (3)
monai/losses/unified_focal_loss.py (3)

213-226: Reduction logic duplicates default fallback.

Lines 220-226 explicitly handle MEAN/SUM/NONE, then line 226 returns torch.mean(total_loss) as fallback. This fallback is unreachable if LossReduction enum is exhaustive.

🔎 Simplification

Remove redundant fallback or add a warning if an unknown reduction is encountered:

         if self.reduction == LossReduction.MEAN.value:
             return torch.mean(total_loss)
         if self.reduction == LossReduction.SUM.value:
             return torch.sum(total_loss)
         if self.reduction == LossReduction.NONE.value:
             return total_loss
-        return torch.mean(total_loss)
+        raise ValueError(f"Unsupported reduction: {self.reduction}")

Same applies to AsymmetricFocalTverskyLoss lines 122-129.


288-291: Shape validation allows mismatch only for binary logits, but one-hot conversion happens downstream.

Lines 288-291 permit shape mismatch if is_binary_logits (C=1 with sigmoid) or if to_onehot_y=True. However, the internal losses perform one-hot conversion independently. If y_true has mismatched shape and to_onehot_y=False, the internal losses will raise ValueError at their shape checks (lines 89, 194).

This validation is redundant; the internal losses already enforce shape compatibility.

🔎 Simplification

Remove this check and let internal losses handle validation:

-        if y_pred.shape != y_true.shape:
-            is_binary_logits = y_pred.shape[1] == 1 and not self.use_softmax
-            if not self.to_onehot_y and not is_binary_logits:
-                raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")
-

Or add a comment explaining why this pre-validation is needed.


89-89: Long exception messages flagged by static analysis.

Lines 89, 194, and 291 embed long f-string messages directly in ValueError. Ruff (TRY003) suggests defining exception classes or message constants for long messages.

For consistency with MONAI conventions, verify if other loss modules use inline messages or constants. If this pattern is acceptable project-wide, ignore the hint.

Also applies to: 194-194, 291-291

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between cbed38d and b08de65.

📒 Files selected for processing (1)
  • monai/losses/unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/losses/unified_focal_loss.py
🪛 Ruff (0.14.10)
monai/losses/unified_focal_loss.py

89-89: Avoid specifying long messages outside the exception class

(TRY003)


194-194: Avoid specifying long messages outside the exception class

(TRY003)


291-291: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (2)
monai/losses/unified_focal_loss.py (2)

266-279: Gamma parameters have opposite semantics in AsymmetricFocalLoss vs AsymmetricFocalTverskyLoss.

AsymmetricFocalLoss uses gamma directly: torch.pow(1 - y_pred, gamma), while AsymmetricFocalTverskyLoss uses its reciprocal: torch.pow(1 - dice_class, 1/gamma). Per the paper, Focal Tversky's optimal gamma=4/3 enhances loss (contrary to Focal loss which suppresses). Passing the same gamma=0.5 to both produces mismatched behaviors and may not match the paper's unified formulation intent.


114-115: The focal modulation formula is correct—it properly focuses on hard examples, not the reverse.

With gamma = 0.75 (default), the exponent 1/gamma = 1.333 > 1, which makes hard examples (low Dice values) contribute more to the loss than easy examples, not less. When you raise numbers to a power greater than 1, small values (easy examples where 1 - dice is small) decrease more than large values (hard examples where 1 - dice is large), so easy examples are down-weighted relative to hard examples. This is standard focal behavior and matches the docstring: "focal exponent value to down-weight easy foreground examples." The Unified Focal Loss specifies γ < 1 increases focusing on harder examples, and MONAI's reparameterization using 1/gamma as the exponent achieves this—gamma = 0.75 yields exponent 1.333, which focuses on hard examples correctly.

Likely an incorrect or invalid review comment.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (2)
monai/losses/unified_focal_loss.py (2)

176-183: Duplicate issue: single-channel incompatible with use_softmax=True.

Same issue as AsymmetricFocalTverskyLoss. Apply the same fix to handle single-channel inputs with use_softmax=True.


69-78: Single-channel input incompatible with use_softmax=True.

Line 70's condition and not self.use_softmax skips auto-expansion when use_softmax=True. A 1-channel tensor passed to torch.softmax at line 95 produces valid probabilities but remains 1-channel, causing shape mismatches downstream. Either document that single-channel requires use_softmax=False, or expand single-channel logits to 2-class before softmax.

🔎 Proposed fix
-        if y_pred.shape[1] == 1 and not self.use_softmax:
+        if y_pred.shape[1] == 1:
+            if self.use_softmax:
+                # Expand to 2-class logits for softmax: [logit] -> [-logit, logit]
+                y_pred = torch.cat([-y_pred, y_pred], dim=1)
+                is_already_prob = False
+            else:
-            y_pred = torch.sigmoid(y_pred)
-            y_pred = torch.cat([1 - y_pred, y_pred], dim=1)
-            is_already_prob = True
+                y_pred = torch.sigmoid(y_pred)
+                y_pred = torch.cat([1 - y_pred, y_pred], dim=1)
+                is_already_prob = True
             if y_true.shape[1] == 1:
                 y_true = one_hot(y_true, num_classes=2)
🧹 Nitpick comments (2)
monai/losses/unified_focal_loss.py (2)

90-90: Consider extracting exception message to a constant.

Ruff flags the inline message. Extract to a module-level constant if this pattern recurs.

Based on static analysis hint.


195-195: Consider extracting exception message to a constant.

Same static analysis hint as line 90.

Based on static analysis hint.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between b08de65 and 7a100d9.

📒 Files selected for processing (1)
  • monai/losses/unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/losses/unified_focal_loss.py
🧬 Code graph analysis (1)
monai/losses/unified_focal_loss.py (2)
monai/utils/enums.py (1)
  • LossReduction (253-264)
monai/networks/utils.py (1)
  • one_hot (170-220)
🪛 Ruff (0.14.10)
monai/losses/unified_focal_loss.py

90-90: Avoid specifying long messages outside the exception class

(TRY003)


195-195: Avoid specifying long messages outside the exception class

(TRY003)


290-290: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
  • GitHub Check: flake8-py3 (mypy)
  • GitHub Check: quick-py3 (windows-latest)
  • GitHub Check: flake8-py3 (codeformat)
  • GitHub Check: quick-py3 (ubuntu-latest)
  • GitHub Check: flake8-py3 (pytype)
  • GitHub Check: quick-py3 (macOS-latest)
  • GitHub Check: build-docs
  • GitHub Check: packaging
  • GitHub Check: min-dep-pytorch (2.5.1)
  • GitHub Check: min-dep-pytorch (2.6.0)
  • GitHub Check: min-dep-pytorch (2.8.0)
  • GitHub Check: min-dep-os (windows-latest)
  • GitHub Check: min-dep-py3 (3.12)
  • GitHub Check: min-dep-py3 (3.9)
  • GitHub Check: min-dep-os (macOS-latest)
  • GitHub Check: min-dep-os (ubuntu-latest)
  • GitHub Check: min-dep-py3 (3.11)
  • GitHub Check: min-dep-pytorch (2.7.1)
  • GitHub Check: min-dep-py3 (3.10)
🔇 Additional comments (6)
monai/losses/unified_focal_loss.py (6)

92-100: Probability conversion logic is correct.

The conditional activation and epsilon clamping properly handle both softmax and sigmoid cases while preventing redundant conversions.


111-130: Per-class loss computation handles edge cases correctly.

The positive exponent 1/self.gamma at line 116 avoids infinity when dice approaches 1.0, addressing the numerical stability concern from past reviews. The (B, C) output shape with reduction=NONE is intentional for the wrapper's spatial averaging.


197-203: Probability conversion and clamping are correct.

Consistent with AsymmetricFocalTverskyLoss implementation.


207-225: Asymmetric focal loss computation is mathematically sound.

Background focal modulation and foreground cross-entropy are correctly weighted and concatenated. The per-pixel (B, C, H, W, [D]) output with reduction=NONE properly complements AsymmetricFocalTverskyLoss.


292-300: Shape alignment correctly resolves reduction=NONE mismatch.

Lines 295-297 average the per-pixel focal loss spatially to produce (B, C), matching the Tversky loss shape. This resolves the shape mismatch concern from past reviews. The weighted combination at line 300 now operates on compatible tensors.


303-311: Final reduction logic is clean and correct.

All three reduction modes are handled explicitly with a fallback to MEAN. The implementation properly applies reduction to the combined (B, C) loss tensor.

Copy link
Member

@ericspod ericspod left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @ytl0623 thanks for this contribution! Please look at the comments from Coderabbit, I think some of the logic does need to be double checked. I have gone over the code briefly and mentioned a few things, but once you've looked at the Coderabbit reviews I can revisit things.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
monai/losses/unified_focal_loss.py (1)

296-308: Legacy preprocessing conflicts with component loss auto-handling.

Lines 296-308 apply one-hot transformations before passing to component losses, but those losses already handle single-channel expansion and to_onehot_y internally (lines 69-87, 176-192). This causes double transformation.

Example: single-channel y_pred is one-hotted at line 297, then component loss tries to sigmoid it at line 71, treating probabilities as logits.

Proposed fix

Remove preprocessing and rely on component losses:

-        if y_pred.shape[1] == 1:
-            y_pred = one_hot(y_pred, num_classes=self.num_classes)
-            y_true = one_hot(y_true, num_classes=self.num_classes)
-
-        if torch.max(y_true) != self.num_classes - 1:
-            raise ValueError(f"Please make sure the number of classes is {self.num_classes - 1}")
-
-        n_pred_ch = y_pred.shape[1]
-        if self.to_onehot_y:
-            if n_pred_ch == 1:
-                warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
-            else:
-                y_true = one_hot(y_true, num_classes=n_pred_ch)
-
         asy_focal_loss = self.asy_focal_loss(y_pred, y_true)
🤖 Fix all issues with AI agents
In @monai/losses/unified_focal_loss.py:
- Around line 246-256: The docstring for the UnifiedFocalLoss constructor/method
(the block documenting to_onehot_y, num_classes, weight, gamma, delta,
reduction, use_softmax) is missing a "Raises" section; update that docstring to
include documented exceptions (e.g., ValueError when input/target shape or
num_classes mismatch, ValueError for invalid weight/gamma/delta ranges,
TypeError for wrong tensor types) and be specific about the conditions that
trigger each exception so callers know what is validated.
- Around line 287-291: The second, unconditional raise after checking shapes is
unreachable and redundant; remove the duplicate raise so that only the
conditional raise using y_true/y_pred shapes remains (keep the existing
is_binary_logits and self.to_onehot_y checks and the single ValueError that
reports the mismatched shapes). Ensure you delete the extra `raise
ValueError(f"ground truth has different shape ({y_true.shape}) from input
({y_pred.shape})")` that follows the conditional so only the intended
conditional path raises.
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between 7a100d9 and 2066793.

📒 Files selected for processing (1)
  • monai/losses/unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/losses/unified_focal_loss.py
🧬 Code graph analysis (1)
monai/losses/unified_focal_loss.py (4)
monai/utils/enums.py (1)
  • LossReduction (253-264)
monai/losses/focal_loss.py (1)
  • forward (119-201)
monai/losses/dice.py (2)
  • forward (131-229)
  • forward (249-256)
monai/networks/utils.py (1)
  • one_hot (170-220)
🪛 Ruff (0.14.10)
monai/losses/unified_focal_loss.py

90-90: Avoid specifying long messages outside the exception class

(TRY003)


195-195: Avoid specifying long messages outside the exception class

(TRY003)


290-290: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
  • GitHub Check: min-dep-pytorch (2.7.1)
  • GitHub Check: quick-py3 (macOS-latest)
  • GitHub Check: flake8-py3 (mypy)
  • GitHub Check: quick-py3 (ubuntu-latest)
  • GitHub Check: packaging
  • GitHub Check: build-docs
  • GitHub Check: min-dep-pytorch (2.8.0)
  • GitHub Check: flake8-py3 (pytype)
  • GitHub Check: quick-py3 (windows-latest)
  • GitHub Check: flake8-py3 (codeformat)
  • GitHub Check: min-dep-py3 (3.9)
  • GitHub Check: min-dep-py3 (3.12)
  • GitHub Check: min-dep-os (ubuntu-latest)
  • GitHub Check: min-dep-os (macOS-latest)
  • GitHub Check: min-dep-pytorch (2.5.1)
  • GitHub Check: min-dep-pytorch (2.6.0)
  • GitHub Check: min-dep-py3 (3.11)
  • GitHub Check: min-dep-py3 (3.10)
  • GitHub Check: min-dep-os (windows-latest)
🔇 Additional comments (1)
monai/losses/unified_focal_loss.py (1)

176-183: Single-channel logic forces sigmoid regardless of use_softmax.

Same issue as AsymmetricFocalTverskyLoss. When use_softmax=True with single-channel input, sigmoid is applied and subsequent softmax is bypassed.

Proposed fix
-        if y_pred.shape[1] == 1 and not self.use_softmax:
+        if y_pred.shape[1] == 1:
+            y_pred = torch.cat([-y_pred, y_pred], dim=1)
-            y_pred = torch.sigmoid(y_pred)
-            y_pred = torch.cat([1 - y_pred, y_pred], dim=1)
-            is_already_prob = True
+            is_already_prob = False
             if y_true.shape[1] == 1:
                 y_true = one_hot(y_true, num_classes=2)
         else:
             is_already_prob = False

Likely an incorrect or invalid review comment.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In @tests/losses/test_unified_focal_loss.py:
- Around line 79-83: The test parameter `expected_val` in test_wrong_prediction
is unused; rename it to `_` (or remove it) in the test signature to avoid the
unused-variable warning. Update the method definition for test_wrong_prediction
that consumes parameters from TEST_CASE_MULTICLASS_WRONG and keep the rest of
the test body (instantiating AsymmetricUnifiedFocalLoss and asserting the loss)
unchanged so only the parameter name is adjusted.
🧹 Nitpick comments (2)
monai/losses/unified_focal_loss.py (2)

110-129: Consider simplifying reduction logic with elif chain.

The reduction handling is correct but uses multiple independent if statements. Using elif would make the mutually exclusive nature clearer and prevent unnecessary condition checks.

♻️ Proposed refactor
 # Apply reduction
 if self.reduction == LossReduction.MEAN.value:
     return torch.mean(all_losses)
-if self.reduction == LossReduction.SUM.value:
+elif self.reduction == LossReduction.SUM.value:
     return torch.sum(all_losses)
-if self.reduction == LossReduction.NONE.value:
+elif self.reduction == LossReduction.NONE.value:
     return all_losses
-
-return torch.mean(all_losses)
+else:
+    return torch.mean(all_losses)

89-89: Descriptive error messages are beneficial here.

Ruff flags detailed exception messages as a style issue (TRY003). However, in a library context, descriptive error messages significantly improve debugging and user experience. The current messages clearly explain what went wrong and why, which is more valuable than brevity.

Also applies to: 194-194, 304-310, 317-319

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between 2066793 and ba15ce3.

📒 Files selected for processing (2)
  • monai/losses/unified_focal_loss.py
  • tests/losses/test_unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/losses/unified_focal_loss.py
  • tests/losses/test_unified_focal_loss.py
🧬 Code graph analysis (2)
monai/losses/unified_focal_loss.py (3)
monai/utils/enums.py (1)
  • LossReduction (253-264)
monai/losses/focal_loss.py (1)
  • forward (119-201)
monai/networks/utils.py (1)
  • one_hot (170-220)
tests/losses/test_unified_focal_loss.py (2)
monai/losses/unified_focal_loss.py (1)
  • AsymmetricUnifiedFocalLoss (227-343)
tests/test_utils.py (1)
  • assert_allclose (119-159)
🪛 Ruff (0.14.10)
monai/losses/unified_focal_loss.py

89-89: Avoid specifying long messages outside the exception class

(TRY003)


194-194: Avoid specifying long messages outside the exception class

(TRY003)


304-307: Avoid specifying long messages outside the exception class

(TRY003)


310-310: Avoid specifying long messages outside the exception class

(TRY003)


317-319: Avoid specifying long messages outside the exception class

(TRY003)

tests/losses/test_unified_focal_loss.py

80-80: Unused method argument: expected_val

(ARG002)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
  • GitHub Check: min-dep-pytorch (2.8.0)
  • GitHub Check: flake8-py3 (mypy)
  • GitHub Check: min-dep-pytorch (2.6.0)
  • GitHub Check: min-dep-pytorch (2.7.1)
  • GitHub Check: flake8-py3 (codeformat)
  • GitHub Check: min-dep-py3 (3.11)
  • GitHub Check: min-dep-py3 (3.12)
  • GitHub Check: min-dep-pytorch (2.5.1)
  • GitHub Check: quick-py3 (ubuntu-latest)
  • GitHub Check: packaging
  • GitHub Check: min-dep-py3 (3.9)
  • GitHub Check: flake8-py3 (pytype)
  • GitHub Check: quick-py3 (windows-latest)
  • GitHub Check: min-dep-os (ubuntu-latest)
  • GitHub Check: min-dep-py3 (3.10)
  • GitHub Check: quick-py3 (macOS-latest)
  • GitHub Check: min-dep-os (windows-latest)
  • GitHub Check: min-dep-os (macOS-latest)
  • GitHub Check: build-docs
🔇 Additional comments (11)
tests/losses/test_unified_focal_loss.py (3)

22-67: Test cases are well-structured and comprehensive.

The new test constants cover binary (logits and 2-channel) and multi-class scenarios effectively. The use of extreme logit values (±10) ensures clear separation for perfect prediction testing.


85-88: Error message assertion updated correctly.

The regex now matches the updated error message format from the implementation.


90-101: CUDA test updated appropriately.

The explicit skip when CUDA is unavailable is clearer than decorator-based approaches. The binary logits test case with matching predictions should produce low loss values.

monai/losses/unified_focal_loss.py (8)

34-61: Parameter addition aligns with PR objectives.

The use_softmax parameter enables multi-class support while maintaining backward compatibility with default False for binary/multi-label cases. Docstring is comprehensive.


69-77: Single-channel expansion logic is correct.

The automatic expansion of single-channel predictions to 2-class probabilities with is_already_prob tracking prevents double conversion downstream. Ground truth alignment with one-hot conversion ensures shape consistency.


143-166: Parameter addition is consistent across loss classes.

The use_softmax parameter follows the same pattern as AsymmetricFocalTverskyLoss, maintaining API consistency.


175-182: Single-channel expansion conditional is appropriate.

The check and not self.use_softmax prevents expansion when softmax is requested, which is correct since softmax on a single channel would be meaningless for multi-class scenarios. For use_softmax=True, the user must provide multi-channel logits.


217-224: Verify reduction behavior consistency.

The MEAN reduction computes torch.mean(torch.sum(all_ce, dim=1)), which sums over classes then averages over batch. This differs from AsymmetricFocalTverskyLoss which applies torch.mean(all_losses) (averaging over both batch and classes). Confirm whether this asymmetry is intentional.

For AsymmetricFocalTverskyLoss:

  • MEAN: torch.mean(all_losses) → averages over (B, C)
  • SUM: torch.sum(all_losses) → sums over (B, C)
  • NONE: returns all_losses with shape (B, C)

For AsymmetricFocalLoss:

  • MEAN: torch.mean(torch.sum(all_ce, dim=1)) → sums over C, then averages over B
  • SUM: torch.sum(all_ce) → sums over (B, C, spatial)
  • NONE: returns all_ce with shape (B, C, spatial)

If the intent is for both losses to produce comparable shapes for combining in AsymmetricUnifiedFocalLoss, this inconsistency could lead to unexpected behavior.


272-286: Internal loss composition architecture is sound.

Creating internal loss instances with reduction=NONE enables flexible combination logic while reusing existing implementations. This aligns with the PR objective to avoid reimplementing core focal logic.


299-322: Shape validation and one-hot conversion logic are correct.

The validation gracefully handles binary logits and index-based ground truth, providing clear error messages when shape mismatches cannot be resolved. The index bounds check (line 316) prevents out-of-bounds errors during one-hot conversion. The internal losses will skip redundant conversion since y_true will already have the correct shape.


323-343: Loss combination and reduction logic are correct.

The spatial averaging of asy_focal_loss output (line 328) properly aligns both loss components to shape (B, C) before weighted combination. The reduction handling is clean and explicit.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In @monai/losses/unified_focal_loss.py:
- Around line 69-77: AsymmetricFocalTverskyLoss unconditionally expands
single-channel predictions to two channels, causing a shape mismatch when
AsymmetricFocalLoss only expands if not self.use_softmax; change the expansion
guard in AsymmetricFocalTverskyLoss (the block handling y_pred.shape[1] == 1) to
only perform sigmoid/concatenation and one_hot conversion when not
self.use_softmax (mirror the condition used in AsymmetricFocalLoss), and ensure
is_already_prob is set consistently; also add a unit test for
AsymmetricUnifiedFocalLoss with use_softmax=True and single-channel inputs that
verifies both sub-losses produce compatible shapes and the combined loss
computes without shape errors.
🧹 Nitpick comments (2)
monai/losses/unified_focal_loss.py (2)

196-202: Probability conversion duplicates logic from AsymmetricFocalTverskyLoss.

Lines 196-202 are identical to lines 91-99 in AsymmetricFocalTverskyLoss. Consider extracting this into a shared helper method to reduce duplication.

♻️ Refactor: Extract shared probability conversion logic

Add a helper method in the module:

def _convert_to_probabilities(y_pred: torch.Tensor, use_softmax: bool, epsilon: float) -> torch.Tensor:
    """Convert logits to probabilities using softmax or sigmoid."""
    if use_softmax:
        y_pred = torch.softmax(y_pred, dim=1)
    else:
        y_pred = torch.sigmoid(y_pred)
    return torch.clamp(y_pred, epsilon, 1.0 - epsilon)

Then replace lines 196-202 and 91-99 with a call to this helper.


89-89: Optional: Static analysis flags long exception messages.

Ruff (TRY003) suggests avoiding long messages constructed in exception calls. Consider extracting these into constants or shortening them, though this is a low-priority style improvement.

Based on static analysis hints.

♻️ Example refactor for line 304-307
+SHAPE_MISMATCH_ERROR = (
+    "Ground truth has different shape ({y_true_shape}) from input ({y_pred_shape}), "
+    "and this mismatch cannot be resolved by `to_onehot_y` or binary expansion."
+)
+
 if not is_binary_logits and not is_target_needs_onehot:
-    raise ValueError(
-        f"Ground truth has different shape ({y_true.shape}) from input ({y_pred.shape}), "
-        "and this mismatch cannot be resolved by `to_onehot_y` or binary expansion."
-    )
+    raise ValueError(SHAPE_MISMATCH_ERROR.format(y_true_shape=y_true.shape, y_pred_shape=y_pred.shape))

Also applies to: 194-194, 304-307, 310-310, 317-319

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between ba15ce3 and 8368ef2.

📒 Files selected for processing (2)
  • monai/losses/unified_focal_loss.py
  • tests/losses/test_unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • tests/losses/test_unified_focal_loss.py
  • monai/losses/unified_focal_loss.py
🧬 Code graph analysis (2)
tests/losses/test_unified_focal_loss.py (1)
monai/losses/unified_focal_loss.py (1)
  • AsymmetricUnifiedFocalLoss (227-343)
monai/losses/unified_focal_loss.py (2)
monai/utils/enums.py (1)
  • LossReduction (253-264)
monai/networks/utils.py (1)
  • one_hot (170-220)
🪛 Ruff (0.14.10)
monai/losses/unified_focal_loss.py

89-89: Avoid specifying long messages outside the exception class

(TRY003)


194-194: Avoid specifying long messages outside the exception class

(TRY003)


304-307: Avoid specifying long messages outside the exception class

(TRY003)


310-310: Avoid specifying long messages outside the exception class

(TRY003)


317-319: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
  • GitHub Check: flake8-py3 (pytype)
  • GitHub Check: quick-py3 (windows-latest)
  • GitHub Check: flake8-py3 (mypy)
  • GitHub Check: quick-py3 (ubuntu-latest)
  • GitHub Check: flake8-py3 (codeformat)
  • GitHub Check: packaging
  • GitHub Check: quick-py3 (macOS-latest)
  • GitHub Check: build-docs
  • GitHub Check: min-dep-py3 (3.11)
  • GitHub Check: min-dep-pytorch (2.7.1)
  • GitHub Check: min-dep-pytorch (2.5.1)
  • GitHub Check: min-dep-pytorch (2.8.0)
  • GitHub Check: min-dep-pytorch (2.6.0)
  • GitHub Check: min-dep-os (macOS-latest)
  • GitHub Check: min-dep-py3 (3.9)
  • GitHub Check: min-dep-os (windows-latest)
  • GitHub Check: min-dep-py3 (3.10)
  • GitHub Check: min-dep-os (ubuntu-latest)
  • GitHub Check: min-dep-py3 (3.12)
🔇 Additional comments (8)
tests/losses/test_unified_focal_loss.py (3)

22-67: Test case definitions are well-structured.

The four test case constants cover binary (logits and 2-channel) and multiclass scenarios with appropriate expected values and parameters.


72-77: Appropriate tolerance for logits-to-probability conversion.

The relaxed tolerance of atol=1e-3, rtol=1e-3 correctly accounts for numerical precision in sigmoid/softmax operations.


90-101: CUDA test now properly guards against unavailable devices.

The early return when CUDA is unavailable prevents test failures on CPU-only systems.

monai/losses/unified_focal_loss.py (5)

91-99: Probability conversion logic correctly guards against double-conversion.

The is_already_prob flag prevents applying softmax/sigmoid when probabilities were already computed during single-channel expansion.


110-129: Asymmetric treatment of background vs. foreground is intentional.

Background uses standard Dice loss (line 112), foreground applies focal Tversky with exponent 1/gamma (line 115). Per-class concatenation and reduction correctly produce shape (B, C) for NONE reduction.


204-224: Cross-entropy formulation correctly implements asymmetric focal loss.

Background applies focal modulation with gamma (line 207), foreground uses standard cross-entropy (line 210). The reduction logic at lines 217-224 correctly handles different reduction modes, summing over classes before batch reduction for MEAN/default cases.


272-285: Internal loss composition with NONE reduction enables proper alignment.

Creating asy_focal_loss and asy_focal_tversky_loss with reduction=NONE allows AsymmetricUnifiedFocalLoss.forward to align shapes and combine losses before final reduction. The instances are correctly exposed as public members.


326-331: Spatial alignment correctly handles shape mismatch between loss components.

Averaging asy_focal_loss over spatial dims (lines 327-328) produces shape (B, C) to match asy_focal_tversky_loss, enabling weighted combination at line 331.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In @monai/losses/unified_focal_loss.py:
- Around line 122-129: The reduction for AsymmetricFocalTverskyLoss is
inconsistent with AsymmetricFocalLoss; change the final reduction logic to first
sum over the class dimension per sample and then apply batch reduction.
Concretely, compute per_sample_losses = torch.sum(all_losses, dim=1) and then if
reduction == LossReduction.MEAN.value return torch.mean(per_sample_losses), elif
reduction == LossReduction.SUM.value return torch.sum(per_sample_losses), elif
reduction == LossReduction.NONE.value return all_losses, else return
torch.mean(per_sample_losses); update the reduction branch in
AsymmetricFocalTverskyLoss (function/method where all_losses is used) so
semantics match AsymmetricFocalLoss.
- Around line 69-77: The single-channel expansion in AsymmetricUnifiedFocalLoss
currently hardcodes a 2-class expansion (y_pred -> [1 - y, y]) which breaks when
self.num_classes > 2; update the branch that handles y_pred.shape[1] == 1 to
expand y_pred to self.num_classes channels (not always 2) so its shape matches
y_true (which the parent one-hot encodes with self.num_classes). Concretely: in
the y_pred single-channel handling, compute probabilities (sigmoid or softmax as
appropriate) and then build/concatenate a tensor with self.num_classes channels
(e.g., one-vs-rest probabilities or repeated/zero-filled extra channels that sum
to 1) so y_pred.shape[1] == self.num_classes; also ensure any sub-loss
constructors or calls in AsymmetricUnifiedFocalLoss receive self.num_classes
instead of defaulting to 2.
🧹 Nitpick comments (2)
monai/losses/unified_focal_loss.py (2)

175-183: Code duplication: extract single-channel expansion to helper.

Lines 175-183 duplicate logic from AsymmetricFocalTverskyLoss (lines 69-77). Extract to a shared helper method to reduce duplication and improve maintainability.

♻️ Example helper method

Add a module-level helper:

def _expand_single_channel(y_pred: torch.Tensor, y_true: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, bool]:
    """Expand single-channel predictions to 2-channel probabilities."""
    if y_pred.shape[1] == 1:
        y_pred = torch.sigmoid(y_pred)
        y_pred = torch.cat([1 - y_pred, y_pred], dim=1)
        if y_true.shape[1] == 1:
            y_true = one_hot(y_true, num_classes=2)
        return y_pred, y_true, True
    return y_pred, y_true, False

Then use it in both classes:

y_pred, y_true, is_already_prob = _expand_single_channel(y_pred, y_true)

89-89: Optional: Extract exception messages to avoid TRY003 warnings.

Static analysis flags long exception messages. Consider extracting to module-level constants for consistency with Python best practices.

Example refactor
# Module-level constants
_ERR_SHAPE_MISMATCH = "ground truth has different shape ({}) from input ({})"
_ERR_SHAPE_MISMATCH_UNIFIED = "Ground truth has different shape ({}) from input ({}), and this mismatch cannot be resolved by `to_onehot_y` or binary expansion."

# Then use in exceptions:
raise ValueError(_ERR_SHAPE_MISMATCH.format(y_true.shape, y_pred.shape))

Also applies to: 194-194, 304-310, 317-319

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between 8368ef2 and 2b72d1b.

📒 Files selected for processing (1)
  • monai/losses/unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/losses/unified_focal_loss.py
🪛 Ruff (0.14.10)
monai/losses/unified_focal_loss.py

89-89: Avoid specifying long messages outside the exception class

(TRY003)


194-194: Avoid specifying long messages outside the exception class

(TRY003)


304-307: Avoid specifying long messages outside the exception class

(TRY003)


310-310: Avoid specifying long messages outside the exception class

(TRY003)


317-319: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (5)
monai/losses/unified_focal_loss.py (5)

91-96: Probability conversion logic is correct.

Properly guards against double conversion using the is_already_prob flag.


217-224: Reduction logic is correct for focal loss.

Properly sums over classes before reducing over batch dimension for MEAN mode.


272-285: Internal loss instances configured correctly.

Using NONE reduction allows proper shape alignment in the forward method.


326-343: Loss combination correct but depends on consistent sub-loss outputs.

Shape alignment and weighted combination logic are correct. However, this depends on both sub-losses returning (B, C) tensors when using NONE reduction. The inconsistent reduction logic flagged earlier must be fixed.


299-322: Missing validation for num_classes and prediction channel mismatch.

The validation at lines 316-319 only checks ground truth indices when to_onehot_y=True, but doesn't validate that num_classes matches y_pred.shape[1]. If these mismatch (e.g., num_classes=3 but 2 prediction channels), no error is raised—the loss silently operates on the actual channel count. Consider adding:

if y_pred.shape[1] > 1 and y_pred.shape[1] != self.num_classes:
    raise ValueError(
        f"Prediction channels ({y_pred.shape[1]}) must match num_classes ({self.num_classes})."
    )

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 9

🤖 Fix all issues with AI agents
In @monai/losses/unified_focal_loss.py:
- Line 339: The combination line uses focal_aligned and asy_focal_tversky_loss
but there is a shape mismatch caused by inconsistent reduction behavior: ensure
both sub-losses produced with reduction=NONE produce the same shape (either
(B,C) or (B,C,H,W,...)). Decide the intended shape for AsymmetricFocalLoss and
AsymmetricFocalTverskyLoss (refer to AsymmetricFocalLoss and the variable
asy_focal_loss/asy_focal_tversky_loss) and make reductions consistent: if both
should be (B,C) remove the spatial averaging step that reduces spatial dims (the
code producing focal_aligned after avg); if both should be full spatial maps,
change the sub-loss reduction to keep spatial dims and compute combined_loss
per-pixel before any final spatial reduction. Update the reduction=NONE handling
and any .mean()/sum() calls so focal_aligned and asy_focal_tversky_loss share
identical shapes before computing combined_loss.
- Around line 331-336: The code assumes asy_focal_loss has spatial dims and
blindly averages dims 2+ into focal_aligned, but AsymmetricFocalLoss
(reduction=NONE) can return (B, C) so that averaging corrupts results; update
the logic in unified_focal_loss (around where asy_focal_loss and focal_aligned
are computed) to check the tensor rank before reducing—if asy_focal_loss.ndim >
2 compute spatial_dims = list(range(2, asy_focal_loss.ndim)) and mean over them,
otherwise leave asy_focal_loss unchanged (or explicitly squeeze/validate shape)
so focal_aligned is correct for both per-element and already-averaged outputs
from AsymmetricFocalLoss.
- Around line 181-188: The current detection of already-probabilistic
predictions only checks the single-channel case (y_pred.shape[1] == 1) and
misses multi-channel inputs that are already probabilities; update the logic in
UnifiedFocalLoss where y_pred/y_true are normalized to detect multi-channel
probabilistic inputs too by: if y_pred has >1 channel, test whether values lie
in [0,1] and channel-wise sums are ~1 (e.g., torch.all((y_pred>=0)&(y_pred<=1))
and torch.allclose(y_pred.sum(dim=1), torch.ones_like(...), atol=1e-4)), set
is_already_prob=True and skip applying softmax/sigmoid; also mirror handling for
y_true (skip one_hot conversion if it already appears one-hot/multi-channel);
use the same approach/pattern as in AsymmetricFocalTverskyLoss to locate the
change around the y_pred/y_true preprocessing block and the is_already_prob
flag.
- Around line 88-89: y_true is being one-hot encoded with n_pred_ch (the y_pred
channel count) which can conflict with the configured self.num_classes; update
the logic in the block that currently checks y_true.shape[1] and calls one_hot
to instead validate that n_pred_ch == self.num_classes and raise a clear error
if they differ, and if they match use self.num_classes when calling one_hot
(i.e., replace n_pred_ch with self.num_classes and add a pre-check comparing
n_pred_ch and self.num_classes so you don’t silently encode with the wrong class
count).
- Around line 77-78: The code expands y_pred to 2 channels when y_pred.shape[1]
== 1 but still calls one_hot(y_true, num_classes=self.num_classes), which can
mismatch if self.num_classes != 2; update the logic in the unified focal loss
preprocessing (around the y_pred/y_true handling) to ensure one_hot uses
num_classes=2 when y_pred was expanded from 1 channel (i.e., detect
y_pred.shape[1]==1 and call one_hot(y_true, num_classes=2)), or alternatively
validate/assert self.num_classes==2 before expanding y_pred and raising a clear
error if not.
- Around line 307-315: The current shape validation incorrectly treats C==1 with
use_softmax the same as binary logits; first add an explicit check for the
invalid combination (y_pred.shape[1] == 1 and self.use_softmax) and raise a
ValueError explaining that softmax with a single channel is invalid, then
keep/clarify the existing mismatch resolution: compute is_binary_logits as
(y_pred.shape[1] == 1 and not self.use_softmax) and is_target_needs_onehot as
(self.to_onehot_y and y_true.shape[1] == 1), and only allow shape mismatch if
is_binary_logits or is_target_needs_onehot; if neither holds, raise the existing
mismatch ValueError. Ensure you reference and update the validation block around
is_binary_logits/is_target_needs_onehot to implement these checks.
- Line 186: The one_hot call uses self.num_classes even when predictions were
expanded from a single channel; update the logic around y_true = one_hot(...) to
pass num_classes=2 when y_pred.shape[1] == 1 (i.e., predictions were expanded to
two channels) and otherwise use self.num_classes so the label one-hot size
matches the expanded y_pred; locate the usage of y_pred, y_true and the one_hot
call in unified_focal_loss.py and adjust the argument accordingly.
- Around line 196-197: The code checks y_true.shape[1] against n_pred_ch but
calls one_hot with num_classes=n_pred_ch, causing inconsistency with the class
count stored on the object; change the one_hot call to use self.num_classes (and
ensure you still validate against n_pred_ch) so replace the num_classes argument
with self.num_classes while keeping the if-condition comparing y_true.shape[1]
to n_pred_ch, updating the one_hot invocation near the y_true handling in
unified_focal_loss.py.
- Around line 72-80: The code incorrectly sets is_already_prob only when
y_pred.shape[1] == 1, causing multi-channel probability inputs passed to
UnifiedFocalLoss.forward to be treated as logits and re-activated; modify
forward (and its signature) to accept an explicit flag (e.g., is_prob_input:
bool = False) or detect probabilities by checking y_pred value range (all values
in [0,1] per channel) and set is_already_prob accordingly, then guard the
activation/softmax/sigmoid block (the lines operating on y_pred and the
subsequent softmax/sigmoid calls) to skip re-activation when is_already_prob is
True, and ensure one_hot conversion for y_true (the one_hot(y_true,
num_classes=self.num_classes) call) still happens when needed.
🧹 Nitpick comments (3)
monai/losses/unified_focal_loss.py (3)

124-132: Simplify reduction logic: default case duplicates MEAN.

Lines 132 repeats the MEAN reduction. Use else instead of the explicit default case for cleaner code.

♻️ Simplification
     if self.reduction == LossReduction.MEAN.value:
         return torch.mean(torch.sum(all_losses, dim=1))
     elif self.reduction == LossReduction.SUM.value:
         return torch.sum(all_losses)
     elif self.reduction == LossReduction.NONE.value:
         return all_losses
     else:
-        return torch.mean(torch.sum(all_losses, dim=1))
+        raise ValueError(f"Unsupported reduction: {self.reduction}")

Or simply:

     if self.reduction == LossReduction.MEAN.value:
         return torch.mean(torch.sum(all_losses, dim=1))
-    elif self.reduction == LossReduction.SUM.value:
+    if self.reduction == LossReduction.SUM.value:
         return torch.sum(all_losses)
-    elif self.reduction == LossReduction.NONE.value:
-        return all_losses
-    else:
-        return torch.mean(torch.sum(all_losses, dim=1))
+    if self.reduction == LossReduction.NONE.value:
+        return all_losses
+    return torch.mean(torch.sum(all_losses, dim=1))

223-230: Simplify reduction logic.

Same as AsymmetricFocalTverskyLoss: default case duplicates MEAN.

♻️ Simplification
     if self.reduction == LossReduction.MEAN.value:
         return torch.mean(torch.sum(all_ce, dim=1))
     if self.reduction == LossReduction.SUM.value:
         return torch.sum(all_ce)
     if self.reduction == LossReduction.NONE.value:
         return all_ce
-    
-    return torch.mean(torch.sum(all_ce, dim=1))
+    return torch.mean(torch.sum(all_ce, dim=1))  # default to MEAN

Or raise an error for unsupported reduction values.


278-293: Redundant conversions: each sub-loss re-processes inputs independently.

Both asy_focal_loss and asy_focal_tversky_loss receive the same y_pred and y_true in forward (lines 331-332). Each will independently perform single-channel expansion, one-hot conversion, and probability transformations. This redundancy wastes computation and risks inconsistency if their logic diverges.

Consider performing input preparation once in AsymmetricUnifiedFocalLoss.forward and passing normalized tensors to sub-losses, or add a flag to skip preprocessing in sub-losses.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between 2b72d1b and cf53eae.

📒 Files selected for processing (1)
  • monai/losses/unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/losses/unified_focal_loss.py
🧬 Code graph analysis (1)
monai/losses/unified_focal_loss.py (3)
monai/utils/enums.py (1)
  • LossReduction (253-264)
monai/losses/focal_loss.py (1)
  • forward (119-201)
monai/networks/utils.py (1)
  • one_hot (170-220)
🪛 Ruff (0.14.10)
monai/losses/unified_focal_loss.py

92-92: Avoid specifying long messages outside the exception class

(TRY003)


200-200: Avoid specifying long messages outside the exception class

(TRY003)


312-315: Avoid specifying long messages outside the exception class

(TRY003)


318-318: Avoid specifying long messages outside the exception class

(TRY003)


325-327: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
  • GitHub Check: flake8-py3 (mypy)
  • GitHub Check: quick-py3 (windows-latest)
  • GitHub Check: quick-py3 (ubuntu-latest)
  • GitHub Check: quick-py3 (macOS-latest)
  • GitHub Check: build-docs
  • GitHub Check: flake8-py3 (codeformat)
  • GitHub Check: packaging
  • GitHub Check: flake8-py3 (pytype)
  • GitHub Check: min-dep-os (ubuntu-latest)
  • GitHub Check: min-dep-py3 (3.9)
  • GitHub Check: min-dep-py3 (3.12)
  • GitHub Check: min-dep-pytorch (2.8.0)
  • GitHub Check: min-dep-py3 (3.10)
  • GitHub Check: min-dep-pytorch (2.7.1)
  • GitHub Check: min-dep-os (windows-latest)
  • GitHub Check: min-dep-py3 (3.11)
  • GitHub Check: min-dep-pytorch (2.6.0)
  • GitHub Check: min-dep-os (macOS-latest)
  • GitHub Check: min-dep-pytorch (2.5.1)
🔇 Additional comments (3)
monai/losses/unified_focal_loss.py (3)

320-330: Good index validation before one-hot conversion.

This correctly validates that ground truth indices don't exceed num_classes before converting to one-hot. This prevents silent errors.


342-350: Reduction logic is correct and consistent.

The final reduction properly handles MEAN, SUM, and NONE cases with a sensible default.


118-118: Implementation is correct—foreground classes intentionally receive focal weighting.

The [:, 1:] slicing applies focal gamma to all foreground classes (multi-class safe), which matches the Unified Focal Loss paper's asymmetric design: background gets standard loss suppression while foreground retains focal suppression to avoid under-weighting rare classes. Tests and docstring ("prioritizes foreground classes") confirm this is intended behavior.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In @monai/losses/unified_focal_loss.py:
- Around line 210-227: In forward of unified_focal_loss.py there’s a
double-activation when handling single-channel inputs: the code applies
torch.sigmoid to y_pred then expands to 2 channels and later code (which assumes
channels==2) may apply sigmoid/softmax again. Fix by recording the original
channel count (e.g., was_single_channel = (y_pred.shape[1] == 1)) before any
conversion, only apply torch.sigmoid when was_single_channel and
self.is_prob_input is False, and then set or propagate a local flag (or set
self.is_prob_input equivalent for the remainder of the method) so downstream
activation/softmax logic skips re-applying activation to the
already-probabilistic expanded y_pred; also ensure y_true one-hot conversion
uses the same was_single_channel check to match the expanded y_pred.
- Around line 75-92: The single-channel branch applies torch.sigmoid to y_pred
when not self.is_prob_input and then expands to two channels, but later the
multi-channel path still checks not self.is_prob_input and may apply
sigmoid/softmax again; fix this by marking the prediction as already converted
to probabilities after the single-channel conversion (e.g., set a local flag
like converted_to_prob=True or temporarily set a variable used by the later
check), and update the later conversion condition (the block guarded by not
self.is_prob_input and y_pred.shape[1] != 1) to skip conversion if the
prediction was already converted; references: y_pred, self.is_prob_input, the
single-channel sigmoid application, the expansion via torch.cat, and the later
multi-channel conversion check.
🧹 Nitpick comments (2)
monai/losses/unified_focal_loss.py (2)

68-158: Consider extracting shared preprocessing logic.

AsymmetricFocalTverskyLoss.forward and AsymmetricFocalLoss.forward have ~70 lines of nearly identical input preprocessing (single-channel handling, one-hot conversion, probability conversion). This violates DRY and increases maintenance burden.

Extract a shared helper method or mixin for input preprocessing.

Also applies to: 203-284


151-158: Default case duplicates MEAN reduction.

Both sub-losses have identical logic in if/elif/else chains where the final else and MEAN branch do the same thing. Minor redundancy but harmless.

Could simplify to:

if self.reduction == LossReduction.SUM.value:
    return torch.sum(...)
if self.reduction == LossReduction.NONE.value:
    return ...
return torch.mean(...)  # default to MEAN

Also applies to: 276-284

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between cf53eae and bf80ed6.

📒 Files selected for processing (2)
  • monai/losses/unified_focal_loss.py
  • tests/losses/test_unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • tests/losses/test_unified_focal_loss.py
  • monai/losses/unified_focal_loss.py
🧬 Code graph analysis (1)
tests/losses/test_unified_focal_loss.py (1)
monai/losses/unified_focal_loss.py (1)
  • AsymmetricUnifiedFocalLoss (287-434)
🪛 Ruff (0.14.10)
monai/losses/unified_focal_loss.py

79-81: Avoid specifying long messages outside the exception class

(TRY003)


97-100: Avoid specifying long messages outside the exception class

(TRY003)


112-114: Avoid specifying long messages outside the exception class

(TRY003)


118-118: Avoid specifying long messages outside the exception class

(TRY003)


214-216: Avoid specifying long messages outside the exception class

(TRY003)


232-235: Avoid specifying long messages outside the exception class

(TRY003)


247-249: Avoid specifying long messages outside the exception class

(TRY003)


253-253: Avoid specifying long messages outside the exception class

(TRY003)


330-332: Avoid specifying long messages outside the exception class

(TRY003)


378-378: Avoid specifying long messages outside the exception class

(TRY003)


382-385: Avoid specifying long messages outside the exception class

(TRY003)


394-397: Avoid specifying long messages outside the exception class

(TRY003)


402-404: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
  • GitHub Check: min-dep-pytorch (2.6.0)
  • GitHub Check: min-dep-pytorch (2.5.1)
  • GitHub Check: min-dep-pytorch (2.7.1)
  • GitHub Check: min-dep-py3 (3.9)
  • GitHub Check: min-dep-pytorch (2.8.0)
  • GitHub Check: min-dep-os (ubuntu-latest)
  • GitHub Check: min-dep-py3 (3.12)
  • GitHub Check: min-dep-os (macOS-latest)
  • GitHub Check: min-dep-os (windows-latest)
  • GitHub Check: min-dep-py3 (3.11)
  • GitHub Check: min-dep-py3 (3.10)
  • GitHub Check: packaging
  • GitHub Check: quick-py3 (ubuntu-latest)
  • GitHub Check: quick-py3 (windows-latest)
  • GitHub Check: flake8-py3 (mypy)
  • GitHub Check: flake8-py3 (codeformat)
  • GitHub Check: quick-py3 (macOS-latest)
  • GitHub Check: build-docs
  • GitHub Check: flake8-py3 (pytype)
🔇 Additional comments (22)
tests/losses/test_unified_focal_loss.py (17)

22-30: LGTM: Binary logits test case is well-defined.

Clear test case with strong logits (±10) ensuring near-perfect predictions after sigmoid.


32-40: LGTM: Binary 2-channel test case properly exercises softmax path.

Tests the multi-channel binary case with use_softmax=True and to_onehot_y=True.


42-58: LGTM: Multiclass perfect prediction case is comprehensive.

3-class setup with correct one-hot ground truth indices (0, 1, 2) and matching logits.


60-70: LGTM: Wrong prediction case validates high loss behavior.

All predictions favor class 1 while GT is class 0 — good adversarial test.


72-80: LGTM: Probability input test case validates is_prob_input flag.

Probabilities sum to 1 across channels, testing bypass of sigmoid/softmax conversion.


82-90: LGTM: Shape mismatch case for error validation.

Tests 4×4 vs 2×2 spatial mismatch to trigger shape validation error.


95-108: LGTM: Perfect prediction test handles both "small" checks and exact values.

Logic correctly branches on expected_val == "small" vs numeric expectations.


109-114: LGTM: Wrong prediction test validates high loss.

Threshold of 1.0 is reasonable for completely wrong predictions.


116-120: LGTM: Shape mismatch error test updated with correct regex.

Matches the error message from the implementation.


122-126: LGTM: Validates use_softmax with single channel error.

Tests the specific edge case where softmax cannot be applied to C=1 input.


128-135: LGTM: num_classes mismatch validation test.

Tests 3-channel input with num_classes=4 to trigger channel/class mismatch error.


137-141: LGTM: Single channel binary constraint test.

Validates that single-channel input enforces num_classes=2.


143-162: LGTM: Reduction modes comprehensively tested.

Tests MEAN (scalar), SUM (scalar, larger than mean), and NONE (shape [B, C]).


164-176: LGTM: CUDA test properly guarded with skipTest.

Validates both device placement and reasonable loss magnitude.


178-188: LGTM: Gradient flow test ensures backpropagation works.

Checks grad is not None and all values are finite.


190-200: LGTM: Batch processing test validates multi-sample batches.

Random inputs with batch_size=4, checks scalar output and finiteness.


202-216: LGTM: Multiclass one-hot conversion test is well-structured.

Uses 3-class predictions with integer class indices as ground truth.

monai/losses/unified_focal_loss.py (5)

34-66: LGTM: Constructor parameters well-documented.

New num_classes, use_softmax, and is_prob_input parameters align with PR objectives.


139-158: LGTM: Per-class loss computation and reduction logic.

Background Dice loss and focal Tversky foreground losses correctly concatenated and reduced.


287-360: LGTM: Unified loss initialization with proper sub-loss configuration.

Sub-losses correctly initialized with reduction=LossReduction.NONE for per-element aggregation.


376-404: LGTM: Input validation is comprehensive.

Covers dimensionality, softmax+single-channel, shape mismatches, and class index bounds.


406-433: LGTM: Loss combination and reduction logic is correct.

Spatial averaging aligns focal loss to (B, C), weighted combination, and proper reduction handling.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (5)
monai/losses/unified_focal_loss.py (4)

45-58: Docstring missing Raises section.

Per coding guidelines, docstrings should describe raised exceptions. This method can raise ValueError for invalid num_classes or channel mismatches.


147-154: Unreachable default case.

Line 154 is unreachable since LossReduction(reduction) in __init__ validates the input. Consider raising ValueError for unexpected values or removing the fallback.

Suggested fix
         elif self.reduction == LossReduction.NONE.value:
             return all_losses
         else:
-            return torch.mean(torch.sum(all_losses, dim=1))
+            raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')

247-253: Same redundant condition as in AsymmetricFocalTverskyLoss.

The y_pred.shape[1] != 1 check is always true after the single-channel expansion at lines 207-221.


396-400: Redundant preprocessing in sub-losses.

Both asy_focal_loss and asy_focal_tversky_loss independently perform sigmoid/softmax conversion, single-channel expansion, and one-hot encoding on the same inputs. Consider preprocessing once in AsymmetricUnifiedFocalLoss.forward and passing is_prob_input=True with pre-expanded tensors to sub-losses.

tests/losses/test_unified_focal_loss.py (1)

90-210: Consider adding 3D (volumetric) input test.

All test cases use 4D tensors (2D spatial). Since the implementation supports 5D (3D spatial), consider adding a test case with shape like (B, C, D, H, W).

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between bf80ed6 and f0772ed.

📒 Files selected for processing (2)
  • monai/losses/unified_focal_loss.py
  • tests/losses/test_unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • tests/losses/test_unified_focal_loss.py
  • monai/losses/unified_focal_loss.py
🧬 Code graph analysis (2)
tests/losses/test_unified_focal_loss.py (1)
monai/losses/unified_focal_loss.py (1)
  • AsymmetricUnifiedFocalLoss (279-424)
monai/losses/unified_focal_loss.py (3)
monai/utils/enums.py (1)
  • LossReduction (253-264)
monai/losses/focal_loss.py (1)
  • forward (119-201)
monai/networks/utils.py (1)
  • one_hot (170-220)
🪛 Ruff (0.14.10)
monai/losses/unified_focal_loss.py

79-79: Avoid specifying long messages outside the exception class

(TRY003)


95-98: Avoid specifying long messages outside the exception class

(TRY003)


110-110: Avoid specifying long messages outside the exception class

(TRY003)


114-114: Avoid specifying long messages outside the exception class

(TRY003)


210-210: Avoid specifying long messages outside the exception class

(TRY003)


226-229: Avoid specifying long messages outside the exception class

(TRY003)


241-241: Avoid specifying long messages outside the exception class

(TRY003)


245-245: Avoid specifying long messages outside the exception class

(TRY003)


322-322: Avoid specifying long messages outside the exception class

(TRY003)


368-368: Avoid specifying long messages outside the exception class

(TRY003)


372-375: Avoid specifying long messages outside the exception class

(TRY003)


384-387: Avoid specifying long messages outside the exception class

(TRY003)


392-394: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (9)
monai/losses/unified_focal_loss.py (2)

116-122: Redundant condition and potential semantic issue.

After the single-channel block (lines 76-90), y_pred.shape[1] is always ≥ 2. The y_pred.shape[1] != 1 check is always true here.

More importantly: with use_softmax=False on multi-channel input, sigmoid is applied per-channel independently. This produces values that don't sum to 1, which may be unintended for exclusive multi-class segmentation.


402-423: LGTM.

The spatial averaging alignment and weighted combination logic is correct. The reduction handling follows the expected pattern.

tests/losses/test_unified_focal_loss.py (7)

22-87: Well-structured test cases covering key scenarios.

Good coverage of binary/multi-class, logits/probabilities, and error conditions.


92-112: LGTM.

Parameterized tests with sensible thresholds for perfect vs. wrong predictions.


114-139: Good validation coverage.

Error message patterns correctly match the implementation's ValueError messages.


141-160: LGTM.

Reduction mode tests correctly verify output shapes and the expected relationship between SUM and MEAN.


162-174: LGTM.

Proper CUDA availability check with skipTest fallback.


176-186: Essential gradient flow test.

Properly verifies gradients are computed and finite.


188-210: LGTM.

Batch processing and one-hot conversion tests provide good coverage for realistic usage patterns.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
monai/losses/unified_focal_loss.py (1)

305-327: Docstring incomplete: missing Raises section and public members.

Document the ValueError exceptions and the exposed public members asy_focal_loss and asy_focal_tversky_loss.

🤖 Fix all issues with AI agents
In @monai/losses/unified_focal_loss.py:
- Line 275: Add a guard that enforces gamma > 0 in the constructors for both the
UnifiedFocalLoss-related class that uses self.gamma (referenced where back_ce =
torch.pow(1 - y_pred[:, 0], self.gamma) * ...) and the
AsymmetricFocalTverskyLoss class (which uses 1/gamma). In each __init__ (look
for the class constructors named around
UnifiedFocalLoss/AsymmetricFocalTverskyLoss), raise a ValueError with a clear
message if gamma <= 0 so callers cannot pass zero or negative gamma; keep the
rest of logic unchanged.
- Around line 34-58: Update the __init__ docstring for UnifiedFocalLoss to
include a "Raises" section that lists the exceptions the class/forward may
raise: e.g., ValueError for invalid num_classes in __init__, ValueError for
mismatched shapes between y_pred and y in forward, ValueError when
probabilities/logits expectations (is_prob_input/use_softmax) are violated, and
any TypeError if inputs are not tensors; reference the class name
UnifiedFocalLoss and its forward method so readers can locate where these errors
originate.
- Around line 176-197: Add a "Raises" section to the UnifiedFocalLoss __init__
docstring that documents the ValueError exceptions thrown by the forward method:
specify that forward raises ValueError if use_softmax and is_prob_input are both
True (invalid input mode), if the prediction and target tensors have mismatched
shapes or channel count that doesn't match num_classes, and if invalid
configuration values are provided (e.g., unsupported reduction string or
inconsistent hyperparameters); keep each condition short and reference the
forward method and the parameters (use_softmax, is_prob_input, num_classes,
reduction) so callers know when these ValueErrors can occur.
- Line 348: The initializer sets self.weight without validating its range; add a
check in the UnifiedFocalLoss (or its __init__) where self.weight is assigned to
ensure weight is a numeric value between 0 and 1 inclusive, and raise a
ValueError with a clear message like "weight must be between 0 and 1" if it is
outside that range or not a number; keep the assignment to self.weight only
after the validation passes.
- Line 62: The delta assignment in UnifiedFocalLoss currently lacks validation;
add a check in the UnifiedFocalLoss initializer after receiving delta to ensure
0 <= delta <= 1 and raise a ValueError with a clear message if out of range;
follow the same pattern used in AsymmetricFocalLoss (the delta validation
at/around line 201) so both classes validate delta consistently.
- Line 148: The expression computing fore_dice uses 1 / self.gamma and will
raise if gamma == 0; add validation in the UnifiedFocalLoss.__init__ to require
gamma > 0 (e.g., if self.gamma <= 0: raise ValueError("gamma must be > 0")) so
instantiation fails early with a clear error; update any constructor docstring
or parameter validation calls in __init__ accordingly and ensure subsequent code
(e.g., the fore_dice computation in forward using self.gamma) can assume gamma
is positive.
- Around line 285-294: The MEAN reduction in AsymmetricFocalLoss is inconsistent
with AsymmetricFocalTverskyLoss: change the MEAN branch that currently returns
torch.mean(all_ce) to match the per-sample semantics used in
AsymmetricFocalTverskyLoss (i.e., reduce class/spatial dims per sample then
average over the batch) or explicitly document and test the difference;
specifically update the code handling self.reduction == LossReduction.MEAN.value
to compute per-sample sums of all_ce and then torch.mean over those sums
(matching AsymmetricFocalTverskyLoss which uses torch.mean(torch.sum(all_losses,
dim=1))), and add/adjust unit tests and docstring to reflect the chosen
semantics.
🧹 Nitpick comments (5)
monai/losses/unified_focal_loss.py (5)

79-119: Binary expansion and one-hot logic is correct, but duplicated.

Lines 79-119 are nearly identical to AsymmetricFocalLoss lines 218-258. Consider extracting to a shared helper method to reduce duplication.


218-258: Identical code duplicated from AsymmetricFocalTverskyLoss.

Lines 218-258 duplicate lines 79-119 from AsymmetricFocalTverskyLoss. Extract to a shared helper method (e.g., _prepare_inputs) to eliminate duplication.


353-370: Exposes sub-losses as public API without documentation.

self.asy_focal_loss and self.asy_focal_tversky_loss are now part of the public API but undocumented. Add a docstring note or consider prefixing with _ if they should remain internal.


410-415: Redundant validation: sub-losses already validate indices.

Lines 410-415 pre-validate to_onehot_y indices, but both sub-losses perform the same check in their forward methods. Remove this duplication.

♻️ Proposed removal
-        # Pre-process y_true if needed (will be done inside sub-losses, but validate here)
-        if self.to_onehot_y and y_true.shape[1] == 1:
-            # Check indices validity before conversion
-            if torch.max(y_true) >= self.num_classes:
-                raise ValueError(
-                    f"Ground truth contains class indices >= {self.num_classes}, which exceeds num_classes."
-                )
-
         # Get losses from sub-losses

82-84: Static analysis: long exception messages (TRY003).

Ruff flags long messages in exception constructors as a style issue. MONAI may prefer detailed messages for user clarity. Consider extracting to constants if consistency is desired, but not critical.

Also applies to: 101-104, 116-118, 122-122, 221-223, 240-243, 255-257, 261-261, 340-342, 388-388, 392-395, 404-407, 413-415

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between f0772ed and a89727d.

📒 Files selected for processing (1)
  • monai/losses/unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/losses/unified_focal_loss.py
🧬 Code graph analysis (1)
monai/losses/unified_focal_loss.py (2)
monai/utils/enums.py (1)
  • LossReduction (253-264)
monai/networks/utils.py (1)
  • one_hot (170-220)
🪛 Ruff (0.14.10)
monai/losses/unified_focal_loss.py

82-84: Avoid specifying long messages outside the exception class

(TRY003)


101-104: Avoid specifying long messages outside the exception class

(TRY003)


116-118: Avoid specifying long messages outside the exception class

(TRY003)


122-122: Avoid specifying long messages outside the exception class

(TRY003)


221-223: Avoid specifying long messages outside the exception class

(TRY003)


240-243: Avoid specifying long messages outside the exception class

(TRY003)


255-257: Avoid specifying long messages outside the exception class

(TRY003)


261-261: Avoid specifying long messages outside the exception class

(TRY003)


340-342: Avoid specifying long messages outside the exception class

(TRY003)


388-388: Avoid specifying long messages outside the exception class

(TRY003)


392-395: Avoid specifying long messages outside the exception class

(TRY003)


404-407: Avoid specifying long messages outside the exception class

(TRY003)


413-415: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
  • GitHub Check: min-dep-pytorch (2.5.1)
  • GitHub Check: min-dep-pytorch (2.7.1)
  • GitHub Check: min-dep-pytorch (2.8.0)
  • GitHub Check: min-dep-py3 (3.10)
  • GitHub Check: min-dep-py3 (3.9)
  • GitHub Check: min-dep-os (macOS-latest)
  • GitHub Check: min-dep-pytorch (2.6.0)
  • GitHub Check: min-dep-py3 (3.12)
  • GitHub Check: min-dep-py3 (3.11)
  • GitHub Check: min-dep-os (ubuntu-latest)
  • GitHub Check: min-dep-os (windows-latest)
  • GitHub Check: flake8-py3 (mypy)
  • GitHub Check: packaging
  • GitHub Check: flake8-py3 (codeformat)
  • GitHub Check: flake8-py3 (pytype)
  • GitHub Check: quick-py3 (windows-latest)
  • GitHub Check: build-docs
  • GitHub Check: quick-py3 (ubuntu-latest)
  • GitHub Check: quick-py3 (macOS-latest)
🔇 Additional comments (5)
monai/losses/unified_focal_loss.py (5)

75-76: Good practice: tracking probability conversion state.

The local variable is_converted_to_prob prevents redundant activation application.


155-162: Reduction logic correctly implements NONE/MEAN/SUM.

The implementation properly handles per-class losses with appropriate aggregation for each reduction mode.


339-342: Good validation: prevents invalid softmax configuration.

Correctly catches misconfiguration early.


423-433: Correct shape alignment and weighted combination.

Properly reduces focal loss spatial dimensions to match Tversky loss shape before combining. The weighted sum implements the unified loss correctly.


34-163: Add tests for parameter validation edge cases.

Tests lack coverage for: gamma=0 (causes division error in 1 / self.gamma), gamma < 0, delta out of valid range [0, 1], and epsilon ≤ 0. These parameters currently have no bounds validation in __init__.

Comment on lines 34 to 54
def __init__(
self,
to_onehot_y: bool = False,
num_classes: int = 2,
delta: float = 0.7,
gamma: float = 0.75,
epsilon: float = 1e-7,
reduction: LossReduction | str = LossReduction.MEAN,
use_softmax: bool = False,
is_prob_input: bool = False,
) -> None:
"""
Args:
to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
delta : weight of the background. Defaults to 0.7.
gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75.
epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
num_classes: number of classes. Defaults to 2.
delta: weight of the background class (used in the Tversky index denominator). Defaults to 0.7.
gamma: focal exponent value to down-weight easy foreground examples. Defaults to 0.75.
epsilon: a small value to prevent division by zero. Defaults to 1e-7.
reduction: {``"none"``, ``"mean"``, ``"sum"``}
Specifies the reduction to apply to the output. Defaults to ``"mean"``.
use_softmax: whether to use softmax to transform original logits into probabilities.
If True, softmax is used (for multi-class). If False, sigmoid is used (for binary/multi-label).
Defaults to False.
is_prob_input: whether input is already probabilities (vs logits). Defaults to False.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Docstring incomplete: missing Raises section.

The docstring should document the exceptions raised in the forward method (ValueError for shape mismatches, invalid num_classes, etc.).

🤖 Prompt for AI Agents
In @monai/losses/unified_focal_loss.py around lines 34 - 58, Update the __init__
docstring for UnifiedFocalLoss to include a "Raises" section that lists the
exceptions the class/forward may raise: e.g., ValueError for invalid num_classes
in __init__, ValueError for mismatched shapes between y_pred and y in forward,
ValueError when probabilities/logits expectations (is_prob_input/use_softmax)
are violated, and any TypeError if inputs are not tensors; reference the class
name UnifiedFocalLoss and its forward method so readers can locate where these
errors originate.

super().__init__(reduction=LossReduction(reduction).value)
self.to_onehot_y = to_onehot_y
self.num_classes = num_classes
self.delta = delta
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Missing validation: delta should be in valid range.

delta is a weight parameter but lacks range validation. Add validation to ensure 0 <= delta <= 1.

✅ Proposed fix
 self.delta = delta
+if not 0 <= self.delta <= 1:
+    raise ValueError(f"delta must be in [0, 1], got {self.delta}")
 self.gamma = gamma

Apply similar validation in AsymmetricFocalLoss (line 201).

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
self.delta = delta
self.delta = delta
if not 0 <= self.delta <= 1:
raise ValueError(f"delta must be in [0, 1], got {self.delta}")
🤖 Prompt for AI Agents
In @monai/losses/unified_focal_loss.py at line 62, The delta assignment in
UnifiedFocalLoss currently lacks validation; add a check in the UnifiedFocalLoss
initializer after receiving delta to ensure 0 <= delta <= 1 and raise a
ValueError with a clear message if out of range; follow the same pattern used in
AsymmetricFocalLoss (the delta validation at/around line 201) so both classes
validate delta consistently.

Comment on lines 176 to 160
def __init__(
self,
to_onehot_y: bool = False,
num_classes: int = 2,
delta: float = 0.7,
gamma: float = 2,
gamma: float = 2.0,
epsilon: float = 1e-7,
reduction: LossReduction | str = LossReduction.MEAN,
use_softmax: bool = False,
is_prob_input: bool = False,
):
"""
Args:
to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
delta : weight of the background. Defaults to 0.7.
gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75.
epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
num_classes: number of classes. Defaults to 2.
delta: weight for the foreground classes. Defaults to 0.7.
gamma: focusing parameter for the background class (to down-weight easy background examples). Defaults to 2.0.
epsilon: a small value to prevent calculation errors. Defaults to 1e-7.
reduction: {``"none"``, ``"mean"``, ``"sum"``}
use_softmax: whether to use softmax to transform logits. Defaults to False.
is_prob_input: whether input is already probabilities (vs logits). Defaults to False.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Docstring incomplete: missing Raises section.

Document the ValueError exceptions raised in forward method.

🤖 Prompt for AI Agents
In @monai/losses/unified_focal_loss.py around lines 176 - 197, Add a "Raises"
section to the UnifiedFocalLoss __init__ docstring that documents the ValueError
exceptions thrown by the forward method: specify that forward raises ValueError
if use_softmax and is_prob_input are both True (invalid input mode), if the
prediction and target tensors have mismatched shapes or channel count that
doesn't match num_classes, and if invalid configuration values are provided
(e.g., unsupported reduction string or inconsistent hyperparameters); keep each
condition short and reference the forward method and the parameters
(use_softmax, is_prob_input, num_classes, reduction) so callers know when these
ValueErrors can occur.

cross_entropy = -y_true * torch.log(y_pred)

# Background (Channel 0): Focal Loss
back_ce = torch.pow(1 - y_pred[:, 0], self.gamma) * cross_entropy[:, 0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Gamma=0 causes division by zero in Tversky loss.

While this line uses torch.pow multiplicatively (safe with gamma=0), the AsymmetricFocalTverskyLoss uses 1/gamma which will fail. For consistency and to prevent confusion, validate gamma > 0 in both classes.

🤖 Prompt for AI Agents
In @monai/losses/unified_focal_loss.py at line 275, Add a guard that enforces
gamma > 0 in the constructors for both the UnifiedFocalLoss-related class that
uses self.gamma (referenced where back_ce = torch.pow(1 - y_pred[:, 0],
self.gamma) * ...) and the AsymmetricFocalTverskyLoss class (which uses
1/gamma). In each __init__ (look for the class constructors named around
UnifiedFocalLoss/AsymmetricFocalTverskyLoss), raise a ValueError with a clear
message if gamma <= 0 so callers cannot pass zero or negative gamma; keep the
rest of logic unchanged.

self.weight: float = weight
self.asy_focal_loss = AsymmetricFocalLoss(gamma=self.gamma, delta=self.delta)
self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(gamma=self.gamma, delta=self.delta)
self.weight = weight
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Missing validation: weight should be in [0, 1].

weight balances two losses but lacks range validation.

✅ Proposed fix
 self.weight = weight
+if not 0 <= self.weight <= 1:
+    raise ValueError(f"weight must be in [0, 1], got {self.weight}")
 self.use_softmax = use_softmax
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
self.weight = weight
self.weight = weight
if not 0 <= self.weight <= 1:
raise ValueError(f"weight must be in [0, 1], got {self.weight}")
🤖 Prompt for AI Agents
In @monai/losses/unified_focal_loss.py at line 348, The initializer sets
self.weight without validating its range; add a check in the UnifiedFocalLoss
(or its __init__) where self.weight is assigned to ensure weight is a numeric
value between 0 and 1 inclusive, and raise a ValueError with a clear message
like "weight must be between 0 and 1" if it is outside that range or not a
number; keep the assignment to self.weight only after the validation passes.

Signed-off-by: ytl0623 <david89062388@gmail.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
monai/losses/unified_focal_loss.py (3)

296-298: Incorrect one_hot conversion of predictions - critical bug.

Line 297 converts y_pred logits to one-hot, which is wrong. Predictions should remain as logits/probabilities, only y_true labels should be one-hot encoded.

This corrupts predictions for single-channel input.

🐛 Proposed fix
         if y_pred.shape[1] == 1:
-            y_pred = one_hot(y_pred, num_classes=self.num_classes)
-            y_true = one_hot(y_true, num_classes=self.num_classes)
+            if y_true.shape[1] == 1:
+                y_true = one_hot(y_true, num_classes=self.num_classes)

Note: y_pred conversion should be handled by sub-losses (lines 70-72, 176-178).


300-301: Inverted validation logic - critical bug.

Line 300 raises error when torch.max(y_true) != self.num_classes - 1, meaning it requires max to NOT equal num_classes-1. This is backwards.

For valid labels [0, 1, ..., num_classes-1], max should EQUAL num_classes-1.

🐛 Proposed fix
-        if torch.max(y_true) != self.num_classes - 1:
-            raise ValueError(f"Please make sure the number of classes is {self.num_classes - 1}")
+        if torch.max(y_true) >= self.num_classes:
+            raise ValueError(f"Ground truth contains class indices >= {self.num_classes}. Expected indices in [0, {self.num_classes - 1}].")

280-330: Complete forward docstring.

Missing Returns and Raises sections. Document the shape handling and sub-loss combination.

📝 Enhanced docstring
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
    """
    Args:
        y_pred: Prediction logits. Shape: (B, C, H, W, [D]).
                Supports binary (C=1 or C=2) and multi-class (C>2) segmentation.
        y_true: Ground truth labels. Shape should match y_pred (or be indices if to_onehot_y is True).

    Returns:
        torch.Tensor: Combined weighted loss. Shape depends on reduction:
            - "none": (B, C)
            - "mean" or "sum": scalar

    Raises:
        ValueError: When ground truth shape differs from prediction shape (outside valid cases).
        ValueError: When input shape is not 4D or 5D.
        ValueError: When ground truth contains invalid class indices.
    """
🤖 Fix all issues with AI agents
In @monai/losses/unified_focal_loss.py:
- Around line 169-174: Update the forward method's docstring in
unified_focal_loss.py to include explicit Returns and Raises sections: under
Returns describe the torch.Tensor loss and its shape behavior for reductions
("none" -> (B, C, spatial_dims...), "mean"/"sum" -> scalar), and under Raises
document a ValueError when y_true shape differs from y_pred; keep the existing
Args description and ensure the docstring remains in the same triple-quoted
block for the forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) method.
- Around line 62-67: The forward method's docstring in unified_focal_loss.py is
missing Returns and Raises sections; update the forward(self, y_pred:
torch.Tensor, y_true: torch.Tensor) -> torch.Tensor docstring to describe the
returned torch.Tensor (including shape behavior for reductions: "none" -> (B,
C), "mean" or "sum" -> scalar) and document that a ValueError is raised when
y_true shape differs from y_pred; place these descriptions under "Returns:" and
"Raises:" in the forward docstring for clarity.
- Around line 287-291: The duplicate ValueError makes the binary-logits
exception unreachable; update the shape-mismatch check in unified_focal_loss
(the block using y_pred, y_true, is_binary_logits and self.to_onehot_y) so it
only raises when neither self.to_onehot_y nor is_binary_logits is true — i.e.,
compute is_binary_logits = (y_pred.shape[1] == 1 and not self.use_softmax) and
then raise a single ValueError only if not (self.to_onehot_y or
is_binary_logits), removing the redundant unconditional raise.
- Line 116: The focal exponent is inverted: replace the use of 1/self.gamma with
self.gamma so that fore_dice is computed as (1 - dice_class[:, 1:])**self.gamma;
update the expression where fore_dice is assigned (variable fore_dice, using
dice_class and self.gamma) to raise (1 - dice_class[:, 1:]) to the power
self.gamma instead of 1/self.gamma.
🧹 Nitpick comments (5)
monai/losses/unified_focal_loss.py (2)

130-130: Remove redundant fallback return.

Line 130 duplicates the MEAN reduction at line 124, making it unreachable.

♻️ Simplify reduction logic
-    if self.reduction == LossReduction.MEAN.value:
-        return torch.mean(all_losses)
-    if self.reduction == LossReduction.SUM.value:
-        return torch.sum(all_losses)
-    if self.reduction == LossReduction.NONE.value:
-        return all_losses
-
-    return torch.mean(all_losses)
+    if self.reduction == LossReduction.MEAN.value:
+        return torch.mean(all_losses)
+    elif self.reduction == LossReduction.SUM.value:
+        return torch.sum(all_losses)
+    else:  # LossReduction.NONE
+        return all_losses

225-225: Remove redundant fallback.

Duplicates line 219, unreachable code.

♻️ Simplify
-    if self.reduction == LossReduction.MEAN.value:
-        return torch.mean(torch.sum(all_ce, dim=1))
-    if self.reduction == LossReduction.SUM.value:
-        return torch.sum(all_ce)
-    if self.reduction == LossReduction.NONE.value:
-        return all_ce
-
-    return torch.mean(torch.sum(all_ce, dim=1))
+    if self.reduction == LossReduction.MEAN.value:
+        return torch.mean(torch.sum(all_ce, dim=1))
+    elif self.reduction == LossReduction.SUM.value:
+        return torch.sum(all_ce)
+    else:  # LossReduction.NONE
+        return all_ce
tests/losses/test_unified_focal_loss.py (3)

83-87: Remove unused parameter.

expected_val is not used in the test body (static analysis flagged ARG002).

♻️ Proposed fix
     @parameterized.expand([TEST_CASE_MULTICLASS_WRONG])
-    def test_wrong_prediction(self, input_data, expected_val, args):
+    def test_wrong_prediction(self, input_data, _expected_val, args):
         loss_func = AsymmetricUnifiedFocalLoss(**args)
         result = loss_func(**input_data)
         self.assertGreater(result.item(), 1.0, "Loss should be high for wrong predictions")

94-107: Use unittest.skipTest instead of print+return.

Line 96 prints and returns. Use self.skipTest("CUDA not available") for proper test skipping.

♻️ Proposed fix
     def test_with_cuda(self):
         if not torch.cuda.is_available():
-            print("CUDA not available, skipping test_with_cuda")
-            return
+            self.skipTest("CUDA not available")

         loss = AsymmetricUnifiedFocalLoss(use_softmax=False, to_onehot_y=False)

74-107: Consider additional test coverage.

Current tests cover basic scenarios well. Consider adding:

  • Different reduction modes (reduction="none", reduction="sum")
  • Access to sub-losses (loss.asy_focal_loss, loss.asy_focal_tversky_loss)
  • Edge case: multi-channel y_true with to_onehot_y=False

Not blocking, but would improve robustness.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between a89727d and 3434aff.

📒 Files selected for processing (2)
  • monai/losses/unified_focal_loss.py
  • tests/losses/test_unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • tests/losses/test_unified_focal_loss.py
  • monai/losses/unified_focal_loss.py
🧬 Code graph analysis (2)
tests/losses/test_unified_focal_loss.py (1)
monai/losses/unified_focal_loss.py (1)
  • AsymmetricUnifiedFocalLoss (228-330)
monai/losses/unified_focal_loss.py (3)
monai/utils/enums.py (1)
  • LossReduction (253-264)
monai/losses/focal_loss.py (1)
  • forward (119-201)
monai/networks/utils.py (1)
  • one_hot (170-220)
🪛 Ruff (0.14.10)
tests/losses/test_unified_focal_loss.py

84-84: Unused method argument: expected_val

(ARG002)

monai/losses/unified_focal_loss.py

90-90: Avoid specifying long messages outside the exception class

(TRY003)


195-195: Avoid specifying long messages outside the exception class

(TRY003)


290-290: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
  • GitHub Check: min-dep-pytorch (2.7.1)
  • GitHub Check: min-dep-os (windows-latest)
  • GitHub Check: min-dep-pytorch (2.8.0)
  • GitHub Check: min-dep-py3 (3.11)
  • GitHub Check: min-dep-py3 (3.12)
  • GitHub Check: min-dep-pytorch (2.5.1)
  • GitHub Check: min-dep-pytorch (2.6.0)
  • GitHub Check: min-dep-py3 (3.10)
  • GitHub Check: min-dep-py3 (3.9)
  • GitHub Check: quick-py3 (windows-latest)
  • GitHub Check: quick-py3 (ubuntu-latest)
  • GitHub Check: quick-py3 (macOS-latest)
  • GitHub Check: min-dep-os (macOS-latest)
  • GitHub Check: min-dep-os (ubuntu-latest)
  • GitHub Check: flake8-py3 (mypy)
  • GitHub Check: flake8-py3 (codeformat)
  • GitHub Check: flake8-py3 (pytype)
  • GitHub Check: packaging
  • GitHub Check: build-docs
🔇 Additional comments (2)
monai/losses/unified_focal_loss.py (2)

219-219: Verify reduction semantics for MEAN.

torch.mean(torch.sum(all_ce, dim=1)) sums over classes then averages over batch. This differs from typical focal loss reduction that averages over all elements.

Confirm this per-class sum → batch mean is intentional.


310-318: Verify shape compatibility between sub-losses.

Line 310: asy_focal_loss with reduction=NONE returns (B, C, H, W, ...).
Line 311: asy_focal_tversky_loss with reduction=NONE returns (B, C) per lines 128.

Line 315 averages spatial dims of focal_loss to get (B, C), then line 318 combines. Confirm asy_focal_tversky_loss already returns (B, C) without spatial dims.

Comment on lines 62 to +67
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
"""
Args:
y_pred: prediction logits or probabilities. Shape should be (B, C, spatial_dims).
y_true: ground truth labels. Shape should match y_pred.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Complete the forward method docstring.

Missing Returns and Raises sections. Per coding guidelines, describe return value and raised exceptions.

📝 Enhanced docstring
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
    """
    Args:
        y_pred: prediction logits or probabilities. Shape should be (B, C, spatial_dims).
        y_true: ground truth labels. Shape should match y_pred.

    Returns:
        torch.Tensor: Computed loss. Shape depends on reduction:
            - "none": (B, C)
            - "mean" or "sum": scalar

    Raises:
        ValueError: When ground truth shape differs from prediction shape.
    """
🤖 Prompt for AI Agents
In @monai/losses/unified_focal_loss.py around lines 62 - 67, The forward
method's docstring in unified_focal_loss.py is missing Returns and Raises
sections; update the forward(self, y_pred: torch.Tensor, y_true: torch.Tensor)
-> torch.Tensor docstring to describe the returned torch.Tensor (including shape
behavior for reductions: "none" -> (B, C), "mean" or "sum" -> scalar) and
document that a ValueError is raised when y_true shape differs from y_pred;
place these descriptions under "Returns:" and "Raises:" in the forward docstring
for clarity.

loss = torch.mean(torch.stack([back_dice, fore_dice], dim=-1))
return loss
# Foreground: Focal Tversky Loss
fore_dice = torch.pow(1 - dice_class[:, 1:], 1 / self.gamma)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🌐 Web query:

Unified Focal Loss Generalising Dice and Cross Entropy-based Losses Handle Class Imbalanced Medical Image Segmentation Michael Yeung focal Tversky formula

💡 Result:

The Focal Tversky loss is defined as (1 − TI)^γ, where the Tversky index (TI) for a class is
TI = TP / (TP + α·FP + β·FN),
with TP, FP, FN the true positives, false positives and false negatives (often α+β=1; e.g. α=0.3, β=0.7). Thus the Focal Tversky loss:
FTL = (1 − TI)^γ,
with γ controlling focus on hard examples (γ=1 → Tversky loss). [1][2]

Sources:

  • Yeung et al., "Unified Focal loss…" (PMC), Eq.12 and surrounding text. [1]
  • Abraham & Khan, "A novel focal Tversky loss…" (original Focal Tversky proposal). [2]

🏁 Script executed:

# Find and examine the unified_focal_loss.py file
fd unified_focal_loss.py -type f

Repository: Project-MONAI/MONAI

Length of output: 182


🏁 Script executed:

# Read the file to see the actual implementation around line 116
cat -n monai/losses/unified_focal_loss.py | sed -n '100,130p'

Repository: Project-MONAI/MONAI

Length of output: 182


🏁 Script executed:

# Check the full context of the AsymmetricFocalTverskyLoss class to understand gamma usage
cat -n monai/losses/unified_focal_loss.py | sed -n '1,150p'

Repository: Project-MONAI/MONAI

Length of output: 182


Fix focal loss exponent formula.

Line 116 uses torch.pow(1 - dice_class[:, 1:], 1 / self.gamma) but should use gamma directly as the exponent, not its reciprocal. Per Yeung et al.'s "Unified Focal Loss" (Eq. 12), the Focal Tversky loss is defined as (1 − TI)^γ, where γ controls focus on hard examples.

🤖 Prompt for AI Agents
In @monai/losses/unified_focal_loss.py at line 116, The focal exponent is
inverted: replace the use of 1/self.gamma with self.gamma so that fore_dice is
computed as (1 - dice_class[:, 1:])**self.gamma; update the expression where
fore_dice is assigned (variable fore_dice, using dice_class and self.gamma) to
raise (1 - dice_class[:, 1:]) to the power self.gamma instead of 1/self.gamma.

Comment on lines 169 to +173
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
"""
Args:
y_pred: prediction logits or probabilities.
y_true: ground truth labels.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Complete forward docstring.

Missing Returns and Raises sections.

📝 Enhanced docstring
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
    """
    Args:
        y_pred: prediction logits or probabilities.
        y_true: ground truth labels.

    Returns:
        torch.Tensor: Computed loss. Shape depends on reduction:
            - "none": (B, C, spatial_dims...)
            - "mean" or "sum": scalar

    Raises:
        ValueError: When ground truth shape differs from prediction shape.
    """
🤖 Prompt for AI Agents
In @monai/losses/unified_focal_loss.py around lines 169 - 174, Update the
forward method's docstring in unified_focal_loss.py to include explicit Returns
and Raises sections: under Returns describe the torch.Tensor loss and its shape
behavior for reductions ("none" -> (B, C, spatial_dims...), "mean"/"sum" ->
scalar), and under Raises document a ValueError when y_true shape differs from
y_pred; keep the existing Args description and ensure the docstring remains in
the same triple-quoted block for the forward(self, y_pred: torch.Tensor, y_true:
torch.Tensor) method.

Comment on lines +24 to +28
TEST_CASE_BINARY_LOGITS = [
{"y_pred": torch.tensor([[[[10.0, -10.0], [-10.0, 10.0]]]]), "y_true": torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]])},
0.0,
{"use_softmax": False, "to_onehot_y": False},
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Test will fail due to bugs in implementation.

Given the critical bugs in AsymmetricUnifiedFocalLoss.forward():

  1. Line 297 incorrectly converts y_pred to one-hot
  2. Line 291 always raises on shape mismatch

This test case will fail. Fix implementation bugs first.

Signed-off-by: ytl0623 <david89062388@gmail.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
monai/losses/unified_focal_loss.py (1)

297-302: Validation logic error.

Lines 301-302 validate torch.max(y_true) != self.num_classes - 1, but this check occurs AFTER y_true is converted to one-hot on line 299. For one-hot tensors, max value is always 1, not num_classes - 1.

The validation should either:

  1. Move before the one-hot conversion, or
  2. Check the shape instead of max value
🐛 Proposed fix

Option 1: Validate before one-hot conversion

         if y_pred.shape[1] == 1:
+            if torch.max(y_true) > self.num_classes - 1:
+                raise ValueError(f"Ground truth contains class indices > {self.num_classes - 1}")
             y_pred = one_hot(y_pred, num_classes=self.num_classes)
             y_true = one_hot(y_true, num_classes=self.num_classes)
-
-        if torch.max(y_true) != self.num_classes - 1:
-            raise ValueError(f"Please make sure the number of classes is {self.num_classes - 1}")

Option 2: Validate shape consistency

         if y_pred.shape[1] == 1:
             y_pred = one_hot(y_pred, num_classes=self.num_classes)
             y_true = one_hot(y_true, num_classes=self.num_classes)
 
-        if torch.max(y_true) != self.num_classes - 1:
-            raise ValueError(f"Please make sure the number of classes is {self.num_classes - 1}")
+        if y_true.shape[1] != self.num_classes:
+            raise ValueError(f"Ground truth channels ({y_true.shape[1]}) don't match num_classes ({self.num_classes})")
🤖 Fix all issues with AI agents
In @monai/losses/unified_focal_loss.py:
- Around line 266-279: The wrapper UnifiedFocalLoss is performing an explicit
one-hot conversion in its forward path while the child losses (asy_focal_loss
and asy_focal_tversky_loss) are also constructed with to_onehot_y=True, causing
double conversion and shape errors; fix by centralizing conversion in the
wrapper: keep the wrapper's one-hot conversion and change the initialization of
AsymmetricFocalLoss and AsymmetricFocalTverskyLoss to to_onehot_y=False (so they
assume inputs are already one-hot), or alternatively remove the wrapper
conversion and leave the sub-losses with to_onehot_y=True—prefer the first
option (do conversion once in UnifiedFocalLoss and set the sub-losses'
to_onehot_y=False).
- Around line 288-292: The code currently always raises a ValueError when
y_pred.shape != y_true.shape due to a duplicate raise; update the conditional in
the shape check inside the unified focal loss logic so that it only raises when
neither binary-logits nor to-onehot_Y handling applies: compute is_binary_logits
= (y_pred.shape[1] == 1 and not self.use_softmax) and then if not
self.to_onehot_y and not is_binary_logits: raise the ValueError, otherwise allow
execution to continue to handle binary logits or one-hot conversion paths;
reference the y_pred, y_true, is_binary_logits, to_onehot_y, and use_softmax
symbols when locating the change.

In @tests/losses/test_unified_focal_loss.py:
- Around line 83-87: The test method test_wrong_prediction currently accepts an
unused parameter expected_val; update its signature to remove expected_val so it
only takes (self, input_data, args), and ensure the parameterized.expand call
using TEST_CASE_MULTICLASS_WRONG still supplies just the two needed values
(input_data and args) or adjust the test case tuple accordingly; keep the body
using AsymmetricUnifiedFocalLoss and the existing assertion unchanged.
🧹 Nitpick comments (4)
tests/losses/test_unified_focal_loss.py (1)

94-107: Use self.skipTest() instead of early return.

The test should use the standard unittest pattern for skipping tests when CUDA is unavailable. Also consider removing the debug print statement on line 105.

♻️ Proposed fix
     def test_with_cuda(self):
         if not torch.cuda.is_available():
-            print("CUDA not available, skipping test_with_cuda")
-            return
+            self.skipTest("CUDA not available")
 
         loss = AsymmetricUnifiedFocalLoss(use_softmax=False, to_onehot_y=False)
         # Binary logits case on GPU
         i = torch.tensor([[[[10.0, 0], [0, 10.0]]], [[[10.0, 0], [0, 10.0]]]]).cuda()
         j = torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]).cuda()
 
         output = loss(i, j)
-        print(f"CUDA Output: {output.item()}")
         self.assertTrue(output.is_cuda)
         self.assertLess(output.item(), 1.0)
monai/losses/unified_focal_loss.py (3)

110-129: Redundant return statement.

Line 129 returns torch.mean(all_losses) as a fallback, but all reduction cases are already covered by lines 122-127. This is unreachable code.

♻️ Proposed fix
         # Apply reduction
         if self.reduction == LossReduction.MEAN.value:
             return torch.mean(all_losses)
         if self.reduction == LossReduction.SUM.value:
             return torch.sum(all_losses)
         if self.reduction == LossReduction.NONE.value:
             return all_losses
-
-        return torch.mean(all_losses)

Or use if-elif-else for clarity:

         # Apply reduction
         if self.reduction == LossReduction.MEAN.value:
             return torch.mean(all_losses)
-        if self.reduction == LossReduction.SUM.value:
+        elif self.reduction == LossReduction.SUM.value:
             return torch.sum(all_losses)
-        if self.reduction == LossReduction.NONE.value:
+        else:  # LossReduction.NONE
             return all_losses
-
-        return torch.mean(all_losses)

204-226: Redundant return statement.

Line 226 is unreachable as all reduction cases are covered. Same issue as in AsymmetricFocalTverskyLoss.

♻️ Proposed fix
         # Apply reduction
         if self.reduction == LossReduction.MEAN.value:
             return torch.mean(total_loss)
-        if self.reduction == LossReduction.SUM.value:
+        elif self.reduction == LossReduction.SUM.value:
             return torch.sum(total_loss)
-        if self.reduction == LossReduction.NONE.value:
+        else:  # LossReduction.NONE
             return total_loss
-        return torch.mean(total_loss)

83-83: Add stacklevel=2 to warnings.

The warnings.warn() calls should include stacklevel=2 to show the calling code location in the warning message, making debugging easier.

♻️ Proposed fix
-                warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
+                warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)

Apply to all three warnings.warn calls in the file (lines 83, 188, 307).

Also applies to: 188-188, 307-307

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between 3434aff and 377fc47.

📒 Files selected for processing (2)
  • monai/losses/unified_focal_loss.py
  • tests/losses/test_unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • tests/losses/test_unified_focal_loss.py
  • monai/losses/unified_focal_loss.py
🧬 Code graph analysis (1)
tests/losses/test_unified_focal_loss.py (2)
monai/losses/unified_focal_loss.py (1)
  • AsymmetricUnifiedFocalLoss (229-316)
tests/test_utils.py (1)
  • assert_allclose (119-159)
🪛 Ruff (0.14.10)
tests/losses/test_unified_focal_loss.py

84-84: Unused method argument: expected_val

(ARG002)

monai/losses/unified_focal_loss.py

83-83: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


89-89: Avoid specifying long messages outside the exception class

(TRY003)


188-188: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


194-194: Avoid specifying long messages outside the exception class

(TRY003)


291-291: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
  • GitHub Check: min-dep-py3 (3.12)
  • GitHub Check: min-dep-py3 (3.10)
  • GitHub Check: min-dep-os (ubuntu-latest)
  • GitHub Check: min-dep-pytorch (2.6.0)
  • GitHub Check: min-dep-py3 (3.11)
  • GitHub Check: min-dep-os (macOS-latest)
  • GitHub Check: min-dep-pytorch (2.7.1)
  • GitHub Check: min-dep-py3 (3.9)
  • GitHub Check: min-dep-os (windows-latest)
  • GitHub Check: min-dep-pytorch (2.5.1)
  • GitHub Check: min-dep-pytorch (2.8.0)
  • GitHub Check: quick-py3 (macOS-latest)
  • GitHub Check: quick-py3 (ubuntu-latest)
  • GitHub Check: quick-py3 (windows-latest)
  • GitHub Check: flake8-py3 (mypy)
  • GitHub Check: flake8-py3 (pytype)
  • GitHub Check: flake8-py3 (codeformat)
  • GitHub Check: build-docs
  • GitHub Check: packaging
🔇 Additional comments (2)
tests/losses/test_unified_focal_loss.py (1)

22-71: LGTM: Test case constants are well-structured.

The four test scenarios cover binary logits, binary 2-channel, multi-class perfect, and multi-class wrong predictions appropriately.

monai/losses/unified_focal_loss.py (1)

69-78: LGTM: Single-channel binary handling is correct.

The automatic expansion of single-channel binary predictions to 2-channel probability distributions is well-implemented.

Comment on lines +266 to +279
self.asy_focal_loss = AsymmetricFocalLoss(
gamma=self.gamma,
delta=self.delta,
use_softmax=self.use_softmax,
to_onehot_y=to_onehot_y,
reduction=reduction,
)
self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(
gamma=self.gamma,
delta=self.delta,
use_softmax=self.use_softmax,
to_onehot_y=to_onehot_y,
reduction=reduction,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Potential double one-hot conversion.

Lines 297-299 convert single-channel inputs to one-hot, then lines 305-309 may convert again if to_onehot_y=True. Additionally, the sub-losses on lines 266-279 are initialized with to_onehot_y, so they will also attempt one-hot conversion.

This could lead to redundant conversions or shape mismatches. Consider:

  1. Performing one-hot conversion only in the wrapper (lines 297-309), then passing to_onehot_y=False to sub-losses, or
  2. Letting sub-losses handle all conversions and removing lines 297-309
🔧 Recommended fix

Option 1: Handle all conversions in wrapper

         self.asy_focal_loss = AsymmetricFocalLoss(
             gamma=self.gamma,
             delta=self.delta,
             use_softmax=self.use_softmax,
-            to_onehot_y=to_onehot_y,
+            to_onehot_y=False,  # Wrapper handles one-hot conversion
             reduction=reduction,
         )
         self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(
             gamma=self.gamma,
             delta=self.delta,
             use_softmax=self.use_softmax,
-            to_onehot_y=to_onehot_y,
+            to_onehot_y=False,  # Wrapper handles one-hot conversion
             reduction=reduction,
         )

Option 2: Let sub-losses handle conversions

         if y_pred.shape != y_true.shape:
             is_binary_logits = y_pred.shape[1] == 1 and not self.use_softmax
             if not self.to_onehot_y and not is_binary_logits:
                 raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")
 
         if len(y_pred.shape) != 4 and len(y_pred.shape) != 5:
             raise ValueError(f"input shape must be 4 or 5, but got {y_pred.shape}")
 
-        if y_pred.shape[1] == 1:
-            y_pred = one_hot(y_pred, num_classes=self.num_classes)
-            y_true = one_hot(y_true, num_classes=self.num_classes)
-
-        if torch.max(y_true) != self.num_classes - 1:
-            raise ValueError(f"Please make sure the number of classes is {self.num_classes - 1}")
-
-        n_pred_ch = y_pred.shape[1]
-        if self.to_onehot_y:
-            if n_pred_ch == 1:
-                warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
-            else:
-                y_true = one_hot(y_true, num_classes=n_pred_ch)
-
         asy_focal_loss = self.asy_focal_loss(y_pred, y_true)
         asy_focal_tversky_loss = self.asy_focal_tversky_loss(y_pred, y_true)

Also applies to: 297-309

🤖 Prompt for AI Agents
In @monai/losses/unified_focal_loss.py around lines 266 - 279, The wrapper
UnifiedFocalLoss is performing an explicit one-hot conversion in its forward
path while the child losses (asy_focal_loss and asy_focal_tversky_loss) are also
constructed with to_onehot_y=True, causing double conversion and shape errors;
fix by centralizing conversion in the wrapper: keep the wrapper's one-hot
conversion and change the initialization of AsymmetricFocalLoss and
AsymmetricFocalTverskyLoss to to_onehot_y=False (so they assume inputs are
already one-hot), or alternatively remove the wrapper conversion and leave the
sub-losses with to_onehot_y=True—prefer the first option (do conversion once in
UnifiedFocalLoss and set the sub-losses' to_onehot_y=False).

Comment on lines 288 to 292
if y_pred.shape != y_true.shape:
is_binary_logits = y_pred.shape[1] == 1 and not self.use_softmax
if not self.to_onehot_y and not is_binary_logits:
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix duplicate raise statement.

Line 292 unconditionally raises the same ValueError as line 291, making the conditional logic on lines 289-291 pointless. This prevents binary logits from being processed correctly.

🐛 Proposed fix
         if y_pred.shape != y_true.shape:
             is_binary_logits = y_pred.shape[1] == 1 and not self.use_softmax
             if not self.to_onehot_y and not is_binary_logits:
                 raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")
-            raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")

This allows binary logits (single-channel with use_softmax=False) and cases where to_onehot_y=True to proceed despite shape mismatch.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if y_pred.shape != y_true.shape:
is_binary_logits = y_pred.shape[1] == 1 and not self.use_softmax
if not self.to_onehot_y and not is_binary_logits:
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")
if y_pred.shape != y_true.shape:
is_binary_logits = y_pred.shape[1] == 1 and not self.use_softmax
if not self.to_onehot_y and not is_binary_logits:
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")
🧰 Tools
🪛 Ruff (0.14.10)

291-291: Avoid specifying long messages outside the exception class

(TRY003)


292-292: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In @monai/losses/unified_focal_loss.py around lines 288 - 292, The code
currently always raises a ValueError when y_pred.shape != y_true.shape due to a
duplicate raise; update the conditional in the shape check inside the unified
focal loss logic so that it only raises when neither binary-logits nor
to-onehot_Y handling applies: compute is_binary_logits = (y_pred.shape[1] == 1
and not self.use_softmax) and then if not self.to_onehot_y and not
is_binary_logits: raise the ValueError, otherwise allow execution to continue to
handle binary logits or one-hot conversion paths; reference the y_pred, y_true,
is_binary_logits, to_onehot_y, and use_softmax symbols when locating the change.

Comment on lines +83 to +87
@parameterized.expand([TEST_CASE_MULTICLASS_WRONG])
def test_wrong_prediction(self, input_data, expected_val, args):
loss_func = AsymmetricUnifiedFocalLoss(**args)
result = loss_func(**input_data)
self.assertGreater(result.item(), 1.0, "Loss should be high for wrong predictions")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Remove unused parameter.

The expected_val parameter is not used in this test method. Remove it from the signature or use it for validation.

🔧 Proposed fix
-    def test_wrong_prediction(self, input_data, expected_val, args):
+    def test_wrong_prediction(self, input_data, _expected_val, args):
         loss_func = AsymmetricUnifiedFocalLoss(**args)
         result = loss_func(**input_data)
         self.assertGreater(result.item(), 1.0, "Loss should be high for wrong predictions")

Alternatively, remove the parameter entirely if not needed:

 # Update TEST_CASE_MULTICLASS_WRONG to remove the None value
 TEST_CASE_MULTICLASS_WRONG = [
     {
         "y_pred": torch.tensor(
             [[[[-10.0, -10.0], [-10.0, -10.0]], [[10.0, 10.0], [10.0, 10.0]], [[-10.0, -10.0], [-10.0, -10.0]]]]
         ),
         "y_true": torch.tensor([[[[0, 0], [0, 0]]]]),  # GT is class 0, but Pred is class 1
     },
     {"use_softmax": True, "to_onehot_y": True},
 ]

-    def test_wrong_prediction(self, input_data, expected_val, args):
+    def test_wrong_prediction(self, input_data, args):
         loss_func = AsymmetricUnifiedFocalLoss(**args)
         result = loss_func(**input_data)
         self.assertGreater(result.item(), 1.0, "Loss should be high for wrong predictions")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@parameterized.expand([TEST_CASE_MULTICLASS_WRONG])
def test_wrong_prediction(self, input_data, expected_val, args):
loss_func = AsymmetricUnifiedFocalLoss(**args)
result = loss_func(**input_data)
self.assertGreater(result.item(), 1.0, "Loss should be high for wrong predictions")
@parameterized.expand([TEST_CASE_MULTICLASS_WRONG])
def test_wrong_prediction(self, input_data, _expected_val, args):
loss_func = AsymmetricUnifiedFocalLoss(**args)
result = loss_func(**input_data)
self.assertGreater(result.item(), 1.0, "Loss should be high for wrong predictions")
🧰 Tools
🪛 Ruff (0.14.10)

84-84: Unused method argument: expected_val

(ARG002)

🤖 Prompt for AI Agents
In @tests/losses/test_unified_focal_loss.py around lines 83 - 87, The test
method test_wrong_prediction currently accepts an unused parameter expected_val;
update its signature to remove expected_val so it only takes (self, input_data,
args), and ensure the parameterized.expand call using TEST_CASE_MULTICLASS_WRONG
still supplies just the two needed values (input_data and args) or adjust the
test case tuple accordingly; keep the body using AsymmetricUnifiedFocalLoss and
the existing assertion unchanged.

@ytl0623 ytl0623 marked this pull request as draft January 9, 2026 01:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add sigmoid/softmax interface for AsymmetricUnifiedFocalLoss

2 participants