Paligemma: fix static cache test#33941
Conversation
| causal_mask = torch.triu(causal_mask, diagonal=1) | ||
| else: | ||
| causal_mask = torch.zeros_like(causal_mask) | ||
| causal_mask[:, :sequence_length] = 0.0 |
There was a problem hiding this comment.
this was the cause as it was not masking dummy tokens from static cache, and thus we always ended up with no mask on those token positions
| min_dtype=min_dtype, | ||
| cache_position=cache_position, | ||
| batch_size=batch_size, | ||
| is_training=is_training, |
There was a problem hiding this comment.
if we come to prepare static cache from here, then we cannot be in training mode. I don't think it is common to pass labels through generation, right?
There was a problem hiding this comment.
I'm not seeing many use-cases indeed, except for maybe constrained generation and RL?
There was a problem hiding this comment.
guess so, let's see what generation master (gante) thinks 😄
There was a problem hiding this comment.
If labels in paligemma has the usual meaning (=tensor with which we compute the loss, with no further uses), then generate will never use labels :D
There was a problem hiding this comment.
nice, yes those are normal labels :)
molbap
left a comment
There was a problem hiding this comment.
LGTM, added comment on training case for generation :)
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
gante
left a comment
There was a problem hiding this comment.
LGTM, thank you for fixing 🤗
* fix * not flaky anymore + style
What does this PR do?
Fixes the flaky test on paligemma from #33630