Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 105 additions & 18 deletions openfecli/commands/gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,85 @@ def format_estimate_uncertainty(
return est_str, unc_str


def format_df_with_precision(
df: pd.DataFrame, est_col_name: str, unc_col_name: str, unc_prec: int = 1
) -> pd.DataFrame:
"""
Returns a new DataFrame with the columns `est_col_name` and `unc_col_name` formatted as strings reported to `unc_prec` precision.

The uncertainty column will be rounded to `unc_prec` precision, then the estimate column will be reported to the same precision.
Any entries that are not floats (such as strings indicating errors), will not be modified.

Parameters
----------
df : pd.DataFrame
DataFrame to format
est_col_name : str
Name of the column containing estimates to format.
unc_col_name : str
Name of the column containing uncertainties to format.

unc_prec : int, optional
Precision to round the uncertainty column to, by default 1.

Returns
-------
pd.DataFrame
DataFrame with formatted uncertainty and estimate columns.

Example
-------
>>> df
ligand_i ligand_j DDG(i->j) (kcal/mol) uncertainty (kcal/mol)
0 lig_ejm_31 lig_ejm_42 Error Error
1 lig_ejm_31 lig_ejm_46 -0.891077 0.064825
2 lig_ejm_31 lig_ejm_47 0.023341 0.145625
3 lig_ejm_31 lig_ejm_48 0.614103 0.088704
4 lig_ejm_31 lig_ejm_50 0.999904 0.044457
5 lig_ejm_42 lig_ejm_43 1.354348 0.156009
6 lig_ejm_46 lig_jmc_23 0.294761 0.086632
7 lig_ejm_46 lig_jmc_27 -0.101737 0.100997
8 lig_ejm_46 lig_jmc_28 Error Error
>>> df_out = format_df_with_precision(df, "DDG(i->j) (kcal/mol)", "uncertainty (kcal/mol)")
>>> df_formatted
ligand_i ligand_j DDG(i->j) (kcal/mol) uncertainty (kcal/mol)
0 lig_ejm_31 lig_ejm_42 Error Error
1 lig_ejm_31 lig_ejm_46 -0.89 0.06
2 lig_ejm_31 lig_ejm_47 0.0 0.1
3 lig_ejm_31 lig_ejm_48 0.61 0.09
4 lig_ejm_31 lig_ejm_50 1.00 0.04
5 lig_ejm_42 lig_ejm_43 1.4 0.2
6 lig_ejm_46 lig_jmc_23 0.29 0.09
7 lig_ejm_46 lig_jmc_27 -0.1 0.1
8 lig_ejm_46 lig_jmc_28 Error Error

"""

# find all entries in both columns that contain strings:
df_is_string = df[[est_col_name, unc_col_name]].applymap(lambda x: isinstance(x, str))

# if either the estimate or uncertainty entries are strings, dont format
no_strings_mask = ~(df_is_string[est_col_name] | df_is_string[unc_col_name])

# skip rows that contain striangs and only round and format numerical vals
df_floats_formatted = df[no_strings_mask].apply(
lambda row: format_estimate_uncertainty(row[est_col_name], row[unc_col_name], unc_prec),
axis=1,
result_type="expand",
)

# explicitly cast to string to make pandas happy
df[[est_col_name, unc_col_name]] = df[[est_col_name, unc_col_name]].astype(str)

# if there are no floats, assigning an empty array will break things
if df_floats_formatted.empty:
pass
else:
df.loc[no_strings_mask, [est_col_name, unc_col_name]] = df_floats_formatted.values

return df


def is_results_json(fpath: os.PathLike | str) -> bool:
"""Sanity check that file is a result json before we try to deserialize"""
return "estimate" in open(fpath, "r").read(20)
Expand Down Expand Up @@ -254,7 +333,7 @@ def _generate_bad_legs_error_message(bad_legs: list[tuple[set[str], tuple[str]]]
return msg


def _get_ddgs(legs: dict, allow_partial=False) -> None:
def _get_ddgs(legs: dict, allow_partial=False) -> pd.DataFrame:
import numpy as np

from openfe.protocols.openmm_rfe.equil_rfe_methods import (
Expand Down Expand Up @@ -316,10 +395,14 @@ def _get_ddgs(legs: dict, allow_partial=False) -> None:
)
click.secho(err_msg, err=True, fg="red")
sys.exit(1)
return DDGs
df_ddg = pd.DataFrame(
DDGs,
columns=["ligand_i", "ligand_j", "DDG_bind", "bind_unc", "DDG_hyd", "hyd_unc"],
)
return df_ddg


def _generate_ddg(legs: dict, allow_partial: bool) -> None:
def _generate_ddg(legs: dict, allow_partial: bool) -> pd.DataFrame:
"""Compute and write out DDG values for the given legs.

Parameters
Expand All @@ -332,23 +415,23 @@ def _generate_ddg(legs: dict, allow_partial: bool) -> None:
"""
DDGs = _get_ddgs(legs, allow_partial=allow_partial)
data = []
for ligA, ligB, DDGbind, bind_unc, DDGhyd, hyd_unc in DDGs:
if DDGbind is not None:
DDGbind, bind_unc = format_estimate_uncertainty(DDGbind, bind_unc)
for _, row in DDGs.iterrows():
ligA, ligB, DDGbind, bind_unc, DDGhyd, hyd_unc = row.to_list()
if not pd.isna(DDGbind):
data.append((ligA, ligB, DDGbind, bind_unc))
if DDGhyd is not None:
DDGhyd, hyd_unc = format_estimate_uncertainty(DDGhyd, hyd_unc)
if not pd.isna(DDGhyd):
data.append((ligA, ligB, DDGhyd, hyd_unc))
elif DDGbind is None and DDGhyd is None:
elif pd.isna(DDGbind) and pd.isna(DDGhyd):
data.append((ligA, ligB, FAIL_STR, FAIL_STR))
df = pd.DataFrame(
data,
columns=["ligand_i", "ligand_j", "DDG(i->j) (kcal/mol)", "uncertainty (kcal/mol)"],
)
return df
df_out = format_df_with_precision(df, "DDG(i->j) (kcal/mol)", "uncertainty (kcal/mol)")
return df_out


def _generate_raw(legs: dict, allow_partial=True) -> None:
def _generate_raw(legs: dict, allow_partial=True) -> pd.DataFrame:
"""
Write out all legs found and their DG values, or indicate that they have failed.

Expand All @@ -367,7 +450,7 @@ def _generate_raw(legs: dict, allow_partial=True) -> None:
if m is None:
m, u = FAIL_STR, FAIL_STR
else:
m, u = format_estimate_uncertainty(m.m, u.m)
m, u = (m.m, u.m)
data.append((simtype, ligpair[0], ligpair[1], m, u))

df = pd.DataFrame(
Expand All @@ -380,7 +463,9 @@ def _generate_raw(legs: dict, allow_partial=True) -> None:
"MBAR uncertainty (kcal/mol)",
],
)
return df
df_out = format_df_with_precision(df, "DG(i->j) (kcal/mol)", "MBAR uncertainty (kcal/mol)")

return df_out


def _check_legs_have_sufficient_repeats(legs):
Expand All @@ -394,7 +479,7 @@ def _check_legs_have_sufficient_repeats(legs):
sys.exit(1)


def _generate_dg_mle(legs: dict, allow_partial: bool) -> None:
def _generate_dg_mle(legs: dict, allow_partial: bool) -> pd.DataFrame:
"""Compute and write out DG values for the given legs.

Parameters
Expand All @@ -419,12 +504,13 @@ def _generate_dg_mle(legs: dict, allow_partial: bool) -> None:
g = nx.DiGraph()
nm_to_idx = {}
DDGbind_count = 0
for ligA, ligB, DDGbind, bind_unc, _, _ in DDGs:
for _, row in DDGs.iterrows():
ligA, ligB, DDGbind, bind_unc, _, _ = row.to_list()
for lig in (ligA, ligB):
if lig not in expected_ligs:
expected_ligs.append(lig)

if DDGbind is None or DDGbind == FAIL_STR:
if pd.isna(DDGbind) or DDGbind == FAIL_STR:
continue
DDGbind_count += 1

Expand Down Expand Up @@ -476,15 +562,16 @@ def _generate_dg_mle(legs: dict, allow_partial: bool) -> None:

data = []
for ligA, DG, unc_DG in MLEs:
DG, unc_DG = format_estimate_uncertainty(DG, unc_DG)
data.append({"ligand": ligA, "DG(MLE) (kcal/mol)": DG, "uncertainty (kcal/mol)": unc_DG})
expected_ligs.remove(ligA)

for ligA in expected_ligs:
data.append({"ligand": ligA, "DG(MLE) (kcal/mol)": FAIL_STR, "uncertainty (kcal/mol)": FAIL_STR}) # fmt: skip

df = pd.DataFrame(data)
return df
df_out = format_df_with_precision(df, "DG(MLE) (kcal/mol)", "uncertainty (kcal/mol)")

return df_out


def _collect_result_jsons(results: List[os.PathLike | str]) -> List[pathlib.Path]:
Expand Down
Loading