diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index bb65f5f5c..7a5001948 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -1656,8 +1656,48 @@ def load_and_process_state_dict( ) state_dict = self.fill_missing_keys(state_dict) + if fold_ln: + if self.cfg.num_experts and self.cfg.num_experts > 1: + logging.warning( + "You are using MoE, so the layer norm weights can't be folded! Skipping" + ) + fold_ln = False + elif self.cfg.normalization_type not in ["LN", "LNPre", "RMS", "RMSPre"]: + logging.warning( + "You are not using LayerNorm or RMSNorm, so the layer norm weights can't be folded! Skipping" + ) + fold_ln = False + else: + ln_keys_present = any( + k.endswith((".ln1.w", ".ln2.w", "ln_final.w")) for k in state_dict + ) + if not ln_keys_present: + logging.warning( + "fold_ln=True but no LayerNorm weights found in state_dict. " + "The model may have been saved with already-folded LayerNorms. " + "Skipping fold." + ) + fold_ln = False + else: + if self.cfg.normalization_type == "LN": + self.cfg.normalization_type = "LNPre" + self.ln_final = LayerNormPre(self.cfg) + for layer in self.blocks: + layer.ln1 = LayerNormPre(self.cfg) + layer.ln2 = LayerNormPre(self.cfg) + if self.cfg.is_layer_norm_activation(): + layer.mlp.ln = LayerNormPre(self.cfg) + elif self.cfg.normalization_type == "RMS": + self.cfg.normalization_type = "RMSPre" + self.ln_final = RMSNormPre(self.cfg) + for layer in self.blocks: + layer.ln1 = RMSNormPre(self.cfg) + layer.ln2 = RMSNormPre(self.cfg) + if self.cfg.is_layer_norm_activation(): + layer.mlp.ln = RMSNormPre(self.cfg) # Use the centralized ProcessWeights class for all weight processing + # (fold_ln is passed through — if we skipped above, it's now False) state_dict = ProcessWeights.process_weights( state_dict, self.cfg, @@ -1678,6 +1718,9 @@ def load_and_process_state_dict( self.load_state_dict({key: state_dict[key]}, strict=False) del state_dict[key] + if fold_ln: + self.setup() + def fill_missing_keys(self, state_dict): return loading.fill_missing_keys(self, state_dict) @@ -1817,31 +1860,6 @@ def process_weights_( version of the same model. """ state_dict = self.state_dict() - if fold_ln and self.cfg.num_experts and self.cfg.num_experts > 1: - # If we're using MoE, we don't fold the layer norm weights, so we don't need to do any preprocessing - # A warning is already issued in `load_and_process_state_dict` - pass - elif fold_ln and self.cfg.normalization_type == "LN": - # If we're folding the LN into the weights, we need to replace all the layernorm layers - # with LayerNormPres, which do not have learnable parameters. This is somewhat hacky, - # but it's the easiest way to do it. - self.cfg.normalization_type = "LNPre" - self.ln_final = LayerNormPre(self.cfg) - for layer in self._get_blocks(): - layer.ln1 = LayerNormPre(self.cfg) - layer.ln2 = LayerNormPre(self.cfg) - if self.cfg.is_layer_norm_activation(): - layer.mlp.ln = LayerNormPre(self.cfg) - elif fold_ln and self.cfg.normalization_type == "RMS": - # We do the same for RMSNorm if used - self.cfg.normalization_type = "RMSPre" - self.ln_final = RMSNormPre(self.cfg) - for layer in self._get_blocks(): - layer.ln1 = RMSNormPre(self.cfg) - layer.ln2 = RMSNormPre(self.cfg) - if self.cfg.is_layer_norm_activation(): - layer.mlp.ln = RMSNormPre(self.cfg) - self.load_and_process_state_dict( state_dict, fold_ln=fold_ln,