diff --git a/n3fit/runcards/Basic_runcard.yml b/n3fit/runcards/Basic_runcard.yml index 6d0d567e55..b044c6a9e5 100644 --- a/n3fit/runcards/Basic_runcard.yml +++ b/n3fit/runcards/Basic_runcard.yml @@ -38,7 +38,6 @@ fitting: trvlseed: 1 nnseed: 2 mcseed: 3 - epochs: 900 save: 'weights.h5' # load: '/path/to/weights.h5/file' diff --git a/n3fit/src/n3fit/checks.py b/n3fit/src/n3fit/checks.py index edab10031f..1b4811ce73 100644 --- a/n3fit/src/n3fit/checks.py +++ b/n3fit/src/n3fit/checks.py @@ -54,7 +54,7 @@ def check_consistent_layers(parameters): raise CheckError(f"Number of layers ({npl}) does not match activation functions: {apl}") -def check_stopping(parameters): +def check_stopping(parameters, epochs_legacy=None): """Checks whether the stopping-related options are sane: stopping patience as a ratio between 0 and 1 and positive number of epochs @@ -62,7 +62,14 @@ def check_stopping(parameters): spt = parameters.get("stopping_patience") if spt is not None and not 0.0 <= spt <= 1.0: raise CheckError(f"The stopping_patience must be between 0 and 1, got: {spt}") - epochs = parameters["epochs"] + epochs = parameters.get("epochs") + if epochs is None: + raise CheckError("A number of epochs must be given under fitting::parameters") + if epochs_legacy is not None and not epochs == epochs_legacy: + raise CheckError(f"Received contradictory values for epochs: {epochs_legacy} and {epochs}. " + "Note that 'epochs' is only supported under fitting::parameters") + if not isinstance(epochs, int): + raise CheckError(f"Only integer number of epochs allowed, received: {epochs}") if epochs < 1: raise CheckError(f"Needs to run at least 1 epoch, got: {epochs}") @@ -155,14 +162,14 @@ def check_model_file(save, load): raise CheckError(f"Model file {load} seems to be empty") @make_argcheck -def wrapper_check_NN(basis, tensorboard, save, load, parameters): - """ Wrapper function for all NN-related checks """ +def wrapper_check_NN(basis, tensorboard, save, load, parameters, fitting): + """ Wrapper function for all sensible NN-related checks """ check_tensorboard(tensorboard) check_model_file(save, load) check_existing_parameters(parameters) check_consistent_layers(parameters) check_basis_with_layers(basis, parameters) - check_stopping(parameters) + check_stopping(parameters, epochs_legacy=fitting.get("epochs")) check_dropout(parameters) check_lagrange_multipliers(parameters, "integrability") check_lagrange_multipliers(parameters, "positivity") @@ -317,12 +324,16 @@ def check_sumrules(sum_rules): # Checks on the physics @make_argcheck -def check_consistent_basis(sum_rules, fitbasis, basis, theoryid): +def check_consistent_basis(sum_rules, fitbasis, basis, theoryid, parameters): """Checks the fitbasis setup for inconsistencies - Checks the sum rules can be imposed - Correct flavours for the selected basis - Correct ranges (min < max) for the small and large-x exponents """ + # Check that the sum rules are not included in the wrong dictionary + if "sum_rules" in parameters: + raise CheckError("'sum_rules' included under fitting::parameters, they should be under " + "fiting: instead") check_sumrules(sum_rules) # Check that there are no duplicate flavours and that parameters are sane flavs = []