Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 99 additions & 10 deletions mothernet/prediction/mothernet_additive.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Tuple

import numpy as np
import pandas as pd
Expand All @@ -12,7 +12,7 @@
from mothernet.models.utils import bin_data
from mothernet.utils import normalize_data

from interpret.glassbox._ebm._ebm import EBMExplanation
from interpret.glassbox._ebm._ebm import EBMExplanation, ExplainableBoostingClassifier
from interpret.utils._explanation import gen_global_selector


Expand Down Expand Up @@ -155,7 +155,7 @@ def predict_with_additive_model(X_train, X_test, weights, biases, bin_edges, nan
return torch.nn.functional.softmax(out / .8, dim=1).cpu().numpy()
else:
raise ValueError(f"Unknown inference_device: {inference_device}")


class ExplainableAdditivePredictor:
def explain_global(self):
Expand All @@ -173,9 +173,9 @@ def explain_global(self):
padded_scores = np.pad(class_one_scores, (1, 1), 'constant', constant_values=(0, 0))
else:
raise Exception("Need to implement explanations for multiclass")

self.term_scores_.append(padded_scores)

lower_bound, upper_bound = np.inf, -np.inf
for scores in self.term_scores_:
lower_bound = min(lower_bound, np.min(scores))
Expand Down Expand Up @@ -236,7 +236,7 @@ def explain_global(self):
overall_dict = {
"type": "univariate",
"names": term_names,
"scores": [1 for i in range(len(term_names))], # TODO: Stop hard coding
"scores": [1 for i in range(len(term_names))], # TODO: Stop hard coding
}
internal_obj = {
"overall": overall_dict,
Expand All @@ -257,7 +257,7 @@ def explain_global(self):
None,
),
)

def _extract_feature_bounds(self, X):
if torch.is_tensor(X):
X_numpy = X.cpu().detach().numpy()
Expand All @@ -266,10 +266,10 @@ def _extract_feature_bounds(self, X):
mins = X_numpy.min(axis=0).tolist()
maxs = X_numpy.max(axis=0).tolist()
feature_bounds = [(float(min_), float(max_)) for min_, max_ in zip(mins, maxs)]

self.feature_bounds_ = feature_bounds
return self


class MotherNetAdditiveClassifier(ClassifierMixin, BaseEstimator, ExplainableAdditivePredictor):
def __init__(self, path=None, device="cpu", inference_device="cpu", model=None, config=None, cat_features: List[int] = None):
Expand Down Expand Up @@ -308,7 +308,7 @@ def fit(self, X, y):
if config['model_type'] not in ["additive", "baam"]:
raise ValueError(f"Incompatible model_type: {config['model_type']}")
model.to(self.device)

try:
self.nan_bin = config['additive']['nan_bin']
except KeyError:
Expand Down Expand Up @@ -337,6 +337,95 @@ def predict(self, X):
return self.classes_[self.predict_proba(X).argmax(axis=1)]


def make_X_with_pair_effect(X, pairs) -> np.array:
"""
Creates a new feature matrix by combining the original features with their pairwise interactions.

Parameters:
X (numpy array): The original feature matrix.
pairs (list of tuples): A list of pairs of feature indices.

Returns:
numpy array: The new feature matrix with the pairwise interactions.
"""
X_pair_effect = [X[:, i:i + 1] * X[:, j:j + 1] for i, j in pairs]
X_combined = np.concatenate([X] + X_pair_effect, axis=1)
return X_combined


def compute_top_pairs(X, y, pair_strategy: str, n_pair_feature_max_ratio: float = 0.9) -> List[Tuple[int, int]]:
"""
:param X: (n_samples, n_features)
:param y: (n_samples)
:param pair_strategy: either "sum_importance" or "fast"
:param n_pair_feature_max_ratio: ratio of pairs to be selected, `int(n_pair_feature_max_ratio * n_features)` pairs
are selected
:return: Compute top pairs according to chosen strategy. Pairs are in [n_samples] and are selected to be the most
predictive of y according to the strategy.
"""
assert pair_strategy in ["sum_importance", "fast"]
assert 0 <= n_pair_feature_max_ratio <= 1
n_pair_feature_max = int(X.shape[1] * n_pair_feature_max_ratio)
if n_pair_feature_max == 0:
return []
else:
if pair_strategy == "sum_importance":
# to select pairs, we compute feature importance and then pick pairs that have the largest sum of importance
from sklearn.ensemble import RandomForestClassifier
forest = RandomForestClassifier(random_state=0)
importances = forest.fit(X, y).feature_importances_

# pick pairs that have the maximum importance sum
n_features = X.shape[1]
pairs = [(i, j, importances[i] + importances[j]) for i in range(n_features) for j in
range(i + 1, n_features)]
pairs = sorted(pairs, key=lambda x: x[2], reverse=True)
pairs = [(i, j) for i, j, importance_sum in pairs]
elif pair_strategy == "fast":
if n_pair_feature_max_ratio == 1.0:
# note we set the ratio to 0.9999 here as 1.0 does not work for EBM and is confounded with 1
n_pair_feature_max_ratio = 0.999999
clf = ExplainableBoostingClassifier(interactions=n_pair_feature_max_ratio)
clf.fit(X, y)
pairs = [(x[0], x[1]) for x in clf.term_features_ if len(x) == 2]

return pairs[:n_pair_feature_max]


class MotherNetAdditiveClassifierPairEffects(MotherNetAdditiveClassifier):
def __init__(self, path=None, device="cpu", inference_device="cpu", model=None, config=None,
cat_features: List[int] = None, n_pair_feature_max_ratio: float = 0.9, pair_strategy: str = "sum_importance"):
super(MotherNetAdditiveClassifierPairEffects, self).__init__(
path=path, device=device, inference_device=inference_device, model=model, config=config,
cat_features=cat_features,
)
self.n_pair_feature_max_ratio = n_pair_feature_max_ratio
assert pair_strategy in ["sum_importance", "fast"]
self.pair_strategy = pair_strategy

def fit(self, X, y):
if self.n_pair_feature_max_ratio > 0:
# compute pairs according the selected strategy
self.pairs_ = compute_top_pairs(
X,
y,
pair_strategy=self.pair_strategy,
n_pair_feature_max_ratio=self.n_pair_feature_max_ratio
)
print(f"Going to use {len(self.pairs_)} pairs ({X.shape[1]} features present)")

# and generate features for those pairs
X = make_X_with_pair_effect(X, self.pairs_)
return super().fit(X, y)

def predict_proba_with_additive_components(self, X):
if self.n_pair_feature_max_ratio > 0:
X = make_X_with_pair_effect(X, self.pairs_)

return predict_with_additive_model(self.X_train_, X, self.w_, self.b_, self.bin_edges_, nan_bin=self.nan_bin,
inference_device=self.inference_device, regression=False)


class MotherNetAdditiveRegressor(RegressorMixin, BaseEstimator, ExplainableAdditivePredictor):
def __init__(self, path=None, device="cpu", inference_device="cpu", model=None, config=None, cat_features: List[int] = None):
self.path = path
Expand Down
38 changes: 38 additions & 0 deletions mothernet/tests/models/test_pair.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import numpy as np
import pytest

from mothernet.prediction.mothernet_additive import compute_top_pairs, MotherNetAdditiveClassifierPairEffects
from mothernet.utils import get_mn_model

np.random.seed(0)
n = 200
d = 10
X = np.random.rand(n, d)
# continuous xor with X0 and X3: x0 + x3 - 2 * x0 * x3
# => requires pairwise effect and top pair should be (0, 3)
y = X[:, 0] + X[:, 3] - 2 * X[:, 0] * X[:, 3]
y = y > 0.5


@pytest.mark.parametrize("n_pair_feature_max_ratio", [0, 0.5, 1.0])
@pytest.mark.parametrize("pair_strategy", ["sum_importance", "fast"])
def test_pair(n_pair_feature_max_ratio: float, pair_strategy: str):
pairs = compute_top_pairs(X, y, pair_strategy=pair_strategy, n_pair_feature_max_ratio=n_pair_feature_max_ratio)

assert len(pairs) == int(n_pair_feature_max_ratio * d)
if pairs:
assert pairs[0] == (0, 3)


@pytest.mark.parametrize("n_pair_feature_max_ratio", [0, 0.1])
@pytest.mark.parametrize("pair_strategy", ["sum_importance", "fast"])
def test_estimator(n_pair_feature_max_ratio: float, pair_strategy: str):
baam_model_string = "baam_nsamples500_numfeatures10_04_07_2024_17_04_53_epoch_1780.cpkt"
clf = MotherNetAdditiveClassifierPairEffects(
path=get_mn_model(baam_model_string),
pair_strategy=pair_strategy,
n_pair_feature_max_ratio=n_pair_feature_max_ratio,
)
y_pred = clf.fit(X, y).predict(X)
if n_pair_feature_max_ratio > 0:
assert (y_pred == y).mean() > 0.9