Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions package/MDAnalysis/analysis/rms.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@
from MDAnalysis.exceptions import SelectionError, NoDataError
from MDAnalysis.lib.log import ProgressMeter
from MDAnalysis.lib.util import asiterable, iterable, get_weights, deprecate
from MDAnalysis.core.groups import AtomGroup


logger = logging.getLogger('MDAnalysis.analysis.rmsd')
Expand Down Expand Up @@ -283,6 +284,8 @@ def process_selection(select):

if isinstance(select, string_types):
select = {'reference': str(select), 'mobile': str(select)}
elif isinstance(select, AtomGroup):
select = {'reference': select, 'mobile': select}
elif type(select) is tuple:
try:
select = {'mobile': select[0], 'reference': select[1]}
Expand All @@ -301,6 +304,10 @@ def process_selection(select):
raise TypeError("'select' must be either a string, 2-tuple, or dict")
select['mobile'] = asiterable(select['mobile'])
select['reference'] = asiterable(select['reference'])
if isinstance(select['mobile'], AtomGroup):
select['mobile'] = [select['mobile']]
if isinstance(select['reference'], AtomGroup):
select['reference'] = [select['reference']]
return select


Expand Down
18 changes: 15 additions & 3 deletions package/MDAnalysis/core/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -2850,15 +2850,27 @@ def select_atoms(self, sel, *othersel, **selgroups):
"You provided {} for group '{}'".format(
thing.__class__.__name__, group))

selections = tuple((selection.Parser.parse(s, selgroups, periodic=periodic)
for s in sel_strs))
new_sel_strs = []
selections = []

for s in sel_strs:
if isinstance(s, AtomGroup):
new_sel_strs.append(repr(s))
selections.append(s)
else:
new_sel_strs.append(s)
selections.append(selection.Parser.parse(s, selgroups, periodic=periodic))

if updating:
atomgrp = UpdatingAtomGroup(self, selections, sel_strs)
atomgrp = UpdatingAtomGroup(self, selections, new_sel_strs)
else:
# Apply the first selection and sum to it
atomgrp = sum([sel.apply(self) for sel in selections[1:]],
selections[0].apply(self))
return atomgrp

def apply(self, other):
return self.intersection(other)

def split(self, level):
"""Split :class:`AtomGroup` into a :class:`list` of
Expand Down
38 changes: 36 additions & 2 deletions testsuite/MDAnalysisTests/analysis/test_rms.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,17 @@ def test_rmsd_atomgroup_selections(self, universe):
R2 = MDAnalysis.analysis.rms.RMSD(universe.atoms.select_atoms("name CA"),
select="resid 1-30").run()
assert not np.allclose(R1.rmsd[:, 2], R2.rmsd[:, 2])

def test_rmsd_atomgroups(self, universe):
# compare string vs atomgroup
R1 = MDAnalysis.analysis.rms.RMSD(universe.atoms,
select="resid 1-30").run()
ag = universe.select_atoms('resid 1-30')
R2 = MDAnalysis.analysis.rms.RMSD(universe.atoms, select=ag).run()

assert_almost_equal(R1.rmsd, R2.rmsd, 4,
err_msg='error: rmsd profile should match ' +
'for selection strings and AtomGroups')

def test_rmsd_single_frame(self, universe):
RMSD = MDAnalysis.analysis.rms.RMSD(universe, select='name CA',
Expand All @@ -235,7 +246,7 @@ def test_mass_weighted_and_save(self, universe, outfile, correct_values):
assert_almost_equal(RMSD.rmsd, saved, 4,
err_msg="error: rmsd profile should match "
"saved test values")

def test_custom_weighted(self, universe, correct_values_mass):
RMSD = MDAnalysis.analysis.rms.RMSD(universe, weights="mass").run(step=49)

Expand Down Expand Up @@ -288,7 +299,7 @@ def test_rmsd_group_selections(self, universe, correct_values_group):
assert_almost_equal(RMSD.rmsd, correct_values_group, 4,
err_msg="error: rmsd profile should match"
"test values")

def test_rmsd_backbone_and_group_selection(self, universe,
correct_values_backbone_group):
RMSD = MDAnalysis.analysis.rms.RMSD(
Expand All @@ -300,6 +311,29 @@ def test_rmsd_backbone_and_group_selection(self, universe,
assert_almost_equal(
RMSD.rmsd, correct_values_backbone_group, 4,
err_msg="error: rmsd profile should match test values")

def test_rmsd_atomgroup_backbone_and_group_selection(self, universe,
correct_values_backbone_group):
RMSD = MDAnalysis.analysis.rms.RMSD(
universe,
reference=universe,
select=universe.select_atoms('backbone'),
groupselections=[universe.select_atoms('backbone and resid 1:10'),
universe.select_atoms('backbone and resid 10:20')
]).run(step=49)
assert_almost_equal(
RMSD.rmsd, correct_values_backbone_group, 4,
err_msg="error: rmsd profile should match test values")

def test_rmsd_select_atomgroup_tuple(self, universe, correct_values):
copy = universe.copy()
ag = universe.select_atoms('name CA')
ag2 = copy.select_atoms('name CA')
RMSD = MDAnalysis.analysis.rms.RMSD(universe,
reference=copy,
select=(ag, ag2)).run(step=49)
assert_almost_equal(RMSD.rmsd, correct_values, 4,
err_msg="error: rmsd profile should match test values")

def test_ref_length_unequal_len(self, universe):
reference = MDAnalysis.Universe(PSF, DCD)
Expand Down
30 changes: 30 additions & 0 deletions testsuite/MDAnalysisTests/core/test_atomselections.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,36 @@ def test_global(self, universe):
"resname LYS and name NZ and around 4 backbone")
ag2 = ag.select_atoms("around 4 global backbone")
assert_equal(ag2.indices, ag1.indices)

def test_returns_equal_AtomGroup_copy(self, universe):
ag = universe.select_atoms('protein')
assert ag == universe.select_atoms(ag)
assert ag is not universe.select_atoms(ag)

def test_returns_equal_UpdatingAtomGroup_copy(self, universe):
ag = universe.select_atoms('resid 100')
uag = universe.select_atoms(ag, updating=True)
no_uag = universe.select_atoms(uag)
assert ag == uag
assert no_uag == uag
assert isinstance(uag, mda.core.groups.UpdatingAtomGroup)
assert isinstance(no_uag, mda.core.groups.AtomGroup)
assert not isinstance(no_uag, mda.core.groups.UpdatingAtomGroup)

def test_returns_atomgroup_intersection(self, universe):
g1 = universe.select_atoms('resid 1:100')
g2 = universe.select_atoms('name CA')
g3 = universe.select_atoms('name O')

ag1 = g1.select_atoms(g2)
ag2 = g2.select_atoms(g1)
ag3 = g1.select_atoms(g2, g3)

assert ag1 == ag2
assert ag1 == (g1 & g2)
assert ag3 == ((g1 & g2) + (g1 & g3))




class TestSelectionsAMBER(object):
Expand Down