Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import torch
from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser
from k_diffusion.sampling import get_sigmas_karras

from ...pipelines import DiffusionPipeline
from ...schedulers import LMSDiscreteScheduler
Expand Down Expand Up @@ -400,6 +401,7 @@ def __call__(
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
use_karras_sigmas: Optional[bool] = False,
):
r"""
Function invoked when calling the pipeline for generation.
Expand Down Expand Up @@ -456,7 +458,10 @@ def __call__(
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.

use_karras_sigmas (`bool`, *optional*, defaults to `False`):
Use karras sigmas. For example, specifying `sample_dpmpp_2m` to `set_scheduler` will be equivalent to
`DPM++2M` in stable-diffusion-webui. On top of that, setting this option to True will make it `DPM++2M
Karras`.
Comment on lines +461 to +464
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is how docstrings should be I think.

Could you maybe just add a link to the paper that introduced Karras sigmas?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sayakpaul Thanks for merging! The so-called Karras Sigmas corresponds to Eq. 5 in the paper. Since I missed the timing, could you please include a citation as a comment when your working on #2905?

Karras, T. (2022, June 1). Elucidating the Design Space of Diffusion-Based Generative Models. arXiv.org. https://arxiv.org/abs/2206.00364

Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
Expand Down Expand Up @@ -494,10 +499,18 @@ def __call__(

# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device)
sigmas = self.scheduler.sigmas

# 5. Prepare sigmas
if use_karras_sigmas:
sigma_min: float = self.k_diffusion_model.sigmas[0].item()
sigma_max: float = self.k_diffusion_model.sigmas[-1].item()
sigmas = get_sigmas_karras(n=num_inference_steps, sigma_min=sigma_min, sigma_max=sigma_max)
sigmas = sigmas.to(device)
else:
sigmas = self.scheduler.sigmas
sigmas = sigmas.to(prompt_embeds.dtype)

# 5. Prepare latent variables
# 6. Prepare latent variables
num_channels_latents = self.unet.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
Expand All @@ -513,7 +526,7 @@ def __call__(
self.k_diffusion_model.sigmas = self.k_diffusion_model.sigmas.to(latents.device)
self.k_diffusion_model.log_sigmas = self.k_diffusion_model.log_sigmas.to(latents.device)

# 6. Define model function
# 7. Define model function
def model_fn(x, t):
latent_model_input = torch.cat([x] * 2)
t = torch.cat([t] * 2)
Expand All @@ -524,16 +537,16 @@ def model_fn(x, t):
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
return noise_pred

# 7. Run k-diffusion solver
# 8. Run k-diffusion solver
latents = self.sampler(model_fn, latents, sigmas)

# 8. Post-processing
# 9. Post-processing
image = self.decode_latents(latents)

# 9. Run safety checker
# 10. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)

# 10. Convert to PIL
# 11. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,32 @@ def test_stable_diffusion_2(self):
expected_slice = np.array([0.1237, 0.1320, 0.1438, 0.1359, 0.1390, 0.1132, 0.1277, 0.1175, 0.1112])

assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-1

def test_stable_diffusion_karras_sigmas(self):
sd_pipe = StableDiffusionKDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base")
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)

sd_pipe.set_scheduler("sample_dpmpp_2m")

prompt = "A painting of a squirrel eating a burger"
generator = torch.manual_seed(0)
output = sd_pipe(
[prompt],
generator=generator,
guidance_scale=7.5,
num_inference_steps=15,
output_type="np",
use_karras_sigmas=True,
)

image = output.images

image_slice = image[0, -3:, -3:, -1]

assert image.shape == (1, 512, 512, 3)
expected_slice = np.array(
[0.11381689, 0.12112921, 0.1389457, 0.12549606, 0.1244964, 0.10831517, 0.11562866, 0.10867816, 0.10499048]
)

assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2