Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.
This repository was archived by the owner on Feb 7, 2025. It is now read-only.

Batchnorm for PatchDiscriminator running in DDP #451

@sRassmann

Description

@sRassmann

Thanks for this amazing work, helps a lot in accelerating experiments!

I tried training a AE using PatchDiscriminator and ran into this issue, when switching to DistributedDataParallel (DDP)

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [512]] is at version 4; expected version 3 instead.

running it with torch.autograd.set_detect_anomaly(True) gives:

UserWarning: Error detected in CudnnBatchNormBackward0

With some troubleshooting I found that the issue is the BatchNorm. So running

discriminator = PatchDiscriminator(**kwargs)
torch.nn.SyncBatchNorm.convert_sync_batchnorm(discriminator)

solves it.

Might it be worthwhile to put this into the constructor of the PatchDiscriminator to avoid similar issues in the future? e.g.

class PatchDiscriminator(nn.Sequential):
    def __init__(**kwargs) -> None:
        super().__init__()
        [...]
        self.apply(self.initialise_weights)

        torch.nn.SyncBatchNorm.convert_sync_batchnorm(self)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions