From eae2d967ff793e2470546e1682ef183f09191b35 Mon Sep 17 00:00:00 2001 From: siranipour Date: Mon, 19 Jul 2021 12:50:11 +0100 Subject: [PATCH 1/2] Changing logic of reading pseudodata --- validphys2/src/validphys/pseudodata.py | 67 +++++++++++--------------- 1 file changed, 28 insertions(+), 39 deletions(-) diff --git a/validphys2/src/validphys/pseudodata.py b/validphys2/src/validphys/pseudodata.py index 1e773cd8f7..12cb04a2c7 100644 --- a/validphys2/src/validphys/pseudodata.py +++ b/validphys2/src/validphys/pseudodata.py @@ -21,9 +21,11 @@ DataTrValSpec = namedtuple('DataTrValSpec', ['pseudodata', 'tr_idx', 'val_idx']) context_index = collect("groups_index", ("fitcontext",)) +read_fit_pseudodata = collect('read_replica_pseudodata', ('fitreplicas',)) +read_pdf_pseudodata = collect('read_replica_pseudodata', ('pdfreplicas',)) @check_cuts_fromfit -def read_fit_pseudodata(fitcontext, context_index): +def read_replica_pseudodata(fit, context_index, replica): """Function to handle the reading of training and validation splits for a fit that has been produced with the ``savepseudodata`` flag set to ``True``. @@ -68,45 +70,32 @@ def read_fit_pseudodata(fitcontext, context_index): # The [0] is because of how pandas handles sorting a MultiIndex sorted_index = context_index.sortlevel(level=range(1,3))[0] - pdf = fitcontext["pdf"] - log.debug(f"Using same pseudodata & training/validation splits as {pdf.name}.") - nrep = len(pdf) - path = pathlib.Path(pdf.infopath) - - data_indices_list = [] - for rep_number in range(1, nrep): - # This is a symlink (usually). - replica = path.with_name(pdf.name + "_" + str(rep_number).zfill(4) + ".dat") - # we resolve the symlink - if replica.parent.is_symlink(): - replica = pathlib.Path(os.path.realpath(replica)) - - training_path = replica.with_name("training.dat") - validation_path = replica.with_name("validation.dat") - - try: - tr = pd.read_csv(training_path, index_col=[0, 1, 2], sep="\t", names=["data"]) - val = pd.read_csv(validation_path, index_col=[0, 1, 2], sep="\t", names=["data"]) - except FileNotFoundError as e: - raise FileNotFoundError( - "Could not find saved training and validation data files. " - f"Please ensure {pdf} was generated with the savepseudodata flag set to true" - ) from e - tr["type"], val["type"] = "training", "validation" - - pseudodata = pd.concat((tr, val)) - pseudodata.sort_index(level=range(1,3), inplace=True) - - pseudodata.index = sorted_index - - tr = pseudodata[pseudodata["type"]=="training"] - val = pseudodata[pseudodata["type"]=="validation"] - - data_indices_list.append( - DataTrValSpec(pseudodata.drop("type", axis=1), tr.index, val.index) - ) + log.debug(f"Reading pseudodata & training/validation splits from {fit.name}.") + path = pathlib.Path(fit.path) / 'nnfit' + replica_path = path / ("replica_" + str(replica)) + + training_path = replica_path / "training.dat" + validation_path = replica_path / "validation.dat" + + try: + tr = pd.read_csv(training_path, index_col=[0, 1, 2], sep="\t", names=[f"replica {replica}"]) + val = pd.read_csv(validation_path, index_col=[0, 1, 2], sep="\t", names=[f"replica {replica}"]) + except FileNotFoundError as e: + raise FileNotFoundError( + "Could not find saved training and validation data files. " + f"Please ensure {fit} was generated with the savepseudodata flag set to true" + ) from e + tr["type"], val["type"] = "training", "validation" + + pseudodata = pd.concat((tr, val)) + pseudodata.sort_index(level=range(1,3), inplace=True) + + pseudodata.index = sorted_index + + tr = pseudodata[pseudodata["type"]=="training"] + val = pseudodata[pseudodata["type"]=="validation"] - return data_indices_list + return DataTrValSpec(pseudodata.drop("type", axis=1), tr.index, val.index) def make_replica(dataset_inputs_loaded_cd_with_cuts, replica_mcseed): From 4cf6d570163ef3a3b62b5bd2bd130e9b3a83b7fc Mon Sep 17 00:00:00 2001 From: siranipour Date: Wed, 21 Jul 2021 10:47:33 +0100 Subject: [PATCH 2/2] Removing os import --- validphys2/src/validphys/pseudodata.py | 1 - 1 file changed, 1 deletion(-) diff --git a/validphys2/src/validphys/pseudodata.py b/validphys2/src/validphys/pseudodata.py index 12cb04a2c7..349d72c4dc 100644 --- a/validphys2/src/validphys/pseudodata.py +++ b/validphys2/src/validphys/pseudodata.py @@ -5,7 +5,6 @@ """ from collections import namedtuple import logging -import os import pathlib import numpy as np