Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ authors.py
*.DS_Store
# ignore files from tests
.hypothesis/
# ignore results from asv
benchmarks/results

# duecredit
.duecredit.p
Expand Down
10 changes: 10 additions & 0 deletions benchmarks/benchmarks/ag_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,3 +403,13 @@ def setup(self, universe_type):

def time_find_fragments(self, universe_type):
frags = self.u.atoms.fragments


class FragmentCaching(FragmentFinding):
"""Test how quickly we find cached fragments"""
def setup(self, universe_type):
super(FragmentCaching, self).setup(universe_type)
frags = self.u.atoms.fragments # Priming the cache

def time_find_cached_fragments(self, universe_type):
frags = self.u.atoms.fragments
5 changes: 4 additions & 1 deletion package/CHANGELOG
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,11 @@ Fixes
* Fix syntax warning over comparison of literals using is (Issue #3066)

Enhancements
* Caches can now undergo central validation at the Universe level, opening
the door to more useful caching. Already applied to fragment caching
(Issue #2376, PR #3135)
* Code for operations on compounds refactored, centralized and optimized for
performance (Issue #3000)
performance (Issue #3000, PR #3005)
* Added automatic selection class generation for TopologyAttrs,
FloatRangeSelection, and BoolSelection (Issues #2925, #2875; PR #2927)
* Added 'to' operator, negatives, scientific notation, and arbitrary
Expand Down
2 changes: 2 additions & 0 deletions package/MDAnalysis/core/topologyattrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2404,6 +2404,7 @@ def fragindex(self):
"""
return self.universe._fragdict[self.ix].ix

@cached('fragindices', universe_validation=True)
def fragindices(self):
r"""The
:class:`fragment indices<MDAnalysis.core.topologyattrs.Bonds.fragindex>`
Expand Down Expand Up @@ -2437,6 +2438,7 @@ def fragment(self):
"""
return self.universe._fragdict[self.ix].fragment

@cached('fragments', universe_validation=True)
def fragments(self):
"""Read-only :class:`tuple` of
:class:`fragments<MDAnalysis.core.topologyattrs.Bonds.fragment>`.
Expand Down
8 changes: 7 additions & 1 deletion package/MDAnalysis/core/universe.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def __init__(self, topology=None, *coordinates, all_coordinates=False,
in_memory_step=1, **kwargs):

self._trajectory = None # managed attribute holding Reader
self._cache = {}
self._cache = {'_valid': {}}
self.atoms = None
self.residues = None
self.segments = None
Expand Down Expand Up @@ -1002,7 +1002,10 @@ def add_bonds(self, values, types=None, guessed=False, order=None):
"""
self._add_topology_objects('bonds', values, types=types,
guessed=guessed, order=order)
# Invalidate bond-related caches
self._cache.pop('fragments', None)
self._cache['_valid'].pop('fragments', None)
self._cache['_valid'].pop('fragindices', None)

def add_angles(self, values, types=None, guessed=False):
"""Add new Angles to this Universe.
Expand Down Expand Up @@ -1139,7 +1142,10 @@ def delete_bonds(self, values):
.. versionadded:: 1.0.0
"""
self._delete_topology_objects('bonds', values)
# Invalidate bond-related caches
self._cache.pop('fragments', None)
self._cache['_valid'].pop('fragments', None)
self._cache['_valid'].pop('fragindices', None)

def delete_angles(self, values):
"""Delete Angles from this Universe.
Expand Down
50 changes: 45 additions & 5 deletions package/MDAnalysis/lib/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@
import functools
from functools import wraps
import textwrap
import weakref

import mmtf
import numpy as np
Expand Down Expand Up @@ -1496,10 +1497,18 @@ def conv_float(s):
return s


def cached(key):
# A dummy, empty, cheaply-hashable object class to use with weakref caching.
# (class object doesn't allow weakrefs to its instances, but user-defined
# classes do)
class _CacheKey:
pass


def cached(key, universe_validation=False):
"""Cache a property within a class.

Requires the Class to have a cache dict called ``_cache``.
Requires the Class to have a cache dict :attr:`_cache` and, with
`universe_validation`, a :attr:`universe` with a cache dict :attr:`_cache`.

Example
-------
Expand All @@ -1513,23 +1522,54 @@ class A(object):
@property
@cached('keyname')
def size(self):
# This code gets ran only if the lookup of keyname fails
# After this code has been ran once, the result is stored in
# This code gets run only if the lookup of keyname fails
# After this code has been run once, the result is stored in
# _cache with the key: 'keyname'
size = 10.0
return 10.0

@property
@cached('keyname', universe_validation=True)
def othersize(self):
# This code gets run only if the lookup
# id(self) is not in the validation set under
# self.universe._cache['_valid']['keyname']
# After this code has been run once, id(self) is added to that
# set. The validation set can be centrally invalidated at the
# universe level (say, if a topology change invalidates specific
# caches).
return 20.0


.. versionadded:: 0.9.0

.. versionchanged::2.0.0
Added the `universe_validation` keyword.
"""

def cached_lookup(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
try:
if universe_validation: # Universe-level cache validation
u_cache = self.universe._cache.setdefault('_valid', dict())
# A WeakSet is used so that keys from out-of-scope/deleted
# objects don't clutter it.
valid_caches = u_cache.setdefault(key, weakref.WeakSet())
try:
if self._cache_key not in valid_caches:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a comment, one day it might be worthwhile to put this self._cache_key setting behaviour into Group.__hash__ so that every AtomGroup natively hashes quickly/cleanly/weakly.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great idea; that seems super clean! I can implement it right away if you're ok with having both functionalities added in the same PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went over and checked and we already have a __hash__ defined, and it looks like it's actually hashing ix which isn't going to be fast, but technically necessary for two equivalent AGs to hash identically. For this lookup I'd rather we quickly hash the AG and maybe have duplicate caches

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, yea. I hadn't checked __hash__ but I had initially tried to use the ag itself as the key, and got the same slow performance (I just assumed the entire object was being hashed instead of just _ix, which probably boils down to a similar time penalty).

raise KeyError
except AttributeError: # No _cache_key yet
# Must create a reference key for the validation set.
# self could be used itself as a weakref but set()
# requires hashing it, which can be slow for AGs. Using
# id(self) fails because ints can't be weak-referenced.
self._cache_key = _CacheKey()
raise KeyError
return self._cache[key]
except KeyError:
self._cache[key] = ret = func(self, *args, **kwargs)
if universe_validation:
valid_caches.add(self._cache_key)
return ret

return wrapper
Expand Down
18 changes: 18 additions & 0 deletions testsuite/MDAnalysisTests/core/test_fragments.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,24 @@ def test_atom_fragment_nobonds_NDE(self):
with pytest.raises(NoDataError):
getattr(u.atoms[10], 'fragindex')

def test_atomgroup_fragment_cache_invalidation_bond_making(self):
u = case1()
fgs = u.atoms.fragments
assert fgs is u.atoms._cache['fragments']
assert u.atoms._cache_key in u._cache['_valid']['fragments']
u.add_bonds((fgs[0][-1] + fgs[1][0],)) # should trigger invalidation
assert 'fragments' not in u._cache['_valid']
assert len(fgs) > len(u.atoms.fragments) # recomputed

def test_atomgroup_fragment_cache_invalidation_bond_breaking(self):
u = case1()
fgs = u.atoms.fragments
assert fgs is u.atoms._cache['fragments']
assert u.atoms._cache_key in u._cache['_valid']['fragments']
u.delete_bonds((u.atoms.bonds[3],)) # should trigger invalidation
assert 'fragments' not in u._cache['_valid']
assert len(fgs) < len(u.atoms.fragments) # recomputed


def test_tpr_fragments():
ag = mda.Universe(TPR, XTC).atoms
Expand Down
39 changes: 39 additions & 0 deletions testsuite/MDAnalysisTests/lib/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,11 @@ def __init__(self):
self.ref3 = 3.0
self.ref4 = 4.0
self.ref5 = 5.0
self.ref6 = 6.0
# For universe-validated caches
# One-line lambda-like class
self.universe = type('Universe', (), dict())()
self.universe._cache = {'_valid': {}}

@cached('val1')
def val1(self):
Expand Down Expand Up @@ -742,6 +747,12 @@ def val5(self, n, s=None):
def _init_val_5(self, n, s=None):
return n * s

# Property decorator and universally-validated cache
@property
@cached('val6', universe_validation=True)
def val6(self):
return self.ref5 + 1.0

# These are designed to mimic the AG and Universe cache methods
def _clear_caches(self, *args):
if len(args) == 0:
Expand Down Expand Up @@ -836,6 +847,34 @@ def test_val5_kwargs(self, obj):

assert obj.val5(5, s='!!!') == 5 * 'abc'

# property decorator, with universe validation
def test_val6_universe_validation(self, obj):
obj._clear_caches()
assert not hasattr(obj, '_cache_key')
assert 'val6' not in obj._cache
assert 'val6' not in obj.universe._cache['_valid']

ret = obj.val6 # Trigger caching
assert obj.val6 == obj.ref6
assert ret is obj.val6
assert 'val6' in obj._cache
assert 'val6' in obj.universe._cache['_valid']
assert obj._cache_key in obj.universe._cache['_valid']['val6']
assert obj._cache['val6'] is ret

# Invalidate cache at universe level
obj.universe._cache['_valid']['val6'].clear()
ret2 = obj.val6
assert ret2 is obj.val6
assert ret2 is not ret

# Clear obj cache and access again
obj._clear_caches()
ret3 = obj.val6
assert ret3 is obj.val6
assert ret3 is not ret2
assert ret3 is not ret


class TestConvFloat(object):
@pytest.mark.parametrize('s, output', [
Expand Down