Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 50 additions & 14 deletions generative/inferers/inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,9 +454,16 @@ def __call__(
latent = F.pad(latent, (1, 0), "constant", vqvae_model.num_embeddings)
latent = latent.long()

prediction = transformer_model(x=latent, context=condition)
# train on a part of the sequence if it is longer than max_seq_length
seq_len = latent.shape[1]
max_seq_len = transformer_model.max_seq_len
if max_seq_len < seq_len:
start = torch.randint(low=0, high=seq_len + 1 - max_seq_len, size=(1,)).item()
else:
start = 0
prediction = transformer_model(x=latent[:, start : start + max_seq_len], context=condition)
if return_latent:
return prediction, latent, latent_spatial_dim
return prediction, latent[:, start : start + max_seq_len], latent_spatial_dim
else:
return prediction

Expand Down Expand Up @@ -534,6 +541,7 @@ def get_likelihood(
condition: torch.Tensor | None = None,
resample_latent_likelihoods: bool = False,
resample_interpolation_mode: str = "nearest",
verbose: bool = False,
) -> torch.Tensor:
"""
Computes the log-likelihoods of the latent representations of the input.
Expand All @@ -548,24 +556,52 @@ def get_likelihood(
dimension as the input images.
resample_interpolation_mode: if use resample_latent_likelihoods, select interpolation 'nearest', 'bilinear',
or 'trilinear;
verbose: if true, prints the progression bar of the sampling process.

"""
if resample_latent_likelihoods and resample_interpolation_mode not in ("nearest", "bilinear", "trilinear"):
raise ValueError(
f"resample_interpolation mode should be either nearest, bilinear, or trilinear, got {resample_interpolation_mode}"
)
logits, latent, latent_spatial_dim = self(
inputs=inputs,
vqvae_model=vqvae_model,
transformer_model=transformer_model,
ordering=ordering,
condition=condition,
return_latent=True,
)
all_probs = F.softmax(logits, dim=-1)
probs = torch.gather(all_probs, 2, latent.unsqueeze(2))
probs = probs.squeeze(2)

# remove the probability of the starting token
with torch.no_grad():
latent = vqvae_model.index_quantize(inputs)

latent_spatial_dim = tuple(latent.shape[1:])
latent = latent.reshape(latent.shape[0], -1)
latent = latent[:, ordering.get_sequence_ordering()]
seq_len = math.prod(latent_spatial_dim)

# Use the value from vqvae_model's num_embeddings as the starting token, the "Begin Of Sentence" (BOS) token.
# Note the transformer_model must have vqvae_model.num_embeddings + 1 defined as num_tokens.
latent = F.pad(latent, (1, 0), "constant", vqvae_model.num_embeddings)
latent = latent.long()

# get the first batch, up to max_seq_length, efficiently
logits = transformer_model(x=latent[:, : transformer_model.max_seq_len], context=condition)
probs = F.softmax(logits, dim=-1)
probs = torch.gather(probs, 2, latent[:, : transformer_model.max_seq_len].unsqueeze(2)).squeeze(2)

# if we have not covered the full sequence we continue with inefficient looping
if probs.shape[1] < latent.shape[1]:
if verbose and has_tqdm:
progress_bar = tqdm(range(transformer_model.max_seq_len, seq_len + 1))
else:
progress_bar = iter(range(transformer_model.max_seq_len, seq_len + 1))

for i in progress_bar:
idx_cond = latent[:, i + 1 - transformer_model.max_seq_len : i + 1]
# forward the model to get the logits for the index in the sequence
logits = transformer_model(x=idx_cond, context=condition)
# pluck the logits at the final step
logits = logits[:, -1, :]
# apply softmax to convert logits to (normalized) probabilities
p = F.softmax(logits, dim=-1)
# select correct values and append
p = torch.gather(p, 1, idx_cond[:, -1].unsqueeze(1))
probs = torch.cat((probs, p), dim=1)

# remove starting token probability
probs = probs[:, 1:]
probs = torch.log(probs)

Expand Down
91 changes: 91 additions & 0 deletions tests/test_vqvaetransformer_inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,30 @@ def test_prediction_shape(
prediction = inferer(inputs=input, vqvae_model=stage_1, transformer_model=stage_2, ordering=ordering)
self.assertEqual(prediction.shape, logits_shape)

@parameterized.expand(TEST_CASES)
def test_prediction_shape_shorter_sequence(
self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape
):
stage_1 = VQVAE(**stage_1_params)
max_seq_len = 3
stage_2_params_shorter = {k: v for k, v in stage_2_params.items()}
stage_2_params_shorter["max_seq_len"] = max_seq_len
stage_2 = DecoderOnlyTransformer(**stage_2_params_shorter)
ordering = Ordering(**ordering_params)

device = "cuda:0" if torch.cuda.is_available() else "cpu"
stage_1.to(device)
stage_2.to(device)
stage_1.eval()
stage_2.eval()

input = torch.randn(input_shape).to(device)

inferer = VQVAETransformerInferer()
prediction = inferer(inputs=input, vqvae_model=stage_1, transformer_model=stage_2, ordering=ordering)
cropped_logits_shape = (logits_shape[0], max_seq_len, logits_shape[2])
self.assertEqual(prediction.shape, cropped_logits_shape)

def test_sample(self):
stage_1 = VQVAE(
spatial_dims=2,
Expand Down Expand Up @@ -139,6 +163,48 @@ def test_sample(self):
)
self.assertEqual(sample.shape, (2, 1, 8, 8))

def test_sample_shorter_sequence(self):
stage_1 = VQVAE(
spatial_dims=2,
in_channels=1,
out_channels=1,
num_channels=(8, 8),
num_res_channels=(8, 8),
downsample_parameters=((2, 4, 1, 1),) * 2,
upsample_parameters=((2, 4, 1, 1, 0),) * 2,
num_res_layers=1,
num_embeddings=16,
embedding_dim=8,
)
stage_2 = DecoderOnlyTransformer(
num_tokens=16 + 1,
max_seq_len=3,
attn_layers_dim=4,
attn_layers_depth=2,
attn_layers_heads=1,
with_cross_attention=False,
)
ordering = Ordering(ordering_type=OrderingType.RASTER_SCAN.value, spatial_dims=2, dimensions=(2, 2, 2))

device = "cuda:0" if torch.cuda.is_available() else "cpu"
stage_1.to(device)
stage_2.to(device)
stage_1.eval()
stage_2.eval()

inferer = VQVAETransformerInferer()

starting_token = 16 # from stage_1 num_embeddings

sample = inferer.sample(
latent_spatial_dim=(2, 2),
starting_tokens=starting_token * torch.ones((2, 1), device=device),
vqvae_model=stage_1,
transformer_model=stage_2,
ordering=ordering,
)
self.assertEqual(sample.shape, (2, 1, 8, 8))

@parameterized.expand(TEST_CASES)
def test_get_likelihood(
self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape
Expand All @@ -161,6 +227,31 @@ def test_get_likelihood(
)
self.assertEqual(likelihood.shape, latent_shape)

@parameterized.expand(TEST_CASES)
def test_get_likelihood_shorter_sequence(
self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape
):
stage_1 = VQVAE(**stage_1_params)
max_seq_len = 3
stage_2_params_shorter = {k: v for k, v in stage_2_params.items()}
stage_2_params_shorter["max_seq_len"] = max_seq_len
stage_2 = DecoderOnlyTransformer(**stage_2_params_shorter)
ordering = Ordering(**ordering_params)

device = "cuda:0" if torch.cuda.is_available() else "cpu"
stage_1.to(device)
stage_2.to(device)
stage_1.eval()
stage_2.eval()

input = torch.randn(input_shape).to(device)

inferer = VQVAETransformerInferer()
likelihood = inferer.get_likelihood(
inputs=input, vqvae_model=stage_1, transformer_model=stage_2, ordering=ordering
)
self.assertEqual(likelihood.shape, latent_shape)

@parameterized.expand(TEST_CASES)
def test_get_likelihood_resampling(
self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape
Expand Down