diff --git a/process/main.py b/process/main.py index ddf5acac31..be7052a465 100644 --- a/process/main.py +++ b/process/main.py @@ -197,6 +197,11 @@ def parse_args(self, args): action="store_true", help="Print the version of PROCESS to the terminal", ) + parser.add_argument( + "--update-obsolete", + action="store_true", + help="Automatically update obsolete variables in the IN.DAT file", + ) # If args is not None, then parse the supplied arguments. This is likely # to come from the test suite when testing command-line arguments; the @@ -215,7 +220,11 @@ def run_mode(self): if self.args.varyiterparams: self.run = VaryRun(self.args.varyiterparamsconfig, self.args.solver) else: - self.run = SingleRun(self.args.input, self.args.solver) + self.run = SingleRun( + self.args.input, + self.args.solver, + update_obsolete=self.args.update_obsolete, + ) self.run.run() def post_process(self): @@ -355,9 +364,8 @@ def run(self): class SingleRun: """Perform a single run of PROCESS.""" - def __init__(self, input_file, solver="vmcon"): + def __init__(self, input_file, solver="vmcon", *, update_obsolete=False): """Read input file and initialise variables. - :param input_file: input file named IN.DAT :type input_file: str :param solver: which solver to use, as specified in solver.py @@ -365,7 +373,7 @@ def __init__(self, input_file, solver="vmcon"): """ self.input_file = input_file - self.validate_input() + self.validate_input(update_obsolete) self.init_module_vars() self.set_filenames() self.initialise() @@ -524,49 +532,103 @@ def append_input(self): mfile_file.write("***********************************************") mfile_file.writelines(input_lines) - def validate_input(self): - """Checks the input IN.DAT file for any obsolete variables in the OBS_VARS dict contained - within obsolete_variables.py. - Then will print out what the used obsolete variables are (if any) before continuing the proces run. + def validate_input(self, replace_obsolete=False): + """ + Checks the input IN.DAT file for any obsolete variables in the OBS_VARS dict contained + within obsolete_variables.py. If obsolete variables are found, and if `replace_obsolete` + is set to True, they are either removed or replaced by their updated names as specified + in the OBS_VARS dictionary. + + Parameters: + replace_obsolete (bool): If True, modifies the IN.DAT file to replace or comment out + obsolete variables. If False, only reports obsolete variables. """ obsolete_variables = ov.OBS_VARS obsolete_vars_help_message = ov.OBS_VARS_HELP filename = self.input_file - variables_in_in_dat = [] + modified_lines = [] + changes_made = [] # To store details of the changes + with open(filename, "r") as file: for line in file: - if line[0] == "*" or "=" not in line: + # Skip comment lines or lines without an assignment + if line.startswith("*") or "=" not in line: + modified_lines.append(line) continue + # Extract the variable name before the separator + variable_name = line.split("=", 1)[0].strip() + variables_in_in_dat.append(variable_name) + + # Check if the variable is obsolete and needs replacing + if variable_name in obsolete_variables: + replacement = obsolete_variables.get(variable_name) + + if replace_obsolete: + # Prepare replacement or removal + if replacement is None: + # If no replacement is defined, comment out the line + modified_lines.append(f"* Obsolete: {line}") + changes_made.append( + f"Commented out obsolete variable: {variable_name}" + ) + else: + if isinstance(replacement, list): + # Raise an error if replacement is a list + replacement_str = ", ".join(replacement) + raise ValueError( + f"The variable '{variable_name}' is obsolete and should be replaced by the following variables: {replacement_str}. " + "Please set their values accordingly." + ) + else: + # Replace obsolete variable with updated variable + modified_line = line.replace( + variable_name, replacement, 1 + ) + modified_lines.append( + f"* Replaced '{variable_name}' with '{replacement}'\n{modified_line}" + ) + changes_made.append( + f"Replaced '{variable_name}' with '{replacement}'" + ) + else: + # If replacement is False, add the line as-is + modified_lines.append(line) else: - sep = " " - variables = line.strip().split(sep, 1)[0] - variables_in_in_dat.append(variables) - - obs_vars_in_in_dat = [] - replace_hints = {} - for var in variables_in_in_dat: - if var in obsolete_variables: - obs_vars_in_in_dat.append(var) - replace_hints[var] = obsolete_variables.get(var) - - if len(obs_vars_in_in_dat) > 0: - message = ( - "The IN.DAT file contains obsolete variables from the OBS_VARS dictionary. The obsolete variables in your IN.DAT file are: " - f"{obs_vars_in_in_dat}. " - "Either remove these or replace them with their updated variable names. " - ) - for obs_var in obs_vars_in_in_dat: - if replace_hints[obs_var] is None: - message += f"\n\n {obs_var} is an obsolete variable and needs to be removed. " - else: - message += f"\n \n {obs_var} is an obsolete variable and needs to be replaced by {str(replace_hints[obs_var])}. " - message += f"{obsolete_vars_help_message.get(obs_var, '')}" - - raise ValueError(message) + modified_lines.append(line) + + obs_vars_in_in_dat = [ + var for var in variables_in_in_dat if var in obsolete_variables + ] + if obs_vars_in_in_dat: + if replace_obsolete: + # If replace_obsolete is True, write the modified content to the file + with open(filename, "w") as file: + file.writelines(modified_lines) + print( + "The IN.DAT file has been updated to replace or comment out obsolete variables." + ) + print("Summary of changes made:") + for change in changes_made: + print(f" - {change}") + else: + # Only print the report if replace_obsolete is False + message = ( + "The IN.DAT file contains obsolete variables from the OBS_VARS dictionary. " + f"The obsolete variables in your IN.DAT file are: {obs_vars_in_in_dat}. " + "Either remove these or replace them with their updated variable names. " + ) + for obs_var in obs_vars_in_in_dat: + replacement = obsolete_variables.get(obs_var) + if replacement is None: + message += f"\n\n{obs_var} is an obsolete variable and needs to be removed." + else: + message += f"\n\n{obs_var} is an obsolete variable and needs to be replaced by {replacement}." + message += f" {obsolete_vars_help_message.get(obs_var, '')}" + raise ValueError(message) else: print("The IN.DAT file does not contain any obsolete variables.") diff --git a/tests/unit/test_main.py b/tests/unit/test_main.py index 858886cf67..b348d58080 100644 --- a/tests/unit/test_main.py +++ b/tests/unit/test_main.py @@ -102,6 +102,8 @@ def test_run_mode(process_obj, monkeypatch): monkeypatch.setattr(process_obj, "args", argparse.Namespace(), raising=False) monkeypatch.setattr(process_obj.args, "varyiterparams", True, raising=False) monkeypatch.setattr(process_obj.args, "version", False, raising=False) + monkeypatch.setattr(process_obj.args, "update_obsolete", False, raising=False) + monkeypatch.setattr( process_obj.args, "varyiterparamsconfig", "file.conf", raising=False )