Fix modular for modernbert-decoder#40431
Conversation
|
[For maintainers] Suggested jobs to run (before merge) run-slow: modernbert_decoder |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
| from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging | ||
| from ...utils.deprecation import deprecate_kwarg | ||
| from ...utils.generic import check_model_inputs | ||
| from ..modernbert.modeling_modernbert import ( |
There was a problem hiding this comment.
interesting I thought models.llama would work, but no
There was a problem hiding this comment.
Same, will check and probably upstream it to the converter to avoid it in the future
vasqu
left a comment
There was a problem hiding this comment.
Thanks! Just nits on my side
| elif module.__class__.__name__ == "ModernBertDecoderForSequenceClassification": | ||
| init_weight(module.classifier, stds["final_out"]) | ||
| elif isinstance(module, ModernBertDecoderForCausalLM): | ||
| elif module.__class__.__name__ == "ModernBertDecoderForCausalLM": |
There was a problem hiding this comment.
Might be a dumb question but why can't we check for the instance here?
There was a problem hiding this comment.
This one could be, but the other ModernBertDecoderForSequenceClassification would be matched by modular, and thus wrongly imported - for consistency made the check on name on the 2 "higher-level" classes
|
|
||
| @auto_docstring | ||
| class ModernBertDecoderPreTrainedModel(ModernBertPreTrainedModel): | ||
| config: ModernBertDecoderConfig |
There was a problem hiding this comment.
No need to redefine, it's inherited
| config: ModernBertDecoderConfig | ||
| _skip_keys_device_placement = ["past_key_values"] | ||
| _no_split_modules = ["ModernBertDecoderLayer"] | ||
| _can_compile_fullgraph = False |
There was a problem hiding this comment.
Same here? Or change the flag maybe?
There was a problem hiding this comment.
It's false by default, no need to add it
What does this PR do?
The modular was ill-formed, resulting in mostly skipping all the rules and instead importing the classes from modernbert in the modeling (which is illegal as we want 1-model -> 1-file).
This fixes it.
Note that as it was mistakenly inheriting from
ModernBertPreTrainedModelinstead of using modular rules and correctly rewriting the code, the model was using FA2 by default instead of sdpa before #40350 (review). As it was a mistake and all models should use sdpa by default unless extreme exception (as far as I know, ModernBert is the only one), I did not revert the new sdpa default (it's done naturally now that modular is fixed).