diff --git a/abfe_tutorial/abfe_analysis.ipynb b/abfe_tutorial/abfe_analysis.ipynb new file mode 100644 index 0000000..9d5738a --- /dev/null +++ b/abfe_tutorial/abfe_analysis.ipynb @@ -0,0 +1,704 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "2c27fe6a-0c34-48b2-aae6-c4ba47ff545d", + "metadata": {}, + "source": [ + "# Analysis of results from ABFE calculations\n", + "\n", + "In this notebook we show how to analyze results obtained with the OpenFE ABFE protocol.\n", + "This notebook shows you how to extract\n", + "\n", + "- The overall binding free energy of each ligand in the dataset (DG)\n", + "- The contribution from the different legs (complex and solvent) of a transformation." + ] + }, + { + "cell_type": "markdown", + "id": "09b36a2a-f1ea-435f-a950-00058c4aae50", + "metadata": {}, + "source": [ + "### Downloading the example dataset\n", + "First let's download some example ABFE results. Please skip this section if you have already done this!" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "8365a5cd-81f2-45ad-9b7b-9225f972da09", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--2025-10-21 11:53:17-- https://zenodo.org/records/17348229/files/abfe_results.zip\n", + "Resolving zenodo.org (zenodo.org)... 188.185.45.92, 188.185.43.25, 188.185.48.194, ...\n", + "Connecting to zenodo.org (zenodo.org)|188.185.45.92|:443... connected.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: 1005319 (982K) [application/octet-stream]\n", + "Saving to: ‘abfe_results.zip’\n", + "\n", + "abfe_results.zip 100%[===================>] 981.76K 718KB/s in 1.4s \n", + "\n", + "2025-10-21 11:53:20 (718 KB/s) - ‘abfe_results.zip’ saved [1005319/1005319]\n", + "\n", + "Archive: abfe_results.zip\n", + " creating: abfe_results/\n", + " creating: abfe_results/results_2/\n", + " inflating: abfe_results/toluene_results.json \n", + " creating: abfe_results/results_1/\n", + " creating: abfe_results/results_0/\n", + " inflating: abfe_results/results_2/1.json \n", + " inflating: abfe_results/results_1/1.json \n", + " inflating: abfe_results/results_0/1.json \n" + ] + } + ], + "source": [ + "!wget https://zenodo.org/records/17348229/files/abfe_results.zip\n", + "!unzip abfe_results.zip" + ] + }, + { + "cell_type": "markdown", + "id": "f2cff1e7-cf62-4839-b53d-62b59b7189b6", + "metadata": {}, + "source": [ + "### Imports\n", + "Here are a bunch of imports we will need later in the notebook." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "7fbf1482-25ca-427b-a881-af88a983461c", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import glob\n", + "import json\n", + "import csv\n", + "import os\n", + "import pathlib\n", + "from typing import Literal, List\n", + "from gufe.tokenization import JSON_HANDLER\n", + "import pandas as pd\n", + "from openff.units import unit\n", + "from openfecli.commands.gather import (\n", + " format_estimate_uncertainty,\n", + " _collect_result_jsons,\n", + " load_json,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "c62eb0c3-0423-49b3-9d02-ffc1c49210ef", + "metadata": {}, + "source": [ + "### Some helper methods to load and format the ABFE results\n", + "Over the next few cells, we define some helper methods that we will use to load and format the ABFE results.\n", + "\n", + "Note: you do not need to directly interact with any of these, unless you are looking to change the behaviour of how data is being processed" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "04ccdedd-84e3-4ccc-ae42-4f3772acb115", + "metadata": {}, + "outputs": [], + "source": [ + "def _load_valid_result_json(\n", + " fpath: os.PathLike | str,\n", + ") -> tuple[tuple | None, dict | None]:\n", + " \"\"\"Load the data from a results JSON into a dict.\n", + "\n", + " Parameters\n", + " ----------\n", + " fpath : os.PathLike | str\n", + " The path to deserialized results.\n", + "\n", + " Returns\n", + " -------\n", + " dict | None\n", + " A dict containing data from the results JSON,\n", + " or None if the JSON file is invalid or missing.\n", + "\n", + " Raises\n", + " ------\n", + " ValueError\n", + " If the JSON file contains an ``estimate`` or ``uncertainty`` key with the\n", + " value ``None``.\n", + " If\n", + " \"\"\"\n", + "\n", + " # TODO: only load this once during collection, then pass namedtuple(fname, dict) into this function\n", + " # for now though, it's not the bottleneck on performance\n", + " result = load_json(fpath)\n", + " try:\n", + " name = _get_name(result)\n", + " except (ValueError, IndexError):\n", + " print(f\"{fpath}: Missing ligand names. Skipping.\")\n", + " return None, None\n", + " if result[\"estimate\"] is None:\n", + " errormsg = f\"{fpath}: No 'estimate' found, assuming to be a failed simulation.\"\n", + " raise ValueError(errormsg)\n", + "\n", + " return name, result" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "3733c540-de62-45a0-ba66-268397a38b49", + "metadata": {}, + "outputs": [], + "source": [ + "def _get_legs_from_result_jsons(\n", + " result_fns: list[pathlib.Path]\n", + ") -> dict[tuple[str, str], dict[str, list]]:\n", + " \"\"\"\n", + " Iterate over a list of result JSONs and populate a dict of dicts with all data needed\n", + " for results processing.\n", + "\n", + "\n", + " Parameters\n", + " ----------\n", + " result_fns : list[pathlib.Path]\n", + " List of filepaths containing results formatted as JSON.\n", + " report : Literal[\"dg\", \"raw\"]\n", + " Type of report to generate.\n", + "\n", + " Returns\n", + " -------\n", + " legs: dict[str,,dict[str, list]]\n", + " Data extracted from the given result JSONs, organized by the ligand name and simulation type.\n", + " \"\"\"\n", + " from collections import defaultdict\n", + "\n", + " dgs = defaultdict(lambda: defaultdict(list))\n", + "\n", + " for result_fn in result_fns:\n", + " name, result = _load_valid_result_json(result_fn)\n", + " if name is None: # this means it couldn't find name and/or simtype\n", + " continue\n", + "\n", + " dgs[name]['overall'].append([result[\"estimate\"], result[\"uncertainty\"]])\n", + " proto_key = [\n", + " k\n", + " for k in result[\"unit_results\"].keys()\n", + " if k.startswith(\"ProtocolUnitResult\") \n", + " ]\n", + " for p in proto_key:\n", + " if \"unit_estimate\" in result[\"unit_results\"][p][\"outputs\"]:\n", + " simtype = result[\"unit_results\"][p][\"outputs\"][\"simtype\"]\n", + " dg = result[\"unit_results\"][p][\"outputs\"][\"unit_estimate\"]\n", + " dg_error = result[\"unit_results\"][p][\"outputs\"][\"unit_estimate_error\"]\n", + " \n", + " dgs[name][simtype].append([dg, dg_error])\n", + " if \"standard_state_correction\" in result[\"unit_results\"][p][\"outputs\"]:\n", + " corr = result[\"unit_results\"][p][\"outputs\"][\"standard_state_correction\"]\n", + " dgs[name][\"standard_state_correction\"].append([corr, 0*unit.kilocalorie_per_mole])\n", + " else:\n", + " continue\n", + "\n", + " return dgs" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "08afcbcf-34a8-4450-a967-eb4ed5800ee1", + "metadata": {}, + "outputs": [], + "source": [ + "def _get_name(result:dict) -> str:\n", + " \"\"\"Get the ligand name from a unit's results data.\n", + "\n", + " Parameters\n", + " ----------\n", + " result : dict\n", + " A results dict.\n", + "\n", + " Returns\n", + " -------\n", + " str\n", + " Ligand name corresponding to the results.\n", + " \"\"\"\n", + " try:\n", + " nm = list(result['unit_results'].values())[0]['name']\n", + "\n", + " except KeyError:\n", + " raise ValueError(\"Failed to guess name\")\n", + "\n", + " toks = nm.split('Binding, ')\n", + " if 'solvent' in toks[1]:\n", + " name = toks[1].split(' solvent')[0]\n", + " if 'complex' in toks[1]:\n", + " name = toks[1].split(' complex')[0]\n", + " return name" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "a5d4aef3-1fab-40b4-848c-d59df8ee1441", + "metadata": {}, + "outputs": [], + "source": [ + "def _error_std(r):\n", + " \"\"\"\n", + " Calculate the error of the estimate as the std of the repeats\n", + " \"\"\"\n", + " return np.std([v[0].m for v in r[\"overall\"]])" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "0621e3a2-7906-4661-a640-c414456f8869", + "metadata": {}, + "outputs": [], + "source": [ + "def _error_mbar(r):\n", + " \"\"\"\n", + " Calculate the error of the estimate using the reported MBAR errors.\n", + "\n", + " This also takes into account that repeats may have been run for this edge by using the average MBAR error\n", + " \"\"\"\n", + " complex_errors = [x[1].m for x in r[\"complex\"]]\n", + " solvent_errors = [x[1].m for x in r[\"solvent\"]]\n", + " return np.sqrt(np.mean(complex_errors) ** 2 + np.mean(solvent_errors) ** 2)" + ] + }, + { + "cell_type": "markdown", + "id": "3550510d-435b-4322-afed-9e521cce8937", + "metadata": {}, + "source": [ + "### Methods to extract and manipulate the ABFE results\n", + "The next three methods allow you to extract ABFE results (extract_results_dict) and then manipulate them to get different types of results.\n", + "\n", + "These manipulation methods include:\n", + "\n", + "- `generate_dg`: to get the dG values.\n", + "- `generate_dg_raw`: to get the raw dG values for each individual legs in the ABFE transformation cycles." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "5821b7f7-6aed-4138-9502-44df3af52647", + "metadata": {}, + "outputs": [], + "source": [ + "def extract_results_dict(\n", + " results_files: list[os.PathLike | str],\n", + ") -> dict[str, dict[str, list]]:\n", + " \"\"\"\n", + " Get a dictionary of ABFE results from a list of directories.\n", + "\n", + " Parameters\n", + " ----------\n", + " results_files : list[ps.PathLike | str]\n", + " A list of directors with ABFE result files to process.\n", + "\n", + " Returns\n", + " -------\n", + " sim_results : dict[str, dict[str, list]]\n", + " Simulation results, organized by the leg's ligand names and simulation type.\n", + " \"\"\"\n", + " # find and filter result jsons\n", + " result_fns = _collect_result_jsons(results_files)\n", + " # pair legs of simulations together into dict of dicts\n", + " sim_results = _get_legs_from_result_jsons(result_fns)\n", + "\n", + " return sim_results" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "e021b2ea-db13-4e47-a6d1-cf5229b5b494", + "metadata": {}, + "outputs": [], + "source": [ + "def generate_dg(results_dict: dict[str, dict[str, list]]) -> pd.DataFrame:\n", + " \"\"\"Compute and write out DG values for the given results.\n", + "\n", + " Parameters\n", + " ----------\n", + " results_dict : dict[str, dict[str, list]]\n", + " Dictionary of results created by ``extract_results_dict``.\n", + "\n", + " Returns\n", + " -------\n", + " pd.DataFrame\n", + " A pandas DataFrame with the dG results for each ligand.\n", + " \"\"\"\n", + " data = []\n", + " # check the type of error which should be used based on the number of repeats\n", + " repeats = {len(v[\"overall\"]) for v in results_dict.values()}\n", + " error_func = _error_mbar if 1 in repeats else _error_std\n", + " for lig, results in sorted(results_dict.items()):\n", + " dg = np.mean([v[0].m for v in results[\"overall\"]])\n", + " error = error_func(results)\n", + " m, u = format_estimate_uncertainty(dg, error, unc_prec=2)\n", + " data.append((lig, m, u))\n", + "\n", + " df = pd.DataFrame(\n", + " data,\n", + " columns=[\n", + " \"ligand\",\n", + " \"DG (kcal/mol)\",\n", + " \"uncertainty (kcal/mol)\",\n", + " ],\n", + " )\n", + " return df" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "b077c4e2-373a-4314-a527-186bb4683262", + "metadata": {}, + "outputs": [], + "source": [ + "def generate_dg_raw(results_dict: dict[str, dict[str, list]]) -> pd.DataFrame:\n", + " \"\"\"\n", + " Get all the transformation cycle legs found and their DG values.\n", + "\n", + " Parameters\n", + " ----------\n", + " results_dict : dict[str, dict[str, list]]\n", + " Dictionary of results created by ``extract_results_dict``.\n", + "\n", + " Returns\n", + " -------\n", + " pd.DataFrame\n", + " A pandas DataFrame with the individual cycle leg dG results.\n", + " \"\"\"\n", + " data = []\n", + " for lig, results in sorted(results_dict.items()):\n", + " for simtype, repeats in sorted(results.items()):\n", + " if simtype != \"overall\":\n", + " for repeat in repeats:\n", + " m, u = format_estimate_uncertainty(\n", + " repeat[0].m, repeat[1].m, unc_prec=2\n", + " )\n", + " data.append((simtype, lig, m, u))\n", + "\n", + " df = pd.DataFrame(\n", + " data,\n", + " columns=[\n", + " \"leg\",\n", + " \"ligand\",\n", + " \"DG (kcal/mol)\",\n", + " \"uncertainty (kcal/mol)\",\n", + " ],\n", + " )\n", + " return df" + ] + }, + { + "cell_type": "markdown", + "id": "98a84e37-91c3-4f78-8a68-a54f8835f7d2", + "metadata": {}, + "source": [ + "## Analyzing your results\n", + "Now that we have defined a set of methods to help us extract results, let's analyze the results!\n", + "\n", + "### Specify result directories and gather results\n", + "Let's start by gathering all our simulation results. First we define all the directories where our ABFE results exist. Here we assume that our simulation repeats sit in three different results directories under abfe_results, named from results_0 to results_2." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "7bc49c0e-6fec-409c-a01c-42c35f57dcc6", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/ialibay/software/mambaforge/install/envs/openfe/lib/python3.12/site-packages/openmoltools/utils.py:9: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.\n", + " from pkg_resources import resource_filename\n", + "/home/ialibay/software/mambaforge/install/envs/openfe/lib/python3.12/site-packages/Bio/Application/__init__.py:39: BiopythonDeprecationWarning: The Bio.Application modules and modules relying on it have been deprecated.\n", + "\n", + "Due to the on going maintenance burden of keeping command line application\n", + "wrappers up to date, we have decided to deprecate and eventually remove these\n", + "modules.\n", + "\n", + "We instead now recommend building your command line and invoking it directly\n", + "with the subprocess module.\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "# Specify paths to result directories\n", + "results_dir = [\n", + " pathlib.Path(\"abfe_results/results_0\"),\n", + " pathlib.Path(\"abfe_results/results_1\"),\n", + " pathlib.Path(\"abfe_results/results_2\"),\n", + "]\n", + "dgs = extract_results_dict(results_dir)" + ] + }, + { + "cell_type": "markdown", + "id": "d6e47322-bd5b-4b3b-b601-c406c03f8284", + "metadata": {}, + "source": [ + "### Obtain the absolute binding free energy for all nodes in the network\n", + "With these extracted results, we can now get the dG prediction between each ligand.\n", + "\n", + "Note: if only a single repeat was run, the MBAR error is used as uncertainty estimate, while the standard deviation is used when results from more than one repeat are provided." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "46996a74-709c-41f2-ac39-0f77fb33371e", + "metadata": {}, + "outputs": [], + "source": [ + "df_dg = generate_dg(dgs)\n", + "df_dg.to_csv('dg.tsv', sep=\"\\t\", lineterminator=\"\\n\", index=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "d1a6ad61-1e5a-4d8a-9067-9ed428ef145c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
| \n", + " | ligand | \n", + "DG (kcal/mol) | \n", + "uncertainty (kcal/mol) | \n", + "
|---|---|---|---|
| 0 | \n", + "1 | \n", + "-18.36 | \n", + "0.98 | \n", + "
| \n", + " | leg | \n", + "ligand | \n", + "DG (kcal/mol) | \n", + "uncertainty (kcal/mol) | \n", + "
|---|---|---|---|---|
| 0 | \n", + "complex | \n", + "1 | \n", + "36.87 | \n", + "0.36 | \n", + "
| 1 | \n", + "complex | \n", + "1 | \n", + "39.24 | \n", + "0.44 | \n", + "
| 2 | \n", + "complex | \n", + "1 | \n", + "37.46 | \n", + "0.48 | \n", + "
| 3 | \n", + "solvent | \n", + "1 | \n", + "10.57 | \n", + "0.66 | \n", + "
| 4 | \n", + "solvent | \n", + "1 | \n", + "10.48 | \n", + "0.65 | \n", + "
| 5 | \n", + "solvent | \n", + "1 | \n", + "10.41 | \n", + "0.66 | \n", + "
| 6 | \n", + "standard_state_correction | \n", + "1 | \n", + "-8.9 | \n", + "0.0 | \n", + "
| 7 | \n", + "standard_state_correction | \n", + "1 | \n", + "-9.1 | \n", + "0.0 | \n", + "
| 8 | \n", + "standard_state_correction | \n", + "1 | \n", + "-9.0 | \n", + "0.0 | \n", + "