Skip to content

Apply GradientCheckpointingLayer to the whole repo#38913

Merged
Cyrilvallez merged 148 commits intohuggingface:mainfrom
qubvel:gradient-checkpointing-layer-propagation
Jun 23, 2025
Merged

Apply GradientCheckpointingLayer to the whole repo#38913
Cyrilvallez merged 148 commits intohuggingface:mainfrom
qubvel:gradient-checkpointing-layer-propagation

Conversation

@qubvel
Copy link
Copy Markdown
Contributor

@qubvel qubvel commented Jun 19, 2025

What does this PR do?

Apply GradientCheckpointingLayer to the remaining models in the repository.

Most of the PR is pretty much similar changes for all models:

  1. Add import for GradientCheckpointingLayer
  2. Inherit *Layer modules from GradientCheckpointingLayer
  3. Remove if/else path for gradient checkpointing, keeping the else path only
    3a) Some changes were required to make sure all tensors with gradients were passed as positional arguments.

Additionally, GradientCheckpointingLayer was modified slightly. I added handling for use_cache and past_key_values within the layer to disable them in case gradient checkpointing is enabled.

We still have to keep some redundant code, though:

Case 1.

        if self.gradient_checkpointing and self.training and use_cache:
            logger.warning_once(
                "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
            )
            use_cache = False

because, later, most of the models rely on the use_cache parameter as follows

            if use_cache:
                next_decoder_cache = layer_outputs[2 if output_attentions else 1]

In case it is handled only by GradientCheckpointingLayer and not modified in the outer module, it leads to an IndexError.

Case 2.

In some cases layer parameters order doesn't allow to handle past_key_values as kwargs, e.g. GPT2

            outputs = block(
                hidden_states,
                past_key_values if not (self.gradient_checkpointing and self.training) else None,
                cache_position,
                causal_mask,
                head_mask[i],
                encoder_hidden_states,  # as a positional argument for gradient checkpointing
                encoder_attention_mask=encoder_attention_mask,
                use_cache=use_cache,
                output_attentions=output_attentions,
                **kwargs,
            )

We have to pass all params up to encoder_hidden_states as positional args (tensors that require grads have to be passed that way), so past_key_values is also passed a positional argument and resolved manually.

Alternatively, we can refactor layers' params order, but that would be a breaking change. Not that many models and mostly the old ones.

Not supported models

Also, there are a couple of exceptions where GradientCheckpointingLayer does not work. I tried to fix it, but I didn't go too far and just kept it in its original state

  • zamba / zamba2
  • mllama

cc @ArthurZucker @Cyrilvallez

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@qubvel qubvel marked this pull request as ready for review June 23, 2025 10:26
Comment on lines +52 to +81
do_warn = False
layer_name = self.__class__.__name__
message = f"Caching is incompatible with gradient checkpointing in {layer_name}. Setting"

if "use_cache" in kwargs and kwargs["use_cache"]:
kwargs["use_cache"] = False
message += " `use_cache=False`,"
do_warn = True

# different names for the same thing in different layers
if "past_key_value" in kwargs and kwargs["past_key_value"] is not None:
kwargs["past_key_value"] = None
message += " `past_key_value=None`,"
do_warn = True

if "past_key_values" in kwargs and kwargs["past_key_values"] is not None:
kwargs["past_key_values"] = None
message += " `past_key_values=None`,"
do_warn = True

if "layer_past" in kwargs and kwargs["layer_past"] is not None:
kwargs["layer_past"] = None
message += " `layer_past=None`,"
do_warn = True

# warn if anything was changed
if do_warn:
message = message.rstrip(",") + "."
logger.warning(message)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update for GradientCheckpointingLayer

@qubvel qubvel requested a review from Cyrilvallez June 23, 2025 10:37
Copy link
Copy Markdown
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Big big PR, and super welcome! 🚀🤗 Can we add a common test for gradient checkpointing though? I see we don't have one yet (only in trainer) - just instantiate a small model and run a single forward with gradient checkpointing and ensure that it runs correctly would be super nice

Copy link
Copy Markdown
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, did not see the test_training_gradient_checkpointing... before, my bad! All good then! Let's merge! 🤗

@Cyrilvallez Cyrilvallez merged commit 84d19be into huggingface:main Jun 23, 2025
18 of 20 checks passed
Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SUper nice! Thanks

High5Apps added a commit to High5Apps/mrt5 that referenced this pull request Dec 4, 2025
- delete_gate_mask is a tensor with `requires_grad=True`, so it must be passed as a positional arg to work with gradient checkpointing, according to this PR
  - huggingface/transformers#38913
- Without this change, running with `batch_size_multiplier=8,` or `gradient_checkpointing=True` would cause the following error:
```
RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants