Skip to content

Add THD support in ESM#44145

Merged
Rocketknight1 merged 6 commits intohuggingface:mainfrom
balvisio:dev/ba/support-thd-in-esm
Apr 9, 2026
Merged

Add THD support in ESM#44145
Rocketknight1 merged 6 commits intohuggingface:mainfrom
balvisio:dev/ba/support-thd-in-esm

Conversation

@balvisio
Copy link
Copy Markdown
Contributor

What does this PR do?

This PR adds support for sequence packing in the ESM2 model. Currently, the RotaryEmbedding class of the ESM2 model supports BSHD format. This PR makes the RotayEmbedding class aware of theposition_ids and builds the cos and sin tables accordingly.

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@balvisio balvisio force-pushed the dev/ba/support-thd-in-esm branch from d8cd5ba to 6c4a4e1 Compare February 19, 2026 03:18
@Rocketknight1
Copy link
Copy Markdown
Member

Hmm, this is an interesting PR! In general I agree that the ESM attention code is old and could use a refactor, but rather than doing it this way, could we refactor it along the lines of a modern masked LM? modernbert in particular has a clean and well-tested implementation, and it might be possible to use most of it.

One issue though - I remember from the original ESM PR that self.inv_freq is saved in the state dict and should not be recomputed because it was (iirc?) computed in float16. Therefore, if you recompute it with float32 values, the model output will be subtly different. Even though float32 is more accurate, we need to match what the model was trained with.

@balvisio
Copy link
Copy Markdown
Contributor Author

Thanks for taking a look at this. Understood; I can duplicate the implementation of ModernBertRotaryEmbedding in this file and adapt it as needed. If you’re okay with it, we could merge this once the tests are fixed, and then I can follow up with a proper refactor in a few weeks. Of course, happy to proceed however you prefer

@Rocketknight1
Copy link
Copy Markdown
Member

I'd prefer one single PR rather than this one plus a follow-up, if that's okay with you!

@balvisio
Copy link
Copy Markdown
Contributor Author

Sounds good. I’ll take care of it in a few days

@balvisio balvisio force-pushed the dev/ba/support-thd-in-esm branch from 6c4a4e1 to 57e4add Compare March 17, 2026 21:45
@balvisio
Copy link
Copy Markdown
Contributor Author

Hi @Rocketknight1 , following your suggestion, I refactored ESMModel making it more aligned with the ModernBERT implementation. I have also took care of the inv_freq buffer initialization so that it is backwards compatible; one of the tests assets that casting the buffer to float16 givesthe exact result.
Thank you for reviewing.

@balvisio balvisio force-pushed the dev/ba/support-thd-in-esm branch from 57e4add to 6baebbc Compare March 17, 2026 21:49
@Rocketknight1
Copy link
Copy Markdown
Member

Hi @balvisio, please check the CI first! There might be issues here, but I can review after it's green.

@balvisio balvisio force-pushed the dev/ba/support-thd-in-esm branch 5 times, most recently from 18803d3 to 7516c81 Compare March 19, 2026 15:03
@balvisio
Copy link
Copy Markdown
Contributor Author

Hi @Rocketknight1 : CI is green now. Thanks!

@Rocketknight1
Copy link
Copy Markdown
Member

run-slow: esm, evolla

@github-actions
Copy link
Copy Markdown
Contributor

Workflow Run ⚙️

This comment contains run-slow, running the specified jobs:

models: ["models/esm", "models/evolla"]
quantizations: []

@github-actions
Copy link
Copy Markdown
Contributor

CI Results

Workflow Run ⚙️

Commit Info

Context Commit Description
RUN 23f6c758 workflow commit (merge commit)
PR 7516c816 branch commit (from PR)
main 82db888e base commit (on main)

✅ No failing test specific to this PR 🎉 👏 !

Comment on lines +732 to +733
if position_ids is None:
position_ids = torch.arange(seq_len, device=device).unsqueeze(0)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One nit I'm not sure about: Does this need to be padding-aware? I think the ESM code also covers the ESM-1 models, which used absolute embeddings, but I worry padding tokens might disrupt the position IDs here. Possibly an issue in training if MLM random-masking selects padding tokens?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Rocketknight1 : You are correct. This was breaking for absolute embeddings. Fixed it in last commit.

@balvisio balvisio force-pushed the dev/ba/support-thd-in-esm branch 3 times, most recently from 8694d46 to 87f5642 Compare March 20, 2026 17:38
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.


def __post_init__(self, **kwargs):
if self.is_folding_model:
self.model_type = "esmfold"
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I missed this one the first time - is there a reason we want this line, when it wasn't there before? Do we use that model type elsewhere in the code?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Rocketknight1 , the reason is because the the two models need different conversion behavior.

  1. I added a conversion mapping for "esm" in conversion_mapping.py to handle the inv_freq key rename.
  2. Since the default EsmForProteinFolding doesn't have rotary embeddings the test

By setting self.model_type = "esmfold" when is_folding_model=True, get_checkpoint_conversion_mapping("esmfold") returns None.

Let me know if this should be handled in a different way.

Copy link
Copy Markdown
Member

@Rocketknight1 Rocketknight1 Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fine because if the mapping doesn't match a weight, then nothing happens and we're okay, right? So having a mapping entry that isn't valid for our model isn't a problem, and we can just remove the self.model_type line

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is with the test_reverse_loading_mapping. With EsmFold:

tests/models/esm/test_modeling_esmfold.py::EsmFoldModelTest::test_reverse_loading_mapping FAILED                                                              [100%]

============================================================================= FAILURES ==============================================================================
___________________________________________________________ EsmFoldModelTest.test_reverse_loading_mapping ___________________________________________________________

self = <tests.models.esm.test_modeling_esmfold.EsmFoldModelTest testMethod=test_reverse_loading_mapping>, check_keys_were_modified = True

    def test_reverse_loading_mapping(self, check_keys_were_modified=True):
        """Make sure we can load and save correctly the models having any weight renaming mapping or weight conversion
        mapping.
        Note that this test would be better if we could start from the serialized keys, and check that the model
        keys correspond to the weight converions. However, when instantiating a model, it already has the "target"
        keys (or modified keys after mapping) of the conversion mapping, so we have to do it the other way, i.e.
        reverse the conversion and then check that those converted keys match correctly the conversions.
    
        However, all the checks performed here should ensure everything is going as it should.
    
        Args:
            check_keys_were_modified (`bool`, *optional*, defaults to `True`):
                Whether to expect keys being modified or not. In some cases, models do not change keys but
                their weights, e.g. via transpose, memory alignment, etc.
        """
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()
    
        #  Some MoE models alternate between a classic MLP and a MoE layer, in which case we want to have at
        # lest one MoE layer here to check the mapping
        config_to_set = config.get_text_config(decoder=True)
        config_to_set.first_k_dense_replace = 1  # means that the first layer (idx 0) will be MLP, then MoE
        config_to_set.moe_layer_start_index = 1  # same as above but for Ernie 4.5...
        config_to_set.mlp_only_layers = [0]  # same but for qwens
        config_to_set.num_dense_layers = 1  # lfm2_moe
    
        for model_class in self.all_model_classes:
            # Each individual model is a subtest
            with self.subTest(model_class.__name__):
                model = model_class(copy.deepcopy(config))
                # Skip if no conversions
                conversions = get_model_conversion_mapping(model, add_legacy=False)
                if len(conversions) == 0:
                    self.skipTest("No conversion found for this model")
    
                # Find the model keys, so the targets according to the conversions
                model_keys = list(model.state_dict().keys())
    
                with tempfile.TemporaryDirectory() as tmpdirname:
                    # Serialize with reverse mapping
                    model.save_pretrained(tmpdirname)
                    state_dict = load_file(os.path.join(tmpdirname, "model.safetensors"))
                    # Get all the serialized keys that we just saved according to the reverse mapping
                    serialized_keys = list(state_dict.keys())
    
                if check_keys_were_modified:
                    # They should be different, otherwise we did not perform any mapping
                    self.assertNotEqual(sorted(serialized_keys), sorted(model_keys), "No key mapping was performed!")
    
                # Check that for each conversion entry, we at least map to one key
                for conversion in conversions:
                    for source_pattern in conversion.source_patterns:
                        # Some patterns are written for gen-model only and won't be applied on base model
                        if "lm_head" in source_pattern and model_class not in [
                            *get_values(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES),
                            *get_values(MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES),
                        ]:
                            continue
    
                        # Sometimes the mappings specify keys that are tied, so absent from the saved state dict
                        if isinstance(conversion, WeightRenaming):
                            # We need to revert the target pattern to make it compatible with regex search
                            target_pattern_reversed = conversion.target_patterns[0]
                            captured_group = process_target_pattern(source_pattern)[1]
                            if captured_group:
                                target_pattern_reversed = target_pattern_reversed.replace(r"\1", captured_group)
                            if any(re.search(target_pattern_reversed, k) for k in model.all_tied_weights_keys.keys()):
                                continue
                        num_matches = sum(re.search(source_pattern, key) is not None for key in serialized_keys)
>                       self.assertTrue(
                            num_matches > 0,
                            f"`{source_pattern}` in `{conversion}` did not match any of the source keys. "
                            "This indicates whether that the pattern is not properly written, ot that it could not be reversed correctly",
                        )
E                       AssertionError: False is not true : `encoder.layer.*.attention.self.rotary_embeddings.inv_freq` in `WeightRenaming(source_patterns=['encoder.layer.*.attention.self.rotary_embeddings.inv_freq'], target_patterns=['rotary_embeddings.inv_freq'], compiled_sources=re.compile('(?P<g0>encoder.layer\\..*\\.attention.self.rotary_embeddings.inv_freq)'), distributed_operation=None, quantization_operation=None, collected_tensors=defaultdict(<class 'list'>, {}), layer_targets=defaultdict(<class 'set'>, {}))` did not match any of the source keys. This indicates whether that the pattern is not properly written, ot that it could not be reversed correctly

tests/test_modeling_common.py:4754: AssertionError

I think the problem is that since by default EsmFold uses absolute embeddings the rotary_embeddings.inv_freq is not found.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could override test_reverse_loading_mapping to use a modfied EsmFoldConfig with rotary position embeddings instead.

Copy link
Copy Markdown
Member

@Rocketknight1 Rocketknight1 Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think the test override makes sense here. The reason this is a little fiddly is that we generally prefer to have one model class correspond to one architecture, with even subtle changes being split into another class. However, we didn't do that with esm / esmfold, mostly because of a rushed release!

The result is that some test assumptions are violated for this class, which I didn't realize at first so thanks for pointing it out! The test assumes, as you said, that the mappings all match something. I think adding an override to either skip the test or modify it for ESM/ESMFold is correct!

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. Removed self.model_type = "esmfold" and changed the test to use a config with rotary embeddings.

@Rocketknight1
Copy link
Copy Markdown
Member

@balvisio almost ready to go, one final thing I missed the first time around!

@balvisio balvisio force-pushed the dev/ba/support-thd-in-esm branch 2 times, most recently from cecdc72 to 397c158 Compare March 25, 2026 15:56
@balvisio
Copy link
Copy Markdown
Contributor Author

@Rocketknight1 The CI is failing. The failing test doesn't seem related to my change and it passes locally

@Rocketknight1
Copy link
Copy Markdown
Member

@balvisio can you try rebasing and resolving the merge conflict? Hopefully tests pass after that, if not just ping me and I'll see what I can do with the CI.

@balvisio balvisio force-pushed the dev/ba/support-thd-in-esm branch from 397c158 to e28274e Compare March 26, 2026 15:49
@Rocketknight1 Rocketknight1 enabled auto-merge March 30, 2026 15:19
@Rocketknight1
Copy link
Copy Markdown
Member

CI issues, will fix with #45123

@Rocketknight1 Rocketknight1 force-pushed the dev/ba/support-thd-in-esm branch from 110cb67 to 0eb7f09 Compare March 30, 2026 16:29
@Rocketknight1 Rocketknight1 added this pull request to the merge queue Mar 30, 2026
@github-merge-queue github-merge-queue Bot removed this pull request from the merge queue due to failed status checks Mar 30, 2026
@Rocketknight1 Rocketknight1 enabled auto-merge March 30, 2026 17:48
@Rocketknight1 Rocketknight1 added this pull request to the merge queue Mar 30, 2026
@github-merge-queue github-merge-queue Bot removed this pull request from the merge queue due to failed status checks Mar 30, 2026
@balvisio balvisio force-pushed the dev/ba/support-thd-in-esm branch from 4fd4a3e to 0c7fc83 Compare March 30, 2026 23:30
@balvisio
Copy link
Copy Markdown
Contributor Author

@Rocketknight1 : I added one more commit that adds missing 'bos_token_id' and 'eos_token_id' fields to the ESM Config. Since they are not there I am reaching this line https://github.com/huggingface/transformers/blob/8213e0d920d52cb00dcade16b6d1f6e952ac0a8c/src/transformers/modeling_utils.py#L4372 and that throws the following exception:

self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  File "/usr/local/lib/python3.12/dist-packages/transformers/modeling_utils.py", line 4379, in warn_if_padding_and_no_attention_mask
    (self.config.bos_token_id is not None and self.config.bos_token_id == self.config.pad_token_id)
     ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/transformers/configuration_utils.py", line 164, in __getattribute__
    return super().__getattribute__(key)

@Rocketknight1
Copy link
Copy Markdown
Member

@balvisio nice, thank you! Trying to merge again, not sure why the CI is complaining.

balvisio added 6 commits April 2, 2026 13:00
Signed-off-by: Bruno Alvisio <balvisio@nvidia.com>
Signed-off-by: Bruno Alvisio <balvisio@nvidia.com>
Signed-off-by: Bruno Alvisio <balvisio@nvidia.com>
Signed-off-by: Bruno Alvisio <balvisio@nvidia.com>
Signed-off-by: Bruno Alvisio <balvisio@nvidia.com>
@Rocketknight1 Rocketknight1 force-pushed the dev/ba/support-thd-in-esm branch from 0c7fc83 to 73078b0 Compare April 2, 2026 12:00
@Rocketknight1 Rocketknight1 enabled auto-merge April 2, 2026 12:00
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Apr 2, 2026

[For maintainers] Suggested jobs to run (before merge)

run-slow: esm, evolla

@Rocketknight1 Rocketknight1 added this pull request to the merge queue Apr 2, 2026
@github-merge-queue github-merge-queue Bot removed this pull request from the merge queue due to no response for status checks Apr 2, 2026
@balvisio
Copy link
Copy Markdown
Contributor Author

balvisio commented Apr 7, 2026

Hi @Rocketknight1 ping on this :) Is there something still going on with the CI?

@Rocketknight1 Rocketknight1 added this pull request to the merge queue Apr 9, 2026
@Rocketknight1
Copy link
Copy Markdown
Member

@balvisio yeah, I'll investigate when I get a chance. In the meantime I'll keep merging and hope we get through!

Merged via the queue into huggingface:main with commit 3170e36 Apr 9, 2026
29 checks passed
@Rocketknight1
Copy link
Copy Markdown
Member

@balvisio CI defeated at last. Thank you again for the PR!

@balvisio
Copy link
Copy Markdown
Contributor Author

balvisio commented Apr 9, 2026

Thank you!

sirzechs66 pushed a commit to sirzechs66/transformers that referenced this pull request Apr 18, 2026
* Add THD support in ESM

Signed-off-by: Bruno Alvisio <balvisio@nvidia.com>

* Remove commented code

Signed-off-by: Bruno Alvisio <balvisio@nvidia.com>

* Refactored Evolla model

Signed-off-by: Bruno Alvisio <balvisio@nvidia.com>

* Fix position_ids for aboslute embeddings

Signed-off-by: Bruno Alvisio <balvisio@nvidia.com>

* Fix test_reverse_loading_mapping for EsmFold

Signed-off-by: Bruno Alvisio <balvisio@nvidia.com>

* Add missing 'bos_token_id' and 'eos_token_id' to ESM config

---------

Signed-off-by: Bruno Alvisio <balvisio@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants