Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,73 @@ New in v0.8.0: trap-level randomized diagnostics:
trap_df = watcher.analyze_traps(layers=[3, 5], plot=True, savefig="trap_images")
```

`analyze_traps()` now reports paper-aligned trap metrics from the NeurIPS trap workflow,
including normalized spectral excess (`trap_delta`), localization (`trap_q` / `trap_ipr`),
top-sector overlap (`trap_top_sector_overlap` with configurable `top_sector_l`), and
the primary paper scalar `trap_variance_burden`.

`trap_q` is now Porter-Thomas-centered (Porter Thomas centered) relative to a random-vector baseline,
`E[IPR] ≈ 3/(m+2)` for real vectors), and `trap_variance_burden` uses this
Porter-Thomas-centered `trap_q`.

For backward comparison, uniform-centered localization is still available as
`trap_q_uniform` (and `trap_diffuseness_uniform`), while
`trap_diffuseness_score` remains the separate heuristic diagnostic.

The legacy heuristic diagnostics are still returned for backward compatibility
(`trap_diffuseness_score`, `trap_risk_score`, `trap_assessment`, plus legacy overlap/excess
fields), but the paper-facing metrics above are the primary outputs for trap interpretation.

For burden-variant experiments (e.g., comparing Porter-Thomas vs uniform localization or
different spectral normalizations), `analyze_traps` also supports:

- `burden_variants=None` (default): keep current PR-358 burden behavior.
- `burden_variants="default"`: add a standard sweep of variant columns
(`trap_variance_burden__*`).
- `return_burden_components=True`: expose scalar burden components for notebook analysis.
- `return_burden_raw=True`: expose raw vectors/overlaps (`u_perm`, `v_perm`, `u_trap`,
`v_trap`, `left_overlaps`, `right_overlaps`, `perm_evals_sorted`).

Fourier/FFT diagnostics are also available:

- `fft=False` (default): keeps standard-basis behavior unchanged.
- `trap_fft=True`: adds Fourier-space localization and Fourier-frequency mass diagnostics for trap vectors.
- `trap_fft_config={...}`: optional FFT diagnostics config (partial dicts are supported).

The legacy `fft=True` argument is still reserved for the older layer-matrix FFT path; use
`trap_fft=True` for trap-vector Fourier diagnostics.

Recommended `trap_fft_config` fields:

```python
{
"sides": "both", # "left" | "right" | "both"
"vectors": "both", # "perm" | "trap" | "both"
"fold_conjugates": True, # combine +k/-k bins for real vectors
"exclude_dc": False, # exclude freq-0 from effective mass diagnostics
"top_frequency_l": 1, # top-k frequencies for top-frequency mass
"selected_frequencies": None, # optional explicit frequency bins
"normalization": "ortho", # np.fft.fft(..., norm="ortho")
"baseline": "uniform", # "uniform" | "pt_real_mc" | "pt_complex"
"mc_samples": 2048, # used by baseline="pt_real_mc"
"mc_seed": 123,
"modulus": None, # optional modular-arithmetic length hint
"apply_only_if_length_matches_modulus": False,
"layer_fft_map": None,
}
```

When `burden_variants="default"` and `trap_fft=True`, FFT burden variants are added
(`trap_variance_burden__fft_*`), for example:

- `trap_variance_burden__fft_uniform_right_current_spectral`
- `trap_variance_burden__fft_uniform_lr_geom_fft_topmass`
- `trap_variance_burden__fft_pt_lr_geom_fft_topmass`

Note: Fourier overlap is intentionally based on Fourier-frequency mass concentration
(`trap_fft_top_frequency_mass_*` / `trap_fft_selected_frequency_mass_*`), not
unitary-transformed vector overlap.

See the new usage guide: [Correlation Trap Workflow (`analyze_traps` + `remove_traps`)](./docs_trap_features.md)

## PEFT / LORA models (experimental)
Expand Down Expand Up @@ -289,6 +356,18 @@ watcher = ww.WeightWatcher(model=my_model)
trap_df = watcher.analyze_traps(layers=[3, 5], plot=True, savefig="trap_images")
```

The implementation is designed to follow the trap definitions in the NeurIPS paper first,
while preserving older diagnostic fields for compatibility with existing scripts/notebooks.
If you want the paper metrics directly, use:

- `trap_delta`
- `trap_ipr`
- `trap_q`
- `trap_diffuseness` (`1 - trap_q`)
- `trap_top_sector_overlap` (cumulative overlap over first `top_sector_l` modes)
- `trap_variance_burden`
- `layer_trap_variance_burden`

For a complete walkthrough (including `remove_traps`), see: [Correlation Trap Workflow (`analyze_traps` + `remove_traps`)](./docs_trap_features.md)

Fig (a) is well trained; Fig (b) may be over-fit.
Expand Down
128 changes: 128 additions & 0 deletions tests/test_analyze_traps.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,76 @@
TORCH_AVAILABLE = False

import weightwatcher as ww
import weightwatcher.trap_analysis as trap_analysis


class TestTrapMetricHelpers(unittest.TestCase):

def setUp(self):
self.watcher = ww.WeightWatcher()

def test_compute_trap_delta(self):
self.assertAlmostEqual(self.watcher.compute_trap_delta(12.0, 10.0), 0.2)
self.assertAlmostEqual(self.watcher.compute_trap_delta(8.0, 10.0), 0.0)
self.assertTrue(np.isnan(self.watcher.compute_trap_delta(12.0, 0.0)))

def test_compute_trap_ipr_q_porter_thomas_baseline_vector(self):
# m=10 => E_PT[IPR]=3/(m+2)=0.25
v = np.array([0.5, 0.5, 0.5, 0.5] + [0.0] * 6)
ipr, q = self.watcher.compute_trap_ipr_q(v)
self.assertAlmostEqual(ipr, 0.25)
self.assertAlmostEqual(q, 0.0, places=7)

def test_compute_trap_ipr_q_uniform_vector_clips_to_zero_under_pt(self):
v = np.ones(10) / np.sqrt(10)
ipr, q = self.watcher.compute_trap_ipr_q(v)
self.assertAlmostEqual(ipr, 0.1)
self.assertAlmostEqual(q, 0.0)

def test_compute_trap_ipr_q_one_hot_vector(self):
v = np.zeros(10)
v[0] = 1.0
ipr, q = self.watcher.compute_trap_ipr_q(v)
self.assertAlmostEqual(ipr, 1.0)
self.assertAlmostEqual(q, 1.0)

def test_compute_trap_ipr_q_uniform_legacy_field_behavior(self):
v = np.ones(10) / np.sqrt(10)
ipr, q = self.watcher.compute_trap_ipr_q_uniform(v)
self.assertAlmostEqual(ipr, 0.1)
self.assertAlmostEqual(q, 0.0)

onehot = np.zeros(10)
onehot[0] = 1.0
ipr, q = self.watcher.compute_trap_ipr_q_uniform(onehot)
self.assertAlmostEqual(ipr, 1.0)
self.assertAlmostEqual(q, 1.0)

def test_compute_top_sector_overlap(self):
overlaps = np.array([0.25, 0.10, 0.05, 0.60])

overlap_1, ell_eff = self.watcher.compute_top_sector_overlap(overlaps, 1)
self.assertAlmostEqual(overlap_1, 0.25)
self.assertEqual(ell_eff, 1)

overlap_2, ell_eff = self.watcher.compute_top_sector_overlap(overlaps, 2)
self.assertAlmostEqual(overlap_2, 0.35)
self.assertEqual(ell_eff, 2)

def test_compute_trap_variance_burden(self):
burden = self.watcher.compute_trap_variance_burden(
trap_delta=0.2,
trap_q=0.5,
trap_top_sector_overlap=0.3,
)
self.assertAlmostEqual(burden, 0.2**2 * 0.5 * 0.3**2)

def test_plot_flag_coercion(self):
self.assertFalse(trap_analysis._coerce_plot_flag(False))
self.assertFalse(trap_analysis._coerce_plot_flag("False"))
self.assertFalse(trap_analysis._coerce_plot_flag("0"))
self.assertTrue(trap_analysis._coerce_plot_flag(True))
self.assertTrue(trap_analysis._coerce_plot_flag("True"))


if TORCH_AVAILABLE:
Expand Down Expand Up @@ -137,6 +207,64 @@ def test_order_invariant_stats_are_finite(self):
]:
self.assertTrue(np.isfinite(row[col]))

def test_analyze_traps_contains_paper_metric_columns(self):
df = self.watcher.analyze_traps(plot=False, savefig=False, rng=1337)
required = {
"top_sector_l",
"top_sector_l_effective",
"trap_seed",
"n_traps",
"perm_signature",
"permutation_n",
"trap_identity_key",
"trap_delta",
"trap_ipr",
"trap_q",
"trap_diffuseness",
"trap_q_uniform",
"trap_diffuseness_uniform",
"trap_top_sector_overlap",
"trap_variance_burden",
"layer_trap_variance_burden",
}
self.assertTrue(required.issubset(set(df.columns)))

def test_analyze_traps_respects_top_sector_l_argument(self):
df1 = self.watcher.analyze_traps(plot=False, savefig=False, rng=1337, top_sector_l=1)
df2 = self.watcher.analyze_traps(plot=False, savefig=False, rng=1337, top_sector_l=2)

if len(df1) == 0 or len(df2) == 0:
self.skipTest("No traps detected in this environment")

self.assertTrue((df1["top_sector_l"] == 1).all())
self.assertTrue((df2["top_sector_l"] == 2).all())
self.assertTrue((df1["top_sector_l_effective"] >= 1).all())
self.assertTrue((df1["top_sector_l_effective"] <= df1["top_sector_l"]).all())
self.assertTrue((df2["top_sector_l_effective"] >= 1).all())
self.assertTrue((df2["top_sector_l_effective"] <= df2["top_sector_l"]).all())

def test_trap_variance_burden_formula_rowwise(self):
df = self.watcher.analyze_traps(plot=False, savefig=False, rng=1337)
if len(df) == 0:
self.skipTest("No traps detected in this environment")

for _, row in df.iterrows():
components = [row["trap_delta"], row["trap_q"], row["trap_top_sector_overlap"]]
if not np.all(np.isfinite(components)):
continue
expected = (row["trap_delta"] ** 2) * row["trap_q"] * (row["trap_top_sector_overlap"] ** 2)
self.assertAlmostEqual(row["trap_variance_burden"], expected)

def test_layer_trap_variance_burden_aggregate(self):
df = self.watcher.analyze_traps(plot=False, savefig=False, rng=1337)
if len(df) == 0:
self.skipTest("No traps detected in this environment")

for layer_id, subdf in df.groupby("layer_id"):
expected = subdf["trap_variance_burden"].sum()
observed = subdf["layer_trap_variance_burden"].iloc[0]
self.assertAlmostEqual(observed, expected)


if __name__ == "__main__":
unittest.main()
142 changes: 142 additions & 0 deletions tests/test_remove_traps.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,148 @@ def test_remove_traps_invalid_indices_warns_and_skips(monkeypatch, caplog):
assert len(post_artifacts) == 0


def test_remove_traps_return_analyze_returns_verification_dataframe():
watcher = WeightWatcher(model=None)
W, _, _, _ = _single_trap_setup(seed=505)
ww_layer = make_ww_layer(W)

def _iterator(model=None, layers=None, params=None, base_model=None):
return [ww_layer]

watcher.make_layer_iterator = _iterator
model_out, verify_df = watcher.remove_traps(
model={"dummy_weight": np.array([1.0])},
layers=[],
trap_indices=[1],
seed=44,
pool=True,
plot=False,
return_analyze=True,
)
assert isinstance(model_out, dict)
assert len(verify_df) == 1
assert bool(verify_df.iloc[0]["trap_verified"])
assert bool(verify_df.iloc[0]["removed"])


def test_remove_traps_verify_traps_true_raises_on_mismatch(monkeypatch):
watcher = WeightWatcher(model=None)
W, _, _, _ = _single_trap_setup(seed=606)
ww_layer = make_ww_layer(W)

monkeypatch.setattr(
watcher,
"make_layer_iterator",
lambda model=None, layers=None, params=None, base_model=None: [ww_layer],
)

traps_df = watcher.analyze_traps(model=None, layers=[], seed=11, plot=False, savefig=False, pool=True)
traps_df = traps_df.copy()
traps_df.loc[:, "perm_signature"] = "deadbeef"

with pytest.raises(RuntimeError, match="Trap verification failed"):
watcher.remove_traps(
model={"dummy_weight": np.array([1.0])},
layers=[],
trap_indices=[1],
seed=11,
pool=True,
plot=False,
verify_traps=True,
traps=traps_df,
)


def test_remove_traps_accepts_analyze_dataframe_without_explicit_indices(monkeypatch):
watcher = WeightWatcher(model=None)
W, _, _, _ = _single_trap_setup(seed=707)
ww_layer = make_ww_layer(W)

monkeypatch.setattr(
watcher,
"make_layer_iterator",
lambda model=None, layers=None, params=None, base_model=None: [ww_layer],
)

traps_df = watcher.analyze_traps(model=None, layers=[], seed=22, plot=False, savefig=False, pool=True)
model_out, verify_df = watcher.remove_traps(
model={"dummy_weight": np.array([1.0])},
layers=[],
trap_indices=None,
seed=22,
pool=True,
plot=False,
return_analyze=True,
traps=traps_df,
)
assert isinstance(model_out, dict)
assert len(verify_df) == 1
assert bool(verify_df.iloc[0]["trap_verified"])


def test_remove_traps_verify_true_with_single_trap_row_subset(monkeypatch):
watcher = WeightWatcher(model=None)
seed = 24
W, _, _, _ = _single_trap_setup(seed=909)
ww_layer = make_ww_layer(W)

monkeypatch.setattr(
watcher,
"make_layer_iterator",
lambda model=None, layers=None, params=None, base_model=None: [ww_layer],
)

trap_df = watcher.analyze_traps(model=None, layers=[], seed=seed, plot=False, savefig=False, pool=True)
k = 0
model_out, verify_df = watcher.remove_traps(
model={"dummy_weight": np.array([1.0])},
layers=[],
traps=trap_df.iloc[[k]],
seed=seed,
pool=True,
plot=False,
verify_traps=True,
return_analyze=True,
)

assert isinstance(model_out, dict)
assert len(verify_df) == 1
assert bool(verify_df.iloc[0]["perm_match"])
assert bool(verify_df.iloc[0]["trap_verified"])


def test_remove_traps_verification_regenerates_layer_local_trap_df(monkeypatch):
watcher = WeightWatcher(model=None)
W, _, _, _ = _single_trap_setup(seed=808)
ww_layer = make_ww_layer(W)

monkeypatch.setattr(
watcher,
"make_layer_iterator",
lambda model=None, layers=None, params=None, base_model=None: [ww_layer],
)

original_analyze = watcher.analyze_traps
recorded_layers = []

def _recording_analyze(*args, **kwargs):
recorded_layers.append(tuple(kwargs.get("layers", [])))
return original_analyze(*args, **kwargs)

monkeypatch.setattr(watcher, "analyze_traps", _recording_analyze)
watcher.remove_traps(
model={"dummy_weight": np.array([1.0])},
layers=[],
trap_indices=[1],
seed=33,
pool=True,
plot=False,
return_analyze=True,
)

assert any(call == (int(ww_layer.layer_id),) for call in recorded_layers)


@pytest.mark.skipif(torch is None, reason="PyTorch not installed")
def test_trap_rng_consistency_analyze_vs_collect_single_and_multi_layer():
model = torch.nn.Sequential(
Expand Down
Loading