diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index 769b40c0..68b0bdd6 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -44,6 +44,7 @@ def __call__( noise: torch.Tensor, timesteps: torch.Tensor, condition: torch.Tensor | None = None, + mode: str = "crossattn", ) -> torch.Tensor: """ Implements the forward pass for a supervised training iteration. @@ -54,8 +55,15 @@ def __call__( noise: random noise, of the same shape as the input. timesteps: random timesteps. condition: Conditioning for network input. + mode: Conditioning mode for the network. """ + if mode not in ["crossattn", "concat"]: + raise NotImplementedError(f"{mode} condition is not supported") + noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) + if mode == "concat": + noisy_image = torch.cat([noisy_image, condition], dim=1) + condition = None prediction = diffusion_model(x=noisy_image, timesteps=timesteps, context=condition) return prediction @@ -69,6 +77,7 @@ def sample( save_intermediates: bool | None = False, intermediate_steps: int | None = 100, conditioning: torch.Tensor | None = None, + mode: str = "crossattn", verbose: bool = True, ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: """ @@ -79,8 +88,12 @@ def sample( save_intermediates: whether to return intermediates along the sampling change intermediate_steps: if save_intermediates is True, saves every n steps conditioning: Conditioning for network input. + mode: Conditioning mode for the network. verbose: if true, prints the progression bar of the sampling process. """ + if mode not in ["crossattn", "concat"]: + raise NotImplementedError(f"{mode} condition is not supported") + if not scheduler: scheduler = self.scheduler image = input_noise @@ -91,9 +104,15 @@ def sample( intermediates = [] for t in progress_bar: # 1. predict noise model_output - model_output = diffusion_model( - image, timesteps=torch.Tensor((t,)).to(input_noise.device), context=conditioning - ) + if mode == "concat": + model_input = torch.cat([image, conditioning], dim=1) + model_output = diffusion_model( + model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), context=None + ) + else: + model_output = diffusion_model( + image, timesteps=torch.Tensor((t,)).to(input_noise.device), context=conditioning + ) # 2. compute previous image: x_t -> x_t-1 image, _ = scheduler.step(model_output, t, image) @@ -112,6 +131,7 @@ def get_likelihood( scheduler: Callable[..., torch.Tensor] | None = None, save_intermediates: bool | None = False, conditioning: torch.Tensor | None = None, + mode: str = "crossattn", original_input_range: tuple | None = (0, 255), scaled_input_range: tuple | None = (0, 1), verbose: bool = True, @@ -125,6 +145,7 @@ def get_likelihood( scheduler: diffusion scheduler. If none provided will use the class attribute scheduler. save_intermediates: save the intermediate spatial KL maps conditioning: Conditioning for network input. + mode: Conditioning mode for the network. original_input_range: the [min,max] intensity range of the input data before any scaling was applied. scaled_input_range: the [min,max] intensity range of the input data after scaling. verbose: if true, prints the progression bar of the sampling process. @@ -137,6 +158,8 @@ def get_likelihood( f"Likelihood computation is only compatible with DDPMScheduler," f" you are using {scheduler._get_name()}" ) + if mode not in ["crossattn", "concat"]: + raise NotImplementedError(f"{mode} condition is not supported") if verbose and has_tqdm: progress_bar = tqdm(scheduler.timesteps) else: @@ -147,7 +170,11 @@ def get_likelihood( for t in progress_bar: timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long() noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) - model_output = diffusion_model(x=noisy_image, timesteps=timesteps, context=conditioning) + if mode == "concat": + noisy_image = torch.cat([noisy_image, conditioning], dim=1) + model_output = diffusion_model(noisy_image, timesteps=timesteps, context=None) + else: + model_output = diffusion_model(x=noisy_image, timesteps=timesteps, context=conditioning) # get the model's predicted mean, and variance if it is predicted if model_output.shape[1] == inputs.shape[1] * 2 and scheduler.variance_type in ["learned", "learned_range"]: model_output, predicted_variance = torch.split(model_output, inputs.shape[1], dim=1) @@ -290,6 +317,7 @@ def __call__( noise: torch.Tensor, timesteps: torch.Tensor, condition: torch.Tensor | None = None, + mode: str = "crossattn", ) -> torch.Tensor: """ Implements the forward pass for a supervised training iteration. @@ -301,12 +329,18 @@ def __call__( noise: random noise, of the same shape as the latent representation. timesteps: random timesteps. condition: conditioning for network input. + mode: Conditioning mode for the network. """ with torch.no_grad(): latent = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor prediction = super().__call__( - inputs=latent, diffusion_model=diffusion_model, noise=noise, timesteps=timesteps, condition=condition + inputs=latent, + diffusion_model=diffusion_model, + noise=noise, + timesteps=timesteps, + condition=condition, + mode=mode, ) return prediction @@ -321,6 +355,7 @@ def sample( save_intermediates: bool | None = False, intermediate_steps: int | None = 100, conditioning: torch.Tensor | None = None, + mode: str = "crossattn", verbose: bool = True, ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: """ @@ -332,6 +367,7 @@ def sample( save_intermediates: whether to return intermediates along the sampling change intermediate_steps: if save_intermediates is True, saves every n steps conditioning: Conditioning for network input. + mode: Conditioning mode for the network. verbose: if true, prints the progression bar of the sampling process. """ outputs = super().sample( @@ -341,6 +377,7 @@ def sample( save_intermediates=save_intermediates, intermediate_steps=intermediate_steps, conditioning=conditioning, + mode=mode, verbose=verbose, ) @@ -369,6 +406,7 @@ def get_likelihood( scheduler: Callable[..., torch.Tensor] | None = None, save_intermediates: bool | None = False, conditioning: torch.Tensor | None = None, + mode: str = "crossattn", original_input_range: tuple | None = (0, 255), scaled_input_range: tuple | None = (0, 1), verbose: bool = True, @@ -385,6 +423,7 @@ def get_likelihood( scheduler: diffusion scheduler. If none provided will use the class attribute scheduler save_intermediates: save the intermediate spatial KL maps conditioning: Conditioning for network input. + mode: Conditioning mode for the network. original_input_range: the [min,max] intensity range of the input data before any scaling was applied. scaled_input_range: the [min,max] intensity range of the input data after scaling. verbose: if true, prints the progression bar of the sampling process. @@ -404,6 +443,7 @@ def get_likelihood( scheduler=scheduler, save_intermediates=save_intermediates, conditioning=conditioning, + mode=mode, verbose=verbose, ) if save_intermediates and resample_latent_likelihoods: diff --git a/tests/test_diffusion_inferer.py b/tests/test_diffusion_inferer.py index 6faf0e68..f3f9aa78 100644 --- a/tests/test_diffusion_inferer.py +++ b/tests/test_diffusion_inferer.py @@ -161,6 +161,62 @@ def test_normal_cdf(self): cdf_true = norm.cdf(x) torch.testing.assert_allclose(cdf_approx, cdf_true, atol=1e-3, rtol=1e-5) + @parameterized.expand(TEST_CASES) + def test_sampler_conditioned_concat(self, model_params, input_shape): + # copy the model_params dict to prevent from modifying test cases + model_params = model_params.copy() + n_concat_channel = 2 + model_params["in_channels"] = model_params["in_channels"] + n_concat_channel + model_params["cross_attention_dim"] = None + model_params["with_conditioning"] = False + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + noise = torch.randn(input_shape).to(device) + conditioning_shape = list(input_shape) + conditioning_shape[1] = n_concat_channel + conditioning = torch.randn(conditioning_shape).to(device) + scheduler = DDIMScheduler(num_train_timesteps=1000) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, + diffusion_model=model, + scheduler=scheduler, + save_intermediates=True, + intermediate_steps=1, + conditioning=conditioning, + mode="concat", + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(TEST_CASES) + def test_call_conditioned_concat(self, model_params, input_shape): + # copy the model_params dict to prevent from modifying test cases + model_params = model_params.copy() + n_concat_channel = 2 + model_params["in_channels"] = model_params["in_channels"] + n_concat_channel + model_params["cross_attention_dim"] = None + model_params["with_conditioning"] = False + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + input = torch.randn(input_shape).to(device) + noise = torch.randn(input_shape).to(device) + conditioning_shape = list(input_shape) + conditioning_shape[1] = n_concat_channel + conditioning = torch.randn(conditioning_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + sample = inferer( + inputs=input, noise=noise, diffusion_model=model, timesteps=timesteps, condition=conditioning, mode="concat" + ) + self.assertEqual(sample.shape, input_shape) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_latent_diffusion_inferer.py b/tests/test_latent_diffusion_inferer.py index a049e1d5..296e9266 100644 --- a/tests/test_latent_diffusion_inferer.py +++ b/tests/test_latent_diffusion_inferer.py @@ -245,6 +245,86 @@ def test_resample_likelihoods(self, model_type, autoencoder_params, stage_2_para self.assertEqual(len(intermediates), 10) self.assertEqual(intermediates[0].shape[2:], input_shape[2:]) + @parameterized.expand(TEST_CASES) + def test_prediction_shape_conditioned_concat( + self, model_type, autoencoder_params, stage_2_params, input_shape, latent_shape + ): + if model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + + stage_2_params = stage_2_params.copy() + n_concat_channel = 3 + stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel + stage_2 = DiffusionModelUNet(**stage_2_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + noise = torch.randn(latent_shape).to(device) + conditioning_shape = list(latent_shape) + conditioning_shape[1] = n_concat_channel + conditioning = torch.randn(conditioning_shape).to(device) + + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + noise=noise, + timesteps=timesteps, + condition=conditioning, + mode="concat", + ) + self.assertEqual(prediction.shape, latent_shape) + + @parameterized.expand(TEST_CASES) + def test_sample_shape_conditioned_concat( + self, model_type, autoencoder_params, stage_2_params, input_shape, latent_shape + ): + if model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + stage_2_params = stage_2_params.copy() + n_concat_channel = 3 + stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel + stage_2 = DiffusionModelUNet(**stage_2_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + noise = torch.randn(latent_shape).to(device) + conditioning_shape = list(latent_shape) + conditioning_shape[1] = n_concat_channel + conditioning = torch.randn(conditioning_shape).to(device) + + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + sample = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + conditioning=conditioning, + mode="concat", + ) + self.assertEqual(sample.shape, input_shape) + if __name__ == "__main__": unittest.main()