diff --git a/docs/api/hierarchy.rst b/docs/api/hierarchy.rst index 73db5bbc34..799657c8d0 100644 --- a/docs/api/hierarchy.rst +++ b/docs/api/hierarchy.rst @@ -15,6 +15,10 @@ Groups (``zarr.hierarchy``) .. automethod:: groups .. automethod:: array_keys .. automethod:: arrays + .. automethod:: visit + .. automethod:: visitkeys + .. automethod:: visitvalues + .. automethod:: visititems .. automethod:: create_group .. automethod:: require_group .. automethod:: create_groups diff --git a/zarr/hierarchy.py b/zarr/hierarchy.py index 9c92f4b122..fbf853bcf8 100644 --- a/zarr/hierarchy.py +++ b/zarr/hierarchy.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- from __future__ import absolute_import, print_function, division from collections import MutableMapping +from itertools import islice import numpy as np @@ -55,6 +56,10 @@ class Group(MutableMapping): groups array_keys arrays + visit + visitkeys + visitvalues + visititems create_group require_group create_groups @@ -414,6 +419,129 @@ def arrays(self): chunk_store=self._chunk_store, synchronizer=self._synchronizer) + def visitvalues(self, func): + """Run ``func`` on each object. + + Note: If ``func`` returns ``None`` (or doesn't return), + iteration continues. However, if ``func`` returns + anything else, it ceases and returns that value. + + Examples + -------- + >>> import zarr + >>> g1 = zarr.group() + >>> g2 = g1.create_group('foo') + >>> g3 = g1.create_group('bar') + >>> g4 = g3.create_group('baz') + >>> g5 = g3.create_group('quux') + >>> def print_visitor(obj): + ... print(obj) + >>> g1.visitvalues(print_visitor) + Group(/bar, 2) + groups: 2; baz, quux + store: DictStore + Group(/bar/baz, 0) + store: DictStore + Group(/bar/quux, 0) + store: DictStore + Group(/foo, 0) + store: DictStore + >>> g3.visitvalues(print_visitor) + Group(/bar/baz, 0) + store: DictStore + Group(/bar/quux, 0) + store: DictStore + + """ + + def _visit(obj): + yield obj + + keys = sorted(getattr(obj, "keys", lambda : [])()) + for each_key in keys: + for each_obj in _visit(obj[each_key]): + yield each_obj + + for each_obj in islice(_visit(self), 1, None): + value = func(each_obj) + if value is not None: + return value + + def visit(self, func): + """Run ``func`` on each object's path. + + Note: If ``func`` returns ``None`` (or doesn't return), + iteration continues. However, if ``func`` returns + anything else, it ceases and returns that value. + + Examples + -------- + >>> import zarr + >>> g1 = zarr.group() + >>> g2 = g1.create_group('foo') + >>> g3 = g1.create_group('bar') + >>> g4 = g3.create_group('baz') + >>> g5 = g3.create_group('quux') + >>> def print_visitor(name): + ... print(name) + >>> g1.visit(print_visitor) + bar + bar/baz + bar/quux + foo + >>> g3.visit(print_visitor) + baz + quux + + """ + + base_len = len(self.name) + return self.visitvalues(lambda o: func(o.name[base_len:].lstrip("/"))) + + def visitkeys(self, func): + """An alias for :py:meth:`~Group.visit`. + """ + + return self.visit(func) + + def visititems(self, func): + """Run ``func`` on each object's path and the object itself. + + Note: If ``func`` returns ``None`` (or doesn't return), + iteration continues. However, if ``func`` returns + anything else, it ceases and returns that value. + + Examples + -------- + >>> import zarr + >>> g1 = zarr.group() + >>> g2 = g1.create_group('foo') + >>> g3 = g1.create_group('bar') + >>> g4 = g3.create_group('baz') + >>> g5 = g3.create_group('quux') + >>> def print_visitor(name, obj): + ... print((name, obj)) + >>> g1.visititems(print_visitor) + ('bar', Group(/bar, 2) + groups: 2; baz, quux + store: DictStore) + ('bar/baz', Group(/bar/baz, 0) + store: DictStore) + ('bar/quux', Group(/bar/quux, 0) + store: DictStore) + ('foo', Group(/foo, 0) + store: DictStore) + >>> g3.visititems(print_visitor) + ('baz', Group(/bar/baz, 0) + store: DictStore) + ('quux', Group(/bar/quux, 0) + store: DictStore) + + """ + + base_len = len(self.name) + return self.visitvalues(lambda o: func(o.name[base_len:].lstrip("/"), o)) + def _write_op(self, f, *args, **kwargs): # guard condition diff --git a/zarr/tests/test_hierarchy.py b/zarr/tests/test_hierarchy.py index a13fb29e05..ee3a7e2eae 100644 --- a/zarr/tests/test_hierarchy.py +++ b/zarr/tests/test_hierarchy.py @@ -473,6 +473,122 @@ def test_getitem_contains_iterators(self): eq('baz', arrays[0][0]) eq(g1['foo']['baz'], arrays[0][1]) + # visitor collection tests + items = [] + + def visitor2(obj): + items.append(obj.path) + + def visitor3(name, obj=None): + items.append(name) + + def visitor4(name, obj): + items.append((name, obj)) + + del items[:] + g1.visitvalues(visitor2) + eq([ + "a", + "a/b", + "a/b/c", + "foo", + "foo/bar", + "foo/baz", + ], items) + + del items[:] + g1["foo"].visitvalues(visitor2) + eq([ + "foo/bar", + "foo/baz", + ], items) + + del items[:] + g1.visit(visitor3) + eq([ + "a", + "a/b", + "a/b/c", + "foo", + "foo/bar", + "foo/baz", + ], items) + + del items[:] + g1["foo"].visit(visitor3) + eq([ + "bar", + "baz", + ], items) + + del items[:] + g1.visitkeys(visitor3) + eq([ + "a", + "a/b", + "a/b/c", + "foo", + "foo/bar", + "foo/baz", + ], items) + + del items[:] + g1["foo"].visitkeys(visitor3) + eq([ + "bar", + "baz", + ], items) + + del items[:] + g1.visititems(visitor3) + eq([ + "a", + "a/b", + "a/b/c", + "foo", + "foo/bar", + "foo/baz", + ], items) + + del items[:] + g1["foo"].visititems(visitor3) + eq([ + "bar", + "baz", + ], items) + + del items[:] + g1.visititems(visitor4) + for n, o in items: + eq(g1[n], o) + + del items[:] + g1["foo"].visititems(visitor4) + for n, o in items: + eq(g1["foo"][n], o) + + # visitor filter tests + def visitor0(val, *args): + name = getattr(val, "path", val) + + if name == "a/b/c/d": + return True # pragma: no cover + + def visitor1(val, *args): + name = getattr(val, "path", val) + + if name == "a/b/c": + return True # pragma: no cover + + eq(None, g1.visit(visitor0)) + eq(None, g1.visitkeys(visitor0)) + eq(None, g1.visitvalues(visitor0)) + eq(None, g1.visititems(visitor0)) + eq(True, g1.visit(visitor1)) + eq(True, g1.visitkeys(visitor1)) + eq(True, g1.visitvalues(visitor1)) + eq(True, g1.visititems(visitor1)) + def test_empty_getitem_contains_iterators(self): # setup g = self.create_group()