Skip to content

XGLM - Fix Softmax NaNs when using FP16#18057

Merged
ydshieh merged 4 commits intohuggingface:mainfrom
gsarti:fix-xglm-fp16-nans
Sep 29, 2022
Merged

XGLM - Fix Softmax NaNs when using FP16#18057
ydshieh merged 4 commits intohuggingface:mainfrom
gsarti:fix-xglm-fp16-nans

Conversation

@gsarti
Copy link
Copy Markdown
Contributor

@gsarti gsarti commented Jul 7, 2022

What does this PR do?

Fixes #18049 following the exact same procedure used in #17437. Beside the added test, I also evaluated the fix on my personal use-case and found the behavior of the fixed model to be consistent when performing single or batched generation.

Before submitting

  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you write any new necessary tests?

Who can review?

@patil-suraj @ydshieh @patrickvonplaten

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

HuggingFaceDocBuilderDev commented Jul 7, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Copy Markdown
Collaborator

@ydshieh ydshieh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, @gsarti, LGTM!

Comment thread tests/models/xglm/test_modeling_xglm.py Outdated
return position_ids.unsqueeze(0).expand(input_shape).contiguous() + past_key_values_length


# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->XGLM
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We (HF team) have to remember to add this back once Bart takes the same fix.

Copy link
Copy Markdown
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM thanks a lot for the fix 🚀 !

f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@stas00 is this operation costly? Wondering how costly such a max operation is

Copy link
Copy Markdown
Contributor

@stas00 stas00 Jul 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what are the contenders? at least max, clamp and where, but probably others as well?

In [13]: a = torch.tensor([5,-1e20])

In [14]: b = torch.tensor(torch.finfo(torch.float16).min)

In [16]: torch.clamp(a, min=b)
Out[16]: tensor([ 5.0000e+00, -6.5504e+04])

In [21]: torch.where(a > b, a, b)
Out[21]: tensor([ 5.0000e+00, -6.5504e+04])

In [22]: torch.max(a, b)
Out[22]: tensor([ 5.0000e+00, -6.5504e+04])

Benchmark:

$ cat clamp-where-max.py
import torch.utils.benchmark as benchmark
import torch

a = torch.empty(512)
b = torch.tensor(torch.finfo(torch.float16).min)

t0 = benchmark.Timer(
    stmt='torch.clamp(a, b)',
    setup='',
    globals=dict(a=a, b=b),
)

t1 = benchmark.Timer(
    stmt='torch.max(a, b)',
    setup='',
    globals=dict(a=a, b=b),
)

t2 = benchmark.Timer(
    stmt='torch.where(a > b, a, b)',
    setup='',
    globals=dict(a=a, b=b),
)
print(t0.timeit(1000))
print(t1.timeit(1000))
print(t2.timeit(1000))

$ python clamp-where-max.py
<torch.utils.benchmark.utils.common.Measurement object at 0x7f2c77739040>
torch.clamp(a, b)
  1.60 us
  1 measurement, 1000 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7f2c77739d60>
torch.max(a, b)
  1.60 us
  1 measurement, 1000 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7f2c77739040>
torch.where(a > b, a, b)
  4.36 us
  1 measurement, 1000 runs , 1 thread

so max is tied with clamp, and where is slow.

but ensure to benchmark with the actual dimensions, though it shouldn't make much of a difference I think.

(edited: got the a wrong initially)

@gsarti
Copy link
Copy Markdown
Contributor Author

gsarti commented Aug 3, 2022

@patil-suraj I think only your check is missing!

@patrickvonplaten
Copy link
Copy Markdown
Contributor

patrickvonplaten commented Aug 23, 2022

Sorry for being so late here @gsarti! Merged master into it to ping circle ci here

@patrickvonplaten
Copy link
Copy Markdown
Contributor

Hey @gsarti - it seems like a test is failing now:

tests/models/xglm/test_modeling_xglm.py::XGLMModelTest::test_xglm_model_past

with

 UnboundLocalError: local variable 'dtype_attn_weights' referenced before assignment

@gsarti
Copy link
Copy Markdown
Contributor Author

gsarti commented Aug 24, 2022

Hey @gsarti - it seems like a test is failing now:

tests/models/xglm/test_modeling_xglm.py::XGLMModelTest::test_xglm_model_past

with

 UnboundLocalError: local variable 'dtype_attn_weights' referenced before assignment

I noticed this when running the code. My understanding is that setting dtype_attn_weights as torch.float32 as default beforehand would fix the issue and maintain the expected behavior, could you double-check?

@github-actions github-actions Bot closed this Sep 25, 2022
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't need this line.

Copy link
Copy Markdown
Collaborator

@ydshieh ydshieh Sep 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead, we can change this part to

if attn_weights.dtype == torch.float16:
    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(attn_weights.dtype)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that we should fix the line on OPT too:

dtype_attn_weights = attn_weights.dtype
in case attention_mask is set to None the forward pass will fail as described in #18057 (comment)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I think that the issue has never been reported since attention_mask is never None:

attention_mask = torch.ones(inputs_embeds.shape[:2], dtype=torch.bool, device=inputs_embeds.device)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! Surprisingly, we don't have test failure for OPT due to this.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You answered my question before I asked it 😆

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahahaha yes :D

@ydshieh ydshieh reopened this Sep 27, 2022
@ydshieh
Copy link
Copy Markdown
Collaborator

ydshieh commented Sep 27, 2022

Hi @gsarti Sorry for being late for this PR. I re-opened it and give some suggestion for a fix to the failing test. Would you like to update this PR after rebasing your working branch on an updated main branch?

@huggingface huggingface deleted a comment from github-actions Bot Sep 27, 2022
@ydshieh
Copy link
Copy Markdown
Collaborator

ydshieh commented Sep 28, 2022

Hi @gsarti I made the necessary change to pass the tests, and pushed to your branch directly. The remaining failing test is irrelevant to this PR, but I will wait until tomorrow to check again, then I will merge.

cc @patrickvonplaten and @younesbelkada

@younesbelkada
Copy link
Copy Markdown
Contributor

Thanks a lot for the fix @ydshieh !!
I think for consistency we should apply the same changes on OPT too, I will take care of that first thing in the morning tomorrow 💪

Comment thread src/transformers/models/xglm/modeling_xglm.py Outdated
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
@ydshieh ydshieh merged commit 9d732fd into huggingface:main Sep 29, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

NaN in XGLM Softmax with FP16

6 participants