Skip to content

Make gradient-checkpoint enabling tolerant of models without get_input_embeddings#42558

Merged
molbap merged 31 commits intomainfrom
fix_enable_grads_again
Dec 17, 2025
Merged

Make gradient-checkpoint enabling tolerant of models without get_input_embeddings#42558
molbap merged 31 commits intomainfrom
fix_enable_grads_again

Conversation

@molbap
Copy link
Copy Markdown
Contributor

@molbap molbap commented Dec 2, 2025

What does this PR do?

Ad title indicates, #42542 and likely a few other models are broken by merged #41993 . This adds an embedding getter and attempts to test the feature with more coverage.

Basically what it does

  • Stop hard-failing gradient_checkpointing_enable when a model lacks a get_input_embeddings. We now just call enable_input_require_grads, let it attach hooks where it can, and issue a single warning if no embedding module is found.
  • Simplify enable_input_require_grads (and the InternVL/MLCD and a couple more model overrides/adjustments) by making them responsible for the warning.
  • Adds a big test to make sure all of that works (please take a look)

Should help GC for PEFT adapters for many VLMs hopefully (and normal models too)

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@molbap
Copy link
Copy Markdown
Contributor Author

molbap commented Dec 2, 2025

added a test as well - but can't find a clean way around the models for which it is not relevant to have a getter method and not causing as many side-effects. WDYT @zucchini-nlp ? kind of stumped (try/excepting at higher level would always work but hides a lot)

Copy link
Copy Markdown
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

models for which it is not relevant to have a getter method and not causing as many side-effects

Would this also mean that we can't support correctly PEFT and GC with these models, or do they have a custom way to set grad on the inputs? We could raise an error with a better message saying that models doesn't support unless it has a way to get its input embeddings, wdyt?

Comment thread src/transformers/modeling_utils.py Outdated
Comment on lines +985 to +987
base_model = getattr(self, "base_model_prefix", None)
if base_model is not None:
base_model = getattr(self, base_model, None)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: self.base_model property has the same functionality

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

true!

Comment thread src/transformers/modeling_utils.py Outdated
_input_embed_layer = "embed_tokens" # default layer that holds input embeddings.

def get_input_embeddings(self) -> nn.Module:
def _get_input_embeddings_no_raise(self) -> Optional[nn.Module]:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh interesting, I was assuming the base get_input_embedding already returns None

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well I ended up in some many little edge cases lol

@molbap
Copy link
Copy Markdown
Contributor Author

molbap commented Dec 3, 2025

Yes it's a good idea to raise/inform for downstream users. I reverted a couple things and will update the test so it actually checks that enabling GC works (probably add another test)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is mostly to fix a broken env situation that can be caused around timm_wrapper (or timm_backbone?) so it protects a few imports

@molbap
Copy link
Copy Markdown
Contributor Author

molbap commented Dec 4, 2025

I reverted a few models to inner positional embeddings calls as mentioned in #38913 .

Modified a few others models as the test I added (test_enable_input_require_grads_with_gradient_checkpointing ) was a bit naive and I was just continue-ing, now it's a proper skip if the loss is undefined.

Hopefully that helps VLMs + GC and does not break adapters

Comment on lines +1987 to +1990
try:
input_embeddings = module.get_input_embeddings()
except NotImplementedError:
continue
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no simple way around this unfortunately

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oke, I think with the warning below, it is more explicit

Comment thread src/transformers/modeling_utils.py Outdated
Comment on lines +2007 to +2011
if not found_embeddings:
logger.warning_once(
f"{self.__class__.__name__} does not expose input embeddings. Gradients cannot flow back to the token "
"embeddings when using adapters or gradient checkpointing. Override `get_input_embeddings` to fully "
"support those features."
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

at least we can warn users!

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

@molbap molbap changed the title Add embedding getter + test Make gradient-checkpoint enabling tolerant of models without get_input_embeddings Dec 4, 2025
Copy link
Copy Markdown
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a pity that there are so many edge cases to handle. Raising a warning seems like a good solution to not silently skip exceptions

I think there are some unrelated changes with rope, let's revert those before merging

Comment thread src/transformers/models/internvl/modeling_internvl.py
Comment thread src/transformers/models/altclip/modeling_altclip.py
Comment on lines +337 to +340
rotary_embeddings = position_embeddings
if rotary_position_tensor is not None:
rotary_embeddings = (rotary_position_tensor.cos(), rotary_position_tensor.sin())

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the position_embeddings are already supposed to be present so we don't need the embeddings, isn't it?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that one is a little bit more annoying. I'm trying to revert it but the thing is this model should work except that it's a tuple... and gradient checkpointing does not disable grads on tuples, only on tensors passed as pos args .

so this was a hacky trick (for the use_reentrant case

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weird, we use tuple cos/sin in most LLMs. I'd prefer to skip this model's test instead of fixing by duplicate args, i think it is not used as commonly

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did skip it yes because it was being annoying 😁 and yes we use them, but they don't usually carry grads, here (in that particular model) they do

Comment thread src/transformers/modeling_utils.py Outdated
Comment on lines +2007 to +2011
if not found_embeddings:
logger.warning_once(
f"{self.__class__.__name__} does not expose input embeddings. Gradients cannot flow back to the token "
"embeddings when using adapters or gradient checkpointing. Override `get_input_embeddings` to fully "
"support those features."
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

Comment on lines +1987 to +1990
try:
input_embeddings = module.get_input_embeddings()
except NotImplementedError:
continue
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oke, I think with the warning below, it is more explicit

Comment on lines +2952 to +2955
needs_embedding_grads = self.main_input_name == "input_ids"
# we use that also to detect whether or not we have to raise if embeddings are missing (the submodel might not have embeddings at all)
enable_input_grads = needs_embedding_grads or getattr(self, "_hf_peft_config_loaded", False)
if enable_input_grads:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, for my understanding, why do we always need to enable grads when doing GC training with text models?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't always, but we do with reentrant checkpointing. IIUC it's not to actualy use these gradients, it's that torch.utils.checkpoint needs at least one input and one output to actually have gradients, else the checkpointed part will not have a gradient.

Comment on lines +922 to +933
def test_enable_input_require_grads_with_gradient_checkpointing(self):
if not getattr(self.model_tester, "is_training", False):
self.skipTest(reason="ModelTester is not configured to run training tests")

config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if hasattr(config, "use_cache"):
config.use_cache = False

has_verified_model = False

for model_class in self.all_model_classes:
if not getattr(model_class, "supports_gradient_checkpointing", False):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see now what you meant earlier, this test has a lot of edge cases

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, it's a bit clunky to have this bool flag but wasn't seeing a simpler option

Copy link
Copy Markdown
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks again for handling all edge cases, was not an easy one

@molbap
Copy link
Copy Markdown
Contributor Author

molbap commented Dec 5, 2025

run-slow: align, altclip, chinese_clip, clap, clvp, falcon_mamba, fast_vlm, internvl, layoutlm, layoutlmv3, lilt, mamba, markuplm, mlcd, poolformer, siglip

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Dec 5, 2025

This comment contains run-slow, running the specified jobs:

models: ["models/align", "models/altclip", "models/chinese_clip", "models/clap", "models/clvp", "models/falcon_mamba", "models/fast_vlm", "models/internvl", "models/layoutlm", "models/layoutlmv3", "models/lilt", "models/mamba", "models/markuplm", "models/mlcd", "models/poolformer", "models/siglip"]
quantizations: []

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Dec 5, 2025

CI Results

Workflow Run ⚙️

✅ No failing test specific to this PR 🎉 !

@molbap molbap requested a review from ArthurZucker December 5, 2025 16:34
Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kudos very nice PR!

Comment thread src/transformers/modeling_utils.py Outdated
Comment on lines +2008 to +2009
"embeddings when using adapters or gradient checkpointing. Override `get_input_embeddings` to fully "
"support those features."
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

either that or sometimes just add a _input_embedding_layer

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i mean update the message please

Comment on lines +913 to +914
if not hasattr(model, "get_input_embeddings"):
continue
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not raise an error instead this way all new models wiil make sure they have this go green before merging?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

forgot to answer but: this would currently raise for many existing models

Comment on lines +923 to +924
if not getattr(self.model_tester, "is_training", False):
self.skipTest(reason="ModelTester is not configured to run training tests")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this one is True by default for all models sg

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes AFAIK, true for CausalLMTester

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First of all thanks!
This will also fix some TP recompile issues cc @3outeille on hidden_states=hidden_states

Comment thread src/transformers/modeling_utils.py Outdated
logger.warning_once(
f"{self.__class__.__name__} does not expose input embeddings. Gradients cannot flow back to the token "
"embeddings when using adapters or gradient checkpointing. Override `get_input_embeddings` to fully "
"support those features."
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"support those features."
"support those features, or add the `_input_embedding_layer` attribut with the name of the embedding layer!"

Comment thread tests/test_modeling_common.py Outdated
grad_after_gc = embedding_param.grad
self.assertIsNotNone(
grad_after_gc,
f"{model_class.__name__} should produce embedding gradients when gradient checkpointing is enabled.",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you have an idea of what could cause this to fail, add it!

Comment on lines +1000 to +1012
f"{model_class.__name__} produced non-finite gradients with gradient checkpointing enabled.",
)
self.assertGreater(
grad_after_gc.abs().sum().item(),
0,
f"{model_class.__name__} should keep non-zero embedding gradients with gradient checkpointing enabled.",
)
has_verified_model = True

if not has_verified_model:
self.skipTest(
reason="No model with a differentiable loss was available to verify enable_input_require_grads with gradient checkpointing."
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same for each one of these, let's help the use to fix it!

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice! small updates and let's merge

Comment thread src/transformers/modeling_utils.py Outdated
Comment on lines +2008 to +2009
"embeddings when using adapters or gradient checkpointing. Override `get_input_embeddings` to fully "
"support those features."
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i mean update the message please

@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: align, altclip, chinese_clip, clap, clvp, falcon_mamba, fast_vlm, internvl, layoutlm, layoutlmv3, lilt, mamba, markuplm, mlcd, poolformer, siglip

@molbap
Copy link
Copy Markdown
Contributor Author

molbap commented Dec 17, 2025

run-slow: align, altclip, chinese_clip, clap, clvp, falcon_mamba, fast_vlm, internvl, layoutlm, layoutlmv3, lilt, mamba, markuplm, mlcd, poolformer, siglip

@github-actions
Copy link
Copy Markdown
Contributor

This comment contains run-slow, running the specified jobs:

models: ["models/align", "models/altclip", "models/chinese_clip", "models/clap", "models/clvp", "models/falcon_mamba", "models/fast_vlm", "models/internvl", "models/layoutlm", "models/layoutlmv3", "models/lilt", "models/mamba", "models/markuplm", "models/mlcd", "models/poolformer", "models/siglip"]
quantizations: []

@molbap
Copy link
Copy Markdown
Contributor Author

molbap commented Dec 17, 2025

Comments addressed. Merging and keeping an eye on this 👀 let's see if something breaks and how

@github-actions
Copy link
Copy Markdown
Contributor

CI Results

Workflow Run ⚙️

✅ No failing test specific to this PR 🎉 !

@molbap molbap merged commit b712a97 into main Dec 17, 2025
27 checks passed
@molbap molbap deleted the fix_enable_grads_again branch December 17, 2025 16:30
SangbumChoi pushed a commit to SangbumChoi/transformers that referenced this pull request Jan 23, 2026
…t_embeddings (huggingface#42558)

* add embedding getter

* modify your own logic

* a common test

* some adapters are not PreTrainedModel s

* few fixes

* implement correct-ish fix?

* fixup

* this is needed likely

* woops

* solving some cross-imports issues here and there

* more ximports issues

* finally

* revert changes

* fixups

* improve message

* add common tests for input_ids first

* increase test coverage

* bigger update for GC

* copies

* mlcd is getting on my nerves a bit

* ah yes

* for BC

* break a couple modelings

* simplify with base_model

* fix copies for torch checkpointing

* simplify this model

* improve messages
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants