From df88819b2d5c100feb5bf897f15e85ae4e0121e6 Mon Sep 17 00:00:00 2001 From: Pengfei Guo Date: Thu, 27 Jul 2023 21:46:17 -0400 Subject: [PATCH 1/5] add concat condition support --- generative/inferers/inferer.py | 222 ++++++++++++++++++--------------- 1 file changed, 123 insertions(+), 99 deletions(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index 769b40c0..ad897107 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -38,12 +38,13 @@ def __init__(self, scheduler: nn.Module) -> None: self.scheduler = scheduler def __call__( - self, - inputs: torch.Tensor, - diffusion_model: Callable[..., torch.Tensor], - noise: torch.Tensor, - timesteps: torch.Tensor, - condition: torch.Tensor | None = None, + self, + inputs: torch.Tensor, + diffusion_model: Callable[..., torch.Tensor], + noise: torch.Tensor, + timesteps: torch.Tensor, + condition: torch.Tensor | None = None, + mode: str = "crossattn" ) -> torch.Tensor: """ Implements the forward pass for a supervised training iteration. @@ -55,21 +56,28 @@ def __call__( timesteps: random timesteps. condition: Conditioning for network input. """ + if mode not in ["crossattn", "concat"]: + raise NotImplementedError(f"{mode} condition is not supported") + noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) + if mode == "concat": + noisy_image = torch.cat([noisy_image, condition], dim=1) + condition = None prediction = diffusion_model(x=noisy_image, timesteps=timesteps, context=condition) return prediction @torch.no_grad() def sample( - self, - input_noise: torch.Tensor, - diffusion_model: Callable[..., torch.Tensor], - scheduler: Callable[..., torch.Tensor] | None = None, - save_intermediates: bool | None = False, - intermediate_steps: int | None = 100, - conditioning: torch.Tensor | None = None, - verbose: bool = True, + self, + input_noise: torch.Tensor, + diffusion_model: Callable[..., torch.Tensor], + scheduler: Callable[..., torch.Tensor] | None = None, + save_intermediates: bool | None = False, + intermediate_steps: int | None = 100, + conditioning: torch.Tensor | None = None, + mode: str = "crossattn", + verbose: bool = True, ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: """ Args: @@ -79,8 +87,12 @@ def sample( save_intermediates: whether to return intermediates along the sampling change intermediate_steps: if save_intermediates is True, saves every n steps conditioning: Conditioning for network input. + mode: Conditioning mode for the network. verbose: if true, prints the progression bar of the sampling process. """ + if mode not in ["crossattn", "concat"]: + raise NotImplementedError(f"{mode} condition is not supported") + if not scheduler: scheduler = self.scheduler image = input_noise @@ -91,9 +103,15 @@ def sample( intermediates = [] for t in progress_bar: # 1. predict noise model_output - model_output = diffusion_model( - image, timesteps=torch.Tensor((t,)).to(input_noise.device), context=conditioning - ) + if mode == "concat": + image = torch.cat([image, conditioning], dim=1) + model_output = diffusion_model( + image, timesteps=torch.Tensor((t,)).to(input_noise.device), context=None + ) + else: + model_output = diffusion_model( + image, timesteps=torch.Tensor((t,)).to(input_noise.device), context=conditioning + ) # 2. compute previous image: x_t -> x_t-1 image, _ = scheduler.step(model_output, t, image) @@ -106,15 +124,15 @@ def sample( @torch.no_grad() def get_likelihood( - self, - inputs: torch.Tensor, - diffusion_model: Callable[..., torch.Tensor], - scheduler: Callable[..., torch.Tensor] | None = None, - save_intermediates: bool | None = False, - conditioning: torch.Tensor | None = None, - original_input_range: tuple | None = (0, 255), - scaled_input_range: tuple | None = (0, 1), - verbose: bool = True, + self, + inputs: torch.Tensor, + diffusion_model: Callable[..., torch.Tensor], + scheduler: Callable[..., torch.Tensor] | None = None, + save_intermediates: bool | None = False, + conditioning: torch.Tensor | None = None, + original_input_range: tuple | None = (0, 255), + scaled_input_range: tuple | None = (0, 1), + verbose: bool = True, ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: """ Computes the log-likelihoods for an input. @@ -167,7 +185,7 @@ def get_likelihood( elif scheduler.prediction_type == "sample": pred_original_sample = model_output elif scheduler.prediction_type == "v_prediction": - pred_original_sample = (alpha_prod_t**0.5) * noisy_image - (beta_prod_t**0.5) * model_output + pred_original_sample = (alpha_prod_t ** 0.5) * noisy_image - (beta_prod_t ** 0.5) * model_output # 3. Clip "predicted x_0" if scheduler.clip_sample: pred_original_sample = torch.clamp(pred_original_sample, -1, 1) @@ -200,11 +218,11 @@ def get_likelihood( else: # compute kl between two normals kl = 0.5 * ( - -1.0 - + log_predicted_variance - - log_posterior_variance - + torch.exp(log_posterior_variance - log_predicted_variance) - + ((posterior_mean - predicted_mean) ** 2) * torch.exp(-log_predicted_variance) + -1.0 + + log_predicted_variance + - log_posterior_variance + + torch.exp(log_posterior_variance - log_predicted_variance) + + ((posterior_mean - predicted_mean) ** 2) * torch.exp(-log_predicted_variance) ) total_kl += kl.view(kl.shape[0], -1).mean(axis=1) if save_intermediates: @@ -222,16 +240,17 @@ def _approx_standard_normal_cdf(self, x): """ return 0.5 * ( - 1.0 + torch.tanh(torch.sqrt(torch.Tensor([2.0 / math.pi]).to(x.device)) * (x + 0.044715 * torch.pow(x, 3))) + 1.0 + torch.tanh( + torch.sqrt(torch.Tensor([2.0 / math.pi]).to(x.device)) * (x + 0.044715 * torch.pow(x, 3))) ) def _get_decoder_log_likelihood( - self, - inputs: torch.Tensor, - means: torch.Tensor, - log_scales: torch.Tensor, - original_input_range: tuple | None = (0, 255), - scaled_input_range: tuple | None = (0, 1), + self, + inputs: torch.Tensor, + means: torch.Tensor, + log_scales: torch.Tensor, + original_input_range: tuple | None = (0, 255), + scaled_input_range: tuple | None = (0, 1), ) -> torch.Tensor: """ Compute the log-likelihood of a Gaussian distribution discretizing to a @@ -247,7 +266,7 @@ def _get_decoder_log_likelihood( """ assert inputs.shape == means.shape bin_width = (scaled_input_range[1] - scaled_input_range[0]) / ( - original_input_range[1] - original_input_range[0] + original_input_range[1] - original_input_range[0] ) centered_x = inputs - means inv_stdv = torch.exp(-log_scales) @@ -283,13 +302,14 @@ def __init__(self, scheduler: nn.Module, scale_factor: float = 1.0) -> None: self.scale_factor = scale_factor def __call__( - self, - inputs: torch.Tensor, - autoencoder_model: Callable[..., torch.Tensor], - diffusion_model: Callable[..., torch.Tensor], - noise: torch.Tensor, - timesteps: torch.Tensor, - condition: torch.Tensor | None = None, + self, + inputs: torch.Tensor, + autoencoder_model: Callable[..., torch.Tensor], + diffusion_model: Callable[..., torch.Tensor], + noise: torch.Tensor, + timesteps: torch.Tensor, + condition: torch.Tensor | None = None, + mode: str = "crossattn", ) -> torch.Tensor: """ Implements the forward pass for a supervised training iteration. @@ -306,22 +326,24 @@ def __call__( latent = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor prediction = super().__call__( - inputs=latent, diffusion_model=diffusion_model, noise=noise, timesteps=timesteps, condition=condition + inputs=latent, diffusion_model=diffusion_model, noise=noise, timesteps=timesteps, condition=condition, + mode=mode ) return prediction @torch.no_grad() def sample( - self, - input_noise: torch.Tensor, - autoencoder_model: Callable[..., torch.Tensor], - diffusion_model: Callable[..., torch.Tensor], - scheduler: Callable[..., torch.Tensor] | None = None, - save_intermediates: bool | None = False, - intermediate_steps: int | None = 100, - conditioning: torch.Tensor | None = None, - verbose: bool = True, + self, + input_noise: torch.Tensor, + autoencoder_model: Callable[..., torch.Tensor], + diffusion_model: Callable[..., torch.Tensor], + scheduler: Callable[..., torch.Tensor] | None = None, + save_intermediates: bool | None = False, + intermediate_steps: int | None = 100, + conditioning: torch.Tensor | None = None, + mode: str = "crossattn", + verbose: bool = True, ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: """ Args: @@ -332,6 +354,7 @@ def sample( save_intermediates: whether to return intermediates along the sampling change intermediate_steps: if save_intermediates is True, saves every n steps conditioning: Conditioning for network input. + mode: Conditioning mode for the network. verbose: if true, prints the progression bar of the sampling process. """ outputs = super().sample( @@ -341,6 +364,7 @@ def sample( save_intermediates=save_intermediates, intermediate_steps=intermediate_steps, conditioning=conditioning, + mode=mode, verbose=verbose, ) @@ -362,18 +386,18 @@ def sample( @torch.no_grad() def get_likelihood( - self, - inputs: torch.Tensor, - autoencoder_model: Callable[..., torch.Tensor], - diffusion_model: Callable[..., torch.Tensor], - scheduler: Callable[..., torch.Tensor] | None = None, - save_intermediates: bool | None = False, - conditioning: torch.Tensor | None = None, - original_input_range: tuple | None = (0, 255), - scaled_input_range: tuple | None = (0, 1), - verbose: bool = True, - resample_latent_likelihoods: bool = False, - resample_interpolation_mode: str = "nearest", + self, + inputs: torch.Tensor, + autoencoder_model: Callable[..., torch.Tensor], + diffusion_model: Callable[..., torch.Tensor], + scheduler: Callable[..., torch.Tensor] | None = None, + save_intermediates: bool | None = False, + conditioning: torch.Tensor | None = None, + original_input_range: tuple | None = (0, 255), + scaled_input_range: tuple | None = (0, 1), + verbose: bool = True, + resample_latent_likelihoods: bool = False, + resample_interpolation_mode: str = "nearest", ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: """ Computes the log-likelihoods of the latent representations of the input. @@ -423,13 +447,13 @@ def __init__(self) -> None: Inferer.__init__(self) def __call__( - self, - inputs: torch.Tensor, - vqvae_model: Callable[..., torch.Tensor], - transformer_model: Callable[..., torch.Tensor], - ordering: Callable[..., torch.Tensor], - condition: torch.Tensor | None = None, - return_latent: bool = False, + self, + inputs: torch.Tensor, + vqvae_model: Callable[..., torch.Tensor], + transformer_model: Callable[..., torch.Tensor], + ordering: Callable[..., torch.Tensor], + condition: torch.Tensor | None = None, + return_latent: bool = False, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor, tuple]: """ Implements the forward pass for a supervised training iteration. @@ -465,24 +489,24 @@ def __call__( start = torch.randint(low=0, high=seq_len + 1 - max_seq_len, size=(1,)).item() else: start = 0 - prediction = transformer_model(x=latent[:, start : start + max_seq_len], context=condition) + prediction = transformer_model(x=latent[:, start: start + max_seq_len], context=condition) if return_latent: - return prediction, target[:, start : start + max_seq_len], latent_spatial_dim + return prediction, target[:, start: start + max_seq_len], latent_spatial_dim else: return prediction @torch.no_grad() def sample( - self, - latent_spatial_dim: Sequence[int, int, int] | Sequence[int, int], - starting_tokens: torch.Tensor, - vqvae_model: Callable[..., torch.Tensor], - transformer_model: Callable[..., torch.Tensor], - ordering: Callable[..., torch.Tensor], - conditioning: torch.Tensor | None = None, - temperature: float = 1.0, - top_k: int | None = None, - verbose: bool = True, + self, + latent_spatial_dim: Sequence[int, int, int] | Sequence[int, int], + starting_tokens: torch.Tensor, + vqvae_model: Callable[..., torch.Tensor], + transformer_model: Callable[..., torch.Tensor], + ordering: Callable[..., torch.Tensor], + conditioning: torch.Tensor | None = None, + temperature: float = 1.0, + top_k: int | None = None, + verbose: bool = True, ) -> torch.Tensor: """ Sampling function for the VQVAE + Transformer model. @@ -510,7 +534,7 @@ def sample( if latent_seq.size(1) <= transformer_model.max_seq_len: idx_cond = latent_seq else: - idx_cond = latent_seq[:, -transformer_model.max_seq_len :] + idx_cond = latent_seq[:, -transformer_model.max_seq_len:] # forward the model to get the logits for the index in the sequence logits = transformer_model(x=idx_cond, context=conditioning) @@ -537,15 +561,15 @@ def sample( @torch.no_grad() def get_likelihood( - self, - inputs: torch.Tensor, - vqvae_model: Callable[..., torch.Tensor], - transformer_model: Callable[..., torch.Tensor], - ordering: Callable[..., torch.Tensor], - condition: torch.Tensor | None = None, - resample_latent_likelihoods: bool = False, - resample_interpolation_mode: str = "nearest", - verbose: bool = False, + self, + inputs: torch.Tensor, + vqvae_model: Callable[..., torch.Tensor], + transformer_model: Callable[..., torch.Tensor], + ordering: Callable[..., torch.Tensor], + condition: torch.Tensor | None = None, + resample_latent_likelihoods: bool = False, + resample_interpolation_mode: str = "nearest", + verbose: bool = False, ) -> torch.Tensor: """ Computes the log-likelihoods of the latent representations of the input. @@ -596,7 +620,7 @@ def get_likelihood( progress_bar = iter(range(transformer_model.max_seq_len, seq_len)) for i in progress_bar: - idx_cond = latent[:, i + 1 - transformer_model.max_seq_len : i + 1] + idx_cond = latent[:, i + 1 - transformer_model.max_seq_len: i + 1] # forward the model to get the logits for the index in the sequence logits = transformer_model(x=idx_cond, context=condition) # pluck the logits at the final step From dc2dd2cb65bd7e22081cd61ed84ccccf61f72ef4 Mon Sep 17 00:00:00 2001 From: Pengfei Guo Date: Thu, 27 Jul 2023 22:09:04 -0400 Subject: [PATCH 2/5] add concat condition support in get_likelihood --- generative/inferers/inferer.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index ad897107..5603fe60 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -44,7 +44,7 @@ def __call__( noise: torch.Tensor, timesteps: torch.Tensor, condition: torch.Tensor | None = None, - mode: str = "crossattn" + mode: str = "crossattn", ) -> torch.Tensor: """ Implements the forward pass for a supervised training iteration. @@ -55,6 +55,7 @@ def __call__( noise: random noise, of the same shape as the input. timesteps: random timesteps. condition: Conditioning for network input. + mode: Conditioning mode for the network. """ if mode not in ["crossattn", "concat"]: raise NotImplementedError(f"{mode} condition is not supported") @@ -130,6 +131,7 @@ def get_likelihood( scheduler: Callable[..., torch.Tensor] | None = None, save_intermediates: bool | None = False, conditioning: torch.Tensor | None = None, + mode: str = "crossattn", original_input_range: tuple | None = (0, 255), scaled_input_range: tuple | None = (0, 1), verbose: bool = True, @@ -143,6 +145,7 @@ def get_likelihood( scheduler: diffusion scheduler. If none provided will use the class attribute scheduler. save_intermediates: save the intermediate spatial KL maps conditioning: Conditioning for network input. + mode: Conditioning mode for the network. original_input_range: the [min,max] intensity range of the input data before any scaling was applied. scaled_input_range: the [min,max] intensity range of the input data after scaling. verbose: if true, prints the progression bar of the sampling process. @@ -155,6 +158,8 @@ def get_likelihood( f"Likelihood computation is only compatible with DDPMScheduler," f" you are using {scheduler._get_name()}" ) + if mode not in ["crossattn", "concat"]: + raise NotImplementedError(f"{mode} condition is not supported") if verbose and has_tqdm: progress_bar = tqdm(scheduler.timesteps) else: @@ -165,7 +170,13 @@ def get_likelihood( for t in progress_bar: timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long() noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) - model_output = diffusion_model(x=noisy_image, timesteps=timesteps, context=conditioning) + if mode == "concat": + noisy_image = torch.cat([noisy_image, conditioning], dim=1) + model_output = diffusion_model( + noisy_image, timesteps=timesteps, context=None + ) + else: + model_output = diffusion_model(x=noisy_image, timesteps=timesteps, context=conditioning) # get the model's predicted mean, and variance if it is predicted if model_output.shape[1] == inputs.shape[1] * 2 and scheduler.variance_type in ["learned", "learned_range"]: model_output, predicted_variance = torch.split(model_output, inputs.shape[1], dim=1) @@ -321,6 +332,7 @@ def __call__( noise: random noise, of the same shape as the latent representation. timesteps: random timesteps. condition: conditioning for network input. + mode: Conditioning mode for the network. """ with torch.no_grad(): latent = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor @@ -393,6 +405,7 @@ def get_likelihood( scheduler: Callable[..., torch.Tensor] | None = None, save_intermediates: bool | None = False, conditioning: torch.Tensor | None = None, + mode: str = "crossattn", original_input_range: tuple | None = (0, 255), scaled_input_range: tuple | None = (0, 1), verbose: bool = True, @@ -409,6 +422,7 @@ def get_likelihood( scheduler: diffusion scheduler. If none provided will use the class attribute scheduler save_intermediates: save the intermediate spatial KL maps conditioning: Conditioning for network input. + mode: Conditioning mode for the network. original_input_range: the [min,max] intensity range of the input data before any scaling was applied. scaled_input_range: the [min,max] intensity range of the input data after scaling. verbose: if true, prints the progression bar of the sampling process. @@ -428,6 +442,7 @@ def get_likelihood( scheduler=scheduler, save_intermediates=save_intermediates, conditioning=conditioning, + mode=mode, verbose=verbose, ) if save_intermediates and resample_latent_likelihoods: From f67642dc68cd73dfac791d017f6fd14b4df867e8 Mon Sep 17 00:00:00 2001 From: Pengfei Guo Date: Thu, 27 Jul 2023 22:17:38 -0400 Subject: [PATCH 3/5] fix format --- generative/inferers/inferer.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index 5603fe60..3b5f2459 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -229,11 +229,11 @@ def get_likelihood( else: # compute kl between two normals kl = 0.5 * ( - -1.0 - + log_predicted_variance - - log_posterior_variance - + torch.exp(log_posterior_variance - log_predicted_variance) - + ((posterior_mean - predicted_mean) ** 2) * torch.exp(-log_predicted_variance) + -1.0 + + log_predicted_variance + - log_posterior_variance + + torch.exp(log_posterior_variance - log_predicted_variance) + + ((posterior_mean - predicted_mean) ** 2) * torch.exp(-log_predicted_variance) ) total_kl += kl.view(kl.shape[0], -1).mean(axis=1) if save_intermediates: @@ -251,8 +251,7 @@ def _approx_standard_normal_cdf(self, x): """ return 0.5 * ( - 1.0 + torch.tanh( - torch.sqrt(torch.Tensor([2.0 / math.pi]).to(x.device)) * (x + 0.044715 * torch.pow(x, 3))) + 1.0 + torch.tanh(torch.sqrt(torch.Tensor([2.0 / math.pi]).to(x.device)) * (x + 0.044715 * torch.pow(x, 3))) ) def _get_decoder_log_likelihood( @@ -277,7 +276,7 @@ def _get_decoder_log_likelihood( """ assert inputs.shape == means.shape bin_width = (scaled_input_range[1] - scaled_input_range[0]) / ( - original_input_range[1] - original_input_range[0] + original_input_range[1] - original_input_range[0] ) centered_x = inputs - means inv_stdv = torch.exp(-log_scales) From 3487770d78c98072febe4c7501d9456d8c0a41c9 Mon Sep 17 00:00:00 2001 From: Pengfei Guo Date: Thu, 27 Jul 2023 22:42:38 -0400 Subject: [PATCH 4/5] fix format --- generative/inferers/inferer.py | 204 ++++++++++++++++----------------- 1 file changed, 102 insertions(+), 102 deletions(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index 3b5f2459..4b85a398 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -38,13 +38,13 @@ def __init__(self, scheduler: nn.Module) -> None: self.scheduler = scheduler def __call__( - self, - inputs: torch.Tensor, - diffusion_model: Callable[..., torch.Tensor], - noise: torch.Tensor, - timesteps: torch.Tensor, - condition: torch.Tensor | None = None, - mode: str = "crossattn", + self, + inputs: torch.Tensor, + diffusion_model: Callable[..., torch.Tensor], + noise: torch.Tensor, + timesteps: torch.Tensor, + condition: torch.Tensor | None = None, + mode: str = "crossattn", ) -> torch.Tensor: """ Implements the forward pass for a supervised training iteration. @@ -70,15 +70,15 @@ def __call__( @torch.no_grad() def sample( - self, - input_noise: torch.Tensor, - diffusion_model: Callable[..., torch.Tensor], - scheduler: Callable[..., torch.Tensor] | None = None, - save_intermediates: bool | None = False, - intermediate_steps: int | None = 100, - conditioning: torch.Tensor | None = None, - mode: str = "crossattn", - verbose: bool = True, + self, + input_noise: torch.Tensor, + diffusion_model: Callable[..., torch.Tensor], + scheduler: Callable[..., torch.Tensor] | None = None, + save_intermediates: bool | None = False, + intermediate_steps: int | None = 100, + conditioning: torch.Tensor | None = None, + mode: str = "crossattn", + verbose: bool = True, ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: """ Args: @@ -106,9 +106,7 @@ def sample( # 1. predict noise model_output if mode == "concat": image = torch.cat([image, conditioning], dim=1) - model_output = diffusion_model( - image, timesteps=torch.Tensor((t,)).to(input_noise.device), context=None - ) + model_output = diffusion_model(image, timesteps=torch.Tensor((t,)).to(input_noise.device), context=None) else: model_output = diffusion_model( image, timesteps=torch.Tensor((t,)).to(input_noise.device), context=conditioning @@ -125,16 +123,16 @@ def sample( @torch.no_grad() def get_likelihood( - self, - inputs: torch.Tensor, - diffusion_model: Callable[..., torch.Tensor], - scheduler: Callable[..., torch.Tensor] | None = None, - save_intermediates: bool | None = False, - conditioning: torch.Tensor | None = None, - mode: str = "crossattn", - original_input_range: tuple | None = (0, 255), - scaled_input_range: tuple | None = (0, 1), - verbose: bool = True, + self, + inputs: torch.Tensor, + diffusion_model: Callable[..., torch.Tensor], + scheduler: Callable[..., torch.Tensor] | None = None, + save_intermediates: bool | None = False, + conditioning: torch.Tensor | None = None, + mode: str = "crossattn", + original_input_range: tuple | None = (0, 255), + scaled_input_range: tuple | None = (0, 1), + verbose: bool = True, ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: """ Computes the log-likelihoods for an input. @@ -172,9 +170,7 @@ def get_likelihood( noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) if mode == "concat": noisy_image = torch.cat([noisy_image, conditioning], dim=1) - model_output = diffusion_model( - noisy_image, timesteps=timesteps, context=None - ) + model_output = diffusion_model(noisy_image, timesteps=timesteps, context=None) else: model_output = diffusion_model(x=noisy_image, timesteps=timesteps, context=conditioning) # get the model's predicted mean, and variance if it is predicted @@ -196,7 +192,7 @@ def get_likelihood( elif scheduler.prediction_type == "sample": pred_original_sample = model_output elif scheduler.prediction_type == "v_prediction": - pred_original_sample = (alpha_prod_t ** 0.5) * noisy_image - (beta_prod_t ** 0.5) * model_output + pred_original_sample = (alpha_prod_t**0.5) * noisy_image - (beta_prod_t**0.5) * model_output # 3. Clip "predicted x_0" if scheduler.clip_sample: pred_original_sample = torch.clamp(pred_original_sample, -1, 1) @@ -255,12 +251,12 @@ def _approx_standard_normal_cdf(self, x): ) def _get_decoder_log_likelihood( - self, - inputs: torch.Tensor, - means: torch.Tensor, - log_scales: torch.Tensor, - original_input_range: tuple | None = (0, 255), - scaled_input_range: tuple | None = (0, 1), + self, + inputs: torch.Tensor, + means: torch.Tensor, + log_scales: torch.Tensor, + original_input_range: tuple | None = (0, 255), + scaled_input_range: tuple | None = (0, 1), ) -> torch.Tensor: """ Compute the log-likelihood of a Gaussian distribution discretizing to a @@ -312,14 +308,14 @@ def __init__(self, scheduler: nn.Module, scale_factor: float = 1.0) -> None: self.scale_factor = scale_factor def __call__( - self, - inputs: torch.Tensor, - autoencoder_model: Callable[..., torch.Tensor], - diffusion_model: Callable[..., torch.Tensor], - noise: torch.Tensor, - timesteps: torch.Tensor, - condition: torch.Tensor | None = None, - mode: str = "crossattn", + self, + inputs: torch.Tensor, + autoencoder_model: Callable[..., torch.Tensor], + diffusion_model: Callable[..., torch.Tensor], + noise: torch.Tensor, + timesteps: torch.Tensor, + condition: torch.Tensor | None = None, + mode: str = "crossattn", ) -> torch.Tensor: """ Implements the forward pass for a supervised training iteration. @@ -337,24 +333,28 @@ def __call__( latent = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor prediction = super().__call__( - inputs=latent, diffusion_model=diffusion_model, noise=noise, timesteps=timesteps, condition=condition, - mode=mode + inputs=latent, + diffusion_model=diffusion_model, + noise=noise, + timesteps=timesteps, + condition=condition, + mode=mode, ) return prediction @torch.no_grad() def sample( - self, - input_noise: torch.Tensor, - autoencoder_model: Callable[..., torch.Tensor], - diffusion_model: Callable[..., torch.Tensor], - scheduler: Callable[..., torch.Tensor] | None = None, - save_intermediates: bool | None = False, - intermediate_steps: int | None = 100, - conditioning: torch.Tensor | None = None, - mode: str = "crossattn", - verbose: bool = True, + self, + input_noise: torch.Tensor, + autoencoder_model: Callable[..., torch.Tensor], + diffusion_model: Callable[..., torch.Tensor], + scheduler: Callable[..., torch.Tensor] | None = None, + save_intermediates: bool | None = False, + intermediate_steps: int | None = 100, + conditioning: torch.Tensor | None = None, + mode: str = "crossattn", + verbose: bool = True, ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: """ Args: @@ -397,19 +397,19 @@ def sample( @torch.no_grad() def get_likelihood( - self, - inputs: torch.Tensor, - autoencoder_model: Callable[..., torch.Tensor], - diffusion_model: Callable[..., torch.Tensor], - scheduler: Callable[..., torch.Tensor] | None = None, - save_intermediates: bool | None = False, - conditioning: torch.Tensor | None = None, - mode: str = "crossattn", - original_input_range: tuple | None = (0, 255), - scaled_input_range: tuple | None = (0, 1), - verbose: bool = True, - resample_latent_likelihoods: bool = False, - resample_interpolation_mode: str = "nearest", + self, + inputs: torch.Tensor, + autoencoder_model: Callable[..., torch.Tensor], + diffusion_model: Callable[..., torch.Tensor], + scheduler: Callable[..., torch.Tensor] | None = None, + save_intermediates: bool | None = False, + conditioning: torch.Tensor | None = None, + mode: str = "crossattn", + original_input_range: tuple | None = (0, 255), + scaled_input_range: tuple | None = (0, 1), + verbose: bool = True, + resample_latent_likelihoods: bool = False, + resample_interpolation_mode: str = "nearest", ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: """ Computes the log-likelihoods of the latent representations of the input. @@ -461,13 +461,13 @@ def __init__(self) -> None: Inferer.__init__(self) def __call__( - self, - inputs: torch.Tensor, - vqvae_model: Callable[..., torch.Tensor], - transformer_model: Callable[..., torch.Tensor], - ordering: Callable[..., torch.Tensor], - condition: torch.Tensor | None = None, - return_latent: bool = False, + self, + inputs: torch.Tensor, + vqvae_model: Callable[..., torch.Tensor], + transformer_model: Callable[..., torch.Tensor], + ordering: Callable[..., torch.Tensor], + condition: torch.Tensor | None = None, + return_latent: bool = False, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor, tuple]: """ Implements the forward pass for a supervised training iteration. @@ -503,24 +503,24 @@ def __call__( start = torch.randint(low=0, high=seq_len + 1 - max_seq_len, size=(1,)).item() else: start = 0 - prediction = transformer_model(x=latent[:, start: start + max_seq_len], context=condition) + prediction = transformer_model(x=latent[:, start : start + max_seq_len], context=condition) if return_latent: - return prediction, target[:, start: start + max_seq_len], latent_spatial_dim + return prediction, target[:, start : start + max_seq_len], latent_spatial_dim else: return prediction @torch.no_grad() def sample( - self, - latent_spatial_dim: Sequence[int, int, int] | Sequence[int, int], - starting_tokens: torch.Tensor, - vqvae_model: Callable[..., torch.Tensor], - transformer_model: Callable[..., torch.Tensor], - ordering: Callable[..., torch.Tensor], - conditioning: torch.Tensor | None = None, - temperature: float = 1.0, - top_k: int | None = None, - verbose: bool = True, + self, + latent_spatial_dim: Sequence[int, int, int] | Sequence[int, int], + starting_tokens: torch.Tensor, + vqvae_model: Callable[..., torch.Tensor], + transformer_model: Callable[..., torch.Tensor], + ordering: Callable[..., torch.Tensor], + conditioning: torch.Tensor | None = None, + temperature: float = 1.0, + top_k: int | None = None, + verbose: bool = True, ) -> torch.Tensor: """ Sampling function for the VQVAE + Transformer model. @@ -548,7 +548,7 @@ def sample( if latent_seq.size(1) <= transformer_model.max_seq_len: idx_cond = latent_seq else: - idx_cond = latent_seq[:, -transformer_model.max_seq_len:] + idx_cond = latent_seq[:, -transformer_model.max_seq_len :] # forward the model to get the logits for the index in the sequence logits = transformer_model(x=idx_cond, context=conditioning) @@ -575,15 +575,15 @@ def sample( @torch.no_grad() def get_likelihood( - self, - inputs: torch.Tensor, - vqvae_model: Callable[..., torch.Tensor], - transformer_model: Callable[..., torch.Tensor], - ordering: Callable[..., torch.Tensor], - condition: torch.Tensor | None = None, - resample_latent_likelihoods: bool = False, - resample_interpolation_mode: str = "nearest", - verbose: bool = False, + self, + inputs: torch.Tensor, + vqvae_model: Callable[..., torch.Tensor], + transformer_model: Callable[..., torch.Tensor], + ordering: Callable[..., torch.Tensor], + condition: torch.Tensor | None = None, + resample_latent_likelihoods: bool = False, + resample_interpolation_mode: str = "nearest", + verbose: bool = False, ) -> torch.Tensor: """ Computes the log-likelihoods of the latent representations of the input. @@ -634,7 +634,7 @@ def get_likelihood( progress_bar = iter(range(transformer_model.max_seq_len, seq_len)) for i in progress_bar: - idx_cond = latent[:, i + 1 - transformer_model.max_seq_len: i + 1] + idx_cond = latent[:, i + 1 - transformer_model.max_seq_len : i + 1] # forward the model to get the logits for the index in the sequence logits = transformer_model(x=idx_cond, context=condition) # pluck the logits at the final step From 7d7c066aeec93444cd38ef85398f7a58172c11f6 Mon Sep 17 00:00:00 2001 From: Pengfei Guo Date: Wed, 2 Aug 2023 23:16:59 -0400 Subject: [PATCH 5/5] add unittests and fix bug --- generative/inferers/inferer.py | 6 +- tests/test_diffusion_inferer.py | 56 ++++++++++++++++++ tests/test_latent_diffusion_inferer.py | 80 ++++++++++++++++++++++++++ 3 files changed, 140 insertions(+), 2 deletions(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index 4b85a398..68b0bdd6 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -105,8 +105,10 @@ def sample( for t in progress_bar: # 1. predict noise model_output if mode == "concat": - image = torch.cat([image, conditioning], dim=1) - model_output = diffusion_model(image, timesteps=torch.Tensor((t,)).to(input_noise.device), context=None) + model_input = torch.cat([image, conditioning], dim=1) + model_output = diffusion_model( + model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), context=None + ) else: model_output = diffusion_model( image, timesteps=torch.Tensor((t,)).to(input_noise.device), context=conditioning diff --git a/tests/test_diffusion_inferer.py b/tests/test_diffusion_inferer.py index 6faf0e68..f3f9aa78 100644 --- a/tests/test_diffusion_inferer.py +++ b/tests/test_diffusion_inferer.py @@ -161,6 +161,62 @@ def test_normal_cdf(self): cdf_true = norm.cdf(x) torch.testing.assert_allclose(cdf_approx, cdf_true, atol=1e-3, rtol=1e-5) + @parameterized.expand(TEST_CASES) + def test_sampler_conditioned_concat(self, model_params, input_shape): + # copy the model_params dict to prevent from modifying test cases + model_params = model_params.copy() + n_concat_channel = 2 + model_params["in_channels"] = model_params["in_channels"] + n_concat_channel + model_params["cross_attention_dim"] = None + model_params["with_conditioning"] = False + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + noise = torch.randn(input_shape).to(device) + conditioning_shape = list(input_shape) + conditioning_shape[1] = n_concat_channel + conditioning = torch.randn(conditioning_shape).to(device) + scheduler = DDIMScheduler(num_train_timesteps=1000) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, + diffusion_model=model, + scheduler=scheduler, + save_intermediates=True, + intermediate_steps=1, + conditioning=conditioning, + mode="concat", + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(TEST_CASES) + def test_call_conditioned_concat(self, model_params, input_shape): + # copy the model_params dict to prevent from modifying test cases + model_params = model_params.copy() + n_concat_channel = 2 + model_params["in_channels"] = model_params["in_channels"] + n_concat_channel + model_params["cross_attention_dim"] = None + model_params["with_conditioning"] = False + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + input = torch.randn(input_shape).to(device) + noise = torch.randn(input_shape).to(device) + conditioning_shape = list(input_shape) + conditioning_shape[1] = n_concat_channel + conditioning = torch.randn(conditioning_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + sample = inferer( + inputs=input, noise=noise, diffusion_model=model, timesteps=timesteps, condition=conditioning, mode="concat" + ) + self.assertEqual(sample.shape, input_shape) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_latent_diffusion_inferer.py b/tests/test_latent_diffusion_inferer.py index a049e1d5..296e9266 100644 --- a/tests/test_latent_diffusion_inferer.py +++ b/tests/test_latent_diffusion_inferer.py @@ -245,6 +245,86 @@ def test_resample_likelihoods(self, model_type, autoencoder_params, stage_2_para self.assertEqual(len(intermediates), 10) self.assertEqual(intermediates[0].shape[2:], input_shape[2:]) + @parameterized.expand(TEST_CASES) + def test_prediction_shape_conditioned_concat( + self, model_type, autoencoder_params, stage_2_params, input_shape, latent_shape + ): + if model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + + stage_2_params = stage_2_params.copy() + n_concat_channel = 3 + stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel + stage_2 = DiffusionModelUNet(**stage_2_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + noise = torch.randn(latent_shape).to(device) + conditioning_shape = list(latent_shape) + conditioning_shape[1] = n_concat_channel + conditioning = torch.randn(conditioning_shape).to(device) + + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + noise=noise, + timesteps=timesteps, + condition=conditioning, + mode="concat", + ) + self.assertEqual(prediction.shape, latent_shape) + + @parameterized.expand(TEST_CASES) + def test_sample_shape_conditioned_concat( + self, model_type, autoencoder_params, stage_2_params, input_shape, latent_shape + ): + if model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + stage_2_params = stage_2_params.copy() + n_concat_channel = 3 + stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel + stage_2 = DiffusionModelUNet(**stage_2_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + noise = torch.randn(latent_shape).to(device) + conditioning_shape = list(latent_shape) + conditioning_shape[1] = n_concat_channel + conditioning = torch.randn(conditioning_shape).to(device) + + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + sample = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + conditioning=conditioning, + mode="concat", + ) + self.assertEqual(sample.shape, input_shape) + if __name__ == "__main__": unittest.main()