-
Notifications
You must be signed in to change notification settings - Fork 562
Add max_context_length parameter to HookedTransformer
#491
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -869,7 +869,7 @@ def convert_hf_model_config(model_name: str, **kwargs): | |
| "n_heads": hf_config.num_attention_heads, | ||
| "d_mlp": hf_config.intermediate_size // 2, | ||
| "n_layers": hf_config.num_hidden_layers, | ||
| "n_ctx": 2048, # Capped bc the actual ctx length is 30k and the attn mask would be too big | ||
| "n_ctx": hf_config.seq_length, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why did you change this?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And why didn't you change the Mistral entry?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My thinking was to separate the parsing of the configs from the filtering of them. Its functionally the same, the filter is just applied at a different stage
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm. I lean towards doing it here? My reasoning is that, if someone wants to understand the Qwen config, they can just come here and read this block of code. If there's later code that adjusts it, that's hard to notice, and may be actively misleading, because they'll see 32K here and expect that to end up in the model. A counter-argument is that adding an "if n_ctx > 4K and override_max_ctx is None: n_ctx=4K" statement can be done once and work for all present and future models, rather than needing people to notice each time. But on net I still prefer all the "understanding Qwen config" code to be in one place
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thats a fair point, another option would be to add a comment mentioning that it may be overridden by get_pretrained_model_config I'm fine with either, lmk which you prefer
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh I forgot the other reason I went with that approach, which is that someone using the package not from source would have to directly modify the config in order to change n_ctx. Though this doesn't really matter if you can also extend n_ctx with the parameter |
||
| "eps": hf_config.layer_norm_epsilon, | ||
| "d_vocab": hf_config.vocab_size, | ||
| "act_fn": "silu", | ||
|
|
@@ -942,6 +942,7 @@ def get_pretrained_model_config( | |
| device: Optional[str] = None, | ||
| n_devices: int = 1, | ||
| default_prepend_bos: bool = True, | ||
| max_context_length: Optional[int] = None, | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not positive that this is the correct default to use here. My thinking is that "get_pretrained_model_config" implies that it will return the config mostly unaltered
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would call it
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My reasoning for max vs. override was that max makes it clear that the context length will only be changed if its too big, whereas there is some ambiguity about whether override would always set it to a fixed value.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. perhaps
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| dtype: torch.dtype = torch.float32, | ||
| **kwargs, | ||
| ): | ||
|
|
@@ -974,6 +975,8 @@ def get_pretrained_model_config( | |
| so this empirically seems to give better results. To change the default behavior to False, pass in | ||
| default_prepend_bos=False. Note that you can also locally override the default behavior by passing | ||
| in prepend_bos=True/False when you call a method that processes the input string. | ||
| max_context_length (int, optional): The maximum value of :attr:`transformer_lens.HookedTransformer.n_ctx` to use | ||
| for the model. If None, `n_ctx` is left unchanged. Defaults to None. | ||
| dtype (torch.dtype, optional): The dtype to load the TransformerLens model in. | ||
| kwargs: Other optional arguments passed to HuggingFace's from_pretrained. | ||
| Also given to other HuggingFace functions when compatible. | ||
|
|
@@ -1016,6 +1019,10 @@ def get_pretrained_model_config( | |
|
|
||
| cfg_dict["dtype"] = dtype | ||
|
|
||
| # If max_context_length is specified, use it to cap n_ctx | ||
| if max_context_length and "n_ctx" in cfg_dict: | ||
| cfg_dict["n_ctx"] = min(cfg_dict["n_ctx"], max_context_length) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IMO if the user sets max_context_length higher than n_ctx, we should allow it to increase n_ctx
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would this be possible, assuming that n_ctx is already the maximum context length of the model? edit: oh I see what you're saying, yeah I can add this
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, rotary models allow you to increase n_ctx beyond what a model was trained with, it may just go completely off the rails. Generally what happens, I think, is that it works fine, but doesn't get any better at predicting the 64Kth token than the 32Kth token (but is better at predicting 32Kth than 16Kth)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It may actually make sense to split it into multiple parameters for this (perhaps
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think it's worth having two parameters for such a subtle distinction. IMO if you use this parameter you should know what you're doing. Maybe have it output a warning?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I had set the default for it at 2048 in HookedTransformer.from_pretrained, the thinking being that the people using larger contexts likely know the library better, but new users may get tripped up trying to use mistral and having to change a flag to get it to load A warning for when n_ctx is being extended is a good idea though |
||
|
|
||
| if fold_ln: | ||
| if cfg_dict["normalization_type"] in ["LN", "LNPre"]: | ||
| cfg_dict["normalization_type"] = "LNPre" | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a complex parameter, so I think the docstring should be more detailed. I would rephrase the comment to something like: