XGLM - Fix Softmax NaNs when using FP16#18057
Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
| return position_ids.unsqueeze(0).expand(input_shape).contiguous() + past_key_values_length | ||
|
|
||
|
|
||
| # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->XGLM |
There was a problem hiding this comment.
We (HF team) have to remember to add this back once Bart takes the same fix.
younesbelkada
left a comment
There was a problem hiding this comment.
LGTM thanks a lot for the fix 🚀 !
| f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" | ||
| ) | ||
| attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask | ||
| attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) |
There was a problem hiding this comment.
@stas00 is this operation costly? Wondering how costly such a max operation is
There was a problem hiding this comment.
what are the contenders? at least max, clamp and where, but probably others as well?
In [13]: a = torch.tensor([5,-1e20])
In [14]: b = torch.tensor(torch.finfo(torch.float16).min)
In [16]: torch.clamp(a, min=b)
Out[16]: tensor([ 5.0000e+00, -6.5504e+04])
In [21]: torch.where(a > b, a, b)
Out[21]: tensor([ 5.0000e+00, -6.5504e+04])
In [22]: torch.max(a, b)
Out[22]: tensor([ 5.0000e+00, -6.5504e+04])
Benchmark:
$ cat clamp-where-max.py
import torch.utils.benchmark as benchmark
import torch
a = torch.empty(512)
b = torch.tensor(torch.finfo(torch.float16).min)
t0 = benchmark.Timer(
stmt='torch.clamp(a, b)',
setup='',
globals=dict(a=a, b=b),
)
t1 = benchmark.Timer(
stmt='torch.max(a, b)',
setup='',
globals=dict(a=a, b=b),
)
t2 = benchmark.Timer(
stmt='torch.where(a > b, a, b)',
setup='',
globals=dict(a=a, b=b),
)
print(t0.timeit(1000))
print(t1.timeit(1000))
print(t2.timeit(1000))
$ python clamp-where-max.py
<torch.utils.benchmark.utils.common.Measurement object at 0x7f2c77739040>
torch.clamp(a, b)
1.60 us
1 measurement, 1000 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7f2c77739d60>
torch.max(a, b)
1.60 us
1 measurement, 1000 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7f2c77739040>
torch.where(a > b, a, b)
4.36 us
1 measurement, 1000 runs , 1 thread
so max is tied with clamp, and where is slow.
but ensure to benchmark with the actual dimensions, though it shouldn't make much of a difference I think.
(edited: got the a wrong initially)
|
@patil-suraj I think only your check is missing! |
|
Sorry for being so late here @gsarti! Merged master into it to ping circle ci here |
|
Hey @gsarti - it seems like a test is failing now: with |
I noticed this when running the code. My understanding is that setting |
There was a problem hiding this comment.
I think we don't need this line.
There was a problem hiding this comment.
Instead, we can change this part to
if attn_weights.dtype == torch.float16:
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(attn_weights.dtype)There was a problem hiding this comment.
I think that we should fix the line on OPT too:
in caseattention_mask is set to None the forward pass will fail as described in #18057 (comment)
There was a problem hiding this comment.
But I think that the issue has never been reported since attention_mask is never None:
There was a problem hiding this comment.
Good catch! Surprisingly, we don't have test failure for OPT due to this.
There was a problem hiding this comment.
You answered my question before I asked it 😆
|
Hi @gsarti Sorry for being late for this PR. I re-opened it and give some suggestion for a fix to the failing test. Would you like to update this PR after rebasing your working branch on an updated |
84af0ee to
87ef76e
Compare
|
Hi @gsarti I made the necessary change to pass the tests, and pushed to your branch directly. The remaining failing test is irrelevant to this PR, but I will wait until tomorrow to check again, then I will merge. cc @patrickvonplaten and @younesbelkada |
|
Thanks a lot for the fix @ydshieh !! |
87ef76e to
1eb0953
Compare
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
What does this PR do?
Fixes #18049 following the exact same procedure used in #17437. Beside the added test, I also evaluated the fix on my personal use-case and found the behavior of the fixed model to be consistent when performing single or batched generation.
Before submitting
Pull Request section?
to it if that's the case.
Who can review?
@patil-suraj @ydshieh @patrickvonplaten