diff --git a/package/CHANGELOG b/package/CHANGELOG index fa077b39b26..b2e72758409 100644 --- a/package/CHANGELOG +++ b/package/CHANGELOG @@ -25,6 +25,7 @@ Fixes Enhancements * Expanded selection wildcards to the start and middle of strings (Issue #2370) + * Added type checking and conversion to Connection TopologyAttrs (Issue #2373) 09/05/19 IAlibay, richardjgowers diff --git a/package/MDAnalysis/core/topologyattrs.py b/package/MDAnalysis/core/topologyattrs.py index 5366bca67c8..532e91eeb1e 100644 --- a/package/MDAnalysis/core/topologyattrs.py +++ b/package/MDAnalysis/core/topologyattrs.py @@ -1642,7 +1642,14 @@ def _get_named_segment(group, segid): class _Connection(AtomAttr): """Base class for connectivity between atoms""" def __init__(self, values, types=None, guessed=False, order=None): - self.values = list(values) + values = [tuple(x) for x in values] + if not all(len(x) == self._n_atoms + and all(isinstance(y, (int, np.integer)) for y in x) + for x in values): + raise ValueError(("{} must be an iterable of tuples with {}" + " atom indices").format(self.attrname, + self._n_atoms)) + self.values = values if types is None: types = [None] * len(values) self.types = types @@ -1743,6 +1750,7 @@ class Bonds(_Connection): # many bonds, so still asks for "bonds" in the plural singular = 'bonds' transplants = defaultdict(list) + _n_atoms = 2 def bonded_atoms(self): """An :class:`~MDAnalysis.core.groups.AtomGroup` of all @@ -1891,6 +1899,7 @@ class Angles(_Connection): attrname = 'angles' singular = 'angles' transplants = defaultdict(list) + _n_atoms = 3 class Dihedrals(_Connection): @@ -1898,6 +1907,7 @@ class Dihedrals(_Connection): attrname = 'dihedrals' singular = 'dihedrals' transplants = defaultdict(list) + _n_atoms = 4 class Impropers(_Connection): @@ -1905,3 +1915,4 @@ class Impropers(_Connection): attrname = 'impropers' singular = 'impropers' transplants = defaultdict(list) + _n_atoms = 4 diff --git a/testsuite/MDAnalysisTests/core/test_universe.py b/testsuite/MDAnalysisTests/core/test_universe.py index c3539f6eaed..bf68553d2f4 100644 --- a/testsuite/MDAnalysisTests/core/test_universe.py +++ b/testsuite/MDAnalysisTests/core/test_universe.py @@ -628,6 +628,42 @@ def test_add_charges(self, universe, toadd, attrname, default): assert hasattr(universe.atoms, attrname) assert getattr(universe.atoms, attrname)[0] == default + + @pytest.mark.parametrize( + 'attr,values', ( + ('bonds', [(1, 0), (1, 2)]), + ('bonds', [[1, 0], [1, 2]]), + ('bonds', set([(1, 0), (1, 2)])), + ('angles', [(1, 0, 2), (1, 2, 3), (2, 1, 4)]), + ('dihedrals', [[1, 2, 3, 1], (3, 1, 5, 2)]), + ('impropers', [[1, 2, 3, 1], (3, 1, 5, 2)]), + ) + ) + def test_add_connection(self, universe, attr, values): + universe.add_TopologyAttr(attr, values) + assert hasattr(universe, attr) + attrgroup = getattr(universe, attr) + assert len(attrgroup) == len(values) + for x in attrgroup: + ix = x.indices + assert ix[0] <= ix[-1] + + @pytest.mark.parametrize( + 'attr,values', ( + ('bonds', [(1, 0, 0), (1, 2)]), + ('bonds', [['x', 'y'], [1, 2]]), + ('bonds', 'rubbish'), + ('bonds', [[1.01, 2.0]]), + ('angles', [(1, 0), (1, 2)]), + ('angles', 'rubbish'), + ('dihedrals', [[1, 1, 1, 0.1]]), + ('impropers', [(1, 2, 3)]), + ) + ) + def add_connection_error(self, universe, attr, values): + with pytest.raises(ValueError): + universe.add_TopologyAttr(attr, values) + class TestAllCoordinatesKwarg(object):