Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion n3fit/runcards/Basic_runcard.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ fitting:
trvlseed: 1
nnseed: 2
mcseed: 3
epochs: 900
save: 'weights.h5'
# load: '/path/to/weights.h5/file'

Expand Down
23 changes: 17 additions & 6 deletions n3fit/src/n3fit/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,22 @@ def check_consistent_layers(parameters):
raise CheckError(f"Number of layers ({npl}) does not match activation functions: {apl}")


def check_stopping(parameters):
def check_stopping(parameters, epochs_legacy=None):
"""Checks whether the stopping-related options are sane:
stopping patience as a ratio between 0 and 1
and positive number of epochs
"""
spt = parameters.get("stopping_patience")
if spt is not None and not 0.0 <= spt <= 1.0:
raise CheckError(f"The stopping_patience must be between 0 and 1, got: {spt}")
epochs = parameters["epochs"]
epochs = parameters.get("epochs")
if epochs is None:
raise CheckError("A number of epochs must be given under fitting::parameters")
if epochs_legacy is not None and not epochs == epochs_legacy:
raise CheckError(f"Received contradictory values for epochs: {epochs_legacy} and {epochs}. "
"Note that 'epochs' is only supported under fitting::parameters")
if not isinstance(epochs, int):
raise CheckError(f"Only integer number of epochs allowed, received: {epochs}")
if epochs < 1:
raise CheckError(f"Needs to run at least 1 epoch, got: {epochs}")

Expand Down Expand Up @@ -155,14 +162,14 @@ def check_model_file(save, load):
raise CheckError(f"Model file {load} seems to be empty")

@make_argcheck
def wrapper_check_NN(basis, tensorboard, save, load, parameters):
""" Wrapper function for all NN-related checks """
def wrapper_check_NN(basis, tensorboard, save, load, parameters, fitting):
""" Wrapper function for all sensible NN-related checks """
check_tensorboard(tensorboard)
check_model_file(save, load)
check_existing_parameters(parameters)
check_consistent_layers(parameters)
check_basis_with_layers(basis, parameters)
check_stopping(parameters)
check_stopping(parameters, epochs_legacy=fitting.get("epochs"))
check_dropout(parameters)
check_lagrange_multipliers(parameters, "integrability")
check_lagrange_multipliers(parameters, "positivity")
Expand Down Expand Up @@ -317,12 +324,16 @@ def check_sumrules(sum_rules):

# Checks on the physics
@make_argcheck
def check_consistent_basis(sum_rules, fitbasis, basis, theoryid):
def check_consistent_basis(sum_rules, fitbasis, basis, theoryid, parameters):
"""Checks the fitbasis setup for inconsistencies
- Checks the sum rules can be imposed
- Correct flavours for the selected basis
- Correct ranges (min < max) for the small and large-x exponents
"""
# Check that the sum rules are not included in the wrong dictionary
if "sum_rules" in parameters:
raise CheckError("'sum_rules' included under fitting::parameters, they should be under "
"fiting: instead")
check_sumrules(sum_rules)
# Check that there are no duplicate flavours and that parameters are sane
flavs = []
Expand Down