From b19939df0faf2a1f7c101bec6ab5d63900be5550 Mon Sep 17 00:00:00 2001 From: virginiafdez Date: Wed, 15 Feb 2023 11:45:25 +0000 Subject: [PATCH 1/2] Changed PatchAdversarialLoss to allow for least-squares criterion to not have a leaky RELU activation layer. --- generative/losses/adversarial_loss.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/generative/losses/adversarial_loss.py b/generative/losses/adversarial_loss.py index 9189d89d..c99370b5 100644 --- a/generative/losses/adversarial_loss.py +++ b/generative/losses/adversarial_loss.py @@ -42,12 +42,14 @@ class PatchAdversarialLoss(_Loss): criterion: which criterion (hinge, least_squares or bce) you want to use on the discriminators outputs. Depending on the criterion, a different activation layer will be used. Make sure you don't run the outputs through an activation layer prior to calling the loss. + no_activation_leastsq: if True, the activation layer in the case of least-squares is removed. """ def __init__( self, reduction: LossReduction | str = LossReduction.MEAN, criterion: str = AdversarialCriterions.LEAST_SQUARE.value, + no_activation_leastsq: bool = False ) -> None: super().__init__(reduction=LossReduction(reduction).value) @@ -67,7 +69,10 @@ def __init__( self.activation = get_act_layer("TANH") self.fake_label = -1.0 elif criterion == AdversarialCriterions.LEAST_SQUARE.value: - self.activation = get_act_layer(name=("LEAKYRELU", {"negative_slope": 0.05})) + if no_activation_leastsq: + self.activation = None + else: + self.activation = get_act_layer(name=("LEAKYRELU", {"negative_slope": 0.05})) self.loss_fct = torch.nn.MSELoss(reduction=reduction) self.criterion = criterion @@ -138,7 +143,8 @@ def forward( # Loss calculation loss = [] for disc_ind, disc_out in enumerate(input): - disc_out = self.activation(disc_out) + if self.activation is not None: + disc_out = self.activation(disc_out) if self.criterion == AdversarialCriterions.HINGE.value and not target_is_real: loss_ = self.forward_single(-disc_out, target_[disc_ind]) else: From 5cf63dc304dd0fcb8f2560ad5ea9357e859284cc Mon Sep 17 00:00:00 2001 From: virginiafdez Date: Thu, 16 Feb 2023 13:53:25 +0000 Subject: [PATCH 2/2] Reformatting. --- generative/losses/adversarial_loss.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/generative/losses/adversarial_loss.py b/generative/losses/adversarial_loss.py index c99370b5..c586093e 100644 --- a/generative/losses/adversarial_loss.py +++ b/generative/losses/adversarial_loss.py @@ -21,7 +21,6 @@ class AdversarialCriterions(StrEnum): - BCE = "bce" HINGE = "hinge" LEAST_SQUARE = "least_squares" @@ -49,7 +48,7 @@ def __init__( self, reduction: LossReduction | str = LossReduction.MEAN, criterion: str = AdversarialCriterions.LEAST_SQUARE.value, - no_activation_leastsq: bool = False + no_activation_leastsq: bool = False, ) -> None: super().__init__(reduction=LossReduction(reduction).value) @@ -109,7 +108,6 @@ def get_zero_tensor(self, input: torch.FloatTensor) -> torch.Tensor: def forward( self, input: torch.FloatTensor | list, target_is_real: bool, for_discriminator: bool ) -> torch.Tensor | list[torch.Tensor]: - """ Args: