diff --git a/tests/integration/test_grouped_query_attention.py b/tests/integration/test_grouped_query_attention.py index e5e603454..0a46fc2c4 100644 --- a/tests/integration/test_grouped_query_attention.py +++ b/tests/integration/test_grouped_query_attention.py @@ -52,10 +52,9 @@ def test_grouped_query_attention_output_is_correct(): "b_K": b_K, "W_V": W_V, "b_V": b_V, - "mask": regular_attention.state_dict()["mask"], "IGNORE": regular_attention.state_dict()["IGNORE"], } - grouped_query_attemtion_state_dict = { + grouped_query_attention_state_dict = { "W_Q": W_Q, "b_Q": b_Q, "W_O": W_O, @@ -64,12 +63,11 @@ def test_grouped_query_attention_output_is_correct(): "_b_K": _b_K, "_W_V": _W_V, "_b_V": _b_V, - "mask": grouped_query_attention.state_dict()["mask"], "IGNORE": grouped_query_attention.state_dict()["IGNORE"], } regular_attention.load_state_dict(regular_attention_state_dict) - grouped_query_attention.load_state_dict(grouped_query_attemtion_state_dict) + grouped_query_attention.load_state_dict(grouped_query_attention_state_dict) query_input = torch.rand((1, 5, d_model)) key_input = torch.rand((1, 5, d_model)) diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 45000efbf..7923a6caa 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -1039,7 +1039,7 @@ def convert_hf_model_config(model_name: str, **kwargs): "n_heads": hf_config.num_attention_heads, "d_mlp": hf_config.intermediate_size // 2, "n_layers": hf_config.num_hidden_layers, - "n_ctx": 2048, # Capped bc the actual ctx length is 30k and the attn mask would be too big + "n_ctx": hf_config.seq_length, "eps": hf_config.layer_norm_epsilon, "d_vocab": hf_config.vocab_size, "act_fn": "silu",