Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions transformer_lens/HookedTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1089,6 +1089,7 @@ def from_pretrained(
fold_value_biases: Optional[bool] = True,
default_prepend_bos: Optional[bool] = True,
default_padding_side: Optional[Literal["left", "right"]] = "right",
max_context_length: Optional[int] = 2048,
dtype="float32",
**from_pretrained_kwargs,
) -> "HookedTransformer":
Expand Down Expand Up @@ -1219,6 +1220,9 @@ def from_pretrained(
other HuggingFace functions when compatible. For some models or arguments it doesn't
work, especially for models that are not internally loaded with HuggingFace's
from_pretrained (e.g. SoLU models).
max_context_length: The maximum context length to use for the model. Defaults to 2048. Can be set to
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a complex parameter, so I think the docstring should be more detailed. I would rephrase the comment to something like:

This allows us to override the TransformerLens default max context length for the model. TransformerLens may default to a smaller max context length than a model was trained with, you can use this parameter to restore a higher max if needed for your use case. Note that this will only work with positional encodings like rotary and sinusoidal that support multiple max context lengths, it will not work with absolute position encodings. If you set the max context to be larger than what the model was trained with, it is unlikely that it can make full use of the context.
For example, Mistral 7B was trained with a 32K context length, but TransformerLens defaults to 2K, because each attention layer has a n_ctx x n_ctx attention mask attached, which can get very memory intensive. If you need the full context, you can set this to a higher value.

None to use the full context length of the model. Unless a larger context length is needed, it is
recommended to use the default value, as longer context lengths are highly memory intensive.
dtype: What data type to load the model in (also sets the dtype of
the HuggingFace model). Set to bfloat16 or float16 if you get out of memory errors when loading
the model.
Expand Down Expand Up @@ -1260,6 +1264,7 @@ def from_pretrained(
device=device,
n_devices=n_devices,
default_prepend_bos=default_prepend_bos,
max_context_length=max_context_length,
dtype=dtype,
**from_pretrained_kwargs,
)
Expand Down Expand Up @@ -1326,6 +1331,7 @@ def from_pretrained_no_processing(
dtype=torch.float32,
default_prepend_bos=True,
default_padding_side="right",
max_context_length=None,
**from_pretrained_kwargs,
):
"""Wrapper for from_pretrained.
Expand All @@ -1343,6 +1349,7 @@ def from_pretrained_no_processing(
dtype=dtype,
default_prepend_bos=default_prepend_bos,
default_padding_side=default_padding_side,
max_context_length=max_context_length,
**from_pretrained_kwargs,
)

Expand Down
9 changes: 8 additions & 1 deletion transformer_lens/loading_from_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,7 +869,7 @@ def convert_hf_model_config(model_name: str, **kwargs):
"n_heads": hf_config.num_attention_heads,
"d_mlp": hf_config.intermediate_size // 2,
"n_layers": hf_config.num_hidden_layers,
"n_ctx": 2048, # Capped bc the actual ctx length is 30k and the attn mask would be too big
"n_ctx": hf_config.seq_length,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you change this?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And why didn't you change the Mistral entry?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My thinking was to separate the parsing of the configs from the filtering of them. Its functionally the same, the filter is just applied at a different stage

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. I lean towards doing it here? My reasoning is that, if someone wants to understand the Qwen config, they can just come here and read this block of code. If there's later code that adjusts it, that's hard to notice, and may be actively misleading, because they'll see 32K here and expect that to end up in the model.

A counter-argument is that adding an "if n_ctx > 4K and override_max_ctx is None: n_ctx=4K" statement can be done once and work for all present and future models, rather than needing people to notice each time. But on net I still prefer all the "understanding Qwen config" code to be in one place

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thats a fair point, another option would be to add a comment mentioning that it may be overridden by get_pretrained_model_config

I'm fine with either, lmk which you prefer

Copy link
Copy Markdown
Contributor Author

@collingray collingray Jan 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I forgot the other reason I went with that approach, which is that someone using the package not from source would have to directly modify the config in order to change n_ctx.

Though this doesn't really matter if you can also extend n_ctx with the parameter

"eps": hf_config.layer_norm_epsilon,
"d_vocab": hf_config.vocab_size,
"act_fn": "silu",
Expand Down Expand Up @@ -942,6 +942,7 @@ def get_pretrained_model_config(
device: Optional[str] = None,
n_devices: int = 1,
default_prepend_bos: bool = True,
max_context_length: Optional[int] = None,
Copy link
Copy Markdown
Contributor Author

@collingray collingray Jan 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not positive that this is the correct default to use here.

My thinking is that "get_pretrained_model_config" implies that it will return the config mostly unaltered

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would call it override_default_context_length maybe? That has the problem of sounding like a Boolean, but combined with the type being Optional[int] I think it's pretty obvious what's going on

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My reasoning for max vs. override was that max makes it clear that the context length will only be changed if its too big, whereas there is some ambiguity about whether override would always set it to a fixed value.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps override_max_context_length would actually be best

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

override_max_context_length sounds good to me

dtype: torch.dtype = torch.float32,
**kwargs,
):
Expand Down Expand Up @@ -974,6 +975,8 @@ def get_pretrained_model_config(
so this empirically seems to give better results. To change the default behavior to False, pass in
default_prepend_bos=False. Note that you can also locally override the default behavior by passing
in prepend_bos=True/False when you call a method that processes the input string.
max_context_length (int, optional): The maximum value of :attr:`transformer_lens.HookedTransformer.n_ctx` to use
for the model. If None, `n_ctx` is left unchanged. Defaults to None.
dtype (torch.dtype, optional): The dtype to load the TransformerLens model in.
kwargs: Other optional arguments passed to HuggingFace's from_pretrained.
Also given to other HuggingFace functions when compatible.
Expand Down Expand Up @@ -1016,6 +1019,10 @@ def get_pretrained_model_config(

cfg_dict["dtype"] = dtype

# If max_context_length is specified, use it to cap n_ctx
if max_context_length and "n_ctx" in cfg_dict:
cfg_dict["n_ctx"] = min(cfg_dict["n_ctx"], max_context_length)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO if the user sets max_context_length higher than n_ctx, we should allow it to increase n_ctx

Copy link
Copy Markdown
Contributor Author

@collingray collingray Jan 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this be possible, assuming that n_ctx is already the maximum context length of the model?

edit: oh I see what you're saying, yeah I can add this

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, rotary models allow you to increase n_ctx beyond what a model was trained with, it may just go completely off the rails. Generally what happens, I think, is that it works fine, but doesn't get any better at predicting the 64Kth token than the 32Kth token (but is better at predicting 32Kth than 16Kth)

Copy link
Copy Markdown
Contributor Author

@collingray collingray Jan 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may actually make sense to split it into multiple parameters for this (perhaps override_max_context_length and extend_context_length), in order to still use a default maximum n_ctx, since if it always sets it directly to the provided value then smaller models will get raised up to that n_ctx without this being clear to the user

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's worth having two parameters for such a subtle distinction. IMO if you use this parameter you should know what you're doing. Maybe have it output a warning?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had set the default for it at 2048 in HookedTransformer.from_pretrained, the thinking being that the people using larger contexts likely know the library better, but new users may get tripped up trying to use mistral and having to change a flag to get it to load

A warning for when n_ctx is being extended is a good idea though


if fold_ln:
if cfg_dict["normalization_type"] in ["LN", "LNPre"]:
cfg_dict["normalization_type"] = "LNPre"
Expand Down