diff --git a/n3fit/src/n3fit/io/writer.py b/n3fit/src/n3fit/io/writer.py index 6b4879f5a8..4fa71eef58 100644 --- a/n3fit/src/n3fit/io/writer.py +++ b/n3fit/src/n3fit/io/writer.py @@ -4,13 +4,17 @@ The goal is to generate the same folder/file structure as the old nnfit code so previously active scripts can still work. """ -import os import json +import logging + import numpy as np -from reportengine.compat import yaml -import validphys + import n3fit from n3fit import vpinterface +from reportengine.compat import yaml +import validphys + +log = logging.getLogger(__name__) XGRID = np.array( [ @@ -213,98 +217,154 @@ ] ) + class WriterWrapper: - def __init__(self, replica_number, pdf_object, stopping_object, q2, timings): + def __init__(self, replica_numbers, pdf_objects, stopping_object, all_chi2s, q2, timings): """ - Initializes the writer for one given replica. This is decoupled from the writing - of the fit in order to fix some of the variables which would be, in principle, - be shared by several different history objects. + Initializes the writer for all replicas. + + This is decoupled from the writing of the fit in order to fix some of the variables + which would be, in principle, be shared by several different history objects. Parameters ---------- - `replica_number` - index of the replica - `pdf_object` + `replica_numbers` + indices of the replicas + `pdf_objects` function to evaluate with a grid in x to generate a pdf `stopping_object` a stopping.Stopping object + `all_chi2s` + list of all the chi2s, in the order: tr_chi2, vl_chi2, true_chi2 `q2` q^2 of the fit `timings` dictionary of the timing of the different events that happened """ - self.replica_number = replica_number - self.pdf_object = pdf_object + self.replica_numbers = replica_numbers + self.pdf_objects = pdf_objects self.stopping_object = stopping_object self.q2 = q2 self.timings = timings + self.tr_chi2, self.vl_chi2, self.true_chi2 = all_chi2s - def write_data(self, replica_path_set, fitname, tr_chi2, vl_chi2, true_chi2): + def write_data(self, save_path, fitname, weights_name): """ - Wrapper around the `storefit` function. + Save all the data of a fit, for all replicas. Parameters ---------- - `replica_path_set` - full path for the replica, ex: `${PWD}/runcard_name/nnfit/replica_1` + `save_path` + path for the fit results, ex: `${PWD}/runcard_name/nnfit` `fitname` - name of the fit - `tr_chi2` - training chi2 - `vl_chi2` - validation chi2 - `true_chi2` - chi2 of the replica to the central experimental data + name of the fit, ex: `Basic_runcard` + `weights_name` + name of the file to save weights to, if not empty """ - # Check the directory exist, if it doesn't, generate it - os.makedirs(replica_path_set, exist_ok=True) + save_path.mkdir(exist_ok=True, parents=True) + + self.preprocessing = [] + self.arc_lengths = [] + self.integrability_numbers = [] + for pdf_object in self.pdf_objects: + self.preprocessing.append(pdf_object.get_preprocessing_factors()) + self.arc_lengths.append(vpinterface.compute_arclength(pdf_object).tolist()) + self.integrability_numbers.append( + vpinterface.integrability_numbers(pdf_object).tolist() + ) + + for i, rn in enumerate(self.replica_numbers): + replica_path = save_path / f"replica_{rn}" + replica_path.mkdir(exist_ok=True, parents=True) + + self._write_chi2s(replica_path / "chi2exps.log") + self._write_metadata_json(i, replica_path / f"{fitname}.json") + self._export_pdf_grid(i, replica_path / f"{fitname}.exportgrid") + if weights_name: + self._write_weights(i, replica_path / f"{weights_name}") + + def _write_chi2s(self, out_path): + # Note: same for all replicas, unless run separately + chi2_log = self.stopping_object.chi2exps_json() + with open(out_path, "w", encoding="utf-8") as fs: + json.dump(chi2_log, fs, indent=2, cls=SuperEncoder) + + def _write_metadata_json(self, i, out_path): + json_dict = jsonfit( + best_epoch=self.stopping_object.e_best_chi2[i], + positivity_status=self.stopping_object.positivity_statuses[i], + preprocessing=self.preprocessing[i], + arc_lengths=self.arc_lengths[i], + integrability_numbers=self.integrability_numbers[i], + tr_chi2=self.tr_chi2[i], + vl_chi2=self.vl_chi2[i], + true_chi2=self.true_chi2[i], + # Note: the 2 arguments below are the same for all replicas, unless run separately + timing=self.timings, + stop_epoch=self.stopping_object.stop_epoch, + ) - stop_epoch = self.stopping_object.stop_epoch + with open(out_path, "w", encoding="utf-8") as fs: + json.dump(json_dict, fs, indent=2, cls=SuperEncoder) - # Get the replica status for this object - replica_status = self.stopping_object.get_next_replica() + log.info( + "Best fit for replica #%d, chi2=%.3f (tr=%.3f, vl=%.3f)", + self.replica_numbers[i], + self.true_chi2[i], + self.tr_chi2[i], + self.vl_chi2[i], + ) - # export PDF grid to file + def _export_pdf_grid(self, i, out_path): storefit( - self.pdf_object, - self.replica_number, - replica_path_set, - fitname, + self.pdf_objects[i], + self.replica_numbers[i], + out_path, self.q2, ) - # write the log file for the chi2 - chi2_log = self.stopping_object.chi2exps_json() - with (replica_path_set / "chi2exps.log").open("w", encoding="utf-8") as fs: - json.dump(chi2_log, fs, indent=2, cls = SuperEncoder) - - # export all metadata from the fit to a single yaml file - output_file = f"{replica_path_set}/{fitname}.json" - json_dict = jsonfit( - replica_status, self.pdf_object, tr_chi2, vl_chi2, true_chi2, stop_epoch, self.timings - ) - with open(output_file, "w", encoding="utf-8") as fs: - json.dump(json_dict, fs, indent=2, cls = SuperEncoder) + def _write_weights(self, i, out_path): + log.info(" > Saving the weights for future in %s", out_path) + # Extract model out of N3PDF + model = self.pdf_objects[i]._models[0] + model.save_weights(out_path, save_format="h5") class SuperEncoder(json.JSONEncoder): - """ Custom json encoder to get around the fact that np.float32 =/= float """ + """Custom json encoder to get around the fact that np.float32 =/= float""" + def default(self, o): if isinstance(o, np.float32): return float(o) return super().default(o) -def jsonfit(replica_status, pdf_object, tr_chi2, vl_chi2, true_chi2, stop_epoch, timing): +def jsonfit( + best_epoch, + positivity_status, + preprocessing, + arc_lengths, + integrability_numbers, + tr_chi2, + vl_chi2, + true_chi2, + stop_epoch, + timing, +): """Generates a dictionary containing all relevant metadata for the fit Parameters ---------- - replica_status: n3fit.stopping.ReplicaBest - a stopping.Validation object - pdf_object: n3fit.vpinterface.N3PDF - N3PDF object constructed from the pdf_model - that receives as input a point in x and returns an array of 14 flavours + best_epoch: int + epoch at which the best fit was found + positivity_status: str + string describing the positivity status of the fit + preprocessing: dict + dictionary of the preprocessing factors + arc_lengths: list + list of the arc lengths of the different PDFs + integrability_numbers: list + list of the integrability numbers of the different PDFs tr_chi2: float chi2 for the training vl_chi2: float @@ -318,16 +378,16 @@ def jsonfit(replica_status, pdf_object, tr_chi2, vl_chi2, true_chi2, stop_epoch, """ all_info = {} # Generate preprocessing information - all_info["preprocessing"] = pdf_object.get_preprocessing_factors() + all_info["preprocessing"] = preprocessing # .fitinfo-like info all_info["stop_epoch"] = stop_epoch - all_info["best_epoch"] = replica_status.best_epoch + all_info["best_epoch"] = best_epoch all_info["erf_tr"] = tr_chi2 all_info["erf_vl"] = vl_chi2 all_info["chi2"] = true_chi2 - all_info["pos_state"] = replica_status.positivity_status - all_info["arc_lengths"] = vpinterface.compute_arclength(pdf_object).tolist() - all_info["integrability"] = vpinterface.integrability_numbers(pdf_object).tolist() + all_info["pos_state"] = positivity_status + all_info["arc_lengths"] = arc_lengths + all_info["integrability"] = integrability_numbers all_info["timing"] = timing # Versioning info all_info["version"] = version() @@ -335,7 +395,7 @@ def jsonfit(replica_status, pdf_object, tr_chi2, vl_chi2, true_chi2, stop_epoch, def version(): - """ Generates a dictionary with misc version info for this run """ + """Generates a dictionary with misc version info for this run""" versions = {} try: # Wrap tf in try-except block as it could possible to run n3fit without tf @@ -373,61 +433,128 @@ def evln2lha(evln): lha[6] = evln[2] - lha[8] = ( 10*evln[1] - + 30*evln[9] + 10*evln[10] + 5*evln[11] + 3*evln[12] + 2*evln[13] - + 10*evln[3] + 30*evln[4] + 10*evln[5] + 5*evln[6] + 3*evln[7] + 2*evln[8] ) / 120 - - lha[4] = ( 10*evln[1] - + 30*evln[9] + 10*evln[10] + 5*evln[11] + 3*evln[12] + 2*evln[13] - - 10*evln[3] - 30*evln[4] - 10*evln[5] - 5*evln[6] - 3*evln[7] - 2*evln[8] ) / 120 - - lha[7] = ( 10*evln[1] - - 30*evln[9] + 10*evln[10] + 5*evln[11] + 3*evln[12] + 2*evln[13] - + 10*evln[3] - 30*evln[4] + 10*evln[5] + 5*evln[6] + 3*evln[7] + 2*evln[8] ) / 120 - - lha[5] = ( 10*evln[1] - - 30*evln[9] + 10*evln[10] + 5*evln[11] + 3*evln[12] + 2*evln[13] - - 10*evln[3] + 30*evln[4] - 10*evln[5] - 5*evln[6] - 3*evln[7] - 2*evln[8] ) / 120 - - lha[9] = ( 10*evln[1] - - 20*evln[10] + 5*evln[11] + 3*evln[12] + 2*evln[13] - + 10*evln[3] - 20*evln[5] + 5*evln[6] + 3*evln[7] + 2*evln[8] ) / 120 - - lha[3] = ( 10*evln[1] - - 20*evln[10] + 5*evln[11] + 3*evln[12] + 2*evln[13] - - 10*evln[3] + 20*evln[5] - 5*evln[6] - 3*evln[7] - 2*evln[8] ) / 120 - - lha[10] = ( 10*evln[1] - - 15*evln[11] + 3*evln[12] + 2*evln[13] - + 10*evln[3] - 15*evln[6] + 3*evln[7] + 2*evln[8] ) / 120 - - lha[2] = ( 10*evln[1] - - 15*evln[11] + 3*evln[12] + 2*evln[13] - - 10*evln[3] + 15*evln[6] - 3*evln[7] - 2*evln[8] ) / 120 - - lha[11] = ( 5*evln[1] - - 6*evln[12] + evln[13] - + 5*evln[3] - 6*evln[7] + evln[8] ) / 60 - - lha[1] = ( 5*evln[1] - - 6*evln[12] + evln[13] - - 5*evln[3] + 6*evln[7] - evln[8] ) / 60 - - lha[12] = ( evln[1] - - evln[13] - + evln[3] - evln[8] ) / 12 - - lha[0] = ( evln[1] - - evln[13] - - evln[3] + evln[8] ) / 12 + lha[8] = ( + 10 * evln[1] + + 30 * evln[9] + + 10 * evln[10] + + 5 * evln[11] + + 3 * evln[12] + + 2 * evln[13] + + 10 * evln[3] + + 30 * evln[4] + + 10 * evln[5] + + 5 * evln[6] + + 3 * evln[7] + + 2 * evln[8] + ) / 120 + + lha[4] = ( + 10 * evln[1] + + 30 * evln[9] + + 10 * evln[10] + + 5 * evln[11] + + 3 * evln[12] + + 2 * evln[13] + - 10 * evln[3] + - 30 * evln[4] + - 10 * evln[5] + - 5 * evln[6] + - 3 * evln[7] + - 2 * evln[8] + ) / 120 + + lha[7] = ( + 10 * evln[1] + - 30 * evln[9] + + 10 * evln[10] + + 5 * evln[11] + + 3 * evln[12] + + 2 * evln[13] + + 10 * evln[3] + - 30 * evln[4] + + 10 * evln[5] + + 5 * evln[6] + + 3 * evln[7] + + 2 * evln[8] + ) / 120 + + lha[5] = ( + 10 * evln[1] + - 30 * evln[9] + + 10 * evln[10] + + 5 * evln[11] + + 3 * evln[12] + + 2 * evln[13] + - 10 * evln[3] + + 30 * evln[4] + - 10 * evln[5] + - 5 * evln[6] + - 3 * evln[7] + - 2 * evln[8] + ) / 120 + + lha[9] = ( + 10 * evln[1] + - 20 * evln[10] + + 5 * evln[11] + + 3 * evln[12] + + 2 * evln[13] + + 10 * evln[3] + - 20 * evln[5] + + 5 * evln[6] + + 3 * evln[7] + + 2 * evln[8] + ) / 120 + + lha[3] = ( + 10 * evln[1] + - 20 * evln[10] + + 5 * evln[11] + + 3 * evln[12] + + 2 * evln[13] + - 10 * evln[3] + + 20 * evln[5] + - 5 * evln[6] + - 3 * evln[7] + - 2 * evln[8] + ) / 120 + + lha[10] = ( + 10 * evln[1] + - 15 * evln[11] + + 3 * evln[12] + + 2 * evln[13] + + 10 * evln[3] + - 15 * evln[6] + + 3 * evln[7] + + 2 * evln[8] + ) / 120 + + lha[2] = ( + 10 * evln[1] + - 15 * evln[11] + + 3 * evln[12] + + 2 * evln[13] + - 10 * evln[3] + + 15 * evln[6] + - 3 * evln[7] + - 2 * evln[8] + ) / 120 + + lha[11] = (5 * evln[1] - 6 * evln[12] + evln[13] + 5 * evln[3] - 6 * evln[7] + evln[8]) / 60 + + lha[1] = (5 * evln[1] - 6 * evln[12] + evln[13] - 5 * evln[3] + 6 * evln[7] - evln[8]) / 60 + + lha[12] = (evln[1] - evln[13] + evln[3] - evln[8]) / 12 + + lha[0] = (evln[1] - evln[13] - evln[3] + evln[8]) / 12 return lha def storefit( pdf_object, replica, - replica_path, - fitname, + out_path, q20, ): """ @@ -441,16 +568,14 @@ def storefit( that receives as input a point in x and returns an array of 14 flavours `replica` the replica index - `replica_path` - path for this replica - `fitname` - name of the fit + `out_path` + the path where to store the output `q20` q_0^2 """ # build exportgrid xgrid = XGRID.reshape(-1, 1) - + result = pdf_object(xgrid, flavours="n3fit").squeeze() lha = evln2lha(result.T).T @@ -458,9 +583,24 @@ def storefit( "replica": replica, "q20": q20, "xgrid": xgrid.T.tolist()[0], - "labels": ["TBAR", "BBAR", "CBAR", "SBAR", "UBAR", "DBAR", "GLUON", "D", "U", "S", "C", "B", "T", "PHT"], + "labels": [ + "TBAR", + "BBAR", + "CBAR", + "SBAR", + "UBAR", + "DBAR", + "GLUON", + "D", + "U", + "S", + "C", + "B", + "T", + "PHT", + ], "pdfgrid": lha.tolist(), } - with open(f"{replica_path}/{fitname}.exportgrid", "w") as fs: + with open(out_path, "w") as fs: yaml.dump(data, fs) diff --git a/n3fit/src/n3fit/performfit.py b/n3fit/src/n3fit/performfit.py index 431bb1f5a6..85c393c7df 100644 --- a/n3fit/src/n3fit/performfit.py +++ b/n3fit/src/n3fit/performfit.py @@ -264,49 +264,20 @@ def performfit( log.info("Stopped at epoch=%d", stopping_object.stop_epoch) final_time = stopwatch.stop() - all_training_chi2, all_val_chi2, all_exp_chi2 = the_model_trainer.evaluate(stopping_object) + all_chi2s = the_model_trainer.evaluate(stopping_object) pdf_models = result["pdf_models"] - for i, (replica_number, pdf_model) in enumerate(zip(replica_idxs, pdf_models)): - # Each model goes into its own replica folder - replica_path_set = replica_path / f"replica_{replica_number}" - - # Create a pdf instance - q0 = theoryid.get_description().get("Q0") - pdf_instance = N3PDF(pdf_model, fit_basis=basis, Q=q0) - - # Generate the writer wrapper - writer_wrapper = WriterWrapper( - replica_number, - pdf_instance, - stopping_object, - q0**2, - final_time, - ) - - # Get the right chi2s - training_chi2 = np.take(all_training_chi2, i) - val_chi2 = np.take(all_val_chi2, i) - exp_chi2 = np.take(all_exp_chi2, i) - - # And write the data down - writer_wrapper.write_data( - replica_path_set, output_path.name, training_chi2, val_chi2, exp_chi2 - ) - log.info( - "Best fit for replica #%d, chi2=%.3f (tr=%.3f, vl=%.3f)", - replica_number, - exp_chi2, - training_chi2, - val_chi2, - ) - - # Save the weights to some file for the given replica - if save: - model_file_path = replica_path_set / save - log.info(" > Saving the weights for future in %s", model_file_path) - # Need to use "str" here because TF 2.2 has a bug for paths objects (fixed in 2.3) - pdf_model.save_weights(str(model_file_path), save_format="h5") + q0 = theoryid.get_description().get("Q0") + pdf_instances = [N3PDF(pdf_model, fit_basis=basis, Q=q0) for pdf_model in pdf_models] + writer_wrapper = WriterWrapper( + replica_idxs, + pdf_instances, + stopping_object, + all_chi2s, + q0**2, + final_time, + ) + writer_wrapper.write_data(replica_path, output_path.name, save) if tensorboard is not None: log.info("Tensorboard logging information is stored at %s", log_path) diff --git a/n3fit/src/n3fit/stopping.py b/n3fit/src/n3fit/stopping.py index d08a0e04ef..0c9bc99e9a 100644 --- a/n3fit/src/n3fit/stopping.py +++ b/n3fit/src/n3fit/stopping.py @@ -28,6 +28,7 @@ be used instead. """ import logging + import numpy as np log = logging.getLogger(__name__) @@ -133,11 +134,7 @@ def parse_losses(history_object, data, suffix="loss"): class FitState: """ - Holds the state of the chi2 during the fit for all replicas - - It holds the necessary information to reload the fit - to a specific point in time if we are interested on reloading - (otherwise the relevant variables stay empty to save memory) + Holds the state of the chi2 during the fit, for all replicas and one epoch Note: the training chi2 is computed before the update of the weights so it is the chi2 that informed the updated corresponding to this state. @@ -160,11 +157,11 @@ def __init__(self, training_info, validation_info): raise ValueError( "FitState cannot be instantiated until vl_ndata, tr_ndata and vl_suffix are filled" ) - self.training = training_info + self._training = training_info self.validation = validation_info self._parsed = False - self._vl_chi2 = None - self._tr_chi2 = None + self._vl_chi2 = None # These are per replica + self._tr_chi2 = None # This is an overall training chi2 self._vl_dict = None self._tr_dict = None @@ -176,7 +173,7 @@ def vl_loss(self): @property def tr_loss(self): """Return the total validation loss as it comes from the info dictionaries""" - return self.training.get("loss") + return self._training.get("loss") def _parse_chi2(self): """ @@ -185,8 +182,8 @@ def _parse_chi2(self): """ if self._parsed: return - if self.training is not None: - self._tr_chi2, self._tr_dict = parse_losses(self.training, self.tr_ndata) + if self._training is not None: + self._tr_chi2, self._tr_dict = parse_losses(self._training, self.tr_ndata) if self.validation is not None: self._vl_chi2, self._vl_dict = parse_losses( self.validation, self.vl_ndata, suffix=self.vl_suffix @@ -212,115 +209,47 @@ def all_vl_chi2(self): self._parse_chi2() return self._vl_dict - def all_tr_chi2_for_replica(self, r): - """" Return the tr chi2 per dataset for a given replica """ - return {k: np.take(i, r) for k, i in self.all_tr_chi2.items()} + def all_tr_chi2_for_replica(self, i_replica): + """Return the tr chi2 per dataset for a given replica""" + return {k: np.take(v, i_replica) for k, v in self.all_tr_chi2.items()} - def all_vl_chi2_for_replica(self, r): - """" Return the vl chi2 per dataset for a given replica """ - return {k: np.take(i, r) for k, i in self.all_vl_chi2.items()} + def all_vl_chi2_for_replica(self, i_replica): + """Return the vl chi2 per dataset for a given replica""" + return {k: np.take(v, i_replica) for k, v in self.all_vl_chi2.items()} def total_partial_tr_chi2(self): - """ Return the tr chi2 summed over replicas per experiment""" - return {k: np.sum(i) for k, i in self.all_tr_chi2.items()} + """Return the tr chi2 summed over replicas per experiment""" + return {k: np.sum(v) for k, v in self.all_tr_chi2.items()} def total_partial_vl_chi2(self): - """ Return the vl chi2 summed over replicas per experiment""" - return {k: np.sum(i) for k, i in self.all_tr_chi2.items()} + """Return the vl chi2 summed over replicas per experiment""" + return {k: np.sum(v) for k, v in self.all_tr_chi2.items()} def total_tr_chi2(self): - """ Return the total tr chi2 summed over replicas """ + """Return the total tr chi2 summed over replicas""" return np.sum(self.tr_chi2) def total_vl_chi2(self): - """ Return the total vl chi2 summed over replicas """ + """Return the total vl chi2 summed over replicas""" return np.sum(self.vl_chi2) def __str__(self): return f"chi2: tr={self.tr_chi2} vl={self.vl_chi2}" -class ReplicaState: - """Extra complication which eventually will be merged with someone else - but it is here only for development.""" - - def __init__(self, pdf_model): - self._pdf_model = pdf_model - self._weights = None - self._best_epoch = None - self._stop_epoch = None - self._best_vl_chi2 = INITIAL_CHI2 - - def positivity_pass(self): - """ By definition, if we have a ``best_epoch`` then positivity passed """ - if self._best_epoch is None: - return False - else: - return True - - @property - def best_epoch(self): - if self._best_epoch is None: - return self.stop_epoch - return self._best_epoch - - @property - def stop_epoch(self): - return self._stop_epoch - - @property - def best_vl(self): - return float(self._best_vl_chi2) - - @property - def positivity_status(self): - if self.positivity_pass(): - return POS_OK - else: - return POS_BAD - - def register_best(self, chi2, epoch): - """ Register a new best state and some metadata about it """ - self._weights = self._pdf_model.get_weights() - self._best_epoch = epoch - self._best_vl_chi2 = chi2 - - def reload(self): - """ Reload the weights of the best state """ - if self._weights: - self._pdf_model.set_weights(self._weights) - - def stop_training(self, epoch = None): - """ Stop training this replica if not stopped before """ - if self._pdf_model.trainable: - self._pdf_model.trainable = False - self._stop_epoch = epoch - - class FitHistory: """ - Keeps a list of FitState items holding the full history of the fit. - - It also keeps track of the best epoch and the associated weights. - - Can be iterated when there are snapshots of the fit being saved. - When iterated it will rewind the fit to each of the point in history - that have been saved. + Keeps a list of FitState items holding the full chi2 history of the fit. Parameters ---------- - pdf_models: n3fit.backends.MetaModel - list of PDF models being trained, used to saved the weights + tr_ndata: dict + dictionary of {dataset: n_points} for the training data + vl_ndata: dict + dictionary of {dataset: n_points} for the validation data """ - def __init__(self, pdf_models, tr_ndata, vl_ndata): - # Create a ReplicaState object for all models - # which will hold the best chi2 and weights per replica - self._replicas = [] - for pdf_model in pdf_models: - self._replicas.append(ReplicaState(pdf_model)) - self._iter_replicas = iter(self._replicas) - + def __init__(self, tr_ndata, vl_ndata): if vl_ndata is None: vl_ndata = tr_ndata vl_suffix = "loss" @@ -335,13 +264,8 @@ def __init__(self, pdf_models, tr_ndata, vl_ndata): self._history = [] self.final_epoch = None - @property - def best_epoch(self): - """ Return the best epoch per replica """ - return [i.best_epoch for i in self._replicas] - def get_state(self, epoch): - """ Get the FitState of the system for a given epoch """ + """Get the FitState of the system for a given epoch""" try: return self._history[epoch] except IndexError as e: @@ -349,34 +273,21 @@ def get_state(self, epoch): f"Tried to get obtain the state for epoch {epoch} when only {len(self._history)} epochs have been saved" ) from e - def save_best_replica(self, i, epoch=None): - """Save the state of replica ``i`` as a best fit so far. - If an epoch is given, save the best as the given epoch, otherwise - use the last one - """ - if epoch is None: - epoch = self.final_epoch - loss = self.get_state(epoch).vl_loss[i] - self._replicas[i].register_best(loss, epoch) - - def all_positivity_status(self): - """ Returns whether the positivity passed or not per replica """ - return np.array([i.positivity_status for i in self._replicas]) - - def all_best_vl_loss(self): - """ Returns the best validation loss for each replica """ - return np.array([i.best_vl for i in self._replicas]) - def register(self, epoch, training_info, validation_info): """Save a new fitstate and updates the current final epoch Parameters ---------- - fitstate: FitState - FitState object - the fitstate of the object to save epoch: int the current epoch of the fit + training_info: dict + all losses for the training model + validation_info: dict + all losses for the validation model + + Returns + ------- + FitState """ # Save all the information in a fitstate object fitstate = FitState(training_info, validation_info) @@ -384,21 +295,6 @@ def register(self, epoch, training_info, validation_info): self._history.append(fitstate) return fitstate - def stop_training_replica(self, i, e): - """ Stop training replica i in epoch e""" - self._replicas[i].stop_training(e) - - def reload(self): - """Reloads the best fit weights into the model if there are models to be reloaded - Ensure that all replicas have stopped at this point. - """ - for replica in self._replicas: - replica.stop_training(self.final_epoch) - replica.reload() - - def __next__(self): - return next(self._iter_replicas) - class Stopping: """ @@ -412,7 +308,7 @@ class Stopping: validation_model: n3fit.backends.MetaModel the model with the validation mask applied (and compiled with the validation data and covmat) - all_data_dict: dict + all_data_dicts: dict list containg all dictionaries containing all information about the experiments/validation/regularizers/etc to be parsed by Stopping pdf_models: list(n3fit.backends.MetaModel) @@ -423,6 +319,8 @@ class Stopping: total number of epochs stopping_patience: int how many epochs to wait for the validation loss to improve + threshold_chi2: float + maximum value allowed for chi2 dont_stop: bool dont care about early stopping """ @@ -438,49 +336,61 @@ def __init__( threshold_chi2=10.0, dont_stop=False, ): + self._pdf_models = pdf_models + # Save the validation object self._validation = validation_model # Create the History object tr_ndata, vl_ndata, pos_sets = parse_ndata(all_data_dicts) - self._history = FitHistory(pdf_models, tr_ndata, vl_ndata) + self._history = FitHistory(tr_ndata, vl_ndata) # And the positivity checker self._positivity = Positivity(threshold_positivity, pos_sets) # Initialize internal variables for the stopping - self.n_replicas = len(pdf_models) - self.threshold_chi2 = threshold_chi2 - self.stopping_degree = np.zeros(self.n_replicas, dtype=int) - self.count = np.zeros(self.n_replicas, dtype=int) - - self.dont_stop = dont_stop - self.stop_now = False - self.stopping_patience = stopping_patience - self.total_epochs = total_epochs + self._n_replicas = len(pdf_models) + self._threshold_chi2 = threshold_chi2 + self._stopping_degrees = np.zeros(self._n_replicas, dtype=int) + self._counts = np.zeros(self._n_replicas, dtype=int) + + self._dont_stop = dont_stop + self._stop_now = False + self._stopping_patience = stopping_patience + self._total_epochs = total_epochs + + self._stop_epochs = [total_epochs - 1] * self._n_replicas + self._best_epochs = [None] * self._n_replicas + self.positivity_statuses = [POS_BAD] * self._n_replicas + self._best_weights = [None] * self._n_replicas + self._best_val_chi2s = [INITIAL_CHI2] * self._n_replicas @property def vl_chi2(self): - """ Current validation chi2 """ + """Current validation chi2""" validation_info = self._validation.compute_losses() fitstate = FitState(None, validation_info) return fitstate.vl_chi2 @property def e_best_chi2(self): - """ Epoch of the best chi2, if there is no best epoch, return last""" - return self._history.best_epoch + """Epoch of the best chi2, if there is no best epoch, return last""" + best_or_last_epochs = [ + best if best is not None else last + for best, last in zip(self._best_epochs, self._stop_epochs) + ] + return best_or_last_epochs @property def stop_epoch(self): - """ Epoch in which the fit is stopped """ + """Epoch in which the fit is stopped""" return self._history.final_epoch + 1 @property def positivity_status(self): """Returns POS_PASS if positivity passes or veto if it doesn't for each replica""" - return self._history.all_positivity_status() + return self.positivity_statuses def evaluate_training(self, training_model): """Given the training model, evaluates the @@ -543,26 +453,32 @@ def monitor_chi2(self, training_info, epoch, print_stats=False): # this means improving vl_chi2 and passing positivity # Don't start counting until the chi2 of the validation goes below a certain threshold # once we start counting, don't bother anymore - passes = self.count | (fitstate.vl_chi2 < self.threshold_chi2) - passes &= fitstate.vl_loss < self._history.all_best_vl_loss() + passes = self._counts | (fitstate.vl_chi2 < self._threshold_chi2) + passes &= fitstate.vl_loss < self._best_val_chi2s # And the ones that pass positivity passes &= self._positivity(fitstate) - self.stopping_degree += self.count + self._stopping_degrees += self._counts # Step 5. loop over the valid indices to check whether the vl improved - for i in np.where(passes)[0]: - self._history.save_best_replica(i) - self.stopping_degree[i] = 0 - self.count[i] = 1 + for i_replica in np.where(passes)[0]: + self._best_epochs[i_replica] = epoch + # By definition, if we have a ``best_epoch`` then positivity passed + self.positivity_statuses[i_replica] = POS_OK + + self._best_val_chi2s[i_replica] = self._history.get_state(epoch).vl_loss[i_replica] + self._best_weights[i_replica] = self._pdf_models[i_replica].get_weights() + + self._stopping_degrees[i_replica] = 0 + self._counts[i_replica] = 1 - stop_replicas = self.count & (self.stopping_degree > self.stopping_patience) - for i in np.where(stop_replicas)[0]: - self.count[i] = 0 - self._history.stop_training_replica(i, epoch) + stop_replicas = self._counts & (self._stopping_degrees > self._stopping_patience) + for i_replica in np.where(stop_replicas)[0]: + self._stop_epochs[i_replica] = epoch + self._counts[i_replica] = 0 # By using the stopping degree we only stop when none of the replicas are improving anymore - if min(self.stopping_degree) > self.stopping_patience: + if min(self._stopping_degrees) > self._stopping_patience: self.make_stop() return True @@ -570,8 +486,13 @@ def make_stop(self): """Convenience method to set the stop_now flag and reload the history to the point of the best model if any """ - self.stop_now = True - self._history.reload() + self._stop_now = True + self._restore_best_weights() + + def _restore_best_weights(self): + for replica, weights in zip(self._pdf_models, self._best_weights): + if weights is not None: + replica.set_weights(weights) def print_current_stats(self, epoch, fitstate): """ @@ -580,10 +501,10 @@ def print_current_stats(self, epoch, fitstate): epoch_index = epoch + 1 tr_chi2 = fitstate.total_tr_chi2() vl_chi2 = fitstate.total_vl_chi2() - total_str = f"At epoch {epoch_index}/{self.total_epochs}, total chi2: {tr_chi2}\n" + total_str = f"At epoch {epoch_index}/{self._total_epochs}, total chi2: {tr_chi2}\n" # The partial chi2 makes no sense for more than one replica at once: - if self.n_replicas == 1: + if self._n_replicas == 1: partial_tr_chi2 = fitstate.total_partial_tr_chi2() partials = [] for experiment, chi2 in partial_tr_chi2.items(): @@ -596,22 +517,18 @@ def stop_here(self): """Returns the stopping status If `dont_stop` is set returns always False (i.e., never stop) """ - if self.dont_stop: + if self._dont_stop: return False else: - return self.stop_now - - def get_next_replica(self): - """ Return the next ReplicaState object""" - return next(self._history) + return self._stop_now - def chi2exps_json(self, replica=0, log_each=100): + def chi2exps_json(self, i_replica=0, log_each=100): """ Returns and apt-for-json dictionary with the status of the fit every `log_each` epochs Parameters ---------- - replica: int + i_replica: int which replica are we writing the log for log_each: int every how many epochs to print the log @@ -624,10 +541,10 @@ def chi2exps_json(self, replica=0, log_each=100): final_epoch = self._history.final_epoch json_dict = {} - for i in range(log_each - 1, final_epoch + 1, log_each): - fitstate = self._history.get_state(i) - all_tr = fitstate.all_tr_chi2_for_replica(replica) - all_vl = fitstate.all_vl_chi2_for_replica(replica) + for epoch in range(log_each - 1, final_epoch + 1, log_each): + fitstate = self._history.get_state(epoch) + all_tr = fitstate.all_tr_chi2_for_replica(i_replica) + all_vl = fitstate.all_vl_chi2_for_replica(i_replica) tmp = {exp: {"training": tr_chi2} for exp, tr_chi2 in all_tr.items()} for exp, vl_chi2 in all_vl.items(): @@ -635,7 +552,7 @@ def chi2exps_json(self, replica=0, log_each=100): tmp[exp] = {"training": None} tmp[exp]["validation"] = vl_chi2 - json_dict[i + 1] = tmp + json_dict[epoch + 1] = tmp return json_dict