diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index 5cce47eec..3ce5e9e25 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -1089,6 +1089,7 @@ def from_pretrained( fold_value_biases: Optional[bool] = True, default_prepend_bos: Optional[bool] = True, default_padding_side: Optional[Literal["left", "right"]] = "right", + max_context_length: Optional[int] = 2048, dtype="float32", **from_pretrained_kwargs, ) -> "HookedTransformer": @@ -1219,6 +1220,9 @@ def from_pretrained( other HuggingFace functions when compatible. For some models or arguments it doesn't work, especially for models that are not internally loaded with HuggingFace's from_pretrained (e.g. SoLU models). + max_context_length: The maximum context length to use for the model. Defaults to 2048. Can be set to + None to use the full context length of the model. Unless a larger context length is needed, it is + recommended to use the default value, as longer context lengths are highly memory intensive. dtype: What data type to load the model in (also sets the dtype of the HuggingFace model). Set to bfloat16 or float16 if you get out of memory errors when loading the model. @@ -1260,6 +1264,7 @@ def from_pretrained( device=device, n_devices=n_devices, default_prepend_bos=default_prepend_bos, + max_context_length=max_context_length, dtype=dtype, **from_pretrained_kwargs, ) @@ -1326,6 +1331,7 @@ def from_pretrained_no_processing( dtype=torch.float32, default_prepend_bos=True, default_padding_side="right", + max_context_length=None, **from_pretrained_kwargs, ): """Wrapper for from_pretrained. @@ -1343,6 +1349,7 @@ def from_pretrained_no_processing( dtype=dtype, default_prepend_bos=default_prepend_bos, default_padding_side=default_padding_side, + max_context_length=max_context_length, **from_pretrained_kwargs, ) diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index cb81276c3..0da6b3b87 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -869,7 +869,7 @@ def convert_hf_model_config(model_name: str, **kwargs): "n_heads": hf_config.num_attention_heads, "d_mlp": hf_config.intermediate_size // 2, "n_layers": hf_config.num_hidden_layers, - "n_ctx": 2048, # Capped bc the actual ctx length is 30k and the attn mask would be too big + "n_ctx": hf_config.seq_length, "eps": hf_config.layer_norm_epsilon, "d_vocab": hf_config.vocab_size, "act_fn": "silu", @@ -942,6 +942,7 @@ def get_pretrained_model_config( device: Optional[str] = None, n_devices: int = 1, default_prepend_bos: bool = True, + max_context_length: Optional[int] = None, dtype: torch.dtype = torch.float32, **kwargs, ): @@ -974,6 +975,8 @@ def get_pretrained_model_config( so this empirically seems to give better results. To change the default behavior to False, pass in default_prepend_bos=False. Note that you can also locally override the default behavior by passing in prepend_bos=True/False when you call a method that processes the input string. + max_context_length (int, optional): The maximum value of :attr:`transformer_lens.HookedTransformer.n_ctx` to use + for the model. If None, `n_ctx` is left unchanged. Defaults to None. dtype (torch.dtype, optional): The dtype to load the TransformerLens model in. kwargs: Other optional arguments passed to HuggingFace's from_pretrained. Also given to other HuggingFace functions when compatible. @@ -1016,6 +1019,10 @@ def get_pretrained_model_config( cfg_dict["dtype"] = dtype + # If max_context_length is specified, use it to cap n_ctx + if max_context_length and "n_ctx" in cfg_dict: + cfg_dict["n_ctx"] = min(cfg_dict["n_ctx"], max_context_length) + if fold_ln: if cfg_dict["normalization_type"] in ["LN", "LNPre"]: cfg_dict["normalization_type"] = "LNPre"