From 42aa0887aa887ced56dc4795f47e19a16c748e32 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Tue, 24 Oct 2023 11:20:52 +0800 Subject: [PATCH 1/4] fix #7065 Signed-off-by: KumoLiu --- monai/losses/dice.py | 15 ++++++++------- monai/losses/focal_loss.py | 14 ++++++++------ 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/monai/losses/dice.py b/monai/losses/dice.py index d74d40fe37..10404df9b8 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -111,8 +111,9 @@ def __init__( self.smooth_nr = float(smooth_nr) self.smooth_dr = float(smooth_dr) self.batch = batch - self.weight = weight - self.register_buffer("class_weight", torch.ones(1)) + if weight is not None: + weight = torch.as_tensor(weight) + self.register_buffer("class_weight", weight) def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ @@ -189,13 +190,13 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: f: torch.Tensor = 1.0 - (2.0 * intersection + self.smooth_nr) / (denominator + self.smooth_dr) - if self.weight is not None and target.shape[1] != 1: + num_of_classes = target.shape[1] + if self.class_weight is not None and num_of_classes != 1: # make sure the lengths of weights are equal to the number of classes - num_of_classes = target.shape[1] - if isinstance(self.weight, (float, int)): - self.class_weight = torch.as_tensor([self.weight] * num_of_classes) + if self.class_weight.ndim == 0: + self.class_weight = torch.as_tensor([self.class_weight] * num_of_classes) else: - self.class_weight = torch.as_tensor(self.weight) + self.class_weight = torch.as_tensor(self.class_weight) if self.class_weight.shape[0] != num_of_classes: raise ValueError( """the length of the `weight` sequence should be the same as the number of classes. diff --git a/monai/losses/focal_loss.py b/monai/losses/focal_loss.py index fbd0e6efb8..3ea0b627dc 100644 --- a/monai/losses/focal_loss.py +++ b/monai/losses/focal_loss.py @@ -113,7 +113,9 @@ def __init__( self.alpha = alpha self.weight = weight self.use_softmax = use_softmax - self.register_buffer("class_weight", torch.ones(1)) + if weight is not None: + weight = torch.as_tensor(weight) + self.register_buffer("class_weight", weight) def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ @@ -162,13 +164,13 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: else: loss = sigmoid_focal_loss(input, target, self.gamma, self.alpha) - if self.weight is not None: + num_of_classes = target.shape[1] + if self.class_weight is not None and num_of_classes != 1: # make sure the lengths of weights are equal to the number of classes - num_of_classes = target.shape[1] - if isinstance(self.weight, (float, int)): - self.class_weight = torch.as_tensor([self.weight] * num_of_classes) + if self.class_weight.ndim == 0: + self.class_weight = torch.as_tensor([self.class_weight] * num_of_classes) else: - self.class_weight = torch.as_tensor(self.weight) + self.class_weight = torch.as_tensor(self.class_weight) if self.class_weight.shape[0] != num_of_classes: raise ValueError( """the length of the `weight` sequence should be the same as the number of classes. From 5b4fe2936ec9db1d54f5ab829650164990b767a9 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Tue, 24 Oct 2023 11:32:16 +0800 Subject: [PATCH 2/4] minor fix Signed-off-by: KumoLiu --- monai/losses/dice.py | 4 +--- monai/losses/focal_loss.py | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 10404df9b8..90bdf685c6 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -111,8 +111,7 @@ def __init__( self.smooth_nr = float(smooth_nr) self.smooth_dr = float(smooth_dr) self.batch = batch - if weight is not None: - weight = torch.as_tensor(weight) + weight = torch.as_tensor(weight) if weight is not None else None self.register_buffer("class_weight", weight) def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: @@ -196,7 +195,6 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if self.class_weight.ndim == 0: self.class_weight = torch.as_tensor([self.class_weight] * num_of_classes) else: - self.class_weight = torch.as_tensor(self.class_weight) if self.class_weight.shape[0] != num_of_classes: raise ValueError( """the length of the `weight` sequence should be the same as the number of classes. diff --git a/monai/losses/focal_loss.py b/monai/losses/focal_loss.py index 3ea0b627dc..aefe050e92 100644 --- a/monai/losses/focal_loss.py +++ b/monai/losses/focal_loss.py @@ -113,8 +113,7 @@ def __init__( self.alpha = alpha self.weight = weight self.use_softmax = use_softmax - if weight is not None: - weight = torch.as_tensor(weight) + weight = torch.as_tensor(weight) if weight is not None else None self.register_buffer("class_weight", weight) def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: @@ -170,7 +169,6 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if self.class_weight.ndim == 0: self.class_weight = torch.as_tensor([self.class_weight] * num_of_classes) else: - self.class_weight = torch.as_tensor(self.class_weight) if self.class_weight.shape[0] != num_of_classes: raise ValueError( """the length of the `weight` sequence should be the same as the number of classes. From 8713554726255a1e88e2cf4950fa1f0a5a37b21d Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Tue, 24 Oct 2023 11:34:55 +0800 Subject: [PATCH 3/4] fix mypy Signed-off-by: KumoLiu --- monai/losses/dice.py | 1 + monai/losses/focal_loss.py | 1 + 2 files changed, 2 insertions(+) diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 90bdf685c6..d567ec0d20 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -190,6 +190,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: f: torch.Tensor = 1.0 - (2.0 * intersection + self.smooth_nr) / (denominator + self.smooth_dr) num_of_classes = target.shape[1] + self.class_weight: None | torch.Tensor if self.class_weight is not None and num_of_classes != 1: # make sure the lengths of weights are equal to the number of classes if self.class_weight.ndim == 0: diff --git a/monai/losses/focal_loss.py b/monai/losses/focal_loss.py index aefe050e92..36add8882a 100644 --- a/monai/losses/focal_loss.py +++ b/monai/losses/focal_loss.py @@ -164,6 +164,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: loss = sigmoid_focal_loss(input, target, self.gamma, self.alpha) num_of_classes = target.shape[1] + self.class_weight: None | torch.Tensor if self.class_weight is not None and num_of_classes != 1: # make sure the lengths of weights are equal to the number of classes if self.class_weight.ndim == 0: From 843131fd264373d1977bbffbf5c8b3eacc4ce2a6 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Tue, 24 Oct 2023 11:52:01 +0800 Subject: [PATCH 4/4] fix ci Signed-off-by: KumoLiu --- monai/losses/dice.py | 2 +- monai/losses/focal_loss.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/losses/dice.py b/monai/losses/dice.py index d567ec0d20..b3c0f57c6e 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -113,6 +113,7 @@ def __init__( self.batch = batch weight = torch.as_tensor(weight) if weight is not None else None self.register_buffer("class_weight", weight) + self.class_weight: None | torch.Tensor def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ @@ -190,7 +191,6 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: f: torch.Tensor = 1.0 - (2.0 * intersection + self.smooth_nr) / (denominator + self.smooth_dr) num_of_classes = target.shape[1] - self.class_weight: None | torch.Tensor if self.class_weight is not None and num_of_classes != 1: # make sure the lengths of weights are equal to the number of classes if self.class_weight.ndim == 0: diff --git a/monai/losses/focal_loss.py b/monai/losses/focal_loss.py index 36add8882a..98c1a071b6 100644 --- a/monai/losses/focal_loss.py +++ b/monai/losses/focal_loss.py @@ -115,6 +115,7 @@ def __init__( self.use_softmax = use_softmax weight = torch.as_tensor(weight) if weight is not None else None self.register_buffer("class_weight", weight) + self.class_weight: None | torch.Tensor def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ @@ -164,7 +165,6 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: loss = sigmoid_focal_loss(input, target, self.gamma, self.alpha) num_of_classes = target.shape[1] - self.class_weight: None | torch.Tensor if self.class_weight is not None and num_of_classes != 1: # make sure the lengths of weights are equal to the number of classes if self.class_weight.ndim == 0: