diff --git a/package/CHANGELOG b/package/CHANGELOG index e59c09023c6..d0e28171e36 100644 --- a/package/CHANGELOG +++ b/package/CHANGELOG @@ -31,6 +31,9 @@ Changes * Maximum pinned versions in setup.py removed for python 3.6+ (PR #3139) Deprecations + * Deprecated current Group.bonds, angles, dihedrals, impropers + behavior where it returns connections to atoms outside the group + by default (Issue #1264, #2821, PR #3159) * ParmEdConverter no longer accepts Timestep objects at all (Issue #3031, PR #3172) * NCDFWriter `scale_factor` writing will change in version 2.0 to @@ -43,11 +46,14 @@ Deprecations replaced with hydrogenbonds.WaterBridgeAnalysis (#3111) * TPRParser indexing resids from 0 by default is deprecated. From 2.0 TPRParser will index resids from 1 by default. + Enhancements + * Added `get_connections` method to get bonds, angles, dihedrals, etc. + with or without atoms outside the group (Issues #1264, #2821, PR #3160) * Added `tpr_resid_from_one` keyword to select whether TPRParser indexes resids from 0 or 1 (Issue #2364, PR #3153) - + 01/17/21 richardjgowers, IAlibay, orbeckst, tylerjereddy, jbarnoud, yuxuanzhuang, lilyminium, VOD555, p-j-smith, bieniekmateusz, diff --git a/package/MDAnalysis/core/groups.py b/package/MDAnalysis/core/groups.py index d47cdb35dd6..9fe6de650b5 100644 --- a/package/MDAnalysis/core/groups.py +++ b/package/MDAnalysis/core/groups.py @@ -388,6 +388,41 @@ def __getattr__(self, attr): err += 'Did you mean {match}?'.format(match=match) raise AttributeError(err) + def get_connections(self, typename, outside=True): + """ + Get bonded connections between atoms as a + :class:`~MDAnalysis.core.topologyobjects.TopologyGroup`. + + Parameters + ---------- + typename : str + group name. One of {"bonds", "angles", "dihedrals", + "impropers", "ureybradleys", "cmaps"} + outside : bool (optional) + Whether to include connections involving atoms outside + this group. + + Returns + ------- + TopologyGroup + containing the bonded group of choice, i.e. bonds, angles, + dihedrals, impropers, ureybradleys or cmaps. + + .. versionadded:: 1.1.0 + """ + # AtomGroup has handy error messages for missing attributes + ugroup = getattr(self.universe.atoms, typename) + if not len(ugroup): + return ugroup + func = np.any if outside else np.all + try: + indices = self.atoms.ix_array + except AttributeError: # if self is an Atom + indices = self.ix_array + seen = [np.in1d(col, indices) for col in ugroup._bix.T] + mask = func(seen, axis=0) + return ugroup[mask] + class _ImmutableBase(object): """Class used to shortcut :meth:`__new__` to :meth:`object.__new__`. diff --git a/package/MDAnalysis/core/topologyattrs.py b/package/MDAnalysis/core/topologyattrs.py index fa9530bf1ed..8d27b23fac8 100644 --- a/package/MDAnalysis/core/topologyattrs.py +++ b/package/MDAnalysis/core/topologyattrs.py @@ -2250,15 +2250,47 @@ def _bondDict(self): def set_atoms(self, ag): return NotImplementedError("Cannot set bond information") - def get_atoms(self, ag): + def get_atoms(self, ag, outside=True): + """ + Get subset for atoms. + + Parameters + ---------- + ag : AtomGroup + outside : bool (optional) + Whether to include connections to atoms outside the given + AtomGroup. + + .. versionchanged:: 1.1.0 + Added the ``outside`` keyword. Set to ``True`` by default + to give the same behavior as previously + """ + warn = True try: unique_bonds = set(itertools.chain( *[self._bondDict[a] for a in ag.ix])) except TypeError: # maybe we got passed an Atom unique_bonds = self._bondDict[ag.ix] - bond_idx, types, guessed, order = np.hsplit( - np.array(sorted(unique_bonds), dtype=object), 4) + warn = False + unique_bonds = np.array(sorted(unique_bonds), dtype=object) + if not outside: + indices = np.array([list(bd[0]) for bd in unique_bonds]) + try: + mask = np.all(np.isin(indices, ag.ix), axis=1) + except np.AxisError: + mask = [] + unique_bonds = unique_bonds[mask] + elif warn: + warnings.warn("This group contains all connections " + "where at least one atom in the " + "AtomGroup is involved. In MDAnalysis " + "2.0 this behavior will change so that " + "the group only contains connections " + "where all atoms are in the AtomGroup.", + DeprecationWarning) + + bond_idx, types, guessed, order = np.hsplit(unique_bonds, 4) bond_idx = np.array(bond_idx.ravel().tolist(), dtype=np.int32) types = types.ravel() guessed = guessed.ravel() diff --git a/testsuite/MDAnalysisTests/core/test_groups.py b/testsuite/MDAnalysisTests/core/test_groups.py index d0113187f10..20d070835cf 100644 --- a/testsuite/MDAnalysisTests/core/test_groups.py +++ b/testsuite/MDAnalysisTests/core/test_groups.py @@ -37,7 +37,7 @@ import MDAnalysis as mda from MDAnalysis.exceptions import NoDataError from MDAnalysisTests import make_Universe, no_deprecated_call -from MDAnalysisTests.datafiles import PSF, DCD +from MDAnalysisTests.datafiles import PSF, DCD, TPR from MDAnalysis.core import groups @@ -1466,3 +1466,147 @@ def test_decorator(self, compound, pbc, unwrap): self.dummy_funtion(compound=compound, pbc=pbc, unwrap=unwrap) else: assert_equal(self.dummy_funtion(compound=compound, pbc=pbc, unwrap=unwrap), 0) + + +@pytest.fixture() +def tpr(): + return mda.Universe(TPR) + + +class TestGetConnectionsAtoms(object): + """Test Atom and AtomGroup.get_connections""" + + @pytest.mark.parametrize("typename", + ["bonds", "angles", "dihedrals", "impropers"]) + def test_connection_from_atom_not_outside(self, tpr, typename): + cxns = tpr.atoms[1].get_connections(typename, outside=False) + assert len(cxns) == 0 + + @pytest.mark.parametrize("typename, n_atoms", [ + ("bonds", 1), + ("angles", 3), + ("dihedrals", 4), + ]) + def test_connection_from_atom_outside(self, tpr, typename, n_atoms): + cxns = tpr.atoms[10].get_connections(typename, outside=True) + assert len(cxns) == n_atoms + + @pytest.mark.parametrize("typename, n_atoms", [ + ("bonds", 9), + ("angles", 15), + ("dihedrals", 12), + ]) + def test_connection_from_atoms_not_outside(self, tpr, typename, + n_atoms): + ag = tpr.atoms[:10] + cxns = ag.get_connections(typename, outside=False) + assert len(cxns) == n_atoms + indices = np.ravel(cxns.to_indices()) + assert np.all(np.in1d(indices, ag.indices)) + + @pytest.mark.parametrize("typename, n_atoms", [ + ("bonds", 13), + ("angles", 27), + ("dihedrals", 38), + ]) + def test_connection_from_atoms_outside(self, tpr, typename, n_atoms): + ag = tpr.atoms[:10] + cxns = ag.get_connections(typename, outside=True) + assert len(cxns) == n_atoms + indices = np.ravel(cxns.to_indices()) + assert not np.all(np.in1d(indices, ag.indices)) + + def test_invalid_connection_error(self, tpr): + with pytest.raises(AttributeError, match="does not contain"): + ag = tpr.atoms[:10] + ag.get_connections("ureybradleys") + + @pytest.mark.parametrize("outside", [True, False]) + def test_get_empty_group(self, tpr, outside): + imp = tpr.impropers + ag = tpr.atoms[:10] + cxns = ag.get_connections("impropers", outside=outside) + assert len(imp) == 0 + assert len(cxns) == 0 + + +class TestGetConnectionsResidues(object): + """Test Residue and ResidueGroup.get_connections""" + + @pytest.mark.parametrize("typename, n_atoms", [ + ("bonds", 9), + ("angles", 14), + ("dihedrals", 9), + ("impropers", 0), + ]) + def test_connection_from_res_not_outside(self, tpr, typename, n_atoms): + cxns = tpr.residues[10].get_connections(typename, outside=False) + assert len(cxns) == n_atoms + + @pytest.mark.parametrize("typename, n_atoms", [ + ("bonds", 11), + ("angles", 22), + ("dihedrals", 27), + ("impropers", 0), + ]) + def test_connection_from_res_outside(self, tpr, typename, n_atoms): + cxns = tpr.residues[10].get_connections(typename, outside=True) + assert len(cxns) == n_atoms + + @pytest.mark.parametrize("typename, n_atoms", [ + ("bonds", 157), + ("angles", 290), + ("dihedrals", 351), + ]) + def test_connection_from_residues_not_outside(self, tpr, typename, + n_atoms): + ag = tpr.residues[:10] + cxns = ag.get_connections(typename, outside=False) + assert len(cxns) == n_atoms + indices = np.ravel(cxns.to_indices()) + assert np.all(np.in1d(indices, ag.atoms.indices)) + + @pytest.mark.parametrize("typename, n_atoms", [ + ("bonds", 158), + ("angles", 294), + ("dihedrals", 360), + ]) + def test_connection_from_residues_outside(self, tpr, typename, n_atoms): + ag = tpr.residues[:10] + cxns = ag.get_connections(typename, outside=True) + assert len(cxns) == n_atoms + indices = np.ravel(cxns.to_indices()) + assert not np.all(np.in1d(indices, ag.atoms.indices)) + + def test_invalid_connection_error(self, tpr): + with pytest.raises(AttributeError, match="does not contain"): + ag = tpr.residues[:10] + ag.get_connections("ureybradleys") + + @pytest.mark.parametrize("outside", [True, False]) + def test_get_empty_group(self, tpr, outside): + imp = tpr.impropers + ag = tpr.residues[:10] + cxns = ag.get_connections("impropers", outside=outside) + assert len(imp) == 0 + assert len(cxns) == 0 + + +@pytest.mark.parametrize("typename, n_atoms", [ + ("bonds", 13), + ("angles", 27), + ("dihedrals", 38), +]) +def test_get_topologygroup_property_deprecated(tpr, typename, n_atoms): + werr = ("This group contains all connections " + "where at least one atom in the " + "AtomGroup is involved. In MDAnalysis " + "2.0 this behavior will change so that " + "the group only contains connections " + "where all atoms are in the AtomGroup.") + with pytest.warns(DeprecationWarning, match=werr): + ag = tpr.atoms[:10] + cxns = getattr(ag, typename) + assert len(cxns) == n_atoms + indices = np.ravel(cxns.to_indices()) + assert not np.all(np.in1d(indices, ag.indices))