diff --git a/tests/test_analyze_traps.py b/tests/test_analyze_traps.py index 4ce08f1..23c1eca 100644 --- a/tests/test_analyze_traps.py +++ b/tests/test_analyze_traps.py @@ -52,7 +52,11 @@ def test_analyze_traps_columns(self): df = self.watcher.analyze_traps(plot=False, savefig=False) expected_cols = { "layer_id", "name", "trap_index", "perm_mode_index", - "sigma_perm", "mp_bulk_max", "left_top_mass", "right_top_mass" + "sigma_perm", "mp_bulk_max", "left_top_mass", "right_top_mass", + "top_5_mass", "top_10_mass", + "bulk_localization_mean", "bulk_localization_std", + "bulk_top_5_mass_mean", "bulk_top_5_mass_std", + "bulk_top_10_mass_mean", "bulk_top_10_mass_std", } self.assertTrue(expected_cols.issubset(set(df.columns))) diff --git a/tests/test_remove_traps.py b/tests/test_remove_traps.py index e17ee1d..7a785ed 100644 --- a/tests/test_remove_traps.py +++ b/tests/test_remove_traps.py @@ -1,4 +1,5 @@ import numpy as np +import pandas as pd import pytest from weightwatcher.constants import CHANNELS, FRAMEWORK, LAYER_TYPE, DEFAULT_PARAMS @@ -214,6 +215,32 @@ def test_remove_traps_public_api_direct_call(monkeypatch): assert len(post_artifacts) == 0 +def test_remove_traps_accepts_traps_dataframe_and_returns_verify_df(monkeypatch): + W, _, _, _ = _single_trap_setup(seed=505) + ww_layer = make_ww_layer(W) + watcher = WeightWatcher(model={"dummy_weight": np.array([1.0])}) + + monkeypatch.setattr( + watcher, + "make_layer_iterator", + lambda model=None, layers=None, params=None, base_model=None: [ww_layer], + ) + + out_model, verify_df = watcher.remove_traps( + model={"dummy_weight": np.array([1.0])}, + layers=[], + traps=pd.DataFrame([{"layer_id": 0, "trap_index": 1}]), + seed=88, + pool=True, + plot=False, + verify_traps=True, + return_analyze=True, + ) + assert isinstance(out_model, dict) + assert isinstance(verify_df, pd.DataFrame) + assert "verify_passed" in verify_df.columns + + def test_remove_traps_invalid_indices_warns_and_skips(monkeypatch, caplog): W, _, _, _ = _single_trap_setup(seed=404) ww_layer = make_ww_layer(W) diff --git a/weightwatcher/remove_traps.py b/weightwatcher/remove_traps.py index cba8ce0..a8a6e36 100644 --- a/weightwatcher/remove_traps.py +++ b/weightwatcher/remove_traps.py @@ -1,6 +1,7 @@ import logging import numbers import numpy as np +import pandas as pd from .RMT_Util import svd_full, unpermute_matrix from .constants import DEFAULT_PARAMS, DEFAULT_START_ID, FAST_SVD, LAYER_TYPE, PEFT, PLOT, POOL, START_IDS, SVD_METHOD, DEFAULT_PEFT @@ -65,6 +66,18 @@ def identify_trap_mode_indices(ww, ww_layer): def analyze_single_trap(ww, ww_layer, trap_mode_index): + def _top_percent_abs_mass(mat, percent): + flat = np.abs(np.asarray(mat, dtype=float)).ravel() + if flat.size == 0: + return 0.0 + total = float(np.sum(flat)) + if total <= 0.0: + return 0.0 + k = int(np.ceil((float(percent) / 100.0) * flat.size)) + k = max(1, min(k, flat.size)) + top_sum = float(np.sum(np.partition(flat, -k)[-k:])) + return top_sum / total + W_perm = ww_layer.Wmats[0] U_perm, S_perm, Vh_perm = svd_full(W_perm) @@ -74,6 +87,8 @@ def analyze_single_trap(ww, ww_layer, trap_mode_index): T_perm = sigma_perm * np.outer(u_trap, v_trap) T_orig_norm = unpermute_matrix(T_perm, ww_layer.permute_ids[0]) U_orig, _, Vh_orig = svd_full(T_orig_norm) + top_5_mass = _top_percent_abs_mass(T_orig_norm, 5.0) + top_10_mass = _top_percent_abs_mass(T_orig_norm, 10.0) return { "trap_mode_index": trap_mode_index, @@ -84,6 +99,8 @@ def analyze_single_trap(ww, ww_layer, trap_mode_index): "v_trap": Vh_orig[0, :], "T_perm": T_perm, "T_orig_norm": T_orig_norm, + "top_5_mass": float(top_5_mass), + "top_10_mass": float(top_10_mass), } @@ -209,10 +226,32 @@ def apply_remove_traps(ww, ww_layer, trap_indices, params=None, seed=None, rng=N return ww_layer -def remove_traps(ww, model=None, layers=[], trap_indices=None, seed=None, rng=None, pool=True, plot=True, - start_ids=DEFAULT_START_ID, svd_method=FAST_SVD, base_model=None, peft=DEFAULT_PEFT): +def _trap_indices_from_traps_df(traps): + """Extract unique 1-based trap indices from a traps DataFrame-like input.""" + if traps is None: + return None + if isinstance(traps, pd.DataFrame): + trap_df = traps + else: + trap_df = pd.DataFrame(traps) + if "trap_index" not in trap_df.columns: + raise ValueError("traps must include a 'trap_index' column") + indices = trap_df["trap_index"].dropna().astype(int).tolist() + indices = sorted(set(indices)) + if len(indices) == 0: + raise ValueError("traps did not contain any valid trap_index values") + return indices + + +def remove_traps(ww, model=None, layers=[], trap_indices=None, traps=None, seed=None, rng=None, pool=True, plot=True, + verify_traps=False, return_analyze=False, start_ids=DEFAULT_START_ID, svd_method=FAST_SVD, + base_model=None, peft=DEFAULT_PEFT): + # PR359 compatibility path: passing traps= instead of trap_indices=[...] + if trap_indices is None and traps is not None: + trap_indices = _trap_indices_from_traps_df(traps) + if trap_indices is None or len(trap_indices) == 0: - raise ValueError("trap_indices must be provided and non-empty") + raise ValueError("trap_indices must be provided and non-empty (or pass traps with trap_index column)") ww.set_model_(model) params = DEFAULT_PARAMS.copy() @@ -230,8 +269,28 @@ def remove_traps(ww, model=None, layers=[], trap_indices=None, seed=None, rng=No params = ww.normalize_params(params) layer_iterator = ww.make_layer_iterator(model=ww.model, layers=layers, params=params, base_model=base_model) + verify_rows = [] for ww_layer in layer_iterator: if not ww_layer.skipped and ww_layer.has_weights: apply_remove_traps(ww, ww_layer, trap_indices=trap_indices, params=params, seed=seed, rng=params["rng"]) - + if verify_traps: + remaining = collect_trap_artifacts( + ww, + ww_layer, + params=params, + seed=None if params["rng"] is not None else seed, + rng=params["rng"], + ) + verify_rows.append( + { + "layer_id": int(ww_layer.layer_id), + "requested_trap_indices": list(trap_indices), + "remaining_traps": len(remaining), + "verify_passed": len(remaining) == 0, + } + ) + + if return_analyze: + verify_df = pd.DataFrame.from_records(verify_rows) + return model, verify_df return model diff --git a/weightwatcher/weightwatcher.py b/weightwatcher/weightwatcher.py index 0da873b..f4ac58a 100644 --- a/weightwatcher/weightwatcher.py +++ b/weightwatcher/weightwatcher.py @@ -3760,6 +3760,11 @@ def _trap_result_columns(self): "rank1_mass_after_unpermute", "sigma_trap_top", "left_top_mode", "right_top_mode", "left_top_mass", "right_top_mass", "left_overlap_entropy", "right_overlap_entropy", "left_overlap_ipr", "right_overlap_ipr", + "top_5_mass", "top_10_mass", + "bulk_mode_count", + "bulk_localization_mean", "bulk_localization_std", + "bulk_top_5_mass_mean", "bulk_top_5_mass_std", + "bulk_top_10_mass_mean", "bulk_top_10_mass_std", "u_length", "u_entropy", "u_discrete_entropy", "u_localization_ratio", "u_participation_ratio", "v_length", "v_entropy", "v_discrete_entropy", "v_localization_ratio", "v_participation_ratio", "u_l2_fourth_moment", "u_l2_sixth_moment", "u_effective_support", "u_gini_abs", @@ -3785,6 +3790,7 @@ def apply_analyze_traps(self, ww_layer, params=None): self.apply_permute_W(ww_layer, params) self.apply_trap_mp_fit(ww_layer, params=params) trap_mode_indices = self.identify_trap_mode_indices(ww_layer, params=params) + bulk_stats = self.compute_bulk_trap_reference_metrics(ww_layer, trap_mode_indices, params=params) trap_rows = [] for trap_index, mode_index in enumerate(trap_mode_indices): @@ -3795,6 +3801,7 @@ def apply_analyze_traps(self, ww_layer, params=None): params=params, trap_index=trap_index, ) + trap_row.update(bulk_stats) trap_rows.append(trap_row) self.apply_unpermute_W(ww_layer, params) @@ -3815,6 +3822,78 @@ def identify_trap_mode_indices(self, ww_layer, params=None): # counts/indices for a given layer and seed. return remove_traps_ops.identify_trap_mode_indices(self, ww_layer) + def _top_percent_abs_mass(self, mat, percent): + flat = np.abs(np.asarray(mat, dtype=float)).ravel() + if flat.size == 0: + return 0.0 + total = float(np.sum(flat)) + if total <= 0.0: + return 0.0 + k = int(np.ceil((float(percent) / 100.0) * flat.size)) + k = max(1, min(k, flat.size)) + top_sum = float(np.sum(np.partition(flat, -k)[-k:])) + return top_sum / total + + def compute_bulk_trap_reference_metrics(self, ww_layer, trap_mode_indices, params=None): + if params is None: params = DEFAULT_PARAMS.copy() + if len(ww_layer.Wmats) != 1 or len(ww_layer.permute_ids) == 0: + return { + "bulk_mode_count": 0, + "bulk_localization_mean": np.nan, + "bulk_localization_std": np.nan, + "bulk_top_5_mass_mean": np.nan, + "bulk_top_5_mass_std": np.nan, + "bulk_top_10_mass_mean": np.nan, + "bulk_top_10_mass_std": np.nan, + } + + W_perm = ww_layer.Wmats[0].astype(float) + p_ids = ww_layer.permute_ids[0] + U_perm, S_perm, Vh_perm = svd_full(W_perm, method=params[SVD_METHOD]) + trap_set = set(int(i) for i in trap_mode_indices) + bulk_indices = [i for i in range(len(S_perm)) if i not in trap_set] + + if len(bulk_indices) == 0: + return { + "bulk_mode_count": 0, + "bulk_localization_mean": np.nan, + "bulk_localization_std": np.nan, + "bulk_top_5_mass_mean": np.nan, + "bulk_top_5_mass_std": np.nan, + "bulk_top_10_mass_mean": np.nan, + "bulk_top_10_mass_std": np.nan, + } + + bulk_localization = [] + bulk_top_5_mass = [] + bulk_top_10_mass = [] + + for mode_idx in bulk_indices: + u_mode = U_perm[:, mode_idx] + v_mode = Vh_perm[mode_idx, :] + u_metrics = self._trap_vector_metrics(u_mode) + v_metrics = self._trap_vector_metrics(v_mode) + loc = 0.5 * ( + float(u_metrics.get("localization_ratio", np.nan)) + + float(v_metrics.get("localization_ratio", np.nan)) + ) + + T_perm = float(S_perm[mode_idx]) * np.outer(u_mode, v_mode) + T_orig = unpermute_matrix(T_perm, p_ids) + bulk_localization.append(loc) + bulk_top_5_mass.append(self._top_percent_abs_mass(T_orig, 5.0)) + bulk_top_10_mass.append(self._top_percent_abs_mass(T_orig, 10.0)) + + return { + "bulk_mode_count": int(len(bulk_indices)), + "bulk_localization_mean": float(np.nanmean(bulk_localization)), + "bulk_localization_std": float(np.nanstd(bulk_localization)), + "bulk_top_5_mass_mean": float(np.nanmean(bulk_top_5_mass)), + "bulk_top_5_mass_std": float(np.nanstd(bulk_top_5_mass)), + "bulk_top_10_mass_mean": float(np.nanmean(bulk_top_10_mass)), + "bulk_top_10_mass_std": float(np.nanstd(bulk_top_10_mass)), + } + def compute_original_basis_for_traps(self, ww_layer, params=None): if params is None: params = DEFAULT_PARAMS.copy() @@ -3849,6 +3928,9 @@ def analyze_single_trap(self, ww_layer, trap_mode_index, original_basis_cache=No T_perm = sigma_perm * np.outer(u_perm, v_perm) T_orig = unpermute_matrix(T_perm, p_ids) + top_5_mass = self._top_percent_abs_mass(T_orig, 5.0) + top_10_mass = self._top_percent_abs_mass(T_orig, 10.0) + Ut, St, Vht = svd_full(T_orig, method=params[SVD_METHOD]) u_trap = Ut[:, 0] v_trap = Vht.T[:, 0] @@ -3906,6 +3988,8 @@ def analyze_single_trap(self, ww_layer, trap_mode_index, original_basis_cache=No "right_overlap_entropy": right_overlap_entropy, "left_overlap_ipr": left_overlap_ipr, "right_overlap_ipr": right_overlap_ipr, + "top_5_mass": float(top_5_mass), + "top_10_mass": float(top_10_mass), "trap_detected": True, "trap_eval_minus_bulk": float(eval_perm - ww_layer.bulk_max), } @@ -5658,12 +5742,14 @@ def apply_remove_traps(self, ww_layer, trap_indices, params=None, seed=None, rng """Remove selected traps from one dense WWLayer and replace with matched random matrices.""" return remove_traps_ops.apply_remove_traps(self, ww_layer, trap_indices, params=params, seed=seed, rng=rng) - def remove_traps(self, model=None, layers=[], trap_indices=None, seed=None, rng=None, pool=True, plot=True, - start_ids=DEFAULT_START_ID, svd_method=FAST_SVD, base_model=None, peft=DEFAULT_PEFT): + def remove_traps(self, model=None, layers=[], trap_indices=None, traps=None, seed=None, rng=None, pool=True, plot=True, + verify_traps=False, return_analyze=False, start_ids=DEFAULT_START_ID, svd_method=FAST_SVD, + base_model=None, peft=DEFAULT_PEFT): """Remove selected randomized MP/TW traps from dense layers.""" return remove_traps_ops.remove_traps( - self, model=model, layers=layers, trap_indices=trap_indices, seed=seed, rng=rng, - pool=pool, plot=plot, start_ids=start_ids, svd_method=svd_method, base_model=base_model, peft=peft + self, model=model, layers=layers, trap_indices=trap_indices, traps=traps, seed=seed, rng=rng, + pool=pool, plot=plot, verify_traps=verify_traps, return_analyze=return_analyze, + start_ids=start_ids, svd_method=svd_method, base_model=base_model, peft=peft )