Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion tests/test_analyze_traps.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,11 @@ def test_analyze_traps_columns(self):
df = self.watcher.analyze_traps(plot=False, savefig=False)
expected_cols = {
"layer_id", "name", "trap_index", "perm_mode_index",
"sigma_perm", "mp_bulk_max", "left_top_mass", "right_top_mass"
"sigma_perm", "mp_bulk_max", "left_top_mass", "right_top_mass",
"top_5_mass", "top_10_mass",
"bulk_localization_mean", "bulk_localization_std",
"bulk_top_5_mass_mean", "bulk_top_5_mass_std",
"bulk_top_10_mass_mean", "bulk_top_10_mass_std",
}
self.assertTrue(expected_cols.issubset(set(df.columns)))

Expand Down
27 changes: 27 additions & 0 deletions tests/test_remove_traps.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import pandas as pd
import pytest

from weightwatcher.constants import CHANNELS, FRAMEWORK, LAYER_TYPE, DEFAULT_PARAMS
Expand Down Expand Up @@ -214,6 +215,32 @@ def test_remove_traps_public_api_direct_call(monkeypatch):
assert len(post_artifacts) == 0


def test_remove_traps_accepts_traps_dataframe_and_returns_verify_df(monkeypatch):
W, _, _, _ = _single_trap_setup(seed=505)
ww_layer = make_ww_layer(W)
watcher = WeightWatcher(model={"dummy_weight": np.array([1.0])})

monkeypatch.setattr(
watcher,
"make_layer_iterator",
lambda model=None, layers=None, params=None, base_model=None: [ww_layer],
)

out_model, verify_df = watcher.remove_traps(
model={"dummy_weight": np.array([1.0])},
layers=[],
traps=pd.DataFrame([{"layer_id": 0, "trap_index": 1}]),
seed=88,
pool=True,
plot=False,
verify_traps=True,
return_analyze=True,
)
assert isinstance(out_model, dict)
assert isinstance(verify_df, pd.DataFrame)
assert "verify_passed" in verify_df.columns


def test_remove_traps_invalid_indices_warns_and_skips(monkeypatch, caplog):
W, _, _, _ = _single_trap_setup(seed=404)
ww_layer = make_ww_layer(W)
Expand Down
67 changes: 63 additions & 4 deletions weightwatcher/remove_traps.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import numbers
import numpy as np
import pandas as pd

from .RMT_Util import svd_full, unpermute_matrix
from .constants import DEFAULT_PARAMS, DEFAULT_START_ID, FAST_SVD, LAYER_TYPE, PEFT, PLOT, POOL, START_IDS, SVD_METHOD, DEFAULT_PEFT
Expand Down Expand Up @@ -65,6 +66,18 @@ def identify_trap_mode_indices(ww, ww_layer):


def analyze_single_trap(ww, ww_layer, trap_mode_index):
def _top_percent_abs_mass(mat, percent):
flat = np.abs(np.asarray(mat, dtype=float)).ravel()
if flat.size == 0:
return 0.0
total = float(np.sum(flat))
if total <= 0.0:
return 0.0
k = int(np.ceil((float(percent) / 100.0) * flat.size))
k = max(1, min(k, flat.size))
top_sum = float(np.sum(np.partition(flat, -k)[-k:]))
return top_sum / total

W_perm = ww_layer.Wmats[0]
U_perm, S_perm, Vh_perm = svd_full(W_perm)

Expand All @@ -74,6 +87,8 @@ def analyze_single_trap(ww, ww_layer, trap_mode_index):
T_perm = sigma_perm * np.outer(u_trap, v_trap)
T_orig_norm = unpermute_matrix(T_perm, ww_layer.permute_ids[0])
U_orig, _, Vh_orig = svd_full(T_orig_norm)
top_5_mass = _top_percent_abs_mass(T_orig_norm, 5.0)
top_10_mass = _top_percent_abs_mass(T_orig_norm, 10.0)

return {
"trap_mode_index": trap_mode_index,
Expand All @@ -84,6 +99,8 @@ def analyze_single_trap(ww, ww_layer, trap_mode_index):
"v_trap": Vh_orig[0, :],
"T_perm": T_perm,
"T_orig_norm": T_orig_norm,
"top_5_mass": float(top_5_mass),
"top_10_mass": float(top_10_mass),
}


Expand Down Expand Up @@ -209,10 +226,32 @@ def apply_remove_traps(ww, ww_layer, trap_indices, params=None, seed=None, rng=N
return ww_layer


def remove_traps(ww, model=None, layers=[], trap_indices=None, seed=None, rng=None, pool=True, plot=True,
start_ids=DEFAULT_START_ID, svd_method=FAST_SVD, base_model=None, peft=DEFAULT_PEFT):
def _trap_indices_from_traps_df(traps):
"""Extract unique 1-based trap indices from a traps DataFrame-like input."""
if traps is None:
return None
if isinstance(traps, pd.DataFrame):
trap_df = traps
else:
trap_df = pd.DataFrame(traps)
if "trap_index" not in trap_df.columns:
raise ValueError("traps must include a 'trap_index' column")
indices = trap_df["trap_index"].dropna().astype(int).tolist()
indices = sorted(set(indices))
if len(indices) == 0:
raise ValueError("traps did not contain any valid trap_index values")
return indices


def remove_traps(ww, model=None, layers=[], trap_indices=None, traps=None, seed=None, rng=None, pool=True, plot=True,
verify_traps=False, return_analyze=False, start_ids=DEFAULT_START_ID, svd_method=FAST_SVD,
base_model=None, peft=DEFAULT_PEFT):
# PR359 compatibility path: passing traps=<DataFrame> instead of trap_indices=[...]
if trap_indices is None and traps is not None:
trap_indices = _trap_indices_from_traps_df(traps)

if trap_indices is None or len(trap_indices) == 0:
raise ValueError("trap_indices must be provided and non-empty")
raise ValueError("trap_indices must be provided and non-empty (or pass traps with trap_index column)")

ww.set_model_(model)
params = DEFAULT_PARAMS.copy()
Expand All @@ -230,8 +269,28 @@ def remove_traps(ww, model=None, layers=[], trap_indices=None, seed=None, rng=No
params = ww.normalize_params(params)

layer_iterator = ww.make_layer_iterator(model=ww.model, layers=layers, params=params, base_model=base_model)
verify_rows = []
for ww_layer in layer_iterator:
if not ww_layer.skipped and ww_layer.has_weights:
apply_remove_traps(ww, ww_layer, trap_indices=trap_indices, params=params, seed=seed, rng=params["rng"])

if verify_traps:
remaining = collect_trap_artifacts(
ww,
ww_layer,
params=params,
seed=None if params["rng"] is not None else seed,
rng=params["rng"],
)
verify_rows.append(
{
"layer_id": int(ww_layer.layer_id),
"requested_trap_indices": list(trap_indices),
"remaining_traps": len(remaining),
"verify_passed": len(remaining) == 0,
}
)

if return_analyze:
verify_df = pd.DataFrame.from_records(verify_rows)
return model, verify_df
return model
94 changes: 90 additions & 4 deletions weightwatcher/weightwatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3760,6 +3760,11 @@ def _trap_result_columns(self):
"rank1_mass_after_unpermute", "sigma_trap_top",
"left_top_mode", "right_top_mode", "left_top_mass", "right_top_mass",
"left_overlap_entropy", "right_overlap_entropy", "left_overlap_ipr", "right_overlap_ipr",
"top_5_mass", "top_10_mass",
"bulk_mode_count",
"bulk_localization_mean", "bulk_localization_std",
"bulk_top_5_mass_mean", "bulk_top_5_mass_std",
"bulk_top_10_mass_mean", "bulk_top_10_mass_std",
"u_length", "u_entropy", "u_discrete_entropy", "u_localization_ratio", "u_participation_ratio",
"v_length", "v_entropy", "v_discrete_entropy", "v_localization_ratio", "v_participation_ratio",
"u_l2_fourth_moment", "u_l2_sixth_moment", "u_effective_support", "u_gini_abs",
Expand All @@ -3785,6 +3790,7 @@ def apply_analyze_traps(self, ww_layer, params=None):
self.apply_permute_W(ww_layer, params)
self.apply_trap_mp_fit(ww_layer, params=params)
trap_mode_indices = self.identify_trap_mode_indices(ww_layer, params=params)
bulk_stats = self.compute_bulk_trap_reference_metrics(ww_layer, trap_mode_indices, params=params)

trap_rows = []
for trap_index, mode_index in enumerate(trap_mode_indices):
Expand All @@ -3795,6 +3801,7 @@ def apply_analyze_traps(self, ww_layer, params=None):
params=params,
trap_index=trap_index,
)
trap_row.update(bulk_stats)
trap_rows.append(trap_row)

self.apply_unpermute_W(ww_layer, params)
Expand All @@ -3815,6 +3822,78 @@ def identify_trap_mode_indices(self, ww_layer, params=None):
# counts/indices for a given layer and seed.
return remove_traps_ops.identify_trap_mode_indices(self, ww_layer)

def _top_percent_abs_mass(self, mat, percent):
flat = np.abs(np.asarray(mat, dtype=float)).ravel()
if flat.size == 0:
return 0.0
total = float(np.sum(flat))
if total <= 0.0:
return 0.0
k = int(np.ceil((float(percent) / 100.0) * flat.size))
k = max(1, min(k, flat.size))
top_sum = float(np.sum(np.partition(flat, -k)[-k:]))
return top_sum / total

def compute_bulk_trap_reference_metrics(self, ww_layer, trap_mode_indices, params=None):
if params is None: params = DEFAULT_PARAMS.copy()
if len(ww_layer.Wmats) != 1 or len(ww_layer.permute_ids) == 0:
return {
"bulk_mode_count": 0,
"bulk_localization_mean": np.nan,
"bulk_localization_std": np.nan,
"bulk_top_5_mass_mean": np.nan,
"bulk_top_5_mass_std": np.nan,
"bulk_top_10_mass_mean": np.nan,
"bulk_top_10_mass_std": np.nan,
}

W_perm = ww_layer.Wmats[0].astype(float)
p_ids = ww_layer.permute_ids[0]
U_perm, S_perm, Vh_perm = svd_full(W_perm, method=params[SVD_METHOD])
trap_set = set(int(i) for i in trap_mode_indices)
bulk_indices = [i for i in range(len(S_perm)) if i not in trap_set]

if len(bulk_indices) == 0:
return {
"bulk_mode_count": 0,
"bulk_localization_mean": np.nan,
"bulk_localization_std": np.nan,
"bulk_top_5_mass_mean": np.nan,
"bulk_top_5_mass_std": np.nan,
"bulk_top_10_mass_mean": np.nan,
"bulk_top_10_mass_std": np.nan,
}

bulk_localization = []
bulk_top_5_mass = []
bulk_top_10_mass = []

for mode_idx in bulk_indices:
u_mode = U_perm[:, mode_idx]
v_mode = Vh_perm[mode_idx, :]
u_metrics = self._trap_vector_metrics(u_mode)
v_metrics = self._trap_vector_metrics(v_mode)
loc = 0.5 * (
float(u_metrics.get("localization_ratio", np.nan)) +
float(v_metrics.get("localization_ratio", np.nan))
)

T_perm = float(S_perm[mode_idx]) * np.outer(u_mode, v_mode)
T_orig = unpermute_matrix(T_perm, p_ids)
bulk_localization.append(loc)
bulk_top_5_mass.append(self._top_percent_abs_mass(T_orig, 5.0))
bulk_top_10_mass.append(self._top_percent_abs_mass(T_orig, 10.0))

return {
"bulk_mode_count": int(len(bulk_indices)),
"bulk_localization_mean": float(np.nanmean(bulk_localization)),
"bulk_localization_std": float(np.nanstd(bulk_localization)),
"bulk_top_5_mass_mean": float(np.nanmean(bulk_top_5_mass)),
"bulk_top_5_mass_std": float(np.nanstd(bulk_top_5_mass)),
"bulk_top_10_mass_mean": float(np.nanmean(bulk_top_10_mass)),
"bulk_top_10_mass_std": float(np.nanstd(bulk_top_10_mass)),
}


def compute_original_basis_for_traps(self, ww_layer, params=None):
if params is None: params = DEFAULT_PARAMS.copy()
Expand Down Expand Up @@ -3849,6 +3928,9 @@ def analyze_single_trap(self, ww_layer, trap_mode_index, original_basis_cache=No
T_perm = sigma_perm * np.outer(u_perm, v_perm)
T_orig = unpermute_matrix(T_perm, p_ids)

top_5_mass = self._top_percent_abs_mass(T_orig, 5.0)
top_10_mass = self._top_percent_abs_mass(T_orig, 10.0)

Ut, St, Vht = svd_full(T_orig, method=params[SVD_METHOD])
u_trap = Ut[:, 0]
v_trap = Vht.T[:, 0]
Expand Down Expand Up @@ -3906,6 +3988,8 @@ def analyze_single_trap(self, ww_layer, trap_mode_index, original_basis_cache=No
"right_overlap_entropy": right_overlap_entropy,
"left_overlap_ipr": left_overlap_ipr,
"right_overlap_ipr": right_overlap_ipr,
"top_5_mass": float(top_5_mass),
"top_10_mass": float(top_10_mass),
"trap_detected": True,
"trap_eval_minus_bulk": float(eval_perm - ww_layer.bulk_max),
}
Expand Down Expand Up @@ -5658,12 +5742,14 @@ def apply_remove_traps(self, ww_layer, trap_indices, params=None, seed=None, rng
"""Remove selected traps from one dense WWLayer and replace with matched random matrices."""
return remove_traps_ops.apply_remove_traps(self, ww_layer, trap_indices, params=params, seed=seed, rng=rng)

def remove_traps(self, model=None, layers=[], trap_indices=None, seed=None, rng=None, pool=True, plot=True,
start_ids=DEFAULT_START_ID, svd_method=FAST_SVD, base_model=None, peft=DEFAULT_PEFT):
def remove_traps(self, model=None, layers=[], trap_indices=None, traps=None, seed=None, rng=None, pool=True, plot=True,
verify_traps=False, return_analyze=False, start_ids=DEFAULT_START_ID, svd_method=FAST_SVD,
base_model=None, peft=DEFAULT_PEFT):
"""Remove selected randomized MP/TW traps from dense layers."""
return remove_traps_ops.remove_traps(
self, model=model, layers=layers, trap_indices=trap_indices, seed=seed, rng=rng,
pool=pool, plot=plot, start_ids=start_ids, svd_method=svd_method, base_model=base_model, peft=peft
self, model=model, layers=layers, trap_indices=trap_indices, traps=traps, seed=seed, rng=rng,
pool=pool, plot=plot, verify_traps=verify_traps, return_analyze=return_analyze,
start_ids=start_ids, svd_method=svd_method, base_model=base_model, peft=peft
)


Expand Down