Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions n3fit/src/n3fit/backends/keras_backend/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"""

import logging
from pathlib import Path
from time import time

from keras import backend as K
Expand Down Expand Up @@ -196,6 +197,64 @@ def on_step_end(self, epoch, logs=None):
self._update_weights()


class StoreCallback(CallbackStep):
"""
Given a ``savedir``, the callback will store the model parameters in
that directory every ``check_freq`` epochs.

Parameters
----------
pdf_model: MetaModel
The multi-replica PDF model
replica_paths: list[Path]
One path for replica. Weights are saved under <path>/weights/.
check_freq: int
Save every this many epochs (default: 100)
"""

def __init__(self, pdf_model, replica_paths, stopping_object, check_freq=100):
super().__init__()
self.check_freq = check_freq
self.pdf_model = pdf_model
self.weight_dirs = []
self.stopping_object = stopping_object
for path in replica_paths:
weight_dir = path / "parameters"
weight_dir.mkdir(parents=True, exist_ok=True)
self.weight_dirs.append(weight_dir)

def _save_weights(self, epoch, tr_weights, weight_dir):
filepath = weight_dir / f"params_{epoch}.npz"
# save parameters as expected by colibri
trainable_weights_flat = np.concatenate([np.asarray(w).flatten() for w in tr_weights])
np.savez(filepath, params=trainable_weights_flat)
log.info(f"Saved parameters at epoch {epoch} in {filepath}")

def on_step_end(self, epoch, logs=None):
"""Function to be called at the end of every epoch
Every ``check_freq`` number of epochs, the parameters of the model will
be stored in the indicated directory.
"""
if ((epoch + 1) % self.check_freq) == 0:
pdf_replicas = self.pdf_model.split_replicas()
for replica_model, weight_dir in zip(pdf_replicas, self.weight_dirs):
weights = replica_model.trainable_weights
self._save_weights(epoch + 1, weights, weight_dir)

def on_train_end(self, logs=None):
"""Store the best parameters"""
for idx, weight_dir in enumerate(self.weight_dirs):
weights = self.stopping_object._best_weights[idx]
if weights is not None:
best_weights = weights['all_NNs']
best_epoch = self.stopping_object._best_epochs[idx]
self._save_weights(best_epoch, best_weights, weight_dir)
else:
log.warning(
f"No best weights found for replica {idx+1}, skipping saving best parameters."
)


def gen_tensorboard_callback(log_dir, profiling=False, histogram_freq=0):
"""
Generate tensorboard logging details at ``log_dir``.
Expand Down
9 changes: 9 additions & 0 deletions n3fit/src/n3fit/io/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,11 @@ def _write_metadata_json(self, i, out_path):
# Note: the 2 arguments below are the same for all replicas, unless run separately
timing=self.timings,
stop_epoch=self.stopping_object.stop_epoch,
would_stop_epoch=(
self.stopping_object.would_stop_epoch
if self.stopping_object._dont_stop
else self.stopping_object.stop_epoch
),
)

with open(out_path, "w", encoding="utf-8") as fs:
Expand Down Expand Up @@ -347,6 +352,7 @@ def jsonfit(
true_chi2,
stop_epoch,
timing,
would_stop_epoch,
):
"""Generates a dictionary containing all relevant metadata for the fit

Expand All @@ -372,6 +378,8 @@ def jsonfit(
epoch at which the stopping stopped (not the one for the best fit!)
timing: dict
dictionary of the timing of the different events that happened
would_stop_epoch: int
epoch at which the stopping would have stopped if it were not set to "dont_stop"
"""
all_info = {}
# Generate preprocessing information
Expand All @@ -386,6 +394,7 @@ def jsonfit(
all_info["arc_lengths"] = arc_lengths
all_info["integrability"] = integrability_numbers
all_info["timing"] = timing
all_info["would_stop_epoch"] = would_stop_epoch
# Versioning info
all_info["version"] = version()
return all_info
Expand Down
62 changes: 59 additions & 3 deletions n3fit/src/n3fit/model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from collections import namedtuple
from itertools import zip_longest
import logging
import pickle

import numpy as np

Expand Down Expand Up @@ -112,6 +113,10 @@ def __init__(
theoryid=None,
lux_params=None,
replicas=None,
save_checkpoints=False,
replica_path=None,
checkpoint_freq=100,
dont_stop=False,
):
"""
Parameters
Expand Down Expand Up @@ -152,6 +157,15 @@ def __init__(
if not give, the photon is not generated
replicas: list
list with the replicas ids to be fitted
save_checkpoints: bool
whether to save checkpoints (i.e. model parameters) during the fit. This requires
`replica_path` to be set as well. Not doing this will raise an error.
replica_path: Path
root path for all replicas.
checkpoint_freq: int
frequency (in epochs) at which to save checkpoints. Only relevant if `save_checkpoints` is True.
dont_stop: bool
whether to disable the stopping mechanism, i.e. to run for all epochs regardless of the validation chi2
"""
# Save all input information
self.exp_info = list(exp_info)
Expand All @@ -168,6 +182,14 @@ def __init__(
self.lux_params = lux_params
self.replicas = replicas
self.experiments_data = experiments_data
self.dont_stop = dont_stop

# Checkpointing options
self.save_checkpoints = save_checkpoints
self.replica_path = replica_path
self.checkpoint_freq = checkpoint_freq
if self.save_checkpoints and self.replica_path is None:
raise ValueError("To save checkpoints, the 'replica_path' key must be set as well.")

# Initialise internal variables which define behaviour
if debug:
Expand Down Expand Up @@ -721,11 +743,24 @@ def _train_and_fit(self, training_model, stopping_object, epochs=100) -> bool:
self.training["integmultipliers"],
update_freq=PUSH_INTEGRABILITY_EACH,
)
callback_list = [callback_st, callback_pos, callback_integ]

if self.save_checkpoints:
pdf_model = training_model.get_layer("PDFs")
# Save parameters where colibri will look for checkpoints
replica_paths = [
self.replica_path.parent / f"fit_replicas/replica_{r}" for r in self.replicas
]
checpoint_callback = callbacks.StoreCallback(
pdf_model=pdf_model,
replica_paths=replica_paths,
check_freq=self.checkpoint_freq,
stopping_object=stopping_object,
)
callback_list.append(checpoint_callback)

training_model.perform_fit(
epochs=epochs,
verbose=False,
callbacks=self.callbacks + [callback_st, callback_pos, callback_integ],
epochs=epochs, verbose=False, callbacks=self.callbacks + callback_list
)

def _hyperopt_override(self, params):
Expand Down Expand Up @@ -921,6 +956,26 @@ def hyperparametrizable(self, params):
)
replicas_settings.append(tmp)

# TODO: tempoerary fix to use NTK utilities in colibri
# Create model pkl for colibri n3fit module
_init_args = {
"flav_info": self.flavinfo,
"replica_range_settings": {
"min_replica": np.sort(self.replicas)[0],
"max_replica": np.sort(self.replicas)[0],
},
"impose_sumrule": self.impose_sumrule,
"fitbasis": self.fitbasis,
"nodes": params["nodes_per_layer"],
"activations": params["activation_per_layer"],
"initializer_name": params["initializer"],
"layer_type": params["layer_type"],
}
state = {"_init_args": _init_args}

with open(self.replica_path.parent / "pdf_model.pkl", "wb") as file:
pickle.dump(state, file)

### Training loop
for k, partition in enumerate(self.kpartitions):

Expand Down Expand Up @@ -987,6 +1042,7 @@ def hyperparametrizable(self, params):
stopping_patience=stopping_epochs,
threshold_positivity=threshold_pos,
threshold_chi2=threshold_chi2,
dont_stop=self.dont_stop,
)

# Compile each of the models with the right parameters
Expand Down
9 changes: 9 additions & 0 deletions n3fit/src/n3fit/performfit.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ def performfit(
maxcores=None,
double_precision=False,
parallel_models=True,
save_checkpoints=False,
checkpoint_freq=100,
dont_stop=False,
):
"""
This action will (upon having read a validcard) process a full PDF fit
Expand Down Expand Up @@ -128,6 +131,8 @@ def performfit(
whether to use double precision
parallel_models: bool
whether to run models in parallel
save_checkpoints: bool
whether to save checkpoints (i.e. model parameters) during the fit.
"""
from n3fit.backends import set_initial_state

Expand Down Expand Up @@ -197,6 +202,10 @@ def performfit(
theoryid=theoryid,
lux_params=fiatlux,
replicas=replica_idxs,
save_checkpoints=save_checkpoints,
replica_path=replica_path,
checkpoint_freq=checkpoint_freq,
dont_stop=dont_stop,
)

# This is just to give a descriptive name to the fit function
Expand Down
16 changes: 15 additions & 1 deletion n3fit/src/n3fit/stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@ def __init__(

self._dont_stop = dont_stop
self._stop_now = False
self._would_stop_epoch = None
self.stopping_patience = stopping_patience
self.total_epochs = total_epochs

Expand Down Expand Up @@ -481,7 +482,20 @@ def make_stop(self):
and reload the history to the point of the best model if any
"""
self._stop_now = True
self._restore_best_weights()
if self._would_stop_epoch is None:
# final_epoch is the last registered epoch (0-indexed); +1 to match stop_epoch convention
self._would_stop_epoch = (
-1 if self._history.final_epoch is None else self._history.final_epoch + 1
)
if not self._dont_stop:
self._restore_best_weights()

@property
def would_stop_epoch(self):
"""Epoch at which early stopping would have triggered.
Returns None if stopping never triggered (fit converged within total_epochs).
When dont_stop=False this equals stop_epoch."""
return self._would_stop_epoch

def _restore_best_weights(self):
for i_replica, weights in enumerate(self._best_weights):
Expand Down
3 changes: 3 additions & 0 deletions validphys2/src/validphys/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,6 +886,9 @@ def errorbar68(self):
up = np.nanpercentile(self.error_members(), 84.13, axis=0)
return down, up

def median(self):
return np.median(self.error_members(), axis=0)

def sample_values(self, size):
return np.random.choice(self, size=size)

Expand Down
4 changes: 4 additions & 0 deletions validphys2/src/validphys/pdfgrids.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ def __post_init__(self):
if not isinstance(self.grid_values, Stats):
raise ValueError("`XPlottingGrid` grid_values can only be instances of `Stats`")

# Ensure that flavours is a list or tuple and not numpy array
if isinstance(self.flavours, np.ndarray):
self.flavours = self.flavours.tolist()

def select_flavour(self, flindex):
"""Return a new grid for one single flavour"""
if isinstance(flindex, str):
Expand Down
Loading