Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions generative/losses/adversarial_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@


class AdversarialCriterions(StrEnum):

BCE = "bce"
HINGE = "hinge"
LEAST_SQUARE = "least_squares"
Expand All @@ -42,12 +41,14 @@ class PatchAdversarialLoss(_Loss):
criterion: which criterion (hinge, least_squares or bce) you want to use on the discriminators outputs.
Depending on the criterion, a different activation layer will be used. Make sure you don't run the outputs
through an activation layer prior to calling the loss.
no_activation_leastsq: if True, the activation layer in the case of least-squares is removed.
"""

def __init__(
self,
reduction: LossReduction | str = LossReduction.MEAN,
criterion: str = AdversarialCriterions.LEAST_SQUARE.value,
no_activation_leastsq: bool = False,
) -> None:
super().__init__(reduction=LossReduction(reduction).value)

Expand All @@ -67,7 +68,10 @@ def __init__(
self.activation = get_act_layer("TANH")
self.fake_label = -1.0
elif criterion == AdversarialCriterions.LEAST_SQUARE.value:
self.activation = get_act_layer(name=("LEAKYRELU", {"negative_slope": 0.05}))
if no_activation_leastsq:
self.activation = None
else:
self.activation = get_act_layer(name=("LEAKYRELU", {"negative_slope": 0.05}))
self.loss_fct = torch.nn.MSELoss(reduction=reduction)

self.criterion = criterion
Expand Down Expand Up @@ -104,7 +108,6 @@ def get_zero_tensor(self, input: torch.FloatTensor) -> torch.Tensor:
def forward(
self, input: torch.FloatTensor | list, target_is_real: bool, for_discriminator: bool
) -> torch.Tensor | list[torch.Tensor]:

"""

Args:
Expand Down Expand Up @@ -138,7 +141,8 @@ def forward(
# Loss calculation
loss = []
for disc_ind, disc_out in enumerate(input):
disc_out = self.activation(disc_out)
if self.activation is not None:
disc_out = self.activation(disc_out)
if self.criterion == AdversarialCriterions.HINGE.value and not target_is_real:
loss_ = self.forward_single(-disc_out, target_[disc_ind])
else:
Expand Down