diff --git a/generative/losses/adversarial_loss.py b/generative/losses/adversarial_loss.py index 9189d89d..c586093e 100644 --- a/generative/losses/adversarial_loss.py +++ b/generative/losses/adversarial_loss.py @@ -21,7 +21,6 @@ class AdversarialCriterions(StrEnum): - BCE = "bce" HINGE = "hinge" LEAST_SQUARE = "least_squares" @@ -42,12 +41,14 @@ class PatchAdversarialLoss(_Loss): criterion: which criterion (hinge, least_squares or bce) you want to use on the discriminators outputs. Depending on the criterion, a different activation layer will be used. Make sure you don't run the outputs through an activation layer prior to calling the loss. + no_activation_leastsq: if True, the activation layer in the case of least-squares is removed. """ def __init__( self, reduction: LossReduction | str = LossReduction.MEAN, criterion: str = AdversarialCriterions.LEAST_SQUARE.value, + no_activation_leastsq: bool = False, ) -> None: super().__init__(reduction=LossReduction(reduction).value) @@ -67,7 +68,10 @@ def __init__( self.activation = get_act_layer("TANH") self.fake_label = -1.0 elif criterion == AdversarialCriterions.LEAST_SQUARE.value: - self.activation = get_act_layer(name=("LEAKYRELU", {"negative_slope": 0.05})) + if no_activation_leastsq: + self.activation = None + else: + self.activation = get_act_layer(name=("LEAKYRELU", {"negative_slope": 0.05})) self.loss_fct = torch.nn.MSELoss(reduction=reduction) self.criterion = criterion @@ -104,7 +108,6 @@ def get_zero_tensor(self, input: torch.FloatTensor) -> torch.Tensor: def forward( self, input: torch.FloatTensor | list, target_is_real: bool, for_discriminator: bool ) -> torch.Tensor | list[torch.Tensor]: - """ Args: @@ -138,7 +141,8 @@ def forward( # Loss calculation loss = [] for disc_ind, disc_out in enumerate(input): - disc_out = self.activation(disc_out) + if self.activation is not None: + disc_out = self.activation(disc_out) if self.criterion == AdversarialCriterions.HINGE.value and not target_is_real: loss_ = self.forward_single(-disc_out, target_[disc_ind]) else: