Skip to content

🚨 Fix gradient checkpointing for several models and improve test robustness #41818

Merged
vasqu merged 10 commits intohuggingface:mainfrom
githubnemo:issue/gpt-bigcode-gradient-checkpointing
Nov 11, 2025
Merged

🚨 Fix gradient checkpointing for several models and improve test robustness #41818
vasqu merged 10 commits intohuggingface:mainfrom
githubnemo:issue/gpt-bigcode-gradient-checkpointing

Conversation

@githubnemo
Copy link
Copy Markdown
Contributor

@githubnemo githubnemo commented Oct 23, 2025

Support for gradient checkpointing was lost in the major refactoring in PR #38635 and this is the attempt to re-add it.

I extended the tests to

  • test use_reentrant=True and False
  • make sure model.train is called so that gradient checkpointing works; this is a limiation of the tests currently used by GPTBigCode
  • make sure that one (the first) gradient checkpointing layer is called
  • make sure that the same non-zero grads are there for normal and checkpointing runs - this is something we tripped over before in PEFT due to the possibly incompletely stored runtime environment in the checkpointed forward step, see also peft#2826

Note that the invocation of GPTBigCodeBlock.forward has changed:

  • layer_past is now passed as a keyword argument so that GradientCheckpointingLayer.__call__ can see and filter this parameter (use_reentrant=False fails otherwise)
  • {encoder_}hidden_states are still passed as positional arguments so that torch.utils.checkpoint.checkpoint receives them as pos. args and computes gradients for these (kwargs would be filtered by GradientCheckpointingLayer).

🚨 Note that this is breaking compatibility by changing the forward signature in GPTBigCodeBlock.forward!

Support for gradient checkpointing was lost in the major refactoring in PR huggingface#38635
and this is the attempt to re-add it.

I extended the tests to
- test `use_reentrant=True` and `False`
- make sure `model.train` is called so that gradient checkpointing works;
  this is a limiation of the tests currently used by GPTBigCode
- make sure that one (the first) gradient checkpointing layer is called
- make sure that the same non-zero grads are there for normal and checkpointing
  runs - this is something we tripped over before in PEFT due to the possibly
  incompletely stored runtime environment in the checkpointed forward step,
  see also peft#2826

Note that the invocation of `GPTBigCodeBlock.forward` has changed:

- `layer_past` is now passed as a keyword argument so that
  `GradientCheckpointingLayer.__call__` can see and filter this parameter
  (`use_reentrant=False` fails otherwise)
- `{encoder_}hidden_states` are still passed as positional arguments
  so that `torch.utils.checkpoint.checkpoint` receives them as pos. args
  and computes gradients for these (kwargs would be filtered by
  `GradientCheckpointingLayer`).
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests are neat, I think we should move them to common tests tho. Not exactly sure why it was specially treated here.

And ig there will be a need for another round to check similar models that may have been accidentally overriden with the ckpting layer 😓 not necessarily this PR tho

Comment thread src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
Comment thread tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py
Comment thread tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py Outdated
Comment thread tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py Outdated
@githubnemo githubnemo changed the title Implement gradient checkpointing in GPTBigCode 🚨 Implement gradient checkpointing in GPTBigCode Oct 27, 2025
@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Oct 27, 2025

cc @ArthurZucker since this might become a bit more breaking than initially thought, and it likely affects more models

- Compare that the non-zero gradients in a reference run are present in the checkpointing run
- Make sure that the forward of at least one gradient checkpointing layer is actually called
  more than once (as expected during gradient checkpointing backward)

Currently there are some problems with Bert-derived MultipleChoice models, when dropout is
enabled there are scenarios during gradient checkpointing where `classifier.bias.grad` is None.
I don't yet have a good explanation for this, disabling dropout resolves this. I would have
understood, if it is dropout on the classification layer but enabling attention dropout is
also leading to this behavior.

MoE models have selective sparsity depending on the selected experts, for this reason we
only compare gradients on parameters collected on the reference backward run.
@githubnemo
Copy link
Copy Markdown
Contributor Author

I've updated the general tests. From the commit message:

- Compare that the non-zero gradients in a reference run are present in the checkpointing run
- Make sure that the forward of at least one gradient checkpointing layer is actually called
  more than once (as expected during gradient checkpointing backward)

Currently there are some problems with Bert-derived MultipleChoice models, when dropout is
enabled there are scenarios during gradient checkpointing where `classifier.bias.grad` is None.
I don't yet have a good explanation for this, disabling dropout resolves this. I would have
understood, if it is dropout on the classification layer but enabling attention dropout is
also leading to this behavior.

MoE models have selective sparsity depending on the selected experts, for this reason we
only compare gradients on parameters collected on the reference backward run.

Currently these models are expected to fail since they're not implementing GradientCheckpointingLayers:

  • swiftformer
  • xlstm
  • zamba
  • zamba2

most likely these as well:

  • janus (no training testing?)
  • clvp

As I explained in the commit message, there's a strange bug with Bert-derived models when testing the BertForMultipleChoice case. When attention(!) dropout is active, sometimes the classification bias receives a None gradient. I don't have a good explanation for this right now but that seems fishy. Happy about any input.

I didn't revert the GPTBigCode test changes yet since I first wanted to get an opinion if we want to proceed with these more general tests or not.

Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is fine, we will likely need to break other model's signature as well? I.e. not only got bigcode. This PR will get bigger than initially thought but let's fix these models

We can allow this for v5 but let's also mention this PR in the v5 thread (#40822) when we merge this.

Comment thread tests/test_modeling_common.py
Comment thread tests/test_modeling_common.py Outdated
Comment thread src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
Comment thread tests/test_modeling_common.py
nemo added 2 commits October 30, 2025 13:10
also drop janus from ignore list - only the VQVAE case is without
gradient checkpointing and it is doubtful that it is usefule in that
case. Training with gradient checkpointing is not tested anyway.
Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just noticed this small thing in xlstm

Re: Clvp, let's isolate it for now. We can come back later except you have a good idea how to refactor/handle this properly

Comment thread src/transformers/models/xlstm/modeling_xlstm.py
The implementation of GradientCheckpointingLayers is not trivial and may break behavior
that was previously expected. Therefore we keep it as-is for now.
Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also change the title here a bit since the scope changed

cc @ArthurZucker @Cyrilvallez if you can take a last look --> fixes a last few gradient ckpting models and makes the test more robust towards actually using it properly

Comment thread src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
Comment thread tests/test_modeling_common.py Outdated
Comment on lines +837 to +839
"clvp",
"clvp_encoder",
"clvp_decoder",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to override this test in test_modeling_clvp instead? No biggie if not

Copy link
Copy Markdown
Contributor Author

@githubnemo githubnemo Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've removed some exceptions (janus, clvp in training test) but for clvp/clvp_decoder I think it is better to have the one single exception visible instead of duplicating the test code without the single assertion. It also didn't make sense to me to refactor the assertion into an abstract base method for the tests since it is only CLVP and nothing else.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha, thanks for iterating here and bearing with me.

@githubnemo githubnemo requested a review from vasqu November 5, 2025 11:36
@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: clvp, gpt_bigcode, janus, swiftformer, xlstm, zamba, zamba2

Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, cc @Cyrilvallez for core maintainer

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Nov 10, 2025

Can we also change the title of the PR tho. The PR also makes the test more robust and properly checks for gradient ckpting capabilities

Copy link
Copy Markdown
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not much to complain about, thanks a lot for this!

@vasqu vasqu changed the title 🚨 Implement gradient checkpointing in GPTBigCode 🚨 Fix gradient checkpointing for several models and improve test robustness Nov 11, 2025
@vasqu vasqu merged commit fa22b56 into huggingface:main Nov 11, 2025
23 checks passed
SangbumChoi pushed a commit to SangbumChoi/transformers that referenced this pull request Jan 23, 2026
…stness (huggingface#41818)

* Implement gradient checkpointing in GPTBigCode

Support for gradient checkpointing was lost in the major refactoring in PR huggingface#38635
and this is the attempt to re-add it.

I extended the tests to
- test `use_reentrant=True` and `False`
- make sure `model.train` is called so that gradient checkpointing works;
  this is a limiation of the tests currently used by GPTBigCode
- make sure that one (the first) gradient checkpointing layer is called
- make sure that the same non-zero grads are there for normal and checkpointing
  runs - this is something we tripped over before in PEFT due to the possibly
  incompletely stored runtime environment in the checkpointed forward step,
  see also peft#2826

Note that the invocation of `GPTBigCodeBlock.forward` has changed:

- `layer_past` is now passed as a keyword argument so that
  `GradientCheckpointingLayer.__call__` can see and filter this parameter
  (`use_reentrant=False` fails otherwise)
- `{encoder_}hidden_states` are still passed as positional arguments
  so that `torch.utils.checkpoint.checkpoint` receives them as pos. args
  and computes gradients for these (kwargs would be filtered by
  `GradientCheckpointingLayer`).

* Improve gradient checkpointing tests

- Compare that the non-zero gradients in a reference run are present in the checkpointing run
- Make sure that the forward of at least one gradient checkpointing layer is actually called
  more than once (as expected during gradient checkpointing backward)

Currently there are some problems with Bert-derived MultipleChoice models, when dropout is
enabled there are scenarios during gradient checkpointing where `classifier.bias.grad` is None.
I don't yet have a good explanation for this, disabling dropout resolves this. I would have
understood, if it is dropout on the classification layer but enabling attention dropout is
also leading to this behavior.

MoE models have selective sparsity depending on the selected experts, for this reason we
only compare gradients on parameters collected on the reference backward run.

* Remove duplicated gradient checkpointing code

* Address review comments

* Make test output consistent

* GradientCheckpointingLayer for xlstm, zamba, zamba2

* GradientCheckpointingLayer for swiftformer

also drop janus from ignore list - only the VQVAE case is without
gradient checkpointing and it is doubtful that it is usefule in that
case. Training with gradient checkpointing is not tested anyway.

* Make an exception for CLVP

The implementation of GradientCheckpointingLayers is not trivial and may break behavior
that was previously expected. Therefore we keep it as-is for now.

* Remove unneeded exceptions

---------

Co-authored-by: nemo <git@ningu.net>
Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants