Skip to content

accelerate support for RoBERTa family#19906

Merged
younesbelkada merged 2 commits intohuggingface:mainfrom
younesbelkada:add_accelerate_roberta
Oct 26, 2022
Merged

accelerate support for RoBERTa family#19906
younesbelkada merged 2 commits intohuggingface:mainfrom
younesbelkada:add_accelerate_roberta

Conversation

@younesbelkada
Copy link
Copy Markdown
Contributor

@younesbelkada younesbelkada commented Oct 26, 2022

What does this PR do?

This PR adds accelerate support for:

  • RoBERTa
  • data2vec_text
  • Lilt
  • Luke
  • XLM-RoBERTa
  • CamemBERT
  • LongFormer
    This way, any of the models above can be loaded in 8bit using load_in_8bit=True.

Since these models copy the same xxxLMHead from RoBERTa I had to change the copied modules too - happy also to break down this PR into several smaller PRs,

This PR also fixes a small bug on accelerate tests where the variable input_dict is overriden by xxForMultipleChoice models.

Can also confirm all slow tests pass (single + multiple GPUs)

cc @sgugger @ydshieh

Added `accelerate` support for
- `RoBERTa`
- `data2vec_text`
- `Lilt`
- `Luke`
- `XLM-RoBERTa`

fixes
- small bug in `test_modeling_common`
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

HuggingFaceDocBuilderDev commented Oct 26, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Copy Markdown
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this looks like a way better fix!

# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias
# For accelerate compatibility and to not break backward compatibility
if self.decoder.bias.device == torch.device("meta"):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test will probably break if PyTorch is < 1.9, so we need a safer way to test if the device is meta (can be in an util if the test ends up being long).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I propose a fix here, 05da693
I am not sure if device.type can be retrieved for PT<1.9 but it is something that I have seen on accelerate I think

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is a quick try on torch == 1.7.1 !

>>> torch.__version__
'1.7.1'
>>> vec = torch.randn(1, 1)
>>> vec.device
device(type='cpu')
>>> vec.device.type
'cpu'

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

config_class = RobertaConfig
base_model_prefix = "roberta"
supports_gradient_checkpointing = True
_no_split_modules = []
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't even need the base block?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, for some models (roberta, lilt), passing an empty list was sufficient. I guess the accelerate tests are still run since the condition checks only if the list is None.

if model_class._no_split_modules is None:
    continue

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants