From 484bbc81a58eafe30cf8faf7efbdf2910ea5dff5 Mon Sep 17 00:00:00 2001 From: virginiafdez Date: Wed, 6 Dec 2023 10:09:30 +0000 Subject: [PATCH 1/2] Added SPADE functionality on the decode call for the sample methods. --- generative/inferers/inferer.py | 92 ++++++++++++++++++++-------------- 1 file changed, 54 insertions(+), 38 deletions(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index 261f745b..cc5396ff 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -381,15 +381,18 @@ def __call__( if self.ldm_latent_shape is not None: latent = self.ldm_resizer(latent) - call = partial(super().__call__, seg = seg) if \ - isinstance(diffusion_model, SPADEDiffusionModelUNet) else super().__call__ + call = ( + partial(super().__call__, seg=seg) + if isinstance(diffusion_model, SPADEDiffusionModelUNet) + else super().__call__ + ) prediction = call( inputs=latent, diffusion_model=diffusion_model, noise=noise, timesteps=timesteps, condition=condition, - mode=mode + mode=mode, ) return prediction @@ -456,19 +459,22 @@ def sample( latent = self.autoencoder_resizer(latent) latent_intermediates = [self.autoencoder_resizer(l) for l in latent_intermediates] - image = autoencoder_model.decode_stage_2_outputs(latent / self.scale_factor) + decode = ( + partial(autoencoder_model.decode_stage_2_outputs, seg=seg) + if isinstance(autoencoder_model, SPADEAutoencoderKL) + else autoencoder_model.decode_stage_2_outputs + ) + image = decode(latent / self.scale_factor) if save_intermediates: intermediates = [] for latent_intermediate in latent_intermediates: - if isinstance(autoencoder_model, SPADEAutoencoderKL): - intermediates.append( - autoencoder_model.decode_stage_2_outputs(latent_intermediate / self.scale_factor, seg=seg) - ) - else: - intermediates.append( - autoencoder_model.decode_stage_2_outputs(latent_intermediate / self.scale_factor) - ) + decode = ( + partial(autoencoder_model.decode_stage_2_outputs, seg=seg) + if isinstance(autoencoder_model, SPADEAutoencoderKL) + else autoencoder_model.decode_stage_2_outputs + ) + intermediates.append(decode(latent_intermediate / self.scale_factor)) return image, intermediates else: @@ -905,8 +911,11 @@ def __call__( if cn_cond.shape[2:] != latent.shape[2:]: cn_cond = F.interpolate(cn_cond, latent.shape[2:]) - call = partial(super().__call__, seg = seg) if \ - isinstance(diffusion_model, SPADEDiffusionModelUNet) else super().__call__ + call = ( + partial(super().__call__, seg=seg) + if isinstance(diffusion_model, SPADEDiffusionModelUNet) + else super().__call__ + ) prediction = call( inputs=latent, diffusion_model=diffusion_model, @@ -915,7 +924,7 @@ def __call__( timesteps=timesteps, cn_cond=cn_cond, condition=condition, - mode=mode + mode=mode, ) return prediction @@ -966,20 +975,21 @@ def sample( if cn_cond.shape[2:] != input_noise.shape[2:]: cn_cond = F.interpolate(cn_cond, input_noise.shape[2:]) - sample = partial(super().sample, seg = seg) if \ - isinstance(diffusion_model, SPADEDiffusionModelUNet) else super().sample + sample = ( + partial(super().sample, seg=seg) if isinstance(diffusion_model, SPADEDiffusionModelUNet) else super().sample + ) outputs = sample( - input_noise=input_noise, - diffusion_model=diffusion_model, - controlnet=controlnet, - cn_cond=cn_cond, - scheduler=scheduler, - save_intermediates=save_intermediates, - intermediate_steps=intermediate_steps, - conditioning=conditioning, - mode=mode, - verbose=verbose, + input_noise=input_noise, + diffusion_model=diffusion_model, + controlnet=controlnet, + cn_cond=cn_cond, + scheduler=scheduler, + save_intermediates=save_intermediates, + intermediate_steps=intermediate_steps, + conditioning=conditioning, + mode=mode, + verbose=verbose, ) if save_intermediates: @@ -991,19 +1001,22 @@ def sample( latent = self.autoencoder_resizer(latent) latent_intermediates = [self.autoencoder_resizer(l) for l in latent_intermediates] - image = autoencoder_model.decode_stage_2_outputs(latent / self.scale_factor) + decode = ( + partial(autoencoder_model.decode_stage_2_outputs, seg=seg) + if isinstance(autoencoder_model, SPADEAutoencoderKL) + else autoencoder_model.decode_stage_2_outputs + ) + image = decode(latent / self.scale_factor) if save_intermediates: intermediates = [] for latent_intermediate in latent_intermediates: - if isinstance(autoencoder_model, SPADEAutoencoderKL): - intermediates.append( - autoencoder_model.decode_stage_2_outputs(latent_intermediate / self.scale_factor), seg=seg - ) - else: - intermediates.append( - autoencoder_model.decode_stage_2_outputs(latent_intermediate / self.scale_factor) - ) + decode = ( + partial(autoencoder_model.decode_stage_2_outputs, seg=seg) + if isinstance(autoencoder_model, SPADEAutoencoderKL) + else autoencoder_model.decode_stage_2_outputs + ) + intermediates.append(decode(latent_intermediate / self.scale_factor)) return image, intermediates else: @@ -1064,8 +1077,11 @@ def get_likelihood( if self.ldm_latent_shape is not None: latents = self.ldm_resizer(latents) - get_likelihood = partial(super().get_likelihood, seg = seg) if \ - isinstance(diffusion_model, SPADEDiffusionModelUNet) else super().get_likelihood + get_likelihood = ( + partial(super().get_likelihood, seg=seg) + if isinstance(diffusion_model, SPADEDiffusionModelUNet) + else super().get_likelihood + ) outputs = get_likelihood( inputs=latents, diffusion_model=diffusion_model, From 04f9d97e4eed0b06351562848659a94b6d2d7ebd Mon Sep 17 00:00:00 2001 From: virginiafdez Date: Wed, 6 Dec 2023 14:00:52 +0000 Subject: [PATCH 2/2] Changed the format of the partial statements. --- generative/inferers/inferer.py | 119 ++++++++++++++------------------- 1 file changed, 51 insertions(+), 68 deletions(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index cc5396ff..7be22017 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -381,11 +381,10 @@ def __call__( if self.ldm_latent_shape is not None: latent = self.ldm_resizer(latent) - call = ( - partial(super().__call__, seg=seg) - if isinstance(diffusion_model, SPADEDiffusionModelUNet) - else super().__call__ - ) + call = super().__call__ + if isinstance(diffusion_model, SPADEDiffusionModelUNet): + call = partial(super().__call__, seg=seg) + prediction = call( inputs=latent, diffusion_model=diffusion_model, @@ -435,9 +434,9 @@ def sample( "labels for each must be compatible. " ) - sample = ( - partial(super().sample, seg=seg) if isinstance(diffusion_model, SPADEDiffusionModelUNet) else super().sample - ) + sample = super().sample + if isinstance(diffusion_model, SPADEDiffusionModelUNet): + sample = partial(super().sample, seg=seg) outputs = sample( input_noise=input_noise, @@ -459,21 +458,18 @@ def sample( latent = self.autoencoder_resizer(latent) latent_intermediates = [self.autoencoder_resizer(l) for l in latent_intermediates] - decode = ( - partial(autoencoder_model.decode_stage_2_outputs, seg=seg) - if isinstance(autoencoder_model, SPADEAutoencoderKL) - else autoencoder_model.decode_stage_2_outputs - ) + decode = autoencoder_model.decode_stage_2_outputs + if isinstance(autoencoder_model, SPADEAutoencoderKL): + decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg) + image = decode(latent / self.scale_factor) if save_intermediates: intermediates = [] for latent_intermediate in latent_intermediates: - decode = ( - partial(autoencoder_model.decode_stage_2_outputs, seg=seg) - if isinstance(autoencoder_model, SPADEAutoencoderKL) - else autoencoder_model.decode_stage_2_outputs - ) + decode = autoencoder_model.decode_stage_2_outputs + if isinstance(autoencoder_model, SPADEAutoencoderKL): + decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg) intermediates.append(decode(latent_intermediate / self.scale_factor)) return image, intermediates @@ -527,11 +523,9 @@ def get_likelihood( if self.ldm_latent_shape is not None: latents = self.ldm_resizer(latents) - get_likelihood = ( - partial(super().get_likelihood, seg=seg) - if isinstance(diffusion_model, SPADEDiffusionModelUNet) - else super().get_likelihood - ) + get_likelihood = super().get_likelihood + if isinstance(diffusion_model, SPADEDiffusionModelUNet): + get_likelihood = partial(super().get_likelihood, seg=seg) outputs = get_likelihood( inputs=latents, @@ -602,13 +596,11 @@ def __call__( noisy_image = torch.cat([noisy_image, condition], dim=1) condition = None - diffusion_model = ( - partial(diffusion_model, seg=seg) - if isinstance(diffusion_model, SPADEDiffusionModelUNet) - else diffusion_model - ) + diffuse = diffusion_model + if isinstance(diffusion_model, SPADEDiffusionModelUNet): + diffuse = partial(diffusion_model, seg = seg) - prediction = diffusion_model( + prediction = diffuse( x=noisy_image, timesteps=timesteps, context=condition, @@ -664,14 +656,13 @@ def sample( x=image, timesteps=torch.Tensor((t,)).to(input_noise.device), controlnet_cond=cn_cond ) # 2. predict noise model_output - diffusion_model = ( - partial(diffusion_model, seg=seg) - if isinstance(diffusion_model, SPADEDiffusionModelUNet) - else diffusion_model - ) + diffuse = diffusion_model + if isinstance(diffusion_model, SPADEDiffusionModelUNet): + diffuse = partial(diffusion_model, seg=seg) + if mode == "concat": model_input = torch.cat([image, conditioning], dim=1) - model_output = diffusion_model( + model_output = diffuse( model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), context=None, @@ -679,7 +670,7 @@ def sample( mid_block_additional_residual=mid_block_res_sample, ) else: - model_output = diffusion_model( + model_output = diffuse( image, timesteps=torch.Tensor((t,)).to(input_noise.device), context=conditioning, @@ -753,14 +744,13 @@ def get_likelihood( x=noisy_image, timesteps=torch.Tensor((t,)).to(inputs.device), controlnet_cond=cn_cond ) - diffusion_model = ( - partial(diffusion_model, seg=seg) - if isinstance(diffusion_model, SPADEDiffusionModelUNet) - else diffusion_model - ) + diffuse = diffusion_model + if isinstance(diffusion_model, SPADEDiffusionModelUNet): + diffuse = partial(diffusion_model, seg = seg) + if mode == "concat": noisy_image = torch.cat([noisy_image, conditioning], dim=1) - model_output = diffusion_model( + model_output = diffuse( noisy_image, timesteps=timesteps, context=None, @@ -768,7 +758,7 @@ def get_likelihood( mid_block_additional_residual=mid_block_res_sample, ) else: - model_output = diffusion_model( + model_output = diffuse( x=noisy_image, timesteps=timesteps, context=conditioning, @@ -842,7 +832,6 @@ def get_likelihood( else: return total_kl - class ControlNetLatentDiffusionInferer(ControlNetDiffusionInferer): """ ControlNetLatentDiffusionInferer takes a stage 1 model (VQVAE or AutoencoderKL), diffusion model, controlnet, @@ -911,11 +900,10 @@ def __call__( if cn_cond.shape[2:] != latent.shape[2:]: cn_cond = F.interpolate(cn_cond, latent.shape[2:]) - call = ( - partial(super().__call__, seg=seg) - if isinstance(diffusion_model, SPADEDiffusionModelUNet) - else super().__call__ - ) + call = super().__call__ + if isinstance(diffusion_model, SPADEDiffusionModelUNet): + call = partial(super().__call__, seg=seg) + prediction = call( inputs=latent, diffusion_model=diffusion_model, @@ -975,9 +963,9 @@ def sample( if cn_cond.shape[2:] != input_noise.shape[2:]: cn_cond = F.interpolate(cn_cond, input_noise.shape[2:]) - sample = ( - partial(super().sample, seg=seg) if isinstance(diffusion_model, SPADEDiffusionModelUNet) else super().sample - ) + sample = super().sample + if isinstance(diffusion_model, SPADEDiffusionModelUNet): + sample = partial(super().sample, seg=seg) outputs = sample( input_noise=input_noise, @@ -1001,21 +989,18 @@ def sample( latent = self.autoencoder_resizer(latent) latent_intermediates = [self.autoencoder_resizer(l) for l in latent_intermediates] - decode = ( - partial(autoencoder_model.decode_stage_2_outputs, seg=seg) - if isinstance(autoencoder_model, SPADEAutoencoderKL) - else autoencoder_model.decode_stage_2_outputs - ) + decode = autoencoder_model.decode_stage_2_outputs + if isinstance(autoencoder_model, SPADEAutoencoderKL): + decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg) + image = decode(latent / self.scale_factor) if save_intermediates: intermediates = [] for latent_intermediate in latent_intermediates: - decode = ( - partial(autoencoder_model.decode_stage_2_outputs, seg=seg) - if isinstance(autoencoder_model, SPADEAutoencoderKL) - else autoencoder_model.decode_stage_2_outputs - ) + decode = autoencoder_model.decode_stage_2_outputs + if isinstance(autoencoder_model, SPADEAutoencoderKL): + decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg) intermediates.append(decode(latent_intermediate / self.scale_factor)) return image, intermediates @@ -1077,11 +1062,10 @@ def get_likelihood( if self.ldm_latent_shape is not None: latents = self.ldm_resizer(latents) - get_likelihood = ( - partial(super().get_likelihood, seg=seg) - if isinstance(diffusion_model, SPADEDiffusionModelUNet) - else super().get_likelihood - ) + get_likelihood = super().get_likelihood + if isinstance(diffusion_model, SPADEDiffusionModelUNet): + get_likelihood = partial(super().get_likelihood, seg=seg) + outputs = get_likelihood( inputs=latents, diffusion_model=diffusion_model, @@ -1101,7 +1085,6 @@ def get_likelihood( outputs = (outputs[0], intermediates) return outputs - class VQVAETransformerInferer(Inferer): """ Class to perform inference with a VQVAE + Transformer model.