diff --git a/generative/networks/nets/patchgan_discriminator.py b/generative/networks/nets/patchgan_discriminator.py index bf09b743..78120e8e 100644 --- a/generative/networks/nets/patchgan_discriminator.py +++ b/generative/networks/nets/patchgan_discriminator.py @@ -11,6 +11,7 @@ from __future__ import annotations +import warnings from collections.abc import Sequence import torch @@ -218,6 +219,12 @@ def __init__( ) self.apply(self.initialise_weights) + if norm.lower() == "batch" and torch.distributed.is_initialized(): + warnings.warn( + "WARNING: Discriminator is using BatchNorm and a distributed training environment has been detected. " + "To train with DDP, convert discriminator to SyncBatchNorm using " + "torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).)" + ) def forward(self, x: torch.Tensor) -> list[torch.Tensor]: """