From deab3a9859e189a6bc07c221f80d2a989fad8b7a Mon Sep 17 00:00:00 2001 From: harshaljanjani Date: Wed, 17 Sep 2025 23:07:03 +0400 Subject: [PATCH 1/3] fix(timm): Add exception handling for unknown Gemma3n model --- .../timm_wrapper/modeling_timm_wrapper.py | 27 ++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py b/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py index 7839bf7813f2..b4c33b609b23 100644 --- a/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py +++ b/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py @@ -55,6 +55,27 @@ class TimmWrapperModelOutput(ModelOutput): attentions: Optional[tuple[torch.FloatTensor, ...]] = None +def _create_timm_model_with_error_handling(config: "TimmWrapperConfig", **model_kwargs): + """ + Creates a timm model and provides a clear error message if the model is not found, + suggesting a library update. + """ + try: + model = timm.create_model( + config.architecture, + pretrained=False, + **model_kwargs, + ) + return model + except RuntimeError as e: + if "Unknown model" in str(e): + raise ImportError( + f"The model architecture '{config.architecture}' is not supported in your version of timm ({timm.__version__}). " + "Please upgrade timm to a more recent version with `pip install -U timm`." + ) from e + raise e + + @auto_docstring class TimmWrapperPreTrainedModel(PreTrainedModel): main_input_name = "pixel_values" @@ -138,7 +159,7 @@ def __init__(self, config: TimmWrapperConfig): super().__init__(config) # using num_classes=0 to avoid creating classification head extra_init_kwargs = config.model_args or {} - self.timm_model = timm.create_model(config.architecture, pretrained=False, num_classes=0, **extra_init_kwargs) + self.timm_model = _create_timm_model_with_error_handling(config, num_classes=0, **extra_init_kwargs) self.post_init() @auto_docstring @@ -254,8 +275,8 @@ def __init__(self, config: TimmWrapperConfig): ) extra_init_kwargs = config.model_args or {} - self.timm_model = timm.create_model( - config.architecture, pretrained=False, num_classes=config.num_labels, **extra_init_kwargs + self.timm_model = _create_timm_model_with_error_handling( + config, num_classes=config.num_labels, **extra_init_kwargs ) self.num_labels = config.num_labels self.post_init() From bdfbe333d534e729997259e157b8c909f64a6d5c Mon Sep 17 00:00:00 2001 From: harshaljanjani Date: Thu, 18 Sep 2025 00:12:55 +0400 Subject: [PATCH 2/3] =?UTF-8?q?nit:=20Let=E2=80=99s=20cater=20to=20this=20?= =?UTF-8?q?specific=20issue?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../models/timm_wrapper/modeling_timm_wrapper.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py b/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py index b4c33b609b23..eff93e0c069a 100644 --- a/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py +++ b/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py @@ -69,6 +69,13 @@ def _create_timm_model_with_error_handling(config: "TimmWrapperConfig", **model_ return model except RuntimeError as e: if "Unknown model" in str(e): + if "mobilenetv5_300m_enc" in config.architecture: + raise ImportError( + f"You are trying to load a model that uses '{config.architecture}', the vision backbone for Gemma 3n. " + f"This architecture is not supported in your version of timm ({timm.__version__}). " + "Please upgrade to timm >= 1.0.16 with: `pip install -U timm`." + ) from e + # A good general check for other unknown models too. raise ImportError( f"The model architecture '{config.architecture}' is not supported in your version of timm ({timm.__version__}). " "Please upgrade timm to a more recent version with `pip install -U timm`." From 018c2215fd83784881294d8d6760038ce186c825 Mon Sep 17 00:00:00 2001 From: harshaljanjani Date: Thu, 18 Sep 2025 16:44:16 +0400 Subject: [PATCH 3/3] nit: Simplify error handling --- .../models/timm_wrapper/modeling_timm_wrapper.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py b/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py index eff93e0c069a..cfc3c1c104d3 100644 --- a/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py +++ b/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py @@ -69,13 +69,7 @@ def _create_timm_model_with_error_handling(config: "TimmWrapperConfig", **model_ return model except RuntimeError as e: if "Unknown model" in str(e): - if "mobilenetv5_300m_enc" in config.architecture: - raise ImportError( - f"You are trying to load a model that uses '{config.architecture}', the vision backbone for Gemma 3n. " - f"This architecture is not supported in your version of timm ({timm.__version__}). " - "Please upgrade to timm >= 1.0.16 with: `pip install -U timm`." - ) from e - # A good general check for other unknown models too. + # A good general check for unknown models. raise ImportError( f"The model architecture '{config.architecture}' is not supported in your version of timm ({timm.__version__}). " "Please upgrade timm to a more recent version with `pip install -U timm`."