From 3667b3b564d5b4e9f47d34efae188be2dabf28ee Mon Sep 17 00:00:00 2001 From: badayvedat Date: Sat, 1 Feb 2025 11:55:04 +0300 Subject: [PATCH 1/3] feat(training-utils): support device and dtype params in compute_density_for_timestep_sampling --- src/diffusers/training_utils.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 082640f37a17..1269135f1480 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -248,7 +248,13 @@ def _set_state_dict_into_text_encoder( def compute_density_for_timestep_sampling( - weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None + weighting_scheme: str, + batch_size: int, + logit_mean: float = None, + logit_std: float = None, + mode_scale: float = None, + device: torch.device = "cpu", + generator: Optional[torch.Generator] = None, ): """ Compute the density for sampling the timesteps when doing SD3 training. @@ -258,14 +264,13 @@ def compute_density_for_timestep_sampling( SD3 paper reference: https://arxiv.org/abs/2403.03206v1. """ if weighting_scheme == "logit_normal": - # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). - u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu") + u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device=device, generator=generator) u = torch.nn.functional.sigmoid(u) elif weighting_scheme == "mode": - u = torch.rand(size=(batch_size,), device="cpu") + u = torch.rand(size=(batch_size,), device=device, generator=generator) u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) else: - u = torch.rand(size=(batch_size,), device="cpu") + u = torch.rand(size=(batch_size,), device=device, generator=generator) return u From 8c8ada4730976002490145d111533751ba74166b Mon Sep 17 00:00:00 2001 From: badayvedat Date: Sat, 1 Feb 2025 12:04:35 +0300 Subject: [PATCH 2/3] chore: update type hint --- src/diffusers/training_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 1269135f1480..fcef631c12a8 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -253,7 +253,7 @@ def compute_density_for_timestep_sampling( logit_mean: float = None, logit_std: float = None, mode_scale: float = None, - device: torch.device = "cpu", + device: torch.device | str = "cpu", generator: Optional[torch.Generator] = None, ): """ From ba0395097aa8a8fe7cf31eabfc820a14763e3ab3 Mon Sep 17 00:00:00 2001 From: badayvedat Date: Sat, 1 Feb 2025 13:59:40 +0300 Subject: [PATCH 3/3] refactor: use union for type hint --- src/diffusers/training_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index fcef631c12a8..c570bac733db 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -253,7 +253,7 @@ def compute_density_for_timestep_sampling( logit_mean: float = None, logit_std: float = None, mode_scale: float = None, - device: torch.device | str = "cpu", + device: Union[torch.device, str] = "cpu", generator: Optional[torch.Generator] = None, ): """