diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 438465f05..1e7a26832 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -2439,7 +2439,7 @@ def get_pretrained_state_dict( hf_model = AutoModelForCausalLM.from_pretrained( official_model_name, revision=f"checkpoint-{cfg.checkpoint_value}", - torch_dtype=dtype, + dtype=dtype, token=huggingface_token if len(huggingface_token) > 0 else None, **kwargs, ) @@ -2447,7 +2447,7 @@ def get_pretrained_state_dict( hf_model = AutoModelForCausalLM.from_pretrained( official_model_name, revision=f"step{cfg.checkpoint_value}", - torch_dtype=dtype, + dtype=dtype, token=huggingface_token, **kwargs, ) @@ -2460,28 +2460,28 @@ def get_pretrained_state_dict( elif "hubert" in official_model_name: hf_model = HubertModel.from_pretrained( official_model_name, - torch_dtype=dtype, + dtype=dtype, token=huggingface_token if len(huggingface_token) > 0 else None, **kwargs, ) elif "wav2vec2" in official_model_name: hf_model = Wav2Vec2Model.from_pretrained( official_model_name, - torch_dtype=dtype, + dtype=dtype, token=huggingface_token if len(huggingface_token) > 0 else None, **kwargs, ) elif "bert" in official_model_name: hf_model = BertForPreTraining.from_pretrained( official_model_name, - torch_dtype=dtype, + dtype=dtype, token=huggingface_token if len(huggingface_token) > 0 else None, **kwargs, ) elif "t5" in official_model_name: hf_model = T5ForConditionalGeneration.from_pretrained( official_model_name, - torch_dtype=dtype, + dtype=dtype, token=huggingface_token if len(huggingface_token) > 0 else None, **kwargs, ) @@ -2491,14 +2491,14 @@ def get_pretrained_state_dict( hf_model = AutoModel.from_pretrained( official_model_name, - torch_dtype=dtype, + dtype=dtype, token=huggingface_token if len(huggingface_token) > 0 else None, **kwargs, ) else: hf_model = AutoModelForCausalLM.from_pretrained( official_model_name, - torch_dtype=dtype, + dtype=dtype, token=huggingface_token if len(huggingface_token) > 0 else None, **kwargs, )