diff --git a/transformer_lens/utils.py b/transformer_lens/utils.py index 7071f100f..ae5a97032 100644 --- a/transformer_lens/utils.py +++ b/transformer_lens/utils.py @@ -26,7 +26,7 @@ from rich import print as rprint from transformers import AutoTokenizer -from transformer_lens.FactoredMatrix import FactoredMatrix +from transformer_lens import ActivationCache, FactoredMatrix, HookedTransformer CACHE_DIR = transformers.TRANSFORMERS_CACHE USE_DEFAULT_VALUE = None @@ -1202,6 +1202,112 @@ def get_tokens_with_bos_removed(tokenizer, tokens): return tokens[tokens != -100].view(*bos_removed_shape) +def DLA( + model: HookedTransformer, + prompts: List[str], + answer_tokens: Int[torch.Tensor, "batch answers"], + accumulated: bool = False, +) -> (Float[torch.Tensor, "component"], List[str]): + """Function to calculate the DLA (either accumulated or per layer) for given list of prompts and tokens. + + Args: + model(HookedTransformer): model to test + prompts(List[str]): list of prompts + answer_tokens (Int[torch.Tensor, "batch answers"]) : per batch can be either single token or a pair of (correct, wrong) tokens + accumulated (bool): wheter to return the accumulated DLA or per layer + + Returns: + Float[torch.Tensor, "component"] : DLA per layer + List[str] : labels for each layer + """ + assert len(prompts) == answer_tokens.shape[0] + assert answer_tokens.shape[1] == 1 or answer_tokens.shape[1] == 2 + answer_residual_directions: Float[ + torch.Tensor, "batch answers d_model" + ] = model.tokens_to_residual_directions(answer_tokens) + + if ( + answer_tokens.numel() == 1 + ): # special case as tokens_to_residual_directions returns Float[Tensor, "d_model"] + logit_diff_directions: Float[torch.Tensor, "batch d_model"] = torch.unsqueeze( + answer_residual_directions, dim=0 + ) + elif answer_residual_directions.shape[1] == 1: + logit_diff_directions: Float[ + torch.Tensor, "batch d_model" + ] = answer_residual_directions[:, 0, :] + else: + ( + correct_residual_directions, + incorrect_residual_directions, + ) = answer_residual_directions.unbind(dim=1) + logit_diff_directions: Float[torch.Tensor, "batch d_model"] = ( + correct_residual_directions - incorrect_residual_directions + ) + + def residual_stack_to_logit_diff( + residual_stack: Float[torch.Tensor, "... batch d_model"], + cache: ActivationCache, + logit_diff_directions: Float[torch.Tensor, "batch d_model"], + ) -> Float[torch.Tensor, "..."]: + batch_size = residual_stack.size(-2) + scaled_residual_stack = cache.apply_ln_to_stack( + residual_stack, layer=-1, pos_slice=-1 + ) + return ( + einops.einsum( + scaled_residual_stack, + logit_diff_directions, + "... batch d_model, batch d_model -> ...", + ) + / batch_size + ) + + if accumulated: + n_layers = model.cfg.n_layers + _, cache = model.run_with_cache( + prompts, + return_type=None, + names_filter=lambda x: x == get_act_name("resid_post", n_layers - 1) + or x == get_act_name("ln_final.hook_scale") + or x.endswith("resid_pre") + or x.endswith("resid_mid"), + ) + + accumulated_residual, labels = cache.accumulated_resid( + layer=-1, pos_slice=-1, incl_mid=True, return_labels=True + ) + + logit_lens_logit_diffs: Float[ + torch.Tensor, "component" + ] = residual_stack_to_logit_diff( + accumulated_residual, cache, logit_diff_directions + ) + + return logit_lens_logit_diffs, labels + + else: + _, cache = model.run_with_cache( + prompts, + return_type=None, + names_filter=lambda x: x == get_act_name("ln_final.hook_scale") + or x.endswith("embed") + or x.endswith("attn_out") + or x.endswith("mlp_out"), + ) + + per_layer_residual, labels = cache.decompose_resid( + layer=-1, pos_slice=-1, return_labels=True + ) + per_layer_logit_diffs: Float[ + torch.Tensor, "component" + ] = residual_stack_to_logit_diff( + per_layer_residual, cache, logit_diff_directions + ) + + return per_layer_logit_diffs, labels + + try: import pytest