diff --git a/n3fit/src/n3fit/backends/keras_backend/callbacks.py b/n3fit/src/n3fit/backends/keras_backend/callbacks.py index 8dc6bbee48..aa851ab036 100644 --- a/n3fit/src/n3fit/backends/keras_backend/callbacks.py +++ b/n3fit/src/n3fit/backends/keras_backend/callbacks.py @@ -13,6 +13,7 @@ """ import logging +from pathlib import Path from time import time from keras import backend as K @@ -196,6 +197,64 @@ def on_step_end(self, epoch, logs=None): self._update_weights() +class StoreCallback(CallbackStep): + """ + Given a ``savedir``, the callback will store the model parameters in + that directory every ``check_freq`` epochs. + + Parameters + ---------- + pdf_model: MetaModel + The multi-replica PDF model + replica_paths: list[Path] + One path for replica. Weights are saved under /weights/. + check_freq: int + Save every this many epochs (default: 100) + """ + + def __init__(self, pdf_model, replica_paths, stopping_object, check_freq=100): + super().__init__() + self.check_freq = check_freq + self.pdf_model = pdf_model + self.weight_dirs = [] + self.stopping_object = stopping_object + for path in replica_paths: + weight_dir = path / "parameters" + weight_dir.mkdir(parents=True, exist_ok=True) + self.weight_dirs.append(weight_dir) + + def _save_weights(self, epoch, tr_weights, weight_dir): + filepath = weight_dir / f"params_{epoch}.npz" + # save parameters as expected by colibri + trainable_weights_flat = np.concatenate([np.asarray(w).flatten() for w in tr_weights]) + np.savez(filepath, params=trainable_weights_flat) + log.info(f"Saved parameters at epoch {epoch} in {filepath}") + + def on_step_end(self, epoch, logs=None): + """Function to be called at the end of every epoch + Every ``check_freq`` number of epochs, the parameters of the model will + be stored in the indicated directory. + """ + if ((epoch + 1) % self.check_freq) == 0: + pdf_replicas = self.pdf_model.split_replicas() + for replica_model, weight_dir in zip(pdf_replicas, self.weight_dirs): + weights = replica_model.trainable_weights + self._save_weights(epoch + 1, weights, weight_dir) + + def on_train_end(self, logs=None): + """Store the best parameters""" + for idx, weight_dir in enumerate(self.weight_dirs): + weights = self.stopping_object._best_weights[idx] + if weights is not None: + best_weights = weights['all_NNs'] + best_epoch = self.stopping_object._best_epochs[idx] + self._save_weights(best_epoch, best_weights, weight_dir) + else: + log.warning( + f"No best weights found for replica {idx+1}, skipping saving best parameters." + ) + + def gen_tensorboard_callback(log_dir, profiling=False, histogram_freq=0): """ Generate tensorboard logging details at ``log_dir``. diff --git a/n3fit/src/n3fit/io/writer.py b/n3fit/src/n3fit/io/writer.py index c6d02e9569..45cc053a49 100644 --- a/n3fit/src/n3fit/io/writer.py +++ b/n3fit/src/n3fit/io/writer.py @@ -303,6 +303,11 @@ def _write_metadata_json(self, i, out_path): # Note: the 2 arguments below are the same for all replicas, unless run separately timing=self.timings, stop_epoch=self.stopping_object.stop_epoch, + would_stop_epoch=( + self.stopping_object.would_stop_epoch + if self.stopping_object._dont_stop + else self.stopping_object.stop_epoch + ), ) with open(out_path, "w", encoding="utf-8") as fs: @@ -347,6 +352,7 @@ def jsonfit( true_chi2, stop_epoch, timing, + would_stop_epoch, ): """Generates a dictionary containing all relevant metadata for the fit @@ -372,6 +378,8 @@ def jsonfit( epoch at which the stopping stopped (not the one for the best fit!) timing: dict dictionary of the timing of the different events that happened + would_stop_epoch: int + epoch at which the stopping would have stopped if it were not set to "dont_stop" """ all_info = {} # Generate preprocessing information @@ -386,6 +394,7 @@ def jsonfit( all_info["arc_lengths"] = arc_lengths all_info["integrability"] = integrability_numbers all_info["timing"] = timing + all_info["would_stop_epoch"] = would_stop_epoch # Versioning info all_info["version"] = version() return all_info diff --git a/n3fit/src/n3fit/model_trainer.py b/n3fit/src/n3fit/model_trainer.py index 9917e22c89..0b45eb2c61 100644 --- a/n3fit/src/n3fit/model_trainer.py +++ b/n3fit/src/n3fit/model_trainer.py @@ -12,6 +12,7 @@ from collections import namedtuple from itertools import zip_longest import logging +import pickle import numpy as np @@ -112,6 +113,10 @@ def __init__( theoryid=None, lux_params=None, replicas=None, + save_checkpoints=False, + replica_path=None, + checkpoint_freq=100, + dont_stop=False, ): """ Parameters @@ -152,6 +157,15 @@ def __init__( if not give, the photon is not generated replicas: list list with the replicas ids to be fitted + save_checkpoints: bool + whether to save checkpoints (i.e. model parameters) during the fit. This requires + `replica_path` to be set as well. Not doing this will raise an error. + replica_path: Path + root path for all replicas. + checkpoint_freq: int + frequency (in epochs) at which to save checkpoints. Only relevant if `save_checkpoints` is True. + dont_stop: bool + whether to disable the stopping mechanism, i.e. to run for all epochs regardless of the validation chi2 """ # Save all input information self.exp_info = list(exp_info) @@ -168,6 +182,14 @@ def __init__( self.lux_params = lux_params self.replicas = replicas self.experiments_data = experiments_data + self.dont_stop = dont_stop + + # Checkpointing options + self.save_checkpoints = save_checkpoints + self.replica_path = replica_path + self.checkpoint_freq = checkpoint_freq + if self.save_checkpoints and self.replica_path is None: + raise ValueError("To save checkpoints, the 'replica_path' key must be set as well.") # Initialise internal variables which define behaviour if debug: @@ -721,11 +743,24 @@ def _train_and_fit(self, training_model, stopping_object, epochs=100) -> bool: self.training["integmultipliers"], update_freq=PUSH_INTEGRABILITY_EACH, ) + callback_list = [callback_st, callback_pos, callback_integ] + + if self.save_checkpoints: + pdf_model = training_model.get_layer("PDFs") + # Save parameters where colibri will look for checkpoints + replica_paths = [ + self.replica_path.parent / f"fit_replicas/replica_{r}" for r in self.replicas + ] + checpoint_callback = callbacks.StoreCallback( + pdf_model=pdf_model, + replica_paths=replica_paths, + check_freq=self.checkpoint_freq, + stopping_object=stopping_object, + ) + callback_list.append(checpoint_callback) training_model.perform_fit( - epochs=epochs, - verbose=False, - callbacks=self.callbacks + [callback_st, callback_pos, callback_integ], + epochs=epochs, verbose=False, callbacks=self.callbacks + callback_list ) def _hyperopt_override(self, params): @@ -921,6 +956,26 @@ def hyperparametrizable(self, params): ) replicas_settings.append(tmp) + # TODO: tempoerary fix to use NTK utilities in colibri + # Create model pkl for colibri n3fit module + _init_args = { + "flav_info": self.flavinfo, + "replica_range_settings": { + "min_replica": np.sort(self.replicas)[0], + "max_replica": np.sort(self.replicas)[0], + }, + "impose_sumrule": self.impose_sumrule, + "fitbasis": self.fitbasis, + "nodes": params["nodes_per_layer"], + "activations": params["activation_per_layer"], + "initializer_name": params["initializer"], + "layer_type": params["layer_type"], + } + state = {"_init_args": _init_args} + + with open(self.replica_path.parent / "pdf_model.pkl", "wb") as file: + pickle.dump(state, file) + ### Training loop for k, partition in enumerate(self.kpartitions): @@ -987,6 +1042,7 @@ def hyperparametrizable(self, params): stopping_patience=stopping_epochs, threshold_positivity=threshold_pos, threshold_chi2=threshold_chi2, + dont_stop=self.dont_stop, ) # Compile each of the models with the right parameters diff --git a/n3fit/src/n3fit/performfit.py b/n3fit/src/n3fit/performfit.py index 558e9690b6..af9edcc2d6 100644 --- a/n3fit/src/n3fit/performfit.py +++ b/n3fit/src/n3fit/performfit.py @@ -42,6 +42,9 @@ def performfit( maxcores=None, double_precision=False, parallel_models=True, + save_checkpoints=False, + checkpoint_freq=100, + dont_stop=False, ): """ This action will (upon having read a validcard) process a full PDF fit @@ -128,6 +131,8 @@ def performfit( whether to use double precision parallel_models: bool whether to run models in parallel + save_checkpoints: bool + whether to save checkpoints (i.e. model parameters) during the fit. """ from n3fit.backends import set_initial_state @@ -197,6 +202,10 @@ def performfit( theoryid=theoryid, lux_params=fiatlux, replicas=replica_idxs, + save_checkpoints=save_checkpoints, + replica_path=replica_path, + checkpoint_freq=checkpoint_freq, + dont_stop=dont_stop, ) # This is just to give a descriptive name to the fit function diff --git a/n3fit/src/n3fit/stopping.py b/n3fit/src/n3fit/stopping.py index 99be8f45e7..577f353db6 100644 --- a/n3fit/src/n3fit/stopping.py +++ b/n3fit/src/n3fit/stopping.py @@ -345,6 +345,7 @@ def __init__( self._dont_stop = dont_stop self._stop_now = False + self._would_stop_epoch = None self.stopping_patience = stopping_patience self.total_epochs = total_epochs @@ -481,7 +482,20 @@ def make_stop(self): and reload the history to the point of the best model if any """ self._stop_now = True - self._restore_best_weights() + if self._would_stop_epoch is None: + # final_epoch is the last registered epoch (0-indexed); +1 to match stop_epoch convention + self._would_stop_epoch = ( + -1 if self._history.final_epoch is None else self._history.final_epoch + 1 + ) + if not self._dont_stop: + self._restore_best_weights() + + @property + def would_stop_epoch(self): + """Epoch at which early stopping would have triggered. + Returns None if stopping never triggered (fit converged within total_epochs). + When dont_stop=False this equals stop_epoch.""" + return self._would_stop_epoch def _restore_best_weights(self): for i_replica, weights in enumerate(self._best_weights): diff --git a/validphys2/src/validphys/core.py b/validphys2/src/validphys/core.py index 4021a36559..5c4ae2fc23 100644 --- a/validphys2/src/validphys/core.py +++ b/validphys2/src/validphys/core.py @@ -886,6 +886,9 @@ def errorbar68(self): up = np.nanpercentile(self.error_members(), 84.13, axis=0) return down, up + def median(self): + return np.median(self.error_members(), axis=0) + def sample_values(self, size): return np.random.choice(self, size=size) diff --git a/validphys2/src/validphys/pdfgrids.py b/validphys2/src/validphys/pdfgrids.py index 039de9cc9a..247fdd8c8f 100644 --- a/validphys2/src/validphys/pdfgrids.py +++ b/validphys2/src/validphys/pdfgrids.py @@ -65,6 +65,10 @@ def __post_init__(self): if not isinstance(self.grid_values, Stats): raise ValueError("`XPlottingGrid` grid_values can only be instances of `Stats`") + # Ensure that flavours is a list or tuple and not numpy array + if isinstance(self.flavours, np.ndarray): + self.flavours = self.flavours.tolist() + def select_flavour(self, flindex): """Return a new grid for one single flavour""" if isinstance(flindex, str):