diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index cb81276c3..9506beeb6 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -815,7 +815,7 @@ def convert_hf_model_config(model_name: str, **kwargs): "n_heads": 32, "d_mlp": 14336, "n_layers": 32, - "n_ctx": 32768, + "n_ctx": 2048, # Capped due to memory issues "d_vocab": 32000, "act_fn": "silu", "normalization_type": "RMS",