diff --git a/act/optim.py b/act/optim.py index 62f620a..c33f937 100644 --- a/act/optim.py +++ b/act/optim.py @@ -338,160 +338,84 @@ def update_param_vars(self) -> None: self.num_ampl = len(self.config["optimization_parameters"]["amps"]) - -class GeneralACTOptimizer(ACTOptimizer): - def __init__( - self, - simulation_config: SimulationConfig, - logger: object = None, - set_passive_properties = True - ): - super().__init__( - simulation_config=simulation_config, - logger=logger, - set_passive_properties=set_passive_properties - ) - - self.model = None - self.model_pool = None - self.use_random_forest = False # just for testing - self.reg = None # regressor for random forest - self.init_random_forest() - - self.voltage_data_scaler = TorchMinMaxScaler() - self.summary_feature_scaler = TorchMinMaxColScaler() - - self.segregation_index = utils.get_segregation_index(simulation_config) - self.hto_block_channels = [] - - def init_random_forest(self): - params = { - "n_estimators": 5000, - # "max_depth": 32, - "min_samples_split": 2, - # "warm_start": True, - # "oob_score": True, - "random_state": 42, - } - self.reg = RandomForestRegressor(**params) - - def train_random_forest(self, X_train, y_train, columns=[], evaluate=False) -> dict: - """ - Returns the feature importances for stats storing - """ - if evaluate: - print("Evaluating random forest") - # evaluate the model - cv = RepeatedKFold(n_splits=10, n_repeats=3, random_state=1) - n_scores = cross_val_score( - self.reg, - X_train, - y_train, - scoring="neg_mean_absolute_error", - cv=cv, - n_jobs=-1, - error_score="raise", - ) - # report performance - print("MAE: %.6f (%.6f)" % (mean(n_scores), std(n_scores))) - print("fitting random forest") - self.reg.fit(X_train[:1000], y_train[:1000]) - print("Done fitting random forest") - print("Feature importance") - if not columns: - columns = [f"feature_{i+1}" for i in range(X_train.shape[1])] - f = dict(zip(columns, np.around(self.reg.feature_importances_ * 100, 2))) - sf = { - k: v for k, v in sorted(f.items(), key=lambda item: item[1], reverse=True) - } - for k, v in sf.items(): - print(k + " : " + str(v)) - return sf - - def predict_random_forest(self, X_test): - y_pred = self.reg.predict(X_test) - return y_pred - - def optimize(self, target_V: torch.Tensor) -> torch.Tensor: + def load_params(self): # extract only traces that have spikes in them - spiking_only = True - nonsaturated_only = True - - model_class = None - learning_rate = 0 - weight_decay = 0 - num_epochs = 0 - use_spike_summary_stats = True - train_amplitude_frequency = False - train_mean_potential = False - segregation_arima_order = None - train_test_split = 0.85 - summary_feature_columns = [] - learned_variability = 0 - - inj_dur = self.config["simulation_parameters"].get("h_i_dur") - inj_start = self.config["simulation_parameters"].get("h_i_delay") - fs = ( + self.spiking_only = True + self.nonsaturated_only = True + + self.model_class = None + self.learning_rate = 0 + self.weight_decay = 0 + self.num_epochs = 0 + self.use_spike_summary_stats = True + self.train_amplitude_frequency = False + self.train_mean_potential = False + self.segregation_arima_order = None + self.train_test_split = 0.85 + self.summary_feature_columns = [] + self.learned_variability = 0 + + self.inj_dur = self.config["simulation_parameters"].get("h_i_dur") + self.inj_start = self.config["simulation_parameters"].get("h_i_delay") + self.fs = ( self.config["simulation_parameters"].get("h_dt") * self.config["optimization_parameters"].get("decimate_factor") * 1000 ) if self.config["run_mode"] == "segregated": - learning_rate = self.config["segregation"][self.segregation_index].get( + self.learning_rate = self.config["segregation"][self.segregation_index].get( "learning_rate", 0 ) - weight_decay = self.config["segregation"][self.segregation_index].get( + self.weight_decay = self.config["segregation"][self.segregation_index].get( "weight_decay", 0 ) - model_class = self.config["segregation"][self.segregation_index].get( + self.model_class = self.config["segregation"][self.segregation_index].get( "model_class", None ) - num_epochs = self.config["segregation"][self.segregation_index].get( + self.num_epochs = self.config["segregation"][self.segregation_index].get( "num_epochs", 0 ) - spiking_only = self.config["segregation"][self.segregation_index].get( + self.spiking_only = self.config["segregation"][self.segregation_index].get( "train_spiking_only", True ) - nonsaturated_only = self.config["segregation"][self.segregation_index].get( + self.nonsaturated_only = self.config["segregation"][self.segregation_index].get( "nonsaturated_only", True ) - use_spike_summary_stats = self.config["segregation"][ + self.use_spike_summary_stats = self.config["segregation"][ self.segregation_index ].get("use_spike_summary_stats", True) - train_amplitude_frequency = self.config["segregation"][ + self.train_amplitude_frequency = self.config["segregation"][ self.segregation_index ].get("train_amplitude_frequency", False) - train_mean_potential = self.config["segregation"][ + self.train_mean_potential = self.config["segregation"][ self.segregation_index ].get("train_mean_potential", False) - segregation_arima_order = self.config["segregation"][ + self.segregation_arima_order = self.config["segregation"][ self.segregation_index ].get("arima_order", None) - train_test_split = self.config["segregation"][self.segregation_index].get( + self.train_test_split = self.config["segregation"][self.segregation_index].get( "train_test_split", 0.99 ) - learned_variability = self.config["segregation"][ + self.learned_variability = self.config["segregation"][ self.segregation_index ].get("learned_variability", 0) - inj_start = self.config["segregation"][self.segregation_index].get( - "h_i_delay", inj_start + self.inj_start = self.config["segregation"][self.segregation_index].get( + "h_i_delay", self.inj_start ) - inj_dur = self.config["segregation"][self.segregation_index].get( - "h_i_dur", inj_dur + self.inj_dur = self.config["segregation"][self.segregation_index].get( + "h_i_dur", self.inj_dur ) - if not num_epochs: - num_epochs = self.config["optimization_parameters"].get("num_epochs") - - if self.config["optimization_parameters"].get("use_random_forest"): - model_class = "RandomForest" + if not self.num_epochs: + self.num_epochs = self.config["optimization_parameters"].get("num_epochs") - num_first_spikes = self.config["summary_features"].get("num_first_spikes", 20) - print(f"Extracting first {num_first_spikes} spikes for summary features") + self.num_first_spikes = self.config["summary_features"].get("num_first_spikes", 20) + print(f"Extracting first {self.num_first_spikes} spikes for summary features") + def load_voltage_traces(self, target_V): # Get voltage with characteristics similar to target if not self.config["optimization_parameters"]["skip_match_voltage"]: ( @@ -571,7 +495,7 @@ def optimize(self, target_V: torch.Tensor) -> torch.Tensor: else: print(f"Parametric distribution parameters not applied.") - if spiking_only: + if self.spiking_only: ( simulated_V_for_next_stage, param_samples_for_next_stage, @@ -582,7 +506,7 @@ def optimize(self, target_V: torch.Tensor) -> torch.Tensor: param_samples_for_next_stage, ampl_next_stage, ) - if nonsaturated_only: + if self.nonsaturated_only: drop_dur = 200 end_of_drop = 750 start_of_drop = end_of_drop - drop_dur @@ -605,13 +529,24 @@ def optimize(self, target_V: torch.Tensor) -> torch.Tensor: ] ampl_next_stage = ampl_next_stage[nonsaturated_ind] + return ( + simulated_V_for_next_stage, + ampl_next_stage, + spiking_ind, + nonsaturated_ind + ) + + def load_summary_features(self, + simulated_V_for_next_stage, + spiking_ind, + nonsaturated_ind): ( num_spikes_simulated, simulated_interspike_times, ) = self.extract_summary_features(simulated_V_for_next_stage) # spike_stats (first_n_spikes, avg_spike_min, avg_spike_max) = utils.spike_stats( - simulated_V_for_next_stage, n_spikes=num_first_spikes + simulated_V_for_next_stage, n_spikes=self.num_first_spikes ) coefs_loaded = False if os.path.exists("output/arima_stats.json"): @@ -620,9 +555,9 @@ def optimize(self, target_V: torch.Tensor) -> torch.Tensor: input_file="output/arima_stats.json" ) # [subset_target_ind] # TODO REMOVE for testing quickly - if spiking_only: + if self.spiking_only: coefs = coefs[spiking_ind] - if nonsaturated_only: + if self.nonsaturated_only: coefs = coefs[nonsaturated_ind] def generate_arima_columns(coefs): @@ -644,7 +579,7 @@ def generate_arima_columns(coefs): summary_feature_columns.append("Avg Min Spike Height") summary_feature_columns.append("Avg Max Spike Height") - if use_spike_summary_stats: + if self.use_spike_summary_stats: summary_features = torch.cat( (summary_features.T, first_n_spikes, coefs), dim=1 ) @@ -657,7 +592,7 @@ def generate_arima_columns(coefs): summary_features = coefs summary_feature_columns = generate_arima_columns(coefs) else: - if use_spike_summary_stats: + if self.use_spike_summary_stats: summary_features = torch.stack( ( # ampl_next_stage, @@ -677,9 +612,9 @@ def generate_arima_columns(coefs): for i in range(first_n_spikes.shape[1]): summary_feature_columns.append(f"Spike {i+1} time") - if train_amplitude_frequency: + if self.train_amplitude_frequency: amplitude, frequency = utils.get_amplitude_frequency( - simulated_V_for_next_stage.float(), inj_dur, inj_start, fs=fs + simulated_V_for_next_stage.float(), self.inj_dur, self.inj_start, fs=self.fs ) if summary_features is not None: summary_features = torch.cat( @@ -698,9 +633,9 @@ def generate_arima_columns(coefs): "amplitude", "frequency", ] - if train_mean_potential: + if self.train_mean_potential: mean_potential = utils.get_mean_potential( - simulated_V_for_next_stage.float(), inj_dur, inj_start + simulated_V_for_next_stage.float(), self.inj_dur, self.inj_start ) if summary_features is not None: summary_features = torch.cat( @@ -723,53 +658,16 @@ def generate_arima_columns(coefs): "You have to have some summary feature turned on (use_spike_summary_stats, train_amplitude_frequency, arima stats) or select a model that doesn't use them. Errors will occur" ) - # make amp output a learned parameter - param_samples_for_next_stage = torch.cat( - (param_samples_for_next_stage, ampl_next_stage.reshape((-1, 1))), dim=1 - ) - - self.model = self.init_nn_model( - in_channels=target_V.shape[1], - out_channels=self.num_params + 1, # +1 to learn amp input - summary_features=summary_features, - model_class=model_class, - ) - - # Resample to match the length of target data - resampled_data = self.resample_voltage( - simulated_V_for_next_stage, target_V.shape[1] - ) - - # TODO THESE ARE NOT VALID WITH LEARNED PARAMS - lows = [p["low"] for p in self.config["optimization_parameters"]["params"]] - highs = [p["high"] for p in self.config["optimization_parameters"]["params"]] - - lows.append(round(float(ampl_next_stage.min()), 4)) - highs.append(round(float(ampl_next_stage.max()), 4)) - # remove any remaining nan values - summary_features[torch.isnan(summary_features)] = 0 - - # Train model - train_stats = self.train_model( - resampled_data.float(), - param_samples_for_next_stage, - lows, - highs, - train_test_split=train_test_split, - summary_features=summary_features, - learning_rate=learning_rate, - weight_decay=weight_decay, - num_epochs=num_epochs, - summary_feature_columns=summary_feature_columns, - ) - + return summary_features, summary_feature_columns, coefs_loaded + + def extract_target_v_summary_features(self, target_V): # Predict and take max across ci to prevent underestimating ( num_spikes_simulated, simulated_interspike_times, ) = self.extract_summary_features(target_V.float()) (first_n_spikes, avg_spike_min, avg_spike_max) = utils.spike_stats( - target_V.float(), n_spikes=num_first_spikes + target_V.float(), n_spikes=self.num_first_spikes ) ampl_target = torch.tensor(self.config["optimization_parameters"]["amps"]) target_summary_features = None @@ -777,8 +675,8 @@ def generate_arima_columns(coefs): arima_order = (10, 0, 10) if self.config.get("summary_features", {}).get("arima_order"): arima_order = tuple(self.config["summary_features"]["arima_order"]) - if segregation_arima_order: - arima_order = segregation_arima_order + if self.segregation_arima_order: + arima_order = self.segregation_arima_order print(f"ARIMA order set to {arima_order}") total_arima_vals = 2 + arima_order[0] + arima_order[1] coefs = [] @@ -800,14 +698,14 @@ def generate_arima_columns(coefs): avg_spike_max.flatten().T, ) ) - if use_spike_summary_stats: + if self.use_spike_summary_stats: target_summary_features = torch.cat( (target_summary_features.T, first_n_spikes, coefs), dim=1 ) else: target_summary_features = coefs else: - if use_spike_summary_stats: + if self.use_spike_summary_stats: target_summary_features = torch.stack( ( # ampl_target, @@ -824,9 +722,9 @@ def generate_arima_columns(coefs): ), dim=1, ) - if train_amplitude_frequency: + if self.train_amplitude_frequency: target_amplitude, target_frequency = utils.get_amplitude_frequency( - target_V.float(), inj_dur, inj_start, fs=fs + target_V.float(), self.inj_dur, self.inj_start, fs=self.fs ) if target_summary_features is not None: target_summary_features = torch.cat( @@ -842,9 +740,9 @@ def generate_arima_columns(coefs): (target_amplitude.reshape(-1, 1), target_frequency.reshape(-1, 1)), dim=1, ) - if train_mean_potential: + if self.train_mean_potential: target_mean_potential = utils.get_mean_potential( - target_V.float(), inj_dur, inj_start + target_V.float(), self.inj_dur, self.inj_start ) if target_summary_features is not None: target_summary_features = torch.cat( @@ -860,335 +758,77 @@ def generate_arima_columns(coefs): # remove any remaining nan values target_summary_features[torch.isnan(target_summary_features)] = 0 - predictions = self.predict_with_model( - target_V.float(), lows, highs, target_summary_features.float() - ) - # predictions = torch.max(predictions, dim=0).values - - return predictions, train_stats - - def init_nn_model( - self, in_channels: int, out_channels: int, summary_features, model_class=None - ) -> torch.nn.Sequential: - if model_class: - print(f"Overriding model class to {model_class}") - if model_class.lower() == "randomforest": - self.use_random_forest = True - return None - else: - ModelClass = eval(model_class) # dangerous but ok - else: - print(f"Using ConvolutionEmbeddingNet for model class") - # ModelClass = SimpleNet - # ModelClass = BranchingNet - # ModelClass = EmbeddingNet - ModelClass = ConvolutionEmbeddingNet - # ModelClass = SummaryNet - # ModelClass = ConvolutionNet - - model = ModelClass(in_channels, out_channels, summary_features) - return model + return target_summary_features + + def get_parametric_distribution(self, n_slices, simulations_per_amp) -> tuple: + params = [ + p["channel"] for p in self.config["optimization_parameters"]["params"] + ] + lows = [p["low"] for p in self.config["optimization_parameters"]["params"]] + highs = [p["high"] for p in self.config["optimization_parameters"]["params"]] + tstop_config = self.config["simulation_parameters"]["h_tstop"] + if ( + self.config["run_mode"] == "segregated" + ): # sometimes segregated modules have different params + tstop_config = self.config["segregation"][self.segregation_index].get( + "h_tstop", tstop_config + ) - def train_model( - self, - voltage_data: torch.Tensor, - target_params: torch.Tensor, - lows, - highs, - summary_features, - summary_feature_columns=[], - train_test_split=0.85, - batch_size=8, - learning_rate=2e-5, - weight_decay=1e-4, - num_epochs=0, - ) -> None: - if not learning_rate: - learning_rate = 2e-5 - if not weight_decay: - weight_decay = 1e-4 + steps = int(tstop_config / self.config["simulation_parameters"]["h_dt"]) - sigmoid_mins = torch.tensor(lows) - sigmoid_maxs = torch.tensor(highs) + param_samples_for_next_stage = torch.zeros( + (self.num_ampl, simulations_per_amp, self.num_params) + ) + simulated_V = torch.zeros( + ( + self.num_ampl, + simulations_per_amp, + steps, + ) + ) - stats = { - "train_loss_batches": [], - "train_loss": [], - "test_loss": [], - "train_size": 0, - "test_size": 0, - "feature_importance": {}, - } + param_dist = np.array( + [ + np.arange(low, high, (high - low) / n_slices) + for low, high in zip(lows, highs) + ] + ).T - # cut the target_params for segregation - if self.config["run_mode"] == "segregated": - if self.config["segregation"][self.segregation_index].get( - "use_hto_amps", False - ): - self.hto_block_channels = self.config["optimization_parameters"].get( - "hto_block_channels", [] + amps = self.config["optimization_parameters"]["amps"] + print( + f"Sampling parameter space... this may take a while. {len(amps)} amps * {simulations_per_amp} simulations per amp = {len(amps) * simulations_per_amp}" + ) + s_amps = [] + for amp_ind, amp in enumerate(amps): + print( + f" Generating {simulations_per_amp} simulations_per_amp at {amp:.2f} amps." + ) + for slice_ind in range(simulations_per_amp): + # For each current injection amplitude, sample random parameters + param_inds = np.random.randint( + 0, n_slices, len(params) + ) # get random indices for our params + param_sample = param_dist.T[ + range(len(params)), param_inds + ] # select the params + simulated_V[amp_ind][slice_ind] = self.simulate( + amp, params, param_sample + ) + param_samples_for_next_stage[amp_ind][slice_ind] = torch.Tensor( + param_sample ) - # get all the indicies that we want to keep - keep_ind = [] - for i, param in enumerate(self.params): - if ( - param not in self.preset_params - and param not in self.hto_block_channels - ): - keep_ind.append(i) - print(f"Training target param indicies {keep_ind} only for segregation") - keep_ind.append(-1) # we want to also keep the last element for amps - print(f"With amps {keep_ind}") - target_params = target_params[:, keep_ind] - sigmoid_mins = sigmoid_mins[keep_ind] - sigmoid_maxs = sigmoid_maxs[keep_ind] - - # shuffle the training data - indexes = torch.randperm(voltage_data.shape[0]) - split_point = int(voltage_data.shape[0] * train_test_split) - train_ind = indexes[:split_point] - test_ind = indexes[split_point:] - stats["train_size"] = len(train_ind) - stats["test_size"] = len(test_ind) + s_amps.append(amp) - voltage_data_train = voltage_data[train_ind] - voltage_data_test = voltage_data[test_ind] + s_v = torch.flatten(simulated_V).reshape( + [len(amps) * simulations_per_amp, steps] + ) + s_param = torch.flatten(param_samples_for_next_stage).reshape( + [len(amps) * simulations_per_amp, len(params)] + ) + s_amps = torch.tensor(s_amps) - summary_features_train = summary_features[train_ind] - summary_features_test = summary_features[test_ind] - - target_params_train = target_params[train_ind] - target_params_test = target_params[test_ind] - - # Fit the training data, transform both train and test. - # The fit is not applied to original dataset due to the possibility of data leakage - self.voltage_data_scaler.fit(voltage_data_train) - self.summary_feature_scaler.fit(summary_features_train) - - voltage_data_train = self.voltage_data_scaler.transform(voltage_data_train) - voltage_data_test = self.voltage_data_scaler.transform(voltage_data_test) - summary_features_train = self.summary_feature_scaler.transform( - summary_features_train - ) - summary_features_test = self.summary_feature_scaler.transform( - summary_features_test - ) - - if self.use_random_forest: # use the random forest - stats["feature_importance"] = self.train_random_forest( - summary_features_train.cpu().numpy(), - target_params_train.cpu().numpy(), - columns=summary_feature_columns, - ) - else: - optim = torch.optim.Adam( - self.model.parameters(), lr=learning_rate, weight_decay=weight_decay - ) - loss_fn = torch.nn.MSELoss() # torch.nn.functional.l1_loss - - self.logger.info( - f"Training a model with {optim} optimizer | lr = {learning_rate} | weight_decay = {weight_decay}." - ) - self.logger.info( - f"Number of trainable parameters: {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}" - ) - - batch_start = torch.arange(0, len(voltage_data_train), batch_size) - - # Hold the best model - best_mse = np.inf # init to infinity - best_weights = None - - for epoch in range(num_epochs): - self.model.train() - with tqdm.tqdm( - batch_start, unit="batch", mininterval=0, disable=False - ) as bar: - bar.set_description(f"Epoch {epoch}/{num_epochs}") - for start in bar: - voltage_data_batch = voltage_data_train[ - start : start + batch_size - ] - summary_features_batch = summary_features_train[ - start : start + batch_size - ] - target_params_batch = target_params_train[ - start : start + batch_size - ] - # forward pass - pred = ( - self.model(voltage_data_batch, summary_features_batch) - * (sigmoid_maxs - sigmoid_mins) - + sigmoid_mins - ) - loss = loss_fn(pred, target_params_batch) - stats["train_loss_batches"].append( - float(loss.cpu().detach().numpy()) - ) - - # backward pass - optim.zero_grad() # this line is new, wasn't in last round - loss.backward() - - # update weights - optim.step() - - # print process - bar.set_postfix(mse=float(loss)) - # evaluate accuracy at end of each epoch - self.model.eval() - y_out = self.model(voltage_data_train, summary_features_train) - y_pred = y_out * (sigmoid_maxs - sigmoid_mins) + sigmoid_mins - target_params_train_norm = (target_params_train - sigmoid_mins) / ( - sigmoid_maxs - sigmoid_mins - ) - mse = loss_fn(y_pred, target_params_train) - # mse = loss_fn(y_out, target_params_train_norm) - mse = float(mse) - stats["train_loss"].append(mse) - - y_out = self.model(voltage_data_test, summary_features_test) - y_pred = y_out * (sigmoid_maxs - sigmoid_mins) + sigmoid_mins - target_params_test_norm = (target_params_test - sigmoid_mins) / ( - sigmoid_maxs - sigmoid_mins - ) - mse = loss_fn(y_pred, target_params_test) - # mse = loss_fn(y_out, target_params_test_norm) - mse = float(mse) - stats["test_loss"].append(mse) - if mse < best_mse: - best_mse = mse - best_weights = copy.deepcopy(self.model.state_dict()) - - # restore model and return best accuracy - self.model.load_state_dict(best_weights) - - return stats - - def predict_with_model( - self, target_V: torch.Tensor, lows, highs, summary_features - ) -> torch.Tensor: - sigmoid_mins = torch.tensor(lows) - sigmoid_maxs = torch.tensor(highs) - - if self.config["run_mode"] == "segregated": - output_ind = [] # these are the indices that the network returned - for i, param in enumerate(self.params): - if ( - param not in self.preset_params - and param not in self.hto_block_channels - ): - output_ind.append(i) - output_ind.append(-1) - sigmoid_mins = sigmoid_mins[output_ind] - sigmoid_maxs = sigmoid_maxs[output_ind] - - ret = None - if self.use_random_forest: # use random forest - ret = torch.tensor( - self.predict_random_forest(summary_features.cpu().numpy()) - ).float() - else: - self.model.eval() - outs = [] - target_V_fit = self.voltage_data_scaler.transform(target_V) - summary_features_fit = self.summary_feature_scaler.transform( - summary_features - ) - for i in range(target_V.shape[0]): - out = ( - self.model( - target_V_fit[i].reshape(1, -1), - summary_features_fit[i].reshape(1, -1), - ) - * (sigmoid_maxs - sigmoid_mins) - + sigmoid_mins - ) - outs.append(out.reshape(1, -1)) - - ret = torch.cat(outs, dim=0) - - # return with preset params - if self.config["run_mode"] == "segregated": - seg_ret = torch.zeros((ret.shape[0], len(self.params) + 1)) - seg_ret[:, output_ind] = ret - for param_ind, param in enumerate(self.params): - if param in self.preset_params: - seg_ret[:, param_ind] = self.preset_params[param] - return seg_ret - else: - return ret - - def get_parametric_distribution(self, n_slices, simulations_per_amp) -> tuple: - params = [ - p["channel"] for p in self.config["optimization_parameters"]["params"] - ] - lows = [p["low"] for p in self.config["optimization_parameters"]["params"]] - highs = [p["high"] for p in self.config["optimization_parameters"]["params"]] - tstop_config = self.config["simulation_parameters"]["h_tstop"] - if ( - self.config["run_mode"] == "segregated" - ): # sometimes segregated modules have different params - tstop_config = self.config["segregation"][self.segregation_index].get( - "h_tstop", tstop_config - ) - - steps = int(tstop_config / self.config["simulation_parameters"]["h_dt"]) - - param_samples_for_next_stage = torch.zeros( - (self.num_ampl, simulations_per_amp, self.num_params) - ) - simulated_V = torch.zeros( - ( - self.num_ampl, - simulations_per_amp, - steps, - ) - ) - - param_dist = np.array( - [ - np.arange(low, high, (high - low) / n_slices) - for low, high in zip(lows, highs) - ] - ).T - - amps = self.config["optimization_parameters"]["amps"] - print( - f"Sampling parameter space... this may take a while. {len(amps)} amps * {simulations_per_amp} simulations per amp = {len(amps) * simulations_per_amp}" - ) - s_amps = [] - for amp_ind, amp in enumerate(amps): - print( - f" Generating {simulations_per_amp} simulations_per_amp at {amp:.2f} amps." - ) - for slice_ind in range(simulations_per_amp): - # For each current injection amplitude, sample random parameters - param_inds = np.random.randint( - 0, n_slices, len(params) - ) # get random indices for our params - param_sample = param_dist.T[ - range(len(params)), param_inds - ] # select the params - simulated_V[amp_ind][slice_ind] = self.simulate( - amp, params, param_sample - ) - param_samples_for_next_stage[amp_ind][slice_ind] = torch.Tensor( - param_sample - ) - - s_amps.append(amp) - - s_v = torch.flatten(simulated_V).reshape( - [len(amps) * simulations_per_amp, steps] - ) - s_param = torch.flatten(param_samples_for_next_stage).reshape( - [len(amps) * simulations_per_amp, len(params)] - ) - s_amps = torch.tensor(s_amps) - - return s_v, s_param, s_amps + return s_v, s_param, s_amps def match_voltage(self, target_V: torch.Tensor) -> tuple: # Get target voltage summary features @@ -1317,3 +957,584 @@ def extract_summary_features(self, V: torch.Tensor) -> tuple: interspike_times[torch.isnan(interspike_times)] = 0 return num_spikes, interspike_times + +class RandomForestOptimizer(ACTOptimizer): + def __init__( + self, + simulation_config: SimulationConfig, + logger: object = None, + set_passive_properties = True + ): + super().__init__( + simulation_config=simulation_config, + logger=logger, + set_passive_properties=set_passive_properties + ) + + self.model = None + self.model_pool = None + self.reg = None # regressor for random forest + self.init_random_forest() + + self.voltage_data_scaler = TorchMinMaxScaler() + self.summary_feature_scaler = TorchMinMaxColScaler() + + self.segregation_index = utils.get_segregation_index(simulation_config) + self.hto_block_channels = [] + + def init_random_forest(self): + params = { + "n_estimators": 5000, + # "max_depth": 32, + "min_samples_split": 2, + # "warm_start": True, + # "oob_score": True, + "random_state": 42, + } + self.reg = RandomForestRegressor(**params) + + def train_random_forest(self, X_train, y_train, columns=[], evaluate=False) -> dict: + """ + Returns the feature importances for stats storing + """ + if evaluate: + print("Evaluating random forest") + # evaluate the model + cv = RepeatedKFold(n_splits=10, n_repeats=3, random_state=1) + n_scores = cross_val_score( + self.reg, + X_train, + y_train, + scoring="neg_mean_absolute_error", + cv=cv, + n_jobs=-1, + error_score="raise", + ) + # report performance + print("MAE: %.6f (%.6f)" % (mean(n_scores), std(n_scores))) + print("fitting random forest") + self.reg.fit(X_train[:1000], y_train[:1000]) + print("Done fitting random forest") + print("Feature importance") + if not columns: + columns = [f"feature_{i+1}" for i in range(X_train.shape[1])] + f = dict(zip(columns, np.around(self.reg.feature_importances_ * 100, 2))) + sf = { + k: v for k, v in sorted(f.items(), key=lambda item: item[1], reverse=True) + } + for k, v in sf.items(): + print(k + " : " + str(v)) + return sf + + def predict_random_forest(self, X_test): + y_pred = self.reg.predict(X_test) + return y_pred + + def optimize(self, target_V: torch.Tensor) -> torch.Tensor: + self.load_params() + + ( + simulated_V_for_next_stage, + ampl_next_stage, + spiking_ind, + nonsaturated_ind + ) = self.load_voltage_traces(target_V) + + summary_features, summary_feature_columns, coefs_loaded = self.load_summary_features(simulated_V_for_next_stage, + spiking_ind, + nonsaturated_ind + ) + + # make amp output a learned parameter (target params) + param_samples_for_next_stage = torch.cat( + (param_samples_for_next_stage, ampl_next_stage.reshape((-1, 1))), dim=1 + ) + + # Resample to match the length of target data + resampled_data = self.resample_voltage( + simulated_V_for_next_stage, target_V.shape[1] + ) + + lows = [p["low"] for p in self.config["optimization_parameters"]["params"]] + highs = [p["high"] for p in self.config["optimization_parameters"]["params"]] + + lows.append(round(float(ampl_next_stage.min()), 4)) + highs.append(round(float(ampl_next_stage.max()), 4)) + # remove any remaining nan values + summary_features[torch.isnan(summary_features)] = 0 + + # Train model + train_stats = self.train_model( + resampled_data.float(), + param_samples_for_next_stage, + lows, + highs, + train_test_split=self.train_test_split, + summary_features=summary_features, + learning_rate=self.learning_rate, + weight_decay=self.weight_decay, + num_epochs=self.num_epochs, + summary_feature_columns=summary_feature_columns, + ) + + target_summary_features = self.extract_target_v_summary_features(target_V) + + predictions = self.predict_with_model( + target_V.float(), lows, highs, target_summary_features.float() + ) + # predictions = torch.max(predictions, dim=0).values + + return predictions, train_stats + + def train_model( + self, + voltage_data: torch.Tensor, + target_params: torch.Tensor, + lows, + highs, + summary_features, + summary_feature_columns=[], + train_test_split=0.85, + batch_size=8, + learning_rate=2e-5, + weight_decay=1e-4, + num_epochs=0, + ) -> None: + if not learning_rate: + learning_rate = 2e-5 + if not weight_decay: + weight_decay = 1e-4 + + sigmoid_mins = torch.tensor(lows) + sigmoid_maxs = torch.tensor(highs) + + stats = { + "train_loss_batches": [], + "train_loss": [], + "test_loss": [], + "train_size": 0, + "test_size": 0, + "feature_importance": {}, + } + + # cut the target_params for segregation + if self.config["run_mode"] == "segregated": + if self.config["segregation"][self.segregation_index].get( + "use_hto_amps", False + ): + self.hto_block_channels = self.config["optimization_parameters"].get( + "hto_block_channels", [] + ) + # get all the indicies that we want to keep + keep_ind = [] + for i, param in enumerate(self.params): + if ( + param not in self.preset_params + and param not in self.hto_block_channels + ): + keep_ind.append(i) + print(f"Training target param indicies {keep_ind} only for segregation") + keep_ind.append(-1) # we want to also keep the last element for amps + print(f"With amps {keep_ind}") + target_params = target_params[:, keep_ind] + sigmoid_mins = sigmoid_mins[keep_ind] + sigmoid_maxs = sigmoid_maxs[keep_ind] + + # shuffle the training data + indexes = torch.randperm(voltage_data.shape[0]) + split_point = int(voltage_data.shape[0] * train_test_split) + + train_ind = indexes[:split_point] + test_ind = indexes[split_point:] + stats["train_size"] = len(train_ind) + stats["test_size"] = len(test_ind) + + voltage_data_train = voltage_data[train_ind] + voltage_data_test = voltage_data[test_ind] + + summary_features_train = summary_features[train_ind] + summary_features_test = summary_features[test_ind] + + target_params_train = target_params[train_ind] + target_params_test = target_params[test_ind] + + # Fit the training data, transform both train and test. + # The fit is not applied to original dataset due to the possibility of data leakage + self.voltage_data_scaler.fit(voltage_data_train) + self.summary_feature_scaler.fit(summary_features_train) + + voltage_data_train = self.voltage_data_scaler.transform(voltage_data_train) + voltage_data_test = self.voltage_data_scaler.transform(voltage_data_test) + summary_features_train = self.summary_feature_scaler.transform( + summary_features_train + ) + summary_features_test = self.summary_feature_scaler.transform( + summary_features_test + ) + + stats["feature_importance"] = self.train_random_forest( + summary_features_train.cpu().numpy(), + target_params_train.cpu().numpy(), + columns=summary_feature_columns, + ) + + return stats + + def predict_with_model( + self, target_V: torch.Tensor, lows, highs, summary_features + ) -> torch.Tensor: + sigmoid_mins = torch.tensor(lows) + sigmoid_maxs = torch.tensor(highs) + + if self.config["run_mode"] == "segregated": + output_ind = [] # these are the indices that the network returned + for i, param in enumerate(self.params): + if ( + param not in self.preset_params + and param not in self.hto_block_channels + ): + output_ind.append(i) + output_ind.append(-1) + sigmoid_mins = sigmoid_mins[output_ind] + sigmoid_maxs = sigmoid_maxs[output_ind] + + ret = torch.tensor( + self.predict_random_forest(summary_features.cpu().numpy()) + ).float() + + # return with preset params + if self.config["run_mode"] == "segregated": + seg_ret = torch.zeros((ret.shape[0], len(self.params) + 1)) + seg_ret[:, output_ind] = ret + for param_ind, param in enumerate(self.params): + if param in self.preset_params: + seg_ret[:, param_ind] = self.preset_params[param] + return seg_ret + else: + return ret + + + + +class TCNNOptimizer(ACTOptimizer): + def __init__( + self, + simulation_config: SimulationConfig, + logger: object = None, + set_passive_properties = True + ): + super().__init__( + simulation_config=simulation_config, + logger=logger, + set_passive_properties=set_passive_properties + ) + + self.model = None + self.model_pool = None + + self.voltage_data_scaler = TorchMinMaxScaler() + self.summary_feature_scaler = TorchMinMaxColScaler() + + self.segregation_index = utils.get_segregation_index(simulation_config) + self.hto_block_channels = [] + + def optimize(self, target_V: torch.Tensor) -> torch.Tensor: + self.load_params() + + ( + simulated_V_for_next_stage, + ampl_next_stage, + spiking_ind, + nonsaturated_ind + ) = self.load_voltage_traces(target_V) + + summary_features, summary_feature_columns, coefs_loaded = self.load_summary_features(simulated_V_for_next_stage, + spiking_ind, + nonsaturated_ind + ) + + # make amp output a learned parameter (target params) + param_samples_for_next_stage = torch.cat( + (param_samples_for_next_stage, ampl_next_stage.reshape((-1, 1))), dim=1 + ) + + self.model = self.init_nn_model( + in_channels=target_V.shape[1], + out_channels=self.num_params + 1, # +1 to learn amp input + summary_features=summary_features, + model_class=self.model_class, + ) + + # Resample to match the length of target data + resampled_data = self.resample_voltage( + simulated_V_for_next_stage, target_V.shape[1] + ) + + lows = [p["low"] for p in self.config["optimization_parameters"]["params"]] + highs = [p["high"] for p in self.config["optimization_parameters"]["params"]] + + lows.append(round(float(ampl_next_stage.min()), 4)) + highs.append(round(float(ampl_next_stage.max()), 4)) + # remove any remaining nan values + summary_features[torch.isnan(summary_features)] = 0 + + # Train model + train_stats = self.train_model( + resampled_data.float(), + param_samples_for_next_stage, + lows, + highs, + train_test_split=self.train_test_split, + summary_features=summary_features, + learning_rate=self.learning_rate, + weight_decay=self.weight_decay, + num_epochs=self.num_epochs, + summary_feature_columns=summary_feature_columns, + ) + + target_summary_features = self.extract_target_v_summary_features(target_V) + + predictions = self.predict_with_model( + target_V.float(), lows, highs, target_summary_features.float() + ) + # predictions = torch.max(predictions, dim=0).values + + return predictions, train_stats + + def init_nn_model( + self, in_channels: int, out_channels: int, summary_features, model_class=None + ) -> torch.nn.Sequential: + if model_class: + print(f"Overriding model class to {model_class}") + ModelClass = eval(model_class) # dangerous but ok + else: + print(f"Using ConvolutionEmbeddingNet for model class") + # ModelClass = SimpleNet + # ModelClass = BranchingNet + # ModelClass = EmbeddingNet + ModelClass = ConvolutionEmbeddingNet + # ModelClass = SummaryNet + # ModelClass = ConvolutionNet + + model = ModelClass(in_channels, out_channels, summary_features) + return model + + def train_model( + self, + voltage_data: torch.Tensor, + target_params: torch.Tensor, + lows, + highs, + summary_features, + summary_feature_columns=[], + train_test_split=0.85, + batch_size=8, + learning_rate=2e-5, + weight_decay=1e-4, + num_epochs=0, + ) -> None: + if not learning_rate: + learning_rate = 2e-5 + if not weight_decay: + weight_decay = 1e-4 + + sigmoid_mins = torch.tensor(lows) + sigmoid_maxs = torch.tensor(highs) + + stats = { + "train_loss_batches": [], + "train_loss": [], + "test_loss": [], + "train_size": 0, + "test_size": 0, + "feature_importance": {}, + } + + # cut the target_params for segregation + if self.config["run_mode"] == "segregated": + if self.config["segregation"][self.segregation_index].get( + "use_hto_amps", False + ): + self.hto_block_channels = self.config["optimization_parameters"].get( + "hto_block_channels", [] + ) + # get all the indicies that we want to keep + keep_ind = [] + for i, param in enumerate(self.params): + if ( + param not in self.preset_params + and param not in self.hto_block_channels + ): + keep_ind.append(i) + print(f"Training target param indicies {keep_ind} only for segregation") + keep_ind.append(-1) # we want to also keep the last element for amps + print(f"With amps {keep_ind}") + target_params = target_params[:, keep_ind] + sigmoid_mins = sigmoid_mins[keep_ind] + sigmoid_maxs = sigmoid_maxs[keep_ind] + + # shuffle the training data + indexes = torch.randperm(voltage_data.shape[0]) + split_point = int(voltage_data.shape[0] * train_test_split) + + train_ind = indexes[:split_point] + test_ind = indexes[split_point:] + stats["train_size"] = len(train_ind) + stats["test_size"] = len(test_ind) + + voltage_data_train = voltage_data[train_ind] + voltage_data_test = voltage_data[test_ind] + + summary_features_train = summary_features[train_ind] + summary_features_test = summary_features[test_ind] + + target_params_train = target_params[train_ind] + target_params_test = target_params[test_ind] + + # Fit the training data, transform both train and test. + # The fit is not applied to original dataset due to the possibility of data leakage + self.voltage_data_scaler.fit(voltage_data_train) + self.summary_feature_scaler.fit(summary_features_train) + + voltage_data_train = self.voltage_data_scaler.transform(voltage_data_train) + voltage_data_test = self.voltage_data_scaler.transform(voltage_data_test) + summary_features_train = self.summary_feature_scaler.transform( + summary_features_train + ) + summary_features_test = self.summary_feature_scaler.transform( + summary_features_test + ) + + optim = torch.optim.Adam( + self.model.parameters(), lr=learning_rate, weight_decay=weight_decay + ) + loss_fn = torch.nn.MSELoss() # torch.nn.functional.l1_loss + + self.logger.info( + f"Training a model with {optim} optimizer | lr = {learning_rate} | weight_decay = {weight_decay}." + ) + self.logger.info( + f"Number of trainable parameters: {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}" + ) + + batch_start = torch.arange(0, len(voltage_data_train), batch_size) + + # Hold the best model + best_mse = np.inf # init to infinity + best_weights = None + + for epoch in range(num_epochs): + self.model.train() + with tqdm.tqdm( + batch_start, unit="batch", mininterval=0, disable=False + ) as bar: + bar.set_description(f"Epoch {epoch}/{num_epochs}") + for start in bar: + voltage_data_batch = voltage_data_train[ + start : start + batch_size + ] + summary_features_batch = summary_features_train[ + start : start + batch_size + ] + target_params_batch = target_params_train[ + start : start + batch_size + ] + # forward pass + pred = ( + self.model(voltage_data_batch, summary_features_batch) + * (sigmoid_maxs - sigmoid_mins) + + sigmoid_mins + ) + loss = loss_fn(pred, target_params_batch) + stats["train_loss_batches"].append( + float(loss.cpu().detach().numpy()) + ) + + # backward pass + optim.zero_grad() # this line is new, wasn't in last round + loss.backward() + + # update weights + optim.step() + + # print process + bar.set_postfix(mse=float(loss)) + # evaluate accuracy at end of each epoch + self.model.eval() + y_out = self.model(voltage_data_train, summary_features_train) + y_pred = y_out * (sigmoid_maxs - sigmoid_mins) + sigmoid_mins + target_params_train_norm = (target_params_train - sigmoid_mins) / ( + sigmoid_maxs - sigmoid_mins + ) + mse = loss_fn(y_pred, target_params_train) + # mse = loss_fn(y_out, target_params_train_norm) + mse = float(mse) + stats["train_loss"].append(mse) + + y_out = self.model(voltage_data_test, summary_features_test) + y_pred = y_out * (sigmoid_maxs - sigmoid_mins) + sigmoid_mins + target_params_test_norm = (target_params_test - sigmoid_mins) / ( + sigmoid_maxs - sigmoid_mins + ) + mse = loss_fn(y_pred, target_params_test) + # mse = loss_fn(y_out, target_params_test_norm) + mse = float(mse) + stats["test_loss"].append(mse) + if mse < best_mse: + best_mse = mse + best_weights = copy.deepcopy(self.model.state_dict()) + + # restore model and return best accuracy + self.model.load_state_dict(best_weights) + + return stats + + def predict_with_model( + self, target_V: torch.Tensor, lows, highs, summary_features + ) -> torch.Tensor: + sigmoid_mins = torch.tensor(lows) + sigmoid_maxs = torch.tensor(highs) + + if self.config["run_mode"] == "segregated": + output_ind = [] # these are the indices that the network returned + for i, param in enumerate(self.params): + if ( + param not in self.preset_params + and param not in self.hto_block_channels + ): + output_ind.append(i) + output_ind.append(-1) + sigmoid_mins = sigmoid_mins[output_ind] + sigmoid_maxs = sigmoid_maxs[output_ind] + + ret = None + + self.model.eval() + outs = [] + target_V_fit = self.voltage_data_scaler.transform(target_V) + summary_features_fit = self.summary_feature_scaler.transform( + summary_features + ) + for i in range(target_V.shape[0]): + out = ( + self.model( + target_V_fit[i].reshape(1, -1), + summary_features_fit[i].reshape(1, -1), + ) + * (sigmoid_maxs - sigmoid_mins) + + sigmoid_mins + ) + outs.append(out.reshape(1, -1)) + + ret = torch.cat(outs, dim=0) + + # return with preset params + if self.config["run_mode"] == "segregated": + seg_ret = torch.zeros((ret.shape[0], len(self.params) + 1)) + seg_ret[:, output_ind] = ret + for param_ind, param in enumerate(self.params): + if param in self.preset_params: + seg_ret[:, param_ind] = self.preset_params[param] + return seg_ret + else: + return ret diff --git a/act/simulator.py b/act/simulator.py index c8373a6..6d1d304 100644 --- a/act/simulator.py +++ b/act/simulator.py @@ -14,7 +14,7 @@ from act.analysis import save_mse_corr, save_plot, save_prediction_plots from act.logger import ACTLogger from act.metrics import correlation_score, mse_score -from act.optim import GeneralACTOptimizer +from act.optim import RandomForestOptimizer, TCNNOptimizer from act.target_utils import ( get_voltage_trace_from_params, save_target_traces, @@ -148,12 +148,15 @@ def _run(config: SimulationConfig): params = [p["channel"] for p in config["optimization_parameters"]["params"]] for repeat_num in range(config["optimization_parameters"]["num_repeats"]): if config["run_mode"] == "original" or config["run_mode"] == "segregated": - optim = GeneralACTOptimizer(simulation_config=config, logger=logger, set_passive_properties=not ltohto) + if config["optimization_parameters"].get("use_random_forest"): + optim = RandomForestOptimizer(simulation_config=config, logger=logger, set_passive_properties=not ltohto) + else: + optim = TCNNOptimizer(simulation_config=config, logger=logger, set_passive_properties=not ltohto) predictions, train_stats = optim.optimize(target_V) predictions_amps = predictions[:, -1].reshape(-1, 1) predictions = predictions[:, :-1] # elif config["run_mode"] == "segregated": - # optim = GeneralACTOptimizer(simulation_config=config, logger=logger) + # optim = RandomForestOptimizer(simulation_config=config, logger=logger) # predictions, train_stats = optim.optimize_with_segregation( # target_V, "voltage" # )