Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.

Conversation

@marksgraham
Copy link
Collaborator

Fixes #314

@Warvito Warvito self-requested a review March 17, 2023 21:04
Copy link
Collaborator

@Warvito Warvito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Mark, thanks for working in this tutorial. During the review, I found a few things that might be wrong in the inferer (besides the ones pointed in the review). I will try to investigate it further


# if we have not covered the full sequence we continue with inefficient looping
if probs.shape[1] < latent.shape[1]:
if logits.shape[1] < latent.shape[1]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it might have something wrong here, because this logits.shape[1] < latent.shape[1]: will always be true since logits are size= spatial_shape[0] * spatial_shape[1] and latent will be it +1 (BOS)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Running the tests, i find the logits and the latents have the same shape, unless
transformer_model.max_seq_len < (spatial_shape[0] * spatial_shape[1])+1
that is the logits also have shape (spatial_shape[0] * spatial_shape[1])+1

Copy link
Collaborator

@Warvito Warvito Mar 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, but usually the transformer.max_seq_len=(spatial_shape[0] * spatial_shape[1]). Here, are you considering cases where max_seq_len = (spatial_shape[0] * spatial_shape[1])+1 because we pad the BOS token?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I've always been setting max_seq_len = (spatial_shape[0] * spatial_shape[1])+1 in my networks. Have you been doing it without the +1? In all the tests for the VQVAETransformerInferer it is set to (spatial_shape[0] * spatial_shape[1])+1

@marksgraham marksgraham merged commit 78fde33 into main Mar 20, 2023
@Warvito Warvito deleted the 314_fix_transformer_training branch March 20, 2023 22:43
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Fix transformer training

3 participants