From a11f0f81cfbba57952855e7d677c34900f7b862f Mon Sep 17 00:00:00 2001 From: Zachary Goldfine Date: Mon, 30 Mar 2026 12:47:26 -0400 Subject: [PATCH] Fix deprecated `torch_dtype` argument in HuggingFace `from_pretrained` calls Replace `torch_dtype=dtype` with `dtype=dtype` in all internal calls to HuggingFace's `from_pretrained()` methods in loading_from_pretrained.py. The `torch_dtype` parameter is deprecated in recent versions of the transformers library in favor of `dtype`. The backwards-compatible acceptance of `torch_dtype` from users (lines 2389-2391) is preserved so existing user code continues to work. Fixes #1093 Co-Authored-By: Claude Opus 4.6 (1M context) --- transformer_lens/loading_from_pretrained.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 438465f05..1e7a26832 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -2439,7 +2439,7 @@ def get_pretrained_state_dict( hf_model = AutoModelForCausalLM.from_pretrained( official_model_name, revision=f"checkpoint-{cfg.checkpoint_value}", - torch_dtype=dtype, + dtype=dtype, token=huggingface_token if len(huggingface_token) > 0 else None, **kwargs, ) @@ -2447,7 +2447,7 @@ def get_pretrained_state_dict( hf_model = AutoModelForCausalLM.from_pretrained( official_model_name, revision=f"step{cfg.checkpoint_value}", - torch_dtype=dtype, + dtype=dtype, token=huggingface_token, **kwargs, ) @@ -2460,28 +2460,28 @@ def get_pretrained_state_dict( elif "hubert" in official_model_name: hf_model = HubertModel.from_pretrained( official_model_name, - torch_dtype=dtype, + dtype=dtype, token=huggingface_token if len(huggingface_token) > 0 else None, **kwargs, ) elif "wav2vec2" in official_model_name: hf_model = Wav2Vec2Model.from_pretrained( official_model_name, - torch_dtype=dtype, + dtype=dtype, token=huggingface_token if len(huggingface_token) > 0 else None, **kwargs, ) elif "bert" in official_model_name: hf_model = BertForPreTraining.from_pretrained( official_model_name, - torch_dtype=dtype, + dtype=dtype, token=huggingface_token if len(huggingface_token) > 0 else None, **kwargs, ) elif "t5" in official_model_name: hf_model = T5ForConditionalGeneration.from_pretrained( official_model_name, - torch_dtype=dtype, + dtype=dtype, token=huggingface_token if len(huggingface_token) > 0 else None, **kwargs, ) @@ -2491,14 +2491,14 @@ def get_pretrained_state_dict( hf_model = AutoModel.from_pretrained( official_model_name, - torch_dtype=dtype, + dtype=dtype, token=huggingface_token if len(huggingface_token) > 0 else None, **kwargs, ) else: hf_model = AutoModelForCausalLM.from_pretrained( official_model_name, - torch_dtype=dtype, + dtype=dtype, token=huggingface_token if len(huggingface_token) > 0 else None, **kwargs, )