diff --git a/README.md b/README.md index 248f568..36ecee8 100644 --- a/README.md +++ b/README.md @@ -119,6 +119,73 @@ New in v0.8.0: trap-level randomized diagnostics: trap_df = watcher.analyze_traps(layers=[3, 5], plot=True, savefig="trap_images") ``` +`analyze_traps()` now reports paper-aligned trap metrics from the NeurIPS trap workflow, +including normalized spectral excess (`trap_delta`), localization (`trap_q` / `trap_ipr`), +top-sector overlap (`trap_top_sector_overlap` with configurable `top_sector_l`), and +the primary paper scalar `trap_variance_burden`. + +`trap_q` is now Porter-Thomas-centered (Porter Thomas centered) relative to a random-vector baseline, +`E[IPR] ≈ 3/(m+2)` for real vectors), and `trap_variance_burden` uses this +Porter-Thomas-centered `trap_q`. + +For backward comparison, uniform-centered localization is still available as +`trap_q_uniform` (and `trap_diffuseness_uniform`), while +`trap_diffuseness_score` remains the separate heuristic diagnostic. + +The legacy heuristic diagnostics are still returned for backward compatibility +(`trap_diffuseness_score`, `trap_risk_score`, `trap_assessment`, plus legacy overlap/excess +fields), but the paper-facing metrics above are the primary outputs for trap interpretation. + +For burden-variant experiments (e.g., comparing Porter-Thomas vs uniform localization or +different spectral normalizations), `analyze_traps` also supports: + +- `burden_variants=None` (default): keep current PR-358 burden behavior. +- `burden_variants="default"`: add a standard sweep of variant columns + (`trap_variance_burden__*`). +- `return_burden_components=True`: expose scalar burden components for notebook analysis. +- `return_burden_raw=True`: expose raw vectors/overlaps (`u_perm`, `v_perm`, `u_trap`, + `v_trap`, `left_overlaps`, `right_overlaps`, `perm_evals_sorted`). + +Fourier/FFT diagnostics are also available: + +- `fft=False` (default): keeps standard-basis behavior unchanged. +- `trap_fft=True`: adds Fourier-space localization and Fourier-frequency mass diagnostics for trap vectors. +- `trap_fft_config={...}`: optional FFT diagnostics config (partial dicts are supported). + +The legacy `fft=True` argument is still reserved for the older layer-matrix FFT path; use +`trap_fft=True` for trap-vector Fourier diagnostics. + +Recommended `trap_fft_config` fields: + +```python +{ + "sides": "both", # "left" | "right" | "both" + "vectors": "both", # "perm" | "trap" | "both" + "fold_conjugates": True, # combine +k/-k bins for real vectors + "exclude_dc": False, # exclude freq-0 from effective mass diagnostics + "top_frequency_l": 1, # top-k frequencies for top-frequency mass + "selected_frequencies": None, # optional explicit frequency bins + "normalization": "ortho", # np.fft.fft(..., norm="ortho") + "baseline": "uniform", # "uniform" | "pt_real_mc" | "pt_complex" + "mc_samples": 2048, # used by baseline="pt_real_mc" + "mc_seed": 123, + "modulus": None, # optional modular-arithmetic length hint + "apply_only_if_length_matches_modulus": False, + "layer_fft_map": None, +} +``` + +When `burden_variants="default"` and `trap_fft=True`, FFT burden variants are added +(`trap_variance_burden__fft_*`), for example: + +- `trap_variance_burden__fft_uniform_right_current_spectral` +- `trap_variance_burden__fft_uniform_lr_geom_fft_topmass` +- `trap_variance_burden__fft_pt_lr_geom_fft_topmass` + +Note: Fourier overlap is intentionally based on Fourier-frequency mass concentration +(`trap_fft_top_frequency_mass_*` / `trap_fft_selected_frequency_mass_*`), not +unitary-transformed vector overlap. + See the new usage guide: [Correlation Trap Workflow (`analyze_traps` + `remove_traps`)](./docs_trap_features.md) ## PEFT / LORA models (experimental) @@ -289,6 +356,18 @@ watcher = ww.WeightWatcher(model=my_model) trap_df = watcher.analyze_traps(layers=[3, 5], plot=True, savefig="trap_images") ``` +The implementation is designed to follow the trap definitions in the NeurIPS paper first, +while preserving older diagnostic fields for compatibility with existing scripts/notebooks. +If you want the paper metrics directly, use: + +- `trap_delta` +- `trap_ipr` +- `trap_q` +- `trap_diffuseness` (`1 - trap_q`) +- `trap_top_sector_overlap` (cumulative overlap over first `top_sector_l` modes) +- `trap_variance_burden` +- `layer_trap_variance_burden` + For a complete walkthrough (including `remove_traps`), see: [Correlation Trap Workflow (`analyze_traps` + `remove_traps`)](./docs_trap_features.md) Fig (a) is well trained; Fig (b) may be over-fit. diff --git a/tests/test_analyze_traps.py b/tests/test_analyze_traps.py index 4ce08f1..6230e29 100644 --- a/tests/test_analyze_traps.py +++ b/tests/test_analyze_traps.py @@ -9,6 +9,76 @@ TORCH_AVAILABLE = False import weightwatcher as ww +import weightwatcher.trap_analysis as trap_analysis + + +class TestTrapMetricHelpers(unittest.TestCase): + + def setUp(self): + self.watcher = ww.WeightWatcher() + + def test_compute_trap_delta(self): + self.assertAlmostEqual(self.watcher.compute_trap_delta(12.0, 10.0), 0.2) + self.assertAlmostEqual(self.watcher.compute_trap_delta(8.0, 10.0), 0.0) + self.assertTrue(np.isnan(self.watcher.compute_trap_delta(12.0, 0.0))) + + def test_compute_trap_ipr_q_porter_thomas_baseline_vector(self): + # m=10 => E_PT[IPR]=3/(m+2)=0.25 + v = np.array([0.5, 0.5, 0.5, 0.5] + [0.0] * 6) + ipr, q = self.watcher.compute_trap_ipr_q(v) + self.assertAlmostEqual(ipr, 0.25) + self.assertAlmostEqual(q, 0.0, places=7) + + def test_compute_trap_ipr_q_uniform_vector_clips_to_zero_under_pt(self): + v = np.ones(10) / np.sqrt(10) + ipr, q = self.watcher.compute_trap_ipr_q(v) + self.assertAlmostEqual(ipr, 0.1) + self.assertAlmostEqual(q, 0.0) + + def test_compute_trap_ipr_q_one_hot_vector(self): + v = np.zeros(10) + v[0] = 1.0 + ipr, q = self.watcher.compute_trap_ipr_q(v) + self.assertAlmostEqual(ipr, 1.0) + self.assertAlmostEqual(q, 1.0) + + def test_compute_trap_ipr_q_uniform_legacy_field_behavior(self): + v = np.ones(10) / np.sqrt(10) + ipr, q = self.watcher.compute_trap_ipr_q_uniform(v) + self.assertAlmostEqual(ipr, 0.1) + self.assertAlmostEqual(q, 0.0) + + onehot = np.zeros(10) + onehot[0] = 1.0 + ipr, q = self.watcher.compute_trap_ipr_q_uniform(onehot) + self.assertAlmostEqual(ipr, 1.0) + self.assertAlmostEqual(q, 1.0) + + def test_compute_top_sector_overlap(self): + overlaps = np.array([0.25, 0.10, 0.05, 0.60]) + + overlap_1, ell_eff = self.watcher.compute_top_sector_overlap(overlaps, 1) + self.assertAlmostEqual(overlap_1, 0.25) + self.assertEqual(ell_eff, 1) + + overlap_2, ell_eff = self.watcher.compute_top_sector_overlap(overlaps, 2) + self.assertAlmostEqual(overlap_2, 0.35) + self.assertEqual(ell_eff, 2) + + def test_compute_trap_variance_burden(self): + burden = self.watcher.compute_trap_variance_burden( + trap_delta=0.2, + trap_q=0.5, + trap_top_sector_overlap=0.3, + ) + self.assertAlmostEqual(burden, 0.2**2 * 0.5 * 0.3**2) + + def test_plot_flag_coercion(self): + self.assertFalse(trap_analysis._coerce_plot_flag(False)) + self.assertFalse(trap_analysis._coerce_plot_flag("False")) + self.assertFalse(trap_analysis._coerce_plot_flag("0")) + self.assertTrue(trap_analysis._coerce_plot_flag(True)) + self.assertTrue(trap_analysis._coerce_plot_flag("True")) if TORCH_AVAILABLE: @@ -137,6 +207,64 @@ def test_order_invariant_stats_are_finite(self): ]: self.assertTrue(np.isfinite(row[col])) + def test_analyze_traps_contains_paper_metric_columns(self): + df = self.watcher.analyze_traps(plot=False, savefig=False, rng=1337) + required = { + "top_sector_l", + "top_sector_l_effective", + "trap_seed", + "n_traps", + "perm_signature", + "permutation_n", + "trap_identity_key", + "trap_delta", + "trap_ipr", + "trap_q", + "trap_diffuseness", + "trap_q_uniform", + "trap_diffuseness_uniform", + "trap_top_sector_overlap", + "trap_variance_burden", + "layer_trap_variance_burden", + } + self.assertTrue(required.issubset(set(df.columns))) + + def test_analyze_traps_respects_top_sector_l_argument(self): + df1 = self.watcher.analyze_traps(plot=False, savefig=False, rng=1337, top_sector_l=1) + df2 = self.watcher.analyze_traps(plot=False, savefig=False, rng=1337, top_sector_l=2) + + if len(df1) == 0 or len(df2) == 0: + self.skipTest("No traps detected in this environment") + + self.assertTrue((df1["top_sector_l"] == 1).all()) + self.assertTrue((df2["top_sector_l"] == 2).all()) + self.assertTrue((df1["top_sector_l_effective"] >= 1).all()) + self.assertTrue((df1["top_sector_l_effective"] <= df1["top_sector_l"]).all()) + self.assertTrue((df2["top_sector_l_effective"] >= 1).all()) + self.assertTrue((df2["top_sector_l_effective"] <= df2["top_sector_l"]).all()) + + def test_trap_variance_burden_formula_rowwise(self): + df = self.watcher.analyze_traps(plot=False, savefig=False, rng=1337) + if len(df) == 0: + self.skipTest("No traps detected in this environment") + + for _, row in df.iterrows(): + components = [row["trap_delta"], row["trap_q"], row["trap_top_sector_overlap"]] + if not np.all(np.isfinite(components)): + continue + expected = (row["trap_delta"] ** 2) * row["trap_q"] * (row["trap_top_sector_overlap"] ** 2) + self.assertAlmostEqual(row["trap_variance_burden"], expected) + + def test_layer_trap_variance_burden_aggregate(self): + df = self.watcher.analyze_traps(plot=False, savefig=False, rng=1337) + if len(df) == 0: + self.skipTest("No traps detected in this environment") + + for layer_id, subdf in df.groupby("layer_id"): + expected = subdf["trap_variance_burden"].sum() + observed = subdf["layer_trap_variance_burden"].iloc[0] + self.assertAlmostEqual(observed, expected) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_remove_traps.py b/tests/test_remove_traps.py index e17ee1d..d4339e0 100644 --- a/tests/test_remove_traps.py +++ b/tests/test_remove_traps.py @@ -244,6 +244,148 @@ def test_remove_traps_invalid_indices_warns_and_skips(monkeypatch, caplog): assert len(post_artifacts) == 0 +def test_remove_traps_return_analyze_returns_verification_dataframe(): + watcher = WeightWatcher(model=None) + W, _, _, _ = _single_trap_setup(seed=505) + ww_layer = make_ww_layer(W) + + def _iterator(model=None, layers=None, params=None, base_model=None): + return [ww_layer] + + watcher.make_layer_iterator = _iterator + model_out, verify_df = watcher.remove_traps( + model={"dummy_weight": np.array([1.0])}, + layers=[], + trap_indices=[1], + seed=44, + pool=True, + plot=False, + return_analyze=True, + ) + assert isinstance(model_out, dict) + assert len(verify_df) == 1 + assert bool(verify_df.iloc[0]["trap_verified"]) + assert bool(verify_df.iloc[0]["removed"]) + + +def test_remove_traps_verify_traps_true_raises_on_mismatch(monkeypatch): + watcher = WeightWatcher(model=None) + W, _, _, _ = _single_trap_setup(seed=606) + ww_layer = make_ww_layer(W) + + monkeypatch.setattr( + watcher, + "make_layer_iterator", + lambda model=None, layers=None, params=None, base_model=None: [ww_layer], + ) + + traps_df = watcher.analyze_traps(model=None, layers=[], seed=11, plot=False, savefig=False, pool=True) + traps_df = traps_df.copy() + traps_df.loc[:, "perm_signature"] = "deadbeef" + + with pytest.raises(RuntimeError, match="Trap verification failed"): + watcher.remove_traps( + model={"dummy_weight": np.array([1.0])}, + layers=[], + trap_indices=[1], + seed=11, + pool=True, + plot=False, + verify_traps=True, + traps=traps_df, + ) + + +def test_remove_traps_accepts_analyze_dataframe_without_explicit_indices(monkeypatch): + watcher = WeightWatcher(model=None) + W, _, _, _ = _single_trap_setup(seed=707) + ww_layer = make_ww_layer(W) + + monkeypatch.setattr( + watcher, + "make_layer_iterator", + lambda model=None, layers=None, params=None, base_model=None: [ww_layer], + ) + + traps_df = watcher.analyze_traps(model=None, layers=[], seed=22, plot=False, savefig=False, pool=True) + model_out, verify_df = watcher.remove_traps( + model={"dummy_weight": np.array([1.0])}, + layers=[], + trap_indices=None, + seed=22, + pool=True, + plot=False, + return_analyze=True, + traps=traps_df, + ) + assert isinstance(model_out, dict) + assert len(verify_df) == 1 + assert bool(verify_df.iloc[0]["trap_verified"]) + + +def test_remove_traps_verify_true_with_single_trap_row_subset(monkeypatch): + watcher = WeightWatcher(model=None) + seed = 24 + W, _, _, _ = _single_trap_setup(seed=909) + ww_layer = make_ww_layer(W) + + monkeypatch.setattr( + watcher, + "make_layer_iterator", + lambda model=None, layers=None, params=None, base_model=None: [ww_layer], + ) + + trap_df = watcher.analyze_traps(model=None, layers=[], seed=seed, plot=False, savefig=False, pool=True) + k = 0 + model_out, verify_df = watcher.remove_traps( + model={"dummy_weight": np.array([1.0])}, + layers=[], + traps=trap_df.iloc[[k]], + seed=seed, + pool=True, + plot=False, + verify_traps=True, + return_analyze=True, + ) + + assert isinstance(model_out, dict) + assert len(verify_df) == 1 + assert bool(verify_df.iloc[0]["perm_match"]) + assert bool(verify_df.iloc[0]["trap_verified"]) + + +def test_remove_traps_verification_regenerates_layer_local_trap_df(monkeypatch): + watcher = WeightWatcher(model=None) + W, _, _, _ = _single_trap_setup(seed=808) + ww_layer = make_ww_layer(W) + + monkeypatch.setattr( + watcher, + "make_layer_iterator", + lambda model=None, layers=None, params=None, base_model=None: [ww_layer], + ) + + original_analyze = watcher.analyze_traps + recorded_layers = [] + + def _recording_analyze(*args, **kwargs): + recorded_layers.append(tuple(kwargs.get("layers", []))) + return original_analyze(*args, **kwargs) + + monkeypatch.setattr(watcher, "analyze_traps", _recording_analyze) + watcher.remove_traps( + model={"dummy_weight": np.array([1.0])}, + layers=[], + trap_indices=[1], + seed=33, + pool=True, + plot=False, + return_analyze=True, + ) + + assert any(call == (int(ww_layer.layer_id),) for call in recorded_layers) + + @pytest.mark.skipif(torch is None, reason="PyTorch not installed") def test_trap_rng_consistency_analyze_vs_collect_single_and_multi_layer(): model = torch.nn.Sequential( diff --git a/tests/test_trap_burden_variants.py b/tests/test_trap_burden_variants.py new file mode 100644 index 0000000..5773fd7 --- /dev/null +++ b/tests/test_trap_burden_variants.py @@ -0,0 +1,179 @@ +import unittest +import numpy as np + +try: + import torch + import torch.nn as nn + TORCH_AVAILABLE = True +except Exception: + TORCH_AVAILABLE = False + +import weightwatcher as ww +import weightwatcher.trap_burden_variants as tbv + + +class TestTrapBurdenVariantMath(unittest.TestCase): + def test_current_pr358_variant_formula(self): + components = { + "trap_spectral_edge_ratio_current": 0.2, + "trap_q_pt_right_perm": 0.5, + "trap_top_sector_overlap_right": 0.3, + } + cfg = [c for c in tbv.DEFAULT_BURDEN_VARIANTS if c["name"] == "current_pr358"][0] + v = tbv.compute_burden_variant(components, cfg) + self.assertAlmostEqual(v, 0.2 ** 2 * 0.5 * 0.3 ** 2) + + def test_uniform_localization_basics(self): + v = np.ones(10) / np.sqrt(10) + ipr, q = tbv.localization_uniform_centered(v) + self.assertAlmostEqual(ipr, 0.1) + self.assertAlmostEqual(q, 0.0) + + w = np.zeros(10) + w[0] = 1.0 + ipr, q = tbv.localization_uniform_centered(w) + self.assertAlmostEqual(ipr, 1.0) + self.assertAlmostEqual(q, 1.0) + + def test_porter_thomas_localization(self): + # n=10 -> expected real PT IPR is 3/(10+2)=0.25 + # vector with 4 equal non-zero entries has IPR=0.25 exactly + v = np.array([0.5, 0.5, 0.5, 0.5] + [0.0] * 6) + ipr, q = tbv.localization_porter_thomas_centered(v, beta="real") + self.assertAlmostEqual(ipr, 3.0 / 12.0) + self.assertAlmostEqual(q, 0.0, places=7) + + w = np.zeros(10) + w[0] = 1.0 + _, q1 = tbv.localization_porter_thomas_centered(w, beta="real") + self.assertTrue(q1 <= 1.0 and q1 >= 0.0) + self.assertAlmostEqual(q1, 1.0) + + def test_spectral_modes(self): + eval_perm = 12.0 + mp_bulk_max = 10.0 + total = 100.0 + self.assertAlmostEqual( + tbv.spectral_excess(eval_perm, mp_bulk_max, total, mode="edge_ratio_current"), + 0.2, + ) + self.assertAlmostEqual( + tbv.spectral_excess(eval_perm, mp_bulk_max, total, mode="total_excess"), + 0.02, + ) + self.assertAlmostEqual( + tbv.spectral_excess(eval_perm, mp_bulk_max, total, mode="total_fraction"), + 0.12, + ) + + def test_combine_lr(self): + self.assertAlmostEqual(tbv.combine_lr(0.2, 0.8, "geom"), 0.4) + self.assertAlmostEqual(tbv.combine_lr(0.2, 0.8, "min"), 0.2) + self.assertAlmostEqual(tbv.combine_lr(0.2, 0.8, "max"), 0.8) + self.assertAlmostEqual(tbv.combine_lr(0.2, 0.8, "mean"), 0.5) + self.assertAlmostEqual(tbv.combine_lr(0.2, 0.8, "product"), 0.16) + self.assertTrue(np.isnan(tbv.combine_lr(np.nan, 0.8, "mean"))) + + +if TORCH_AVAILABLE: + class TinyTrapNet(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(16, 12, bias=False) + self.fc2 = nn.Linear(12, 10, bias=False) + with torch.no_grad(): + u = torch.linspace(1.0, 2.0, steps=12) + v = torch.linspace(-2.0, 1.0, steps=16) + self.fc1.weight.copy_(35.0 * torch.outer(u, v)) + + u2 = torch.linspace(1.0, 1.5, steps=10) + v2 = torch.linspace(-1.0, 2.0, steps=12) + self.fc2.weight.copy_(20.0 * torch.outer(u2, v2)) + + def forward(self, x): + x = self.fc1(x) + x = self.fc2(x) + return x + + +@unittest.skipUnless(TORCH_AVAILABLE, "torch is required for analyze_traps variant tests") +class TestTrapBurdenVariantAPI(unittest.TestCase): + def setUp(self): + self.watcher = ww.WeightWatcher(model=TinyTrapNet()) + + def test_analyze_traps_default_has_no_variant_cols(self): + df = self.watcher.analyze_traps(plot=False, savefig=False, rng=1337) + self.assertFalse(any(c.startswith("trap_variance_burden__") for c in df.columns)) + + def test_analyze_traps_default_variant_sweep_columns(self): + df = self.watcher.analyze_traps( + plot=False, + savefig=False, + rng=1337, + burden_variants="default", + ) + variant_cols = [c for c in df.columns if c.startswith("trap_variance_burden__")] + self.assertTrue(len(variant_cols) > 0) + self.assertIn("trap_variance_burden__current_pr358", variant_cols) + + def test_analyze_traps_component_columns(self): + df = self.watcher.analyze_traps( + plot=False, + savefig=False, + rng=1337, + burden_variants="default", + return_burden_components=True, + ) + required = { + "trap_q_pt_left_perm", + "trap_q_pt_right_perm", + "trap_q_pt_perm_lr_geom", + "trap_top_sector_overlap_left", + "trap_top_sector_overlap_right", + "trap_spectral_total_excess", + "trap_perm_total_variance", + } + self.assertTrue(required.issubset(set(df.columns))) + + def test_analyze_traps_raw_columns_control(self): + df = self.watcher.analyze_traps( + plot=False, + savefig=False, + rng=1337, + burden_variants="default", + return_burden_raw=False, + ) + for c in ["u_perm", "v_perm", "u_trap", "v_trap", "left_overlaps", "right_overlaps", "perm_evals_sorted"]: + self.assertNotIn(c, df.columns) + + df_raw = self.watcher.analyze_traps( + plot=False, + savefig=False, + rng=1337, + burden_variants="default", + return_burden_raw=True, + ) + if len(df_raw) == 0: + self.skipTest("No traps detected in this environment") + for c in ["u_perm", "v_perm", "u_trap", "v_trap", "left_overlaps", "right_overlaps", "perm_evals_sorted"]: + self.assertIn(c, df_raw.columns) + + def test_current_variant_matches_base_burden(self): + df = self.watcher.analyze_traps( + plot=False, + savefig=False, + rng=1337, + burden_variants="default", + return_burden_components=True, + ) + if len(df) == 0: + self.skipTest("No traps detected in this environment") + self.assertTrue(np.allclose( + df["trap_variance_burden"], + df["trap_variance_burden__current_pr358"], + equal_nan=True, + )) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_trap_fft.py b/tests/test_trap_fft.py new file mode 100644 index 0000000..92b9681 --- /dev/null +++ b/tests/test_trap_fft.py @@ -0,0 +1,129 @@ +import unittest +from unittest.mock import patch +import numpy as np + +import weightwatcher as ww +import weightwatcher.trap_fourier as tf +import weightwatcher.trap_burden_variants as tbv + +try: + import torch + import torch.nn as nn + TORCH_AVAILABLE = True +except Exception: + TORCH_AVAILABLE = False + + +class TestTrapFourierHelpers(unittest.TestCase): + def test_mass_sums_to_one(self): + rng = np.random.RandomState(123) + v = rng.normal(size=31) + raw = tf.fourier_mass(v, fold_conjugates=False) + self.assertAlmostEqual(float(np.sum(raw["mass"])), 1.0, places=7) + + folded = tf.fourier_mass(v, fold_conjugates=True) + self.assertAlmostEqual(float(np.sum(folded["folded_mass"])), 1.0, places=7) + + def test_one_hot_is_fourier_delocalized(self): + n = 32 + v = np.zeros(n) + v[0] = 1.0 + ipr_fft, q_fft = tf.fourier_uniform_centered_q(v, fold_conjugates=False) + self.assertAlmostEqual(ipr_fft, 1.0 / n, places=6) + self.assertAlmostEqual(q_fft, 0.0, places=6) + + def test_sinusoid_is_fourier_localized(self): + n = 64 + k = 7 + j = np.arange(n) + v = np.cos(2.0 * np.pi * k * j / n) + v /= np.linalg.norm(v) + + top_mass, idx, _ = tf.fourier_top_frequency_mass(v, top_frequency_l=1, fold_conjugates=True) + self.assertGreater(top_mass, 0.9) + self.assertTrue(len(idx) > 0) + + _, q_fft = tf.fourier_uniform_centered_q(v, fold_conjugates=True) + self.assertGreater(q_fft, 0.5) + + def test_default_variants_have_fft_only_when_enabled(self): + base = tbv.resolve_burden_variant_configs("default", trap_fft=False) + self.assertFalse(any(cfg["name"].startswith("fft_") for cfg in base)) + + with_fft = tbv.resolve_burden_variant_configs("default", trap_fft=True) + self.assertTrue(any(cfg["name"].startswith("fft_") for cfg in with_fft)) + + +if TORCH_AVAILABLE: + class TinyTrapNet(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(16, 12, bias=False) + self.fc2 = nn.Linear(12, 10, bias=False) + with torch.no_grad(): + u = torch.linspace(1.0, 2.0, steps=12) + v = torch.linspace(-2.0, 1.0, steps=16) + self.fc1.weight.copy_(35.0 * torch.outer(u, v)) + + u2 = torch.linspace(1.0, 1.5, steps=10) + v2 = torch.linspace(-1.0, 2.0, steps=12) + self.fc2.weight.copy_(20.0 * torch.outer(u2, v2)) + + +@unittest.skipUnless(TORCH_AVAILABLE, "torch is required for trap FFT integration tests") +class TestTrapFFTAnalyzeTrapsIntegration(unittest.TestCase): + def setUp(self): + self.watcher = ww.WeightWatcher(model=TinyTrapNet()) + + def test_trap_fft_true_does_not_call_matrix_fft(self): + with patch.object(self.watcher, "apply_FFT", side_effect=RuntimeError("apply_FFT should not be called")): + df = self.watcher.analyze_traps( + plot=False, + savefig=False, + rng=1337, + fft=False, + trap_fft=True, + ) + self.assertIsNotNone(df) + + def test_trap_fft_columns_present(self): + df = self.watcher.analyze_traps( + plot=False, + savefig=False, + rng=1337, + fft=False, + trap_fft=True, + burden_variants="default", + ) + required = { + "trap_fft_ipr_right_perm", + "trap_fft_q_uniform_right_perm", + "trap_fft_q_pt_right_perm", + "trap_fft_top_frequency_mass_right_perm", + "trap_fft_peak_frequency_right_perm", + "trap_fft_peak_mass_right_perm", + "trap_variance_burden__fft_uniform_lr_geom_fft_topmass", + } + self.assertTrue(required.issubset(set(df.columns))) + + def test_fft_true_preserves_old_matrix_fft_behavior(self): + called = {"n": 0} + + def _counting_fft(*args, **kwargs): + called["n"] += 1 + return ww.WeightWatcher.apply_FFT(self.watcher, *args, **kwargs) + + with patch.object(self.watcher, "apply_FFT", side_effect=_counting_fft): + self.watcher.analyze_traps( + plot=False, + savefig=False, + rng=1337, + fft=True, + trap_fft=False, + ) + + self.assertGreater(called["n"], 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/weightwatcher/remove_traps.py b/weightwatcher/remove_traps.py index cba8ce0..3ba8706 100644 --- a/weightwatcher/remove_traps.py +++ b/weightwatcher/remove_traps.py @@ -1,12 +1,14 @@ import logging import numbers import numpy as np +import pandas as pd from .RMT_Util import svd_full, unpermute_matrix from .constants import DEFAULT_PARAMS, DEFAULT_START_ID, FAST_SVD, LAYER_TYPE, PEFT, PLOT, POOL, START_IDS, SVD_METHOD, DEFAULT_PEFT from .constants import LAYERS from .constants import WW_NAME from .trap_histograms import plot_layer_trap_weight_histogram +from . import trap_identity logger = logging.getLogger(WW_NAME) @@ -104,11 +106,34 @@ def collect_trap_artifacts(ww, ww_layer, params=None, seed=None, rng=None): ww.apply_permute_W(analysis_layer, params, rng=rng) apply_trap_mp_fit(ww, analysis_layer, params) trap_mode_indices = identify_trap_mode_indices(ww, analysis_layer) + n_traps = int(len(trap_mode_indices)) + perm_ids = analysis_layer.permute_ids[0] if len(analysis_layer.permute_ids) > 0 else np.array([], dtype=int) + perm_sig = trap_identity.permutation_signature(perm_ids) artifacts = [] for i, trap_mode_index in enumerate(trap_mode_indices, start=1): artifact = analyze_single_trap(ww, analysis_layer, trap_mode_index) artifact["trap_index"] = i + artifact["trap_seed"] = seed + artifact["n_traps"] = n_traps + artifact["perm_signature"] = perm_sig + artifact["permutation_n"] = int(len(np.asarray(perm_ids).ravel())) + artifact["permutation_mode"] = "index_permutation" + artifact["trap_identity_key"] = trap_identity.make_trap_identity_key( + layer_id=analysis_layer.layer_id, + seed=seed, + trap_index=i - 1, + n_traps=n_traps, + perm_signature=perm_sig, + ) + artifact["layer_id"] = analysis_layer.layer_id + artifact["eval_perm"] = float(artifact["sigma_perm"] ** 2) + artifact["mp_bulk_max"] = float(getattr(analysis_layer, "bulk_max", np.nan)) + mp_bulk_max = float(getattr(analysis_layer, "bulk_max", np.nan)) + if np.isfinite(mp_bulk_max) and mp_bulk_max > 0: + artifact["trap_delta"] = float(max(artifact["eval_perm"] - mp_bulk_max, 0.0) / mp_bulk_max) + else: + artifact["trap_delta"] = float(np.nan) artifact["T_orig_raw"] = artifact["T_orig_norm"] / analysis_layer.w_norm artifacts.append(artifact) @@ -209,8 +234,30 @@ def apply_remove_traps(ww, ww_layer, trap_indices, params=None, seed=None, rng=N return ww_layer -def remove_traps(ww, model=None, layers=[], trap_indices=None, seed=None, rng=None, pool=True, plot=True, - start_ids=DEFAULT_START_ID, svd_method=FAST_SVD, base_model=None, peft=DEFAULT_PEFT): +def remove_traps( + ww, + model=None, + layers=[], + trap_indices=None, + seed=None, + rng=None, + pool=True, + plot=True, + start_ids=DEFAULT_START_ID, + svd_method=FAST_SVD, + base_model=None, + peft=DEFAULT_PEFT, + verify_traps=False, + return_analyze=False, + traps=None, + rtol=1e-4, + atol=1e-6, + min_vector_cosine=0.999, +): + traps_df = trap_identity.coerce_traps_dataframe(traps) + if traps_df is not None and len(traps_df) > 0 and (trap_indices is None or len(trap_indices) == 0): + trap_indices = [int(i) + 1 for i in traps_df["trap_index"].astype(int).tolist()] + if trap_indices is None or len(trap_indices) == 0: raise ValueError("trap_indices must be provided and non-empty") @@ -229,9 +276,86 @@ def remove_traps(ww, model=None, layers=[], trap_indices=None, seed=None, rng=No raise Exception(f"Error, params not valid: \n {params}") params = ww.normalize_params(params) + remove_rows = [] + needs_verify = bool(verify_traps or return_analyze or traps_df is not None) + layer_iterator = ww.make_layer_iterator(model=ww.model, layers=layers, params=params, base_model=base_model) for ww_layer in layer_iterator: if not ww_layer.skipped and ww_layer.has_weights: - apply_remove_traps(ww, ww_layer, trap_indices=trap_indices, params=params, seed=seed, rng=params["rng"]) + layer_analyze_df = None + layer_remove_df = None + if needs_verify: + layer_remove_df = ww.analyze_traps( + model=model, + layers=[int(ww_layer.layer_id)], + plot=False, + savefig=False, + pool=pool, + start_ids=start_ids, + peft=peft, + seed=seed, + rng=None, + return_burden_raw=True, + ) + if traps_df is not None and len(traps_df) > 0: + layer_analyze_df = traps_df[traps_df["layer_id"].astype(int) == int(ww_layer.layer_id)].copy() + else: + layer_analyze_df = layer_remove_df.copy() + + layer_indices = sorted(set([int(i) for i in trap_indices])) + removed_flag = True + removal_error = None + if needs_verify: + for idx in layer_indices: + if idx < 1: + continue + trap_index_zero = idx - 1 + analyze_candidates = layer_analyze_df[layer_analyze_df["trap_index"].astype(int) == int(trap_index_zero)] + if len(analyze_candidates) == 0: + analyze_row = pd.Series(dtype=float) + else: + analyze_row = analyze_candidates.iloc[0] + + analyze_seed = analyze_row.get("trap_seed", seed) + analyze_n_traps = analyze_row.get("n_traps", np.nan) + analyze_perm = analyze_row.get("perm_signature", "") + + remove_candidates = layer_remove_df[ + (layer_remove_df["layer_id"].astype(int) == int(ww_layer.layer_id)) + & (layer_remove_df["trap_index"].astype(int) == int(trap_index_zero)) + & (layer_remove_df["trap_seed"].astype(str) == str(analyze_seed)) + & (layer_remove_df["n_traps"].astype(float) == float(analyze_n_traps)) + & (layer_remove_df["perm_signature"].astype(str) == str(analyze_perm)) + ] + remove_row = remove_candidates.iloc[0] if len(remove_candidates) > 0 else pd.Series(dtype=float) + + verify = trap_identity.verify_trap_rows( + analyze_row, + remove_row, + rtol=rtol, + atol=atol, + min_vector_cosine=min_vector_cosine, + ) + vrow = trap_identity.build_trap_verification_row( + analyze_row=analyze_row, + remove_row=remove_row, + verify_dict=verify, + removed=False, + removal_error=None, + ) + remove_rows.append(vrow) + if verify_traps and (not verify.get("trap_verified", False)): + raise RuntimeError( + f"Trap verification failed for layer {ww_layer.layer_id}, trap_index={idx - 1}" + ) + apply_remove_traps(ww, ww_layer, trap_indices=trap_indices, params=params, seed=seed, rng=params["rng"]) + if needs_verify and len(remove_rows) > 0: + for i in range(len(remove_rows)): + if int(remove_rows[i].get("layer_id", -1)) == int(ww_layer.layer_id): + remove_rows[i]["removed"] = removed_flag + remove_rows[i]["removal_error"] = removal_error + if needs_verify: + remove_meta_df = pd.DataFrame(remove_rows) + return model, remove_meta_df return model diff --git a/weightwatcher/trap_analysis.py b/weightwatcher/trap_analysis.py index c35bdd6..948f5f2 100644 --- a/weightwatcher/trap_analysis.py +++ b/weightwatcher/trap_analysis.py @@ -4,6 +4,19 @@ from . import remove_traps as remove_traps_ops from . import weightwatcher as wwcore from .trap_histograms import plot_layer_trap_weight_histogram +from . import trap_fourier +from . import trap_identity + + +def _coerce_plot_flag(plot): + """Normalize user plot flags so analyze_traps(plot=False) is always respected.""" + if isinstance(plot, str): + p = plot.strip().lower() + if p in {"false", "0", "no", "off", ""}: + return False + if p in {"true", "1", "yes", "on"}: + return True + return bool(plot) def analyze_traps( @@ -23,12 +36,19 @@ def analyze_traps( pool=wwcore.DEFAULT_POOL, conv2d_fft=False, fft=False, + trap_fft=False, + trap_fft_config=None, channels=None, svd_method=wwcore.FAST_SVD, start_ids=wwcore.DEFAULT_START_ID, base_model=None, peft=wwcore.DEFAULT_PEFT, + seed=None, rng=None, + top_sector_l=1, + burden_variants=None, + return_burden_components=False, + return_burden_raw=False, ): """Externalized implementation for WeightWatcher.analyze_traps().""" if layers is None: @@ -49,7 +69,7 @@ def analyze_traps( params[wwcore.MAX_EVALS] = max_evals params[wwcore.MAX_N] = max_N - params[wwcore.PLOT] = plot + params[wwcore.PLOT] = _coerce_plot_flag(plot) params[wwcore.RANDOMIZE] = True params[wwcore.MP_FIT] = True params[wwcore.GLOROT_FIT] = glorot_fix @@ -72,7 +92,16 @@ def analyze_traps( params[wwcore.SAVEFIG] = savefig params[wwcore.PEFT] = peft params[wwcore.INVERSE] = False - params["rng"] = remove_traps_ops._normalize_trap_rng(rng=rng) + params["seed"] = seed + params["rng"] = remove_traps_ops._normalize_trap_rng(rng=rng, seed=seed) + params["top_sector_l"] = int(top_sector_l) + params["burden_variants"] = burden_variants + params["return_burden_components"] = bool(return_burden_components) + params["return_burden_raw"] = bool(return_burden_raw) + params["trap_fft"] = bool(trap_fft) + params["trap_fft_config"] = trap_fourier.resolve_trap_fft_config(trap_fft_config) + if int(top_sector_l) < 1: + raise ValueError("top_sector_l must be >= 1") wwcore.logger.debug("params {}".format(params)) if not watcher.valid_params(params): @@ -94,9 +123,10 @@ def analyze_traps( layer_params = dict(params) layer_params["_keep_trap_matrix"] = bool(params.get(wwcore.PLOT, False)) + layer_params["_layer_n_traps"] = 0 layer_rows = watcher.apply_analyze_traps(ww_layer, params=layer_params) if layer_rows: - if params.get(wwcore.PLOT, False): + if bool(params.get(wwcore.PLOT, False)): trap_infos = [] for row in layer_rows: trap_idx_zero_based = int(row.get("trap_index", -1)) @@ -135,9 +165,29 @@ def analyze_traps( details = pd.DataFrame.from_records(trap_rows) else: details = pd.DataFrame(columns=watcher._trap_result_columns()) + from .trap_burden_variants import resolve_burden_variant_configs + variant_configs = resolve_burden_variant_configs( + burden_variants, trap_fft=bool(params.get("trap_fft", False)) + ) + if variant_configs is not None: + for cfg in variant_configs: + details[f"trap_variance_burden__{cfg['name']}"] = pd.Series(dtype=float) + details[f"layer_trap_variance_burden__{cfg['name']}"] = pd.Series(dtype=float) + if bool(params.get("trap_fft", False)): + for col in _trap_fft_result_columns(): + details[col] = pd.Series(dtype=float) trap_cols = watcher._trap_result_columns() details = details.reindex(columns=trap_cols + [c for c in details.columns if c not in trap_cols]) + if len(details) > 0 and "trap_variance_burden" in details.columns: + details["layer_trap_variance_burden"] = ( + details.groupby("layer_id")["trap_variance_burden"].transform("sum") + ) + variant_cols = [c for c in details.columns if c.startswith("trap_variance_burden__")] + for col in variant_cols: + suffix = col.replace("trap_variance_burden__", "", 1) + layer_col = f"layer_trap_variance_burden__{suffix}" + details[layer_col] = details.groupby("layer_id")[col].transform("sum") if len(details) > 0: lead_cols = ["layer_id", "name"] @@ -159,6 +209,380 @@ def analyze_traps( return details +def compute_original_basis_for_traps(watcher, ww_layer, params=None): + if params is None: + params = wwcore.DEFAULT_PARAMS.copy() + if len(ww_layer.Wmats) != 1: + return None + + W_true = ww_layer.Wmats[0].astype(float) + U0, S0, V0h = wwcore.svd_full(W_true, method=params[wwcore.SVD_METHOD]) + return { + "W_true": W_true, + "U0": U0, + "S0": S0, + "V0": V0h.T, + } + + +def compute_trap_delta(eval_perm, mp_bulk_max): + eval_perm = float(eval_perm) + mp_bulk_max = float(mp_bulk_max) + if (not np.isfinite(eval_perm)) or (not np.isfinite(mp_bulk_max)) or mp_bulk_max <= 0.0: + return float(np.nan) + return float(max(eval_perm - mp_bulk_max, 0.0) / mp_bulk_max) + + +def compute_trap_ipr_q_uniform(vec): + v = np.asarray(vec, dtype=float).ravel() + if v.size == 0: + return float(np.nan), float(np.nan) + norm = np.linalg.norm(v) + if (not np.isfinite(norm)) or norm <= 0.0: + return float(np.nan), float(np.nan) + + v = v / norm + ipr = float(np.sum(v ** 4)) + m = int(len(v)) + if m <= 1: + q = 1.0 + else: + q = (m * ipr - 1.0) / (m - 1.0) + q = float(np.clip(q, 0.0, 1.0)) + return ipr, float(q) + + +def compute_trap_ipr_q_porter_thomas(vec): + v = np.asarray(vec, dtype=float).ravel() + if v.size == 0: + return float(np.nan), float(np.nan) + norm = np.linalg.norm(v) + if (not np.isfinite(norm)) or norm <= 0.0: + return float(np.nan), float(np.nan) + + v = v / norm + ipr = float(np.sum(v ** 4)) + m = int(len(v)) + if m <= 1: + q = 1.0 + else: + expected_ipr_pt = 3.0 / (m + 2.0) + q = (ipr - expected_ipr_pt) / (1.0 - expected_ipr_pt) + q = float(np.clip(q, 0.0, 1.0)) + return ipr, float(q) + + +def compute_trap_ipr_q(vec): + """Default paper-facing localization (Porter-Thomas-centered).""" + return compute_trap_ipr_q_porter_thomas(vec) + + +def compute_top_sector_overlap(overlaps, top_sector_l=1): + ell = int(top_sector_l) + if ell < 1: + raise ValueError("top_sector_l must be >= 1") + + overlap_vec = np.asarray(overlaps, dtype=float).ravel() + if overlap_vec.size == 0: + return float(np.nan), 0 + + ell_eff = min(ell, int(len(overlap_vec))) + return float(np.sum(overlap_vec[:ell_eff])), int(ell_eff) + + +def compute_trap_variance_burden(trap_delta, trap_q, trap_top_sector_overlap): + trap_delta = float(trap_delta) + trap_q = float(trap_q) + trap_top_sector_overlap = float(trap_top_sector_overlap) + if (not np.isfinite(trap_delta)) or (not np.isfinite(trap_q)) or (not np.isfinite(trap_top_sector_overlap)): + return float(np.nan) + return float((trap_delta ** 2) * trap_q * (trap_top_sector_overlap ** 2)) + + +def _trap_fft_result_columns(): + cols = [] + for side in ("right", "left"): + for space in ("perm", "trap"): + cols.extend( + [ + f"trap_fft_ipr_{side}_{space}", + f"trap_fft_q_uniform_{side}_{space}", + f"trap_fft_q_pt_{side}_{space}", + f"trap_fft_top_frequency_mass_{side}_{space}", + f"trap_fft_peak_frequency_{side}_{space}", + f"trap_fft_peak_mass_{side}_{space}", + ] + ) + cols.extend( + [ + "trap_fft_q_uniform_perm_lr_geom", + "trap_fft_q_uniform_perm_lr_min", + "trap_fft_q_pt_perm_lr_geom", + "trap_fft_q_pt_perm_lr_min", + "trap_fft_top_frequency_mass_perm_lr_geom", + "trap_fft_top_frequency_mass_perm_lr_min", + ] + ) + return cols + + +def analyze_single_trap(watcher, ww_layer, trap_mode_index, original_basis_cache=None, params=None, trap_index=0): + if params is None: + params = wwcore.DEFAULT_PARAMS.copy() + if original_basis_cache is None: + original_basis_cache = compute_original_basis_for_traps(watcher, ww_layer, params=params) + + W_perm = ww_layer.Wmats[0].astype(float) + p_ids = ww_layer.permute_ids[0] + + U_perm, S_perm, Vh_perm = wwcore.svd_full(W_perm, method=params[wwcore.SVD_METHOD]) + V_perm = Vh_perm.T + + sigma_perm = float(S_perm[trap_mode_index]) + u_perm = U_perm[:, trap_mode_index] + v_perm = V_perm[:, trap_mode_index] + + T_perm = sigma_perm * np.outer(u_perm, v_perm) + T_orig = wwcore.unpermute_matrix(T_perm, p_ids) + + Ut, St, Vht = wwcore.svd_full(T_orig, method=params[wwcore.SVD_METHOD]) + u_trap = Ut[:, 0] + v_trap = Vht.T[:, 0] + + U0 = original_basis_cache["U0"] + V0 = original_basis_cache["V0"] + + left_overlaps = np.abs(U0.T @ u_trap) ** 2 + right_overlaps = np.abs(V0.T @ v_trap) ** 2 + + left_top_mode = int(np.argmax(left_overlaps)) + right_top_mode = int(np.argmax(right_overlaps)) + left_top_mass = float(np.max(left_overlaps)) + right_top_mass = float(np.max(right_overlaps)) + + eps = 1e-12 + left_overlap_entropy = float(-np.sum((left_overlaps + eps) * np.log(left_overlaps + eps))) + right_overlap_entropy = float(-np.sum((right_overlaps + eps) * np.log(right_overlaps + eps))) + left_overlap_ipr = float(np.sum(left_overlaps ** 2)) + right_overlap_ipr = float(np.sum(right_overlaps ** 2)) + + st_sq = St * St + rank1_mass_after_unpermute = float(st_sq[0] / (np.sum(st_sq) + eps)) + + u_metrics = watcher._trap_vector_metrics(u_trap) + v_metrics = watcher._trap_vector_metrics(v_trap) + u_oi = watcher._trap_vector_order_invariant_stats(u_trap) + v_oi = watcher._trap_vector_order_invariant_stats(v_trap) + + eval_perm = sigma_perm ** 2 + top_sector_l = int(params.get("top_sector_l", 1)) + burden_variants = params.get("burden_variants", None) + return_burden_components = bool(params.get("return_burden_components", False)) + return_burden_raw = bool(params.get("return_burden_raw", False)) + trap_delta = compute_trap_delta(eval_perm=eval_perm, mp_bulk_max=ww_layer.bulk_max) + trap_ipr, trap_q = compute_trap_ipr_q_porter_thomas(v_perm) + _, trap_q_uniform = compute_trap_ipr_q_uniform(v_perm) + trap_top_sector_overlap, top_sector_l_effective = compute_top_sector_overlap( + right_overlaps, + top_sector_l=top_sector_l, + ) + trap_variance_burden = compute_trap_variance_burden( + trap_delta=trap_delta, + trap_q=trap_q, + trap_top_sector_overlap=trap_top_sector_overlap, + ) + trap_seed = params.get("seed", None) + perm_ids = ww_layer.permute_ids[0] if len(ww_layer.permute_ids) > 0 else np.array([], dtype=int) + perm_signature = trap_identity.permutation_signature(perm_ids) + n_traps = int(params.get("_layer_n_traps", np.nan)) if params.get("_layer_n_traps", None) is not None else np.nan + trap_identity_key = trap_identity.make_trap_identity_key( + layer_id=ww_layer.layer_id, + seed=trap_seed, + trap_index=int(trap_index), + n_traps=n_traps, + perm_signature=perm_signature, + ) + trap_fft = bool(params.get("trap_fft", False)) + trap_fft_config = params.get("trap_fft_config", None) + fft_metrics = {} + if trap_fft: + from .trap_burden_variants import combine_lr + + def add_fft(vec, side, space): + if not trap_fourier.length_matches_modulus(len(np.asarray(vec).ravel()), trap_fft_config): + fft_metrics[f"trap_fft_ipr_{side}_{space}"] = np.nan + fft_metrics[f"trap_fft_q_uniform_{side}_{space}"] = np.nan + fft_metrics[f"trap_fft_q_pt_{side}_{space}"] = np.nan + fft_metrics[f"trap_fft_top_frequency_mass_{side}_{space}"] = np.nan + fft_metrics[f"trap_fft_peak_frequency_{side}_{space}"] = np.nan + fft_metrics[f"trap_fft_peak_mass_{side}_{space}"] = np.nan + return + sm = trap_fourier.fourier_component_summary(vec, prefix="trap", trap_fft_config=trap_fft_config) + fft_metrics[f"trap_fft_ipr_{side}_{space}"] = sm.get("trap_fft_ipr", np.nan) + fft_metrics[f"trap_fft_q_uniform_{side}_{space}"] = sm.get("trap_fft_q_uniform", np.nan) + fft_metrics[f"trap_fft_q_pt_{side}_{space}"] = sm.get("trap_fft_q_pt", np.nan) + fft_metrics[f"trap_fft_top_frequency_mass_{side}_{space}"] = sm.get("trap_fft_top_frequency_mass", np.nan) + fft_metrics[f"trap_fft_peak_frequency_{side}_{space}"] = sm.get("trap_fft_peak_frequency", np.nan) + fft_metrics[f"trap_fft_peak_mass_{side}_{space}"] = sm.get("trap_fft_peak_mass", np.nan) + + add_fft(v_perm, "right", "perm") + add_fft(u_perm, "left", "perm") + add_fft(v_trap, "right", "trap") + add_fft(u_trap, "left", "trap") + + fft_metrics["trap_fft_q_uniform_perm_lr_geom"] = combine_lr( + fft_metrics.get("trap_fft_q_uniform_left_perm", np.nan), + fft_metrics.get("trap_fft_q_uniform_right_perm", np.nan), + "geom", + ) + fft_metrics["trap_fft_q_uniform_perm_lr_min"] = combine_lr( + fft_metrics.get("trap_fft_q_uniform_left_perm", np.nan), + fft_metrics.get("trap_fft_q_uniform_right_perm", np.nan), + "min", + ) + fft_metrics["trap_fft_q_pt_perm_lr_geom"] = combine_lr( + fft_metrics.get("trap_fft_q_pt_left_perm", np.nan), + fft_metrics.get("trap_fft_q_pt_right_perm", np.nan), + "geom", + ) + fft_metrics["trap_fft_q_pt_perm_lr_min"] = combine_lr( + fft_metrics.get("trap_fft_q_pt_left_perm", np.nan), + fft_metrics.get("trap_fft_q_pt_right_perm", np.nan), + "min", + ) + fft_metrics["trap_fft_top_frequency_mass_perm_lr_geom"] = combine_lr( + fft_metrics.get("trap_fft_top_frequency_mass_left_perm", np.nan), + fft_metrics.get("trap_fft_top_frequency_mass_right_perm", np.nan), + "geom", + ) + fft_metrics["trap_fft_top_frequency_mass_perm_lr_min"] = combine_lr( + fft_metrics.get("trap_fft_top_frequency_mass_left_perm", np.nan), + fft_metrics.get("trap_fft_top_frequency_mass_right_perm", np.nan), + "min", + ) + trap_result = { + "layer_id": ww_layer.layer_id, + "name": ww_layer.name, + "longname": ww_layer.longname, + "layer_type": str(ww_layer.the_type), + "N": ww_layer.N, + "M": ww_layer.M, + "rf": ww_layer.rf, + "Q": ww_layer.N / ww_layer.M if ww_layer.M > 0 else np.nan, + "trap_index": int(trap_index), + "perm_mode_index": int(trap_mode_index), + "sigma_perm": sigma_perm, + "eval_perm": float(eval_perm), + "mp_bulk_max": float(ww_layer.bulk_max), + "mp_bulk_min": float(ww_layer.bulk_min), + "sigma_mp": float(ww_layer.sigma_mp), + "num_spikes": int(ww_layer.num_spikes), + "rank1_mass_after_unpermute": rank1_mass_after_unpermute, + "sigma_trap_top": float(St[0]), + "left_top_mode": left_top_mode, + "right_top_mode": right_top_mode, + "left_top_mass": left_top_mass, + "right_top_mass": right_top_mass, + "left_overlap_entropy": left_overlap_entropy, + "right_overlap_entropy": right_overlap_entropy, + "left_overlap_ipr": left_overlap_ipr, + "right_overlap_ipr": right_overlap_ipr, + "trap_detected": True, + "trap_seed": trap_seed, + "trap_eval_minus_bulk": float(eval_perm - ww_layer.bulk_max), + "n_traps": n_traps, + "perm_signature": perm_signature, + "permutation_n": int(len(np.asarray(perm_ids).ravel())), + "permutation_mode": "index_permutation", + "trap_identity_key": trap_identity_key, + # Paper-aligned trap metrics (NeurIPS trap paper definitions). + "trap_delta": trap_delta, + "trap_ipr": trap_ipr, + "trap_q": trap_q, + # paper-facing diffuseness complement of PT-centered localization + "trap_diffuseness": float(1.0 - trap_q) if np.isfinite(trap_q) else np.nan, + # explicit legacy/uniform-centered localization retained for comparison + "trap_q_uniform": trap_q_uniform, + "trap_diffuseness_uniform": float(1.0 - trap_q_uniform) if np.isfinite(trap_q_uniform) else np.nan, + "top_sector_l": top_sector_l, + "top_sector_l_effective": top_sector_l_effective, + "trap_top_sector_overlap": trap_top_sector_overlap, + "trap_variance_burden": trap_variance_burden, + } + if trap_fft: + trap_result.update(fft_metrics) + + for k, v in u_metrics.items(): + trap_result[f"u_{k}"] = v + for k, v in v_metrics.items(): + trap_result[f"v_{k}"] = v + for k, v in u_oi.items(): + trap_result[f"u_{k}"] = v + for k, v in v_oi.items(): + trap_result[f"v_{k}"] = v + + trap_result["trap_balance_ratio"] = float( + trap_result["u_effective_support"] / (trap_result["v_effective_support"] + 1e-12) + ) + trap_result.update(watcher.assess_trap_diffuseness(trap_result)) + + variant_configs = None + components = None + if (burden_variants is not None) or return_burden_components: + from .trap_burden_variants import ( + compute_burden_components, + compute_burden_variants, + resolve_burden_variant_configs, + ) + + perm_total_variance = float(np.sum(S_perm ** 2)) + components = compute_burden_components( + eval_perm=eval_perm, + mp_bulk_max=ww_layer.bulk_max, + perm_total_variance=perm_total_variance, + u_perm=u_perm, + v_perm=v_perm, + u_trap=u_trap, + v_trap=v_trap, + left_overlaps=left_overlaps, + right_overlaps=right_overlaps, + top_sector_l=top_sector_l, + ) + if trap_fft: + components.update(fft_metrics) + variant_configs = resolve_burden_variant_configs(burden_variants, trap_fft=trap_fft) + if return_burden_components: + trap_result.update(components) + if variant_configs is not None: + trap_result.update(compute_burden_variants(components, variant_configs)) + + trap_result["left_overlaps"] = left_overlaps + trap_result["right_overlaps"] = right_overlaps + trap_result["u_trap"] = u_trap + trap_result["v_trap"] = v_trap + trap_result["u_perm"] = u_perm + trap_result["v_perm"] = v_perm + trap_result["T_orig"] = T_orig + trap_result["perm_evals_sorted"] = np.array(ww_layer.evals).copy() + + if bool(params.get(wwcore.PLOT, False)): + watcher.plot_trap_analysis(ww_layer, trap_result, params=params) + + if not return_burden_raw: + trap_result.pop("left_overlaps", None) + trap_result.pop("right_overlaps", None) + trap_result.pop("u_trap", None) + trap_result.pop("v_trap", None) + trap_result.pop("u_perm", None) + trap_result.pop("v_perm", None) + if not params.get("_keep_trap_matrix", False): + trap_result.pop("T_orig", None) + if not return_burden_raw: + trap_result.pop("perm_evals_sorted", None) + + return trap_result + + def _top_trap_component_row(row, weight_matrix, top_k=10): trap_matrix = np.asarray(row.get("T_orig", np.array([])), dtype=float) weight_matrix = np.asarray(weight_matrix, dtype=float) diff --git a/weightwatcher/trap_burden_variants.py b/weightwatcher/trap_burden_variants.py new file mode 100644 index 0000000..1bdf05c --- /dev/null +++ b/weightwatcher/trap_burden_variants.py @@ -0,0 +1,633 @@ +import copy +import math +from typing import Dict, Iterable, List, Optional + +import numpy as np +import pandas as pd + + +def safe_float(x): + try: + val = float(x) + except Exception: + return float(np.nan) + if not np.isfinite(val): + return float(np.nan) + return float(val) + + +def vector_ipr(vec): + v = np.asarray(vec, dtype=float).ravel() + if v.size == 0: + return float(np.nan) + norm = np.linalg.norm(v) + if (not np.isfinite(norm)) or norm <= 0.0: + return float(np.nan) + v = v / norm + return float(np.sum(np.abs(v) ** 4)) + + +def localization_uniform_centered(vec, clip=True): + ipr = vector_ipr(vec) + if not np.isfinite(ipr): + return float(np.nan), float(np.nan) + n = len(np.asarray(vec).ravel()) + if n <= 1: + q = 1.0 + else: + q = (n * ipr - 1.0) / (n - 1.0) + if clip: + q = float(np.clip(q, 0.0, 1.0)) + return float(ipr), float(q) + + +def localization_porter_thomas_centered(vec, beta="real", clip=True): + ipr = vector_ipr(vec) + if not np.isfinite(ipr): + return float(np.nan), float(np.nan) + n = len(np.asarray(vec).ravel()) + if n <= 1: + return float(ipr), 1.0 + + if beta == "complex": + expected_ipr = 2.0 / (n + 1.0) + else: + expected_ipr = 3.0 / (n + 2.0) + + denom = 1.0 - expected_ipr + if denom <= 0: + return float(ipr), float(np.nan) + + q_pt = (ipr - expected_ipr) / denom + if clip: + q_pt = float(np.clip(q_pt, 0.0, 1.0)) + return float(ipr), float(q_pt) + + +def spectral_excess(eval_perm, mp_bulk_max, perm_total_variance, mode="edge_ratio_current"): + eval_perm = safe_float(eval_perm) + mp_bulk_max = safe_float(mp_bulk_max) + total_var = safe_float(perm_total_variance) + if not np.isfinite(eval_perm): + return float(np.nan) + + raw_excess = float(np.nan) + if np.isfinite(mp_bulk_max): + raw_excess = float(max(eval_perm - mp_bulk_max, 0.0)) + + if mode in {"edge_ratio_current", "edge_ratio_linear"}: + if (not np.isfinite(mp_bulk_max)) or mp_bulk_max <= 0.0: + return float(np.nan) + return float(raw_excess / mp_bulk_max) + if mode == "total_excess": + if (not np.isfinite(total_var)) or total_var <= 0.0 or (not np.isfinite(raw_excess)): + return float(np.nan) + return float(raw_excess / total_var) + if mode == "total_fraction": + if (not np.isfinite(total_var)) or total_var <= 0.0: + return float(np.nan) + return float(eval_perm / total_var) + if mode == "raw_excess": + return float(raw_excess) + if mode == "log_edge_ratio": + if (not np.isfinite(mp_bulk_max)) or mp_bulk_max <= 0.0 or eval_perm <= 0.0: + return float(np.nan) + return float(max(math.log(eval_perm / mp_bulk_max), 0.0)) + raise ValueError(f"Unknown spectral mode: {mode}") + + +def combine_lr(left, right, method): + l = safe_float(left) + r = safe_float(right) + if method == "left": + return l + if method == "right": + return r + if (not np.isfinite(l)) or (not np.isfinite(r)): + return float(np.nan) + if method == "mean": + return float(0.5 * (l + r)) + if method == "geom": + return float(np.sqrt(max(l, 0.0) * max(r, 0.0))) + if method == "min": + return float(min(l, r)) + if method == "max": + return float(max(l, r)) + if method == "product": + return float(l * r) + raise ValueError(f"Unknown left/right combine method: {method}") + + +def compute_top_sector_overlap_pair(left_overlaps, right_overlaps, top_sector_l): + ell = int(top_sector_l) + if ell < 1: + raise ValueError("top_sector_l must be >= 1") + left = np.asarray(left_overlaps, dtype=float).ravel() + right = np.asarray(right_overlaps, dtype=float).ravel() + + if left.size == 0: + left_sum = float(np.nan) + ell_left = 0 + else: + ell_left = min(ell, int(left.size)) + left_sum = float(np.sum(left[:ell_left])) + + if right.size == 0: + right_sum = float(np.nan) + ell_right = 0 + else: + ell_right = min(ell, int(right.size)) + right_sum = float(np.sum(right[:ell_right])) + + return left_sum, right_sum, int(ell_left), int(ell_right) + + +def compute_burden_components( + eval_perm, + mp_bulk_max, + perm_total_variance, + u_perm, + v_perm, + u_trap, + v_trap, + left_overlaps, + right_overlaps, + top_sector_l=1, + beta="real", +): + out = {} + out["trap_perm_total_variance"] = safe_float(perm_total_variance) + out["trap_spectral_edge_ratio_current"] = spectral_excess( + eval_perm, mp_bulk_max, perm_total_variance, mode="edge_ratio_current" + ) + out["trap_spectral_total_excess"] = spectral_excess( + eval_perm, mp_bulk_max, perm_total_variance, mode="total_excess" + ) + out["trap_spectral_total_fraction"] = spectral_excess( + eval_perm, mp_bulk_max, perm_total_variance, mode="total_fraction" + ) + out["trap_spectral_raw_excess"] = spectral_excess( + eval_perm, mp_bulk_max, perm_total_variance, mode="raw_excess" + ) + out["trap_spectral_log_edge_ratio"] = spectral_excess( + eval_perm, mp_bulk_max, perm_total_variance, mode="log_edge_ratio" + ) + + ipr_l_perm, q_l_perm = localization_uniform_centered(u_perm, clip=True) + ipr_r_perm, q_r_perm = localization_uniform_centered(v_perm, clip=True) + ipr_l_trap, q_l_trap = localization_uniform_centered(u_trap, clip=True) + ipr_r_trap, q_r_trap = localization_uniform_centered(v_trap, clip=True) + + out["trap_ipr_left_perm"] = ipr_l_perm + out["trap_ipr_right_perm"] = ipr_r_perm + out["trap_ipr_left_trap"] = ipr_l_trap + out["trap_ipr_right_trap"] = ipr_r_trap + + out["trap_q_uniform_left_perm"] = q_l_perm + out["trap_q_uniform_right_perm"] = q_r_perm + out["trap_q_uniform_left_trap"] = q_l_trap + out["trap_q_uniform_right_trap"] = q_r_trap + + _, q_pt_l_perm = localization_porter_thomas_centered(u_perm, beta=beta, clip=True) + _, q_pt_r_perm = localization_porter_thomas_centered(v_perm, beta=beta, clip=True) + _, q_pt_l_trap = localization_porter_thomas_centered(u_trap, beta=beta, clip=True) + _, q_pt_r_trap = localization_porter_thomas_centered(v_trap, beta=beta, clip=True) + + out["trap_q_pt_left_perm"] = q_pt_l_perm + out["trap_q_pt_right_perm"] = q_pt_r_perm + out["trap_q_pt_left_trap"] = q_pt_l_trap + out["trap_q_pt_right_trap"] = q_pt_r_trap + + out["trap_q_pt_perm_lr_geom"] = combine_lr(q_pt_l_perm, q_pt_r_perm, "geom") + out["trap_q_pt_perm_lr_min"] = combine_lr(q_pt_l_perm, q_pt_r_perm, "min") + out["trap_q_pt_trap_lr_geom"] = combine_lr(q_pt_l_trap, q_pt_r_trap, "geom") + out["trap_q_pt_trap_lr_min"] = combine_lr(q_pt_l_trap, q_pt_r_trap, "min") + + left_ov, right_ov, ell_l, ell_r = compute_top_sector_overlap_pair( + left_overlaps, right_overlaps, top_sector_l=top_sector_l + ) + out["top_sector_l_effective_left"] = ell_l + out["top_sector_l_effective_right"] = ell_r + out["trap_top_sector_overlap_left"] = left_ov + out["trap_top_sector_overlap_right"] = right_ov + out["trap_top_sector_overlap_lr_geom"] = combine_lr(left_ov, right_ov, "geom") + out["trap_top_sector_overlap_lr_min"] = combine_lr(left_ov, right_ov, "min") + out["trap_top_sector_overlap_lr_mean"] = combine_lr(left_ov, right_ov, "mean") + out["trap_top_sector_overlap_lr_max"] = combine_lr(left_ov, right_ov, "max") + out["trap_top_sector_overlap_lr_product"] = combine_lr(left_ov, right_ov, "product") + return out + + +def _localization_key(family: str, vectors: str, side: str): + fam = "pt" if family == "porter_thomas" else "uniform" + if side in {"left", "right"}: + return f"trap_q_{fam}_{side}_{vectors}" + if side in {"mean", "geom", "min", "max", "product"}: + return None + raise ValueError(f"Unknown localization side: {side}") + + +def _get_localization_value(components, family="uniform", vectors="perm", side="right", domain="standard"): + if domain == "fft": + fam = "q_pt" if family == "porter_thomas" else "q_uniform" + if side in {"left", "right"}: + return safe_float(components.get(f"trap_fft_{fam}_{side}_{vectors}", np.nan)) + left = safe_float(components.get(f"trap_fft_{fam}_left_{vectors}", np.nan)) + right = safe_float(components.get(f"trap_fft_{fam}_right_{vectors}", np.nan)) + return combine_lr(left, right, side) + + key = _localization_key(family, vectors, side) + if key is not None: + return safe_float(components.get(key, np.nan)) + left = safe_float(components.get(f"trap_q_{'pt' if family == 'porter_thomas' else 'uniform'}_left_{vectors}", np.nan)) + right = safe_float(components.get(f"trap_q_{'pt' if family == 'porter_thomas' else 'uniform'}_right_{vectors}", np.nan)) + return combine_lr(left, right, side) + + +def _get_overlap_value(components, side="right", domain="standard", fft_overlap_measure="top_frequency_mass", vectors="perm"): + if domain == "fft": + base = "trap_fft_top_frequency_mass" if fft_overlap_measure == "top_frequency_mass" else "trap_fft_selected_frequency_mass" + if side in {"left", "right"}: + return safe_float(components.get(f"{base}_{side}_{vectors}", np.nan)) + left = safe_float(components.get(f"{base}_left_{vectors}", np.nan)) + right = safe_float(components.get(f"{base}_right_{vectors}", np.nan)) + return combine_lr(left, right, side) + if side in {"left", "right"}: + return safe_float(components.get(f"trap_top_sector_overlap_{side}", np.nan)) + if side in {"mean", "geom", "min", "max", "product"}: + return safe_float(components.get(f"trap_top_sector_overlap_lr_{side}", np.nan)) + raise ValueError(f"Unknown overlap side: {side}") + + +def compute_burden_variant(components, config): + spectral_mode = config.get("spectral_mode", "edge_ratio_current") + spectral_power = float(config.get("spectral_power", 1.0)) + localization_family = config.get("localization_family", "uniform") + localization_vectors = config.get("localization_vectors", "perm") + localization_side = config.get("localization_side", "right") + localization_power = float(config.get("localization_power", 1.0)) + overlap_side = config.get("overlap_side", "right") + overlap_power = float(config.get("overlap_power", 1.0)) + localization_domain = config.get("localization_domain", "standard") + overlap_domain = config.get("overlap_domain", "standard") + fft_localization_family = config.get("fft_localization_family", "uniform") + fft_overlap_measure = config.get("fft_overlap_measure", "top_frequency_mass") + overlap_vectors = config.get("overlap_vectors", localization_vectors) + + spectral_key = { + "edge_ratio_current": "trap_spectral_edge_ratio_current", + "edge_ratio_linear": "trap_spectral_edge_ratio_current", + "total_excess": "trap_spectral_total_excess", + "total_fraction": "trap_spectral_total_fraction", + "raw_excess": "trap_spectral_raw_excess", + "log_edge_ratio": "trap_spectral_log_edge_ratio", + }.get(spectral_mode) + if spectral_key is None: + raise ValueError(f"Unknown spectral mode: {spectral_mode}") + spectral_value = safe_float(components.get(spectral_key, np.nan)) + localization_value = _get_localization_value( + components, + family=(fft_localization_family if localization_domain == "fft" else localization_family), + vectors=localization_vectors, + side=localization_side, + domain=localization_domain, + ) + overlap_value = _get_overlap_value( + components, + side=overlap_side, + domain=overlap_domain, + fft_overlap_measure=fft_overlap_measure, + vectors=overlap_vectors, + ) + + values = [spectral_value, localization_value, overlap_value] + if any(not np.isfinite(v) for v in values): + return float(np.nan) + return float((spectral_value ** spectral_power) * (localization_value ** localization_power) * (overlap_value ** overlap_power)) + + +DEFAULT_BURDEN_VARIANTS: List[Dict] = [ + dict( + name="current_pr358", + spectral_mode="edge_ratio_current", + spectral_power=2, + localization_family="porter_thomas", + localization_vectors="perm", + localization_side="right", + localization_power=1, + overlap_side="right", + overlap_power=2, + ), + dict( + name="edge_linear_uniform_right", + spectral_mode="edge_ratio_current", + spectral_power=1, + localization_family="uniform", + localization_vectors="perm", + localization_side="right", + localization_power=1, + overlap_side="right", + overlap_power=2, + ), + dict( + name="edge_squared_pt_right_perm", + spectral_mode="edge_ratio_current", + spectral_power=2, + localization_family="porter_thomas", + localization_vectors="perm", + localization_side="right", + localization_power=1, + overlap_side="right", + overlap_power=2, + ), + dict( + name="total_excess_pt_right_perm", + spectral_mode="total_excess", + spectral_power=1, + localization_family="porter_thomas", + localization_vectors="perm", + localization_side="right", + localization_power=1, + overlap_side="right", + overlap_power=2, + ), + dict( + name="total_excess_pt_lr_geom_perm_overlap_lr_geom", + spectral_mode="total_excess", + spectral_power=1, + localization_family="porter_thomas", + localization_vectors="perm", + localization_side="geom", + localization_power=1, + overlap_side="geom", + overlap_power=1, + ), + dict( + name="total_excess_pt_lr_min_perm_overlap_lr_min", + spectral_mode="total_excess", + spectral_power=1, + localization_family="porter_thomas", + localization_vectors="perm", + localization_side="min", + localization_power=1, + overlap_side="min", + overlap_power=1, + ), + dict( + name="total_excess_pt_lr_geom_trap_overlap_lr_geom", + spectral_mode="total_excess", + spectral_power=1, + localization_family="porter_thomas", + localization_vectors="trap", + localization_side="geom", + localization_power=1, + overlap_side="geom", + overlap_power=1, + ), + dict( + name="total_fraction_pt_lr_geom_perm_overlap_lr_geom", + spectral_mode="total_fraction", + spectral_power=1, + localization_family="porter_thomas", + localization_vectors="perm", + localization_side="geom", + localization_power=1, + overlap_side="geom", + overlap_power=1, + ), + dict( + name="log_edge_pt_lr_geom_perm_overlap_lr_geom", + spectral_mode="log_edge_ratio", + spectral_power=1, + localization_family="porter_thomas", + localization_vectors="perm", + localization_side="geom", + localization_power=1, + overlap_side="geom", + overlap_power=1, + ), + dict( + name="total_excess_pt_lr_geom_no_overlap", + spectral_mode="total_excess", + spectral_power=1, + localization_family="porter_thomas", + localization_vectors="perm", + localization_side="geom", + localization_power=1, + overlap_side="right", + overlap_power=0, + ), +] + +FFT_DEFAULT_BURDEN_VARIANTS: List[Dict] = [ + dict( + name="fft_uniform_right_current_spectral", + spectral_mode="edge_ratio_current", + spectral_power=2, + localization_domain="fft", + fft_localization_family="uniform", + localization_vectors="perm", + localization_side="right", + overlap_domain="standard", + overlap_side="right", + overlap_power=2, + ), + dict( + name="fft_uniform_lr_geom_current_spectral", + spectral_mode="edge_ratio_current", + spectral_power=2, + localization_domain="fft", + fft_localization_family="uniform", + localization_vectors="perm", + localization_side="geom", + overlap_domain="standard", + overlap_side="right", + overlap_power=2, + ), + dict( + name="fft_uniform_lr_geom_fft_topmass", + spectral_mode="edge_ratio_current", + spectral_power=2, + localization_domain="fft", + fft_localization_family="uniform", + localization_vectors="perm", + localization_side="geom", + overlap_domain="fft", + fft_overlap_measure="top_frequency_mass", + overlap_vectors="perm", + overlap_side="geom", + overlap_power=1, + ), + dict( + name="fft_pt_lr_geom_fft_topmass", + spectral_mode="edge_ratio_current", + spectral_power=2, + localization_domain="fft", + fft_localization_family="porter_thomas", + localization_vectors="perm", + localization_side="geom", + overlap_domain="fft", + fft_overlap_measure="top_frequency_mass", + overlap_vectors="perm", + overlap_side="geom", + overlap_power=1, + ), + dict( + name="fft_uniform_total_fraction_lr_geom_fft_topmass", + spectral_mode="total_fraction", + spectral_power=1, + localization_domain="fft", + fft_localization_family="uniform", + localization_vectors="perm", + localization_side="geom", + overlap_domain="fft", + fft_overlap_measure="top_frequency_mass", + overlap_vectors="perm", + overlap_side="geom", + overlap_power=1, + ), + dict( + name="fft_pt_total_fraction_lr_geom_fft_topmass", + spectral_mode="total_fraction", + spectral_power=1, + localization_domain="fft", + fft_localization_family="porter_thomas", + localization_vectors="perm", + localization_side="geom", + overlap_domain="fft", + fft_overlap_measure="top_frequency_mass", + overlap_vectors="perm", + overlap_side="geom", + overlap_power=1, + ), +] + + +def resolve_burden_variant_configs(burden_variants, trap_fft=False): + if burden_variants is None: + return None + if burden_variants == "default": + out = copy.deepcopy(DEFAULT_BURDEN_VARIANTS) + if bool(trap_fft): + out.extend(copy.deepcopy(FFT_DEFAULT_BURDEN_VARIANTS)) + return out + if isinstance(burden_variants, list): + return copy.deepcopy(burden_variants) + raise ValueError("burden_variants must be None, 'default', or list[dict]") + + +def compute_burden_variants(components, variant_configs): + out = {} + if variant_configs is None: + return out + for cfg in variant_configs: + name = cfg.get("name") + if not name: + raise ValueError("Each burden variant config must include a non-empty 'name'") + out[f"trap_variance_burden__{name}"] = compute_burden_variant(components, cfg) + return out + + +def _confusion_from_binary(y_true, y_pred): + y_true = np.asarray(y_true, dtype=int) + y_pred = np.asarray(y_pred, dtype=int) + tp = int(np.sum((y_true == 1) & (y_pred == 1))) + fp = int(np.sum((y_true == 0) & (y_pred == 1))) + tn = int(np.sum((y_true == 0) & (y_pred == 0))) + fn = int(np.sum((y_true == 1) & (y_pred == 0))) + return tp, fp, tn, fn + + +def evaluate_burden_variants( + trap_df, + label_col, + variant_cols=None, + positive_label=1, + thresholds=None, + top_k=None, + group_col=None, +): + df = trap_df.copy() + if variant_cols is None: + variant_cols = [c for c in df.columns if c.startswith("trap_variance_burden__")] + if thresholds is None: + thresholds = [0.0] + if top_k is None: + top_k = [] + + y_true = (df[label_col] == positive_label).astype(int).to_numpy() + rows = [] + for col in variant_cols: + scores = pd.to_numeric(df[col], errors="coerce").fillna(0.0).to_numpy() + + auroc = np.nan + auprc = np.nan + try: + from sklearn.metrics import average_precision_score, roc_auc_score + if len(np.unique(y_true)) > 1: + auroc = float(roc_auc_score(y_true, scores)) + auprc = float(average_precision_score(y_true, scores)) + except Exception: + pass + + for th in thresholds: + y_pred = (scores >= float(th)).astype(int) + tp, fp, tn, fn = _confusion_from_binary(y_true, y_pred) + n = int(len(y_true)) + precision = float(tp / (tp + fp)) if (tp + fp) > 0 else np.nan + recall = float(tp / (tp + fn)) if (tp + fn) > 0 else np.nan + specificity = float(tn / (tn + fp)) if (tn + fp) > 0 else np.nan + fpr = float(fp / (fp + tn)) if (fp + tn) > 0 else np.nan + rows.append( + dict( + variant=col, + threshold=float(th), + top_k=np.nan, + n=n, + tp=tp, + fp=fp, + tn=tn, + fn=fn, + false_positive_rate=fpr, + precision=precision, + recall=recall, + specificity=specificity, + false_positives=fp, + auroc=auroc, + auprc=auprc, + ) + ) + + ranked = np.argsort(-scores) + for k in top_k: + kk = int(min(max(int(k), 0), len(scores))) + y_pred = np.zeros_like(y_true) + if kk > 0: + y_pred[ranked[:kk]] = 1 + tp, fp, tn, fn = _confusion_from_binary(y_true, y_pred) + n = int(len(y_true)) + precision = float(tp / (tp + fp)) if (tp + fp) > 0 else np.nan + recall = float(tp / (tp + fn)) if (tp + fn) > 0 else np.nan + specificity = float(tn / (tn + fp)) if (tn + fp) > 0 else np.nan + fpr = float(fp / (fp + tn)) if (fp + tn) > 0 else np.nan + rows.append( + dict( + variant=col, + threshold=np.nan, + top_k=kk, + n=n, + tp=tp, + fp=fp, + tn=tn, + fn=fn, + false_positive_rate=fpr, + precision=precision, + recall=recall, + specificity=specificity, + false_positives=fp, + auroc=auroc, + auprc=auprc, + ) + ) + + return pd.DataFrame(rows) diff --git a/weightwatcher/trap_fourier.py b/weightwatcher/trap_fourier.py new file mode 100644 index 0000000..58f7343 --- /dev/null +++ b/weightwatcher/trap_fourier.py @@ -0,0 +1,248 @@ +import numpy as np + +DEFAULT_TRAP_FFT_CONFIG = { + "sides": "both", + "vectors": "both", + "fold_conjugates": True, + "exclude_dc": False, + "top_frequency_l": 1, + "selected_frequencies": None, + "normalization": "ortho", + "baseline": "uniform", # uniform | pt_real_mc | pt_complex + "mc_samples": 2048, + "mc_seed": 123, + "modulus": None, + "apply_only_if_length_matches_modulus": False, + "layer_fft_map": None, +} + +_PT_CACHE = {} + + +def resolve_trap_fft_config(cfg=None): + out = dict(DEFAULT_TRAP_FFT_CONFIG) + if cfg: + out.update(dict(cfg)) + out["top_frequency_l"] = int(out.get("top_frequency_l", 1)) + if out["top_frequency_l"] < 1: + raise ValueError("trap_fft_config['top_frequency_l'] must be >= 1") + return out + + +def normalize_vector(vec): + v = np.asarray(vec).ravel() + if v.size == 0: + return None + if not np.iscomplexobj(v): + v = v.astype(float) + norm = np.linalg.norm(v) + if (not np.isfinite(norm)) or norm <= 0.0: + return None + v = v / norm + if not np.all(np.isfinite(np.real(v))) or not np.all(np.isfinite(np.imag(v))): + return None + return v + + +def _fold_conjugates(mass): + n = len(mass) + if n == 0: + return np.array([], dtype=float) + out = [float(mass[0])] + for k in range(1, (n - 1) // 2 + 1): + out.append(float(mass[k] + mass[n - k])) + if n % 2 == 0: + out.append(float(mass[n // 2])) + out = np.asarray(out, dtype=float) + s = np.sum(out) + if np.isfinite(s) and s > 0: + out = out / s + return out + + +def fourier_mass(vec, fold_conjugates=True, exclude_dc=False, normalization="ortho"): + v = normalize_vector(vec) + if v is None: + return { + "fft_values": None, + "mass": np.array([], dtype=float), + "folded_mass": None, + "effective_mass": np.array([], dtype=float), + "n": 0, + "n_effective": 0, + } + + fv = np.fft.fft(v, norm=normalization) + mass = np.abs(fv) ** 2 + s = np.sum(mass) + if not np.isfinite(s) or s <= 0: + return { + "fft_values": fv, + "mass": np.array([], dtype=float), + "folded_mass": None, + "effective_mass": np.array([], dtype=float), + "n": len(v), + "n_effective": 0, + } + mass = mass / s + folded = _fold_conjugates(mass) if fold_conjugates else None + effective = folded if fold_conjugates else mass.copy() + if exclude_dc and len(effective) > 1: + effective = effective.copy() + effective[0] = 0.0 + e = np.sum(effective) + if np.isfinite(e) and e > 0: + effective = effective / e + + return { + "fft_values": fv, + "mass": np.asarray(mass, dtype=float), + "folded_mass": None if folded is None else np.asarray(folded, dtype=float), + "effective_mass": np.asarray(effective, dtype=float), + "n": int(len(v)), + "n_effective": int(len(effective)), + } + + +def fourier_ipr(vec, **kwargs): + fm = fourier_mass(vec, **kwargs) + p = fm["effective_mass"] + if p.size == 0: + return float(np.nan) + return float(np.sum(p ** 2)) + + +def fourier_uniform_centered_q(vec, **kwargs): + fm = fourier_mass(vec, **kwargs) + p = fm["effective_mass"] + n_eff = fm["n_effective"] + if p.size == 0: + return float(np.nan), float(np.nan) + ipr = float(np.sum(p ** 2)) + if n_eff <= 1: + return ipr, 1.0 + q = (n_eff * ipr - 1.0) / (n_eff - 1.0) + return ipr, float(np.clip(q, 0.0, 1.0)) + + +def _expected_ipr_real_mc(n, fold_conjugates, exclude_dc, normalization, mc_samples, mc_seed): + key = (int(n), bool(fold_conjugates), bool(exclude_dc), str(normalization), int(mc_samples), int(mc_seed)) + if key in _PT_CACHE: + return _PT_CACHE[key] + rng = np.random.RandomState(int(mc_seed)) + vals = [] + for _ in range(int(mc_samples)): + v = rng.normal(size=int(n)) + vals.append(fourier_ipr(v, fold_conjugates=fold_conjugates, exclude_dc=exclude_dc, normalization=normalization)) + expected = float(np.nanmean(vals)) if len(vals) else float(np.nan) + _PT_CACHE[key] = expected + return expected + + +def fourier_pt_centered_q(vec, baseline="pt_real_mc", mc_samples=2048, mc_seed=123, **kwargs): + fm = fourier_mass(vec, **kwargs) + p = fm["effective_mass"] + n_eff = fm["n_effective"] + if p.size == 0: + return float(np.nan), float(np.nan) + ipr = float(np.sum(p ** 2)) + if n_eff <= 1: + return ipr, 1.0 + + if baseline == "uniform": + expected = 1.0 / n_eff + elif baseline == "pt_complex": + expected = 2.0 / (n_eff + 1.0) + elif baseline == "pt_real_mc": + expected = _expected_ipr_real_mc( + fm["n"], + kwargs.get("fold_conjugates", True), + kwargs.get("exclude_dc", False), + kwargs.get("normalization", "ortho"), + mc_samples, + mc_seed, + ) + else: + raise ValueError(f"Unknown baseline: {baseline}") + + if (not np.isfinite(expected)) or expected >= 1.0: + return ipr, float(np.nan) + q = (ipr - expected) / (1.0 - expected) + return ipr, float(np.clip(q, 0.0, 1.0)) + + +def fourier_top_frequency_mass(vec, top_frequency_l=1, **kwargs): + fm = fourier_mass(vec, **kwargs) + p = fm["effective_mass"] + if p.size == 0: + return float(np.nan), [], 0 + l = int(top_frequency_l) + if l < 1: + raise ValueError("top_frequency_l must be >= 1") + l_eff = min(l, len(p)) + order = np.argsort(p)[::-1] + idx = order[:l_eff] + return float(np.sum(p[idx])), [int(i) for i in idx], int(l_eff) + + +def fourier_selected_frequency_mass(vec, selected_frequencies=None, **kwargs): + if selected_frequencies is None: + return float(np.nan) + fm = fourier_mass(vec, **kwargs) + p = fm["effective_mass"] + if p.size == 0: + return float(np.nan) + idx = [int(i) for i in selected_frequencies if 0 <= int(i) < len(p)] + return float(np.sum(p[idx])) if len(idx) else 0.0 + + +def fourier_component_summary(vec, prefix, trap_fft_config): + cfg = resolve_trap_fft_config(trap_fft_config) + shared = dict( + fold_conjugates=bool(cfg.get("fold_conjugates", True)), + exclude_dc=bool(cfg.get("exclude_dc", False)), + normalization=cfg.get("normalization", "ortho"), + ) + ipr, q_uniform = fourier_uniform_centered_q(vec, **shared) + _, q_pt = fourier_pt_centered_q( + vec, + baseline=cfg.get("baseline", "uniform"), + mc_samples=int(cfg.get("mc_samples", 2048)), + mc_seed=int(cfg.get("mc_seed", 123)), + **shared, + ) + top_mass, top_idx, l_eff = fourier_top_frequency_mass( + vec, + top_frequency_l=int(cfg.get("top_frequency_l", 1)), + **shared, + ) + selected_mass = fourier_selected_frequency_mass( + vec, + selected_frequencies=cfg.get("selected_frequencies", None), + **shared, + ) + + peak_freq = top_idx[0] if len(top_idx) else np.nan + peak_mass = top_mass if l_eff == 1 else np.nan + + return { + f"{prefix}_fft_ipr": ipr, + f"{prefix}_fft_q_uniform": q_uniform, + f"{prefix}_fft_q_pt": q_pt, + f"{prefix}_fft_top_frequency_mass": top_mass, + f"{prefix}_fft_peak_frequency": peak_freq, + f"{prefix}_fft_peak_mass": peak_mass, + f"{prefix}_fft_selected_frequency_mass": selected_mass, + } + + +def length_matches_modulus(vec_len, trap_fft_config): + cfg = resolve_trap_fft_config(trap_fft_config) + if not bool(cfg.get("apply_only_if_length_matches_modulus", False)): + return True + modulus = cfg.get("modulus", None) + if modulus is None: + return False + m = int(modulus) + allowed = {m, m + 1} + return int(vec_len) in allowed diff --git a/weightwatcher/trap_identity.py b/weightwatcher/trap_identity.py new file mode 100644 index 0000000..3c35047 --- /dev/null +++ b/weightwatcher/trap_identity.py @@ -0,0 +1,115 @@ +import hashlib +import numpy as np +import pandas as pd + + +def permutation_signature(indices): + arr = np.asarray(indices, dtype=np.int64).ravel() + arr = np.ascontiguousarray(arr) + return hashlib.sha256(arr.tobytes()).hexdigest() + + +def make_trap_identity_key(layer_id, seed, trap_index, n_traps, perm_signature): + seed_str = "none" if seed is None else str(seed) + perm_short = (perm_signature or "")[:16] + return f"layer={layer_id}|seed={seed_str}|trap_index={trap_index}|n_traps={n_traps}|perm={perm_short}" + + +def abs_cosine(vec_a, vec_b): + if vec_a is None or vec_b is None: + return float(np.nan) + a = np.asarray(vec_a, dtype=float).ravel() + b = np.asarray(vec_b, dtype=float).ravel() + if a.size == 0 or b.size == 0: + return float(np.nan) + n = min(a.size, b.size) + a = a[:n] + b = b[:n] + na = np.linalg.norm(a) + nb = np.linalg.norm(b) + if (not np.isfinite(na)) or (not np.isfinite(nb)) or na <= 0 or nb <= 0: + return float(np.nan) + return float(abs(np.dot(a, b) / (na * nb))) + + +def compare_numeric(a, b, rtol=1e-4, atol=1e-6): + try: + fa = float(a) + fb = float(b) + except Exception: + return False + if not np.isfinite(fa) or not np.isfinite(fb): + return False + return bool(np.isclose(fa, fb, rtol=rtol, atol=atol)) + + +def verify_trap_rows(analyze_row, remove_row, rtol=1e-4, atol=1e-6, min_vector_cosine=0.999): + perm_match = str(analyze_row.get("perm_signature", "")) == str(remove_row.get("perm_signature", "")) + eval_close = compare_numeric(analyze_row.get("eval_perm", np.nan), remove_row.get("eval_perm", np.nan), rtol=rtol, atol=atol) + bulk_close = compare_numeric(analyze_row.get("mp_bulk_max", np.nan), remove_row.get("mp_bulk_max", np.nan), rtol=rtol, atol=atol) + delta_close = compare_numeric(analyze_row.get("trap_delta", np.nan), remove_row.get("trap_delta", np.nan), rtol=rtol, atol=atol) + q_close = compare_numeric(analyze_row.get("trap_q", np.nan), remove_row.get("trap_q", np.nan), rtol=rtol, atol=atol) + overlap_close = compare_numeric( + analyze_row.get("trap_top_sector_overlap", np.nan), + remove_row.get("trap_top_sector_overlap", np.nan), + rtol=rtol, + atol=atol, + ) + + v_cos = abs_cosine(analyze_row.get("v_trap", None), remove_row.get("v_trap", None)) + vec_ok = True if not np.isfinite(v_cos) else (v_cos >= float(min_vector_cosine)) + + return { + "perm_match": perm_match, + "eval_perm_close": eval_close, + "mp_bulk_max_close": bulk_close, + "trap_delta_close": delta_close, + "trap_q_close": q_close, + "trap_top_sector_overlap_close": overlap_close, + "v_abs_cosine": v_cos, + "vector_close": vec_ok, + "trap_verified": bool(perm_match and eval_close and bulk_close and delta_close and q_close and overlap_close and vec_ok), + } + + +def build_trap_verification_row(analyze_row, remove_row, verify_dict, removed=False, removal_error=None): + row = { + "layer_id": analyze_row.get("layer_id", remove_row.get("layer_id", np.nan)), + "trap_index": analyze_row.get("trap_index", remove_row.get("trap_index", np.nan)), + "trap_seed": analyze_row.get("trap_seed", remove_row.get("trap_seed", np.nan)), + "n_traps_analyze": analyze_row.get("n_traps", np.nan), + "n_traps_remove": remove_row.get("n_traps", np.nan), + "perm_signature_analyze": analyze_row.get("perm_signature", None), + "perm_signature_remove": remove_row.get("perm_signature", None), + "perm_match": verify_dict.get("perm_match", False), + "eval_perm_analyze": analyze_row.get("eval_perm", np.nan), + "eval_perm_remove": remove_row.get("eval_perm", np.nan), + "eval_perm_close": verify_dict.get("eval_perm_close", False), + "mp_bulk_max_analyze": analyze_row.get("mp_bulk_max", np.nan), + "mp_bulk_max_remove": remove_row.get("mp_bulk_max", np.nan), + "mp_bulk_max_close": verify_dict.get("mp_bulk_max_close", False), + "trap_delta_analyze": analyze_row.get("trap_delta", np.nan), + "trap_delta_remove": remove_row.get("trap_delta", np.nan), + "trap_delta_close": verify_dict.get("trap_delta_close", False), + "trap_q_analyze": analyze_row.get("trap_q", np.nan), + "trap_q_remove": remove_row.get("trap_q", np.nan), + "trap_q_close": verify_dict.get("trap_q_close", False), + "trap_top_sector_overlap_analyze": analyze_row.get("trap_top_sector_overlap", np.nan), + "trap_top_sector_overlap_remove": remove_row.get("trap_top_sector_overlap", np.nan), + "trap_top_sector_overlap_close": verify_dict.get("trap_top_sector_overlap_close", False), + "v_abs_cosine": verify_dict.get("v_abs_cosine", np.nan), + "trap_verified": verify_dict.get("trap_verified", False), + "removed": bool(removed), + "removal_error": removal_error, + } + return row + + +def coerce_traps_dataframe(traps): + if traps is None: + return None + if isinstance(traps, pd.DataFrame): + return traps.copy() + if isinstance(traps, pd.Series): + return pd.DataFrame([traps.to_dict()]) + return pd.DataFrame(traps) diff --git a/weightwatcher/weightwatcher.py b/weightwatcher/weightwatcher.py index 0da873b..32a0ab0 100644 --- a/weightwatcher/weightwatcher.py +++ b/weightwatcher/weightwatcher.py @@ -2915,7 +2915,7 @@ def apply_FFT(self, ww_layer, params=None): layer_id = ww_layer.layer_id name = ww_layer.name - if not ww_layer.skippe: + if not ww_layer.skipped: logger.info("applying 2D FFT on to {} {} ".format(layer_id, name)) Wmats = ww_layer.Wmats @@ -3701,11 +3701,17 @@ def analyze_traps(self, model=None, layers=[], conv2d_norm=True, ww2x=DEFAULT_WW2X, pool=DEFAULT_POOL, conv2d_fft=False, fft=False, channels=None, + trap_fft=False, trap_fft_config=None, svd_method=FAST_SVD, start_ids=DEFAULT_START_ID, base_model=None, peft=DEFAULT_PEFT, - rng=None): + seed=None, + rng=None, + top_sector_l=1, + burden_variants=None, + return_burden_components=False, + return_burden_raw=False): """Analyze randomized correlation traps and return one row per trap. This method follows the randomized/permuted trap workflow: @@ -3744,12 +3750,19 @@ def analyze_traps(self, model=None, layers=[], pool=pool, conv2d_fft=conv2d_fft, fft=fft, + trap_fft=trap_fft, + trap_fft_config=trap_fft_config, channels=channels, svd_method=svd_method, start_ids=start_ids, base_model=base_model, peft=peft, + seed=seed, rng=rng, + top_sector_l=top_sector_l, + burden_variants=burden_variants, + return_burden_components=return_burden_components, + return_burden_raw=return_burden_raw, ) def _trap_result_columns(self): @@ -3767,6 +3780,11 @@ def _trap_result_columns(self): "v_l2_fourth_moment", "v_l2_sixth_moment", "v_effective_support", "v_gini_abs", "v_top1_mass", "v_top5_mass", "v_top10_mass", "v_squared_amp_entropy", "v_stable_rank_surrogate", "trap_balance_ratio", "trap_detected", "trap_eval_minus_bulk", + "trap_seed", "n_traps", "perm_signature", "permutation_n", "permutation_mode", "trap_identity_key", + "trap_delta", "trap_ipr", "trap_q", "trap_diffuseness", + "trap_q_uniform", "trap_diffuseness_uniform", + "top_sector_l", "top_sector_l_effective", "trap_top_sector_overlap", + "trap_variance_burden", "layer_trap_variance_burden", "trap_diffuseness_score", "trap_risk_score", "trap_assessment", ] @@ -3785,6 +3803,7 @@ def apply_analyze_traps(self, ww_layer, params=None): self.apply_permute_W(ww_layer, params) self.apply_trap_mp_fit(ww_layer, params=params) trap_mode_indices = self.identify_trap_mode_indices(ww_layer, params=params) + params["_layer_n_traps"] = int(len(trap_mode_indices)) trap_rows = [] for trap_index, mode_index in enumerate(trap_mode_indices): @@ -3817,133 +3836,44 @@ def identify_trap_mode_indices(self, ww_layer, params=None): def compute_original_basis_for_traps(self, ww_layer, params=None): - if params is None: params = DEFAULT_PARAMS.copy() - if len(ww_layer.Wmats) != 1: - return None - - W_true = ww_layer.Wmats[0].astype(float) - U0, S0, V0h = svd_full(W_true, method=params[SVD_METHOD]) - return { - "W_true": W_true, - "U0": U0, - "S0": S0, - "V0": V0h.T, - } + from . import trap_analysis + return trap_analysis.compute_original_basis_for_traps(self, ww_layer, params=params) def analyze_single_trap(self, ww_layer, trap_mode_index, original_basis_cache=None, params=None, trap_index=0): - if params is None: params = DEFAULT_PARAMS.copy() - if original_basis_cache is None: - original_basis_cache = self.compute_original_basis_for_traps(ww_layer, params=params) - - W_perm = ww_layer.Wmats[0].astype(float) - p_ids = ww_layer.permute_ids[0] - - U_perm, S_perm, Vh_perm = svd_full(W_perm, method=params[SVD_METHOD]) - V_perm = Vh_perm.T - - sigma_perm = float(S_perm[trap_mode_index]) - u_perm = U_perm[:, trap_mode_index] - v_perm = V_perm[:, trap_mode_index] - - T_perm = sigma_perm * np.outer(u_perm, v_perm) - T_orig = unpermute_matrix(T_perm, p_ids) - - Ut, St, Vht = svd_full(T_orig, method=params[SVD_METHOD]) - u_trap = Ut[:, 0] - v_trap = Vht.T[:, 0] - - U0 = original_basis_cache["U0"] - V0 = original_basis_cache["V0"] - - left_overlaps = np.abs(U0.T @ u_trap) ** 2 - right_overlaps = np.abs(V0.T @ v_trap) ** 2 - - left_top_mode = int(np.argmax(left_overlaps)) - right_top_mode = int(np.argmax(right_overlaps)) - left_top_mass = float(np.max(left_overlaps)) - right_top_mass = float(np.max(right_overlaps)) - - eps = 1e-12 - left_overlap_entropy = float(-np.sum((left_overlaps + eps) * np.log(left_overlaps + eps))) - right_overlap_entropy = float(-np.sum((right_overlaps + eps) * np.log(right_overlaps + eps))) - left_overlap_ipr = float(np.sum(left_overlaps ** 2)) - right_overlap_ipr = float(np.sum(right_overlaps ** 2)) - - st_sq = St * St - rank1_mass_after_unpermute = float(st_sq[0] / (np.sum(st_sq) + eps)) - - u_metrics = self._trap_vector_metrics(u_trap) - v_metrics = self._trap_vector_metrics(v_trap) - u_oi = self._trap_vector_order_invariant_stats(u_trap) - v_oi = self._trap_vector_order_invariant_stats(v_trap) - - eval_perm = sigma_perm ** 2 - trap_result = { - "layer_id": ww_layer.layer_id, - "name": ww_layer.name, - "longname": ww_layer.longname, - "layer_type": str(ww_layer.the_type), - "N": ww_layer.N, - "M": ww_layer.M, - "rf": ww_layer.rf, - "Q": ww_layer.N / ww_layer.M if ww_layer.M > 0 else np.nan, - "trap_index": int(trap_index), - "perm_mode_index": int(trap_mode_index), - "sigma_perm": sigma_perm, - "eval_perm": float(eval_perm), - "mp_bulk_max": float(ww_layer.bulk_max), - "mp_bulk_min": float(ww_layer.bulk_min), - "sigma_mp": float(ww_layer.sigma_mp), - "num_spikes": int(ww_layer.num_spikes), - "rank1_mass_after_unpermute": rank1_mass_after_unpermute, - "sigma_trap_top": float(St[0]), - "left_top_mode": left_top_mode, - "right_top_mode": right_top_mode, - "left_top_mass": left_top_mass, - "right_top_mass": right_top_mass, - "left_overlap_entropy": left_overlap_entropy, - "right_overlap_entropy": right_overlap_entropy, - "left_overlap_ipr": left_overlap_ipr, - "right_overlap_ipr": right_overlap_ipr, - "trap_detected": True, - "trap_eval_minus_bulk": float(eval_perm - ww_layer.bulk_max), - } - - for k, v in u_metrics.items(): - trap_result[f"u_{k}"] = v - for k, v in v_metrics.items(): - trap_result[f"v_{k}"] = v - for k, v in u_oi.items(): - trap_result[f"u_{k}"] = v - for k, v in v_oi.items(): - trap_result[f"v_{k}"] = v - - trap_result["trap_balance_ratio"] = float( - trap_result["u_effective_support"] / (trap_result["v_effective_support"] + 1e-12) + from . import trap_analysis + return trap_analysis.analyze_single_trap( + self, + ww_layer, + trap_mode_index=trap_mode_index, + original_basis_cache=original_basis_cache, + params=params, + trap_index=trap_index, ) - trap_result.update(self.assess_trap_diffuseness(trap_result)) - trap_result["left_overlaps"] = left_overlaps - trap_result["right_overlaps"] = right_overlaps - trap_result["u_trap"] = u_trap - trap_result["v_trap"] = v_trap - trap_result["T_orig"] = T_orig - trap_result["perm_evals_sorted"] = np.array(ww_layer.evals).copy() + def compute_trap_delta(self, eval_perm, mp_bulk_max): + from . import trap_analysis + return trap_analysis.compute_trap_delta(eval_perm=eval_perm, mp_bulk_max=mp_bulk_max) - if params[PLOT]: - self.plot_trap_analysis(ww_layer, trap_result, params=params) + def compute_trap_ipr_q(self, vec): + from . import trap_analysis + return trap_analysis.compute_trap_ipr_q(vec) - trap_result.pop("left_overlaps", None) - trap_result.pop("right_overlaps", None) - trap_result.pop("u_trap", None) - trap_result.pop("v_trap", None) - if not params.get("_keep_trap_matrix", False): - trap_result.pop("T_orig", None) - trap_result.pop("perm_evals_sorted", None) + def compute_trap_ipr_q_uniform(self, vec): + from . import trap_analysis + return trap_analysis.compute_trap_ipr_q_uniform(vec) - return trap_result + def compute_top_sector_overlap(self, overlaps, top_sector_l=1): + from . import trap_analysis + return trap_analysis.compute_top_sector_overlap(overlaps, top_sector_l=top_sector_l) + def compute_trap_variance_burden(self, trap_delta, trap_q, trap_top_sector_overlap): + from . import trap_analysis + return trap_analysis.compute_trap_variance_burden( + trap_delta=trap_delta, + trap_q=trap_q, + trap_top_sector_overlap=trap_top_sector_overlap, + ) def assess_trap_diffuseness(self, trap_result): """Heuristic classifier for trap severity in original weight space. @@ -3951,6 +3881,11 @@ def assess_trap_diffuseness(self, trap_result): Trap risk is computed from normalized trap strength and then explicitly downweighted by diffuseness. This is intentionally a separate function so it can be unit-tested and adjusted independently. + + NOTE: trap_q/trap_diffuseness are Porter-Thomas-centered paper-facing localization + metrics. trap_q_uniform/trap_diffuseness_uniform preserve the older + uniform-centered localization. trap_diffuseness_score/trap_risk_score/trap_assessment + remain heuristic diagnostics retained for backward compatibility. """ eps = 1e-12 @@ -5659,11 +5594,15 @@ def apply_remove_traps(self, ww_layer, trap_indices, params=None, seed=None, rng return remove_traps_ops.apply_remove_traps(self, ww_layer, trap_indices, params=params, seed=seed, rng=rng) def remove_traps(self, model=None, layers=[], trap_indices=None, seed=None, rng=None, pool=True, plot=True, - start_ids=DEFAULT_START_ID, svd_method=FAST_SVD, base_model=None, peft=DEFAULT_PEFT): + start_ids=DEFAULT_START_ID, svd_method=FAST_SVD, base_model=None, peft=DEFAULT_PEFT, + verify_traps=False, return_analyze=False, traps=None, + rtol=1e-4, atol=1e-6, min_vector_cosine=0.999): """Remove selected randomized MP/TW traps from dense layers.""" return remove_traps_ops.remove_traps( self, model=model, layers=layers, trap_indices=trap_indices, seed=seed, rng=rng, - pool=pool, plot=plot, start_ids=start_ids, svd_method=svd_method, base_model=base_model, peft=peft + pool=pool, plot=plot, start_ids=start_ids, svd_method=svd_method, base_model=base_model, peft=peft, + verify_traps=verify_traps, return_analyze=return_analyze, traps=traps, + rtol=rtol, atol=atol, min_vector_cosine=min_vector_cosine, )