Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mothernet/cli_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def argparser_from_config(parser, description="Train Mothernet"):
classification_prior.add_argument('--multiclass-type', help="Which multiclass prior to use ['steps', 'rank'].", type=str)
classification_prior.add_argument('--multiclass-max-steps', help="Maximum number of steps in multiclass step prior", type=int)
classification_prior.add_argument('--pad-zeros', help="Whether to pad data with zeros for consistent size", type=str2bool)
classification_prior.add_argument('--per-dataset-categorical', help="Per dataset categorical features", type=str2bool)
classification_prior.add_argument('--feature-curriculum', help="Whether to use a curriculum for number of features", type=str2bool)
classification_prior.set_defaults(**config['prior']['classification'])

Expand Down
2 changes: 0 additions & 2 deletions mothernet/datasets/eval_model_on_step_function_prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,5 +60,3 @@ def eval_step_function():
assert (prob.argmax(axis=1) == classifier.predict(X_test)).all()
assert classifier.score(X_test, y_test) > 0.9


eval_step_function()
46 changes: 46 additions & 0 deletions mothernet/evaluation/benchmark_classification_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import json
import time
from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
from tqdm import tqdm

from mothernet import priors
from mothernet.model_configs import get_prior_config
from mothernet.priors import ClassificationAdapterPrior

prior_config = get_prior_config()
prior_config['prior']['classification']['num_classes'] = 10
prior_config['prior']['classification']['categorical_feature_p'] = 1.0

results = defaultdict(list)
for device in ['cpu', 'cuda']:
for per_dataset_categorical in [True, False]:
prior_config['prior']['classification']['per_dataset_categorical'] = per_dataset_categorical
for batch_size in tqdm(np.linspace(2, 128, 10)):
batch_size = int(batch_size)
for _ in range(10):
start = time.time()
ClassificationAdapterPrior(
priors.MLPPrior(prior_config['prior']['mlp']),
num_features=prior_config['prior']['num_features'], device=device,
**prior_config['prior']['classification']
).get_batch(batch_size=batch_size, n_samples=500, num_features=64,
device=device, epoch=None, single_eval_pos=None)
end = time.time()
results['per_dataset_categorical'].append(per_dataset_categorical)
results['batch_size'].append(batch_size)
results['time'].append(end - start)

with open(f'benchmark_classification_adapter_{device}.json', 'w') as f:
json.dump(results, f)

for device in ['cpu', 'cuda']:
with open(f'benchmark_classification_adapter_{device}.json', 'r') as f:
results = json.load(f)
sns.lineplot(x='batch_size', y='time', hue='per_dataset_categorical', data=results)
plt.title('Device ' + device)
plt.show()
1 change: 1 addition & 0 deletions mothernet/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def get_prior_config(max_features=100, n_samples=1024+128):
'output_multiclass_ordered_p': 0.,
'multiclass_max_steps': 10,
"multiclass_type": 'rank',
"per_dataset_categorical": False,
'categorical_feature_p': .2, # diff: .0
'nan_prob_no_reason': 0.0,
'nan_prob_unknown_reason': 0.0,
Expand Down
41 changes: 30 additions & 11 deletions mothernet/priors/classification_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def __init__(self, base_prior, config):
self.h = sample_distributions(parse_distributions(config))

self.base_prior = base_prior
self.per_feature_categorical = config['per_dataset_categorical']
if self.h['num_classes'] == 0:
self.class_assigner = RegressionNormalized()
else:
Expand Down Expand Up @@ -138,17 +139,35 @@ def __call__(self, batch_size, n_samples, num_features, device, epoch=None, sing
else:
x = self.drop_for_reason(x, nan_handling_missing_for_unknown_reason_value(self.h['set_value_to_nan']))

# Categorical features
categorical_features = []
if random.random() < self.h['categorical_feature_p']:
p = random.random()
for col in range(x.shape[2]):
num_unique_features = max(round(random.gammavariate(1, 10)), 2)
m = MulticlassRank(num_unique_features, ordered_p=0.3)
if random.random() < p:
categorical_features.append(col)
x[:, :, col] = m(x[:, :, col])

if self.per_feature_categorical:
# Categorical features (random.gammavariate equivalent is torch Gamma but with inverse scale)
per_dataset_cat_features = torch.distributions.Gamma(
1, 1.0 / 10).sample(sample_shape=(x.shape[1], x.shape[2])).round().clamp(min=2).to(x.device)
# Mask out categorical features randomly as before
per_dataset_cat_features *= (
torch.rand(size=(x.shape[1], x.shape[2]), device=x.device) < self.h['categorical_feature_p']
).to(torch.float)

class_boundaries = torch.randint(
0, x.shape[0], (x.shape[1], x.shape[2], int(per_dataset_cat_features.max())), device=x.device)
classes = torch.searchsorted(class_boundaries.contiguous(), x.permute(1, 2, 0).contiguous())
class_assignment = torch.remainder(
classes, torch.where(per_dataset_cat_features == 0, torch.inf, per_dataset_cat_features).unsqueeze(-1))

# Only overwrite the categorical features.
x[:, per_dataset_cat_features > 0] = class_assignment.permute(2, 0, 1)[:, per_dataset_cat_features > 0]
categorical_features = per_dataset_cat_features > 0
else:
# Categorical features
categorical_features = []
if random.random() < self.h['categorical_feature_p']:
p = random.random()
for col in range(x.shape[2]):
num_unique_features = max(round(random.gammavariate(1, 10)), 2)
m = MulticlassRank(num_unique_features, ordered_p=0.3)
if random.random() < p:
categorical_features.append(col)
x[:, :, col] = m(x[:, :, col])
x = remove_outliers(x, categorical_features=categorical_features)
x, y = normalize_data(x), normalize_data(y)

Expand Down
11 changes: 7 additions & 4 deletions mothernet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,12 @@ def remove_outliers(X, n_sigma=4, normalize_positions=-1, categorical_features=N
# Expects T, B, H
assert len(X.shape) == 3, "X must be T,B,H"

if categorical_features:
categorical_mask = torch.zeros(X.shape[2], dtype=torch.bool, device=X.device)
categorical_mask.scatter_(0, torch.tensor(categorical_features, device=X.device, dtype=int), 1.)
if categorical_features is not None:
if isinstance(categorical_features, list):
categorical_mask = torch.zeros(X.shape[2], dtype=torch.bool, device=X.device)
categorical_mask.scatter_(0, torch.tensor(categorical_features, device=X.device, dtype=int), 1.)
elif isinstance(categorical_features, torch.Tensor):
categorical_mask = categorical_features

data = X if normalize_positions == -1 else X[:normalize_positions]

Expand All @@ -141,7 +144,7 @@ def remove_outliers(X, n_sigma=4, normalize_positions=-1, categorical_features=N
cut_off = data_std * n_sigma
lower, upper = data_mean - cut_off, data_mean + cut_off

if categorical_features:
if categorical_features is not None:
X = torch.where(categorical_mask, X, torch.maximum(-torch.log(1+torch.abs(X)) + lower, X))
X = torch.where(categorical_mask, X, torch.minimum(torch.log(1+torch.abs(X)) + upper, X))
else:
Expand Down