From 831e8094711c4508e1821818955dc80983fc8bec Mon Sep 17 00:00:00 2001 From: juliensiems Date: Sat, 20 Apr 2024 20:51:31 +0100 Subject: [PATCH 1/6] Current state of the per dataset categorical sampling. --- .../eval_model_on_step_function_prior.py | 11 +++++++--- mothernet/priors/classification_adapter.py | 20 +++++++++++++++---- 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/mothernet/datasets/eval_model_on_step_function_prior.py b/mothernet/datasets/eval_model_on_step_function_prior.py index 90109dcc..492ea596 100644 --- a/mothernet/datasets/eval_model_on_step_function_prior.py +++ b/mothernet/datasets/eval_model_on_step_function_prior.py @@ -1,7 +1,9 @@ import matplotlib.pyplot as plt import numpy as np import torch +from lightning import seed_everything from sklearn.model_selection import train_test_split +from sklearn.metrics import log_loss from mothernet.evaluation.concurvity import pairwise from mothernet.prediction import MotherNetAdditiveClassifier @@ -27,15 +29,16 @@ def plot_shape_function(bin_edges: np.ndarray, w: np.ndarray): def eval_step_function(): + seed_everything(42) step_function_prior = StepFunctionPrior({'max_steps': 1, 'sampling': 'uniform'}) - X, y, step_function = step_function_prior._get_batch(batch_size=1, n_samples=500, num_features=2) - X = X.squeeze().numpy() + X, y, step_function, step, mask = step_function_prior._get_batch(batch_size=1, n_samples=500, num_features=2) + X = X.squeeze(1).numpy() y = y.squeeze().numpy() # Plot the shape function here fig, ax = plt.subplots(ncols=2, sharey=True) ax[0].plot(X[:, 0], step_function[0, :, 0], 'o') - ax[1].plot(X[:, 1], step_function[0, :, 1], 'o') + # ax[1].plot(X[:, 1], step_function[0, :, 1], 'o') ax[0].set_xlabel('Feature 0') ax[1].set_xlabel('Feature 1') ax[0].set_ylabel('y') @@ -57,7 +60,9 @@ def eval_step_function(): conc = pairwise(torch.from_numpy(np.stack([additive_comp[0][:, 1], additive_comp[1][:, 1]])), kind='corr', eps=1e-12) print(f'Concurvity: {conc:.3f}') + print('Model', model_string) assert (prob.argmax(axis=1) == classifier.predict(X_test)).all() + print('Cross Entropy:', log_loss(classifier.predict(X_test), y_test)) assert classifier.score(X_test, y_test) > 0.9 diff --git a/mothernet/priors/classification_adapter.py b/mothernet/priors/classification_adapter.py index 3928bcf0..e9505016 100644 --- a/mothernet/priors/classification_adapter.py +++ b/mothernet/priors/classification_adapter.py @@ -138,14 +138,26 @@ def __call__(self, batch_size, n_samples, num_features, device, epoch=None, sing else: x = self.drop_for_reason(x, nan_handling_missing_for_unknown_reason_value(self.h['set_value_to_nan'])) - # Categorical features - categorical_features = [] - if random.random() < self.h['categorical_feature_p']: + # Categorical features (random.gammavariate equivalent is torch Gamma but with inverse scale) + per_dataset_cat_features = torch.distributions.Gamma( + 1, 1.0 / 10).sample(sample_shape=(x.shape[1], x.shape[2])).round().clamp(min=2).to(x.device) + # Mask out categorical features randomly as before + per_dataset_cat_features *= (torch.rand(size=(x.shape[1], x.shape[2]), device=x.device) < self.h['categorical_feature_p']).to(torch.float) + + class_boundaries = torch.randint( + 0, x.shape[0], (x.shape[1], x.shape[2], int(per_dataset_cat_features.max()))) + classes = torch.searchsorted(class_boundaries, x.permute(1, 2, 0)) + class_assignment = torch.remainder( + classes, torch.where(per_dataset_cat_features == 0, torch.inf, per_dataset_cat_features).unsqueeze(-1)) + + # Only overwrite the categorical features. + x[:, per_dataset_cat_features > 0] = class_assignment.permute(2, 0, 1)[:, per_dataset_cat_features > 0] + if True: p = random.random() for col in range(x.shape[2]): num_unique_features = max(round(random.gammavariate(1, 10)), 2) m = MulticlassRank(num_unique_features, ordered_p=0.3) - if random.random() < p: + if True: categorical_features.append(col) x[:, :, col] = m(x[:, :, col]) From 0c9ec132e0f8d40f50542f74224d0ab1cb4ddaa9 Mon Sep 17 00:00:00 2001 From: juliensiems Date: Sun, 21 Apr 2024 14:36:37 +0100 Subject: [PATCH 2/6] Updated version of per dataset categorical features. --- mothernet/priors/classification_adapter.py | 15 +++++---------- mothernet/utils.py | 11 +++++++---- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/mothernet/priors/classification_adapter.py b/mothernet/priors/classification_adapter.py index e9505016..ccc5e6fe 100644 --- a/mothernet/priors/classification_adapter.py +++ b/mothernet/priors/classification_adapter.py @@ -142,24 +142,19 @@ def __call__(self, batch_size, n_samples, num_features, device, epoch=None, sing per_dataset_cat_features = torch.distributions.Gamma( 1, 1.0 / 10).sample(sample_shape=(x.shape[1], x.shape[2])).round().clamp(min=2).to(x.device) # Mask out categorical features randomly as before - per_dataset_cat_features *= (torch.rand(size=(x.shape[1], x.shape[2]), device=x.device) < self.h['categorical_feature_p']).to(torch.float) + per_dataset_cat_features *= ( + torch.rand(size=(x.shape[1], x.shape[2]), device=x.device) < self.h['categorical_feature_p'] + ).to(torch.float) class_boundaries = torch.randint( 0, x.shape[0], (x.shape[1], x.shape[2], int(per_dataset_cat_features.max()))) - classes = torch.searchsorted(class_boundaries, x.permute(1, 2, 0)) + classes = torch.searchsorted(class_boundaries.contiguous(), x.permute(1, 2, 0).contiguous()) class_assignment = torch.remainder( classes, torch.where(per_dataset_cat_features == 0, torch.inf, per_dataset_cat_features).unsqueeze(-1)) # Only overwrite the categorical features. x[:, per_dataset_cat_features > 0] = class_assignment.permute(2, 0, 1)[:, per_dataset_cat_features > 0] - if True: - p = random.random() - for col in range(x.shape[2]): - num_unique_features = max(round(random.gammavariate(1, 10)), 2) - m = MulticlassRank(num_unique_features, ordered_p=0.3) - if True: - categorical_features.append(col) - x[:, :, col] = m(x[:, :, col]) + categorical_features = per_dataset_cat_features > 0 x = remove_outliers(x, categorical_features=categorical_features) x, y = normalize_data(x), normalize_data(y) diff --git a/mothernet/utils.py b/mothernet/utils.py index f1830598..b06d95ed 100644 --- a/mothernet/utils.py +++ b/mothernet/utils.py @@ -125,9 +125,12 @@ def remove_outliers(X, n_sigma=4, normalize_positions=-1, categorical_features=N # Expects T, B, H assert len(X.shape) == 3, "X must be T,B,H" - if categorical_features: - categorical_mask = torch.zeros(X.shape[2], dtype=torch.bool, device=X.device) - categorical_mask.scatter_(0, torch.tensor(categorical_features, device=X.device, dtype=int), 1.) + if categorical_features is not None: + if isinstance(categorical_features, list): + categorical_mask = torch.zeros(X.shape[2], dtype=torch.bool, device=X.device) + categorical_mask.scatter_(0, torch.tensor(categorical_features, device=X.device, dtype=int), 1.) + elif isinstance(categorical_features, torch.Tensor): + categorical_mask = categorical_features data = X if normalize_positions == -1 else X[:normalize_positions] @@ -141,7 +144,7 @@ def remove_outliers(X, n_sigma=4, normalize_positions=-1, categorical_features=N cut_off = data_std * n_sigma lower, upper = data_mean - cut_off, data_mean + cut_off - if categorical_features: + if categorical_features is not None: X = torch.where(categorical_mask, X, torch.maximum(-torch.log(1+torch.abs(X)) + lower, X)) X = torch.where(categorical_mask, X, torch.minimum(torch.log(1+torch.abs(X)) + upper, X)) else: From 961170bd709dc7f227d50a1d67b1b739b13149d4 Mon Sep 17 00:00:00 2001 From: juliensiems Date: Sun, 21 Apr 2024 14:41:31 +0100 Subject: [PATCH 3/6] revert --- .../datasets/eval_model_on_step_function_prior.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/mothernet/datasets/eval_model_on_step_function_prior.py b/mothernet/datasets/eval_model_on_step_function_prior.py index 492ea596..643b9100 100644 --- a/mothernet/datasets/eval_model_on_step_function_prior.py +++ b/mothernet/datasets/eval_model_on_step_function_prior.py @@ -1,9 +1,7 @@ import matplotlib.pyplot as plt import numpy as np import torch -from lightning import seed_everything from sklearn.model_selection import train_test_split -from sklearn.metrics import log_loss from mothernet.evaluation.concurvity import pairwise from mothernet.prediction import MotherNetAdditiveClassifier @@ -29,16 +27,15 @@ def plot_shape_function(bin_edges: np.ndarray, w: np.ndarray): def eval_step_function(): - seed_everything(42) step_function_prior = StepFunctionPrior({'max_steps': 1, 'sampling': 'uniform'}) - X, y, step_function, step, mask = step_function_prior._get_batch(batch_size=1, n_samples=500, num_features=2) - X = X.squeeze(1).numpy() + X, y, step_function = step_function_prior._get_batch(batch_size=1, n_samples=500, num_features=2) + X = X.squeeze().numpy() y = y.squeeze().numpy() # Plot the shape function here fig, ax = plt.subplots(ncols=2, sharey=True) ax[0].plot(X[:, 0], step_function[0, :, 0], 'o') - # ax[1].plot(X[:, 1], step_function[0, :, 1], 'o') + ax[1].plot(X[:, 1], step_function[0, :, 1], 'o') ax[0].set_xlabel('Feature 0') ax[1].set_xlabel('Feature 1') ax[0].set_ylabel('y') @@ -60,10 +57,8 @@ def eval_step_function(): conc = pairwise(torch.from_numpy(np.stack([additive_comp[0][:, 1], additive_comp[1][:, 1]])), kind='corr', eps=1e-12) print(f'Concurvity: {conc:.3f}') - print('Model', model_string) assert (prob.argmax(axis=1) == classifier.predict(X_test)).all() - print('Cross Entropy:', log_loss(classifier.predict(X_test), y_test)) assert classifier.score(X_test, y_test) > 0.9 -eval_step_function() +eval_step_function() \ No newline at end of file From 04a872b1c77b070e18d26100c382996df992f9d7 Mon Sep 17 00:00:00 2001 From: juliensiems Date: Sun, 21 Apr 2024 14:41:54 +0100 Subject: [PATCH 4/6] revert --- mothernet/datasets/eval_model_on_step_function_prior.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/mothernet/datasets/eval_model_on_step_function_prior.py b/mothernet/datasets/eval_model_on_step_function_prior.py index 643b9100..a1af151d 100644 --- a/mothernet/datasets/eval_model_on_step_function_prior.py +++ b/mothernet/datasets/eval_model_on_step_function_prior.py @@ -60,5 +60,3 @@ def eval_step_function(): assert (prob.argmax(axis=1) == classifier.predict(X_test)).all() assert classifier.score(X_test, y_test) > 0.9 - -eval_step_function() \ No newline at end of file From 5fe6c6ef4bb4e1bcac988614a621d213bd81d50f Mon Sep 17 00:00:00 2001 From: juliensiems Date: Thu, 2 May 2024 10:40:12 +0200 Subject: [PATCH 5/6] Add CLI option for per_dataset_categorical. --- mothernet/cli_parsing.py | 1 + mothernet/model_configs.py | 1 + mothernet/priors/classification_adapter.py | 48 ++++++++++++++-------- 3 files changed, 32 insertions(+), 18 deletions(-) diff --git a/mothernet/cli_parsing.py b/mothernet/cli_parsing.py index 33134c7c..efdbca0d 100644 --- a/mothernet/cli_parsing.py +++ b/mothernet/cli_parsing.py @@ -158,6 +158,7 @@ def argparser_from_config(parser, description="Train Mothernet"): classification_prior.add_argument('--multiclass-type', help="Which multiclass prior to use ['steps', 'rank'].", type=str) classification_prior.add_argument('--multiclass-max-steps', help="Maximum number of steps in multiclass step prior", type=int) classification_prior.add_argument('--pad-zeros', help="Whether to pad data with zeros for consistent size", type=str2bool) + classification_prior.add_argument('--per-dataset-categorical', help="Per dataset categorical features", type=str2bool) classification_prior.add_argument('--feature-curriculum', help="Whether to use a curriculum for number of features", type=str2bool) classification_prior.set_defaults(**config['prior']['classification']) diff --git a/mothernet/model_configs.py b/mothernet/model_configs.py index 7d13c867..6a02cf03 100644 --- a/mothernet/model_configs.py +++ b/mothernet/model_configs.py @@ -106,6 +106,7 @@ def get_prior_config(max_features=100, n_samples=1024+128): 'output_multiclass_ordered_p': 0., 'multiclass_max_steps': 10, "multiclass_type": 'rank', + "per_dataset_categorical": False, 'categorical_feature_p': .2, # diff: .0 'nan_prob_no_reason': 0.0, 'nan_prob_unknown_reason': 0.0, diff --git a/mothernet/priors/classification_adapter.py b/mothernet/priors/classification_adapter.py index ccc5e6fe..322dfb9d 100644 --- a/mothernet/priors/classification_adapter.py +++ b/mothernet/priors/classification_adapter.py @@ -84,6 +84,7 @@ def __init__(self, base_prior, config): self.h = sample_distributions(parse_distributions(config)) self.base_prior = base_prior + self.per_feature_categorical = config['per_dataset_categorical'] if self.h['num_classes'] == 0: self.class_assigner = RegressionNormalized() else: @@ -138,24 +139,35 @@ def __call__(self, batch_size, n_samples, num_features, device, epoch=None, sing else: x = self.drop_for_reason(x, nan_handling_missing_for_unknown_reason_value(self.h['set_value_to_nan'])) - # Categorical features (random.gammavariate equivalent is torch Gamma but with inverse scale) - per_dataset_cat_features = torch.distributions.Gamma( - 1, 1.0 / 10).sample(sample_shape=(x.shape[1], x.shape[2])).round().clamp(min=2).to(x.device) - # Mask out categorical features randomly as before - per_dataset_cat_features *= ( - torch.rand(size=(x.shape[1], x.shape[2]), device=x.device) < self.h['categorical_feature_p'] - ).to(torch.float) - - class_boundaries = torch.randint( - 0, x.shape[0], (x.shape[1], x.shape[2], int(per_dataset_cat_features.max()))) - classes = torch.searchsorted(class_boundaries.contiguous(), x.permute(1, 2, 0).contiguous()) - class_assignment = torch.remainder( - classes, torch.where(per_dataset_cat_features == 0, torch.inf, per_dataset_cat_features).unsqueeze(-1)) - - # Only overwrite the categorical features. - x[:, per_dataset_cat_features > 0] = class_assignment.permute(2, 0, 1)[:, per_dataset_cat_features > 0] - categorical_features = per_dataset_cat_features > 0 - + if self.per_feature_categorical: + # Categorical features (random.gammavariate equivalent is torch Gamma but with inverse scale) + per_dataset_cat_features = torch.distributions.Gamma( + 1, 1.0 / 10).sample(sample_shape=(x.shape[1], x.shape[2])).round().clamp(min=2).to(x.device) + # Mask out categorical features randomly as before + per_dataset_cat_features *= ( + torch.rand(size=(x.shape[1], x.shape[2]), device=x.device) < self.h['categorical_feature_p'] + ).to(torch.float) + + class_boundaries = torch.randint( + 0, x.shape[0], (x.shape[1], x.shape[2], int(per_dataset_cat_features.max()))) + classes = torch.searchsorted(class_boundaries.contiguous(), x.permute(1, 2, 0).contiguous()) + class_assignment = torch.remainder( + classes, torch.where(per_dataset_cat_features == 0, torch.inf, per_dataset_cat_features).unsqueeze(-1)) + + # Only overwrite the categorical features. + x[:, per_dataset_cat_features > 0] = class_assignment.permute(2, 0, 1)[:, per_dataset_cat_features > 0] + categorical_features = per_dataset_cat_features > 0 + else: + # Categorical features + categorical_features = [] + if random.random() < self.h['categorical_feature_p']: + p = random.random() + for col in range(x.shape[2]): + num_unique_features = max(round(random.gammavariate(1, 10)), 2) + m = MulticlassRank(num_unique_features, ordered_p=0.3) + if random.random() < p: + categorical_features.append(col) + x[:, :, col] = m(x[:, :, col]) x = remove_outliers(x, categorical_features=categorical_features) x, y = normalize_data(x), normalize_data(y) From 7fbbdd6965b1499fe5708e0ed3ba01d830a8c636 Mon Sep 17 00:00:00 2001 From: juliensiems Date: Fri, 3 May 2024 17:03:52 +0200 Subject: [PATCH 6/6] Add benchmarking script to compare the runtime of the previous categorical mapping and the new one. --- .../benchmark_classification_adapter.py | 46 +++++++++++++++++++ mothernet/priors/classification_adapter.py | 2 +- 2 files changed, 47 insertions(+), 1 deletion(-) create mode 100644 mothernet/evaluation/benchmark_classification_adapter.py diff --git a/mothernet/evaluation/benchmark_classification_adapter.py b/mothernet/evaluation/benchmark_classification_adapter.py new file mode 100644 index 00000000..545dca8c --- /dev/null +++ b/mothernet/evaluation/benchmark_classification_adapter.py @@ -0,0 +1,46 @@ +import json +import time +from collections import defaultdict + +import matplotlib.pyplot as plt +import numpy as np +import seaborn as sns +import torch +from tqdm import tqdm + +from mothernet import priors +from mothernet.model_configs import get_prior_config +from mothernet.priors import ClassificationAdapterPrior + +prior_config = get_prior_config() +prior_config['prior']['classification']['num_classes'] = 10 +prior_config['prior']['classification']['categorical_feature_p'] = 1.0 + +results = defaultdict(list) +for device in ['cpu', 'cuda']: + for per_dataset_categorical in [True, False]: + prior_config['prior']['classification']['per_dataset_categorical'] = per_dataset_categorical + for batch_size in tqdm(np.linspace(2, 128, 10)): + batch_size = int(batch_size) + for _ in range(10): + start = time.time() + ClassificationAdapterPrior( + priors.MLPPrior(prior_config['prior']['mlp']), + num_features=prior_config['prior']['num_features'], device=device, + **prior_config['prior']['classification'] + ).get_batch(batch_size=batch_size, n_samples=500, num_features=64, + device=device, epoch=None, single_eval_pos=None) + end = time.time() + results['per_dataset_categorical'].append(per_dataset_categorical) + results['batch_size'].append(batch_size) + results['time'].append(end - start) + + with open(f'benchmark_classification_adapter_{device}.json', 'w') as f: + json.dump(results, f) + +for device in ['cpu', 'cuda']: + with open(f'benchmark_classification_adapter_{device}.json', 'r') as f: + results = json.load(f) + sns.lineplot(x='batch_size', y='time', hue='per_dataset_categorical', data=results) + plt.title('Device ' + device) + plt.show() diff --git a/mothernet/priors/classification_adapter.py b/mothernet/priors/classification_adapter.py index 322dfb9d..912357b0 100644 --- a/mothernet/priors/classification_adapter.py +++ b/mothernet/priors/classification_adapter.py @@ -149,7 +149,7 @@ def __call__(self, batch_size, n_samples, num_features, device, epoch=None, sing ).to(torch.float) class_boundaries = torch.randint( - 0, x.shape[0], (x.shape[1], x.shape[2], int(per_dataset_cat_features.max()))) + 0, x.shape[0], (x.shape[1], x.shape[2], int(per_dataset_cat_features.max())), device=x.device) classes = torch.searchsorted(class_boundaries.contiguous(), x.permute(1, 2, 0).contiguous()) class_assignment = torch.remainder( classes, torch.where(per_dataset_cat_features == 0, torch.inf, per_dataset_cat_features).unsqueeze(-1))