From 26850fb8eaee9fcc5c4543521e709f7b06a10da8 Mon Sep 17 00:00:00 2001 From: collin Date: Wed, 24 Jan 2024 16:56:30 -0800 Subject: [PATCH 1/3] cap n_ctx --- transformer_lens/loading_from_pretrained.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index cb81276c3..bdf71a5c4 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -815,7 +815,7 @@ def convert_hf_model_config(model_name: str, **kwargs): "n_heads": 32, "d_mlp": 14336, "n_layers": 32, - "n_ctx": 32768, + "n_ctx": 2048, "d_vocab": 32000, "act_fn": "silu", "normalization_type": "RMS", From 33e5e650621207c6f7461773e0c07565f02810ca Mon Sep 17 00:00:00 2001 From: collin Date: Wed, 24 Jan 2024 17:07:25 -0800 Subject: [PATCH 2/3] add comment --- transformer_lens/loading_from_pretrained.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index bdf71a5c4..93274d423 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -815,7 +815,7 @@ def convert_hf_model_config(model_name: str, **kwargs): "n_heads": 32, "d_mlp": 14336, "n_layers": 32, - "n_ctx": 2048, + "n_ctx": 2048, # Capped due to memory issues "d_vocab": 32000, "act_fn": "silu", "normalization_type": "RMS", From f60296793ea440710cddf473ec1a2aba8b1a6434 Mon Sep 17 00:00:00 2001 From: collin Date: Wed, 24 Jan 2024 17:54:44 -0800 Subject: [PATCH 3/3] formatting --- transformer_lens/loading_from_pretrained.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 93274d423..9506beeb6 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -815,7 +815,7 @@ def convert_hf_model_config(model_name: str, **kwargs): "n_heads": 32, "d_mlp": 14336, "n_layers": 32, - "n_ctx": 2048, # Capped due to memory issues + "n_ctx": 2048, # Capped due to memory issues "d_vocab": 32000, "act_fn": "silu", "normalization_type": "RMS",