Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 45 additions & 5 deletions generative/inferers/inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __call__(
noise: torch.Tensor,
timesteps: torch.Tensor,
condition: torch.Tensor | None = None,
mode: str = "crossattn",
) -> torch.Tensor:
"""
Implements the forward pass for a supervised training iteration.
Expand All @@ -54,8 +55,15 @@ def __call__(
noise: random noise, of the same shape as the input.
timesteps: random timesteps.
condition: Conditioning for network input.
mode: Conditioning mode for the network.
"""
if mode not in ["crossattn", "concat"]:
raise NotImplementedError(f"{mode} condition is not supported")

noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
if mode == "concat":
noisy_image = torch.cat([noisy_image, condition], dim=1)
condition = None
prediction = diffusion_model(x=noisy_image, timesteps=timesteps, context=condition)

return prediction
Expand All @@ -69,6 +77,7 @@ def sample(
save_intermediates: bool | None = False,
intermediate_steps: int | None = 100,
conditioning: torch.Tensor | None = None,
mode: str = "crossattn",
verbose: bool = True,
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
"""
Expand All @@ -79,8 +88,12 @@ def sample(
save_intermediates: whether to return intermediates along the sampling change
intermediate_steps: if save_intermediates is True, saves every n steps
conditioning: Conditioning for network input.
mode: Conditioning mode for the network.
verbose: if true, prints the progression bar of the sampling process.
"""
if mode not in ["crossattn", "concat"]:
raise NotImplementedError(f"{mode} condition is not supported")

if not scheduler:
scheduler = self.scheduler
image = input_noise
Expand All @@ -91,9 +104,15 @@ def sample(
intermediates = []
for t in progress_bar:
# 1. predict noise model_output
model_output = diffusion_model(
image, timesteps=torch.Tensor((t,)).to(input_noise.device), context=conditioning
)
if mode == "concat":
model_input = torch.cat([image, conditioning], dim=1)
model_output = diffusion_model(
model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), context=None
)
else:
model_output = diffusion_model(
image, timesteps=torch.Tensor((t,)).to(input_noise.device), context=conditioning
)

# 2. compute previous image: x_t -> x_t-1
image, _ = scheduler.step(model_output, t, image)
Expand All @@ -112,6 +131,7 @@ def get_likelihood(
scheduler: Callable[..., torch.Tensor] | None = None,
save_intermediates: bool | None = False,
conditioning: torch.Tensor | None = None,
mode: str = "crossattn",
original_input_range: tuple | None = (0, 255),
scaled_input_range: tuple | None = (0, 1),
verbose: bool = True,
Expand All @@ -125,6 +145,7 @@ def get_likelihood(
scheduler: diffusion scheduler. If none provided will use the class attribute scheduler.
save_intermediates: save the intermediate spatial KL maps
conditioning: Conditioning for network input.
mode: Conditioning mode for the network.
original_input_range: the [min,max] intensity range of the input data before any scaling was applied.
scaled_input_range: the [min,max] intensity range of the input data after scaling.
verbose: if true, prints the progression bar of the sampling process.
Expand All @@ -137,6 +158,8 @@ def get_likelihood(
f"Likelihood computation is only compatible with DDPMScheduler,"
f" you are using {scheduler._get_name()}"
)
if mode not in ["crossattn", "concat"]:
raise NotImplementedError(f"{mode} condition is not supported")
if verbose and has_tqdm:
progress_bar = tqdm(scheduler.timesteps)
else:
Expand All @@ -147,7 +170,11 @@ def get_likelihood(
for t in progress_bar:
timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long()
noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
model_output = diffusion_model(x=noisy_image, timesteps=timesteps, context=conditioning)
if mode == "concat":
noisy_image = torch.cat([noisy_image, conditioning], dim=1)
model_output = diffusion_model(noisy_image, timesteps=timesteps, context=None)
else:
model_output = diffusion_model(x=noisy_image, timesteps=timesteps, context=conditioning)
# get the model's predicted mean, and variance if it is predicted
if model_output.shape[1] == inputs.shape[1] * 2 and scheduler.variance_type in ["learned", "learned_range"]:
model_output, predicted_variance = torch.split(model_output, inputs.shape[1], dim=1)
Expand Down Expand Up @@ -290,6 +317,7 @@ def __call__(
noise: torch.Tensor,
timesteps: torch.Tensor,
condition: torch.Tensor | None = None,
mode: str = "crossattn",
) -> torch.Tensor:
"""
Implements the forward pass for a supervised training iteration.
Expand All @@ -301,12 +329,18 @@ def __call__(
noise: random noise, of the same shape as the latent representation.
timesteps: random timesteps.
condition: conditioning for network input.
mode: Conditioning mode for the network.
"""
with torch.no_grad():
latent = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor

prediction = super().__call__(
inputs=latent, diffusion_model=diffusion_model, noise=noise, timesteps=timesteps, condition=condition
inputs=latent,
diffusion_model=diffusion_model,
noise=noise,
timesteps=timesteps,
condition=condition,
mode=mode,
)

return prediction
Expand All @@ -321,6 +355,7 @@ def sample(
save_intermediates: bool | None = False,
intermediate_steps: int | None = 100,
conditioning: torch.Tensor | None = None,
mode: str = "crossattn",
verbose: bool = True,
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
"""
Expand All @@ -332,6 +367,7 @@ def sample(
save_intermediates: whether to return intermediates along the sampling change
intermediate_steps: if save_intermediates is True, saves every n steps
conditioning: Conditioning for network input.
mode: Conditioning mode for the network.
verbose: if true, prints the progression bar of the sampling process.
"""
outputs = super().sample(
Expand All @@ -341,6 +377,7 @@ def sample(
save_intermediates=save_intermediates,
intermediate_steps=intermediate_steps,
conditioning=conditioning,
mode=mode,
verbose=verbose,
)

Expand Down Expand Up @@ -369,6 +406,7 @@ def get_likelihood(
scheduler: Callable[..., torch.Tensor] | None = None,
save_intermediates: bool | None = False,
conditioning: torch.Tensor | None = None,
mode: str = "crossattn",
original_input_range: tuple | None = (0, 255),
scaled_input_range: tuple | None = (0, 1),
verbose: bool = True,
Expand All @@ -385,6 +423,7 @@ def get_likelihood(
scheduler: diffusion scheduler. If none provided will use the class attribute scheduler
save_intermediates: save the intermediate spatial KL maps
conditioning: Conditioning for network input.
mode: Conditioning mode for the network.
original_input_range: the [min,max] intensity range of the input data before any scaling was applied.
scaled_input_range: the [min,max] intensity range of the input data after scaling.
verbose: if true, prints the progression bar of the sampling process.
Expand All @@ -404,6 +443,7 @@ def get_likelihood(
scheduler=scheduler,
save_intermediates=save_intermediates,
conditioning=conditioning,
mode=mode,
verbose=verbose,
)
if save_intermediates and resample_latent_likelihoods:
Expand Down
56 changes: 56 additions & 0 deletions tests/test_diffusion_inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,62 @@ def test_normal_cdf(self):
cdf_true = norm.cdf(x)
torch.testing.assert_allclose(cdf_approx, cdf_true, atol=1e-3, rtol=1e-5)

@parameterized.expand(TEST_CASES)
def test_sampler_conditioned_concat(self, model_params, input_shape):
# copy the model_params dict to prevent from modifying test cases
model_params = model_params.copy()
n_concat_channel = 2
model_params["in_channels"] = model_params["in_channels"] + n_concat_channel
model_params["cross_attention_dim"] = None
model_params["with_conditioning"] = False
model = DiffusionModelUNet(**model_params)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()
noise = torch.randn(input_shape).to(device)
conditioning_shape = list(input_shape)
conditioning_shape[1] = n_concat_channel
conditioning = torch.randn(conditioning_shape).to(device)
scheduler = DDIMScheduler(num_train_timesteps=1000)
inferer = DiffusionInferer(scheduler=scheduler)
scheduler.set_timesteps(num_inference_steps=10)
sample, intermediates = inferer.sample(
input_noise=noise,
diffusion_model=model,
scheduler=scheduler,
save_intermediates=True,
intermediate_steps=1,
conditioning=conditioning,
mode="concat",
)
self.assertEqual(len(intermediates), 10)

@parameterized.expand(TEST_CASES)
def test_call_conditioned_concat(self, model_params, input_shape):
# copy the model_params dict to prevent from modifying test cases
model_params = model_params.copy()
n_concat_channel = 2
model_params["in_channels"] = model_params["in_channels"] + n_concat_channel
model_params["cross_attention_dim"] = None
model_params["with_conditioning"] = False
model = DiffusionModelUNet(**model_params)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()
input = torch.randn(input_shape).to(device)
noise = torch.randn(input_shape).to(device)
conditioning_shape = list(input_shape)
conditioning_shape[1] = n_concat_channel
conditioning = torch.randn(conditioning_shape).to(device)
scheduler = DDPMScheduler(num_train_timesteps=10)
inferer = DiffusionInferer(scheduler=scheduler)
scheduler.set_timesteps(num_inference_steps=10)
timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long()
sample = inferer(
inputs=input, noise=noise, diffusion_model=model, timesteps=timesteps, condition=conditioning, mode="concat"
)
self.assertEqual(sample.shape, input_shape)


if __name__ == "__main__":
unittest.main()
80 changes: 80 additions & 0 deletions tests/test_latent_diffusion_inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,86 @@ def test_resample_likelihoods(self, model_type, autoencoder_params, stage_2_para
self.assertEqual(len(intermediates), 10)
self.assertEqual(intermediates[0].shape[2:], input_shape[2:])

@parameterized.expand(TEST_CASES)
def test_prediction_shape_conditioned_concat(
self, model_type, autoencoder_params, stage_2_params, input_shape, latent_shape
):
if model_type == "AutoencoderKL":
stage_1 = AutoencoderKL(**autoencoder_params)
if model_type == "VQVAE":
stage_1 = VQVAE(**autoencoder_params)

stage_2_params = stage_2_params.copy()
n_concat_channel = 3
stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel
stage_2 = DiffusionModelUNet(**stage_2_params)

device = "cuda:0" if torch.cuda.is_available() else "cpu"
stage_1.to(device)
stage_2.to(device)
stage_1.eval()
stage_2.eval()

input = torch.randn(input_shape).to(device)
noise = torch.randn(latent_shape).to(device)
conditioning_shape = list(latent_shape)
conditioning_shape[1] = n_concat_channel
conditioning = torch.randn(conditioning_shape).to(device)

scheduler = DDPMScheduler(num_train_timesteps=10)
inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
scheduler.set_timesteps(num_inference_steps=10)

timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long()
prediction = inferer(
inputs=input,
autoencoder_model=stage_1,
diffusion_model=stage_2,
noise=noise,
timesteps=timesteps,
condition=conditioning,
mode="concat",
)
self.assertEqual(prediction.shape, latent_shape)

@parameterized.expand(TEST_CASES)
def test_sample_shape_conditioned_concat(
self, model_type, autoencoder_params, stage_2_params, input_shape, latent_shape
):
if model_type == "AutoencoderKL":
stage_1 = AutoencoderKL(**autoencoder_params)
if model_type == "VQVAE":
stage_1 = VQVAE(**autoencoder_params)
stage_2_params = stage_2_params.copy()
n_concat_channel = 3
stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel
stage_2 = DiffusionModelUNet(**stage_2_params)

device = "cuda:0" if torch.cuda.is_available() else "cpu"
stage_1.to(device)
stage_2.to(device)
stage_1.eval()
stage_2.eval()

noise = torch.randn(latent_shape).to(device)
conditioning_shape = list(latent_shape)
conditioning_shape[1] = n_concat_channel
conditioning = torch.randn(conditioning_shape).to(device)

scheduler = DDPMScheduler(num_train_timesteps=10)
inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
scheduler.set_timesteps(num_inference_steps=10)

sample = inferer.sample(
input_noise=noise,
autoencoder_model=stage_1,
diffusion_model=stage_2,
scheduler=scheduler,
conditioning=conditioning,
mode="concat",
)
self.assertEqual(sample.shape, input_shape)


if __name__ == "__main__":
unittest.main()