diff --git a/src/transformers/models/esm/modeling_esm.py b/src/transformers/models/esm/modeling_esm.py index 3524b221a0ec..64821a84c78d 100755 --- a/src/transformers/models/esm/modeling_esm.py +++ b/src/transformers/models/esm/modeling_esm.py @@ -589,6 +589,7 @@ class EsmPreTrainedModel(PreTrainedModel): config: EsmConfig base_model_prefix = "esm" supports_gradient_checkpointing = True + accepts_loss_kwargs = False _no_split_modules = ["EsmLayer", "EsmFoldTriangularSelfAttentionBlock", "EsmEmbeddings"] _keys_to_ignore_on_load_unexpected = ["position_embeddings.weight"] _supports_flash_attn = True