Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 56 additions & 133 deletions zarr/tests/test_convenience.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import atexit
import os
import tempfile
import unittest
from numbers import Integral
Expand Down Expand Up @@ -456,67 +455,67 @@ def test_copy_all():
assert destination_group.subgroup.attrs["info"] == "sub attrs"


# noinspection PyAttributeOutsideInit
class TestCopy(unittest.TestCase):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.source_h5py = False
self.dest_h5py = False
self.new_source = group
self.new_dest = group

def setUp(self):
source = self.new_source()
foo = source.create_group('foo')
foo.attrs['experiment'] = 'weird science'
baz = foo.create_dataset('bar/baz', data=np.arange(100), chunks=(50,))
baz.attrs['units'] = 'metres'
if self.source_h5py:
extra_kws = dict(compression='gzip', compression_opts=3, fillvalue=84,
shuffle=True, fletcher32=True)
class TestCopy:
@pytest.fixture(params=[False, True], ids=['zarr', 'hdf5'])
def source(self, request, tmpdir):
def prep_source(source):
foo = source.create_group('foo')
foo.attrs['experiment'] = 'weird science'
baz = foo.create_dataset('bar/baz', data=np.arange(100), chunks=(50,))
baz.attrs['units'] = 'metres'
if request.param:
extra_kws = dict(compression='gzip', compression_opts=3, fillvalue=84,
shuffle=True, fletcher32=True)
else:
extra_kws = dict(compressor=Zlib(3), order='F', fill_value=42, filters=[Adler32()])
source.create_dataset('spam', data=np.arange(100, 200).reshape(20, 5),
chunks=(10, 2), dtype='i2', **extra_kws)
return source

if request.param:
h5py = pytest.importorskip('h5py')
fn = tmpdir.join('source.h5')
with h5py.File(str(fn), mode='w') as h5f:
yield prep_source(h5f)
else:
extra_kws = dict(compressor=Zlib(3), order='F', fill_value=42,
filters=[Adler32()])
source.create_dataset('spam', data=np.arange(100, 200).reshape(20, 5),
chunks=(10, 2), dtype='i2', **extra_kws)
self.source = source

def test_copy_array(self):
source = self.source
dest = self.new_dest()
yield prep_source(group())

@pytest.fixture(params=[False, True], ids=['zarr', 'hdf5'])
def dest(self, request, tmpdir):
if request.param:
h5py = pytest.importorskip('h5py')
fn = tmpdir.join('dest.h5')
with h5py.File(str(fn), mode='w') as h5f:
yield h5f
else:
yield group()

def test_copy_array(self, source, dest):
# copy array with default options
copy(source['foo/bar/baz'], dest)
check_copied_array(source['foo/bar/baz'], dest['baz'])
copy(source['spam'], dest)
check_copied_array(source['spam'], dest['spam'])

def test_copy_bad_dest(self):
source = self.source

def test_copy_bad_dest(self, source, dest):
# try to copy to an array, dest must be a group
dest = self.new_dest().create_dataset('eggs', shape=(100,))
dest = dest.create_dataset('eggs', shape=(100,))
with pytest.raises(ValueError):
copy(source['foo/bar/baz'], dest)

def test_copy_array_name(self):
source = self.source
dest = self.new_dest()

def test_copy_array_name(self, source, dest):
# copy array with name
copy(source['foo/bar/baz'], dest, name='qux')
assert 'baz' not in dest
check_copied_array(source['foo/bar/baz'], dest['qux'])

def test_copy_array_create_options(self):
source = self.source
dest = self.new_dest()
def test_copy_array_create_options(self, source, dest):
dest_h5py = dest.__module__.startswith('h5py.')

# copy array, provide creation options
compressor = Zlib(9)
create_kws = dict(chunks=(10,))
if self.dest_h5py:
if dest_h5py:
create_kws.update(compression='gzip', compression_opts=9,
shuffle=True, fletcher32=True, fillvalue=42)
else:
Expand All @@ -526,10 +525,7 @@ def test_copy_array_create_options(self):
check_copied_array(source['foo/bar/baz'], dest['baz'],
without_attrs=True, expect_props=create_kws)

def test_copy_array_exists_array(self):
source = self.source
dest = self.new_dest()

def test_copy_array_exists_array(self, source, dest):
# copy array, dest array in the way
dest.create_dataset('baz', shape=(10,))

Expand All @@ -554,10 +550,7 @@ def test_copy_array_exists_array(self):
with pytest.raises(ValueError):
copy(source['foo/bar/baz'], dest, if_exists='foobar')

def test_copy_array_exists_group(self):
source = self.source
dest = self.new_dest()

def test_copy_array_exists_group(self, source, dest):
# copy array, dest group in the way
dest.create_group('baz')

Expand All @@ -577,13 +570,13 @@ def test_copy_array_exists_group(self):
copy(source['foo/bar/baz'], dest, if_exists='replace')
check_copied_array(source['foo/bar/baz'], dest['baz'])

def test_copy_array_skip_initialized(self):
source = self.source
dest = self.new_dest()
def test_copy_array_skip_initialized(self, source, dest):
dest_h5py = dest.__module__.startswith('h5py.')

dest.create_dataset('baz', shape=(100,), chunks=(10,), dtype='i8')
assert not np.all(source['foo/bar/baz'][:] == dest['baz'][:])

if self.dest_h5py:
if dest_h5py:
with pytest.raises(ValueError):
# not available with copy to h5py
copy(source['foo/bar/baz'], dest, if_exists='skip_initialized')
Expand All @@ -599,55 +592,37 @@ def test_copy_array_skip_initialized(self):
assert_array_equal(np.arange(100, 200), dest['baz'][:])
assert not np.all(source['foo/bar/baz'][:] == dest['baz'][:])

def test_copy_group(self):
source = self.source
dest = self.new_dest()

def test_copy_group(self, source, dest):
# copy group, default options
copy(source['foo'], dest)
check_copied_group(source['foo'], dest['foo'])

def test_copy_group_no_name(self):
source = self.source
dest = self.new_dest()

def test_copy_group_no_name(self, source, dest):
with pytest.raises(TypeError):
# need a name if copy root
copy(source, dest)

copy(source, dest, name='root')
check_copied_group(source, dest['root'])

def test_copy_group_options(self):
source = self.source
dest = self.new_dest()

def test_copy_group_options(self, source, dest):
# copy group, non-default options
copy(source['foo'], dest, name='qux', without_attrs=True)
assert 'foo' not in dest
check_copied_group(source['foo'], dest['qux'], without_attrs=True)

def test_copy_group_shallow(self):
source = self.source
dest = self.new_dest()

def test_copy_group_shallow(self, source, dest):
# copy group, shallow
copy(source, dest, name='eggs', shallow=True)
check_copied_group(source, dest['eggs'], shallow=True)

def test_copy_group_exists_group(self):
source = self.source
dest = self.new_dest()

def test_copy_group_exists_group(self, source, dest):
# copy group, dest groups exist
dest.create_group('foo/bar')
copy(source['foo'], dest)
check_copied_group(source['foo'], dest['foo'])

def test_copy_group_exists_array(self):
source = self.source
dest = self.new_dest()

def test_copy_group_exists_array(self, source, dest):
# copy group, dest array in the way
dest.create_dataset('foo/bar', shape=(10,))

Expand All @@ -667,10 +642,7 @@ def test_copy_group_exists_array(self):
copy(source['foo'], dest, if_exists='replace')
check_copied_group(source['foo'], dest['foo'])

def test_copy_group_dry_run(self):
source = self.source
dest = self.new_dest()

def test_copy_group_dry_run(self, source, dest):
# dry run, empty destination
n_copied, n_skipped, n_bytes_copied = \
copy(source['foo'], dest, dry_run=True, return_stats=True)
Expand Down Expand Up @@ -710,67 +682,18 @@ def test_copy_group_dry_run(self):
assert 0 == n_bytes_copied
assert_array_equal(baz, dest['foo/bar/baz'])

def test_logging(self):
source = self.source
dest = self.new_dest()

def test_logging(self, source, dest, tmpdir):
# callable log
copy(source['foo'], dest, dry_run=True, log=print)

# file name
fn = tempfile.mktemp()
atexit.register(os.remove, fn)
fn = str(tmpdir.join('log_name'))
copy(source['foo'], dest, dry_run=True, log=fn)

# file
with tempfile.TemporaryFile(mode='w') as f:
with tmpdir.join('log_file').open(mode='w') as f:
copy(source['foo'], dest, dry_run=True, log=f)

# bad option
with pytest.raises(TypeError):
copy(source['foo'], dest, dry_run=True, log=True)


try:
import h5py
except ImportError: # pragma: no cover
h5py = None


def temp_h5f():
h5py = pytest.importorskip("h5py")
fn = tempfile.mktemp()
atexit.register(os.remove, fn)
h5f = h5py.File(fn, mode='w')
atexit.register(lambda v: v.close(), h5f)
return h5f


class TestCopyHDF5ToZarr(TestCopy):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.source_h5py = True
self.dest_h5py = False
self.new_source = temp_h5f
self.new_dest = group


class TestCopyZarrToHDF5(TestCopy):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.source_h5py = False
self.dest_h5py = True
self.new_source = group
self.new_dest = temp_h5f


class TestCopyHDF5ToHDF5(TestCopy):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.source_h5py = True
self.dest_h5py = True
self.new_source = temp_h5f
self.new_dest = temp_h5f