Add THD support in ESM#44145
Conversation
d8cd5ba to
6c4a4e1
Compare
|
Hmm, this is an interesting PR! In general I agree that the ESM attention code is old and could use a refactor, but rather than doing it this way, could we refactor it along the lines of a modern masked LM? One issue though - I remember from the original ESM PR that |
|
Thanks for taking a look at this. Understood; I can duplicate the implementation of ModernBertRotaryEmbedding in this file and adapt it as needed. If you’re okay with it, we could merge this once the tests are fixed, and then I can follow up with a proper refactor in a few weeks. Of course, happy to proceed however you prefer |
|
I'd prefer one single PR rather than this one plus a follow-up, if that's okay with you! |
|
Sounds good. I’ll take care of it in a few days |
6c4a4e1 to
57e4add
Compare
|
Hi @Rocketknight1 , following your suggestion, I refactored |
57e4add to
6baebbc
Compare
|
Hi @balvisio, please check the CI first! There might be issues here, but I can review after it's green. |
18803d3 to
7516c81
Compare
|
Hi @Rocketknight1 : CI is green now. Thanks! |
|
run-slow: esm, evolla |
|
This comment contains models: ["models/esm", "models/evolla"] |
| if position_ids is None: | ||
| position_ids = torch.arange(seq_len, device=device).unsqueeze(0) |
There was a problem hiding this comment.
One nit I'm not sure about: Does this need to be padding-aware? I think the ESM code also covers the ESM-1 models, which used absolute embeddings, but I worry padding tokens might disrupt the position IDs here. Possibly an issue in training if MLM random-masking selects padding tokens?
There was a problem hiding this comment.
@Rocketknight1 : You are correct. This was breaking for absolute embeddings. Fixed it in last commit.
8694d46 to
87f5642
Compare
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
|
||
| def __post_init__(self, **kwargs): | ||
| if self.is_folding_model: | ||
| self.model_type = "esmfold" |
There was a problem hiding this comment.
I missed this one the first time - is there a reason we want this line, when it wasn't there before? Do we use that model type elsewhere in the code?
There was a problem hiding this comment.
Hi @Rocketknight1 , the reason is because the the two models need different conversion behavior.
- I added a conversion mapping for "esm" in conversion_mapping.py to handle the inv_freq key rename.
- Since the default
EsmForProteinFoldingdoesn't have rotary embeddings the test
By setting self.model_type = "esmfold" when is_folding_model=True, get_checkpoint_conversion_mapping("esmfold") returns None.
Let me know if this should be handled in a different way.
There was a problem hiding this comment.
I think it's fine because if the mapping doesn't match a weight, then nothing happens and we're okay, right? So having a mapping entry that isn't valid for our model isn't a problem, and we can just remove the self.model_type line
There was a problem hiding this comment.
The problem is with the test_reverse_loading_mapping. With EsmFold:
tests/models/esm/test_modeling_esmfold.py::EsmFoldModelTest::test_reverse_loading_mapping FAILED [100%]
============================================================================= FAILURES ==============================================================================
___________________________________________________________ EsmFoldModelTest.test_reverse_loading_mapping ___________________________________________________________
self = <tests.models.esm.test_modeling_esmfold.EsmFoldModelTest testMethod=test_reverse_loading_mapping>, check_keys_were_modified = True
def test_reverse_loading_mapping(self, check_keys_were_modified=True):
"""Make sure we can load and save correctly the models having any weight renaming mapping or weight conversion
mapping.
Note that this test would be better if we could start from the serialized keys, and check that the model
keys correspond to the weight converions. However, when instantiating a model, it already has the "target"
keys (or modified keys after mapping) of the conversion mapping, so we have to do it the other way, i.e.
reverse the conversion and then check that those converted keys match correctly the conversions.
However, all the checks performed here should ensure everything is going as it should.
Args:
check_keys_were_modified (`bool`, *optional*, defaults to `True`):
Whether to expect keys being modified or not. In some cases, models do not change keys but
their weights, e.g. via transpose, memory alignment, etc.
"""
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
# Some MoE models alternate between a classic MLP and a MoE layer, in which case we want to have at
# lest one MoE layer here to check the mapping
config_to_set = config.get_text_config(decoder=True)
config_to_set.first_k_dense_replace = 1 # means that the first layer (idx 0) will be MLP, then MoE
config_to_set.moe_layer_start_index = 1 # same as above but for Ernie 4.5...
config_to_set.mlp_only_layers = [0] # same but for qwens
config_to_set.num_dense_layers = 1 # lfm2_moe
for model_class in self.all_model_classes:
# Each individual model is a subtest
with self.subTest(model_class.__name__):
model = model_class(copy.deepcopy(config))
# Skip if no conversions
conversions = get_model_conversion_mapping(model, add_legacy=False)
if len(conversions) == 0:
self.skipTest("No conversion found for this model")
# Find the model keys, so the targets according to the conversions
model_keys = list(model.state_dict().keys())
with tempfile.TemporaryDirectory() as tmpdirname:
# Serialize with reverse mapping
model.save_pretrained(tmpdirname)
state_dict = load_file(os.path.join(tmpdirname, "model.safetensors"))
# Get all the serialized keys that we just saved according to the reverse mapping
serialized_keys = list(state_dict.keys())
if check_keys_were_modified:
# They should be different, otherwise we did not perform any mapping
self.assertNotEqual(sorted(serialized_keys), sorted(model_keys), "No key mapping was performed!")
# Check that for each conversion entry, we at least map to one key
for conversion in conversions:
for source_pattern in conversion.source_patterns:
# Some patterns are written for gen-model only and won't be applied on base model
if "lm_head" in source_pattern and model_class not in [
*get_values(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES),
*get_values(MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES),
]:
continue
# Sometimes the mappings specify keys that are tied, so absent from the saved state dict
if isinstance(conversion, WeightRenaming):
# We need to revert the target pattern to make it compatible with regex search
target_pattern_reversed = conversion.target_patterns[0]
captured_group = process_target_pattern(source_pattern)[1]
if captured_group:
target_pattern_reversed = target_pattern_reversed.replace(r"\1", captured_group)
if any(re.search(target_pattern_reversed, k) for k in model.all_tied_weights_keys.keys()):
continue
num_matches = sum(re.search(source_pattern, key) is not None for key in serialized_keys)
> self.assertTrue(
num_matches > 0,
f"`{source_pattern}` in `{conversion}` did not match any of the source keys. "
"This indicates whether that the pattern is not properly written, ot that it could not be reversed correctly",
)
E AssertionError: False is not true : `encoder.layer.*.attention.self.rotary_embeddings.inv_freq` in `WeightRenaming(source_patterns=['encoder.layer.*.attention.self.rotary_embeddings.inv_freq'], target_patterns=['rotary_embeddings.inv_freq'], compiled_sources=re.compile('(?P<g0>encoder.layer\\..*\\.attention.self.rotary_embeddings.inv_freq)'), distributed_operation=None, quantization_operation=None, collected_tensors=defaultdict(<class 'list'>, {}), layer_targets=defaultdict(<class 'set'>, {}))` did not match any of the source keys. This indicates whether that the pattern is not properly written, ot that it could not be reversed correctly
tests/test_modeling_common.py:4754: AssertionError
I think the problem is that since by default EsmFold uses absolute embeddings the rotary_embeddings.inv_freq is not found.
There was a problem hiding this comment.
I could override test_reverse_loading_mapping to use a modfied EsmFoldConfig with rotary position embeddings instead.
There was a problem hiding this comment.
Yes, I think the test override makes sense here. The reason this is a little fiddly is that we generally prefer to have one model class correspond to one architecture, with even subtle changes being split into another class. However, we didn't do that with esm / esmfold, mostly because of a rushed release!
The result is that some test assumptions are violated for this class, which I didn't realize at first so thanks for pointing it out! The test assumes, as you said, that the mappings all match something. I think adding an override to either skip the test or modify it for ESM/ESMFold is correct!
There was a problem hiding this comment.
Sounds good. Removed self.model_type = "esmfold" and changed the test to use a config with rotary embeddings.
|
@balvisio almost ready to go, one final thing I missed the first time around! |
cecdc72 to
397c158
Compare
|
@Rocketknight1 The CI is failing. The failing test doesn't seem related to my change and it passes locally |
|
@balvisio can you try rebasing and resolving the merge conflict? Hopefully tests pass after that, if not just ping me and I'll see what I can do with the CI. |
397c158 to
e28274e
Compare
|
CI issues, will fix with #45123 |
110cb67 to
0eb7f09
Compare
4fd4a3e to
0c7fc83
Compare
|
@Rocketknight1 : I added one more commit that adds missing 'bos_token_id' and 'eos_token_id' fields to the ESM Config. Since they are not there I am reaching this line |
|
@balvisio nice, thank you! Trying to merge again, not sure why the CI is complaining. |
Signed-off-by: Bruno Alvisio <balvisio@nvidia.com>
Signed-off-by: Bruno Alvisio <balvisio@nvidia.com>
Signed-off-by: Bruno Alvisio <balvisio@nvidia.com>
Signed-off-by: Bruno Alvisio <balvisio@nvidia.com>
Signed-off-by: Bruno Alvisio <balvisio@nvidia.com>
0c7fc83 to
73078b0
Compare
|
[For maintainers] Suggested jobs to run (before merge) run-slow: esm, evolla |
|
Hi @Rocketknight1 ping on this :) Is there something still going on with the CI? |
|
@balvisio yeah, I'll investigate when I get a chance. In the meantime I'll keep merging and hope we get through! |
|
@balvisio CI defeated at last. Thank you again for the PR! |
|
Thank you! |
* Add THD support in ESM Signed-off-by: Bruno Alvisio <balvisio@nvidia.com> * Remove commented code Signed-off-by: Bruno Alvisio <balvisio@nvidia.com> * Refactored Evolla model Signed-off-by: Bruno Alvisio <balvisio@nvidia.com> * Fix position_ids for aboslute embeddings Signed-off-by: Bruno Alvisio <balvisio@nvidia.com> * Fix test_reverse_loading_mapping for EsmFold Signed-off-by: Bruno Alvisio <balvisio@nvidia.com> * Add missing 'bos_token_id' and 'eos_token_id' to ESM config --------- Signed-off-by: Bruno Alvisio <balvisio@nvidia.com>
What does this PR do?
This PR adds support for sequence packing in the ESM2 model. Currently, the RotaryEmbedding class of the ESM2 model supports BSHD format. This PR makes the RotayEmbedding class aware of the
position_idsand builds the cos and sin tables accordingly.Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.