Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions mne/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,47 @@ def download_is_error(monkeypatch):
"""Prevent downloading by raising an error when it's attempted."""
import pooch
monkeypatch.setattr(pooch, 'retrieve', _fail)
yield


@pytest.fixture()
def fake_retrieve(monkeypatch, download_is_error):
"""Monkeypatch pooch.retrieve to avoid downloading (just touch files)."""
import pooch
my_func = _FakeFetch()
monkeypatch.setattr(pooch, 'retrieve', my_func)
monkeypatch.setattr(pooch, 'create', my_func)
yield my_func


class _FakeFetch:

def __init__(self):
self.call_args_list = list()

@property
def call_count(self):
return len(self.call_args_list)

# Wrapper for pooch.retrieve(...) and pooch.create(...)
def __call__(self, *args, **kwargs):
assert 'path' in kwargs
if 'fname' in kwargs: # pooch.retrieve(...)
self.call_args_list.append((args, kwargs))
path = Path(kwargs['path'], kwargs['fname'])
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text('test')
return path
else: # pooch.create(...) has been called
self.path = kwargs['path']
return self

# Wrappers for Pooch instances (e.g., in eegbci we pooch.create)
def fetch(self, fname):
self(path=self.path, fname=fname)

def load_registry(self, registry):
assert Path(registry).exists(), registry


# We can't use monkeypatch because its scope (function-level) conflicts with
Expand Down
5 changes: 3 additions & 2 deletions mne/datasets/eegbci/eegbci.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,17 +203,18 @@ def load_data(subject, runs, path=None, force_update=False, update_path=None,
for run in runs:
file_part = f'S{subject:03d}/S{subject:03d}R{run:02d}.edf'
destination = Path(base_path, file_part)
data_paths.append(destination)
if destination.exists():
if force_update:
destination.unlink()
else:
continue
if sz == 0: # log once
logger.info('Downloading EEGBCI data')
data_paths.append(fetcher.fetch(file_part))
fetcher.fetch(file_part)
# update path in config if desired
_do_path_update(path, update_path, config_key, name)
sz += destination.stat().st_size
_do_path_update(path, update_path, config_key, name)
if sz > 0:
_log_time_size(t0, sz)
return data_paths
Expand Down
14 changes: 14 additions & 0 deletions mne/datasets/eegbci/tests/test_eegbci.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Authors: Eric Larson <larson.eric.d@gmail.com>
#
# License: BSD Style.

from mne.datasets import eegbci


def test_eegbci_download(tmp_path, fake_retrieve):
"""Test Sleep Physionet URL handling."""
for subj in range(4):
fnames = eegbci.load_data(
subj + 1, runs=[3], path=tmp_path, update_path=False)
assert len(fnames) == 1, subj
assert fake_retrieve.call_count == 4
33 changes: 4 additions & 29 deletions mne/datasets/sleep_physionet/tests/test_physionet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from pathlib import Path
import pytest

import pooch

from mne.utils import requires_good_network
from mne.utils import requires_pandas, requires_version
Expand All @@ -23,20 +22,6 @@ def physionet_tmpdir(tmp_path_factory):
return str(tmp_path_factory.mktemp('physionet_files'))


class _FakeFetch:

def __init__(self):
self.call_args_list = list()

def __call__(self, *args, **kwargs):
self.call_args_list.append((args, kwargs))
Path(kwargs['path'], kwargs['fname']).write_text('test')

@property
def call_count(self):
return len(self.call_args_list)


def _keep_basename_only(paths):
return [Path(p).name for p in paths]

Expand Down Expand Up @@ -119,15 +104,8 @@ def test_sleep_physionet_age_missing_recordings(physionet_tmpdir, subject,
assert paths == []


def test_sleep_physionet_age(physionet_tmpdir, monkeypatch, download_is_error):
def test_sleep_physionet_age(physionet_tmpdir, fake_retrieve):
"""Test Sleep Physionet URL handling."""
# check download_is_error patching
with pytest.raises(AssertionError, match='Test should not download'):
age.fetch_data(subjects=[0], recording=[1], path=physionet_tmpdir)
# then patch
my_func = _FakeFetch()
monkeypatch.setattr(pooch, 'retrieve', my_func)

paths = age.fetch_data(subjects=[0], recording=[1], path=physionet_tmpdir)
assert _keep_basename_only(paths[0]) == \
['SC4001E0-PSG.edf', 'SC4001EC-Hypnogram.edf']
Expand Down Expand Up @@ -161,7 +139,7 @@ def test_sleep_physionet_age(physionet_tmpdir, monkeypatch, download_is_error):
'hash': '386230188a3552b1fc90bba0fb7476ceaca174b6'},
)
base_path = age.data_path(path=physionet_tmpdir)
_check_mocked_function_calls(my_func, EXPECTED_CALLS, base_path)
_check_mocked_function_calls(fake_retrieve, EXPECTED_CALLS, base_path)


@pytest.mark.xfail(strict=False)
Expand All @@ -180,11 +158,8 @@ def test_run_update_temazepam_records(tmp_path):
data, pd.read_csv(TEMAZEPAM_SLEEP_RECORDS))


def test_sleep_physionet_temazepam(physionet_tmpdir, monkeypatch):
def test_sleep_physionet_temazepam(physionet_tmpdir, fake_retrieve):
"""Test Sleep Physionet URL handling."""
my_func = _FakeFetch()
monkeypatch.setattr(pooch, 'retrieve', my_func)

paths = temazepam.fetch_data(subjects=[0], path=physionet_tmpdir)
assert _keep_basename_only(paths[0]) == \
['ST7011J0-PSG.edf', 'ST7011JP-Hypnogram.edf']
Expand All @@ -195,7 +170,7 @@ def test_sleep_physionet_temazepam(physionet_tmpdir, monkeypatch):
{'name': 'ST7011JP-Hypnogram.edf',
'hash': 'ff28e5e01296cefed49ae0c27cfb3ebc42e710bf'})
base_path = temazepam.data_path(path=physionet_tmpdir)
_check_mocked_function_calls(my_func, EXPECTED_CALLS, base_path)
_check_mocked_function_calls(fake_retrieve, EXPECTED_CALLS, base_path)

with pytest.raises(
ValueError, match='This dataset contains subjects 0 to 21'):
Expand Down
1 change: 1 addition & 0 deletions tutorials/preprocessing/40_artifact_correction_ica.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
sample_data_raw_file = os.path.join(sample_data_folder, 'MEG', 'sample',
'sample_audvis_filt-0-40_raw.fif')
raw = mne.io.read_raw_fif(sample_data_raw_file)

# Here we'll crop to 60 seconds and drop gradiometer channels for speed
raw.crop(tmax=60.).pick_types(meg='mag', eeg=True, stim=True, eog=True)
raw.load_data()
Expand Down