diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index 7af3f33bfa..eea573609c 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -916,6 +916,7 @@ def sample( verbose: bool = True, seg: torch.Tensor | None = None, cfg: float | None = None, + cfg_fill_value: float = -1.0, ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: """ Args: @@ -929,6 +930,7 @@ def sample( verbose: if true, prints the progression bar of the sampling process. seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided. cfg: classifier-free-guidance scale, which indicates the level of strengthening on the conditioning. + cfg_fill_value: the fill value to use for the unconditioned input when using classifier-free guidance. """ if mode not in ["crossattn", "concat"]: raise NotImplementedError(f"{mode} condition is not supported") @@ -961,7 +963,7 @@ def sample( model_input = torch.cat([image] * 2, dim=0) if conditioning is not None: uncondition = torch.ones_like(conditioning) - uncondition.fill_(-1) + uncondition.fill_(cfg_fill_value) conditioning_input = torch.cat([uncondition, conditioning], dim=0) else: conditioning_input = None @@ -1261,6 +1263,7 @@ def sample( # type: ignore[override] verbose: bool = True, seg: torch.Tensor | None = None, cfg: float | None = None, + cfg_fill_value: float = -1.0, ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: """ Args: @@ -1276,6 +1279,7 @@ def sample( # type: ignore[override] seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model is instance of SPADEAutoencoderKL, segmentation must be provided. cfg: classifier-free-guidance scale, which indicates the level of strengthening on the conditioning. + cfg_fill_value: the fill value to use for the unconditioned input when using classifier-free guidance. """ if ( @@ -1300,6 +1304,7 @@ def sample( # type: ignore[override] verbose=verbose, seg=seg, cfg=cfg, + cfg_fill_value=cfg_fill_value, ) if save_intermediates: @@ -1479,6 +1484,7 @@ def sample( # type: ignore[override] verbose: bool = True, seg: torch.Tensor | None = None, cfg: float | None = None, + cfg_fill_value: float = -1.0, ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: """ Args: @@ -1493,7 +1499,8 @@ def sample( # type: ignore[override] mode: Conditioning mode for the network. verbose: if true, prints the progression bar of the sampling process. seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided. - cfg: classifier-free-guidance scale, which indicates the level of strengthening on the conditioning. + cfg: classifier-free-guidance scale, which indicates the level of strengthening on the conditioning. + cfg_fill_value: the fill value to use for the unconditioned input when using classifier-free guidance. """ if mode not in ["crossattn", "concat"]: raise NotImplementedError(f"{mode} condition is not supported") @@ -1521,7 +1528,7 @@ def sample( # type: ignore[override] model_input = torch.cat([image] * 2, dim=0) if conditioning is not None: uncondition = torch.ones_like(conditioning) - uncondition.fill_(-1) + uncondition.fill_(cfg_fill_value) conditioning_input = torch.cat([uncondition, conditioning], dim=0) else: conditioning_input = None @@ -1839,6 +1846,7 @@ def sample( # type: ignore[override] verbose: bool = True, seg: torch.Tensor | None = None, cfg: float | None = None, + cfg_fill_value: float = -1.0, ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: """ Args: @@ -1856,6 +1864,7 @@ def sample( # type: ignore[override] seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model is instance of SPADEAutoencoderKL, segmentation must be provided. cfg: classifier-free-guidance scale, which indicates the level of strengthening on the conditioning. + cfg_fill_value: the fill value to use for the unconditioned input when using classifier-free guidance. """ if ( @@ -1884,6 +1893,7 @@ def sample( # type: ignore[override] verbose=verbose, seg=seg, cfg=cfg, + cfg_fill_value=cfg_fill_value, ) if save_intermediates: diff --git a/tests/inferers/test_diffusion_inferer.py b/tests/inferers/test_diffusion_inferer.py index 02890a71d4..81874ed3a8 100644 --- a/tests/inferers/test_diffusion_inferer.py +++ b/tests/inferers/test_diffusion_inferer.py @@ -106,6 +106,7 @@ def test_sample_cfg(self, model_params, input_shape): save_intermediates=True, intermediate_steps=1, cfg=5, + cfg_fill_value=-1, ) self.assertEqual(sample.shape, noise.shape) diff --git a/tests/inferers/test_latent_diffusion_inferer.py b/tests/inferers/test_latent_diffusion_inferer.py index ed5e1a149e..ab80363cde 100644 --- a/tests/inferers/test_latent_diffusion_inferer.py +++ b/tests/inferers/test_latent_diffusion_inferer.py @@ -456,6 +456,7 @@ def test_sample_shape_with_cfg( scheduler=scheduler, seg=input_seg, cfg=5, + cfg_fill_value=-1, ) else: sample = inferer.sample(