Make gradient-checkpoint enabling tolerant of models without get_input_embeddings#42558
Make gradient-checkpoint enabling tolerant of models without get_input_embeddings#42558
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
added a test as well - but can't find a clean way around the models for which it is not relevant to have a getter method and not causing as many side-effects. WDYT @zucchini-nlp ? kind of stumped (try/excepting at higher level would always work but hides a lot) |
zucchini-nlp
left a comment
There was a problem hiding this comment.
models for which it is not relevant to have a getter method and not causing as many side-effects
Would this also mean that we can't support correctly PEFT and GC with these models, or do they have a custom way to set grad on the inputs? We could raise an error with a better message saying that models doesn't support unless it has a way to get its input embeddings, wdyt?
| base_model = getattr(self, "base_model_prefix", None) | ||
| if base_model is not None: | ||
| base_model = getattr(self, base_model, None) |
There was a problem hiding this comment.
nit: self.base_model property has the same functionality
| _input_embed_layer = "embed_tokens" # default layer that holds input embeddings. | ||
|
|
||
| def get_input_embeddings(self) -> nn.Module: | ||
| def _get_input_embeddings_no_raise(self) -> Optional[nn.Module]: |
There was a problem hiding this comment.
oh interesting, I was assuming the base get_input_embedding already returns None
There was a problem hiding this comment.
well I ended up in some many little edge cases lol
|
Yes it's a good idea to raise/inform for downstream users. I reverted a couple things and will update the test so it actually checks that enabling GC works (probably add another test) |
There was a problem hiding this comment.
This is mostly to fix a broken env situation that can be caused around timm_wrapper (or timm_backbone?) so it protects a few imports
|
I reverted a few models to inner positional embeddings calls as mentioned in #38913 . Modified a few others models as the test I added ( Hopefully that helps VLMs + GC and does not break adapters |
| try: | ||
| input_embeddings = module.get_input_embeddings() | ||
| except NotImplementedError: | ||
| continue |
There was a problem hiding this comment.
no simple way around this unfortunately
There was a problem hiding this comment.
oke, I think with the warning below, it is more explicit
| if not found_embeddings: | ||
| logger.warning_once( | ||
| f"{self.__class__.__name__} does not expose input embeddings. Gradients cannot flow back to the token " | ||
| "embeddings when using adapters or gradient checkpointing. Override `get_input_embeddings` to fully " | ||
| "support those features." |
There was a problem hiding this comment.
at least we can warn users!
zucchini-nlp
left a comment
There was a problem hiding this comment.
It is a pity that there are so many edge cases to handle. Raising a warning seems like a good solution to not silently skip exceptions
I think there are some unrelated changes with rope, let's revert those before merging
| rotary_embeddings = position_embeddings | ||
| if rotary_position_tensor is not None: | ||
| rotary_embeddings = (rotary_position_tensor.cos(), rotary_position_tensor.sin()) | ||
|
|
There was a problem hiding this comment.
the position_embeddings are already supposed to be present so we don't need the embeddings, isn't it?
There was a problem hiding this comment.
that one is a little bit more annoying. I'm trying to revert it but the thing is this model should work except that it's a tuple... and gradient checkpointing does not disable grads on tuples, only on tensors passed as pos args .
so this was a hacky trick (for the use_reentrant case
There was a problem hiding this comment.
weird, we use tuple cos/sin in most LLMs. I'd prefer to skip this model's test instead of fixing by duplicate args, i think it is not used as commonly
There was a problem hiding this comment.
I did skip it yes because it was being annoying 😁 and yes we use them, but they don't usually carry grads, here (in that particular model) they do
| if not found_embeddings: | ||
| logger.warning_once( | ||
| f"{self.__class__.__name__} does not expose input embeddings. Gradients cannot flow back to the token " | ||
| "embeddings when using adapters or gradient checkpointing. Override `get_input_embeddings` to fully " | ||
| "support those features." |
| try: | ||
| input_embeddings = module.get_input_embeddings() | ||
| except NotImplementedError: | ||
| continue |
There was a problem hiding this comment.
oke, I think with the warning below, it is more explicit
| needs_embedding_grads = self.main_input_name == "input_ids" | ||
| # we use that also to detect whether or not we have to raise if embeddings are missing (the submodel might not have embeddings at all) | ||
| enable_input_grads = needs_embedding_grads or getattr(self, "_hf_peft_config_loaded", False) | ||
| if enable_input_grads: |
There was a problem hiding this comment.
hmm, for my understanding, why do we always need to enable grads when doing GC training with text models?
There was a problem hiding this comment.
we don't always, but we do with reentrant checkpointing. IIUC it's not to actualy use these gradients, it's that torch.utils.checkpoint needs at least one input and one output to actually have gradients, else the checkpointed part will not have a gradient.
| def test_enable_input_require_grads_with_gradient_checkpointing(self): | ||
| if not getattr(self.model_tester, "is_training", False): | ||
| self.skipTest(reason="ModelTester is not configured to run training tests") | ||
|
|
||
| config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() | ||
| if hasattr(config, "use_cache"): | ||
| config.use_cache = False | ||
|
|
||
| has_verified_model = False | ||
|
|
||
| for model_class in self.all_model_classes: | ||
| if not getattr(model_class, "supports_gradient_checkpointing", False): |
There was a problem hiding this comment.
I see now what you meant earlier, this test has a lot of edge cases
There was a problem hiding this comment.
yes, it's a bit clunky to have this bool flag but wasn't seeing a simpler option
zucchini-nlp
left a comment
There was a problem hiding this comment.
Thanks again for handling all edge cases, was not an easy one
|
run-slow: align, altclip, chinese_clip, clap, clvp, falcon_mamba, fast_vlm, internvl, layoutlm, layoutlmv3, lilt, mamba, markuplm, mlcd, poolformer, siglip |
|
This comment contains models: ["models/align", "models/altclip", "models/chinese_clip", "models/clap", "models/clvp", "models/falcon_mamba", "models/fast_vlm", "models/internvl", "models/layoutlm", "models/layoutlmv3", "models/lilt", "models/mamba", "models/markuplm", "models/mlcd", "models/poolformer", "models/siglip"] |
CI Results✅ No failing test specific to this PR 🎉 ! |
ArthurZucker
left a comment
There was a problem hiding this comment.
Kudos very nice PR!
| "embeddings when using adapters or gradient checkpointing. Override `get_input_embeddings` to fully " | ||
| "support those features." |
There was a problem hiding this comment.
either that or sometimes just add a _input_embedding_layer
There was a problem hiding this comment.
i mean update the message please
| if not hasattr(model, "get_input_embeddings"): | ||
| continue |
There was a problem hiding this comment.
why not raise an error instead this way all new models wiil make sure they have this go green before merging?
There was a problem hiding this comment.
forgot to answer but: this would currently raise for many existing models
| if not getattr(self.model_tester, "is_training", False): | ||
| self.skipTest(reason="ModelTester is not configured to run training tests") |
There was a problem hiding this comment.
If this one is True by default for all models sg
There was a problem hiding this comment.
yes AFAIK, true for CausalLMTester
ArthurZucker
left a comment
There was a problem hiding this comment.
First of all thanks!
This will also fix some TP recompile issues cc @3outeille on hidden_states=hidden_states
| logger.warning_once( | ||
| f"{self.__class__.__name__} does not expose input embeddings. Gradients cannot flow back to the token " | ||
| "embeddings when using adapters or gradient checkpointing. Override `get_input_embeddings` to fully " | ||
| "support those features." |
There was a problem hiding this comment.
| "support those features." | |
| "support those features, or add the `_input_embedding_layer` attribut with the name of the embedding layer!" |
| grad_after_gc = embedding_param.grad | ||
| self.assertIsNotNone( | ||
| grad_after_gc, | ||
| f"{model_class.__name__} should produce embedding gradients when gradient checkpointing is enabled.", |
There was a problem hiding this comment.
if you have an idea of what could cause this to fail, add it!
| f"{model_class.__name__} produced non-finite gradients with gradient checkpointing enabled.", | ||
| ) | ||
| self.assertGreater( | ||
| grad_after_gc.abs().sum().item(), | ||
| 0, | ||
| f"{model_class.__name__} should keep non-zero embedding gradients with gradient checkpointing enabled.", | ||
| ) | ||
| has_verified_model = True | ||
|
|
||
| if not has_verified_model: | ||
| self.skipTest( | ||
| reason="No model with a differentiable loss was available to verify enable_input_require_grads with gradient checkpointing." | ||
| ) |
There was a problem hiding this comment.
same for each one of these, let's help the use to fix it!
ArthurZucker
left a comment
There was a problem hiding this comment.
Very nice! small updates and let's merge
| "embeddings when using adapters or gradient checkpointing. Override `get_input_embeddings` to fully " | ||
| "support those features." |
There was a problem hiding this comment.
i mean update the message please
|
[For maintainers] Suggested jobs to run (before merge) run-slow: align, altclip, chinese_clip, clap, clvp, falcon_mamba, fast_vlm, internvl, layoutlm, layoutlmv3, lilt, mamba, markuplm, mlcd, poolformer, siglip |
|
run-slow: align, altclip, chinese_clip, clap, clvp, falcon_mamba, fast_vlm, internvl, layoutlm, layoutlmv3, lilt, mamba, markuplm, mlcd, poolformer, siglip |
|
This comment contains models: ["models/align", "models/altclip", "models/chinese_clip", "models/clap", "models/clvp", "models/falcon_mamba", "models/fast_vlm", "models/internvl", "models/layoutlm", "models/layoutlmv3", "models/lilt", "models/mamba", "models/markuplm", "models/mlcd", "models/poolformer", "models/siglip"] |
|
Comments addressed. Merging and keeping an eye on this 👀 let's see if something breaks and how |
CI Results✅ No failing test specific to this PR 🎉 ! |
…t_embeddings (huggingface#42558) * add embedding getter * modify your own logic * a common test * some adapters are not PreTrainedModel s * few fixes * implement correct-ish fix? * fixup * this is needed likely * woops * solving some cross-imports issues here and there * more ximports issues * finally * revert changes * fixups * improve message * add common tests for input_ids first * increase test coverage * bigger update for GC * copies * mlcd is getting on my nerves a bit * ah yes * for BC * break a couple modelings * simplify with base_model * fix copies for torch checkpointing * simplify this model * improve messages
What does this PR do?
Ad title indicates, #42542 and likely a few other models are broken by merged #41993 . This adds an embedding getter and attempts to test the feature with more coverage.
Basically what it does
Should help GC for PEFT adapters for many VLMs hopefully (and normal models too)