From ad010032b3bde628bc336eb28c72703ef81456fc Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Tue, 2 Jan 2024 14:48:02 +0000 Subject: [PATCH 1/3] Add warning --- generative/networks/nets/patchgan_discriminator.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/generative/networks/nets/patchgan_discriminator.py b/generative/networks/nets/patchgan_discriminator.py index bf09b743..c1b99f82 100644 --- a/generative/networks/nets/patchgan_discriminator.py +++ b/generative/networks/nets/patchgan_discriminator.py @@ -218,7 +218,10 @@ def __init__( ) self.apply(self.initialise_weights) - + if norm.lower() == 'batch' and torch.distributed.is_initialized(): + print("WARNING: Discriminator is using BatchNorm and a distributed training environment has been detected. " + "To train with DDP, convert discriminator to SyncBatchNorm using " + "torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).)") def forward(self, x: torch.Tensor) -> list[torch.Tensor]: """ From c1dba21ad88ebfa5f140e3c21ffec909aeb93963 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Mon, 8 Jan 2024 09:56:06 +0000 Subject: [PATCH 2/3] Update generative/networks/nets/patchgan_discriminator.py Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Signed-off-by: Mark Graham --- generative/networks/nets/patchgan_discriminator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/generative/networks/nets/patchgan_discriminator.py b/generative/networks/nets/patchgan_discriminator.py index c1b99f82..356aa8d8 100644 --- a/generative/networks/nets/patchgan_discriminator.py +++ b/generative/networks/nets/patchgan_discriminator.py @@ -219,7 +219,7 @@ def __init__( self.apply(self.initialise_weights) if norm.lower() == 'batch' and torch.distributed.is_initialized(): - print("WARNING: Discriminator is using BatchNorm and a distributed training environment has been detected. " + warnings.warn("WARNING: Discriminator is using BatchNorm and a distributed training environment has been detected. " "To train with DDP, convert discriminator to SyncBatchNorm using " "torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).)") def forward(self, x: torch.Tensor) -> list[torch.Tensor]: From 61bb9f18c4b396a5302418e6770e7cfc4fab5d0e Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Mon, 8 Jan 2024 09:59:44 +0000 Subject: [PATCH 3/3] Adds missing warning --- generative/networks/nets/patchgan_discriminator.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/generative/networks/nets/patchgan_discriminator.py b/generative/networks/nets/patchgan_discriminator.py index 356aa8d8..78120e8e 100644 --- a/generative/networks/nets/patchgan_discriminator.py +++ b/generative/networks/nets/patchgan_discriminator.py @@ -11,6 +11,7 @@ from __future__ import annotations +import warnings from collections.abc import Sequence import torch @@ -218,10 +219,13 @@ def __init__( ) self.apply(self.initialise_weights) - if norm.lower() == 'batch' and torch.distributed.is_initialized(): - warnings.warn("WARNING: Discriminator is using BatchNorm and a distributed training environment has been detected. " - "To train with DDP, convert discriminator to SyncBatchNorm using " - "torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).)") + if norm.lower() == "batch" and torch.distributed.is_initialized(): + warnings.warn( + "WARNING: Discriminator is using BatchNorm and a distributed training environment has been detected. " + "To train with DDP, convert discriminator to SyncBatchNorm using " + "torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).)" + ) + def forward(self, x: torch.Tensor) -> list[torch.Tensor]: """