From 8fbec82b295bfaf87920770be830987177320642 Mon Sep 17 00:00:00 2001 From: bala93 Date: Sat, 1 Jun 2024 22:11:42 -0400 Subject: [PATCH 01/50] Initial commit -- Adding calibration loss specific to segmentation --- docs/source/losses.rst | 5 ++ monai/losses/__init__.py | 1 + monai/losses/segcalib.py | 124 +++++++++++++++++++++++++++++++++++++++ tests/test_nacl_loss.py | 108 ++++++++++++++++++++++++++++++++++ 4 files changed, 238 insertions(+) create mode 100644 monai/losses/segcalib.py create mode 100644 tests/test_nacl_loss.py diff --git a/docs/source/losses.rst b/docs/source/losses.rst index ba794af3eb..c008f77be8 100644 --- a/docs/source/losses.rst +++ b/docs/source/losses.rst @@ -93,6 +93,11 @@ Segmentation Losses .. autoclass:: SoftDiceclDiceLoss :members: +`NACLLoss` +~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: NACLLoss + :members: + Registration Losses ------------------- diff --git a/monai/losses/__init__.py b/monai/losses/__init__.py index e937b53fa4..fdf47de07b 100644 --- a/monai/losses/__init__.py +++ b/monai/losses/__init__.py @@ -44,3 +44,4 @@ from .sure_loss import SURELoss from .tversky import TverskyLoss from .unified_focal_loss import AsymmetricUnifiedFocalLoss +from .segcalib import NACLLoss \ No newline at end of file diff --git a/monai/losses/segcalib.py b/monai/losses/segcalib.py new file mode 100644 index 0000000000..c7646144fe --- /dev/null +++ b/monai/losses/segcalib.py @@ -0,0 +1,124 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.modules.loss import _Loss +import math + +def get_gaussian_kernel_2d(ksize=3, sigma=1): + x_grid = torch.arange(ksize).repeat(ksize).view(ksize, ksize) + y_grid = x_grid.t() + xy_grid = torch.stack([x_grid, y_grid], dim=-1).float() + mean = (ksize - 1)/2. + variance = sigma**2. + gaussian_kernel = (1./(2.*math.pi*variance + 1e-16)) * torch.exp( + -torch.sum((xy_grid - mean)**2., dim=-1) / (2*variance + 1e-16) + ) + return gaussian_kernel / torch.sum(gaussian_kernel) + +class get_svls_filter_2d(torch.nn.Module): + def __init__(self, ksize=3, sigma=1, channels=0): + super(get_svls_filter_2d, self).__init__() + gkernel = get_gaussian_kernel_2d(ksize=ksize, sigma=sigma) + neighbors_sum = (1 - gkernel[1,1]) + 1e-16 + gkernel[int(ksize/2), int(ksize/2)] = neighbors_sum + self.svls_kernel = gkernel / neighbors_sum + svls_kernel_2d = self.svls_kernel.view(1, 1, ksize, ksize) + svls_kernel_2d = svls_kernel_2d.repeat(channels, 1, 1, 1) + padding = int(ksize/2) + self.svls_layer = torch.nn.Conv2d(in_channels=channels, out_channels=channels, + kernel_size=ksize, groups=channels, + bias=False, padding=padding, padding_mode='replicate') + self.svls_layer.weight.data = svls_kernel_2d + self.svls_layer.weight.requires_grad = False + def forward(self, x): + return self.svls_layer(x) / self.svls_kernel.sum() + +class NACLLoss(_Loss): + """Add marginal penalty to logits: + CE + alpha * max(0, max(l^n) - l^n - margin) + """ + def __init__(self, + classes=None, + kernel_size=3, + kernel_ops='mean', + distance_type='l1', + is_softmax=False, + alpha=0.1, + ignore_index=-100, + sigma=1, + schedule=""): + + super().__init__() + assert schedule in ("", "add", "multiply", "step") + + self.distance_type = distance_type + + self.alpha = alpha + self.ignore_index = ignore_index + + self.is_softmax = is_softmax + + self.nc = classes + self.ks = kernel_size + self.kernel_ops = kernel_ops + self.cross_entropy = nn.CrossEntropyLoss() + if kernel_ops == 'gaussian': + self.svls_layer = get_svls_filter_2d(ksize=kernel_size, sigma=sigma, channels=classes) + + def get_constr_target(self, mask): + + mask = mask.unsqueeze(1) ## unfold works for 4d. + + bs, _, h, w = mask.shape + unfold = torch.nn.Unfold(kernel_size=(self.ks, self.ks),padding=self.ks // 2) + + rmask = [] + + if self.kernel_ops == 'mean': + umask = unfold(mask.float()) + + for ii in range(self.nc): + rmask.append(torch.sum(umask == ii,1)/self.ks**2) + + if self.kernel_ops == 'gaussian': + + oh_labels = F.one_hot(mask[:,0].to(torch.int64), num_classes = self.nc).contiguous().permute(0,3,1,2).float() + rmask = self.svls_layer(oh_labels) + + return rmask + + rmask = torch.stack(rmask,dim=1) + rmask = rmask.reshape(bs, self.nc, h, w) + + return rmask + + + def forward(self, inputs, targets, imgs): + + loss_ce = self.cross_entropy(inputs, targets) + + utargets = self.get_constr_target(targets, imgs) + + if self.is_softmax: + inputs = F.softmax(inputs, dim=1) + + if self.distance_type == 'l1': + loss_conf = torch.abs(utargets - inputs).mean() + + if self.distance_type == 'l2': + loss_conf = (torch.abs(utargets - inputs)**2).mean() + + loss = loss_ce + self.alpha * loss_conf + + return loss, loss_ce, loss_conf diff --git a/tests/test_nacl_loss.py b/tests/test_nacl_loss.py new file mode 100644 index 0000000000..fc42336584 --- /dev/null +++ b/tests/test_nacl_loss.py @@ -0,0 +1,108 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.losses import NACLLoss + +TEST_CASES = [ + [ # shape: (2, 2, 3), (2, 2, 3) + {"classes": 2}, + { + "inputs": torch.tensor( + [ + [[0.8959, 0.7435, 0.4429], [0.6038, 0.5506, 0.3869], [0.8485, 0.4703, 0.8790]], + [[0.5137, 0.8345, 0.2821], [0.3644, 0.8000, 0.5156], [0.4732, 0.2018, 0.4564]], + ] + ), + "targets": torch.tensor( + [ + [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], + [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], + ] + ), + }, + 3.3611, # the result equals to -1 + np.log(1 + np.exp(1)) + ], + [ # shape: (2, 2, 3), (2, 2, 3) + {"classes": 2, "kernel_ops": "gaussian"}, + { + "inputs": torch.tensor( + [ + [[0.8959, 0.7435, 0.4429], [0.6038, 0.5506, 0.3869], [0.8485, 0.4703, 0.8790]], + [[0.5137, 0.8345, 0.2821], [0.3644, 0.8000, 0.5156], [0.4732, 0.2018, 0.4564]], + ] + ), + "targets": torch.tensor( + [ + [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], + [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], + ] + ), + }, + 3.3963, # the result equals to -1 + np.log(1 + np.exp(1)) + ], + [ # shape: (2, 2, 3), (2, 2, 3) + {"classes": 2, "distance_type": "l2"}, + { + "inputs": torch.tensor( + [ + [[0.8959, 0.7435, 0.4429], [0.6038, 0.5506, 0.3869], [0.8485, 0.4703, 0.8790]], + [[0.5137, 0.8345, 0.2821], [0.3644, 0.8000, 0.5156], [0.4732, 0.2018, 0.4564]], + ] + ), + "targets": torch.tensor( + [ + [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], + [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], + ] + ), + }, + 3.3459, # the result equals to -1 + np.log(1 + np.exp(1)) + ], + [ # shape: (2, 2, 3), (2, 2, 3) + {"classes": 2, "alpha": 0.2}, + { + "inputs": torch.tensor( + [ + [[0.8959, 0.7435, 0.4429], [0.6038, 0.5506, 0.3869], [0.8485, 0.4703, 0.8790]], + [[0.5137, 0.8345, 0.2821], [0.3644, 0.8000, 0.5156], [0.4732, 0.2018, 0.4564]], + ] + ), + "targets": torch.tensor( + [ + [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], + [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], + ] + ), + }, + 3.3836, # the result equals to -1 + np.log(1 + np.exp(1)) + ], +] + + +class TestNACLLoss(unittest.TestCase): + + @parameterized.expand(TEST_CASES) + def test_result(self, input_param, input_data, expected_val): + loss = NACLLoss(**input_param) + result = loss(**input_data) + np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4) + + +if __name__ == "__main__": + unittest.main() From 23b897bf050257b03e7be8d9dfcb8ba5d60af1cf Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 2 Jun 2024 02:13:11 +0000 Subject: [PATCH 02/50] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/losses/__init__.py | 2 +- monai/losses/segcalib.py | 48 ++++++++++++++++++++-------------------- 2 files changed, 25 insertions(+), 25 deletions(-) diff --git a/monai/losses/__init__.py b/monai/losses/__init__.py index fdf47de07b..51e1f7797e 100644 --- a/monai/losses/__init__.py +++ b/monai/losses/__init__.py @@ -44,4 +44,4 @@ from .sure_loss import SURELoss from .tversky import TverskyLoss from .unified_focal_loss import AsymmetricUnifiedFocalLoss -from .segcalib import NACLLoss \ No newline at end of file +from .segcalib import NACLLoss diff --git a/monai/losses/segcalib.py b/monai/losses/segcalib.py index c7646144fe..35a5f2800c 100644 --- a/monai/losses/segcalib.py +++ b/monai/losses/segcalib.py @@ -21,7 +21,7 @@ def get_gaussian_kernel_2d(ksize=3, sigma=1): xy_grid = torch.stack([x_grid, y_grid], dim=-1).float() mean = (ksize - 1)/2. variance = sigma**2. - gaussian_kernel = (1./(2.*math.pi*variance + 1e-16)) * torch.exp( + gaussian_kernel = (1./(2.*math.pi*variance + 1e-16)) * torch.exp( -torch.sum((xy_grid - mean)**2., dim=-1) / (2*variance + 1e-16) ) return gaussian_kernel / torch.sum(gaussian_kernel) @@ -58,15 +58,15 @@ def __init__(self, ignore_index=-100, sigma=1, schedule=""): - + super().__init__() assert schedule in ("", "add", "multiply", "step") - + self.distance_type = distance_type - + self.alpha = alpha self.ignore_index = ignore_index - + self.is_softmax = is_softmax self.nc = classes @@ -77,20 +77,20 @@ def __init__(self, self.svls_layer = get_svls_filter_2d(ksize=kernel_size, sigma=sigma, channels=classes) def get_constr_target(self, mask): - - mask = mask.unsqueeze(1) ## unfold works for 4d. - + + mask = mask.unsqueeze(1) ## unfold works for 4d. + bs, _, h, w = mask.shape - unfold = torch.nn.Unfold(kernel_size=(self.ks, self.ks),padding=self.ks // 2) - + unfold = torch.nn.Unfold(kernel_size=(self.ks, self.ks),padding=self.ks // 2) + rmask = [] - - if self.kernel_ops == 'mean': + + if self.kernel_ops == 'mean': umask = unfold(mask.float()) - + for ii in range(self.nc): rmask.append(torch.sum(umask == ii,1)/self.ks**2) - + if self.kernel_ops == 'gaussian': oh_labels = F.one_hot(mask[:,0].to(torch.int64), num_classes = self.nc).contiguous().permute(0,3,1,2).float() @@ -100,24 +100,24 @@ def get_constr_target(self, mask): rmask = torch.stack(rmask,dim=1) rmask = rmask.reshape(bs, self.nc, h, w) - + return rmask - + def forward(self, inputs, targets, imgs): - + loss_ce = self.cross_entropy(inputs, targets) - + utargets = self.get_constr_target(targets, imgs) - + if self.is_softmax: inputs = F.softmax(inputs, dim=1) - + if self.distance_type == 'l1': - loss_conf = torch.abs(utargets - inputs).mean() - - if self.distance_type == 'l2': - loss_conf = (torch.abs(utargets - inputs)**2).mean() + loss_conf = torch.abs(utargets - inputs).mean() + + if self.distance_type == 'l2': + loss_conf = (torch.abs(utargets - inputs)**2).mean() loss = loss_ce + self.alpha * loss_conf From b2ec62b62f7c94d495639fbb126a9dc2312040ae Mon Sep 17 00:00:00 2001 From: Balamurali Date: Sat, 1 Jun 2024 20:37:36 -0700 Subject: [PATCH 03/50] Update __init__.py Fix the order of loss function in import statement --- monai/losses/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/losses/__init__.py b/monai/losses/__init__.py index 51e1f7797e..58963360f2 100644 --- a/monai/losses/__init__.py +++ b/monai/losses/__init__.py @@ -38,10 +38,10 @@ from .image_dissimilarity import GlobalMutualInformationLoss, LocalNormalizedCrossCorrelationLoss from .multi_scale import MultiScaleLoss from .perceptual import PerceptualLoss +from .segcalib import NACLLoss from .spatial_mask import MaskedLoss from .spectral_loss import JukeboxLoss from .ssim_loss import SSIMLoss from .sure_loss import SURELoss from .tversky import TverskyLoss from .unified_focal_loss import AsymmetricUnifiedFocalLoss -from .segcalib import NACLLoss From 93ee114a739bcc929c5a1c75afb3545ea2dc5410 Mon Sep 17 00:00:00 2001 From: Balamurali Date: Sat, 1 Jun 2024 20:50:32 -0700 Subject: [PATCH 04/50] Update segcalib.py Fixed the formatting and minor issues. --- monai/losses/segcalib.py | 137 +++++++++++++++++++++++++-------------- 1 file changed, 89 insertions(+), 48 deletions(-) diff --git a/monai/losses/segcalib.py b/monai/losses/segcalib.py index 35a5f2800c..29732ad21a 100644 --- a/monai/losses/segcalib.py +++ b/monai/losses/segcalib.py @@ -9,116 +9,157 @@ # See the License for the specific language governing permissions and # limitations under the License. + import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.modules.loss import _Loss import math +import warnings +from monai.utils import pytorch_after + def get_gaussian_kernel_2d(ksize=3, sigma=1): x_grid = torch.arange(ksize).repeat(ksize).view(ksize, ksize) y_grid = x_grid.t() xy_grid = torch.stack([x_grid, y_grid], dim=-1).float() - mean = (ksize - 1)/2. - variance = sigma**2. - gaussian_kernel = (1./(2.*math.pi*variance + 1e-16)) * torch.exp( - -torch.sum((xy_grid - mean)**2., dim=-1) / (2*variance + 1e-16) - ) + mean = (ksize - 1) / 2.0 + variance = sigma**2.0 + gaussian_kernel = (1.0 / (2.0 * math.pi * variance + 1e-16)) * torch.exp( + -torch.sum((xy_grid - mean) ** 2.0, dim=-1) / (2 * variance + 1e-16) + ) return gaussian_kernel / torch.sum(gaussian_kernel) + class get_svls_filter_2d(torch.nn.Module): def __init__(self, ksize=3, sigma=1, channels=0): super(get_svls_filter_2d, self).__init__() gkernel = get_gaussian_kernel_2d(ksize=ksize, sigma=sigma) - neighbors_sum = (1 - gkernel[1,1]) + 1e-16 - gkernel[int(ksize/2), int(ksize/2)] = neighbors_sum + neighbors_sum = (1 - gkernel[1, 1]) + 1e-16 + gkernel[int(ksize / 2), int(ksize / 2)] = neighbors_sum self.svls_kernel = gkernel / neighbors_sum svls_kernel_2d = self.svls_kernel.view(1, 1, ksize, ksize) svls_kernel_2d = svls_kernel_2d.repeat(channels, 1, 1, 1) - padding = int(ksize/2) - self.svls_layer = torch.nn.Conv2d(in_channels=channels, out_channels=channels, - kernel_size=ksize, groups=channels, - bias=False, padding=padding, padding_mode='replicate') + padding = int(ksize / 2) + self.svls_layer = torch.nn.Conv2d( + in_channels=channels, + out_channels=channels, + kernel_size=ksize, + groups=channels, + bias=False, + padding=padding, + padding_mode="replicate", + ) self.svls_layer.weight.data = svls_kernel_2d self.svls_layer.weight.requires_grad = False + def forward(self, x): return self.svls_layer(x) / self.svls_kernel.sum() + class NACLLoss(_Loss): - """Add marginal penalty to logits: - CE + alpha * max(0, max(l^n) - l^n - margin) """ - def __init__(self, - classes=None, - kernel_size=3, - kernel_ops='mean', - distance_type='l1', - is_softmax=False, - alpha=0.1, - ignore_index=-100, - sigma=1, - schedule=""): + Murugesan, Balamurali, et al. + "Trust your neighbours: Penalty-based constraints for model calibration." + International Conference on Medical Image Computing and Computer-Assisted Intervention, 2023. + https://arxiv.org/abs/2303.06268 + """ + + def __init__(self, classes, kernel_size=3, kernel_ops="mean", distance_type="l1", alpha=0.1, sigma=1): + """ + Args: + classes: number of classes + kernel_size: size of the spatial kernel + kenel_ops: type of kernel operation (mean/gaussian) + distance_type: l1/l2 distance between spatial kernel and predicted logits + alpha: weightage between cross entropy and logit constraint + sigma: sigma if the kernel type is gaussian + """ super().__init__() - assert schedule in ("", "add", "multiply", "step") - self.distance_type = distance_type + if kernel_ops not in ["mean", "gaussian"]: + raise ValueError("Kernel ops must be either mean or gaussian") - self.alpha = alpha - self.ignore_index = ignore_index + if distance_type not in ["l1", "l2"]: + raise ValueError("Distance type must be either L1 or L2") - self.is_softmax = is_softmax + self.kernel_ops = kernel_ops + self.distance_type = distance_type + self.alpha = alpha self.nc = classes self.ks = kernel_size - self.kernel_ops = kernel_ops self.cross_entropy = nn.CrossEntropyLoss() - if kernel_ops == 'gaussian': + + if kernel_ops == "gaussian": self.svls_layer = get_svls_filter_2d(ksize=kernel_size, sigma=sigma, channels=classes) + self.old_pt_ver = not pytorch_after(1, 10) + + def ce(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Compute CrossEntropy loss for the input logits and target. + Will remove the channel dim according to PyTorch CrossEntropyLoss: + https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html?#torch.nn.CrossEntropyLoss. + + """ + n_pred_ch, n_target_ch = input.shape[1], target.shape[1] + if n_pred_ch != n_target_ch and n_target_ch == 1: + target = torch.squeeze(target, dim=1) + target = target.long() + elif self.old_pt_ver: + warnings.warn( + f"Multichannel targets are not supported in this older Pytorch version {torch.__version__}. " + "Using argmax (as a workaround) to convert target to a single channel." + ) + target = torch.argmax(target, dim=1) + elif not torch.is_floating_point(target): + target = target.to(dtype=input.dtype) + + return self.cross_entropy(input, target) # type: ignore[no-any-return] + def get_constr_target(self, mask): - mask = mask.unsqueeze(1) ## unfold works for 4d. + mask = mask.unsqueeze(1) ## unfold works for 4d. bs, _, h, w = mask.shape - unfold = torch.nn.Unfold(kernel_size=(self.ks, self.ks),padding=self.ks // 2) + unfold = torch.nn.Unfold(kernel_size=(self.ks, self.ks), padding=self.ks // 2) rmask = [] - if self.kernel_ops == 'mean': + if self.kernel_ops == "mean": umask = unfold(mask.float()) for ii in range(self.nc): - rmask.append(torch.sum(umask == ii,1)/self.ks**2) + rmask.append(torch.sum(umask == ii, 1) / self.ks**2) - if self.kernel_ops == 'gaussian': + if self.kernel_ops == "gaussian": - oh_labels = F.one_hot(mask[:,0].to(torch.int64), num_classes = self.nc).contiguous().permute(0,3,1,2).float() + oh_labels = ( + F.one_hot(mask[:, 0].to(torch.int64), num_classes=self.nc).contiguous().permute(0, 3, 1, 2).float() + ) rmask = self.svls_layer(oh_labels) return rmask - rmask = torch.stack(rmask,dim=1) + rmask = torch.stack(rmask, dim=1) rmask = rmask.reshape(bs, self.nc, h, w) return rmask + def forward(self, inputs, targets): - def forward(self, inputs, targets, imgs): - - loss_ce = self.cross_entropy(inputs, targets) - - utargets = self.get_constr_target(targets, imgs) + loss_ce = self.ce(inputs, targets) - if self.is_softmax: - inputs = F.softmax(inputs, dim=1) + utargets = self.get_constr_target(targets) - if self.distance_type == 'l1': + if self.distance_type == "l1": loss_conf = torch.abs(utargets - inputs).mean() - if self.distance_type == 'l2': - loss_conf = (torch.abs(utargets - inputs)**2).mean() + if self.distance_type == "l2": + loss_conf = (torch.abs(utargets - inputs) ** 2).mean() loss = loss_ce + self.alpha * loss_conf - return loss, loss_ce, loss_conf + return loss # , loss_ce, loss_conf From 42e732bbf88302d8efbfc78063abf7e0d09eda67 Mon Sep 17 00:00:00 2001 From: Balamurali Date: Sat, 1 Jun 2024 22:14:37 -0700 Subject: [PATCH 05/50] Update segcalib.py Reorder import statements --- monai/losses/segcalib.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/monai/losses/segcalib.py b/monai/losses/segcalib.py index 29732ad21a..d796a35014 100644 --- a/monai/losses/segcalib.py +++ b/monai/losses/segcalib.py @@ -9,13 +9,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings +import math import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.modules.loss import _Loss -import math -import warnings + from monai.utils import pytorch_after From 187053d154374ec9649ab10a990c126160c6e8c3 Mon Sep 17 00:00:00 2001 From: Balamurali Date: Sun, 2 Jun 2024 09:21:49 -0700 Subject: [PATCH 06/50] Update segcalib.py Reorder system imports. --- monai/losses/segcalib.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/losses/segcalib.py b/monai/losses/segcalib.py index d796a35014..ee8b67a828 100644 --- a/monai/losses/segcalib.py +++ b/monai/losses/segcalib.py @@ -9,8 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import warnings import math +import warnings import torch import torch.nn as nn From 1d27ec5c5de88815d7b1b95d6eba0acd4379aa47 Mon Sep 17 00:00:00 2001 From: Balamurali Date: Sun, 2 Jun 2024 11:33:37 -0700 Subject: [PATCH 07/50] Update segcalib.py isort inferred one --- monai/losses/segcalib.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/monai/losses/segcalib.py b/monai/losses/segcalib.py index ee8b67a828..19347af4b5 100644 --- a/monai/losses/segcalib.py +++ b/monai/losses/segcalib.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import math import warnings From d4991340657564bdee7892d8909aa14a060bd67b Mon Sep 17 00:00:00 2001 From: Balamurali Date: Sun, 2 Jun 2024 18:45:31 -0700 Subject: [PATCH 08/50] Update segcalib.py Updated class name according to standard --- monai/losses/segcalib.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/losses/segcalib.py b/monai/losses/segcalib.py index 19347af4b5..05d5c3f3d6 100644 --- a/monai/losses/segcalib.py +++ b/monai/losses/segcalib.py @@ -34,9 +34,9 @@ def get_gaussian_kernel_2d(ksize=3, sigma=1): return gaussian_kernel / torch.sum(gaussian_kernel) -class get_svls_filter_2d(torch.nn.Module): +class GaussianFilter(torch.nn.Module): def __init__(self, ksize=3, sigma=1, channels=0): - super(get_svls_filter_2d, self).__init__() + super(GaussianFilter, self).__init__() gkernel = get_gaussian_kernel_2d(ksize=ksize, sigma=sigma) neighbors_sum = (1 - gkernel[1, 1]) + 1e-16 gkernel[int(ksize / 2), int(ksize / 2)] = neighbors_sum @@ -96,7 +96,7 @@ def __init__(self, classes, kernel_size=3, kernel_ops="mean", distance_type="l1" self.cross_entropy = nn.CrossEntropyLoss() if kernel_ops == "gaussian": - self.svls_layer = get_svls_filter_2d(ksize=kernel_size, sigma=sigma, channels=classes) + self.svls_layer = GaussianFilter(ksize=kernel_size, sigma=sigma, channels=classes) self.old_pt_ver = not pytorch_after(1, 10) From 1e3f7559e20d2789828ce7415a3a14fd905ddb85 Mon Sep 17 00:00:00 2001 From: Balamurali Date: Sun, 2 Jun 2024 18:55:06 -0700 Subject: [PATCH 09/50] Update segcalib.py Fixed the inline comment --- monai/losses/segcalib.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/losses/segcalib.py b/monai/losses/segcalib.py index 05d5c3f3d6..44272bef8f 100644 --- a/monai/losses/segcalib.py +++ b/monai/losses/segcalib.py @@ -124,7 +124,7 @@ def ce(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: def get_constr_target(self, mask): - mask = mask.unsqueeze(1) ## unfold works for 4d. + mask = mask.unsqueeze(1) # unfold works for 4d. bs, _, h, w = mask.shape unfold = torch.nn.Unfold(kernel_size=(self.ks, self.ks), padding=self.ks // 2) From 9dedfba17a83886a472e96f723bfa2ea4d8739e8 Mon Sep 17 00:00:00 2001 From: Balamurali Date: Tue, 4 Jun 2024 08:16:49 -0700 Subject: [PATCH 10/50] Update segcalib.py Add datatype in functions and classes --- monai/losses/segcalib.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/monai/losses/segcalib.py b/monai/losses/segcalib.py index 44272bef8f..ba13af104a 100644 --- a/monai/losses/segcalib.py +++ b/monai/losses/segcalib.py @@ -22,7 +22,7 @@ from monai.utils import pytorch_after -def get_gaussian_kernel_2d(ksize=3, sigma=1): +def get_gaussian_kernel_2d(ksize: int = 3, sigma: float = 1.0) -> torch.Tensor: x_grid = torch.arange(ksize).repeat(ksize).view(ksize, ksize) y_grid = x_grid.t() xy_grid = torch.stack([x_grid, y_grid], dim=-1).float() @@ -35,7 +35,7 @@ def get_gaussian_kernel_2d(ksize=3, sigma=1): class GaussianFilter(torch.nn.Module): - def __init__(self, ksize=3, sigma=1, channels=0): + def __init__(self, ksize: int = 3, sigma: float = 1.0, channels: int = 0) -> torch.Tensor: super(GaussianFilter, self).__init__() gkernel = get_gaussian_kernel_2d(ksize=ksize, sigma=sigma) neighbors_sum = (1 - gkernel[1, 1]) + 1e-16 @@ -68,7 +68,15 @@ class NACLLoss(_Loss): https://arxiv.org/abs/2303.06268 """ - def __init__(self, classes, kernel_size=3, kernel_ops="mean", distance_type="l1", alpha=0.1, sigma=1): + def __init__( + self, + classes, + kernel_size: int = 3, + kernel_ops: str = "mean", + distance_type: str = "l1", + alpha: float = 0.1, + sigma: float = 1.0, + ) -> torch.Tensor: """ Args: classes: number of classes @@ -122,8 +130,7 @@ def ce(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: return self.cross_entropy(input, target) # type: ignore[no-any-return] - def get_constr_target(self, mask): - + def get_constr_target(self, mask: torch.Tensor) -> torch.Tensor: mask = mask.unsqueeze(1) # unfold works for 4d. bs, _, h, w = mask.shape @@ -138,7 +145,6 @@ def get_constr_target(self, mask): rmask.append(torch.sum(umask == ii, 1) / self.ks**2) if self.kernel_ops == "gaussian": - oh_labels = ( F.one_hot(mask[:, 0].to(torch.int64), num_classes=self.nc).contiguous().permute(0, 3, 1, 2).float() ) @@ -151,8 +157,7 @@ def get_constr_target(self, mask): return rmask - def forward(self, inputs, targets): - + def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: loss_ce = self.ce(inputs, targets) utargets = self.get_constr_target(targets) From 59959ce310dc036510f4499176935de0d917a47c Mon Sep 17 00:00:00 2001 From: Balamurali Date: Fri, 14 Jun 2024 06:32:10 -0700 Subject: [PATCH 11/50] Update monai/losses/segcalib.py Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Signed-off-by: Balamurali --- monai/losses/segcalib.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/losses/segcalib.py b/monai/losses/segcalib.py index ba13af104a..565f6ca821 100644 --- a/monai/losses/segcalib.py +++ b/monai/losses/segcalib.py @@ -170,4 +170,4 @@ def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: loss = loss_ce + self.alpha * loss_conf - return loss # , loss_ce, loss_conf + return loss From cf1d044f817bbd8c810eeff161bd02e08fd85b28 Mon Sep 17 00:00:00 2001 From: Balamurali Date: Fri, 14 Jun 2024 06:33:30 -0700 Subject: [PATCH 12/50] Update monai/losses/segcalib.py Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Signed-off-by: Balamurali --- monai/losses/segcalib.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/monai/losses/segcalib.py b/monai/losses/segcalib.py index 565f6ca821..c786c9e02d 100644 --- a/monai/losses/segcalib.py +++ b/monai/losses/segcalib.py @@ -163,10 +163,9 @@ def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: utargets = self.get_constr_target(targets) if self.distance_type == "l1": - loss_conf = torch.abs(utargets - inputs).mean() - - if self.distance_type == "l2": - loss_conf = (torch.abs(utargets - inputs) ** 2).mean() + loss_conf = utargets.sub(inputs).abs_().mean() + elif self.distance_type == "l2": + loss_conf = utargets.sub(inputs).pow_(2).abs_().mean() loss = loss_ce + self.alpha * loss_conf From 0926851268901f2df668f7ab7ea01041d257a862 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 14 Jun 2024 13:33:48 +0000 Subject: [PATCH 13/50] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/losses/segcalib.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/losses/segcalib.py b/monai/losses/segcalib.py index c786c9e02d..8748005ed5 100644 --- a/monai/losses/segcalib.py +++ b/monai/losses/segcalib.py @@ -169,4 +169,4 @@ def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: loss = loss_ce + self.alpha * loss_conf - return loss + return loss From 5317706d70212aa50e5b3ef50385c701668122de Mon Sep 17 00:00:00 2001 From: Balamurali Date: Fri, 14 Jun 2024 22:58:39 -0700 Subject: [PATCH 14/50] Update segcalib.py Add model description. --- monai/losses/segcalib.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/monai/losses/segcalib.py b/monai/losses/segcalib.py index 8748005ed5..6d74f2f1ae 100644 --- a/monai/losses/segcalib.py +++ b/monai/losses/segcalib.py @@ -62,6 +62,10 @@ def forward(self, x): class NACLLoss(_Loss): """ + Neighbor-Aware Calibration Loss (NACL) is primarily developed for developing calibrated models in image segmentation. + NACL computes standard cross-entropy loss with a linear penalty that enforces the logit distributions + to match a soft class proportion of surrounding pixel. + Murugesan, Balamurali, et al. "Trust your neighbours: Penalty-based constraints for model calibration." International Conference on Medical Image Computing and Computer-Assisted Intervention, 2023. From 3155433c12c137429358bbec77e9718f4cce18a4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 15 Jun 2024 05:59:01 +0000 Subject: [PATCH 15/50] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/losses/segcalib.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/losses/segcalib.py b/monai/losses/segcalib.py index 6d74f2f1ae..038f0f4713 100644 --- a/monai/losses/segcalib.py +++ b/monai/losses/segcalib.py @@ -64,7 +64,7 @@ class NACLLoss(_Loss): """ Neighbor-Aware Calibration Loss (NACL) is primarily developed for developing calibrated models in image segmentation. NACL computes standard cross-entropy loss with a linear penalty that enforces the logit distributions - to match a soft class proportion of surrounding pixel. + to match a soft class proportion of surrounding pixel. Murugesan, Balamurali, et al. "Trust your neighbours: Penalty-based constraints for model calibration." From 7c121a07857c4d59703a7fd37b63ac52dac67b04 Mon Sep 17 00:00:00 2001 From: bala93 Date: Sat, 3 Aug 2024 11:22:55 -0400 Subject: [PATCH 16/50] Add specific to gaussian for both 2d and 3d --- monai/losses/segcalib.py | 179 +++++++++++++++++++++++---------------- tests/test_nacl_loss.py | 166 +++++++++++++++++++++++------------- 2 files changed, 210 insertions(+), 135 deletions(-) diff --git a/monai/losses/segcalib.py b/monai/losses/segcalib.py index 038f0f4713..586c5d24f5 100644 --- a/monai/losses/segcalib.py +++ b/monai/losses/segcalib.py @@ -34,27 +34,69 @@ def get_gaussian_kernel_2d(ksize: int = 3, sigma: float = 1.0) -> torch.Tensor: return gaussian_kernel / torch.sum(gaussian_kernel) +def get_gaussian_kernel_3d(ksize: int = 3, sigma: float = 1.0) -> torch.Tensor: + x_coord = torch.arange(ksize) + x_grid_2d = x_coord.repeat(ksize).view(ksize, ksize) + x_grid = x_coord.repeat(ksize * ksize).view(ksize, ksize, ksize) + y_grid_2d = x_grid_2d.t() + y_grid = y_grid_2d.repeat(ksize, 1).view(ksize, ksize, ksize) + z_grid = y_grid_2d.repeat(1, ksize).view(ksize, ksize, ksize) + xyz_grid = torch.stack([x_grid, y_grid, z_grid], dim=-1).float() + mean = (ksize - 1) / 2.0 + variance = sigma**2.0 + gaussian_kernel = (1.0 / (2.0 * math.pi * variance + 1e-16)) * torch.exp( + -torch.sum((xyz_grid - mean) ** 2.0, dim=-1) / (2 * variance + 1e-16) + ) + return gaussian_kernel / torch.sum(gaussian_kernel) + + class GaussianFilter(torch.nn.Module): - def __init__(self, ksize: int = 3, sigma: float = 1.0, channels: int = 0) -> torch.Tensor: + def __init__(self, dim: int = 3, ksize: int = 3, sigma: float = 1.0, channels: int = 0) -> torch.Tensor: super(GaussianFilter, self).__init__() - gkernel = get_gaussian_kernel_2d(ksize=ksize, sigma=sigma) - neighbors_sum = (1 - gkernel[1, 1]) + 1e-16 - gkernel[int(ksize / 2), int(ksize / 2)] = neighbors_sum - self.svls_kernel = gkernel / neighbors_sum - svls_kernel_2d = self.svls_kernel.view(1, 1, ksize, ksize) - svls_kernel_2d = svls_kernel_2d.repeat(channels, 1, 1, 1) - padding = int(ksize / 2) - self.svls_layer = torch.nn.Conv2d( - in_channels=channels, - out_channels=channels, - kernel_size=ksize, - groups=channels, - bias=False, - padding=padding, - padding_mode="replicate", - ) - self.svls_layer.weight.data = svls_kernel_2d - self.svls_layer.weight.requires_grad = False + + if dim == 2: + gkernel = get_gaussian_kernel_2d(ksize=ksize, sigma=sigma) + neighbors_sum = (1 - gkernel[1, 1]) + 1e-16 + gkernel[int(ksize / 2), int(ksize / 2)] = neighbors_sum + self.svls_kernel = gkernel / neighbors_sum + + svls_kernel_2d = self.svls_kernel.view(1, 1, ksize, ksize) + svls_kernel_2d = svls_kernel_2d.repeat(channels, 1, 1, 1) + padding = int(ksize / 2) + + self.svls_layer = torch.nn.Conv2d( + in_channels=channels, + out_channels=channels, + kernel_size=ksize, + groups=channels, + bias=False, + padding=padding, + padding_mode="replicate", + ) + self.svls_layer.weight.data = svls_kernel_2d + self.svls_layer.weight.requires_grad = False + + if dim == 3: + gkernel = get_gaussian_kernel_3d(ksize=ksize, sigma=sigma) + neighbors_sum = 1 - gkernel[1, 1, 1] + gkernel[1, 1, 1] = neighbors_sum + self.svls_kernel = gkernel / neighbors_sum + + svls_kernel_3d = self.svls_kernel.view(1, 1, ksize, ksize, ksize) + svls_kernel_3d = svls_kernel_3d.repeat(channels, 1, 1, 1, 1) + padding = int(ksize / 2) + + self.svls_layer = torch.nn.Conv3d( + in_channels=channels, + out_channels=channels, + kernel_size=ksize, + groups=channels, + bias=False, + padding=padding, + padding_mode="replicate", + ) + self.svls_layer.weight.data = svls_kernel_3d + self.svls_layer.weight.requires_grad = False def forward(self, x): return self.svls_layer(x) / self.svls_kernel.sum() @@ -64,7 +106,7 @@ class NACLLoss(_Loss): """ Neighbor-Aware Calibration Loss (NACL) is primarily developed for developing calibrated models in image segmentation. NACL computes standard cross-entropy loss with a linear penalty that enforces the logit distributions - to match a soft class proportion of surrounding pixel. + to match a soft class proportion of surrounding pixel. Murugesan, Balamurali, et al. "Trust your neighbours: Penalty-based constraints for model calibration." @@ -74,95 +116,84 @@ class NACLLoss(_Loss): def __init__( self, - classes, + classes: int, + dim: int, kernel_size: int = 3, - kernel_ops: str = "mean", distance_type: str = "l1", alpha: float = 0.1, sigma: float = 1.0, ) -> torch.Tensor: """ Args: - classes: number of classes + classes: number of classes kernel_size: size of the spatial kernel - kenel_ops: type of kernel operation (mean/gaussian) distance_type: l1/l2 distance between spatial kernel and predicted logits alpha: weightage between cross entropy and logit constraint - sigma: sigma if the kernel type is gaussian + sigma: sigma of gaussian """ super().__init__() - if kernel_ops not in ["mean", "gaussian"]: - raise ValueError("Kernel ops must be either mean or gaussian") + if dim not in [2, 3]: + raise ValueError("Supoorts 2d and 3d") if distance_type not in ["l1", "l2"]: raise ValueError("Distance type must be either L1 or L2") - self.kernel_ops = kernel_ops + self.nc = classes + self.dim = dim + self.cross_entropy = nn.CrossEntropyLoss() self.distance_type = distance_type self.alpha = alpha - - self.nc = classes self.ks = kernel_size - self.cross_entropy = nn.CrossEntropyLoss() - if kernel_ops == "gaussian": - self.svls_layer = GaussianFilter(ksize=kernel_size, sigma=sigma, channels=classes) + self.svls_layer = GaussianFilter(dim=dim, ksize=kernel_size, sigma=sigma, channels=classes) self.old_pt_ver = not pytorch_after(1, 10) - def ce(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - """ - Compute CrossEntropy loss for the input logits and target. - Will remove the channel dim according to PyTorch CrossEntropyLoss: - https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html?#torch.nn.CrossEntropyLoss. - - """ - n_pred_ch, n_target_ch = input.shape[1], target.shape[1] - if n_pred_ch != n_target_ch and n_target_ch == 1: - target = torch.squeeze(target, dim=1) - target = target.long() - elif self.old_pt_ver: - warnings.warn( - f"Multichannel targets are not supported in this older Pytorch version {torch.__version__}. " - "Using argmax (as a workaround) to convert target to a single channel." - ) - target = torch.argmax(target, dim=1) - elif not torch.is_floating_point(target): - target = target.to(dtype=input.dtype) - - return self.cross_entropy(input, target) # type: ignore[no-any-return] + # def ce(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + # """ + # Compute CrossEntropy loss for the input logits and target. + # Will remove the channel dim according to PyTorch CrossEntropyLoss: + # https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html?#torch.nn.CrossEntropyLoss. + + # """ + # n_pred_ch, n_target_ch = input.shape[1], target.shape[1] + # if n_pred_ch != n_target_ch and n_target_ch == 1: + # target = torch.squeeze(target, dim=1) + # target = target.long() + # elif self.old_pt_ver: + # warnings.warn( + # f"Multichannel targets are not supported in this older Pytorch version {torch.__version__}. " + # "Using argmax (as a workaround) to convert target to a single channel." + # ) + # target = torch.argmax(target, dim=1) + # elif not torch.is_floating_point(target): + # target = target.to(dtype=input.dtype) + + # return self.cross_entropy(input, target) # type: ignore[no-any-return] def get_constr_target(self, mask: torch.Tensor) -> torch.Tensor: - mask = mask.unsqueeze(1) # unfold works for 4d. - - bs, _, h, w = mask.shape - unfold = torch.nn.Unfold(kernel_size=(self.ks, self.ks), padding=self.ks // 2) + + if self.dim == 2: - rmask = [] - - if self.kernel_ops == "mean": - umask = unfold(mask.float()) - - for ii in range(self.nc): - rmask.append(torch.sum(umask == ii, 1) / self.ks**2) - - if self.kernel_ops == "gaussian": oh_labels = ( - F.one_hot(mask[:, 0].to(torch.int64), num_classes=self.nc).contiguous().permute(0, 3, 1, 2).float() + F.one_hot(mask.to(torch.int64), num_classes=self.nc).contiguous().permute(0, 3, 1, 2).float() ) rmask = self.svls_layer(oh_labels) - return rmask - - rmask = torch.stack(rmask, dim=1) - rmask = rmask.reshape(bs, self.nc, h, w) + if self.dim == 3: + oh_labels = ( + F.one_hot(mask.to(torch.int64), num_classes=self.nc).contiguous().permute(0, 4, 1, 2, 3).float() + ) + rmask = self.svls_layer(oh_labels) + return rmask + def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: - loss_ce = self.ce(inputs, targets) + loss_ce = self.cross_entropy(inputs, targets) utargets = self.get_constr_target(targets) @@ -173,4 +204,4 @@ def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: loss = loss_ce + self.alpha * loss_conf - return loss + return loss \ No newline at end of file diff --git a/tests/test_nacl_loss.py b/tests/test_nacl_loss.py index fc42336584..52207bd426 100644 --- a/tests/test_nacl_loss.py +++ b/tests/test_nacl_loss.py @@ -21,76 +21,120 @@ TEST_CASES = [ [ # shape: (2, 2, 3), (2, 2, 3) - {"classes": 2}, + {"classes": 3, "dim": 2}, { - "inputs": torch.tensor( - [ - [[0.8959, 0.7435, 0.4429], [0.6038, 0.5506, 0.3869], [0.8485, 0.4703, 0.8790]], - [[0.5137, 0.8345, 0.2821], [0.3644, 0.8000, 0.5156], [0.4732, 0.2018, 0.4564]], - ] - ), - "targets": torch.tensor( - [ - [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], - [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], - ] - ), + "inputs": torch.tensor([[[[0.1498, 0.1158, 0.3996, 0.3730], + [0.2155, 0.1585, 0.8541, 0.8579], + [0.6640, 0.2424, 0.0774, 0.0324], + [0.0580, 0.2180, 0.3447, 0.8722]], + [[0.3908, 0.9366, 0.1779, 0.1003], + [0.9630, 0.6118, 0.4405, 0.7916], + [0.5782, 0.9515, 0.4088, 0.3946], + [0.7860, 0.3910, 0.0324, 0.9568]], + [[0.0759, 0.0238, 0.5570, 0.1691], + [0.2703, 0.7722, 0.1611, 0.6431], + [0.8051, 0.6596, 0.4121, 0.1125], + [0.5283, 0.6746, 0.5528, 0.7913]]]]), + "targets": torch.tensor([[[1, 1, 1, 1], + [1, 1, 1, 0], + [0, 0, 0, 0], + [0, 0, 0, 0]]]), }, - 3.3611, # the result equals to -1 + np.log(1 + np.exp(1)) + 1.1850, # the result equals to -1 + np.log(1 + np.exp(1)) ], [ # shape: (2, 2, 3), (2, 2, 3) - {"classes": 2, "kernel_ops": "gaussian"}, + {"classes": 3, "dim": 2}, { - "inputs": torch.tensor( - [ - [[0.8959, 0.7435, 0.4429], [0.6038, 0.5506, 0.3869], [0.8485, 0.4703, 0.8790]], - [[0.5137, 0.8345, 0.2821], [0.3644, 0.8000, 0.5156], [0.4732, 0.2018, 0.4564]], - ] - ), - "targets": torch.tensor( - [ - [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], - [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], - ] - ), + "inputs": torch.tensor([[[[0.0411, 0.3386, 0.8352, 0.2741], + [0.4821, 0.0519, 0.2561, 0.9391], + [0.5954, 0.4184, 0.9160, 0.7977], + [0.0588, 0.9156, 0.1307, 0.9914]], + [[0.8481, 0.0892, 0.2044, 0.8470], + [0.3558, 0.3569, 0.0979, 0.4491], + [0.0876, 0.0929, 0.4040, 0.8384], + [0.5313, 0.3927, 0.4165, 0.1107]], + [[0.7993, 0.6938, 0.3151, 0.8728], + [0.7332, 0.4111, 0.3862, 0.9988], + [0.2622, 0.5002, 0.1905, 0.1644], + [0.6354, 0.0047, 0.1649, 0.7112]]]]), + + "targets": torch.tensor([[[1, 2, 0, 1], + [0, 2, 1, 2], + [0, 0, 2, 1], + [1, 1, 1, 2]]]), }, - 3.3963, # the result equals to -1 + np.log(1 + np.exp(1)) + 1.0375, # the result equals to -1 + np.log(1 + np.exp(1)) ], [ # shape: (2, 2, 3), (2, 2, 3) - {"classes": 2, "distance_type": "l2"}, + {"classes": 3, "dim": 3}, { - "inputs": torch.tensor( - [ - [[0.8959, 0.7435, 0.4429], [0.6038, 0.5506, 0.3869], [0.8485, 0.4703, 0.8790]], - [[0.5137, 0.8345, 0.2821], [0.3644, 0.8000, 0.5156], [0.4732, 0.2018, 0.4564]], - ] - ), - "targets": torch.tensor( - [ - [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], - [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], - ] - ), - }, - 3.3459, # the result equals to -1 + np.log(1 + np.exp(1)) - ], - [ # shape: (2, 2, 3), (2, 2, 3) - {"classes": 2, "alpha": 0.2}, - { - "inputs": torch.tensor( - [ - [[0.8959, 0.7435, 0.4429], [0.6038, 0.5506, 0.3869], [0.8485, 0.4703, 0.8790]], - [[0.5137, 0.8345, 0.2821], [0.3644, 0.8000, 0.5156], [0.4732, 0.2018, 0.4564]], - ] - ), - "targets": torch.tensor( - [ - [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], - [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], - ] - ), - }, - 3.3836, # the result equals to -1 + np.log(1 + np.exp(1)) + "inputs": torch.tensor([[[[[0.5977, 0.2767, 0.0591, 0.1675], + [0.4835, 0.3778, 0.8406, 0.3065], + [0.6047, 0.2860, 0.9742, 0.2013], + [0.9128, 0.8368, 0.6711, 0.4384]], + [[0.9797, 0.1863, 0.5584, 0.6652], + [0.2272, 0.2004, 0.7914, 0.4224], + [0.5097, 0.8818, 0.2581, 0.3495], + [0.1054, 0.5483, 0.3732, 0.3587]], + [[0.3060, 0.7066, 0.7922, 0.4689], + [0.1733, 0.8902, 0.6704, 0.2037], + [0.8656, 0.5561, 0.2701, 0.0092], + [0.1866, 0.7714, 0.6424, 0.9791]], + [[0.5067, 0.3829, 0.6156, 0.8985], + [0.5192, 0.8347, 0.2098, 0.2260], + [0.8887, 0.3944, 0.6400, 0.5345], + [0.1207, 0.3763, 0.5282, 0.7741]]], + [[[0.8499, 0.4759, 0.1964, 0.5701], + [0.3190, 0.1238, 0.2368, 0.9517], + [0.0797, 0.6185, 0.0135, 0.8672], + [0.4116, 0.1683, 0.1355, 0.0545]], + [[0.7533, 0.2658, 0.5955, 0.4498], + [0.9500, 0.2317, 0.2825, 0.9763], + [0.1493, 0.1558, 0.3743, 0.8723], + [0.1723, 0.7980, 0.8816, 0.0133]], + [[0.8426, 0.2666, 0.2077, 0.3161], + [0.1725, 0.8414, 0.1515, 0.2825], + [0.4882, 0.5159, 0.4120, 0.1585], + [0.2551, 0.9073, 0.7691, 0.9898]], + [[0.4633, 0.8717, 0.8537, 0.2899], + [0.3693, 0.7953, 0.1183, 0.4596], + [0.0087, 0.7925, 0.0989, 0.8385], + [0.8261, 0.6920, 0.7069, 0.4464]]], + [[[0.0110, 0.1608, 0.4814, 0.6317], + [0.0194, 0.9669, 0.3259, 0.0028], + [0.5674, 0.8286, 0.0306, 0.5309], + [0.3973, 0.8183, 0.0238, 0.1934]], + [[0.8947, 0.6629, 0.9439, 0.8905], + [0.0072, 0.1697, 0.4634, 0.0201], + [0.7184, 0.2424, 0.0820, 0.7504], + [0.3937, 0.1424, 0.4463, 0.5779]], + [[0.4123, 0.6227, 0.0523, 0.8826], + [0.0051, 0.0353, 0.3662, 0.7697], + [0.4867, 0.8986, 0.2510, 0.5316], + [0.1856, 0.2634, 0.9140, 0.9725]], + [[0.2041, 0.4248, 0.2371, 0.7256], + [0.2168, 0.5380, 0.4538, 0.7007], + [0.9013, 0.2623, 0.0739, 0.2998], + [0.1366, 0.5590, 0.2952, 0.4592]]]]]), + + "targets": torch.tensor([[[[0, 1, 0, 1], + [1, 2, 1, 0], + [2, 1, 1, 1], + [1, 1, 0, 1]], + [[2, 1, 0, 2], + [1, 2, 0, 2], + [1, 0, 1, 1], + [1, 1, 0, 0]], + [[1, 0, 2, 1], + [0, 2, 2, 1], + [1, 0, 1, 1], + [0, 0, 2, 1]], + [[2, 1, 1, 0], + [1, 0, 0, 2], + [1, 0, 2, 1], + [2, 1, 0, 1]]]]), + }, + 1.1504, # the result equals to -1 + np.log(1 + np.exp(1)) ], ] From 24efd85e4b9cdf5d0688f42fbe3f6679dea82bf3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 3 Aug 2024 15:23:26 +0000 Subject: [PATCH 17/50] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/losses/segcalib.py | 17 ++++++++--------- tests/test_nacl_loss.py | 4 ++-- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/monai/losses/segcalib.py b/monai/losses/segcalib.py index 586c5d24f5..5b4de7916f 100644 --- a/monai/losses/segcalib.py +++ b/monai/losses/segcalib.py @@ -12,7 +12,6 @@ from __future__ import annotations import math -import warnings import torch import torch.nn as nn @@ -63,7 +62,7 @@ def __init__(self, dim: int = 3, ksize: int = 3, sigma: float = 1.0, channels: i svls_kernel_2d = self.svls_kernel.view(1, 1, ksize, ksize) svls_kernel_2d = svls_kernel_2d.repeat(channels, 1, 1, 1) padding = int(ksize / 2) - + self.svls_layer = torch.nn.Conv2d( in_channels=channels, out_channels=channels, @@ -85,7 +84,7 @@ def __init__(self, dim: int = 3, ksize: int = 3, sigma: float = 1.0, channels: i svls_kernel_3d = self.svls_kernel.view(1, 1, ksize, ksize, ksize) svls_kernel_3d = svls_kernel_3d.repeat(channels, 1, 1, 1, 1) padding = int(ksize / 2) - + self.svls_layer = torch.nn.Conv3d( in_channels=channels, out_channels=channels, @@ -106,7 +105,7 @@ class NACLLoss(_Loss): """ Neighbor-Aware Calibration Loss (NACL) is primarily developed for developing calibrated models in image segmentation. NACL computes standard cross-entropy loss with a linear penalty that enforces the logit distributions - to match a soft class proportion of surrounding pixel. + to match a soft class proportion of surrounding pixel. Murugesan, Balamurali, et al. "Trust your neighbours: Penalty-based constraints for model calibration." @@ -125,7 +124,7 @@ def __init__( ) -> torch.Tensor: """ Args: - classes: number of classes + classes: number of classes kernel_size: size of the spatial kernel distance_type: l1/l2 distance between spatial kernel and predicted logits alpha: weightage between cross entropy and logit constraint @@ -174,7 +173,7 @@ def __init__( # return self.cross_entropy(input, target) # type: ignore[no-any-return] def get_constr_target(self, mask: torch.Tensor) -> torch.Tensor: - + if self.dim == 2: oh_labels = ( @@ -187,8 +186,8 @@ def get_constr_target(self, mask: torch.Tensor) -> torch.Tensor: oh_labels = ( F.one_hot(mask.to(torch.int64), num_classes=self.nc).contiguous().permute(0, 4, 1, 2, 3).float() ) - rmask = self.svls_layer(oh_labels) - + rmask = self.svls_layer(oh_labels) + return rmask @@ -204,4 +203,4 @@ def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: loss = loss_ce + self.alpha * loss_conf - return loss \ No newline at end of file + return loss diff --git a/tests/test_nacl_loss.py b/tests/test_nacl_loss.py index 52207bd426..19a2ef6336 100644 --- a/tests/test_nacl_loss.py +++ b/tests/test_nacl_loss.py @@ -57,7 +57,7 @@ [0.7332, 0.4111, 0.3862, 0.9988], [0.2622, 0.5002, 0.1905, 0.1644], [0.6354, 0.0047, 0.1649, 0.7112]]]]), - + "targets": torch.tensor([[[1, 2, 0, 1], [0, 2, 1, 2], [0, 0, 2, 1], @@ -116,7 +116,7 @@ [0.2168, 0.5380, 0.4538, 0.7007], [0.9013, 0.2623, 0.0739, 0.2998], [0.1366, 0.5590, 0.2952, 0.4592]]]]]), - + "targets": torch.tensor([[[[0, 1, 0, 1], [1, 2, 1, 0], [2, 1, 1, 1], From dccde477906742ff237861696168ff2377ffdb20 Mon Sep 17 00:00:00 2001 From: bala93 Date: Sat, 3 Aug 2024 16:16:38 -0400 Subject: [PATCH 18/50] Add mean loss and resolve formatting --- monai/losses/segcalib.py | 76 ++++++++-- tests/test_nacl_loss.py | 292 ++++++++++++++++++++++++--------------- 2 files changed, 249 insertions(+), 119 deletions(-) diff --git a/monai/losses/segcalib.py b/monai/losses/segcalib.py index 5b4de7916f..38dc97f7d6 100644 --- a/monai/losses/segcalib.py +++ b/monai/losses/segcalib.py @@ -12,6 +12,7 @@ from __future__ import annotations import math +import warnings import torch import torch.nn as nn @@ -21,6 +22,16 @@ from monai.utils import pytorch_after +def get_mean_kernel_2d(ksize: int = 3) -> torch.Tensor: + mean_kernel = torch.ones([ksize, ksize]) / (ksize**2) + return mean_kernel + + +def get_mean_kernel_3d(ksize: int = 3) -> torch.Tensor: + mean_kernel = torch.ones([ksize, ksize, ksize]) / (ksize**3) + return mean_kernel + + def get_gaussian_kernel_2d(ksize: int = 3, sigma: float = 1.0) -> torch.Tensor: x_grid = torch.arange(ksize).repeat(ksize).view(ksize, ksize) y_grid = x_grid.t() @@ -101,6 +112,50 @@ def forward(self, x): return self.svls_layer(x) / self.svls_kernel.sum() +class MeanFilter(torch.nn.Module): + def __init__(self, dim: int = 3, ksize: int = 3, channels: int = 0) -> torch.Tensor: + super(MeanFilter, self).__init__() + + if dim == 2: + self.svls_kernel = get_mean_kernel_2d(ksize=ksize) + svls_kernel_2d = self.svls_kernel.view(1, 1, ksize, ksize) + svls_kernel_2d = svls_kernel_2d.repeat(channels, 1, 1, 1) + padding = int(ksize / 2) + + self.svls_layer = torch.nn.Conv2d( + in_channels=channels, + out_channels=channels, + kernel_size=ksize, + groups=channels, + bias=False, + padding=padding, + padding_mode="replicate", + ) + self.svls_layer.weight.data = svls_kernel_2d + self.svls_layer.weight.requires_grad = False + + if dim == 3: + self.svls_kernel = get_mean_kernel_3d(ksize=ksize) + svls_kernel_3d = self.svls_kernel.view(1, 1, ksize, ksize) + svls_kernel_3d = svls_kernel_3d.repeat(channels, 1, 1, 1) + padding = int(ksize / 2) + + self.svls_layer = torch.nn.Conv3d( + in_channels=channels, + out_channels=channels, + kernel_size=ksize, + groups=channels, + bias=False, + padding=padding, + padding_mode="replicate", + ) + self.svls_layer.weight.data = svls_kernel_3d + self.svls_layer.weight.requires_grad = False + + def forward(self, x): + return self.svls_layer(x) / self.svls_kernel.sum() + + class NACLLoss(_Loss): """ Neighbor-Aware Calibration Loss (NACL) is primarily developed for developing calibrated models in image segmentation. @@ -118,6 +173,7 @@ def __init__( classes: int, dim: int, kernel_size: int = 3, + kernel_ops: str = "mean", distance_type: str = "l1", alpha: float = 0.1, sigma: float = 1.0, @@ -133,6 +189,9 @@ def __init__( super().__init__() + if kernel_ops not in ["mean", "gaussian"]: + raise ValueError("Kernel ops must be either mean or gaussian") + if dim not in [2, 3]: raise ValueError("Supoorts 2d and 3d") @@ -146,7 +205,10 @@ def __init__( self.alpha = alpha self.ks = kernel_size - self.svls_layer = GaussianFilter(dim=dim, ksize=kernel_size, sigma=sigma, channels=classes) + if kernel_ops == "mean": + self.svls_layer = MeanFilter(dim=dim, ksize=kernel_size, channels=classes) + if kernel_ops == "gaussian": + self.svls_layer = GaussianFilter(dim=dim, ksize=kernel_size, sigma=sigma, channels=classes) self.old_pt_ver = not pytorch_after(1, 10) @@ -173,24 +235,16 @@ def __init__( # return self.cross_entropy(input, target) # type: ignore[no-any-return] def get_constr_target(self, mask: torch.Tensor) -> torch.Tensor: - if self.dim == 2: - - oh_labels = ( - F.one_hot(mask.to(torch.int64), num_classes=self.nc).contiguous().permute(0, 3, 1, 2).float() - ) + oh_labels = F.one_hot(mask.to(torch.int64), num_classes=self.nc).contiguous().permute(0, 3, 1, 2).float() rmask = self.svls_layer(oh_labels) if self.dim == 3: - - oh_labels = ( - F.one_hot(mask.to(torch.int64), num_classes=self.nc).contiguous().permute(0, 4, 1, 2, 3).float() - ) + oh_labels = F.one_hot(mask.to(torch.int64), num_classes=self.nc).contiguous().permute(0, 4, 1, 2, 3).float() rmask = self.svls_layer(oh_labels) return rmask - def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: loss_ce = self.cross_entropy(inputs, targets) diff --git a/tests/test_nacl_loss.py b/tests/test_nacl_loss.py index 19a2ef6336..1a4772dcb8 100644 --- a/tests/test_nacl_loss.py +++ b/tests/test_nacl_loss.py @@ -20,127 +20,203 @@ from monai.losses import NACLLoss TEST_CASES = [ - [ # shape: (2, 2, 3), (2, 2, 3) + [ {"classes": 3, "dim": 2}, { - "inputs": torch.tensor([[[[0.1498, 0.1158, 0.3996, 0.3730], - [0.2155, 0.1585, 0.8541, 0.8579], - [0.6640, 0.2424, 0.0774, 0.0324], - [0.0580, 0.2180, 0.3447, 0.8722]], - [[0.3908, 0.9366, 0.1779, 0.1003], - [0.9630, 0.6118, 0.4405, 0.7916], - [0.5782, 0.9515, 0.4088, 0.3946], - [0.7860, 0.3910, 0.0324, 0.9568]], - [[0.0759, 0.0238, 0.5570, 0.1691], - [0.2703, 0.7722, 0.1611, 0.6431], - [0.8051, 0.6596, 0.4121, 0.1125], - [0.5283, 0.6746, 0.5528, 0.7913]]]]), - "targets": torch.tensor([[[1, 1, 1, 1], - [1, 1, 1, 0], - [0, 0, 0, 0], - [0, 0, 0, 0]]]), + "inputs": torch.tensor( + [ + [ + [ + [0.1498, 0.1158, 0.3996, 0.3730], + [0.2155, 0.1585, 0.8541, 0.8579], + [0.6640, 0.2424, 0.0774, 0.0324], + [0.0580, 0.2180, 0.3447, 0.8722], + ], + [ + [0.3908, 0.9366, 0.1779, 0.1003], + [0.9630, 0.6118, 0.4405, 0.7916], + [0.5782, 0.9515, 0.4088, 0.3946], + [0.7860, 0.3910, 0.0324, 0.9568], + ], + [ + [0.0759, 0.0238, 0.5570, 0.1691], + [0.2703, 0.7722, 0.1611, 0.6431], + [0.8051, 0.6596, 0.4121, 0.1125], + [0.5283, 0.6746, 0.5528, 0.7913], + ], + ] + ] + ), + "targets": torch.tensor([[[1, 1, 1, 1], [1, 1, 1, 0], [0, 0, 0, 0], [0, 0, 0, 0]]]), }, - 1.1850, # the result equals to -1 + np.log(1 + np.exp(1)) + 1.1820, ], - [ # shape: (2, 2, 3), (2, 2, 3) - {"classes": 3, "dim": 2}, + [ + {"classes": 3, "dim": 2, "kernel_ops": "gaussian"}, { - "inputs": torch.tensor([[[[0.0411, 0.3386, 0.8352, 0.2741], - [0.4821, 0.0519, 0.2561, 0.9391], - [0.5954, 0.4184, 0.9160, 0.7977], - [0.0588, 0.9156, 0.1307, 0.9914]], - [[0.8481, 0.0892, 0.2044, 0.8470], - [0.3558, 0.3569, 0.0979, 0.4491], - [0.0876, 0.0929, 0.4040, 0.8384], - [0.5313, 0.3927, 0.4165, 0.1107]], - [[0.7993, 0.6938, 0.3151, 0.8728], - [0.7332, 0.4111, 0.3862, 0.9988], - [0.2622, 0.5002, 0.1905, 0.1644], - [0.6354, 0.0047, 0.1649, 0.7112]]]]), - - "targets": torch.tensor([[[1, 2, 0, 1], - [0, 2, 1, 2], - [0, 0, 2, 1], - [1, 1, 1, 2]]]), + "inputs": torch.tensor( + [ + [ + [ + [0.1498, 0.1158, 0.3996, 0.3730], + [0.2155, 0.1585, 0.8541, 0.8579], + [0.6640, 0.2424, 0.0774, 0.0324], + [0.0580, 0.2180, 0.3447, 0.8722], + ], + [ + [0.3908, 0.9366, 0.1779, 0.1003], + [0.9630, 0.6118, 0.4405, 0.7916], + [0.5782, 0.9515, 0.4088, 0.3946], + [0.7860, 0.3910, 0.0324, 0.9568], + ], + [ + [0.0759, 0.0238, 0.5570, 0.1691], + [0.2703, 0.7722, 0.1611, 0.6431], + [0.8051, 0.6596, 0.4121, 0.1125], + [0.5283, 0.6746, 0.5528, 0.7913], + ], + ] + ] + ), + "targets": torch.tensor([[[1, 1, 1, 1], [1, 1, 1, 0], [0, 0, 0, 0], [0, 0, 0, 0]]]), }, - 1.0375, # the result equals to -1 + np.log(1 + np.exp(1)) + 1.1850, ], - [ # shape: (2, 2, 3), (2, 2, 3) - {"classes": 3, "dim": 3}, + [ + {"classes": 3, "dim": 2, "kernel_ops": "gaussian"}, { - "inputs": torch.tensor([[[[[0.5977, 0.2767, 0.0591, 0.1675], - [0.4835, 0.3778, 0.8406, 0.3065], - [0.6047, 0.2860, 0.9742, 0.2013], - [0.9128, 0.8368, 0.6711, 0.4384]], - [[0.9797, 0.1863, 0.5584, 0.6652], - [0.2272, 0.2004, 0.7914, 0.4224], - [0.5097, 0.8818, 0.2581, 0.3495], - [0.1054, 0.5483, 0.3732, 0.3587]], - [[0.3060, 0.7066, 0.7922, 0.4689], - [0.1733, 0.8902, 0.6704, 0.2037], - [0.8656, 0.5561, 0.2701, 0.0092], - [0.1866, 0.7714, 0.6424, 0.9791]], - [[0.5067, 0.3829, 0.6156, 0.8985], - [0.5192, 0.8347, 0.2098, 0.2260], - [0.8887, 0.3944, 0.6400, 0.5345], - [0.1207, 0.3763, 0.5282, 0.7741]]], - [[[0.8499, 0.4759, 0.1964, 0.5701], - [0.3190, 0.1238, 0.2368, 0.9517], - [0.0797, 0.6185, 0.0135, 0.8672], - [0.4116, 0.1683, 0.1355, 0.0545]], - [[0.7533, 0.2658, 0.5955, 0.4498], - [0.9500, 0.2317, 0.2825, 0.9763], - [0.1493, 0.1558, 0.3743, 0.8723], - [0.1723, 0.7980, 0.8816, 0.0133]], - [[0.8426, 0.2666, 0.2077, 0.3161], - [0.1725, 0.8414, 0.1515, 0.2825], - [0.4882, 0.5159, 0.4120, 0.1585], - [0.2551, 0.9073, 0.7691, 0.9898]], - [[0.4633, 0.8717, 0.8537, 0.2899], - [0.3693, 0.7953, 0.1183, 0.4596], - [0.0087, 0.7925, 0.0989, 0.8385], - [0.8261, 0.6920, 0.7069, 0.4464]]], - [[[0.0110, 0.1608, 0.4814, 0.6317], - [0.0194, 0.9669, 0.3259, 0.0028], - [0.5674, 0.8286, 0.0306, 0.5309], - [0.3973, 0.8183, 0.0238, 0.1934]], - [[0.8947, 0.6629, 0.9439, 0.8905], - [0.0072, 0.1697, 0.4634, 0.0201], - [0.7184, 0.2424, 0.0820, 0.7504], - [0.3937, 0.1424, 0.4463, 0.5779]], - [[0.4123, 0.6227, 0.0523, 0.8826], - [0.0051, 0.0353, 0.3662, 0.7697], - [0.4867, 0.8986, 0.2510, 0.5316], - [0.1856, 0.2634, 0.9140, 0.9725]], - [[0.2041, 0.4248, 0.2371, 0.7256], - [0.2168, 0.5380, 0.4538, 0.7007], - [0.9013, 0.2623, 0.0739, 0.2998], - [0.1366, 0.5590, 0.2952, 0.4592]]]]]), - - "targets": torch.tensor([[[[0, 1, 0, 1], - [1, 2, 1, 0], - [2, 1, 1, 1], - [1, 1, 0, 1]], - [[2, 1, 0, 2], - [1, 2, 0, 2], - [1, 0, 1, 1], - [1, 1, 0, 0]], - [[1, 0, 2, 1], - [0, 2, 2, 1], - [1, 0, 1, 1], - [0, 0, 2, 1]], - [[2, 1, 1, 0], - [1, 0, 0, 2], - [1, 0, 2, 1], - [2, 1, 0, 1]]]]), - }, - 1.1504, # the result equals to -1 + np.log(1 + np.exp(1)) + "inputs": torch.tensor( + [ + [ + [ + [0.0411, 0.3386, 0.8352, 0.2741], + [0.4821, 0.0519, 0.2561, 0.9391], + [0.5954, 0.4184, 0.9160, 0.7977], + [0.0588, 0.9156, 0.1307, 0.9914], + ], + [ + [0.8481, 0.0892, 0.2044, 0.8470], + [0.3558, 0.3569, 0.0979, 0.4491], + [0.0876, 0.0929, 0.4040, 0.8384], + [0.5313, 0.3927, 0.4165, 0.1107], + ], + [ + [0.7993, 0.6938, 0.3151, 0.8728], + [0.7332, 0.4111, 0.3862, 0.9988], + [0.2622, 0.5002, 0.1905, 0.1644], + [0.6354, 0.0047, 0.1649, 0.7112], + ], + ] + ] + ), + "targets": torch.tensor([[[1, 2, 0, 1], [0, 2, 1, 2], [0, 0, 2, 1], [1, 1, 1, 2]]]), + }, + 1.0375, + ], + [ + {"classes": 3, "dim": 3, "kernel_ops": "gaussian"}, + { + "inputs": torch.tensor( + [ + [ + [ + [ + [0.5977, 0.2767, 0.0591, 0.1675], + [0.4835, 0.3778, 0.8406, 0.3065], + [0.6047, 0.2860, 0.9742, 0.2013], + [0.9128, 0.8368, 0.6711, 0.4384], + ], + [ + [0.9797, 0.1863, 0.5584, 0.6652], + [0.2272, 0.2004, 0.7914, 0.4224], + [0.5097, 0.8818, 0.2581, 0.3495], + [0.1054, 0.5483, 0.3732, 0.3587], + ], + [ + [0.3060, 0.7066, 0.7922, 0.4689], + [0.1733, 0.8902, 0.6704, 0.2037], + [0.8656, 0.5561, 0.2701, 0.0092], + [0.1866, 0.7714, 0.6424, 0.9791], + ], + [ + [0.5067, 0.3829, 0.6156, 0.8985], + [0.5192, 0.8347, 0.2098, 0.2260], + [0.8887, 0.3944, 0.6400, 0.5345], + [0.1207, 0.3763, 0.5282, 0.7741], + ], + ], + [ + [ + [0.8499, 0.4759, 0.1964, 0.5701], + [0.3190, 0.1238, 0.2368, 0.9517], + [0.0797, 0.6185, 0.0135, 0.8672], + [0.4116, 0.1683, 0.1355, 0.0545], + ], + [ + [0.7533, 0.2658, 0.5955, 0.4498], + [0.9500, 0.2317, 0.2825, 0.9763], + [0.1493, 0.1558, 0.3743, 0.8723], + [0.1723, 0.7980, 0.8816, 0.0133], + ], + [ + [0.8426, 0.2666, 0.2077, 0.3161], + [0.1725, 0.8414, 0.1515, 0.2825], + [0.4882, 0.5159, 0.4120, 0.1585], + [0.2551, 0.9073, 0.7691, 0.9898], + ], + [ + [0.4633, 0.8717, 0.8537, 0.2899], + [0.3693, 0.7953, 0.1183, 0.4596], + [0.0087, 0.7925, 0.0989, 0.8385], + [0.8261, 0.6920, 0.7069, 0.4464], + ], + ], + [ + [ + [0.0110, 0.1608, 0.4814, 0.6317], + [0.0194, 0.9669, 0.3259, 0.0028], + [0.5674, 0.8286, 0.0306, 0.5309], + [0.3973, 0.8183, 0.0238, 0.1934], + ], + [ + [0.8947, 0.6629, 0.9439, 0.8905], + [0.0072, 0.1697, 0.4634, 0.0201], + [0.7184, 0.2424, 0.0820, 0.7504], + [0.3937, 0.1424, 0.4463, 0.5779], + ], + [ + [0.4123, 0.6227, 0.0523, 0.8826], + [0.0051, 0.0353, 0.3662, 0.7697], + [0.4867, 0.8986, 0.2510, 0.5316], + [0.1856, 0.2634, 0.9140, 0.9725], + ], + [ + [0.2041, 0.4248, 0.2371, 0.7256], + [0.2168, 0.5380, 0.4538, 0.7007], + [0.9013, 0.2623, 0.0739, 0.2998], + [0.1366, 0.5590, 0.2952, 0.4592], + ], + ], + ] + ] + ), + "targets": torch.tensor( + [ + [ + [[0, 1, 0, 1], [1, 2, 1, 0], [2, 1, 1, 1], [1, 1, 0, 1]], + [[2, 1, 0, 2], [1, 2, 0, 2], [1, 0, 1, 1], [1, 1, 0, 0]], + [[1, 0, 2, 1], [0, 2, 2, 1], [1, 0, 1, 1], [0, 0, 2, 1]], + [[2, 1, 1, 0], [1, 0, 0, 2], [1, 0, 2, 1], [2, 1, 0, 1]], + ] + ] + ), + }, + 1.1504, ], ] class TestNACLLoss(unittest.TestCase): - @parameterized.expand(TEST_CASES) def test_result(self, input_param, input_data, expected_val): loss = NACLLoss(**input_param) From 44e80653b6ddf7fb5954d1013d905cbbd460b9eb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 3 Aug 2024 20:17:04 +0000 Subject: [PATCH 19/50] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/losses/segcalib.py | 1 - 1 file changed, 1 deletion(-) diff --git a/monai/losses/segcalib.py b/monai/losses/segcalib.py index 38dc97f7d6..44f0bab541 100644 --- a/monai/losses/segcalib.py +++ b/monai/losses/segcalib.py @@ -12,7 +12,6 @@ from __future__ import annotations import math -import warnings import torch import torch.nn as nn From 5cd9a33ceaede6eb65d98023511b89aa0f9a8b0c Mon Sep 17 00:00:00 2001 From: Balamurali Date: Sat, 3 Aug 2024 14:18:06 -0700 Subject: [PATCH 20/50] Update segcalib.py Fixed the initialization issues with return types --- monai/losses/segcalib.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/monai/losses/segcalib.py b/monai/losses/segcalib.py index 44f0bab541..2e7eee1680 100644 --- a/monai/losses/segcalib.py +++ b/monai/losses/segcalib.py @@ -12,6 +12,7 @@ from __future__ import annotations import math +import warnings import torch import torch.nn as nn @@ -62,6 +63,9 @@ def get_gaussian_kernel_3d(ksize: int = 3, sigma: float = 1.0) -> torch.Tensor: class GaussianFilter(torch.nn.Module): def __init__(self, dim: int = 3, ksize: int = 3, sigma: float = 1.0, channels: int = 0) -> torch.Tensor: super(GaussianFilter, self).__init__() + + self.svls_kernel: torch.Tensor + self.svls_layer: torch.Tensor if dim == 2: gkernel = get_gaussian_kernel_2d(ksize=ksize, sigma=sigma) @@ -114,6 +118,9 @@ def forward(self, x): class MeanFilter(torch.nn.Module): def __init__(self, dim: int = 3, ksize: int = 3, channels: int = 0) -> torch.Tensor: super(MeanFilter, self).__init__() + + self.svls_kernel: torch.Tensor + self.svls_layer: torch.Tensor if dim == 2: self.svls_kernel = get_mean_kernel_2d(ksize=ksize) @@ -135,8 +142,8 @@ def __init__(self, dim: int = 3, ksize: int = 3, channels: int = 0) -> torch.Ten if dim == 3: self.svls_kernel = get_mean_kernel_3d(ksize=ksize) - svls_kernel_3d = self.svls_kernel.view(1, 1, ksize, ksize) - svls_kernel_3d = svls_kernel_3d.repeat(channels, 1, 1, 1) + svls_kernel_3d = self.svls_kernel.view(1, 1, ksize, ksize, ksize) + svls_kernel_3d = svls_kernel_3d.repeat(channels, 1, 1, 1, 1) padding = int(ksize / 2) self.svls_layer = torch.nn.Conv3d( @@ -203,6 +210,7 @@ def __init__( self.distance_type = distance_type self.alpha = alpha self.ks = kernel_size + self.svls_layer: torch.Tensor if kernel_ops == "mean": self.svls_layer = MeanFilter(dim=dim, ksize=kernel_size, channels=classes) From b547c4e990887cd34edc122197b21621517a9f1a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 3 Aug 2024 21:18:28 +0000 Subject: [PATCH 21/50] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/losses/segcalib.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/monai/losses/segcalib.py b/monai/losses/segcalib.py index 2e7eee1680..973189a6fb 100644 --- a/monai/losses/segcalib.py +++ b/monai/losses/segcalib.py @@ -12,7 +12,6 @@ from __future__ import annotations import math -import warnings import torch import torch.nn as nn @@ -63,9 +62,9 @@ def get_gaussian_kernel_3d(ksize: int = 3, sigma: float = 1.0) -> torch.Tensor: class GaussianFilter(torch.nn.Module): def __init__(self, dim: int = 3, ksize: int = 3, sigma: float = 1.0, channels: int = 0) -> torch.Tensor: super(GaussianFilter, self).__init__() - + self.svls_kernel: torch.Tensor - self.svls_layer: torch.Tensor + self.svls_layer: torch.Tensor if dim == 2: gkernel = get_gaussian_kernel_2d(ksize=ksize, sigma=sigma) @@ -118,9 +117,9 @@ def forward(self, x): class MeanFilter(torch.nn.Module): def __init__(self, dim: int = 3, ksize: int = 3, channels: int = 0) -> torch.Tensor: super(MeanFilter, self).__init__() - + self.svls_kernel: torch.Tensor - self.svls_layer: torch.Tensor + self.svls_layer: torch.Tensor if dim == 2: self.svls_kernel = get_mean_kernel_2d(ksize=ksize) From 42a021563605339f9d48af5e9b4633a2f72e61ee Mon Sep 17 00:00:00 2001 From: Balamurali Date: Sat, 3 Aug 2024 14:27:04 -0700 Subject: [PATCH 22/50] Update segcalib.py Use Any instead of tensor --- monai/losses/segcalib.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/monai/losses/segcalib.py b/monai/losses/segcalib.py index 973189a6fb..d0857c4bbb 100644 --- a/monai/losses/segcalib.py +++ b/monai/losses/segcalib.py @@ -63,8 +63,8 @@ class GaussianFilter(torch.nn.Module): def __init__(self, dim: int = 3, ksize: int = 3, sigma: float = 1.0, channels: int = 0) -> torch.Tensor: super(GaussianFilter, self).__init__() - self.svls_kernel: torch.Tensor - self.svls_layer: torch.Tensor + self.svls_kernel: Any + self.svls_layer: Any if dim == 2: gkernel = get_gaussian_kernel_2d(ksize=ksize, sigma=sigma) @@ -118,8 +118,8 @@ class MeanFilter(torch.nn.Module): def __init__(self, dim: int = 3, ksize: int = 3, channels: int = 0) -> torch.Tensor: super(MeanFilter, self).__init__() - self.svls_kernel: torch.Tensor - self.svls_layer: torch.Tensor + self.svls_kernel: Any + self.svls_layer: Any if dim == 2: self.svls_kernel = get_mean_kernel_2d(ksize=ksize) @@ -209,7 +209,7 @@ def __init__( self.distance_type = distance_type self.alpha = alpha self.ks = kernel_size - self.svls_layer: torch.Tensor + self.svls_layer: Any if kernel_ops == "mean": self.svls_layer = MeanFilter(dim=dim, ksize=kernel_size, channels=classes) From 7e36ca18b99b666de45318fc2c3782797141443f Mon Sep 17 00:00:00 2001 From: Balamurali Date: Sat, 3 Aug 2024 14:45:42 -0700 Subject: [PATCH 23/50] Update segcalib.py Import typing --- monai/losses/segcalib.py | 1 + 1 file changed, 1 insertion(+) diff --git a/monai/losses/segcalib.py b/monai/losses/segcalib.py index d0857c4bbb..da4b731a4d 100644 --- a/monai/losses/segcalib.py +++ b/monai/losses/segcalib.py @@ -13,6 +13,7 @@ import math +from typing import Any import torch import torch.nn as nn import torch.nn.functional as F From 6dbd53dd63148eb69b45a4c935bf9f7e3f82407a Mon Sep 17 00:00:00 2001 From: Balamurali Date: Sat, 3 Aug 2024 14:55:12 -0700 Subject: [PATCH 24/50] Update segcalib.py --- monai/losses/segcalib.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/losses/segcalib.py b/monai/losses/segcalib.py index da4b731a4d..67d1a6e327 100644 --- a/monai/losses/segcalib.py +++ b/monai/losses/segcalib.py @@ -1,3 +1,4 @@ + # Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,8 +13,8 @@ from __future__ import annotations import math - from typing import Any + import torch import torch.nn as nn import torch.nn.functional as F From 354056ce9b5308609ea6f24153938f91743e0102 Mon Sep 17 00:00:00 2001 From: Balamurali Date: Sat, 3 Aug 2024 17:51:41 -0700 Subject: [PATCH 25/50] Update segcalib.py --- monai/losses/segcalib.py | 1 - 1 file changed, 1 deletion(-) diff --git a/monai/losses/segcalib.py b/monai/losses/segcalib.py index 67d1a6e327..b7d9848ad9 100644 --- a/monai/losses/segcalib.py +++ b/monai/losses/segcalib.py @@ -1,4 +1,3 @@ - # Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 7eb911f88db62b15cca5da5dc1cdc119dce2f334 Mon Sep 17 00:00:00 2001 From: Balamurali Date: Sat, 3 Aug 2024 18:03:46 -0700 Subject: [PATCH 26/50] Update segcalib.py Fix return type formatting --- monai/losses/segcalib.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/monai/losses/segcalib.py b/monai/losses/segcalib.py index b7d9848ad9..5f65954923 100644 --- a/monai/losses/segcalib.py +++ b/monai/losses/segcalib.py @@ -61,7 +61,7 @@ def get_gaussian_kernel_3d(ksize: int = 3, sigma: float = 1.0) -> torch.Tensor: class GaussianFilter(torch.nn.Module): - def __init__(self, dim: int = 3, ksize: int = 3, sigma: float = 1.0, channels: int = 0) -> torch.Tensor: + def __init__(self, dim: int = 3, ksize: int = 3, sigma: float = 1.0, channels: int = 0) -> None: super(GaussianFilter, self).__init__() self.svls_kernel: Any @@ -111,12 +111,12 @@ def __init__(self, dim: int = 3, ksize: int = 3, sigma: float = 1.0, channels: i self.svls_layer.weight.data = svls_kernel_3d self.svls_layer.weight.requires_grad = False - def forward(self, x): + def forward(self, x) -> torch.Tensor: return self.svls_layer(x) / self.svls_kernel.sum() class MeanFilter(torch.nn.Module): - def __init__(self, dim: int = 3, ksize: int = 3, channels: int = 0) -> torch.Tensor: + def __init__(self, dim: int = 3, ksize: int = 3, channels: int = 0) -> None: super(MeanFilter, self).__init__() self.svls_kernel: Any @@ -158,7 +158,7 @@ def __init__(self, dim: int = 3, ksize: int = 3, channels: int = 0) -> torch.Ten self.svls_layer.weight.data = svls_kernel_3d self.svls_layer.weight.requires_grad = False - def forward(self, x): + def forward(self, x) -> torch.Tensor : return self.svls_layer(x) / self.svls_kernel.sum() @@ -183,7 +183,7 @@ def __init__( distance_type: str = "l1", alpha: float = 0.1, sigma: float = 1.0, - ) -> torch.Tensor: + ) -> None: """ Args: classes: number of classes From 0b1209bdbe62a6d7bccef56eabc6a23871bba151 Mon Sep 17 00:00:00 2001 From: Balamurali Date: Sat, 3 Aug 2024 18:12:35 -0700 Subject: [PATCH 27/50] Update segcalib.py --- monai/losses/segcalib.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/losses/segcalib.py b/monai/losses/segcalib.py index 5f65954923..f4d3252473 100644 --- a/monai/losses/segcalib.py +++ b/monai/losses/segcalib.py @@ -158,7 +158,7 @@ def __init__(self, dim: int = 3, ksize: int = 3, channels: int = 0) -> None: self.svls_layer.weight.data = svls_kernel_3d self.svls_layer.weight.requires_grad = False - def forward(self, x) -> torch.Tensor : + def forward(self, x) -> torch.Tensor: return self.svls_layer(x) / self.svls_kernel.sum() From 035c92e1097fe453d1ec4b8b759f70a145f500f0 Mon Sep 17 00:00:00 2001 From: Balamurali Date: Sat, 3 Aug 2024 21:05:07 -0700 Subject: [PATCH 28/50] Update segcalib.py Fix parameter datatype --- monai/losses/segcalib.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/monai/losses/segcalib.py b/monai/losses/segcalib.py index f4d3252473..22cbd886a9 100644 --- a/monai/losses/segcalib.py +++ b/monai/losses/segcalib.py @@ -64,7 +64,7 @@ class GaussianFilter(torch.nn.Module): def __init__(self, dim: int = 3, ksize: int = 3, sigma: float = 1.0, channels: int = 0) -> None: super(GaussianFilter, self).__init__() - self.svls_kernel: Any + self.svls_kernel: torch.Tensor self.svls_layer: Any if dim == 2: @@ -111,7 +111,7 @@ def __init__(self, dim: int = 3, ksize: int = 3, sigma: float = 1.0, channels: i self.svls_layer.weight.data = svls_kernel_3d self.svls_layer.weight.requires_grad = False - def forward(self, x) -> torch.Tensor: + def forward(self, x: torch.Tensor) -> torch.Tensor: return self.svls_layer(x) / self.svls_kernel.sum() @@ -119,7 +119,7 @@ class MeanFilter(torch.nn.Module): def __init__(self, dim: int = 3, ksize: int = 3, channels: int = 0) -> None: super(MeanFilter, self).__init__() - self.svls_kernel: Any + self.svls_kernel: torch.Tensor self.svls_layer: Any if dim == 2: @@ -158,7 +158,7 @@ def __init__(self, dim: int = 3, ksize: int = 3, channels: int = 0) -> None: self.svls_layer.weight.data = svls_kernel_3d self.svls_layer.weight.requires_grad = False - def forward(self, x) -> torch.Tensor: + def forward(self, x: torch.Tensor) -> torch.Tensor: return self.svls_layer(x) / self.svls_kernel.sum() From c1de5f11d4475040ecdc0abd9e6ee4c08609800e Mon Sep 17 00:00:00 2001 From: Balamurali Date: Mon, 5 Aug 2024 10:05:38 -0700 Subject: [PATCH 29/50] Rename segcalib.py to nacl_loss.py --- monai/losses/{segcalib.py => nacl_loss.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename monai/losses/{segcalib.py => nacl_loss.py} (100%) diff --git a/monai/losses/segcalib.py b/monai/losses/nacl_loss.py similarity index 100% rename from monai/losses/segcalib.py rename to monai/losses/nacl_loss.py From 91dd1b9c6fa6c829570ef488ce77ca1be9e0cc94 Mon Sep 17 00:00:00 2001 From: Balamurali Date: Mon, 5 Aug 2024 10:06:37 -0700 Subject: [PATCH 30/50] Update __init__.py Renamed the loss function --- monai/losses/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/losses/__init__.py b/monai/losses/__init__.py index 58963360f2..41935be204 100644 --- a/monai/losses/__init__.py +++ b/monai/losses/__init__.py @@ -37,8 +37,8 @@ from .hausdorff_loss import HausdorffDTLoss, LogHausdorffDTLoss from .image_dissimilarity import GlobalMutualInformationLoss, LocalNormalizedCrossCorrelationLoss from .multi_scale import MultiScaleLoss +from .nacl_loss import NACLLoss from .perceptual import PerceptualLoss -from .segcalib import NACLLoss from .spatial_mask import MaskedLoss from .spectral_loss import JukeboxLoss from .ssim_loss import SSIMLoss From 9702c025706289ca7ef62777e802ea58fae2254c Mon Sep 17 00:00:00 2001 From: Balamurali Date: Mon, 5 Aug 2024 10:19:03 -0700 Subject: [PATCH 31/50] Update test_nacl_loss.py Removed redundant test cases. --- tests/test_nacl_loss.py | 31 ------------------------------- 1 file changed, 31 deletions(-) diff --git a/tests/test_nacl_loss.py b/tests/test_nacl_loss.py index 1a4772dcb8..e41142124b 100644 --- a/tests/test_nacl_loss.py +++ b/tests/test_nacl_loss.py @@ -82,37 +82,6 @@ }, 1.1850, ], - [ - {"classes": 3, "dim": 2, "kernel_ops": "gaussian"}, - { - "inputs": torch.tensor( - [ - [ - [ - [0.0411, 0.3386, 0.8352, 0.2741], - [0.4821, 0.0519, 0.2561, 0.9391], - [0.5954, 0.4184, 0.9160, 0.7977], - [0.0588, 0.9156, 0.1307, 0.9914], - ], - [ - [0.8481, 0.0892, 0.2044, 0.8470], - [0.3558, 0.3569, 0.0979, 0.4491], - [0.0876, 0.0929, 0.4040, 0.8384], - [0.5313, 0.3927, 0.4165, 0.1107], - ], - [ - [0.7993, 0.6938, 0.3151, 0.8728], - [0.7332, 0.4111, 0.3862, 0.9988], - [0.2622, 0.5002, 0.1905, 0.1644], - [0.6354, 0.0047, 0.1649, 0.7112], - ], - ] - ] - ), - "targets": torch.tensor([[[1, 2, 0, 1], [0, 2, 1, 2], [0, 0, 2, 1], [1, 1, 1, 2]]]), - }, - 1.0375, - ], [ {"classes": 3, "dim": 3, "kernel_ops": "gaussian"}, { From 4462379bd04170f4bcf1f581105e59caaaf7adbc Mon Sep 17 00:00:00 2001 From: Balamurali Date: Mon, 5 Aug 2024 14:10:48 -0700 Subject: [PATCH 32/50] Update nacl_loss.py DCO Remediation Commit for Balamurali I, Balamurali , hereby add my Signed-off-by to this commit: b2ec62b62f7c94d495639fbb126a9dc2312040ae I, Balamurali , hereby add my Signed-off-by to this commit: 93ee114a739bcc929c5a1c75afb3545ea2dc5410 I, Balamurali , hereby add my Signed-off-by to this commit: 42e732bbf88302d8efbfc78063abf7e0d09eda67 I, Balamurali , hereby add my Signed-off-by to this commit: 187053d154374ec9649ab10a990c126160c6e8c3 I, Balamurali , hereby add my Signed-off-by to this commit: 1d27ec5c5de88815d7b1b95d6eba0acd4379aa47 I, Balamurali , hereby add my Signed-off-by to this commit: d4991340657564bdee7892d8909aa14a060bd67b I, Balamurali , hereby add my Signed-off-by to this commit: 1e3f7559e20d2789828ce7415a3a14fd905ddb85 I, Balamurali , hereby add my Signed-off-by to this commit: 9dedfba17a83886a472e96f723bfa2ea4d8739e8 I, Balamurali , hereby add my Signed-off-by to this commit: 5317706d70212aa50e5b3ef50385c701668122de I, Balamurali , hereby add my Signed-off-by to this commit: 5cd9a33ceaede6eb65d98023511b89aa0f9a8b0c I, Balamurali , hereby add my Signed-off-by to this commit: 42a021563605339f9d48af5e9b4633a2f72e61ee I, Balamurali , hereby add my Signed-off-by to this commit: 7e36ca18b99b666de45318fc2c3782797141443f I, Balamurali , hereby add my Signed-off-by to this commit: 6dbd53dd63148eb69b45a4c935bf9f7e3f82407a I, Balamurali , hereby add my Signed-off-by to this commit: 354056ce9b5308609ea6f24153938f91743e0102 I, Balamurali , hereby add my Signed-off-by to this commit: 7eb911f88db62b15cca5da5dc1cdc119dce2f334 I, Balamurali , hereby add my Signed-off-by to this commit: 0b1209bdbe62a6d7bccef56eabc6a23871bba151 I, Balamurali , hereby add my Signed-off-by to this commit: 035c92e1097fe453d1ec4b8b759f70a145f500f0 I, Balamurali , hereby add my Signed-off-by to this commit: c1de5f11d4475040ecdc0abd9e6ee4c08609800e I, Balamurali , hereby add my Signed-off-by to this commit: 91dd1b9c6fa6c829570ef488ce77ca1be9e0cc94 I, Balamurali , hereby add my Signed-off-by to this commit: 9702c025706289ca7ef62777e802ea58fae2254c Signed-off-by: Balamurali --- monai/losses/nacl_loss.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/monai/losses/nacl_loss.py b/monai/losses/nacl_loss.py index 22cbd886a9..f96e85464c 100644 --- a/monai/losses/nacl_loss.py +++ b/monai/losses/nacl_loss.py @@ -41,7 +41,8 @@ def get_gaussian_kernel_2d(ksize: int = 3, sigma: float = 1.0) -> torch.Tensor: gaussian_kernel = (1.0 / (2.0 * math.pi * variance + 1e-16)) * torch.exp( -torch.sum((xy_grid - mean) ** 2.0, dim=-1) / (2 * variance + 1e-16) ) - return gaussian_kernel / torch.sum(gaussian_kernel) + gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel) + return gaussian_kernel def get_gaussian_kernel_3d(ksize: int = 3, sigma: float = 1.0) -> torch.Tensor: @@ -57,7 +58,8 @@ def get_gaussian_kernel_3d(ksize: int = 3, sigma: float = 1.0) -> torch.Tensor: gaussian_kernel = (1.0 / (2.0 * math.pi * variance + 1e-16)) * torch.exp( -torch.sum((xyz_grid - mean) ** 2.0, dim=-1) / (2 * variance + 1e-16) ) - return gaussian_kernel / torch.sum(gaussian_kernel) + gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel) + return gaussian_kernel class GaussianFilter(torch.nn.Module): @@ -112,7 +114,8 @@ def __init__(self, dim: int = 3, ksize: int = 3, sigma: float = 1.0, channels: i self.svls_layer.weight.requires_grad = False def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.svls_layer(x) / self.svls_kernel.sum() + svls_normalized = self.svls_layer(x) / self.svls_kernel.sum() + return svls_normalized class MeanFilter(torch.nn.Module): @@ -159,7 +162,8 @@ def __init__(self, dim: int = 3, ksize: int = 3, channels: int = 0) -> None: self.svls_layer.weight.requires_grad = False def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.svls_layer(x) / self.svls_kernel.sum() + svls_normalized = self.svls_layer(x) / self.svls_kernel.sum() + return svls_normalized class NACLLoss(_Loss): From c4f82839ee14e29b28bb7e3b2b8847d3c9bf17d6 Mon Sep 17 00:00:00 2001 From: Balamurali Date: Mon, 5 Aug 2024 14:16:25 -0700 Subject: [PATCH 33/50] Update test_nacl_loss.py Introduce temporary value for same cases to reduce space. DCO Remediation Commit for bala93 I, bala93 , hereby add my Signed-off-by to this commit: 8fbec82b295bfaf87920770be830987177320642 I, bala93 , hereby add my Signed-off-by to this commit: 7c121a07857c4d59703a7fd37b63ac52dac67b04 I, bala93 , hereby add my Signed-off-by to this commit: dccde477906742ff237861696168ff2377ffdb20 Signed-off-by: bala93 --- tests/test_nacl_loss.py | 44 +++++++++++------------------------------ 1 file changed, 12 insertions(+), 32 deletions(-) diff --git a/tests/test_nacl_loss.py b/tests/test_nacl_loss.py index e41142124b..e3e910c5b6 100644 --- a/tests/test_nacl_loss.py +++ b/tests/test_nacl_loss.py @@ -19,11 +19,7 @@ from monai.losses import NACLLoss -TEST_CASES = [ - [ - {"classes": 3, "dim": 2}, - { - "inputs": torch.tensor( +inputs = torch.tensor( [ [ [ @@ -46,39 +42,23 @@ ], ] ] - ), - "targets": torch.tensor([[[1, 1, 1, 1], [1, 1, 1, 0], [0, 0, 0, 0], [0, 0, 0, 0]]]), + ) +targets = torch.tensor([[[1, 1, 1, 1], [1, 1, 1, 0], [0, 0, 0, 0], [0, 0, 0, 0]]]), + +TEST_CASES = [ + [ + {"classes": 3, "dim": 2}, + { + "inputs": inputs, + "targets": targets, }, 1.1820, ], [ {"classes": 3, "dim": 2, "kernel_ops": "gaussian"}, { - "inputs": torch.tensor( - [ - [ - [ - [0.1498, 0.1158, 0.3996, 0.3730], - [0.2155, 0.1585, 0.8541, 0.8579], - [0.6640, 0.2424, 0.0774, 0.0324], - [0.0580, 0.2180, 0.3447, 0.8722], - ], - [ - [0.3908, 0.9366, 0.1779, 0.1003], - [0.9630, 0.6118, 0.4405, 0.7916], - [0.5782, 0.9515, 0.4088, 0.3946], - [0.7860, 0.3910, 0.0324, 0.9568], - ], - [ - [0.0759, 0.0238, 0.5570, 0.1691], - [0.2703, 0.7722, 0.1611, 0.6431], - [0.8051, 0.6596, 0.4121, 0.1125], - [0.5283, 0.6746, 0.5528, 0.7913], - ], - ] - ] - ), - "targets": torch.tensor([[[1, 1, 1, 1], [1, 1, 1, 0], [0, 0, 0, 0], [0, 0, 0, 0]]]), + "inputs": inputs, + "targets": targets, }, 1.1850, ], From bc6b995c82004bff4585e22d1ec44ae9602518d4 Mon Sep 17 00:00:00 2001 From: Balamurali Date: Mon, 5 Aug 2024 14:25:25 -0700 Subject: [PATCH 34/50] Update test_nacl_loss.py DCO Remediation Commit for bala93 I, Balamurali , hereby add my Signed-off-by to this commit: c4f82839ee14e29b28bb7e3b2b8847d3c9bf17d6 I, bala93 , hereby add my Signed-off-by to this commit: 8fbec82b295bfaf87920770be830987177320642 I, bala93 , hereby add my Signed-off-by to this commit: 7c121a07857c4d59703a7fd37b63ac52dac67b04 I, bala93 , hereby add my Signed-off-by to this commit: dccde477906742ff237861696168ff2377ffdb20 Signed-off-by: bala93 --- tests/test_nacl_loss.py | 52 +++++++++++++++++++++-------------------- 1 file changed, 27 insertions(+), 25 deletions(-) diff --git a/tests/test_nacl_loss.py b/tests/test_nacl_loss.py index e3e910c5b6..f0853269f9 100644 --- a/tests/test_nacl_loss.py +++ b/tests/test_nacl_loss.py @@ -20,30 +20,30 @@ from monai.losses import NACLLoss inputs = torch.tensor( - [ - [ - [ - [0.1498, 0.1158, 0.3996, 0.3730], - [0.2155, 0.1585, 0.8541, 0.8579], - [0.6640, 0.2424, 0.0774, 0.0324], - [0.0580, 0.2180, 0.3447, 0.8722], - ], - [ - [0.3908, 0.9366, 0.1779, 0.1003], - [0.9630, 0.6118, 0.4405, 0.7916], - [0.5782, 0.9515, 0.4088, 0.3946], - [0.7860, 0.3910, 0.0324, 0.9568], - ], - [ - [0.0759, 0.0238, 0.5570, 0.1691], - [0.2703, 0.7722, 0.1611, 0.6431], - [0.8051, 0.6596, 0.4121, 0.1125], - [0.5283, 0.6746, 0.5528, 0.7913], - ], - ] - ] - ) -targets = torch.tensor([[[1, 1, 1, 1], [1, 1, 1, 0], [0, 0, 0, 0], [0, 0, 0, 0]]]), + [ + [ + [ + [0.1498, 0.1158, 0.3996, 0.3730], + [0.2155, 0.1585, 0.8541, 0.8579], + [0.6640, 0.2424, 0.0774, 0.0324], + [0.0580, 0.2180, 0.3447, 0.8722], + ], + [ + [0.3908, 0.9366, 0.1779, 0.1003], + [0.9630, 0.6118, 0.4405, 0.7916], + [0.5782, 0.9515, 0.4088, 0.3946], + [0.7860, 0.3910, 0.0324, 0.9568], + ], + [ + [0.0759, 0.0238, 0.5570, 0.1691], + [0.2703, 0.7722, 0.1611, 0.6431], + [0.8051, 0.6596, 0.4121, 0.1125], + [0.5283, 0.6746, 0.5528, 0.7913], + ], + ] + ] +) +targets = (torch.tensor([[[1, 1, 1, 1], [1, 1, 1, 0], [0, 0, 0, 0], [0, 0, 0, 0]]]),) TEST_CASES = [ [ @@ -170,7 +170,9 @@ class TestNACLLoss(unittest.TestCase): def test_result(self, input_param, input_data, expected_val): loss = NACLLoss(**input_param) result = loss(**input_data) - np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4) + np.testing.assert_allclose( + result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4 + ) if __name__ == "__main__": From 51e15fe6cf863f6dd734cfd70320c653762e29dd Mon Sep 17 00:00:00 2001 From: bala93 Date: Mon, 5 Aug 2024 17:49:39 -0400 Subject: [PATCH 35/50] Added missing parameters in doc DCO Remediation Commit for bala93 I, bala93 , hereby add my Signed-off-by to this commit: 8fbec82b295bfaf87920770be830987177320642 I, bala93 , hereby add my Signed-off-by to this commit: 7c121a07857c4d59703a7fd37b63ac52dac67b04 I, bala93 , hereby add my Signed-off-by to this commit: dccde477906742ff237861696168ff2377ffdb20 Signed-off-by: bala93 --- monai/losses/nacl_loss.py | 3 ++- tests/test_nacl_loss.py | 4 +--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/monai/losses/nacl_loss.py b/monai/losses/nacl_loss.py index f96e85464c..b5ffcd66cd 100644 --- a/monai/losses/nacl_loss.py +++ b/monai/losses/nacl_loss.py @@ -174,7 +174,7 @@ class NACLLoss(_Loss): Murugesan, Balamurali, et al. "Trust your neighbours: Penalty-based constraints for model calibration." - International Conference on Medical Image Computing and Computer-Assisted Intervention, 2023. + International Conference on Medical Image Computing and Computer-Assisted Intervention, MICCAI 2023. https://arxiv.org/abs/2303.06268 """ @@ -191,6 +191,7 @@ def __init__( """ Args: classes: number of classes + dim: dimension of data (supports 2d and 3d) kernel_size: size of the spatial kernel distance_type: l1/l2 distance between spatial kernel and predicted logits alpha: weightage between cross entropy and logit constraint diff --git a/tests/test_nacl_loss.py b/tests/test_nacl_loss.py index f0853269f9..608da95301 100644 --- a/tests/test_nacl_loss.py +++ b/tests/test_nacl_loss.py @@ -170,9 +170,7 @@ class TestNACLLoss(unittest.TestCase): def test_result(self, input_param, input_data, expected_val): loss = NACLLoss(**input_param) result = loss(**input_data) - np.testing.assert_allclose( - result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4 - ) + np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4) if __name__ == "__main__": From 3a00aec61761109e9efe7a3d459db05177e1c251 Mon Sep 17 00:00:00 2001 From: bala93 Date: Mon, 5 Aug 2024 18:07:10 -0400 Subject: [PATCH 36/50] Formatting check with monai DCO Remediation Commit for Balamurali I, Balamurali , hereby add my Signed-off-by to this commit: bc6b995c82004bff4585e22d1ec44ae9602518d4 Signed-off-by: Balamurali --- tests/test_nacl_loss.py | 18 ++---------------- 1 file changed, 2 insertions(+), 16 deletions(-) diff --git a/tests/test_nacl_loss.py b/tests/test_nacl_loss.py index 608da95301..a587509051 100644 --- a/tests/test_nacl_loss.py +++ b/tests/test_nacl_loss.py @@ -46,22 +46,8 @@ targets = (torch.tensor([[[1, 1, 1, 1], [1, 1, 1, 0], [0, 0, 0, 0], [0, 0, 0, 0]]]),) TEST_CASES = [ - [ - {"classes": 3, "dim": 2}, - { - "inputs": inputs, - "targets": targets, - }, - 1.1820, - ], - [ - {"classes": 3, "dim": 2, "kernel_ops": "gaussian"}, - { - "inputs": inputs, - "targets": targets, - }, - 1.1850, - ], + [{"classes": 3, "dim": 2}, {"inputs": inputs, "targets": targets}, 1.1820], + [{"classes": 3, "dim": 2, "kernel_ops": "gaussian"}, {"inputs": inputs, "targets": targets}, 1.1850], [ {"classes": 3, "dim": 3, "kernel_ops": "gaussian"}, { From 818b42b54f569d5f7b0d70dfd21172f91b33f6fa Mon Sep 17 00:00:00 2001 From: Balamurali Date: Mon, 5 Aug 2024 15:50:54 -0700 Subject: [PATCH 37/50] Update test_nacl_loss.py DCO Remediation Commit for Balamurali I, Balamurali , hereby add my Signed-off-by to this commit: bc6b995c82004bff4585e22d1ec44ae9602518d4 Signed-off-by: Balamurali --- tests/test_nacl_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_nacl_loss.py b/tests/test_nacl_loss.py index a587509051..d3a8d7ee5f 100644 --- a/tests/test_nacl_loss.py +++ b/tests/test_nacl_loss.py @@ -43,7 +43,7 @@ ] ] ) -targets = (torch.tensor([[[1, 1, 1, 1], [1, 1, 1, 0], [0, 0, 0, 0], [0, 0, 0, 0]]]),) +targets = torch.tensor([[[1, 1, 1, 1], [1, 1, 1, 0], [0, 0, 0, 0], [0, 0, 0, 0]]]) TEST_CASES = [ [{"classes": 3, "dim": 2}, {"inputs": inputs, "targets": targets}, 1.1820], From 664770809a44b4297d3b3da53dcdeef43c0462f2 Mon Sep 17 00:00:00 2001 From: Balamurali Date: Mon, 5 Aug 2024 19:08:48 -0400 Subject: [PATCH 38/50] Added mypy fixes Signed-off-by: Balamurali --- monai/losses/nacl_loss.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/monai/losses/nacl_loss.py b/monai/losses/nacl_loss.py index b5ffcd66cd..2cb2f106df 100644 --- a/monai/losses/nacl_loss.py +++ b/monai/losses/nacl_loss.py @@ -38,7 +38,7 @@ def get_gaussian_kernel_2d(ksize: int = 3, sigma: float = 1.0) -> torch.Tensor: xy_grid = torch.stack([x_grid, y_grid], dim=-1).float() mean = (ksize - 1) / 2.0 variance = sigma**2.0 - gaussian_kernel = (1.0 / (2.0 * math.pi * variance + 1e-16)) * torch.exp( + gaussian_kernel: torch.Tensor = (1.0 / (2.0 * math.pi * variance + 1e-16)) * torch.exp( -torch.sum((xy_grid - mean) ** 2.0, dim=-1) / (2 * variance + 1e-16) ) gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel) @@ -55,7 +55,7 @@ def get_gaussian_kernel_3d(ksize: int = 3, sigma: float = 1.0) -> torch.Tensor: xyz_grid = torch.stack([x_grid, y_grid, z_grid], dim=-1).float() mean = (ksize - 1) / 2.0 variance = sigma**2.0 - gaussian_kernel = (1.0 / (2.0 * math.pi * variance + 1e-16)) * torch.exp( + gaussian_kernel: torch.Tensor = (1.0 / (2.0 * math.pi * variance + 1e-16)) * torch.exp( -torch.sum((xyz_grid - mean) ** 2.0, dim=-1) / (2 * variance + 1e-16) ) gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel) @@ -114,7 +114,7 @@ def __init__(self, dim: int = 3, ksize: int = 3, sigma: float = 1.0, channels: i self.svls_layer.weight.requires_grad = False def forward(self, x: torch.Tensor) -> torch.Tensor: - svls_normalized = self.svls_layer(x) / self.svls_kernel.sum() + svls_normalized: torch.Tensor = self.svls_layer(x) / self.svls_kernel.sum() return svls_normalized @@ -162,7 +162,7 @@ def __init__(self, dim: int = 3, ksize: int = 3, channels: int = 0) -> None: self.svls_layer.weight.requires_grad = False def forward(self, x: torch.Tensor) -> torch.Tensor: - svls_normalized = self.svls_layer(x) / self.svls_kernel.sum() + svls_normalized: torch.Tensor = self.svls_layer(x) / self.svls_kernel.sum() return svls_normalized @@ -247,6 +247,9 @@ def __init__( # return self.cross_entropy(input, target) # type: ignore[no-any-return] def get_constr_target(self, mask: torch.Tensor) -> torch.Tensor: + + rmask: torch.Tensor + if self.dim == 2: oh_labels = F.one_hot(mask.to(torch.int64), num_classes=self.nc).contiguous().permute(0, 3, 1, 2).float() rmask = self.svls_layer(oh_labels) @@ -267,6 +270,6 @@ def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: elif self.distance_type == "l2": loss_conf = utargets.sub(inputs).pow_(2).abs_().mean() - loss = loss_ce + self.alpha * loss_conf + loss: torch.Tensor = loss_ce + self.alpha * loss_conf return loss From 7e579dd5ffd7ea0741c736abf5b2559a119a0b17 Mon Sep 17 00:00:00 2001 From: bala93 Date: Mon, 5 Aug 2024 19:44:04 -0400 Subject: [PATCH 39/50] DCO Remediation Commit for bala93 I, bala93 , hereby add my Signed-off-by to this commit: 3a00aec61761109e9efe7a3d459db05177e1c251 Signed-off-by: bala93 --- monai/losses/nacl_loss.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/monai/losses/nacl_loss.py b/monai/losses/nacl_loss.py index 2cb2f106df..27761bb6d0 100644 --- a/monai/losses/nacl_loss.py +++ b/monai/losses/nacl_loss.py @@ -176,6 +176,10 @@ class NACLLoss(_Loss): "Trust your neighbours: Penalty-based constraints for model calibration." International Conference on Medical Image Computing and Computer-Assisted Intervention, MICCAI 2023. https://arxiv.org/abs/2303.06268 + + Murugesan, Balamurali, et al. + "Neighbor-Aware Calibration of Segmentation Networks with Penalty-Based Constraints." + https://arxiv.org/abs/2401.14487 """ def __init__( From 4f8abf1efd32527a8c1faf9f58e4770305f38d09 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 5 Aug 2024 23:45:16 +0000 Subject: [PATCH 40/50] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/losses/nacl_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/losses/nacl_loss.py b/monai/losses/nacl_loss.py index 27761bb6d0..6875517e83 100644 --- a/monai/losses/nacl_loss.py +++ b/monai/losses/nacl_loss.py @@ -176,7 +176,7 @@ class NACLLoss(_Loss): "Trust your neighbours: Penalty-based constraints for model calibration." International Conference on Medical Image Computing and Computer-Assisted Intervention, MICCAI 2023. https://arxiv.org/abs/2303.06268 - + Murugesan, Balamurali, et al. "Neighbor-Aware Calibration of Segmentation Networks with Penalty-Based Constraints." https://arxiv.org/abs/2401.14487 From b72e47872c9931ada112d18849f80eb516c1d992 Mon Sep 17 00:00:00 2001 From: Balamurali Date: Tue, 6 Aug 2024 06:16:39 -0700 Subject: [PATCH 41/50] Update docs/source/losses.rst Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Balamurali --- docs/source/losses.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/losses.rst b/docs/source/losses.rst index c008f77be8..528ccd1173 100644 --- a/docs/source/losses.rst +++ b/docs/source/losses.rst @@ -94,7 +94,7 @@ Segmentation Losses :members: `NACLLoss` -~~~~~~~~~~~~~~~~~~~~ +~~~~~~~~~~ .. autoclass:: NACLLoss :members: From 747681dd84bb48b1b817d748200710020b283555 Mon Sep 17 00:00:00 2001 From: Balamurali Date: Wed, 7 Aug 2024 01:24:57 -0400 Subject: [PATCH 42/50] * Include test cases covering more cases * Modify the code with MeanFilter, and GaussianFilter of MONAI layers * Add Doc string explaining the mask preparation Signed-off-by: Balamurali --- monai/losses/nacl_loss.py | 179 ++------------------------------------ tests/test_nacl_loss.py | 11 ++- 2 files changed, 14 insertions(+), 176 deletions(-) diff --git a/monai/losses/nacl_loss.py b/monai/losses/nacl_loss.py index 6875517e83..2e0014af97 100644 --- a/monai/losses/nacl_loss.py +++ b/monai/losses/nacl_loss.py @@ -19,151 +19,7 @@ import torch.nn.functional as F from torch.nn.modules.loss import _Loss -from monai.utils import pytorch_after - - -def get_mean_kernel_2d(ksize: int = 3) -> torch.Tensor: - mean_kernel = torch.ones([ksize, ksize]) / (ksize**2) - return mean_kernel - - -def get_mean_kernel_3d(ksize: int = 3) -> torch.Tensor: - mean_kernel = torch.ones([ksize, ksize, ksize]) / (ksize**3) - return mean_kernel - - -def get_gaussian_kernel_2d(ksize: int = 3, sigma: float = 1.0) -> torch.Tensor: - x_grid = torch.arange(ksize).repeat(ksize).view(ksize, ksize) - y_grid = x_grid.t() - xy_grid = torch.stack([x_grid, y_grid], dim=-1).float() - mean = (ksize - 1) / 2.0 - variance = sigma**2.0 - gaussian_kernel: torch.Tensor = (1.0 / (2.0 * math.pi * variance + 1e-16)) * torch.exp( - -torch.sum((xy_grid - mean) ** 2.0, dim=-1) / (2 * variance + 1e-16) - ) - gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel) - return gaussian_kernel - - -def get_gaussian_kernel_3d(ksize: int = 3, sigma: float = 1.0) -> torch.Tensor: - x_coord = torch.arange(ksize) - x_grid_2d = x_coord.repeat(ksize).view(ksize, ksize) - x_grid = x_coord.repeat(ksize * ksize).view(ksize, ksize, ksize) - y_grid_2d = x_grid_2d.t() - y_grid = y_grid_2d.repeat(ksize, 1).view(ksize, ksize, ksize) - z_grid = y_grid_2d.repeat(1, ksize).view(ksize, ksize, ksize) - xyz_grid = torch.stack([x_grid, y_grid, z_grid], dim=-1).float() - mean = (ksize - 1) / 2.0 - variance = sigma**2.0 - gaussian_kernel: torch.Tensor = (1.0 / (2.0 * math.pi * variance + 1e-16)) * torch.exp( - -torch.sum((xyz_grid - mean) ** 2.0, dim=-1) / (2 * variance + 1e-16) - ) - gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel) - return gaussian_kernel - - -class GaussianFilter(torch.nn.Module): - def __init__(self, dim: int = 3, ksize: int = 3, sigma: float = 1.0, channels: int = 0) -> None: - super(GaussianFilter, self).__init__() - - self.svls_kernel: torch.Tensor - self.svls_layer: Any - - if dim == 2: - gkernel = get_gaussian_kernel_2d(ksize=ksize, sigma=sigma) - neighbors_sum = (1 - gkernel[1, 1]) + 1e-16 - gkernel[int(ksize / 2), int(ksize / 2)] = neighbors_sum - self.svls_kernel = gkernel / neighbors_sum - - svls_kernel_2d = self.svls_kernel.view(1, 1, ksize, ksize) - svls_kernel_2d = svls_kernel_2d.repeat(channels, 1, 1, 1) - padding = int(ksize / 2) - - self.svls_layer = torch.nn.Conv2d( - in_channels=channels, - out_channels=channels, - kernel_size=ksize, - groups=channels, - bias=False, - padding=padding, - padding_mode="replicate", - ) - self.svls_layer.weight.data = svls_kernel_2d - self.svls_layer.weight.requires_grad = False - - if dim == 3: - gkernel = get_gaussian_kernel_3d(ksize=ksize, sigma=sigma) - neighbors_sum = 1 - gkernel[1, 1, 1] - gkernel[1, 1, 1] = neighbors_sum - self.svls_kernel = gkernel / neighbors_sum - - svls_kernel_3d = self.svls_kernel.view(1, 1, ksize, ksize, ksize) - svls_kernel_3d = svls_kernel_3d.repeat(channels, 1, 1, 1, 1) - padding = int(ksize / 2) - - self.svls_layer = torch.nn.Conv3d( - in_channels=channels, - out_channels=channels, - kernel_size=ksize, - groups=channels, - bias=False, - padding=padding, - padding_mode="replicate", - ) - self.svls_layer.weight.data = svls_kernel_3d - self.svls_layer.weight.requires_grad = False - - def forward(self, x: torch.Tensor) -> torch.Tensor: - svls_normalized: torch.Tensor = self.svls_layer(x) / self.svls_kernel.sum() - return svls_normalized - - -class MeanFilter(torch.nn.Module): - def __init__(self, dim: int = 3, ksize: int = 3, channels: int = 0) -> None: - super(MeanFilter, self).__init__() - - self.svls_kernel: torch.Tensor - self.svls_layer: Any - - if dim == 2: - self.svls_kernel = get_mean_kernel_2d(ksize=ksize) - svls_kernel_2d = self.svls_kernel.view(1, 1, ksize, ksize) - svls_kernel_2d = svls_kernel_2d.repeat(channels, 1, 1, 1) - padding = int(ksize / 2) - - self.svls_layer = torch.nn.Conv2d( - in_channels=channels, - out_channels=channels, - kernel_size=ksize, - groups=channels, - bias=False, - padding=padding, - padding_mode="replicate", - ) - self.svls_layer.weight.data = svls_kernel_2d - self.svls_layer.weight.requires_grad = False - - if dim == 3: - self.svls_kernel = get_mean_kernel_3d(ksize=ksize) - svls_kernel_3d = self.svls_kernel.view(1, 1, ksize, ksize, ksize) - svls_kernel_3d = svls_kernel_3d.repeat(channels, 1, 1, 1, 1) - padding = int(ksize / 2) - - self.svls_layer = torch.nn.Conv3d( - in_channels=channels, - out_channels=channels, - kernel_size=ksize, - groups=channels, - bias=False, - padding=padding, - padding_mode="replicate", - ) - self.svls_layer.weight.data = svls_kernel_3d - self.svls_layer.weight.requires_grad = False - - def forward(self, x: torch.Tensor) -> torch.Tensor: - svls_normalized: torch.Tensor = self.svls_layer(x) / self.svls_kernel.sum() - return svls_normalized +from monai.networks.layers import GaussianFilter, MeanFilter class NACLLoss(_Loss): @@ -222,36 +78,15 @@ def __init__( self.svls_layer: Any if kernel_ops == "mean": - self.svls_layer = MeanFilter(dim=dim, ksize=kernel_size, channels=classes) + self.svls_layer = MeanFilter(spatial_dims=dim, size=kernel_size) + self.svls_layer.filter = self.svls_layer.filter / (kernel_size**dim) if kernel_ops == "gaussian": - self.svls_layer = GaussianFilter(dim=dim, ksize=kernel_size, sigma=sigma, channels=classes) - - self.old_pt_ver = not pytorch_after(1, 10) - - # def ce(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - # """ - # Compute CrossEntropy loss for the input logits and target. - # Will remove the channel dim according to PyTorch CrossEntropyLoss: - # https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html?#torch.nn.CrossEntropyLoss. - - # """ - # n_pred_ch, n_target_ch = input.shape[1], target.shape[1] - # if n_pred_ch != n_target_ch and n_target_ch == 1: - # target = torch.squeeze(target, dim=1) - # target = target.long() - # elif self.old_pt_ver: - # warnings.warn( - # f"Multichannel targets are not supported in this older Pytorch version {torch.__version__}. " - # "Using argmax (as a workaround) to convert target to a single channel." - # ) - # target = torch.argmax(target, dim=1) - # elif not torch.is_floating_point(target): - # target = target.to(dtype=input.dtype) - - # return self.cross_entropy(input, target) # type: ignore[no-any-return] + self.svls_layer = GaussianFilter(spatial_dims=dim, sigma=sigma) def get_constr_target(self, mask: torch.Tensor) -> torch.Tensor: - + """ + Converts the mask to one hot represenation and applies the spatial filter. + """ rmask: torch.Tensor if self.dim == 2: diff --git a/tests/test_nacl_loss.py b/tests/test_nacl_loss.py index d3a8d7ee5f..51ec275cf4 100644 --- a/tests/test_nacl_loss.py +++ b/tests/test_nacl_loss.py @@ -43,11 +43,14 @@ ] ] ) -targets = torch.tensor([[[1, 1, 1, 1], [1, 1, 1, 0], [0, 0, 0, 0], [0, 0, 0, 0]]]) +targets = torch.tensor([[[1, 1, 1, 1], [1, 1, 1, 0], [0, 0, 1, 0], [0, 1, 0, 0]]]) TEST_CASES = [ - [{"classes": 3, "dim": 2}, {"inputs": inputs, "targets": targets}, 1.1820], - [{"classes": 3, "dim": 2, "kernel_ops": "gaussian"}, {"inputs": inputs, "targets": targets}, 1.1850], + [{"classes": 3, "dim": 2}, {"inputs": inputs, "targets": targets}, 1.1442], + [{"classes": 3, "dim": 2, "kernel_ops": "gaussian"}, {"inputs": inputs, "targets": targets}, 1.1433], + [{"classes": 3, "dim": 2, "kernel_ops": "gaussian", "sigma": 0.5}, {"inputs": inputs, "targets": targets}, 1.1469], + [{"classes": 3, "dim": 2, "distance_type": "l2"}, {"inputs": inputs, "targets": targets}, 1.1269], + [{"classes": 3, "dim": 2, "alpha": 0.2}, {"inputs": inputs, "targets": targets}, 1.1790], [ {"classes": 3, "dim": 3, "kernel_ops": "gaussian"}, { @@ -146,7 +149,7 @@ ] ), }, - 1.1504, + 1.15035, ], ] From 3b155547f0b839ce80a07ebfcc173ae0eb83760d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 7 Aug 2024 05:28:24 +0000 Subject: [PATCH 43/50] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/losses/nacl_loss.py | 1 - 1 file changed, 1 deletion(-) diff --git a/monai/losses/nacl_loss.py b/monai/losses/nacl_loss.py index 2e0014af97..bb8fece6c1 100644 --- a/monai/losses/nacl_loss.py +++ b/monai/losses/nacl_loss.py @@ -11,7 +11,6 @@ from __future__ import annotations -import math from typing import Any import torch From 877139c0e07842413c6744cf7dae514a819a26e4 Mon Sep 17 00:00:00 2001 From: Balamurali Date: Wed, 7 Aug 2024 06:27:37 -0700 Subject: [PATCH 44/50] Update monai/losses/nacl_loss.py Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Balamurali --- monai/losses/nacl_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/losses/nacl_loss.py b/monai/losses/nacl_loss.py index bb8fece6c1..9e49f432cd 100644 --- a/monai/losses/nacl_loss.py +++ b/monai/losses/nacl_loss.py @@ -63,7 +63,7 @@ def __init__( raise ValueError("Kernel ops must be either mean or gaussian") if dim not in [2, 3]: - raise ValueError("Supoorts 2d and 3d") + raise ValueError(f"Support 2d and 3d, got dim={dim}.") if distance_type not in ["l1", "l2"]: raise ValueError("Distance type must be either L1 or L2") From 467945668a8f58b78310810a7feb88cc31b8ae93 Mon Sep 17 00:00:00 2001 From: Balamurali Date: Wed, 7 Aug 2024 06:28:01 -0700 Subject: [PATCH 45/50] Update monai/losses/nacl_loss.py Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Balamurali --- monai/losses/nacl_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/losses/nacl_loss.py b/monai/losses/nacl_loss.py index 9e49f432cd..3b540d713d 100644 --- a/monai/losses/nacl_loss.py +++ b/monai/losses/nacl_loss.py @@ -66,7 +66,7 @@ def __init__( raise ValueError(f"Support 2d and 3d, got dim={dim}.") if distance_type not in ["l1", "l2"]: - raise ValueError("Distance type must be either L1 or L2") + raise ValueError(f"Distance type must be either L1 or L2, got {distance_type}") self.nc = classes self.dim = dim From 7c5217ec623c624638833f97a499a698be888856 Mon Sep 17 00:00:00 2001 From: Balamurali Date: Wed, 7 Aug 2024 10:32:22 -0700 Subject: [PATCH 46/50] * Add docstring with better explanations Signed-off-by: Balamurali --- monai/losses/nacl_loss.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/monai/losses/nacl_loss.py b/monai/losses/nacl_loss.py index 3b540d713d..9cfd9f303a 100644 --- a/monai/losses/nacl_loss.py +++ b/monai/losses/nacl_loss.py @@ -85,6 +85,12 @@ def __init__( def get_constr_target(self, mask: torch.Tensor) -> torch.Tensor: """ Converts the mask to one hot represenation and applies the spatial filter. + + Args: + mask: the shape should be BHW[D] + + Returns: + torch.Tensor: the shape would be BNHW[D], N being number of classes. """ rmask: torch.Tensor @@ -99,6 +105,23 @@ def get_constr_target(self, mask: torch.Tensor) -> torch.Tensor: return rmask def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: + """ + Computes standard cross-entropy loss and constraints it neighbor aware logit penalty. + + Args: + inputs: the shape should be BNHW[D], where N is the number of classes. + targets: the shape should be BHW[D]. + + Example: + >>> import torch + >>> from monai.losses import NACLLoss + >>> B, N, H, W = 8, 3, 64, 64 + >>> input = torch.rand(B, N, H, W) + >>> target = torch.randint(0, N, (B, H, W)) + >>> criterion = NACLLoss(classes = N, dim = 2) + >>> loss = self(input, target) + """ + loss_ce = self.cross_entropy(inputs, targets) utargets = self.get_constr_target(targets) From d33f43586d8354a47caafe3f53007d40535a0243 Mon Sep 17 00:00:00 2001 From: Balamurali Date: Wed, 7 Aug 2024 10:42:17 -0700 Subject: [PATCH 47/50] * Maintain the dimension consistency. Signed-off-by: Balamurali --- monai/losses/nacl_loss.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/monai/losses/nacl_loss.py b/monai/losses/nacl_loss.py index 9cfd9f303a..5cbd2f7f44 100644 --- a/monai/losses/nacl_loss.py +++ b/monai/losses/nacl_loss.py @@ -112,6 +112,9 @@ def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: inputs: the shape should be BNHW[D], where N is the number of classes. targets: the shape should be BHW[D]. + Returns: + torch.Tensor: value of the loss. + Example: >>> import torch >>> from monai.losses import NACLLoss From 7deb2ccaefbb3ae0cb90a8a1a4bd138fb0134f42 Mon Sep 17 00:00:00 2001 From: Balamurali Date: Wed, 7 Aug 2024 10:50:25 -0700 Subject: [PATCH 48/50] Update nacl_loss.py --- monai/losses/nacl_loss.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/monai/losses/nacl_loss.py b/monai/losses/nacl_loss.py index 5cbd2f7f44..30f722e6a4 100644 --- a/monai/losses/nacl_loss.py +++ b/monai/losses/nacl_loss.py @@ -87,10 +87,10 @@ def get_constr_target(self, mask: torch.Tensor) -> torch.Tensor: Converts the mask to one hot represenation and applies the spatial filter. Args: - mask: the shape should be BHW[D] + mask: the shape should be BH[WD]. Returns: - torch.Tensor: the shape would be BNHW[D], N being number of classes. + torch.Tensor: the shape would be BNH[WD], N being number of classes. """ rmask: torch.Tensor @@ -109,8 +109,8 @@ def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: Computes standard cross-entropy loss and constraints it neighbor aware logit penalty. Args: - inputs: the shape should be BNHW[D], where N is the number of classes. - targets: the shape should be BHW[D]. + inputs: the shape should be BNH[WD], where N is the number of classes. + targets: the shape should be BH[WD]. Returns: torch.Tensor: value of the loss. @@ -122,7 +122,7 @@ def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: >>> input = torch.rand(B, N, H, W) >>> target = torch.randint(0, N, (B, H, W)) >>> criterion = NACLLoss(classes = N, dim = 2) - >>> loss = self(input, target) + >>> loss = criterion(input, target) """ loss_ce = self.cross_entropy(inputs, targets) From 91ce50b75d4931a282a70011a7d6f4450aa05a8d Mon Sep 17 00:00:00 2001 From: Balamurali Date: Wed, 7 Aug 2024 11:29:17 -0700 Subject: [PATCH 49/50] Update nacl_loss.py Signed-off-by: Balamurali --- monai/losses/nacl_loss.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/monai/losses/nacl_loss.py b/monai/losses/nacl_loss.py index 5cbd2f7f44..30f722e6a4 100644 --- a/monai/losses/nacl_loss.py +++ b/monai/losses/nacl_loss.py @@ -87,10 +87,10 @@ def get_constr_target(self, mask: torch.Tensor) -> torch.Tensor: Converts the mask to one hot represenation and applies the spatial filter. Args: - mask: the shape should be BHW[D] + mask: the shape should be BH[WD]. Returns: - torch.Tensor: the shape would be BNHW[D], N being number of classes. + torch.Tensor: the shape would be BNH[WD], N being number of classes. """ rmask: torch.Tensor @@ -109,8 +109,8 @@ def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: Computes standard cross-entropy loss and constraints it neighbor aware logit penalty. Args: - inputs: the shape should be BNHW[D], where N is the number of classes. - targets: the shape should be BHW[D]. + inputs: the shape should be BNH[WD], where N is the number of classes. + targets: the shape should be BH[WD]. Returns: torch.Tensor: value of the loss. @@ -122,7 +122,7 @@ def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: >>> input = torch.rand(B, N, H, W) >>> target = torch.randint(0, N, (B, H, W)) >>> criterion = NACLLoss(classes = N, dim = 2) - >>> loss = self(input, target) + >>> loss = criterion(input, target) """ loss_ce = self.cross_entropy(inputs, targets) From 0e880a848c2fe97e0d74cf57a54dc2af009b1500 Mon Sep 17 00:00:00 2001 From: Balamurali Date: Wed, 7 Aug 2024 11:51:09 -0700 Subject: [PATCH 50/50] Modify docstring DCO Remediation Commit for Balamurali I, Balamurali , hereby add my Signed-off-by to this commit: 7deb2ccaefbb3ae0cb90a8a1a4bd138fb0134f42 Signed-off-by: Balamurali --- monai/losses/nacl_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/losses/nacl_loss.py b/monai/losses/nacl_loss.py index 30f722e6a4..3303e89bce 100644 --- a/monai/losses/nacl_loss.py +++ b/monai/losses/nacl_loss.py @@ -84,7 +84,7 @@ def __init__( def get_constr_target(self, mask: torch.Tensor) -> torch.Tensor: """ - Converts the mask to one hot represenation and applies the spatial filter. + Converts the mask to one hot represenation and is smoothened with the selected spatial filter. Args: mask: the shape should be BH[WD].