From e5cd8507ec7fa5630af47c5b69d9ed2dec2c15b3 Mon Sep 17 00:00:00 2001 From: Andy Arditi Date: Wed, 24 Jan 2024 20:23:17 +0000 Subject: [PATCH 1/2] construct causal mask on-the-fly --- tests/unit/test_grouped_query_attention.py | 6 +-- transformer_lens/components.py | 41 +++++++++++++-------- transformer_lens/loading_from_pretrained.py | 2 +- 3 files changed, 29 insertions(+), 20 deletions(-) diff --git a/tests/unit/test_grouped_query_attention.py b/tests/unit/test_grouped_query_attention.py index 885ec39a0..7541b1ddb 100644 --- a/tests/unit/test_grouped_query_attention.py +++ b/tests/unit/test_grouped_query_attention.py @@ -51,10 +51,9 @@ def test_grouped_query_attention_output_is_correct(): "b_K": b_K, "W_V": W_V, "b_V": b_V, - "mask": regular_attention.state_dict()["mask"], "IGNORE": regular_attention.state_dict()["IGNORE"], } - grouped_query_attemtion_state_dict = { + grouped_query_attention_state_dict = { "W_Q": W_Q, "b_Q": b_Q, "W_O": W_O, @@ -63,12 +62,11 @@ def test_grouped_query_attention_output_is_correct(): "_b_K": _b_K, "_W_V": _W_V, "_b_V": _b_V, - "mask": grouped_query_attention.state_dict()["mask"], "IGNORE": grouped_query_attention.state_dict()["IGNORE"], } regular_attention.load_state_dict(regular_attention_state_dict) - grouped_query_attention.load_state_dict(grouped_query_attemtion_state_dict) + grouped_query_attention.load_state_dict(grouped_query_attention_state_dict) query_input = torch.rand((1, 5, d_model)) key_input = torch.rand((1, 5, d_model)) diff --git a/transformer_lens/components.py b/transformer_lens/components.py index 942ec2819..200adc9c5 100644 --- a/transformer_lens/components.py +++ b/transformer_lens/components.py @@ -424,19 +424,8 @@ def __init__( self.b_O = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=cfg.dtype)) self.attn_type = attn_type - # Create a max_ctx x max_ctx mask, with True iff that query position - # can attend to that key position (query is first axis, key is second axis) - causal_mask = torch.tril(torch.ones((self.cfg.n_ctx, self.cfg.n_ctx)).bool()) - if self.attn_type == "global": - # For global attention, this is a lower triangular matrix - key <= query - self.register_buffer("mask", causal_mask) - elif self.attn_type == "local": - # For local, this is banded, query - window_size < key <= query - assert isinstance(self.cfg.window_size, int) - self.register_buffer( - "mask", torch.triu(causal_mask, 1 - self.cfg.window_size) - ) - else: + + if self.attn_type not in ["global", "local"]: raise ValueError(f"Invalid attention type: {self.attn_type}") self.register_buffer("IGNORE", torch.tensor(-torch.inf)) @@ -723,10 +712,32 @@ def apply_causal_mask( query_ctx_length + past_kv_pos_offset == key_ctx_length ), f"query_ctx_length {query_ctx_length} + past_kv_pos_offset {past_kv_pos_offset} != key_ctx_length {key_ctx_length} - you likely have a bug." + full_ctx_length = key_ctx_length + + # Construct causal mask on-the-fly + if self.attn_type == "global": + # For global attention, this is a lower triangular matrix - key <= query + causal_mask = torch.tril( + torch.ones( + (full_ctx_length, full_ctx_length), device=attn_scores.device + ).bool() + ) + elif self.attn_type == "local": + # For local, this is banded, query - window_size < key <= query + assert isinstance(self.cfg.window_size, int) + causal_mask = torch.tril( + torch.ones( + (full_ctx_length, full_ctx_length), device=attn_scores.device + ).bool() + ) + causal_mask = torch.triu(causal_mask, 1 - self.cfg.window_size) + else: + raise ValueError(f"Invalid attention type: {self.attn_type}") + # Index back to front to ensure local attention works - final_mask = self.mask[ + final_mask = causal_mask[ None, None, -query_ctx_length:, -key_ctx_length: - ] # [1, 1, pos, pos] + ] # [1, 1, query_ctx_length, key_ctx_length] if attention_mask is not None: # Apply a causal mask to the attention scores considering the padding einsum_str = "batch head pos offset_pos, batch offset_pos -> batch head pos offset_pos" diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index cb81276c3..b013ef9b4 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -869,7 +869,7 @@ def convert_hf_model_config(model_name: str, **kwargs): "n_heads": hf_config.num_attention_heads, "d_mlp": hf_config.intermediate_size // 2, "n_layers": hf_config.num_hidden_layers, - "n_ctx": 2048, # Capped bc the actual ctx length is 30k and the attn mask would be too big + "n_ctx": hf_config.seq_length, "eps": hf_config.layer_norm_epsilon, "d_vocab": hf_config.vocab_size, "act_fn": "silu", From d7291910183163568fdb9fe2501a784ecfb50661 Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Fri, 12 Apr 2024 00:53:13 +0200 Subject: [PATCH 2/2] ran formatting --- transformer_lens/components.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/transformer_lens/components.py b/transformer_lens/components.py index d121faa33..73e17886b 100644 --- a/transformer_lens/components.py +++ b/transformer_lens/components.py @@ -694,24 +694,22 @@ def apply_causal_mask( if self.attn_type == "global": # For global attention, this is a lower triangular matrix - key <= query causal_mask = torch.tril( - torch.ones( - (full_ctx_length, full_ctx_length), device=attn_scores.device - ).bool() + torch.ones((full_ctx_length, full_ctx_length), device=attn_scores.device).bool() ) elif self.attn_type == "local": # For local, this is banded, query - window_size < key <= query assert isinstance(self.cfg.window_size, int) causal_mask = torch.tril( - torch.ones( - (full_ctx_length, full_ctx_length), device=attn_scores.device - ).bool() + torch.ones((full_ctx_length, full_ctx_length), device=attn_scores.device).bool() ) causal_mask = torch.triu(causal_mask, 1 - self.cfg.window_size) else: raise ValueError(f"Invalid attention type: {self.attn_type}") # Index back to front to ensure local attention works - final_mask = causal_mask[None, None, -query_ctx_length:, -key_ctx_length:] # [1, 1, query_ctx_length, key_ctx_length + final_mask = causal_mask[ + None, None, -query_ctx_length:, -key_ctx_length: + ] # [1, 1, query_ctx_length, key_ctx_length if attention_mask is not None: # Apply a causal mask to the attention scores considering the padding einsum_str = "batch head pos offset_pos, batch offset_pos -> batch head pos offset_pos"