diff --git a/package/MDAnalysis/analysis/rms.py b/package/MDAnalysis/analysis/rms.py index aabd8ae8866..71661b2ed0d 100644 --- a/package/MDAnalysis/analysis/rms.py +++ b/package/MDAnalysis/analysis/rms.py @@ -148,6 +148,7 @@ from MDAnalysis.exceptions import SelectionError, NoDataError from MDAnalysis.lib.log import ProgressMeter from MDAnalysis.lib.util import asiterable, iterable, get_weights, deprecate +from MDAnalysis.core.groups import AtomGroup logger = logging.getLogger('MDAnalysis.analysis.rmsd') @@ -283,6 +284,8 @@ def process_selection(select): if isinstance(select, string_types): select = {'reference': str(select), 'mobile': str(select)} + elif isinstance(select, AtomGroup): + select = {'reference': select, 'mobile': select} elif type(select) is tuple: try: select = {'mobile': select[0], 'reference': select[1]} @@ -301,6 +304,10 @@ def process_selection(select): raise TypeError("'select' must be either a string, 2-tuple, or dict") select['mobile'] = asiterable(select['mobile']) select['reference'] = asiterable(select['reference']) + if isinstance(select['mobile'], AtomGroup): + select['mobile'] = [select['mobile']] + if isinstance(select['reference'], AtomGroup): + select['reference'] = [select['reference']] return select diff --git a/package/MDAnalysis/core/groups.py b/package/MDAnalysis/core/groups.py index e8fadce077f..1cb30e2100c 100644 --- a/package/MDAnalysis/core/groups.py +++ b/package/MDAnalysis/core/groups.py @@ -2850,15 +2850,27 @@ def select_atoms(self, sel, *othersel, **selgroups): "You provided {} for group '{}'".format( thing.__class__.__name__, group)) - selections = tuple((selection.Parser.parse(s, selgroups, periodic=periodic) - for s in sel_strs)) + new_sel_strs = [] + selections = [] + + for s in sel_strs: + if isinstance(s, AtomGroup): + new_sel_strs.append(repr(s)) + selections.append(s) + else: + new_sel_strs.append(s) + selections.append(selection.Parser.parse(s, selgroups, periodic=periodic)) + if updating: - atomgrp = UpdatingAtomGroup(self, selections, sel_strs) + atomgrp = UpdatingAtomGroup(self, selections, new_sel_strs) else: # Apply the first selection and sum to it atomgrp = sum([sel.apply(self) for sel in selections[1:]], selections[0].apply(self)) return atomgrp + + def apply(self, other): + return self.intersection(other) def split(self, level): """Split :class:`AtomGroup` into a :class:`list` of diff --git a/testsuite/MDAnalysisTests/analysis/test_rms.py b/testsuite/MDAnalysisTests/analysis/test_rms.py index 8971f83ad34..a9f63018d4b 100644 --- a/testsuite/MDAnalysisTests/analysis/test_rms.py +++ b/testsuite/MDAnalysisTests/analysis/test_rms.py @@ -212,6 +212,17 @@ def test_rmsd_atomgroup_selections(self, universe): R2 = MDAnalysis.analysis.rms.RMSD(universe.atoms.select_atoms("name CA"), select="resid 1-30").run() assert not np.allclose(R1.rmsd[:, 2], R2.rmsd[:, 2]) + + def test_rmsd_atomgroups(self, universe): + # compare string vs atomgroup + R1 = MDAnalysis.analysis.rms.RMSD(universe.atoms, + select="resid 1-30").run() + ag = universe.select_atoms('resid 1-30') + R2 = MDAnalysis.analysis.rms.RMSD(universe.atoms, select=ag).run() + + assert_almost_equal(R1.rmsd, R2.rmsd, 4, + err_msg='error: rmsd profile should match ' + + 'for selection strings and AtomGroups') def test_rmsd_single_frame(self, universe): RMSD = MDAnalysis.analysis.rms.RMSD(universe, select='name CA', @@ -235,7 +246,7 @@ def test_mass_weighted_and_save(self, universe, outfile, correct_values): assert_almost_equal(RMSD.rmsd, saved, 4, err_msg="error: rmsd profile should match " "saved test values") - + def test_custom_weighted(self, universe, correct_values_mass): RMSD = MDAnalysis.analysis.rms.RMSD(universe, weights="mass").run(step=49) @@ -288,7 +299,7 @@ def test_rmsd_group_selections(self, universe, correct_values_group): assert_almost_equal(RMSD.rmsd, correct_values_group, 4, err_msg="error: rmsd profile should match" "test values") - + def test_rmsd_backbone_and_group_selection(self, universe, correct_values_backbone_group): RMSD = MDAnalysis.analysis.rms.RMSD( @@ -300,6 +311,29 @@ def test_rmsd_backbone_and_group_selection(self, universe, assert_almost_equal( RMSD.rmsd, correct_values_backbone_group, 4, err_msg="error: rmsd profile should match test values") + + def test_rmsd_atomgroup_backbone_and_group_selection(self, universe, + correct_values_backbone_group): + RMSD = MDAnalysis.analysis.rms.RMSD( + universe, + reference=universe, + select=universe.select_atoms('backbone'), + groupselections=[universe.select_atoms('backbone and resid 1:10'), + universe.select_atoms('backbone and resid 10:20') + ]).run(step=49) + assert_almost_equal( + RMSD.rmsd, correct_values_backbone_group, 4, + err_msg="error: rmsd profile should match test values") + + def test_rmsd_select_atomgroup_tuple(self, universe, correct_values): + copy = universe.copy() + ag = universe.select_atoms('name CA') + ag2 = copy.select_atoms('name CA') + RMSD = MDAnalysis.analysis.rms.RMSD(universe, + reference=copy, + select=(ag, ag2)).run(step=49) + assert_almost_equal(RMSD.rmsd, correct_values, 4, + err_msg="error: rmsd profile should match test values") def test_ref_length_unequal_len(self, universe): reference = MDAnalysis.Universe(PSF, DCD) diff --git a/testsuite/MDAnalysisTests/core/test_atomselections.py b/testsuite/MDAnalysisTests/core/test_atomselections.py index 7b7d6010400..1ce39235720 100644 --- a/testsuite/MDAnalysisTests/core/test_atomselections.py +++ b/testsuite/MDAnalysisTests/core/test_atomselections.py @@ -363,6 +363,36 @@ def test_global(self, universe): "resname LYS and name NZ and around 4 backbone") ag2 = ag.select_atoms("around 4 global backbone") assert_equal(ag2.indices, ag1.indices) + + def test_returns_equal_AtomGroup_copy(self, universe): + ag = universe.select_atoms('protein') + assert ag == universe.select_atoms(ag) + assert ag is not universe.select_atoms(ag) + + def test_returns_equal_UpdatingAtomGroup_copy(self, universe): + ag = universe.select_atoms('resid 100') + uag = universe.select_atoms(ag, updating=True) + no_uag = universe.select_atoms(uag) + assert ag == uag + assert no_uag == uag + assert isinstance(uag, mda.core.groups.UpdatingAtomGroup) + assert isinstance(no_uag, mda.core.groups.AtomGroup) + assert not isinstance(no_uag, mda.core.groups.UpdatingAtomGroup) + + def test_returns_atomgroup_intersection(self, universe): + g1 = universe.select_atoms('resid 1:100') + g2 = universe.select_atoms('name CA') + g3 = universe.select_atoms('name O') + + ag1 = g1.select_atoms(g2) + ag2 = g2.select_atoms(g1) + ag3 = g1.select_atoms(g2, g3) + + assert ag1 == ag2 + assert ag1 == (g1 & g2) + assert ag3 == ((g1 & g2) + (g1 & g3)) + + class TestSelectionsAMBER(object):