diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 2c3d5c8e15e8..6c64f45c9856 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -5,12 +5,11 @@ import torch import PIL -from tqdm.auto import tqdm from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from ...models import AutoencoderKL, UNet2DConditionModel from ...pipeline_utils import DiffusionPipeline -from ...schedulers import DDIMScheduler, PNDMScheduler +from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from .safety_checker import StableDiffusionSafetyChecker @@ -31,7 +30,7 @@ def __init__( text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, - scheduler: Union[DDIMScheduler, PNDMScheduler], + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, ): @@ -93,12 +92,17 @@ def __call__( # get the original timestep using init_timestep init_timestep = int(num_inference_steps * strength) + offset init_timestep = min(init_timestep, num_inference_steps) - timesteps = self.scheduler.timesteps[-init_timestep] - timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device) + if isinstance(self.scheduler, LMSDiscreteScheduler): + timesteps = torch.tensor( + [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device + ) + else: + timesteps = self.scheduler.timesteps[-init_timestep] + timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device) # add noise to latents using the timesteps noise = torch.randn(init_latents.shape, generator=generator, device=self.device) - init_latents = self.scheduler.add_noise(init_latents, noise, timesteps) + init_latents = self.scheduler.add_noise(init_latents, noise, timesteps).to(self.device) # get prompt text embeddings text_input = self.tokenizer( @@ -137,11 +141,22 @@ def __call__( extra_step_kwargs["eta"] = eta latents = init_latents + t_start = max(num_inference_steps - init_timestep + offset, 0) - for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])): + for i, t in enumerate(self.progress_bar(self.scheduler.timesteps[t_start:])): + t_index = t_start + i + # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas + if isinstance(self.scheduler, LMSDiscreteScheduler): + sigma = self.scheduler.sigmas[t_index] + # the model input needs to be scaled to match the continuous ODE formulation in K-LMS + latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) + latent_model_input = latent_model_input.to(self.unet.dtype) + t = t.to(self.unet.dtype) + # predict the noise residual noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"] @@ -151,11 +166,14 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"] + if isinstance(self.scheduler, LMSDiscreteScheduler): + latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs)["prev_sample"] + else: + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"] # scale and decode the image latents with vae latents = 1 / 0.18215 * latents - image = self.vae.decode(latents) + image = self.vae.decode(latents.to(self.vae.dtype)) image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 04ad59992677..257517a12c65 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -445,6 +445,49 @@ def test_stable_diffusion_img2img(self): expected_slice = np.array([0.4492, 0.3865, 0.4222, 0.5854, 0.5139, 0.4379, 0.4193, 0.48, 0.4218]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + def test_stable_diffusion_img2img_k_lms(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet + scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear") + + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + init_image = self.dummy_image.to(device) + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionImg2ImgPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=self.dummy_safety_checker, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=device).manual_seed(0) + output = sd_pipe( + [prompt], + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + init_image=init_image, + ) + + image = output["sample"] + + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 32, 32, 3) + expected_slice = np.array([0.4367, 0.4986, 0.4372, 0.6706, 0.5665, 0.444, 0.5864, 0.6019, 0.5203]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + def test_stable_diffusion_inpaint(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator unet = self.dummy_cond_unet @@ -892,7 +935,7 @@ def test_lms_stable_diffusion_pipeline(self): def test_stable_diffusion_img2img_pipeline(self): ds = load_dataset("hf-internal-testing/diffusers-images", split="train") - init_image = ds[1]["image"].resize((768, 512)) + init_image = ds[2]["image"].resize((768, 512)) output_image = ds[0]["image"].resize((768, 512)) model_id = "CompVis/stable-diffusion-v1-4" @@ -915,12 +958,40 @@ def test_stable_diffusion_img2img_pipeline(self): @slow @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU") - def test_stable_diffusion_in_paint_pipeline(self): + def test_stable_diffusion_img2img_pipeline_k_lms(self): ds = load_dataset("hf-internal-testing/diffusers-images", split="train") init_image = ds[2]["image"].resize((768, 512)) - mask_image = ds[3]["image"].resize((768, 512)) - output_image = ds[4]["image"].resize((768, 512)) + output_image = ds[1]["image"].resize((768, 512)) + + lms = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear") + + model_id = "CompVis/stable-diffusion-v1-4" + pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id, scheduler=lms, use_auth_token=True) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + prompt = "A fantasy landscape, trending on artstation" + + generator = torch.Generator(device=torch_device).manual_seed(0) + image = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5, generator=generator)[ + "sample" + ][0] + + expected_array = np.array(output_image) + sampled_array = np.array(image) + + assert sampled_array.shape == (512, 768, 3) + assert np.max(np.abs(sampled_array - expected_array)) < 1e-4 + + @slow + @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU") + def test_stable_diffusion_in_paint_pipeline(self): + ds = load_dataset("hf-internal-testing/diffusers-images", split="train") + + init_image = ds[3]["image"].resize((768, 512)) + mask_image = ds[4]["image"].resize((768, 512)) + output_image = ds[5]["image"].resize((768, 512)) model_id = "CompVis/stable-diffusion-v1-4" pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, use_auth_token=True)