From 030aa6aba1842e1792a8a17771475a8ed9825887 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Wed, 19 Apr 2023 10:48:22 -0400 Subject: [PATCH 1/2] BUG: Fix bug with paths --- mne/datasets/eegbci/eegbci.py | 5 +++-- tutorials/preprocessing/40_artifact_correction_ica.py | 1 + 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/mne/datasets/eegbci/eegbci.py b/mne/datasets/eegbci/eegbci.py index e89ae089fcc..fd2b0a71e24 100644 --- a/mne/datasets/eegbci/eegbci.py +++ b/mne/datasets/eegbci/eegbci.py @@ -203,6 +203,7 @@ def load_data(subject, runs, path=None, force_update=False, update_path=None, for run in runs: file_part = f'S{subject:03d}/S{subject:03d}R{run:02d}.edf' destination = Path(base_path, file_part) + data_paths.append(destination) if destination.exists(): if force_update: destination.unlink() @@ -210,10 +211,10 @@ def load_data(subject, runs, path=None, force_update=False, update_path=None, continue if sz == 0: # log once logger.info('Downloading EEGBCI data') - data_paths.append(fetcher.fetch(file_part)) + fetcher.fetch(file_part) # update path in config if desired - _do_path_update(path, update_path, config_key, name) sz += destination.stat().st_size + _do_path_update(path, update_path, config_key, name) if sz > 0: _log_time_size(t0, sz) return data_paths diff --git a/tutorials/preprocessing/40_artifact_correction_ica.py b/tutorials/preprocessing/40_artifact_correction_ica.py index 51e353dcae7..b4aae956300 100644 --- a/tutorials/preprocessing/40_artifact_correction_ica.py +++ b/tutorials/preprocessing/40_artifact_correction_ica.py @@ -29,6 +29,7 @@ sample_data_raw_file = os.path.join(sample_data_folder, 'MEG', 'sample', 'sample_audvis_filt-0-40_raw.fif') raw = mne.io.read_raw_fif(sample_data_raw_file) + # Here we'll crop to 60 seconds and drop gradiometer channels for speed raw.crop(tmax=60.).pick_types(meg='mag', eeg=True, stim=True, eog=True) raw.load_data() From a86d84001c7ff5b55099c11600c5befd9bbbbc31 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Wed, 19 Apr 2023 13:49:23 -0400 Subject: [PATCH 2/2] TST: Add test --- mne/conftest.py | 41 +++++++++++++++++++ mne/datasets/eegbci/tests/test_eegbci.py | 14 +++++++ .../sleep_physionet/tests/test_physionet.py | 33 ++------------- 3 files changed, 59 insertions(+), 29 deletions(-) create mode 100644 mne/datasets/eegbci/tests/test_eegbci.py diff --git a/mne/conftest.py b/mne/conftest.py index 05d433cacfc..a4b261db704 100644 --- a/mne/conftest.py +++ b/mne/conftest.py @@ -713,6 +713,47 @@ def download_is_error(monkeypatch): """Prevent downloading by raising an error when it's attempted.""" import pooch monkeypatch.setattr(pooch, 'retrieve', _fail) + yield + + +@pytest.fixture() +def fake_retrieve(monkeypatch, download_is_error): + """Monkeypatch pooch.retrieve to avoid downloading (just touch files).""" + import pooch + my_func = _FakeFetch() + monkeypatch.setattr(pooch, 'retrieve', my_func) + monkeypatch.setattr(pooch, 'create', my_func) + yield my_func + + +class _FakeFetch: + + def __init__(self): + self.call_args_list = list() + + @property + def call_count(self): + return len(self.call_args_list) + + # Wrapper for pooch.retrieve(...) and pooch.create(...) + def __call__(self, *args, **kwargs): + assert 'path' in kwargs + if 'fname' in kwargs: # pooch.retrieve(...) + self.call_args_list.append((args, kwargs)) + path = Path(kwargs['path'], kwargs['fname']) + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text('test') + return path + else: # pooch.create(...) has been called + self.path = kwargs['path'] + return self + + # Wrappers for Pooch instances (e.g., in eegbci we pooch.create) + def fetch(self, fname): + self(path=self.path, fname=fname) + + def load_registry(self, registry): + assert Path(registry).exists(), registry # We can't use monkeypatch because its scope (function-level) conflicts with diff --git a/mne/datasets/eegbci/tests/test_eegbci.py b/mne/datasets/eegbci/tests/test_eegbci.py new file mode 100644 index 00000000000..e60988ff36c --- /dev/null +++ b/mne/datasets/eegbci/tests/test_eegbci.py @@ -0,0 +1,14 @@ +# Authors: Eric Larson +# +# License: BSD Style. + +from mne.datasets import eegbci + + +def test_eegbci_download(tmp_path, fake_retrieve): + """Test Sleep Physionet URL handling.""" + for subj in range(4): + fnames = eegbci.load_data( + subj + 1, runs=[3], path=tmp_path, update_path=False) + assert len(fnames) == 1, subj + assert fake_retrieve.call_count == 4 diff --git a/mne/datasets/sleep_physionet/tests/test_physionet.py b/mne/datasets/sleep_physionet/tests/test_physionet.py index 549963cb73f..ad400505d73 100644 --- a/mne/datasets/sleep_physionet/tests/test_physionet.py +++ b/mne/datasets/sleep_physionet/tests/test_physionet.py @@ -6,7 +6,6 @@ from pathlib import Path import pytest -import pooch from mne.utils import requires_good_network from mne.utils import requires_pandas, requires_version @@ -23,20 +22,6 @@ def physionet_tmpdir(tmp_path_factory): return str(tmp_path_factory.mktemp('physionet_files')) -class _FakeFetch: - - def __init__(self): - self.call_args_list = list() - - def __call__(self, *args, **kwargs): - self.call_args_list.append((args, kwargs)) - Path(kwargs['path'], kwargs['fname']).write_text('test') - - @property - def call_count(self): - return len(self.call_args_list) - - def _keep_basename_only(paths): return [Path(p).name for p in paths] @@ -119,15 +104,8 @@ def test_sleep_physionet_age_missing_recordings(physionet_tmpdir, subject, assert paths == [] -def test_sleep_physionet_age(physionet_tmpdir, monkeypatch, download_is_error): +def test_sleep_physionet_age(physionet_tmpdir, fake_retrieve): """Test Sleep Physionet URL handling.""" - # check download_is_error patching - with pytest.raises(AssertionError, match='Test should not download'): - age.fetch_data(subjects=[0], recording=[1], path=physionet_tmpdir) - # then patch - my_func = _FakeFetch() - monkeypatch.setattr(pooch, 'retrieve', my_func) - paths = age.fetch_data(subjects=[0], recording=[1], path=physionet_tmpdir) assert _keep_basename_only(paths[0]) == \ ['SC4001E0-PSG.edf', 'SC4001EC-Hypnogram.edf'] @@ -161,7 +139,7 @@ def test_sleep_physionet_age(physionet_tmpdir, monkeypatch, download_is_error): 'hash': '386230188a3552b1fc90bba0fb7476ceaca174b6'}, ) base_path = age.data_path(path=physionet_tmpdir) - _check_mocked_function_calls(my_func, EXPECTED_CALLS, base_path) + _check_mocked_function_calls(fake_retrieve, EXPECTED_CALLS, base_path) @pytest.mark.xfail(strict=False) @@ -180,11 +158,8 @@ def test_run_update_temazepam_records(tmp_path): data, pd.read_csv(TEMAZEPAM_SLEEP_RECORDS)) -def test_sleep_physionet_temazepam(physionet_tmpdir, monkeypatch): +def test_sleep_physionet_temazepam(physionet_tmpdir, fake_retrieve): """Test Sleep Physionet URL handling.""" - my_func = _FakeFetch() - monkeypatch.setattr(pooch, 'retrieve', my_func) - paths = temazepam.fetch_data(subjects=[0], path=physionet_tmpdir) assert _keep_basename_only(paths[0]) == \ ['ST7011J0-PSG.edf', 'ST7011JP-Hypnogram.edf'] @@ -195,7 +170,7 @@ def test_sleep_physionet_temazepam(physionet_tmpdir, monkeypatch): {'name': 'ST7011JP-Hypnogram.edf', 'hash': 'ff28e5e01296cefed49ae0c27cfb3ebc42e710bf'}) base_path = temazepam.data_path(path=physionet_tmpdir) - _check_mocked_function_calls(my_func, EXPECTED_CALLS, base_path) + _check_mocked_function_calls(fake_retrieve, EXPECTED_CALLS, base_path) with pytest.raises( ValueError, match='This dataset contains subjects 0 to 21'):