Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions generative/inferers/inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,14 +580,16 @@ def get_likelihood(
# get the first batch, up to max_seq_length, efficiently
logits = transformer_model(x=latent[:, : transformer_model.max_seq_len], context=condition)
probs = F.softmax(logits, dim=-1)
probs = torch.gather(probs, 2, latent[:, : transformer_model.max_seq_len].unsqueeze(2)).squeeze(2)
# target token for each set of logits is the next token along
target = latent[:, 1:]
probs = torch.gather(probs, 2, target[:, : transformer_model.max_seq_len].unsqueeze(2)).squeeze(2)

# if we have not covered the full sequence we continue with inefficient looping
if probs.shape[1] < latent.shape[1]:
if logits.shape[1] < latent.shape[1]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it might have something wrong here, because this logits.shape[1] < latent.shape[1]: will always be true since logits are size= spatial_shape[0] * spatial_shape[1] and latent will be it +1 (BOS)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Running the tests, i find the logits and the latents have the same shape, unless
transformer_model.max_seq_len < (spatial_shape[0] * spatial_shape[1])+1
that is the logits also have shape (spatial_shape[0] * spatial_shape[1])+1

Copy link
Collaborator

@Warvito Warvito Mar 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, but usually the transformer.max_seq_len=(spatial_shape[0] * spatial_shape[1]). Here, are you considering cases where max_seq_len = (spatial_shape[0] * spatial_shape[1])+1 because we pad the BOS token?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I've always been setting max_seq_len = (spatial_shape[0] * spatial_shape[1])+1 in my networks. Have you been doing it without the +1? In all the tests for the VQVAETransformerInferer it is set to (spatial_shape[0] * spatial_shape[1])+1

if verbose and has_tqdm:
progress_bar = tqdm(range(transformer_model.max_seq_len, seq_len + 1))
progress_bar = tqdm(range(transformer_model.max_seq_len, seq_len))
else:
progress_bar = iter(range(transformer_model.max_seq_len, seq_len + 1))
progress_bar = iter(range(transformer_model.max_seq_len, seq_len))

for i in progress_bar:
idx_cond = latent[:, i + 1 - transformer_model.max_seq_len : i + 1]
Expand All @@ -598,11 +600,11 @@ def get_likelihood(
# apply softmax to convert logits to (normalized) probabilities
p = F.softmax(logits, dim=-1)
# select correct values and append
p = torch.gather(p, 1, idx_cond[:, -1].unsqueeze(1))
p = torch.gather(p, 1, target[:, i].unsqueeze(1))

probs = torch.cat((probs, p), dim=1)

# remove starting token probability
probs = probs[:, 1:]
# convert to log-likelihood
probs = torch.log(probs)

# reshape
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.14.4
# jupytext_version: 1.14.1
# kernelspec:
# display_name: Python 3 (ipykernel)
# language: python
Expand Down Expand Up @@ -313,11 +313,11 @@

# %% [markdown]
# ### Transformer Training
# We will train the Transformer for 5 epochs.
# We will train the Transformer for 20 epochs.

# %%
n_epochs = 5
val_interval = 2
n_epochs = 20
val_interval = 5
epoch_losses = []
val_epoch_losses = []
vqvae_model.eval()
Expand All @@ -337,7 +337,8 @@
logits, quantizations_target, _ = inferer(images, vqvae_model, transformer_model, ordering, return_latent=True)
logits = logits.transpose(1, 2)

loss = ce_loss(logits, quantizations_target)
# train the transformer to predict token n+1 using tokens 0-n
loss = ce_loss(logits[:, :, :-1], quantizations_target[:, 1:])

loss.backward()
optimizer.step()
Expand All @@ -360,10 +361,22 @@
)
logits = logits.transpose(1, 2)

loss = ce_loss(logits, quantizations_target)
loss = ce_loss(logits[:, :, :-1], quantizations_target[:, 1:])

val_loss += loss.item()

# get sample
sample = inferer.sample(
vqvae_model=vqvae_model,
transformer_model=transformer_model,
ordering=ordering,
latent_spatial_dim=(spatial_shape[0], spatial_shape[1]),
starting_tokens=vqvae_model.num_embeddings * torch.ones((1, 1), device=device),
)
plt.imshow(sample[0, 0, ...].cpu().detach())
plt.title(f"Sample epoch {epoch}")
plt.show()
val_loss /= val_step
val_epoch_losses.append(val_loss)
val_loss /= val_step
val_epoch_losses.append(val_loss)

Expand Down Expand Up @@ -424,13 +437,13 @@
ood_likelihoods = np.concatenate(ood_likelihoods)

# %% [markdown]
# ## Log-likehood plot
# ## Log-likelihood plot
#
# Here, we plot the log-likelihood of the images. In this case, the lower the log-likelihood, the more unlikely the image belongs to the training set.

# %%
sns.kdeplot(in_likelihoods, color="dodgerblue", bw_adjust=50, label="In-distribution")
sns.kdeplot(ood_likelihoods, color="deeppink", bw_adjust=1, label="OOD")
sns.kdeplot(in_likelihoods, color="dodgerblue", bw_adjust=1, label="In-distribution")
sns.kdeplot(ood_likelihoods, color="deeppink", bw_adjust=10, label="OOD")
plt.legend()
plt.xlabel("Log-likelihood")

Expand Down