Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion package/CHANGELOG
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ The rules for this file:

------------------------------------------------------------------------------
??/??/?? richardjgowers, IAlibay, hmacdope, orbeckst, cbouy, lilyminium,
daveminh, jbarnoud
daveminh, jbarnoud, yuxuanzhuang



* 2.0.0
Expand All @@ -25,6 +26,8 @@ Fixes
* TOPParser no longer guesses elements when missing atomic number records
(Issues #2449, #2651)
* Testsuite does not any more matplotlib.use('agg') (#2191)
* In ChainReader, read_frame does not trigger change of iterating position.
(Issue #2723, PR #2815)

Enhancements
* Added the RDKitParser which creates a `core.topology.Topology` object from
Expand Down
62 changes: 29 additions & 33 deletions package/MDAnalysis/coordinates/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,11 @@
.. automethod:: _get
.. automethod:: _get_same
.. automethod:: _read_frame
.. automethod:: _chained_iterator

"""
import warnings

import os.path
import itertools
import bisect
import copy

Expand Down Expand Up @@ -141,15 +139,15 @@ def filter_times(times, dt):

def check_allowed_filetypes(readers, allowed):
"""
Make a check that all readers have the same filetype and are of the
Make a check that all readers have the same filetype and are of the
allowed files types. Throws Exception on failure.

Parameters
----------
readers : list of MDA readers
allowed : list of allowed formats
"""
classname = type(readers[0])
classname = type(readers[0])
only_one_reader = np.all([isinstance(r, classname) for r in readers])
if not only_one_reader:
readernames = [type(r) for r in readers]
Expand All @@ -158,7 +156,7 @@ def check_allowed_filetypes(readers, allowed):
"Found: {}".format(readernames))
if readers[0].format not in allowed:
raise NotImplementedError("ChainReader: continuous=True only "
"supported for formats: {}".format(allowed))
"supported for formats: {}".format(allowed))


class ChainReader(base.ProtoReader):
Expand Down Expand Up @@ -263,7 +261,8 @@ def __init__(self, filenames, skip=1, dt=None, continuous=False, **kwargs):
kwargs['dt'] = dt
self.readers = [core.reader(filename, **kwargs)
for filename in filenames]
self.filenames = np.array([fn[0] if isinstance(fn, tuple) else fn for fn in filenames])
self.filenames = np.array([fn[0] if isinstance(fn, tuple) else fn
for fn in filenames])
# pointer to "active" trajectory index into self.readers
self.__active_reader_index = 0

Expand All @@ -290,9 +289,6 @@ def __init__(self, filenames, skip=1, dt=None, continuous=False, **kwargs):
self.dts = np.array(self._get('dt'))
self.total_times = self.dts * n_frames

#: source for trajectories frame (fakes trajectory)
self.__chained_trajectories_iter = None

# calculate new start_frames to have a time continuous trajectory.
if continuous:
check_allowed_filetypes(self.readers, ['XTC', 'TRR'])
Expand Down Expand Up @@ -346,7 +342,8 @@ def __init__(self, filenames, skip=1, dt=None, continuous=False, **kwargs):
# check for interleaving
r1[1]
if r1_start_time < start_time < r1.time:
raise RuntimeError("ChainReader: Interleaving not supported with continuous=True.")
raise RuntimeError("ChainReader: Interleaving not supported "
"with continuous=True.")

# find end where trajectory was restarted from
for ts in r1[::-1]:
Expand Down Expand Up @@ -439,8 +436,6 @@ def copy(self):
new.ts = self.ts.copy()
return new



# attributes that can change with the current reader
@property
def filename(self):
Expand Down Expand Up @@ -561,49 +556,39 @@ def _read_frame(self, frame):
# update Timestep
self.ts = self.active_reader.ts
self.ts.frame = frame # continuous frames, 0-based
self.__current_frame = frame
return self.ts

def _chained_iterator(self):
"""Iterator that presents itself as a chained trajectory."""
self._rewind() # must rewind all readers
for i in range(self.n_frames):
j, f = self._get_local_frame(i)
self.__activate_reader(j)
self.ts = self.active_reader[f]
self.ts.frame = i
yield self.ts

def _read_next_timestep(self, ts=None):
self.ts = next(self.__chained_trajectories_iter)
return self.ts
if ts is None:
ts = self.ts
ts = self.__next__()
return ts

def rewind(self):
"""Set current frame to the beginning."""
self._rewind()
self.__chained_trajectories_iter = self._chained_iterator()
# set time step for frame 1
self.ts = next(self.__chained_trajectories_iter)

def _rewind(self):
"""Internal method: Rewind trajectories themselves and trj pointer."""
self.__current_frame = -1
self._apply('rewind')
self.__activate_reader(0)
self.__next__()

def close(self):
self._apply('close')

def __iter__(self):
"""Generator for all frames, starting at frame 1."""
self._rewind()
"""Generator for all frames, starting at frame 0."""
self.__current_frame = -1
# start from first frame
self.__chained_trajectories_iter = self._chained_iterator()
for ts in self.__chained_trajectories_iter:
yield ts
return self

def __repr__(self):
if len(self.filenames) > 3:
fnames = "{fname} and {nfanmes} more".format(
fname=os.path.basename(self.filenames[0]),
fname=os.path.basename(self.filenames[0]),
nfanmes=len(self.filenames) - 1)
else:
fnames = ", ".join([os.path.basename(fn) for fn in self.filenames])
Expand Down Expand Up @@ -656,3 +641,14 @@ def _apply_transformations(self, ts):
# to avoid applying the same transformations multiple times on each frame

return ts

def __next__(self):
if self.__current_frame < self.n_frames - 1:
j, f = self._get_local_frame(self.__current_frame + 1)
self.__activate_reader(j)
self.ts = self.active_reader[f]
self.ts.frame = self.__current_frame + 1
self.__current_frame += 1
return self.ts
else:
raise StopIteration()
5 changes: 5 additions & 0 deletions testsuite/MDAnalysisTests/coordinates/test_chainreader.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@ def test_frame_numbering(self, universe):
universe.trajectory[98] # index is 0-based and frames are 0-based
assert_equal(universe.trajectory.frame, 98, "wrong frame number")

def test_next_after_frame_numbering(self, universe):
universe.trajectory[98] # index is 0-based and frames are 0-based
universe.trajectory.next()
assert_equal(universe.trajectory.frame, 99, "wrong frame number")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this failed previously? Wow.


def test_frame(self, universe):
universe.trajectory[0]
coord0 = universe.atoms.positions.copy()
Expand Down