diff --git a/mothernet/prediction/mothernet_additive.py b/mothernet/prediction/mothernet_additive.py index 51e2d658..1b870593 100644 --- a/mothernet/prediction/mothernet_additive.py +++ b/mothernet/prediction/mothernet_additive.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Tuple import numpy as np import pandas as pd @@ -12,7 +12,7 @@ from mothernet.models.utils import bin_data from mothernet.utils import normalize_data -from interpret.glassbox._ebm._ebm import EBMExplanation +from interpret.glassbox._ebm._ebm import EBMExplanation, ExplainableBoostingClassifier from interpret.utils._explanation import gen_global_selector @@ -155,7 +155,7 @@ def predict_with_additive_model(X_train, X_test, weights, biases, bin_edges, nan return torch.nn.functional.softmax(out / .8, dim=1).cpu().numpy() else: raise ValueError(f"Unknown inference_device: {inference_device}") - + class ExplainableAdditivePredictor: def explain_global(self): @@ -173,9 +173,9 @@ def explain_global(self): padded_scores = np.pad(class_one_scores, (1, 1), 'constant', constant_values=(0, 0)) else: raise Exception("Need to implement explanations for multiclass") - + self.term_scores_.append(padded_scores) - + lower_bound, upper_bound = np.inf, -np.inf for scores in self.term_scores_: lower_bound = min(lower_bound, np.min(scores)) @@ -236,7 +236,7 @@ def explain_global(self): overall_dict = { "type": "univariate", "names": term_names, - "scores": [1 for i in range(len(term_names))], # TODO: Stop hard coding + "scores": [1 for i in range(len(term_names))], # TODO: Stop hard coding } internal_obj = { "overall": overall_dict, @@ -257,7 +257,7 @@ def explain_global(self): None, ), ) - + def _extract_feature_bounds(self, X): if torch.is_tensor(X): X_numpy = X.cpu().detach().numpy() @@ -266,10 +266,10 @@ def _extract_feature_bounds(self, X): mins = X_numpy.min(axis=0).tolist() maxs = X_numpy.max(axis=0).tolist() feature_bounds = [(float(min_), float(max_)) for min_, max_ in zip(mins, maxs)] - + self.feature_bounds_ = feature_bounds return self - + class MotherNetAdditiveClassifier(ClassifierMixin, BaseEstimator, ExplainableAdditivePredictor): def __init__(self, path=None, device="cpu", inference_device="cpu", model=None, config=None, cat_features: List[int] = None): @@ -308,7 +308,7 @@ def fit(self, X, y): if config['model_type'] not in ["additive", "baam"]: raise ValueError(f"Incompatible model_type: {config['model_type']}") model.to(self.device) - + try: self.nan_bin = config['additive']['nan_bin'] except KeyError: @@ -337,6 +337,95 @@ def predict(self, X): return self.classes_[self.predict_proba(X).argmax(axis=1)] +def make_X_with_pair_effect(X, pairs) -> np.array: + """ + Creates a new feature matrix by combining the original features with their pairwise interactions. + + Parameters: + X (numpy array): The original feature matrix. + pairs (list of tuples): A list of pairs of feature indices. + + Returns: + numpy array: The new feature matrix with the pairwise interactions. + """ + X_pair_effect = [X[:, i:i + 1] * X[:, j:j + 1] for i, j in pairs] + X_combined = np.concatenate([X] + X_pair_effect, axis=1) + return X_combined + + +def compute_top_pairs(X, y, pair_strategy: str, n_pair_feature_max_ratio: float = 0.9) -> List[Tuple[int, int]]: + """ + :param X: (n_samples, n_features) + :param y: (n_samples) + :param pair_strategy: either "sum_importance" or "fast" + :param n_pair_feature_max_ratio: ratio of pairs to be selected, `int(n_pair_feature_max_ratio * n_features)` pairs + are selected + :return: Compute top pairs according to chosen strategy. Pairs are in [n_samples] and are selected to be the most + predictive of y according to the strategy. + """ + assert pair_strategy in ["sum_importance", "fast"] + assert 0 <= n_pair_feature_max_ratio <= 1 + n_pair_feature_max = int(X.shape[1] * n_pair_feature_max_ratio) + if n_pair_feature_max == 0: + return [] + else: + if pair_strategy == "sum_importance": + # to select pairs, we compute feature importance and then pick pairs that have the largest sum of importance + from sklearn.ensemble import RandomForestClassifier + forest = RandomForestClassifier(random_state=0) + importances = forest.fit(X, y).feature_importances_ + + # pick pairs that have the maximum importance sum + n_features = X.shape[1] + pairs = [(i, j, importances[i] + importances[j]) for i in range(n_features) for j in + range(i + 1, n_features)] + pairs = sorted(pairs, key=lambda x: x[2], reverse=True) + pairs = [(i, j) for i, j, importance_sum in pairs] + elif pair_strategy == "fast": + if n_pair_feature_max_ratio == 1.0: + # note we set the ratio to 0.9999 here as 1.0 does not work for EBM and is confounded with 1 + n_pair_feature_max_ratio = 0.999999 + clf = ExplainableBoostingClassifier(interactions=n_pair_feature_max_ratio) + clf.fit(X, y) + pairs = [(x[0], x[1]) for x in clf.term_features_ if len(x) == 2] + + return pairs[:n_pair_feature_max] + + +class MotherNetAdditiveClassifierPairEffects(MotherNetAdditiveClassifier): + def __init__(self, path=None, device="cpu", inference_device="cpu", model=None, config=None, + cat_features: List[int] = None, n_pair_feature_max_ratio: float = 0.9, pair_strategy: str = "sum_importance"): + super(MotherNetAdditiveClassifierPairEffects, self).__init__( + path=path, device=device, inference_device=inference_device, model=model, config=config, + cat_features=cat_features, + ) + self.n_pair_feature_max_ratio = n_pair_feature_max_ratio + assert pair_strategy in ["sum_importance", "fast"] + self.pair_strategy = pair_strategy + + def fit(self, X, y): + if self.n_pair_feature_max_ratio > 0: + # compute pairs according the selected strategy + self.pairs_ = compute_top_pairs( + X, + y, + pair_strategy=self.pair_strategy, + n_pair_feature_max_ratio=self.n_pair_feature_max_ratio + ) + print(f"Going to use {len(self.pairs_)} pairs ({X.shape[1]} features present)") + + # and generate features for those pairs + X = make_X_with_pair_effect(X, self.pairs_) + return super().fit(X, y) + + def predict_proba_with_additive_components(self, X): + if self.n_pair_feature_max_ratio > 0: + X = make_X_with_pair_effect(X, self.pairs_) + + return predict_with_additive_model(self.X_train_, X, self.w_, self.b_, self.bin_edges_, nan_bin=self.nan_bin, + inference_device=self.inference_device, regression=False) + + class MotherNetAdditiveRegressor(RegressorMixin, BaseEstimator, ExplainableAdditivePredictor): def __init__(self, path=None, device="cpu", inference_device="cpu", model=None, config=None, cat_features: List[int] = None): self.path = path diff --git a/mothernet/tests/models/test_pair.py b/mothernet/tests/models/test_pair.py new file mode 100644 index 00000000..d5c1ea05 --- /dev/null +++ b/mothernet/tests/models/test_pair.py @@ -0,0 +1,38 @@ +import numpy as np +import pytest + +from mothernet.prediction.mothernet_additive import compute_top_pairs, MotherNetAdditiveClassifierPairEffects +from mothernet.utils import get_mn_model + +np.random.seed(0) +n = 200 +d = 10 +X = np.random.rand(n, d) +# continuous xor with X0 and X3: x0 + x3 - 2 * x0 * x3 +# => requires pairwise effect and top pair should be (0, 3) +y = X[:, 0] + X[:, 3] - 2 * X[:, 0] * X[:, 3] +y = y > 0.5 + + +@pytest.mark.parametrize("n_pair_feature_max_ratio", [0, 0.5, 1.0]) +@pytest.mark.parametrize("pair_strategy", ["sum_importance", "fast"]) +def test_pair(n_pair_feature_max_ratio: float, pair_strategy: str): + pairs = compute_top_pairs(X, y, pair_strategy=pair_strategy, n_pair_feature_max_ratio=n_pair_feature_max_ratio) + + assert len(pairs) == int(n_pair_feature_max_ratio * d) + if pairs: + assert pairs[0] == (0, 3) + + +@pytest.mark.parametrize("n_pair_feature_max_ratio", [0, 0.1]) +@pytest.mark.parametrize("pair_strategy", ["sum_importance", "fast"]) +def test_estimator(n_pair_feature_max_ratio: float, pair_strategy: str): + baam_model_string = "baam_nsamples500_numfeatures10_04_07_2024_17_04_53_epoch_1780.cpkt" + clf = MotherNetAdditiveClassifierPairEffects( + path=get_mn_model(baam_model_string), + pair_strategy=pair_strategy, + n_pair_feature_max_ratio=n_pair_feature_max_ratio, + ) + y_pred = clf.fit(X, y).predict(X) + if n_pair_feature_max_ratio > 0: + assert (y_pred == y).mean() > 0.9