accelerate support for RoBERTa family#19906
Conversation
Added `accelerate` support for - `RoBERTa` - `data2vec_text` - `Lilt` - `Luke` - `XLM-RoBERTa` fixes - small bug in `test_modeling_common`
|
The documentation is not available anymore as the PR was closed or merged. |
sgugger
left a comment
There was a problem hiding this comment.
Thanks, this looks like a way better fix!
| # To tie those two weights if they get disconnected (on TPU or when the bias is resized) | ||
| self.bias = self.decoder.bias | ||
| # For accelerate compatibility and to not break backward compatibility | ||
| if self.decoder.bias.device == torch.device("meta"): |
There was a problem hiding this comment.
The test will probably break if PyTorch is < 1.9, so we need a safer way to test if the device is meta (can be in an util if the test ends up being long).
There was a problem hiding this comment.
I propose a fix here, 05da693
I am not sure if device.type can be retrieved for PT<1.9 but it is something that I have seen on accelerate I think
There was a problem hiding this comment.
Here is a quick try on torch == 1.7.1 !
>>> torch.__version__
'1.7.1'
>>> vec = torch.randn(1, 1)
>>> vec.device
device(type='cpu')
>>> vec.device.type
'cpu'
| config_class = RobertaConfig | ||
| base_model_prefix = "roberta" | ||
| supports_gradient_checkpointing = True | ||
| _no_split_modules = [] |
There was a problem hiding this comment.
We don't even need the base block?
There was a problem hiding this comment.
Yes, for some models (roberta, lilt), passing an empty list was sufficient. I guess the accelerate tests are still run since the condition checks only if the list is None.
if model_class._no_split_modules is None:
continue
What does this PR do?
This PR adds
acceleratesupport for:RoBERTadata2vec_textLiltLukeXLM-RoBERTaCamemBERTLongFormerThis way, any of the models above can be loaded in 8bit using
load_in_8bit=True.Since these models copy the same
xxxLMHeadfromRoBERTaI had to change the copied modules too - happy also to break down this PR into several smaller PRs,This PR also fixes a small bug on
acceleratetests where the variableinput_dictis overriden byxxForMultipleChoicemodels.Can also confirm all slow tests pass (single + multiple GPUs)
cc @sgugger @ydshieh