diff --git a/openfe/protocols/restraint_utils/geometry/boresch/geometry.py b/openfe/protocols/restraint_utils/geometry/boresch/geometry.py index 7613e4b06..c5194cfa8 100644 --- a/openfe/protocols/restraint_utils/geometry/boresch/geometry.py +++ b/openfe/protocols/restraint_utils/geometry/boresch/geometry.py @@ -251,12 +251,12 @@ def find_boresch_restraint( host_pool = find_host_atom_candidates( universe=universe, host_idxs=host_idxs, - l1_idx=guest_anchor[0], + guest_anchor_idx=guest_anchor[0], host_selection=host_selection, dssp_filter=dssp_filter, rmsf_cutoff=rmsf_cutoff, - min_distance=host_min_distance, - max_distance=host_max_distance, + min_search_distance=host_min_distance, + max_search_distance=host_max_distance, ) host_anchor = find_host_anchor( diff --git a/openfe/protocols/restraint_utils/geometry/boresch/host.py b/openfe/protocols/restraint_utils/geometry/boresch/host.py index 52e9cdf28..1a7cc4e13 100644 --- a/openfe/protocols/restraint_utils/geometry/boresch/host.py +++ b/openfe/protocols/restraint_utils/geometry/boresch/host.py @@ -29,15 +29,64 @@ from openff.units import Quantity, unit +def _host_atoms_search( + atomgroup: mda.AtomGroup, + guest_anchor_idx: int, + rmsf_cutoff: Quantity, + min_search_distance: Quantity, + max_search_distance: Quantity, +) -> npt.NDArray: + """ + Helper method to get a set of host atoms with minimal RMSF + within a given distance of a guest anchor. + + Parameters + ---------- + atomgroup : mda.AtomGroup + An AtomGroup to find host atoms in. + guest_anchor_idx : int + The index of the proposed guest anchor binding atom. + rmsf_cutoff : Quantity + The maximum allowed RMSF value for any candidate host atom. + min_search_distance : Quantity + The minimum host atom search distance around the guest anchor. + max_search_distance : Quantity + The maximum host atom search distance around the guest anchor. + + Return + ------ + NDArray + Array of host atom indexes + """ + # 0 Deal with the empty case + if len(atomgroup) == 0: + return np.array([], dtype=int) + + # 1 Get the RMSF & filter to create a new AtomGroup + rmsf = get_local_rmsf(atomgroup) + filtered_atomgroup = atomgroup.atoms[rmsf < rmsf_cutoff] + + # 2. Search for atoms within the min/max cutoff of the guest anchor + atom_finder = FindHostAtoms( + host_atoms=filtered_atomgroup, + guest_atoms=atomgroup.universe.atoms[guest_anchor_idx], + min_search_distance=min_search_distance, + max_search_distance=max_search_distance, + ) + atom_finder.run() + + return atom_finder.results.host_idxs + + def find_host_atom_candidates( universe: mda.Universe, host_idxs: list[int], - l1_idx: int, + guest_anchor_idx: int, host_selection: str, dssp_filter: bool = False, rmsf_cutoff: Quantity = 0.1 * unit.nanometer, - min_distance: Quantity = 1 * unit.nanometer, - max_distance: Quantity = 3 * unit.nanometer, + min_search_distance: Quantity = 0.5 * unit.nanometer, + max_search_distance: Quantity = 1.5 * unit.nanometer, ) -> npt.NDArray: """ Get a list of suitable host atoms. @@ -48,17 +97,17 @@ def find_host_atom_candidates( An MDAnalysis Universe defining the system and its coordinates. host_idxs : list[int] A list of the host indices in the system topology. - l1_idx : int + guest_anchor_idx : int The index of the proposed l1 binding atom. host_selection : str An MDAnalysis selection string to filter the host by. dssp_filter : bool Whether or not to apply a DSSP filter on the host selection. rmsf_cutoff : openff.units.Quantity - The maximum RMSF value allowwed for any candidate host atom. - min_distance : openff.units.Quantity + The maximum RMSF value allowed for any candidate host atom. + min_search_distance : openff.units.Quantity The minimum search distance around l1 for suitable candidate atoms. - max_distance : openff.units.Quantity + max_search_distance : openff.units.Quantity The maximum search distance around l1 for suitable candidate atoms. Return @@ -80,55 +129,65 @@ def find_host_atom_candidates( ) raise ValueError(errmsg) + # None filtered_host_ixs for condition checking later + filtered_host_idxs = None + # If requested, filter the host atoms based on if their residues exist # within stable secondary structures. if dssp_filter: - # TODO: allow user-supplied kwargs here - stable_ag = stable_secondary_structure_selection(selected_host_ag) + # TODO: allow more user-supplied kwargs here + filtered_host_idxs = _host_atoms_search( + atomgroup=stable_secondary_structure_selection(selected_host_ag), + guest_anchor_idx=guest_anchor_idx, + rmsf_cutoff=rmsf_cutoff, + min_search_distance=min_search_distance, + max_search_distance=max_search_distance, + ) - if len(stable_ag) < 20: + if len(filtered_host_idxs) < 20: wmsg = ( - "Secondary structure filtering: " - "Too few atoms found via secondary structure filtering will " - "try to only select all residues in protein chains instead." + "Restraint generation: DSSP filter found too few host atoms " + f"({len(filtered_host_idxs)} found). Will attempt to use all protein chains." ) warnings.warn(wmsg) - stable_ag = protein_chain_selection(selected_host_ag) + filtered_host_idxs = _host_atoms_search( + atomgroup=protein_chain_selection(selected_host_ag), + guest_anchor_idx=guest_anchor_idx, + rmsf_cutoff=rmsf_cutoff, + min_search_distance=min_search_distance, + max_search_distance=max_search_distance, + ) - if len(stable_ag) < 20: + if len(filtered_host_idxs) < 20: wmsg = ( - "Secondary structure filtering: " - "Too few atoms found in protein residue chains, will just " - "use all atoms." + "Restraint generation: protein chain filter found too few " + f"host atoms ({len(filtered_host_idxs)} found). Will attempt to use all host atoms in " + f"selection: {host_selection}." ) warnings.warn(wmsg) - else: - selected_host_ag = stable_ag - - # 1. Get the RMSF & filter to create a new AtomGroup - rmsf = get_local_rmsf(selected_host_ag) - filtered_host_ag = selected_host_ag.atoms[rmsf < rmsf_cutoff] - - # 2. Search of atoms within the min/max cutoff - atom_finder = FindHostAtoms( - host_atoms=filtered_host_ag, - guest_atoms=universe.atoms[l1_idx], - min_search_distance=min_distance, - max_search_distance=max_distance, - ) - atom_finder.run() + filtered_host_idxs = None + + if filtered_host_idxs is None: + filtered_host_idxs = _host_atoms_search( + atomgroup=selected_host_ag, + guest_anchor_idx=guest_anchor_idx, + rmsf_cutoff=rmsf_cutoff, + min_search_distance=min_search_distance, + max_search_distance=max_search_distance, + ) - if not atom_finder.results.host_idxs.any(): + # Crash out if no atoms were found + if len(filtered_host_idxs) == 0: errmsg = ( f"No host atoms found within the search distance " - f"{min_distance}-{max_distance} consider widening the search window." + f"{min_search_distance}-{max_search_distance}. Consider widening the search window." ) raise ValueError(errmsg) - # Now we sort them! + # Now we sort them by their distance from the guest anchor atom_sorter = CentroidDistanceSort( - sortable_atoms=universe.atoms[atom_finder.results.host_idxs], - reference_atoms=universe.atoms[l1_idx], + sortable_atoms=universe.atoms[filtered_host_idxs], + reference_atoms=universe.atoms[guest_anchor_idx], ) atom_sorter.run() diff --git a/openfe/tests/protocols/restraints/test_geometry_boresch_host.py b/openfe/tests/protocols/restraints/test_geometry_boresch_host.py index 00e14100f..249710fbc 100644 --- a/openfe/tests/protocols/restraints/test_geometry_boresch_host.py +++ b/openfe/tests/protocols/restraints/test_geometry_boresch_host.py @@ -3,6 +3,7 @@ import MDAnalysis as mda import numpy as np +from numpy.testing import assert_equal import pytest from openfe.protocols.restraint_utils.geometry.boresch.host import ( EvaluateHostAtoms1, @@ -23,24 +24,46 @@ def eg5_protein_ligand_universe(eg5_protein_pdb, eg5_ligands): def test_host_atom_candidates_dssp(eg5_protein_ligand_universe): + """ + Run DSSP search normally + """ + host_atoms = eg5_protein_ligand_universe.select_atoms("protein") + + idxs = find_host_atom_candidates( + universe=eg5_protein_ligand_universe, + host_idxs=host_atoms.ix, + # hand picked + guest_anchor_idx=5508, + host_selection="backbone and resnum 212:221", + dssp_filter=True, + ) + expected = np.array( + [3144, 3146, 3145, 3143, 3162, 3206, 3200, 3207, 3126, 3201, 3127, + 3163, 3199, 3202, 3164, 3125, 3165, 3177, 3208, 3179, 3124, 3216, + 3209, 3109, 3107, 3178, 3110, 3180, 3108, 3248, 3217, 3249, 3226, + 3218, 3228, 3227, 3250, 3219, 3251, 3229] + ) + assert_equal(idxs, expected) + + +def test_host_atom_candidates_dssp_too_few_atoms(eg5_protein_ligand_universe): """ Make sure both dssp warnings are triggered """ host_atoms = eg5_protein_ligand_universe.select_atoms("protein") with ( - pytest.warns( - match="Too few atoms found via secondary structure filtering will" - ), - pytest.warns(match="Too few atoms found in protein residue chains,"), + pytest.warns(match="DSSP filter found"), + pytest.warns(match="protein chain filter found"), ): _ = find_host_atom_candidates( universe=eg5_protein_ligand_universe, - host_idxs=[a.ix for a in host_atoms], + host_idxs=host_atoms.ix, # hand picked - l1_idx=5508, + guest_anchor_idx=5508, host_selection="backbone and resnum 15:25", dssp_filter=True, + max_search_distance=2*unit.nanometer ) @@ -52,12 +75,12 @@ def test_host_atom_candidate_small_search(eg5_protein_ligand_universe): ): _ = find_host_atom_candidates( universe=eg5_protein_ligand_universe, - host_idxs=[a.ix for a in host_atoms], + host_idxs=host_atoms.ix, # hand picked - l1_idx=5508, + guest_anchor_idx=5508, host_selection="backbone", dssp_filter=False, - max_distance=0.1 * unit.angstrom, + max_search_distance=0.1 * unit.angstrom, )