Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 62 additions & 3 deletions generative/inferers/inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,8 @@ def __call__(
transformer_model: Callable[..., torch.Tensor],
ordering: Callable[..., torch.Tensor],
condition: torch.Tensor | None = None,
) -> torch.Tensor:
return_latent: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor, tuple]:
"""
Implements the forward pass for a supervised training iteration.

Expand All @@ -441,11 +442,13 @@ def __call__(
vqvae_model: first stage model.
transformer_model: autoregressive transformer model.
ordering: ordering of the quantised latent representation.
return_latent: also return latent sequence and spatial dim of the latent.
condition: conditioning for network input.
"""
with torch.no_grad():
latent = vqvae_model.index_quantize(inputs)

latent_spatial_dim = tuple(latent.shape[1:])
latent = latent.reshape(latent.shape[0], -1)
latent = latent[:, ordering.get_sequence_ordering()]

Expand All @@ -455,8 +458,10 @@ def __call__(
latent = latent.long()

prediction = transformer_model(x=latent, context=condition)

return prediction
if return_latent:
return prediction, latent, latent_spatial_dim
else:
return prediction

@torch.no_grad()
def sample(
Expand Down Expand Up @@ -521,3 +526,57 @@ def sample(
latent = latent_seq.reshape((starting_tokens.shape[0],) + latent_spatial_dim)

return vqvae_model.decode_samples(latent)

@torch.no_grad()
def get_likelihood(
self,
inputs: torch.Tensor,
vqvae_model: Callable[..., torch.Tensor],
transformer_model: Callable[..., torch.Tensor],
ordering: Callable[..., torch.Tensor],
condition: torch.Tensor | None = None,
resample_latent_likelihoods: bool = False,
resample_interpolation_mode: str = "trilinear",
) -> torch.Tensor:
"""
Computes the log-likelihoods of the latent representations of the input.

Args:
inputs: input images, NxCxHxW[xD]
vqvae_model: first stage model.
transformer_model: autoregressive transformer model.
ordering: ordering of the quantised latent representation.
condition: conditioning for network input.
resample_latent_likelihoods: if true, resamples the intermediate likelihood maps to have the same spatial
dimension as the input images.
resample_interpolation_mode: if use resample_latent_likelihoods, select interpolation 'nearest', 'bilinear',
or 'trilinear;
"""
if resample_interpolation_mode not in ("nearest", "bilinear", "trilinear"):
raise ValueError(
f"resample_interpolation mode should be either nearest, bilinear, or trilinear, got {resample_interpolation_mode}"
)
logits, latent, latent_spatial_dim = self(
inputs=inputs,
vqvae_model=vqvae_model,
transformer_model=transformer_model,
ordering=ordering,
condition=condition,
return_latent=True,
)
all_probs = F.softmax(logits, dim=-1)
probs = torch.gather(all_probs, 2, latent.unsqueeze(2))
probs = probs.squeeze(2)

# remove the probability of the starting token
probs = probs[:, 1:]
probs = torch.log(probs)

# reshape
probs = probs[:, ordering.get_revert_sequence_ordering()]
probs_reshaped = probs.reshape((inputs.shape[0],) + latent_spatial_dim)
if resample_latent_likelihoods:
resizer = nn.Upsample(size=inputs.shape[2:], mode=resample_interpolation_mode)
probs_reshaped = resizer(probs_reshaped[:, None, ...])

return probs_reshaped
57 changes: 55 additions & 2 deletions tests/test_vqvaetransformer_inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
{"ordering_type": OrderingType.RASTER_SCAN.value, "spatial_dims": 2, "dimensions": (2, 2, 2)},
(2, 1, 8, 8),
(2, 5, 17),
(2, 2, 2),
],
[
{
Expand All @@ -70,13 +71,16 @@
{"ordering_type": OrderingType.RASTER_SCAN.value, "spatial_dims": 3, "dimensions": (2, 2, 2, 2)},
(2, 1, 8, 8, 8),
(2, 9, 17),
(2, 2, 2, 2),
],
]


class TestVQVAETransformerInferer(unittest.TestCase):
@parameterized.expand(TEST_CASES)
def test_prediction_shape(self, stage_1_params, stage_2_params, ordering_params, input_shape, latent_shape):
def test_prediction_shape(
self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape
):
stage_1 = VQVAE(**stage_1_params)
stage_2 = DecoderOnlyTransformer(**stage_2_params)
ordering = Ordering(**ordering_params)
Expand All @@ -91,7 +95,7 @@ def test_prediction_shape(self, stage_1_params, stage_2_params, ordering_params,

inferer = VQVAETransformerInferer()
prediction = inferer(inputs=input, vqvae_model=stage_1, transformer_model=stage_2, ordering=ordering)
self.assertEqual(prediction.shape, latent_shape)
self.assertEqual(prediction.shape, logits_shape)

def test_sample(self):
stage_1 = VQVAE(
Expand Down Expand Up @@ -135,6 +139,55 @@ def test_sample(self):
)
self.assertEqual(sample.shape, (2, 1, 8, 8))

@parameterized.expand(TEST_CASES)
def test_get_likelihood(
self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape
):
stage_1 = VQVAE(**stage_1_params)
stage_2 = DecoderOnlyTransformer(**stage_2_params)
ordering = Ordering(**ordering_params)

device = "cuda:0" if torch.cuda.is_available() else "cpu"
stage_1.to(device)
stage_2.to(device)
stage_1.eval()
stage_2.eval()

input = torch.randn(input_shape).to(device)

inferer = VQVAETransformerInferer()
likelihood = inferer.get_likelihood(
inputs=input, vqvae_model=stage_1, transformer_model=stage_2, ordering=ordering
)
self.assertEqual(likelihood.shape, latent_shape)

@parameterized.expand(TEST_CASES)
def test_get_likelihood_resampling(
self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape
):
stage_1 = VQVAE(**stage_1_params)
stage_2 = DecoderOnlyTransformer(**stage_2_params)
ordering = Ordering(**ordering_params)

device = "cuda:0" if torch.cuda.is_available() else "cpu"
stage_1.to(device)
stage_2.to(device)
stage_1.eval()
stage_2.eval()

input = torch.randn(input_shape).to(device)

inferer = VQVAETransformerInferer()
likelihood = inferer.get_likelihood(
inputs=input,
vqvae_model=stage_1,
transformer_model=stage_2,
ordering=ordering,
resample_latent_likelihoods=True,
resample_interpolation_mode="nearest",
)
self.assertEqual(likelihood.shape, input_shape)


if __name__ == "__main__":
unittest.main()