Apply GradientCheckpointingLayer to the whole repo#38913
Apply GradientCheckpointingLayer to the whole repo#38913Cyrilvallez merged 148 commits intohuggingface:mainfrom
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
| do_warn = False | ||
| layer_name = self.__class__.__name__ | ||
| message = f"Caching is incompatible with gradient checkpointing in {layer_name}. Setting" | ||
|
|
||
| if "use_cache" in kwargs and kwargs["use_cache"]: | ||
| kwargs["use_cache"] = False | ||
| message += " `use_cache=False`," | ||
| do_warn = True | ||
|
|
||
| # different names for the same thing in different layers | ||
| if "past_key_value" in kwargs and kwargs["past_key_value"] is not None: | ||
| kwargs["past_key_value"] = None | ||
| message += " `past_key_value=None`," | ||
| do_warn = True | ||
|
|
||
| if "past_key_values" in kwargs and kwargs["past_key_values"] is not None: | ||
| kwargs["past_key_values"] = None | ||
| message += " `past_key_values=None`," | ||
| do_warn = True | ||
|
|
||
| if "layer_past" in kwargs and kwargs["layer_past"] is not None: | ||
| kwargs["layer_past"] = None | ||
| message += " `layer_past=None`," | ||
| do_warn = True | ||
|
|
||
| # warn if anything was changed | ||
| if do_warn: | ||
| message = message.rstrip(",") + "." | ||
| logger.warning(message) | ||
|
|
There was a problem hiding this comment.
update for GradientCheckpointingLayer
Cyrilvallez
left a comment
There was a problem hiding this comment.
Big big PR, and super welcome! 🚀🤗 Can we add a common test for gradient checkpointing though? I see we don't have one yet (only in trainer) - just instantiate a small model and run a single forward with gradient checkpointing and ensure that it runs correctly would be super nice
Cyrilvallez
left a comment
There was a problem hiding this comment.
Alright, did not see the test_training_gradient_checkpointing... before, my bad! All good then! Let's merge! 🤗
ArthurZucker
left a comment
There was a problem hiding this comment.
SUper nice! Thanks
- delete_gate_mask is a tensor with `requires_grad=True`, so it must be passed as a positional arg to work with gradient checkpointing, according to this PR - huggingface/transformers#38913 - Without this change, running with `batch_size_multiplier=8,` or `gradient_checkpointing=True` would cause the following error: ``` RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward. ```
What does this PR do?
Apply
GradientCheckpointingLayerto the remaining models in the repository.Most of the PR is pretty much similar changes for all models:
GradientCheckpointingLayer*Layermodules fromGradientCheckpointingLayerif/elsepath for gradient checkpointing, keeping theelsepath only3a) Some changes were required to make sure all tensors with gradients were passed as positional arguments.
Additionally,
GradientCheckpointingLayerwas modified slightly. I added handling foruse_cacheandpast_key_valueswithin the layer to disable them in case gradient checkpointing is enabled.We still have to keep some redundant code, though:
Case 1.
because, later, most of the models rely on the
use_cacheparameter as followsIn case it is handled only by
GradientCheckpointingLayerand not modified in the outer module, it leads to an IndexError.Case 2.
In some cases layer parameters order doesn't allow to handle
past_key_valuesas kwargs, e.g. GPT2We have to pass all params up to
encoder_hidden_statesas positional args (tensors that require grads have to be passed that way), sopast_key_valuesis also passed a positional argument and resolved manually.Alternatively, we can refactor layers' params order, but that would be a breaking change. Not that many models and mostly the old ones.
Not supported models
Also, there are a couple of exceptions where
GradientCheckpointingLayerdoes not work. I tried to fix it, but I didn't go too far and just kept it in its original statecc @ArthurZucker @Cyrilvallez