Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
433ab62
Join replicas into single MetaModel
APJansen Oct 16, 2023
f44a96b
Next step in joining replicas, in model_generation
APJansen Oct 16, 2023
97d3e5d
Next step in joining replicas, in stopping, added MetaModel methods t…
APJansen Oct 16, 2023
a145d11
Next step in joining replicas, split back into separate replicas afte…
APJansen Oct 16, 2023
0c0ee66
Rewrite hyperoptimization penalties in terms of single pdf_model
APJansen Oct 16, 2023
b871ec6
Make hyperopt losses dependent on multi-replica PDF, moving the conve…
APJansen Oct 18, 2023
8316562
Merge branch 'master' into multi-dense-logistics
APJansen Oct 18, 2023
3cbc9e2
Rewrite weight loading to multi-replica pdf
APJansen Oct 18, 2023
c80ec48
Remove loop over pdf_models when registering photon
APJansen Oct 18, 2023
ca65c95
Fix saturation penalty to give loss per replica, fix test
APJansen Oct 18, 2023
d20968d
Fix test_vpinterface by splitting pdf into single replicas in the test
APJansen Oct 18, 2023
dbd1fb6
Change individual replica PDF from Lambda layer to MetaModel (tempora…
APJansen Oct 18, 2023
c63f331
Fix getting of add photon layer
APJansen Oct 18, 2023
4b84d1c
range(replicas) -> replicas
APJansen Oct 18, 2023
a39d792
load weights using replica indices starting from 0 rather than the va…
APJansen Oct 19, 2023
ee43aab
Another fix...
APJansen Oct 19, 2023
ebdbac6
Move joining of replicas back further to directly after creation of N…
APJansen Oct 19, 2023
6254018
Add default 1 for replicas in msr
APJansen Oct 19, 2023
e69926d
Fix msr test, adding missing replica dimensions
APJansen Oct 20, 2023
08ea3e5
Rewrite photon layer to work on all replicas
APJansen Oct 20, 2023
f03fd68
concatenate -> stack
APJansen Oct 23, 2023
01d9f0f
Reshape photon integrals tensor
APJansen Oct 23, 2023
1e2b50d
Fix loading of weights from file
APJansen Oct 23, 2023
b919773
PDF_0 -> PDFs in registering photon
APJansen Oct 23, 2023
60b97f3
Temporary fix for loading of replica other than 0
APJansen Oct 23, 2023
d5d88a8
Stack photon replicas
APJansen Oct 23, 2023
590cebc
Adapt photon test to new structure
APJansen Oct 23, 2023
e03255a
Fix splitting to single replicas
APJansen Oct 24, 2023
0354097
Add wrapper function to attach single replica models
APJansen Oct 24, 2023
78c973f
Fix bug replacing replica numbers in weight names
APJansen Oct 24, 2023
060274b
Cleanup
APJansen Oct 24, 2023
105da0f
Fix tests
APJansen Oct 24, 2023
f4c9b3a
replicas=[1] instead of [0] for photon?
APJansen Oct 24, 2023
47c250e
Remove reference to pdf_models
APJansen Oct 25, 2023
d70698b
Fix bug with get_replica_weights passing a reference rather than the …
APJansen Oct 25, 2023
8fd6615
Fix issue where Keras recognised single replica models as layers of m…
APJansen Oct 25, 2023
aa2d228
Fix issue with weight names
APJansen Oct 25, 2023
a84559f
Fix issue with running multiple replicas in sequence.
APJansen Oct 26, 2023
a1c5a0e
Fix test, update argument parallel_models->num_replicas
APJansen Oct 26, 2023
27701f4
Fix bug with subtract_one
APJansen Oct 26, 2023
515f7a7
Add test of split_replicas
APJansen Nov 20, 2023
97eb8c9
remove unused variable
APJansen Nov 24, 2023
261cabb
remove unused argument
APJansen Nov 24, 2023
d34554c
Remove unused arguments
APJansen Nov 24, 2023
94222e6
Pass on out argument
APJansen Nov 24, 2023
12f9939
Use a single replica generator rather than a fixed list
APJansen Nov 28, 2023
a28da06
Define NN and preprocessing_factor as constants
APJansen Nov 28, 2023
2ea4878
Simplify input generation in msr test
APJansen Nov 28, 2023
552bc97
Remove unused photon attributes
APJansen Nov 28, 2023
f0359d7
Revert "Remove unused photon attributes"
APJansen Nov 28, 2023
9d7c493
Add exception in case photons and multiple replicas are combined
APJansen Nov 29, 2023
dcf9684
Merge branch 'master' into multi-dense-logistics
APJansen Nov 29, 2023
ab219a6
Rename NN, preprocessing prefixes
APJansen Nov 30, 2023
8cb24e9
Add error if no generator set
APJansen Nov 30, 2023
50d77f5
Add comments
APJansen Nov 30, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 143 additions & 13 deletions n3fit/src/n3fit/backends/keras_backend/MetaModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@
"""

import re

import h5py
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras import optimizers as Kopt
from tensorflow.keras.models import Model
from tensorflow.python.keras.utils import tf_utils # pylint: disable=no-name-in-module

import n3fit.backends.keras_backend.operations as op

# Check the TF version to check if legacy-mode is needed (TF < 2.2)
Expand Down Expand Up @@ -42,18 +45,20 @@
"SGD": (Kopt.SGD, {"learning_rate": 0.01, "momentum": 0.0, "nesterov": False}),
}

NN_PREFIX = "NN"
PREPROCESSING_PREFIX = "preprocessing_factor"

# Some keys need to work for everyone
for k, v in optimizers.items():
v[1]["clipnorm"] = 1.0


def _default_loss(y_true, y_pred): # pylint: disable=unused-argument
def _default_loss(y_true, y_pred): # pylint: disable=unused-argument
"""Default loss to be used when the model is compiled with loss = Null
(for instance if the prediction of the model is already the loss"""
return op.sum(y_pred)



class MetaModel(Model):
"""
The model wraps keras.Model and adds some custom behaviour. Most notably it
Expand Down Expand Up @@ -95,7 +100,6 @@ def __init__(self, input_tensors, output_tensors, scaler=None, input_values=None
if not isinstance(input_values, dict):
raise TypeError("Expecting input_values to be a dict or None")


x_in = {}
# Go over the inputs. If we can deduce a constant value, either because
# it is set in input_values or because it has a tensor_content, we
Expand All @@ -111,13 +115,13 @@ def __init__(self, input_tensors, output_tensors, scaler=None, input_values=None
super().__init__(input_tensors, output_tensors, **kwargs)

self.x_in = x_in
self.tensors_in = input_tensors
self.input_tensors = input_tensors
Comment thread
APJansen marked this conversation as resolved.
self.single_replica_generator = None

self.target_tensors = None
self.compute_losses_function = None
self._scaler = scaler


@tf.autograph.experimental.do_not_convert
def _parse_input(self, extra_input=None):
"""Returns the input data the model was compiled with.
Expand All @@ -127,9 +131,7 @@ def _parse_input(self, extra_input=None):
"""
if extra_input is None:
if self.required_slots:
raise ValueError(
f"The following inputs must be provided: {self.required_slots}"
)
raise ValueError(f"The following inputs must be provided: {self.required_slots}")
return self.x_in

if not isinstance(extra_input, dict):
Expand Down Expand Up @@ -172,7 +174,7 @@ def perform_fit(self, x=None, y=None, epochs=1, **kwargs):
return loss_dict

def predict(self, x=None, **kwargs):
""" Call super().predict with the right input arguments """
"""Call super().predict with the right input arguments"""
x = self._parse_input(x)
result = super().predict(x=x, **kwargs)
return result
Expand Down Expand Up @@ -325,11 +327,139 @@ def reset_layer_weights_to(self, layer_names, reference_vals):
w.assign(v)

def apply_as_layer(self, x):
""" Apply the model as a layer """
all_input = {**self.tensors_in, **x}
"""Apply the model as a layer"""
all_input = {**self.input_tensors, **x}
return all_input, super().__call__(all_input)

def get_layer_re(self, regex):
""" Get all layers matching the given regular expression """
"""Get all layers matching the given regular expression"""
check = lambda x: re.match(regex, x.name)
return list(filter(check, self.layers))

def get_replica_weights(self, i_replica):
"""
Get the weights of replica i_replica.

This assumes that the only weights are in layers called
``NN_{i_replica}`` and ``preprocessing_factor_{i_replica}``


Parameters
----------
i_replica: int

Returns
-------
dict
dictionary with the weights of the replica
"""
NN_weights = [
tf.Variable(w, name=w.name) for w in self.get_layer(f"{NN_PREFIX}_{i_replica}").weights
]
prepro_weights = [
tf.Variable(w, name=w.name)
for w in self.get_layer(f"{PREPROCESSING_PREFIX}_{i_replica}").weights
]
weights = {NN_PREFIX: NN_weights, PREPROCESSING_PREFIX: prepro_weights}

return weights

def set_replica_weights(self, weights, i_replica=0):
"""
Set the weights of replica i_replica.

This assumes that the only weights are in layers called
``NN_{i_replica}`` and ``preprocessing_factor_{i_replica}``

Parameters
----------
weights: dict
dictionary with the weights of the replica
i_replica: int
the replica number to set, defaulting to 0
"""
self.get_layer(f"{NN_PREFIX}_{i_replica}").set_weights(weights[NN_PREFIX])
self.get_layer(f"{PREPROCESSING_PREFIX}_{i_replica}").set_weights(
weights[PREPROCESSING_PREFIX]
)
Comment thread
APJansen marked this conversation as resolved.

def split_replicas(self):
"""
Split the single multi-replica model into a list of separate single replica models,
maintaining the current state of the weights.

Returns
-------
list
list of single replica models
"""
Comment thread
APJansen marked this conversation as resolved.
if self.single_replica_generator is None:
raise ValueError("Trying to generate single replica models with no generator set.")
replicas = []
num_replicas = self.output.shape[-1]
for i_replica in range(num_replicas):
replica = self.single_replica_generator()
replica.set_replica_weights(self.get_replica_weights(i_replica))

# pick single photon
if "add_photons" in self.layers:
replica.get_layer("add_photons").set_photon(
self.get_layer("add_photons").get_photon(i_replica)
)
replicas.append(replica)

return replicas
Comment thread
APJansen marked this conversation as resolved.

def load_identical_replicas(self, model_file):
"""
From a single replica model, load the same weights into all replicas.
"""
weights = self._format_weights_from_file(model_file)

num_replicas = self.output.shape[-1]
for i_replica in range(num_replicas):
self.set_replica_weights(weights, i_replica)

def _format_weights_from_file(self, model_file):
"""Read weights from a .h5 file and format into a dictionary of tf.Variables"""
weights = {}

with h5py.File(model_file, 'r') as f:
# look at layers of the form NN_i and take the lowest i
i_replica = 0
while f"{NN_PREFIX}_{i_replica}" not in f:
i_replica += 1

weights[NN_PREFIX] = self._extract_weights(
f[f"{NN_PREFIX}_{i_replica}"], NN_PREFIX, i_replica
)
weights[PREPROCESSING_PREFIX] = self._extract_weights(
f[f"{PREPROCESSING_PREFIX}_{i_replica}"], PREPROCESSING_PREFIX, i_replica
)

return weights

def _extract_weights(self, h5_group, weights_key, i_replica):
"""Extract weights from a h5py group, turning them into Tensorflow variables"""
weights = []

def append_weights(name, node):
if isinstance(node, h5py.Dataset):
weight_name = node.name.split("/", 2)[-1]
weight_name = weight_name.replace(f"{NN_PREFIX}_{i_replica}", f"{NN_PREFIX}_0")
weight_name = weight_name.replace(
f"{PREPROCESSING_PREFIX}_{i_replica}", f"{PREPROCESSING_PREFIX}_0"
)
weights.append(tf.Variable(node[()], name=weight_name))

h5_group.visititems(append_weights)

# have to put them in the same order
weights_ordered = []
weights_model_order = [w.name for w in self.get_replica_weights(0)[weights_key]]
for w in weights_model_order:
for w_h5 in weights:
if w_h5.name == w:
weights_ordered.append(w_h5)

return weights_ordered
37 changes: 22 additions & 15 deletions n3fit/src/n3fit/hyper_optimization/penalties.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

Penalties in this module usually take as signature the positional arguments:

pdf_models: list(:py:class:`n3fit.backends.keras_backend.MetaModel`)
list of models or functions taking a ``(1, xgrid_size, 1)`` array as input
and returns a ``(1, xgrid_size, 14)`` pdf.
pdf_model: :py:class:`n3fit.backends.keras_backend.MetaModel`
model taking a ``(1, xgrid_size, 1)`` array as input
and returning a ``(1, xgrid_size, 14, replicas)`` pdf.

stopping_object: :py:class:`n3fit.stopping.Stopping`
object holding the information about the validation model
Expand All @@ -19,11 +19,12 @@
The name in the runcard must match the name used in this module.
"""
import numpy as np
from validphys import fitveto

from n3fit.vpinterface import N3PDF, integrability_numbers
from validphys import fitveto


def saturation(pdf_models=None, n=100, min_x=1e-6, max_x=1e-4, flavors=None, **_kwargs):
def saturation(pdf_model=None, n=100, min_x=1e-6, max_x=1e-4, flavors=None, **_kwargs):
"""Checks the pdf models for saturation at small x
by checking the slope from ``min_x`` to ``max_x``.
Sum the saturation loss of all pdf models
Expand Down Expand Up @@ -52,15 +53,21 @@ def saturation(pdf_models=None, n=100, min_x=1e-6, max_x=1e-4, flavors=None, **_
if flavors is None:
flavors = [1, 2]
x = np.logspace(np.log10(min_x), np.log10(max_x), n)
xin = np.expand_dims(x, axis=[0, -1])
x = np.expand_dims(x, axis=[0, -1])
extra_loss = 0.0
for pdf_model in pdf_models:
y = pdf_model.predict({"pdf_input": xin})
xpdf = y[0, :, flavors]
slope = np.diff(xpdf) / np.diff(np.log10(x))
pen = abs(np.mean(slope, axis=1)) + np.std(slope, axis=1)
# Add a small offset to avoid ZeroDivisionError
extra_loss += np.sum(1.0 / (1e-7 + pen))

y = pdf_model.predict({"pdf_input": x})
xpdf = y[0, :, flavors]

delta_logx = np.diff(np.log10(x), axis=1)
delta_xpdf = np.diff(xpdf, axis=1)
slope = delta_xpdf / delta_logx

pen = abs(np.mean(slope, axis=1)) + np.std(slope, axis=1)

# sum over flavors
# Add a small offset to avoid ZeroDivisionError
extra_loss += np.sum(1.0 / (1e-7 + pen), axis=0)
return extra_loss


Expand Down Expand Up @@ -94,7 +101,7 @@ def patience(stopping_object=None, alpha=1e-4, **_kwargs):
return vl_loss * np.exp(alpha * diff)


def integrability(pdf_models=None, **_kwargs):
def integrability(pdf_model=None, **_kwargs):
"""Adds a penalty proportional to the value of the integrability integration
It adds a 0-penalty when the value of the integrability is equal or less than the value
of the threshold defined in validphys::fitveto
Expand All @@ -111,7 +118,7 @@ def integrability(pdf_models=None, **_kwargs):
True

"""
pdf_instance = N3PDF(pdf_models)
pdf_instance = N3PDF(pdf_model.split_replicas())
integ_values = integrability_numbers(pdf_instance)
integ_overflow = np.sum(integ_values[integ_values > fitveto.INTEG_THRESHOLD])
if integ_overflow > 50.0:
Expand Down
16 changes: 12 additions & 4 deletions n3fit/src/n3fit/hyper_optimization/rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Keyword arguments that model_trainer.py will pass to this file are:

- fold_losses: a list with the loss of each fold
- n3pdfs: a list of N3PDF objects for each fit (which can contain more than 1 replica)
- pdfs_per_fold: a list of (multi replica) PDFs for each fold
- experimental_models: a reference to the model that contains the cv for all data (no masks)

New loss functions can be added directly in this module
Expand All @@ -25,7 +25,14 @@

"""
import numpy as np
from validphys.pdfgrids import xplotting_grid, distance_grids

from n3fit.vpinterface import N3PDF
from validphys.pdfgrids import distance_grids, xplotting_grid


def _pdfs_to_n3pdfs(pdfs_per_fold):
"""Convert a list of multi-replica PDFs to a list of N3PDFs"""
return [N3PDF(pdf.split_replicas(), name=f"fold_{k}") for k, pdf in enumerate(pdfs_per_fold)]


def average(fold_losses=None, **_kwargs):
Expand All @@ -43,10 +50,11 @@ def std(fold_losses=None, **_kwargs):
return np.std(fold_losses)


def fit_distance(n3pdfs=None, **_kwargs):
def fit_distance(pdfs_per_fold=None, **_kwargs):
"""Loss function for hyperoptimization based on the distance of
the fits of all folds to the first fold
"""
n3pdfs = _pdfs_to_n3pdfs(pdfs_per_fold)
if n3pdfs is None:
raise ValueError("fit_distance needs n3pdf models to act upon")
xgrid = np.concatenate([np.logspace(-6, -1, 20), np.linspace(0.11, 0.9, 30)])
Expand Down Expand Up @@ -100,6 +108,7 @@ def fit_future_tests(n3pdfs=None, experimental_models=None, **_kwargs):
compatibility_mode = False
try:
import tensorflow as tf

from n3fit.backends import set_eager

tf_version = tf.__version__.split(".")
Expand All @@ -119,7 +128,6 @@ def fit_future_tests(n3pdfs=None, experimental_models=None, **_kwargs):
# Loop over all models but the last (our reference!)
total_loss = 0.0
for n3pdf, exp_model in zip(n3pdfs[:-1], experimental_models[:-1]):

_set_central_value(n3pdf, exp_model)

# Get the full input and the total chi2
Expand Down
Loading