From c10e196f338e0268c68ec5720734979f40cfb3bd Mon Sep 17 00:00:00 2001 From: collin Date: Tue, 23 Jan 2024 16:13:50 -0800 Subject: [PATCH 1/3] add max_context_length --- transformer_lens/HookedTransformer.py | 5 +++++ transformer_lens/loading_from_pretrained.py | 9 ++++++++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index 5cce47eec..a78c5ea44 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -1089,6 +1089,7 @@ def from_pretrained( fold_value_biases: Optional[bool] = True, default_prepend_bos: Optional[bool] = True, default_padding_side: Optional[Literal["left", "right"]] = "right", + max_context_length: Optional[int] = 2048, dtype="float32", **from_pretrained_kwargs, ) -> "HookedTransformer": @@ -1219,6 +1220,9 @@ def from_pretrained( other HuggingFace functions when compatible. For some models or arguments it doesn't work, especially for models that are not internally loaded with HuggingFace's from_pretrained (e.g. SoLU models). + max_context_length: The maximum context length to use for the model. Defaults to 2048. Can be set to + None to use the full context length of the model. Unless a larger context length is needed, it is + recommended to use the default value, as longer context lengths are highly memory intensive. dtype: What data type to load the model in (also sets the dtype of the HuggingFace model). Set to bfloat16 or float16 if you get out of memory errors when loading the model. @@ -1260,6 +1264,7 @@ def from_pretrained( device=device, n_devices=n_devices, default_prepend_bos=default_prepend_bos, + max_context_length=max_context_length, dtype=dtype, **from_pretrained_kwargs, ) diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index cb81276c3..a9897e77b 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -869,7 +869,7 @@ def convert_hf_model_config(model_name: str, **kwargs): "n_heads": hf_config.num_attention_heads, "d_mlp": hf_config.intermediate_size // 2, "n_layers": hf_config.num_hidden_layers, - "n_ctx": 2048, # Capped bc the actual ctx length is 30k and the attn mask would be too big + "n_ctx": hf_config.seq_length, "eps": hf_config.layer_norm_epsilon, "d_vocab": hf_config.vocab_size, "act_fn": "silu", @@ -942,6 +942,7 @@ def get_pretrained_model_config( device: Optional[str] = None, n_devices: int = 1, default_prepend_bos: bool = True, + max_context_length: Optional[int] = None, dtype: torch.dtype = torch.float32, **kwargs, ): @@ -974,6 +975,8 @@ def get_pretrained_model_config( so this empirically seems to give better results. To change the default behavior to False, pass in default_prepend_bos=False. Note that you can also locally override the default behavior by passing in prepend_bos=True/False when you call a method that processes the input string. + max_context_length (int, optional): The maximum value of `n_ctx` to use for the model. + If None, `n_ctx` is left uncapped. Defaults to None. dtype (torch.dtype, optional): The dtype to load the TransformerLens model in. kwargs: Other optional arguments passed to HuggingFace's from_pretrained. Also given to other HuggingFace functions when compatible. @@ -1016,6 +1019,10 @@ def get_pretrained_model_config( cfg_dict["dtype"] = dtype + # If max_context_length is specified, use it to cap n_ctx + if max_context_length and "n_ctx" in cfg_dict: + cfg_dict["n_ctx"] = min(cfg_dict["n_ctx"], max_context_length) + if fold_ln: if cfg_dict["normalization_type"] in ["LN", "LNPre"]: cfg_dict["normalization_type"] = "LNPre" From 001ba5e8ce637f16b888e8252ffc0aa488c34472 Mon Sep 17 00:00:00 2001 From: collin Date: Tue, 23 Jan 2024 16:25:47 -0800 Subject: [PATCH 2/3] add docstring ref --- transformer_lens/loading_from_pretrained.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index a9897e77b..0da6b3b87 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -975,8 +975,8 @@ def get_pretrained_model_config( so this empirically seems to give better results. To change the default behavior to False, pass in default_prepend_bos=False. Note that you can also locally override the default behavior by passing in prepend_bos=True/False when you call a method that processes the input string. - max_context_length (int, optional): The maximum value of `n_ctx` to use for the model. - If None, `n_ctx` is left uncapped. Defaults to None. + max_context_length (int, optional): The maximum value of :attr:`transformer_lens.HookedTransformer.n_ctx` to use + for the model. If None, `n_ctx` is left unchanged. Defaults to None. dtype (torch.dtype, optional): The dtype to load the TransformerLens model in. kwargs: Other optional arguments passed to HuggingFace's from_pretrained. Also given to other HuggingFace functions when compatible. From 4a5ceab114a7224b13e394d1a7666237e98d770f Mon Sep 17 00:00:00 2001 From: collin Date: Tue, 23 Jan 2024 16:40:10 -0800 Subject: [PATCH 3/3] add param to from_pre_no_processing --- transformer_lens/HookedTransformer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index a78c5ea44..3ce5e9e25 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -1331,6 +1331,7 @@ def from_pretrained_no_processing( dtype=torch.float32, default_prepend_bos=True, default_padding_side="right", + max_context_length=None, **from_pretrained_kwargs, ): """Wrapper for from_pretrained. @@ -1348,6 +1349,7 @@ def from_pretrained_no_processing( dtype=dtype, default_prepend_bos=default_prepend_bos, default_padding_side=default_padding_side, + max_context_length=max_context_length, **from_pretrained_kwargs, )