-
Notifications
You must be signed in to change notification settings - Fork 105
314 fix transformer training #318
Conversation
Warvito
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Mark, thanks for working in this tutorial. During the review, I found a few things that might be wrong in the inferer (besides the ones pointed in the review). I will try to investigate it further
|
|
||
| # if we have not covered the full sequence we continue with inefficient looping | ||
| if probs.shape[1] < latent.shape[1]: | ||
| if logits.shape[1] < latent.shape[1]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it might have something wrong here, because this logits.shape[1] < latent.shape[1]: will always be true since logits are size= spatial_shape[0] * spatial_shape[1] and latent will be it +1 (BOS)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Running the tests, i find the logits and the latents have the same shape, unless
transformer_model.max_seq_len < (spatial_shape[0] * spatial_shape[1])+1
that is the logits also have shape (spatial_shape[0] * spatial_shape[1])+1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, but usually the transformer.max_seq_len=(spatial_shape[0] * spatial_shape[1]). Here, are you considering cases where max_seq_len = (spatial_shape[0] * spatial_shape[1])+1 because we pad the BOS token?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I've always been setting max_seq_len = (spatial_shape[0] * spatial_shape[1])+1 in my networks. Have you been doing it without the +1? In all the tests for the VQVAETransformerInferer it is set to (spatial_shape[0] * spatial_shape[1])+1
Fixes #314