From a00840e9f033bc614c5b3830243c5ee34ac34ef1 Mon Sep 17 00:00:00 2001 From: richardjgowers Date: Wed, 14 Nov 2018 18:16:06 -0600 Subject: [PATCH 1/8] added experimental pickling support to Universe --- package/MDAnalysis/core/universe.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/package/MDAnalysis/core/universe.py b/package/MDAnalysis/core/universe.py index f869ffad945..b8ec883676d 100644 --- a/package/MDAnalysis/core/universe.py +++ b/package/MDAnalysis/core/universe.py @@ -768,11 +768,15 @@ def __repr__(self): return "".format( n_atoms=len(self.atoms)) - def __getstate__(self): - raise NotImplementedError + @classmethod + def _unpickle_U(cls, top, traj, anchor): + u = cls(top, anchor_name=anchor) + u.load_new(traj) + + return u - def __setstate__(self, state): - raise NotImplementedError + def __reduce__(self): + return (self._unpickle_U, (self._topology, self.trajectory.filename, self.anchor_name)) # Properties @property From 133072f90c585f4be47e1a64141457bd1b4dd5b1 Mon Sep 17 00:00:00 2001 From: richardjgowers Date: Sat, 17 Nov 2018 16:10:08 -0600 Subject: [PATCH 2/8] fixed pickle test --- testsuite/MDAnalysisTests/core/test_universe.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/testsuite/MDAnalysisTests/core/test_universe.py b/testsuite/MDAnalysisTests/core/test_universe.py index 20d5f2a3293..05663c5af6b 100644 --- a/testsuite/MDAnalysisTests/core/test_universe.py +++ b/testsuite/MDAnalysisTests/core/test_universe.py @@ -283,10 +283,14 @@ def test_load_multiple_args(self): assert_equal(len(u.atoms), 3341, "Loading universe failed somehow") assert_equal(u.trajectory.n_frames, 2 * ref.trajectory.n_frames) - def test_pickle_raises_NotImplementedError(self): + def test_pickle(self): u = mda.Universe(PSF, DCD) - with pytest.raises(NotImplementedError): - cPickle.dumps(u, protocol = cPickle.HIGHEST_PROTOCOL) + + s = cPickle.dumps(u, protocol = cPickle.HIGHEST_PROTOCOL) + + new_u = cPickle.loads(s) + + assert_equal(u.atoms.names, new_u.atoms.names) def test_set_dimensions(self): u = mda.Universe(PSF, DCD) From aa595f58500cf1a30afba9b6de0b658a63150eb0 Mon Sep 17 00:00:00 2001 From: richardjgowers Date: Wed, 21 Nov 2018 13:50:53 -0600 Subject: [PATCH 3/8] wip of reader pickling --- package/MDAnalysis/coordinates/base.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/package/MDAnalysis/coordinates/base.py b/package/MDAnalysis/coordinates/base.py index d29b741fd75..87fe76be39d 100644 --- a/package/MDAnalysis/coordinates/base.py +++ b/package/MDAnalysis/coordinates/base.py @@ -2078,6 +2078,20 @@ def __init__(self, filename, convert_units=None, **kwargs): self._ts_kwargs = ts_kwargs + @classmethod + def _unpickle_Reader(cls, filename, timestep, auxs, trans): + new_R = cls(filename) + new_R.add_transformations(trans) + for auxname, auxdata in auxs.items(): + new_R.add_aux(auxname, auxdata) + + def __reduce__(self): + return (self._unpickle_Reader, + (self.filename, self.ts, self._auxs, self._transformations)) + + def __len__(self): + return self.n_frames + def copy(self): """Return independent copy of this Reader. From ba1a13876001e4472914b463432fa3e313bfff24 Mon Sep 17 00:00:00 2001 From: richardjgowers Date: Sat, 16 Feb 2019 10:51:59 -0600 Subject: [PATCH 4/8] simplified serialisation support more --- package/MDAnalysis/coordinates/base.py | 11 ----------- package/MDAnalysis/core/universe.py | 15 +++++++++++---- 2 files changed, 11 insertions(+), 15 deletions(-) diff --git a/package/MDAnalysis/coordinates/base.py b/package/MDAnalysis/coordinates/base.py index 87fe76be39d..0a08e7b35d5 100644 --- a/package/MDAnalysis/coordinates/base.py +++ b/package/MDAnalysis/coordinates/base.py @@ -2078,17 +2078,6 @@ def __init__(self, filename, convert_units=None, **kwargs): self._ts_kwargs = ts_kwargs - @classmethod - def _unpickle_Reader(cls, filename, timestep, auxs, trans): - new_R = cls(filename) - new_R.add_transformations(trans) - for auxname, auxdata in auxs.items(): - new_R.add_aux(auxname, auxdata) - - def __reduce__(self): - return (self._unpickle_Reader, - (self.filename, self.ts, self._auxs, self._transformations)) - def __len__(self): return self.n_frames diff --git a/package/MDAnalysis/core/universe.py b/package/MDAnalysis/core/universe.py index b8ec883676d..27e2b7946a1 100644 --- a/package/MDAnalysis/core/universe.py +++ b/package/MDAnalysis/core/universe.py @@ -738,7 +738,7 @@ def _gen_anchor_hash(self): return self._anchor_uuid except AttributeError: # store this so we can later recall it if needed - self._anchor_uuid = uuid.uuid4() + self._anchor_uuid = str(uuid.uuid4()) return self._anchor_uuid @property @@ -770,13 +770,20 @@ def __repr__(self): @classmethod def _unpickle_U(cls, top, traj, anchor): - u = cls(top, anchor_name=anchor) - u.load_new(traj) + """Special method used by __reduce__ to deserialise a Universe""" + # top is a Topology object at this point, but Universe can handle that + u = cls(top) + u.anchor_name = anchor + # maybe this is None, but that's still cool + u.trajectory = traj return u def __reduce__(self): - return (self._unpickle_U, (self._topology, self.trajectory.filename, self.anchor_name)) + # Can't quite use __setstate__/__getstate__ so go via __reduce__ + # Universe's two "legs" of topology and traj both serialise themselves + # the only other state held in Universe is anchor name? + return (self._unpickle_U, (self._topology, self._trajectory, self.anchor_name)) # Properties @property From caa588d1394df57d1d3f662be02dc66ae3b70ac1 Mon Sep 17 00:00:00 2001 From: richardjgowers Date: Sat, 16 Feb 2019 11:14:30 -0600 Subject: [PATCH 5/8] make AuxReader not seralise (until tested) --- package/MDAnalysis/auxiliary/base.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/package/MDAnalysis/auxiliary/base.py b/package/MDAnalysis/auxiliary/base.py index 50c3c3bf7c9..e769cfde883 100644 --- a/package/MDAnalysis/auxiliary/base.py +++ b/package/MDAnalysis/auxiliary/base.py @@ -307,6 +307,10 @@ def __init__(self, represent_ts_as='closest', auxname=None, cutoff=-1, self.auxstep._dt = self.time - self.initial_time self.rewind() + def __getstate__(self): + # probably works fine, but someone needs to write tests to confirm + return NotImplementedError + def copy(self): raise NotImplementedError("Copy not implemented for AuxReader") From e21361d509457ade46cda96c04083ae3c16101a4 Mon Sep 17 00:00:00 2001 From: richardjgowers Date: Sat, 16 Feb 2019 11:15:07 -0600 Subject: [PATCH 6/8] add tests for multiprocessing --- .../MDAnalysisTests/parallelism/__init__.py | 0 .../parallelism/test_multiprocessing.py | 38 +++++++++++++++++++ 2 files changed, 38 insertions(+) create mode 100644 testsuite/MDAnalysisTests/parallelism/__init__.py create mode 100644 testsuite/MDAnalysisTests/parallelism/test_multiprocessing.py diff --git a/testsuite/MDAnalysisTests/parallelism/__init__.py b/testsuite/MDAnalysisTests/parallelism/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/testsuite/MDAnalysisTests/parallelism/test_multiprocessing.py b/testsuite/MDAnalysisTests/parallelism/test_multiprocessing.py new file mode 100644 index 00000000000..dfcf57ab3f4 --- /dev/null +++ b/testsuite/MDAnalysisTests/parallelism/test_multiprocessing.py @@ -0,0 +1,38 @@ +"""Test that MDAnalysis plays nicely with multiprocessing + +""" +import multiprocessing +import numpy as np +import pytest + +import MDAnalysis as mda +from MDAnalysisTests.datafiles import ( + PSF, DCD +) + +from numpy.testing import assert_equal + + +@pytest.fixture +def u(): + return mda.Universe(PSF, DCD) + + +def cog(u, ag, frame_id): + u.trajectory[frame_id] + + return ag.center_of_geometry() + + +def test_multiprocess_COM(u): + ag = u.atoms[10:20] + + ref = np.array([cog(u, ag, i) + for i in range(4)]) + + p = multiprocessing.Pool(2) + + res = np.array([p.apply(cog, args=(u, ag, i)) + for i in range(4)]) + + assert_equal(ref, res) From 99bdd08b4cd8ffa329cbe1c75969f3f03f1b94bf Mon Sep 17 00:00:00 2001 From: richardjgowers Date: Sat, 16 Feb 2019 11:24:45 -0600 Subject: [PATCH 7/8] added more formats to multiprocessing tests and broke everything --- .../parallelism/test_multiprocessing.py | 44 ++++++++++++++++--- 1 file changed, 37 insertions(+), 7 deletions(-) diff --git a/testsuite/MDAnalysisTests/parallelism/test_multiprocessing.py b/testsuite/MDAnalysisTests/parallelism/test_multiprocessing.py index dfcf57ab3f4..7705e8a7823 100644 --- a/testsuite/MDAnalysisTests/parallelism/test_multiprocessing.py +++ b/testsuite/MDAnalysisTests/parallelism/test_multiprocessing.py @@ -7,32 +7,62 @@ import MDAnalysis as mda from MDAnalysisTests.datafiles import ( - PSF, DCD + PSF, DCD, + GRO, XTC, + PDB, + XYZ, ) from numpy.testing import assert_equal -@pytest.fixture -def u(): - return mda.Universe(PSF, DCD) - +@pytest.fixture(params=[ + (PSF, DCD), + (GRO, XTC), + (PDB,), + (XYZ,), +]) +def u(request): + if len(request.param) == 1: + f = request.param + return mda.Universe(f) + else: + top, trj = request.param + return mda.Universe(top, trj) +# Define target functions here +# inside test functions doesn't work def cog(u, ag, frame_id): u.trajectory[frame_id] return ag.center_of_geometry() -def test_multiprocess_COM(u): +def getnames(u, ix): + # Check topology stuff works + return u.atoms[ix].name + + +def test_multiprocess_COG(u): ag = u.atoms[10:20] ref = np.array([cog(u, ag, i) for i in range(4)]) p = multiprocessing.Pool(2) - res = np.array([p.apply(cog, args=(u, ag, i)) for i in range(4)]) + p.close() + assert_equal(ref, res) + + +def test_multiprocess_names(u): + ref = [getnames(u, i) + for i in range(10)] + + p = multiprocessing.Pool(2) + res = [p.apply(getnames, args=(u, i)) + for i in range(10)] + p.close() assert_equal(ref, res) From f1a7d5aec969d30adc31b719d637b1bcfe386024 Mon Sep 17 00:00:00 2001 From: richardjgowers Date: Wed, 20 Feb 2019 16:33:55 -0600 Subject: [PATCH 8/8] pickling all readers now works... --- package/MDAnalysis/coordinates/DLPoly.py | 37 ++--- package/MDAnalysis/coordinates/GMS.py | 22 +-- package/MDAnalysis/coordinates/PDB.py | 12 +- package/MDAnalysis/coordinates/TRJ.py | 144 ++++++++++-------- package/MDAnalysis/coordinates/TRZ.py | 40 ++--- package/MDAnalysis/coordinates/XYZ.py | 20 +-- package/MDAnalysis/coordinates/base.py | 38 ++++- package/MDAnalysis/coordinates/chain.py | 12 +- .../parallelism/test_multiprocessing.py | 73 ++++++++- 9 files changed, 261 insertions(+), 137 deletions(-) diff --git a/package/MDAnalysis/coordinates/DLPoly.py b/package/MDAnalysis/coordinates/DLPoly.py index 5e322e4ffc1..7901d05f59c 100644 --- a/package/MDAnalysis/coordinates/DLPoly.py +++ b/package/MDAnalysis/coordinates/DLPoly.py @@ -37,6 +37,7 @@ from . import base from . import core +from ..lib import util _DLPOLY_UNITS = {'length': 'Angstrom', 'velocity': 'Angstrom/ps', 'time': 'ps'} @@ -141,7 +142,7 @@ def _read_first_frame(self): ts.frame = 0 -class HistoryReader(base.ReaderBase): +class HistoryReader(base.ReaderBase, base._AsciiPickle): """Reads DLPoly format HISTORY files .. versionadded:: 0.11.0 @@ -154,9 +155,9 @@ def __init__(self, filename, **kwargs): super(HistoryReader, self).__init__(filename, **kwargs) # "private" file handle - self._file = open(self.filename, 'r') - self.title = self._file.readline().strip() - self._levcfg, self._imcon, self.n_atoms = np.int64(self._file.readline().split()[:3]) + self._f = util.anyopen(self.filename, 'r') + self.title = self._f.readline().strip() + self._levcfg, self._imcon, self.n_atoms = np.int64(self._f.readline().split()[:3]) self._has_vels = True if self._levcfg > 0 else False self._has_forces = True if self._levcfg == 2 else False @@ -170,20 +171,20 @@ def _read_next_timestep(self, ts=None): if ts is None: ts = self.ts - line = self._file.readline() # timestep line + line = self._f.readline() # timestep line if not line.startswith('timestep'): raise IOError if not self._imcon == 0: - ts._unitcell[0] = self._file.readline().split() - ts._unitcell[1] = self._file.readline().split() - ts._unitcell[2] = self._file.readline().split() + ts._unitcell[0] = self._f.readline().split() + ts._unitcell[1] = self._f.readline().split() + ts._unitcell[2] = self._f.readline().split() # If ids are given, put them in here # and later sort by them ids = [] for i in range(self.n_atoms): - line = self._file.readline().strip() # atom info line + line = self._f.readline().strip() # atom info line try: idx = int(line.split()[1]) except IndexError: @@ -192,11 +193,11 @@ def _read_next_timestep(self, ts=None): ids.append(idx) # Read in this order for now, then later reorder in place - ts._pos[i] = self._file.readline().split() + ts._pos[i] = self._f.readline().split() if self._has_vels: - ts._velocities[i] = self._file.readline().split() + ts._velocities[i] = self._f.readline().split() if self._has_forces: - ts._forces[i] = self._file.readline().split() + ts._forces[i] = self._f.readline().split() if ids: ids = np.array(ids) @@ -214,7 +215,7 @@ def _read_next_timestep(self, ts=None): def _read_frame(self, frame): """frame is 0 based, error checking is done in base.getitem""" - self._file.seek(self._offsets[frame]) + self._f.seek(self._offsets[frame]) self.ts.frame = frame - 1 # gets +1'd in read_next_frame return self._read_next_timestep() @@ -234,7 +235,7 @@ def _read_n_frames(self): """ offsets = self._offsets = [] - with open(self.filename, 'r') as f: + with util.anyopen(self.filename, 'r') as f: n_frames = 0 f.readline() @@ -262,10 +263,10 @@ def _read_n_frames(self): def _reopen(self): self.close() - self._file = open(self.filename, 'r') - self._file.readline() # header is 2 lines - self._file.readline() + self._f = util.anyopen(self.filename, 'r') + self._f.readline() # header is 2 lines + self._f.readline() self.ts.frame = -1 def close(self): - self._file.close() + self._f.close() diff --git a/package/MDAnalysis/coordinates/GMS.py b/package/MDAnalysis/coordinates/GMS.py index 46be3a34c09..36251bfa086 100644 --- a/package/MDAnalysis/coordinates/GMS.py +++ b/package/MDAnalysis/coordinates/GMS.py @@ -47,7 +47,7 @@ import MDAnalysis.lib.util as util -class GMSReader(base.ReaderBase): +class GMSReader(base.ReaderBase, base._AsciiPickle): """Reads from an GAMESS output file :Data: @@ -82,7 +82,7 @@ def __init__(self, outfilename, **kwargs): super(GMSReader, self).__init__(outfilename, **kwargs) # the filename has been parsed to be either b(g)zipped or not - self.outfile = util.anyopen(self.filename) + self._f = util.anyopen(self.filename) # note that, like for xtc and trr files, _n_atoms and _n_frames are used quasi-private variables # to prevent the properties being recalculated @@ -177,7 +177,7 @@ def _read_out_n_frames(self): return len(offsets) def _read_frame(self, frame): - self.outfile.seek(self._offsets[frame]) + self._f.seek(self._offsets[frame]) self.ts.frame = frame - 1 # gets +1'd in _read_next return self._read_next_timestep() @@ -186,7 +186,7 @@ def _read_next_timestep(self, ts=None): if ts is None: ts = self.ts # check that the outfile object exists; if not reopen the trajectory - if self.outfile is None: + if self._f is None: self.open_trajectory() x = [] y = [] @@ -195,7 +195,7 @@ def _read_next_timestep(self, ts=None): flag = 0 counter = 0 - for line in self.outfile: + for line in self._f: if self.runtyp == 'optimize': if (flag == 0) and (re.match(r'^.NSERCH=.*', line) is not None): flag = 1 @@ -246,22 +246,22 @@ def _reopen(self): self.open_trajectory() def open_trajectory(self): - if self.outfile is not None: + if self._f is not None: raise IOError(errno.EALREADY, 'GMS file already opened', self.filename) if not os.path.exists(self.filename): # must check; otherwise might segmentation fault raise IOError(errno.ENOENT, 'GMS file not found', self.filename) - self.outfile = util.anyopen(self.filename) + self._f = util.anyopen(self.filename) # reset ts ts = self.ts ts.frame = -1 - return self.outfile + return self._f def close(self): """Close out trajectory file if it was open.""" - if self.outfile is None: + if self._f is None: return - self.outfile.close() - self.outfile = None + self._f.close() + self._f = None diff --git a/package/MDAnalysis/coordinates/PDB.py b/package/MDAnalysis/coordinates/PDB.py index 4506ebc7cce..534a4a189fd 100644 --- a/package/MDAnalysis/coordinates/PDB.py +++ b/package/MDAnalysis/coordinates/PDB.py @@ -168,7 +168,7 @@ # Pairs of residue name / atom name in use to deduce PDB formatted atom names Pair = collections.namedtuple('Atom', 'resname name') -class PDBReader(base.ReaderBase): +class PDBReader(base.ReaderBase, base._BAsciiPickle): """PDBReader that reads a `PDB-formatted`_ file, no frills. The following *PDB records* are parsed (see `PDB coordinate section`_ for @@ -277,7 +277,7 @@ def __init__(self, filename, **kwargs): if isinstance(filename, util.NamedStream) and isinstance(filename.stream, StringIO): filename.stream = BytesIO(filename.stream.getvalue().encode()) - pdbfile = self._pdbfile = util.anyopen(filename, 'rb') + pdbfile = self._f = util.anyopen(filename, 'rb') line = "magical" while line: @@ -345,7 +345,7 @@ def _reopen(self): # Pretend the current TS is -1 (in 0 based) so "next" is the # 0th frame self.close() - self._pdbfile = util.anyopen(self.filename, 'rb') + self._f = util.anyopen(self.filename, 'rb') self.ts.frame = -1 def _read_next_timestep(self, ts=None): @@ -371,8 +371,8 @@ def _read_frame(self, frame): occupancy = np.ones(self.n_atoms) # Seek to start and read until start of next frame - self._pdbfile.seek(start) - chunk = self._pdbfile.read(stop - start).decode() + self._f.seek(start) + chunk = self._f.read(stop - start).decode() tmp_buf = [] for line in chunk.splitlines(): @@ -411,7 +411,7 @@ def _read_frame(self, frame): return self.ts def close(self): - self._pdbfile.close() + self._f.close() class PDBWriter(base.WriterBase): diff --git a/package/MDAnalysis/coordinates/TRJ.py b/package/MDAnalysis/coordinates/TRJ.py index 867d38ede20..8a16fcba713 100644 --- a/package/MDAnalysis/coordinates/TRJ.py +++ b/package/MDAnalysis/coordinates/TRJ.py @@ -188,7 +188,7 @@ class Timestep(base.Timestep): order = 'C' -class TRJReader(base.ReaderBase): +class TRJReader(base.ReaderBase, base._AsciiPickle): """AMBER trajectory reader. Reads the ASCII formatted `AMBER TRJ format`_. Periodic box information @@ -223,7 +223,7 @@ def __init__(self, filename, n_atoms=None, **kwargs): self._n_atoms = n_atoms self._n_frames = None - self.trjfile = None # have _read_next_timestep() open it properly! + self._f = None # have _read_next_timestep() open it properly! self.ts = self._Timestep(self.n_atoms, **self._ts_kwargs) # FORMAT(10F8.3) (X(i), Y(i), Z(i), i=1,NATOM) @@ -248,22 +248,22 @@ def __init__(self, filename, n_atoms=None, **kwargs): self._read_next_timestep() def _read_frame(self, frame): - if self.trjfile is None: + if self._f is None: self.open_trajectory() - self.trjfile.seek(self._offsets[frame]) + self._f.seek(self._offsets[frame]) self.ts.frame = frame - 1 # gets +1'd in _read_next return self._read_next_timestep() def _read_next_timestep(self): # FORMAT(10F8.3) (X(i), Y(i), Z(i), i=1,NATOM) ts = self.ts - if self.trjfile is None: + if self._f is None: self.open_trajectory() # Read coordinat frame: # coordinates = numpy.zeros(3*self.n_atoms, dtype=np.float32) _coords = [] - for number, line in enumerate(self.trjfile): + for number, line in enumerate(self._f): try: _coords.extend(self.default_line_parser.read(line)) except ValueError: @@ -278,7 +278,7 @@ def _read_next_timestep(self): # Read box information if self.periodic: - line = next(self.trjfile) + line = next(self._f) box = self.box_line_parser.read(line) ts._unitcell[:3] = np.array(box, dtype=np.float32) ts._unitcell[3:] = [90., 90., 90.] # assumed @@ -325,7 +325,7 @@ def _detect_amber_box(self): self._read_next_timestep() ts = self.ts # TODO: what do we do with 1-frame trajectories? Try..except EOFError? - line = next(self.trjfile) + line = next(self._f) nentries = self.default_line_parser.number_of_matches(line) if nentries == 3: self.periodic = True @@ -376,8 +376,8 @@ def _reopen(self): def open_trajectory(self): """Open the trajectory for reading and load first frame.""" - self.trjfile = util.anyopen(self.filename) - self.header = self.trjfile.readline() # ignore first line + self._f = util.anyopen(self.filename) + self.header = self._f.readline() # ignore first line if len(self.header.rstrip()) > 80: # Chimera uses this check raise OSError( @@ -387,14 +387,14 @@ def open_trajectory(self): ts = self.ts ts.frame = -1 - return self.trjfile + return self._f def close(self): """Close trj trajectory file if it was open.""" - if self.trjfile is None: + if self._f is None: return - self.trjfile.close() - self.trjfile = None + self._f.close() + self._f = None class NCDFReader(base.ReaderBase): @@ -465,36 +465,28 @@ def __init__(self, filename, n_atoms=None, mmap=None, **kwargs): super(NCDFReader, self).__init__(filename, **kwargs) - self.trjfile = scipy.io.netcdf.netcdf_file(self.filename, - mmap=self._mmap) + self._f = scipy.io.netcdf.netcdf_file(self.filename, + mmap=self._mmap) - if not ('AMBER' in self.trjfile.Conventions.decode('utf-8').split(',') or - 'AMBER' in self.trjfile.Conventions.decode('utf-8').split()): + if not ('AMBER' in self._f.Conventions.decode('utf-8').split(',') or + 'AMBER' in self._f.Conventions.decode('utf-8').split()): errmsg = ("NCDF trajectory {0} does not conform to AMBER " "specifications, http://ambermd.org/netcdf/nctraj.xhtml " "('AMBER' must be one of the tokens in attribute " "Conventions)".format(self.filename)) logger.fatal(errmsg) raise TypeError(errmsg) - if not self.trjfile.ConventionVersion.decode('utf-8') == self.version: + if not self._f.ConventionVersion.decode('utf-8') == self.version: wmsg = ("NCDF trajectory format is {0!s} but the reader " "implements format {1!s}".format( - self.trjfile.ConventionVersion, self.version)) + self._f.ConventionVersion, self.version)) warnings.warn(wmsg) logger.warning(wmsg) - self.n_atoms = self.trjfile.dimensions['atom'] - self.n_frames = self.trjfile.dimensions['frame'] - # example trajectory when read with scipy.io.netcdf has - # dimensions['frame'] == None (indicating a record dimension that can - # grow) whereas if read with netCDF4 I get len(dimensions['frame']) == - # 10: in any case, we need to get the number of frames from somewhere - # such as the time variable: - if self.n_frames is None: - self.n_frames = self.trjfile.variables['time'].shape[0] + self.n_atoms = self._f.dimensions['atom'] try: - self.remarks = self.trjfile.title + self.remarks = self._f.title except AttributeError: self.remarks = "" # other metadata (*= requd): @@ -505,27 +497,27 @@ def __init__(self, filename, n_atoms=None, mmap=None, **kwargs): # checks for not-implemented features (other units would need to be # hacked into MDAnalysis.units) - if self.trjfile.variables['time'].units.decode('utf-8') != "picosecond": + if self._f.variables['time'].units.decode('utf-8') != "picosecond": raise NotImplementedError( "NETCDFReader currently assumes that the trajectory was written " "with a time unit of picoseconds and not {0}.".format( - self.trjfile.variables['time'].units)) - if self.trjfile.variables['coordinates'].units.decode('utf-8') != "angstrom": + self._f.variables['time'].units)) + if self._f.variables['coordinates'].units.decode('utf-8') != "angstrom": raise NotImplementedError( "NETCDFReader currently assumes that the trajectory was written " "with a length unit of Angstroem and not {0}.".format( - self.trjfile.variables['coordinates'].units)) - if hasattr(self.trjfile.variables['coordinates'], 'scale_factor'): + self._f.variables['coordinates'].units)) + if hasattr(self._f.variables['coordinates'], 'scale_factor'): raise NotImplementedError("scale_factors are not implemented") if n_atoms is not None and n_atoms != self.n_atoms: raise ValueError("Supplied n_atoms ({0}) != natom from ncdf ({1}). " "Note: n_atoms can be None and then the ncdf value " "is used!".format(n_atoms, self.n_atoms)) - self.has_velocities = 'velocities' in self.trjfile.variables - self.has_forces = 'forces' in self.trjfile.variables + self.has_velocities = 'velocities' in self._f.variables + self.has_forces = 'forces' in self._f.variables - self.periodic = 'cell_lengths' in self.trjfile.variables + self.periodic = 'cell_lengths' in self._f.variables self._current_frame = 0 self.ts = self._Timestep(self.n_atoms, @@ -537,6 +529,30 @@ def __init__(self, filename, n_atoms=None, mmap=None, **kwargs): # load first data frame self._read_frame(0) + def __getstate__(self): + state = self.__dict__.copy() + del state['_f'] + + return state + + def __setstate__(self, state): + self.__dict__.update(state) + self._f = scipy.io.netcdf.netcdf_file(self.filename, + mmap=self._mmap) + + @property + def n_frames(self): + n_frames = self._f.dimensions['frame'] + # example trajectory when read with scipy.io.netcdf has + # dimensions['frame'] == None (indicating a record dimension that can + # grow) whereas if read with netCDF4 I get len(dimensions['frame']) == + # 10: in any case, we need to get the number of frames from somewhere + # such as the time variable: + if n_frames is None: + n_frames = self._f.variables['time'].shape[0] + + return n_frames + @staticmethod def parse_n_atoms(filename, **kwargs): with scipy.io.netcdf.netcdf_file(filename, mmap=None) as f: @@ -546,7 +562,7 @@ def parse_n_atoms(filename, **kwargs): def _read_frame(self, frame): ts = self.ts - if self.trjfile is None: + if self._f is None: raise IOError("Trajectory is closed") if np.dtype(type(frame)) != np.dtype(int): # convention... for netcdf could also be a slice @@ -554,16 +570,16 @@ def _read_frame(self, frame): if frame >= self.n_frames or frame < 0: raise IndexError("frame index must be 0 <= frame < {0}".format( self.n_frames)) - # note: self.trjfile.variables['coordinates'].shape == (frames, n_atoms, 3) - ts._pos[:] = self.trjfile.variables['coordinates'][frame] - ts.time = self.trjfile.variables['time'][frame] + # note: self._f.variables['coordinates'].shape == (frames, n_atoms, 3) + ts._pos[:] = self._f.variables['coordinates'][frame] + ts.time = self._f.variables['time'][frame] if self.has_velocities: - ts._velocities[:] = self.trjfile.variables['velocities'][frame] + ts._velocities[:] = self._f.variables['velocities'][frame] if self.has_forces: - ts._forces[:] = self.trjfile.variables['forces'][frame] + ts._forces[:] = self._f.variables['forces'][frame] if self.periodic: - ts._unitcell[:3] = self.trjfile.variables['cell_lengths'][frame] - ts._unitcell[3:] = self.trjfile.variables['cell_angles'][frame] + ts._unitcell[:3] = self._f.variables['cell_lengths'][frame] + ts._unitcell[3:] = self._f.variables['cell_angles'][frame] if self.convert_units: self.convert_pos_from_native(ts._pos) # in-place ! self.convert_time_from_native( @@ -592,8 +608,8 @@ def _read_next_timestep(self, ts=None): raise IOError def _get_dt(self): - t1 = self.trjfile.variables['time'][1] - t0 = self.trjfile.variables['time'][0] + t1 = self._f.variables['time'][1] + t0 = self._f.variables['time'][0] return t1 - t0 def close(self): @@ -605,9 +621,9 @@ def close(self): before the file can be closed. """ - if self.trjfile is not None: - self.trjfile.close() - self.trjfile = None + if self._f is not None: + self._f.close() + self._f = None def Writer(self, filename, **kwargs): """Returns a NCDFWriter for `filename` with the same parameters as this NCDF. @@ -763,7 +779,7 @@ def __init__(self, self.ts = None # when/why would this be assigned?? self._first_frame = True # signals to open trajectory - self.trjfile = None # open on first write with _init_netcdf() + self._f = None # open on first write with _init_netcdf() self.periodic = None # detect on first write self.has_velocities = kwargs.get('velocities', False) self.has_forces = kwargs.get('forces', False) @@ -862,7 +878,7 @@ def _init_netcdf(self, periodic=True): ncfile.sync() self._first_frame = False - self.trjfile = ncfile + self._f = ncfile def is_periodic(self, ts=None): """Test if `Timestep` contains a periodic trajectory. @@ -902,7 +918,7 @@ def write_next_timestep(self, ts=None): raise IOError( "NCDFWriter: Timestep does not have the correct number of atoms") - if self.trjfile is None: + if self._f is None: # first time step: analyze data and open trajectory accordingly self._init_netcdf(periodic=self.is_periodic(ts)) @@ -938,12 +954,12 @@ def _write_next_timestep(self, ts): unitcell = self.convert_dimensions_to_unitcell(ts) # write step - self.trjfile.variables['coordinates'][self.curr_frame, :, :] = pos - self.trjfile.variables['time'][self.curr_frame] = time + self._f.variables['coordinates'][self.curr_frame, :, :] = pos + self._f.variables['time'][self.curr_frame] = time if self.periodic: - self.trjfile.variables['cell_lengths'][ + self._f.variables['cell_lengths'][ self.curr_frame, :] = unitcell[:3] - self.trjfile.variables['cell_angles'][ + self._f.variables['cell_angles'][ self.curr_frame, :] = unitcell[3:] if self.has_velocities: @@ -951,19 +967,19 @@ def _write_next_timestep(self, ts): if self.convert_units: velocities = self.convert_velocities_to_native( velocities, inplace=False) - self.trjfile.variables['velocities'][self.curr_frame, :, :] = velocities + self._f.variables['velocities'][self.curr_frame, :, :] = velocities if self.has_forces: forces = ts._forces if self.convert_units: forces = self.convert_forces_to_native( forces, inplace=False) - self.trjfile.variables['forces'][self.curr_frame, :, :] = forces + self._f.variables['forces'][self.curr_frame, :, :] = forces - self.trjfile.sync() + self._f.sync() self.curr_frame += 1 def close(self): - if self.trjfile is not None: - self.trjfile.close() - self.trjfile = None + if self._f is not None: + self._f.close() + self._f = None diff --git a/package/MDAnalysis/coordinates/TRZ.py b/package/MDAnalysis/coordinates/TRZ.py index a469ee09aba..80123e41d45 100644 --- a/package/MDAnalysis/coordinates/TRZ.py +++ b/package/MDAnalysis/coordinates/TRZ.py @@ -126,7 +126,7 @@ def dimensions(self, box): self._unitcell[:] = triclinic_vectors(box).reshape(9) -class TRZReader(base.ReaderBase): +class TRZReader(base.ReaderBase, base._BAsciiPickle): """Reads an IBIsCO or YASP trajectory file Attributes @@ -170,7 +170,7 @@ def __init__(self, trzfilename, n_atoms=None, **kwargs): if n_atoms is None: raise ValueError('TRZReader requires the n_atoms keyword') - self.trzfile = util.anyopen(self.filename, 'rb') + self._f = util.anyopen(self.filename, 'rb') self._cache = dict() self._n_atoms = n_atoms @@ -234,7 +234,7 @@ def _read_trz_header(self): ('p2', '<2i4'), ('force', '