Skip to content

replace apex.normalization.FusedLayerNorm with torch.nn.LayerNorm#9386

Merged
patrickvonplaten merged 1 commit intohuggingface:masterfrom
stas00:native-layer-norm
Jan 4, 2021
Merged

replace apex.normalization.FusedLayerNorm with torch.nn.LayerNorm#9386
patrickvonplaten merged 1 commit intohuggingface:masterfrom
stas00:native-layer-norm

Conversation

@stas00
Copy link
Copy Markdown
Contributor

@stas00 stas00 commented Jan 2, 2021

This PR proposes to drop apex.normalization.FusedLayerNorm in favor of faster torch.nn.LayerNorm.

  1. For performance and background details please see the discussions in replacing apex.normalization.FusedLayerNorm with torch.nn.LayerNorm #9377
  2. It's also needed for [model parallelism] Bart goes parallel #9384 since apex.normalization.FusedLayerNorm corrupts data under model parallel FusedLayerNorm corrupts data when switching devices NVIDIA/apex#1022

Fixes: #9377

@LysandreJik, @sgugger, @patrickvonplaten

Copy link
Copy Markdown
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me!

Copy link
Copy Markdown
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM too!

@LysandreJik LysandreJik mentioned this pull request Jan 4, 2021
15 tasks
Copy link
Copy Markdown
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Way easier this way, LGTM!

@patrickvonplaten
Copy link
Copy Markdown
Contributor

Merging since it's blocking #9343 .

@patrickvonplaten patrickvonplaten merged commit 47ca0ea into huggingface:master Jan 4, 2021
@stas00 stas00 deleted the native-layer-norm branch January 4, 2021 19:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

replacing apex.normalization.FusedLayerNorm with torch.nn.LayerNorm

4 participants