Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion package/CHANGELOG
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ Changes
* Maximum pinned versions in setup.py removed for python 3.6+ (PR #3139)

Deprecations
* Deprecated current Group.bonds, angles, dihedrals, impropers
behavior where it returns connections to atoms outside the group
by default (Issue #1264, #2821, PR #3159)
* ParmEdConverter no longer accepts Timestep objects at all
(Issue #3031, PR #3172)
* NCDFWriter `scale_factor` writing will change in version 2.0 to
Expand All @@ -43,11 +46,14 @@ Deprecations
replaced with hydrogenbonds.WaterBridgeAnalysis (#3111)
* TPRParser indexing resids from 0 by default is deprecated.
From 2.0 TPRParser will index resids from 1 by default.


Enhancements
* Added `get_connections` method to get bonds, angles, dihedrals, etc.
with or without atoms outside the group (Issues #1264, #2821, PR #3160)
* Added `tpr_resid_from_one` keyword to select whether TPRParser
indexes resids from 0 or 1 (Issue #2364, PR #3153)


01/17/21 richardjgowers, IAlibay, orbeckst, tylerjereddy, jbarnoud,
yuxuanzhuang, lilyminium, VOD555, p-j-smith, bieniekmateusz,
Expand Down
35 changes: 35 additions & 0 deletions package/MDAnalysis/core/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,41 @@ def __getattr__(self, attr):
err += 'Did you mean {match}?'.format(match=match)
raise AttributeError(err)

def get_connections(self, typename, outside=True):
"""
Get bonded connections between atoms as a
:class:`~MDAnalysis.core.topologyobjects.TopologyGroup`.

Parameters
----------
typename : str
group name. One of {"bonds", "angles", "dihedrals",
"impropers", "ureybradleys", "cmaps"}
outside : bool (optional)
Whether to include connections involving atoms outside
this group.

Returns
-------
TopologyGroup
containing the bonded group of choice, i.e. bonds, angles,
dihedrals, impropers, ureybradleys or cmaps.

.. versionadded:: 1.1.0
"""
# AtomGroup has handy error messages for missing attributes
ugroup = getattr(self.universe.atoms, typename)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we're not using atomgroup_intersection here because this method can handle not just things in a TopologyGroup?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it was more to avoid having to distinguish between Atoms and AtomGroups (etc) with ix_array, as atomgroup_intersection only accepts AtomGroups

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just realised that Residue.bonds is not existing functionality, so I added tests.

if not len(ugroup):
return ugroup
func = np.any if outside else np.all
try:
indices = self.atoms.ix_array
except AttributeError: # if self is an Atom
indices = self.ix_array
seen = [np.in1d(col, indices) for col in ugroup._bix.T]
mask = func(seen, axis=0)
return ugroup[mask]


class _ImmutableBase(object):
"""Class used to shortcut :meth:`__new__` to :meth:`object.__new__`.
Expand Down
38 changes: 35 additions & 3 deletions package/MDAnalysis/core/topologyattrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2250,15 +2250,47 @@ def _bondDict(self):
def set_atoms(self, ag):
return NotImplementedError("Cannot set bond information")

def get_atoms(self, ag):
def get_atoms(self, ag, outside=True):
"""
Get subset for atoms.

Parameters
----------
ag : AtomGroup
outside : bool (optional)
Whether to include connections to atoms outside the given
AtomGroup.

.. versionchanged:: 1.1.0
Added the ``outside`` keyword. Set to ``True`` by default
to give the same behavior as previously
"""
warn = True
try:
unique_bonds = set(itertools.chain(
*[self._bondDict[a] for a in ag.ix]))
except TypeError:
# maybe we got passed an Atom
unique_bonds = self._bondDict[ag.ix]
bond_idx, types, guessed, order = np.hsplit(
np.array(sorted(unique_bonds), dtype=object), 4)
warn = False
unique_bonds = np.array(sorted(unique_bonds), dtype=object)
if not outside:
Copy link
Copy Markdown
Member

@cbouy cbouy Mar 15, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why use a not here?

if outside: 
    <deprecation code>
else:
    <np.all>

It makes it more readable imo

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I started with strict=False 😅

This will make it easier to add to 2.0, though, where we will not have the deprecation warning.

indices = np.array([list(bd[0]) for bd in unique_bonds])
try:
mask = np.all(np.isin(indices, ag.ix), axis=1)
except np.AxisError:
mask = []
unique_bonds = unique_bonds[mask]
elif warn:
warnings.warn("This group contains all connections "
"where at least one atom in the "
"AtomGroup is involved. In MDAnalysis "
"2.0 this behavior will change so that "
"the group only contains connections "
"where all atoms are in the AtomGroup.",
DeprecationWarning)

bond_idx, types, guessed, order = np.hsplit(unique_bonds, 4)
bond_idx = np.array(bond_idx.ravel().tolist(), dtype=np.int32)
types = types.ravel()
guessed = guessed.ravel()
Expand Down
146 changes: 145 additions & 1 deletion testsuite/MDAnalysisTests/core/test_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
import MDAnalysis as mda
from MDAnalysis.exceptions import NoDataError
from MDAnalysisTests import make_Universe, no_deprecated_call
from MDAnalysisTests.datafiles import PSF, DCD
from MDAnalysisTests.datafiles import PSF, DCD, TPR
from MDAnalysis.core import groups


Expand Down Expand Up @@ -1466,3 +1466,147 @@ def test_decorator(self, compound, pbc, unwrap):
self.dummy_funtion(compound=compound, pbc=pbc, unwrap=unwrap)
else:
assert_equal(self.dummy_funtion(compound=compound, pbc=pbc, unwrap=unwrap), 0)


@pytest.fixture()
def tpr():
return mda.Universe(TPR)


class TestGetConnectionsAtoms(object):
"""Test Atom and AtomGroup.get_connections"""

@pytest.mark.parametrize("typename",
["bonds", "angles", "dihedrals", "impropers"])
def test_connection_from_atom_not_outside(self, tpr, typename):
cxns = tpr.atoms[1].get_connections(typename, outside=False)
assert len(cxns) == 0

@pytest.mark.parametrize("typename, n_atoms", [
("bonds", 1),
("angles", 3),
("dihedrals", 4),
])
def test_connection_from_atom_outside(self, tpr, typename, n_atoms):
cxns = tpr.atoms[10].get_connections(typename, outside=True)
assert len(cxns) == n_atoms

@pytest.mark.parametrize("typename, n_atoms", [
("bonds", 9),
("angles", 15),
("dihedrals", 12),
])
def test_connection_from_atoms_not_outside(self, tpr, typename,
n_atoms):
ag = tpr.atoms[:10]
cxns = ag.get_connections(typename, outside=False)
assert len(cxns) == n_atoms
indices = np.ravel(cxns.to_indices())
assert np.all(np.in1d(indices, ag.indices))

@pytest.mark.parametrize("typename, n_atoms", [
("bonds", 13),
("angles", 27),
("dihedrals", 38),
])
def test_connection_from_atoms_outside(self, tpr, typename, n_atoms):
ag = tpr.atoms[:10]
cxns = ag.get_connections(typename, outside=True)
assert len(cxns) == n_atoms
indices = np.ravel(cxns.to_indices())
assert not np.all(np.in1d(indices, ag.indices))

def test_invalid_connection_error(self, tpr):
with pytest.raises(AttributeError, match="does not contain"):
ag = tpr.atoms[:10]
ag.get_connections("ureybradleys")

@pytest.mark.parametrize("outside", [True, False])
def test_get_empty_group(self, tpr, outside):
imp = tpr.impropers
ag = tpr.atoms[:10]
cxns = ag.get_connections("impropers", outside=outside)
assert len(imp) == 0
assert len(cxns) == 0


class TestGetConnectionsResidues(object):
"""Test Residue and ResidueGroup.get_connections"""

@pytest.mark.parametrize("typename, n_atoms", [
("bonds", 9),
("angles", 14),
("dihedrals", 9),
("impropers", 0),
])
def test_connection_from_res_not_outside(self, tpr, typename, n_atoms):
cxns = tpr.residues[10].get_connections(typename, outside=False)
assert len(cxns) == n_atoms

@pytest.mark.parametrize("typename, n_atoms", [
("bonds", 11),
("angles", 22),
("dihedrals", 27),
("impropers", 0),
])
def test_connection_from_res_outside(self, tpr, typename, n_atoms):
cxns = tpr.residues[10].get_connections(typename, outside=True)
assert len(cxns) == n_atoms

@pytest.mark.parametrize("typename, n_atoms", [
("bonds", 157),
("angles", 290),
("dihedrals", 351),
])
def test_connection_from_residues_not_outside(self, tpr, typename,
n_atoms):
ag = tpr.residues[:10]
cxns = ag.get_connections(typename, outside=False)
assert len(cxns) == n_atoms
indices = np.ravel(cxns.to_indices())
assert np.all(np.in1d(indices, ag.atoms.indices))

@pytest.mark.parametrize("typename, n_atoms", [
("bonds", 158),
("angles", 294),
("dihedrals", 360),
])
def test_connection_from_residues_outside(self, tpr, typename, n_atoms):
ag = tpr.residues[:10]
cxns = ag.get_connections(typename, outside=True)
assert len(cxns) == n_atoms
indices = np.ravel(cxns.to_indices())
assert not np.all(np.in1d(indices, ag.atoms.indices))

def test_invalid_connection_error(self, tpr):
with pytest.raises(AttributeError, match="does not contain"):
ag = tpr.residues[:10]
ag.get_connections("ureybradleys")

@pytest.mark.parametrize("outside", [True, False])
def test_get_empty_group(self, tpr, outside):
imp = tpr.impropers
ag = tpr.residues[:10]
cxns = ag.get_connections("impropers", outside=outside)
assert len(imp) == 0
assert len(cxns) == 0


@pytest.mark.parametrize("typename, n_atoms", [
("bonds", 13),
("angles", 27),
("dihedrals", 38),
])
def test_get_topologygroup_property_deprecated(tpr, typename, n_atoms):
werr = ("This group contains all connections "
"where at least one atom in the "
"AtomGroup is involved. In MDAnalysis "
"2.0 this behavior will change so that "
"the group only contains connections "
"where all atoms are in the AtomGroup.")
with pytest.warns(DeprecationWarning, match=werr):
ag = tpr.atoms[:10]
cxns = getattr(ag, typename)
assert len(cxns) == n_atoms
indices = np.ravel(cxns.to_indices())
assert not np.all(np.in1d(indices, ag.indices))