diff --git a/doc/changes/latest.inc b/doc/changes/latest.inc index 90c55853fe7..400fc1155a1 100644 --- a/doc/changes/latest.inc +++ b/doc/changes/latest.inc @@ -191,6 +191,10 @@ Enhancements - Add :func:`mne.preprocessing.ieeg.project_sensors_onto_brain` to project ECoG sensors onto the pial surface to compensate for brain shift (:gh:`9800` by `Alex Rockhill`_) +- All functions for reading and writing files should now automatically handle ``~`` (the tilde character) and expand it to the user's home directory. Should you come across any function that doesn't do it, please do let us know! (:gh:`9613` by `Richard Höchenberger`_) + +- All functions accepting a FreeSurfer subjects directory via a ``subjects_dir`` parameter can now consume :py:class:`pathlib.Path` objects too (used to be only strings) (:gh:`9613` by `Richard Höchenberger`_) + Bugs ~~~~ - Fix bug in :meth:`mne.io.Raw.pick` and related functions when parameter list contains channels which are not in info instance (:gh:`9708` **by new contributor** |Evgeny Goldstein|_) diff --git a/mne/cov.py b/mne/cov.py index 22c384076b0..893cf70a1fc 100644 --- a/mne/cov.py +++ b/mne/cov.py @@ -38,7 +38,7 @@ warn, copy_function_doc_to_method_doc, _pl, _undo_scaling_cov, _scaled_array, _validate_type, _check_option, eigh, fill_doc, _on_missing, - _check_on_missing) + _check_on_missing, _check_fname) from . import viz from .fixes import (BaseEstimator, EmpiricalCovariance, _logdet, @@ -153,7 +153,8 @@ def save(self, fname): """ check_fname(fname, 'covariance', ('-cov.fif', '-cov.fif.gz', '_cov.fif', '_cov.fif.gz')) - + # TODO: Add `overwrite` param to method signature + fname = _check_fname(fname=fname, overwrite=True) fid = start_file(fname) try: @@ -384,6 +385,7 @@ def read_cov(fname, verbose=None): """ check_fname(fname, 'covariance', ('-cov.fif', '-cov.fif.gz', '_cov.fif', '_cov.fif.gz')) + fname = _check_fname(fname=fname, must_exist=True, overwrite='read') f, tree = fiff_open(fname)[:2] with f as fid: return Covariance(**_read_cov(fid, tree, FIFF.FIFFV_MNE_NOISE_COV, diff --git a/mne/epochs.py b/mne/epochs.py index 58ae9e4c77c..d7099d8c3e1 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -59,7 +59,8 @@ _check_combine, ShiftTimeMixin, _build_data_frame, _check_pandas_index_arguments, _convert_times, _scale_dataframe_data, _check_time_format, object_size, - _on_missing, _validate_type, _ensure_events) + _on_missing, _validate_type, _ensure_events, + _path_like) from .utils.docs import fill_doc from .data.html_templates import epochs_template @@ -1738,8 +1739,8 @@ def save(self, fname, split_size='2GB', fmt='single', overwrite=False, Parameters ---------- fname : str - The name of the file, which should end with -epo.fif or - -epo.fif.gz. + The name of the file, which should end with ``-epo.fif`` or + ``-epo.fif.gz``. split_size : str | int Large raw files are automatically split into multiple pieces. This parameter specifies the maximum size of each piece. If the @@ -1771,8 +1772,8 @@ def save(self, fname, split_size='2GB', fmt='single', overwrite=False, check_fname(fname, 'epochs', ('-epo.fif', '-epo.fif.gz', '_epo.fif', '_epo.fif.gz')) - # check for file existence - _check_fname(fname, overwrite) + # check for file existence and expand `~` if present + fname = _check_fname(fname=fname, overwrite=overwrite) split_size_bytes = _get_split_size(split_size) @@ -3059,14 +3060,11 @@ def read_epochs(fname, proj=True, preload=True, verbose=None): Parameters ---------- - fname : str | file-like - The epochs filename to load. Filename should end with -epo.fif or - -epo.fif.gz. If a file-like object is provided, preloading must be - used. + %(epochs_fname)s %(proj_epochs)s preload : bool - If True, read all epochs from disk immediately. If False, epochs will - be read on demand. + If True, read all epochs from disk immediately. If ``False``, epochs + will be read on demand. %(verbose)s Returns @@ -3100,9 +3098,7 @@ class EpochsFIF(BaseEpochs): Parameters ---------- - fname : str | file-like - The name of the file, which should end with -epo.fif or -epo.fif.gz. If - a file-like object is provided, preloading must be used. + %(epochs_fname)s %(proj_epochs)s preload : bool If True, read all epochs from disk immediately. If False, epochs will @@ -3119,9 +3115,13 @@ class EpochsFIF(BaseEpochs): @verbose def __init__(self, fname, proj=True, preload=True, verbose=None): # noqa: D102 - if isinstance(fname, str): - check_fname(fname, 'epochs', ('-epo.fif', '-epo.fif.gz', - '_epo.fif', '_epo.fif.gz')) + if _path_like(fname): + check_fname( + fname=fname, filetype='epochs', + endings=('-epo.fif', '-epo.fif.gz', '_epo.fif', '_epo.fif.gz') + ) + fname = _check_fname(fname=fname, must_exist=True, + overwrite='read') elif not preload: raise ValueError('preload must be used with file-like objects') diff --git a/mne/evoked.py b/mne/evoked.py index 0179f8e09c6..def5bcd5d80 100644 --- a/mne/evoked.py +++ b/mne/evoked.py @@ -22,7 +22,7 @@ fill_doc, _check_option, ShiftTimeMixin, _build_data_frame, _check_pandas_installed, _check_pandas_index_arguments, _convert_times, _scale_dataframe_data, _check_time_format, - _check_preload) + _check_preload, _check_fname) from .viz import (plot_evoked, plot_evoked_topomap, plot_evoked_field, plot_evoked_image, plot_evoked_topo) from .viz.evoked import plot_evoked_white, plot_evoked_joint @@ -127,6 +127,7 @@ def __init__(self, fname, condition=None, proj=True, verbose=None): # noqa: D102 _validate_type(proj, bool, "'proj'") # Read the requested data + fname = _check_fname(fname=fname, must_exist=True, overwrite='read') self.info, self.nave, self._aspect_kind, self.comment, self.times, \ self.data, self.baseline = _read_evoked(fname, condition, kind, allow_maxshield) @@ -278,7 +279,7 @@ def apply_baseline(self, baseline=(None, 0), *, verbose=None): return self def save(self, fname): - """Save dataset to file. + """Save evoked data to a file. Parameters ---------- @@ -295,6 +296,8 @@ def save(self, fname): Information on baseline correction will be stored with the data, and will be restored when reading again via `mne.read_evokeds`. """ + # TODO: Add `overwrite` param to method signature + fname = _check_fname(fname=fname, overwrite=True) write_evokeds(fname, self) def __repr__(self): # noqa: D105 @@ -1398,6 +1401,8 @@ def _write_evokeds(fname, evoked, check=True, *, on_mismatch='raise'): if check: check_fname(fname, 'evoked', ('-ave.fif', '-ave.fif.gz', '_ave.fif', '_ave.fif.gz')) + # TODO: Add `overwrite` param to method signature + fname = _check_fname(fname=fname, overwrite=True) if not isinstance(evoked, (list, tuple)): evoked = [evoked] diff --git a/mne/forward/forward.py b/mne/forward/forward.py index 96907f5a685..24d0a7d731c 100644 --- a/mne/forward/forward.py +++ b/mne/forward/forward.py @@ -437,7 +437,7 @@ def read_forward_solution(fname, include=(), exclude=(), verbose=None): """ check_fname(fname, 'forward', ('-fwd.fif', '-fwd.fif.gz', '_fwd.fif', '_fwd.fif.gz')) - + fname = _check_fname(fname=fname, must_exist=True, overwrite='read') # Open the file, create directory logger.info('Reading forward solution from %s...' % fname) f, tree, _ = fiff_open(fname) @@ -718,8 +718,8 @@ def write_forward_solution(fname, fwd, overwrite=False, verbose=None): Parameters ---------- fname : str - File name to save the forward solution to. It should end with -fwd.fif - or -fwd.fif.gz. + File name to save the forward solution to. It should end with + ``-fwd.fif`` or ``-fwd.fif.gz``. fwd : Forward Forward solution. %(overwrite)s @@ -747,8 +747,8 @@ def write_forward_solution(fname, fwd, overwrite=False, verbose=None): check_fname(fname, 'forward', ('-fwd.fif', '-fwd.fif.gz', '_fwd.fif', '_fwd.fif.gz')) - # check for file existence - _check_fname(fname, overwrite) + # check for file existence and expand `~` if present + fname = _check_fname(fname, overwrite) fid = start_file(fname) start_block(fid, FIFF.FIFFB_MNE) diff --git a/mne/io/base.py b/mne/io/base.py index 06745464470..7b918713566 100644 --- a/mne/io/base.py +++ b/mne/io/base.py @@ -1416,7 +1416,6 @@ def save(self, fname, picks=None, tmin=0, tmax=None, buffer_size_sec=None, or all forms of SSS). It is recommended not to concatenate and then save raw files for this reason. """ - fname = op.abspath(fname) endings = ('raw.fif', 'raw_sss.fif', 'raw_tsss.fif', '_meg.fif', '_eeg.fif', '_ieeg.fif') endings += tuple([f'{e}.gz' for e in endings]) @@ -1447,8 +1446,8 @@ def save(self, fname, picks=None, tmin=0, tmax=None, buffer_size_sec=None, raise ValueError('Complex data must be saved as "single" or ' '"double", not "short"') - # check for file existence - _check_fname(fname, overwrite) + # check for file existence and expand `~` if present + fname = _check_fname(fname=fname, overwrite=overwrite) if proj: info = deepcopy(self.info) @@ -2148,6 +2147,9 @@ def _write_raw(fname, raw, info, picks, fmt, data_type, reset_range, start, raise RuntimeError('Cannot write raw file with no data: %s -> %s ' '(max: %s) requested' % (start, stop, n_times_max)) + # Expand `~` if present + fname = _check_fname(fname=fname, overwrite=overwrite) + base, ext = op.splitext(fname) if part_idx > 0: if split_naming == 'neuromag': @@ -2181,7 +2183,9 @@ def _write_raw(fname, raw, info, picks, fmt, data_type, reset_range, start, raw, info, picks, fid, cals, part_idx, start, stop, buffer_size, prev_fname, split_size, use_fname, projector, drop_small_buffer, fmt, fname, reserved_fname, - data_type, reset_range, split_naming, overwrite) + data_type, reset_range, split_naming, + overwrite=True # we've started writing already above + ) if final_fname != use_fname: assert split_naming == 'bids' logger.info(f'Renaming BIDS split file {op.basename(final_fname)}') diff --git a/mne/io/eeglab/eeglab.py b/mne/io/eeglab/eeglab.py index ea86f7a009f..01fe901f38a 100644 --- a/mne/io/eeglab/eeglab.py +++ b/mne/io/eeglab/eeglab.py @@ -446,6 +446,8 @@ def __init__(self, input_fname, events=None, event_id=None, tmin=0, baseline=None, reject=None, flat=None, reject_tmin=None, reject_tmax=None, eog=(), verbose=None, uint16_codec=None): # noqa: D102 + input_fname = _check_fname(fname=input_fname, must_exist=True, + overwrite='read') eeg = _check_load_mat(input_fname, uint16_codec) if not ((events is None and event_id is None) or @@ -507,7 +509,6 @@ def __init__(self, input_fname, events=None, event_id=None, tmin=0, events = read_events(events) logger.info('Extracting parameters from %s...' % input_fname) - input_fname = op.abspath(input_fname) info, eeg_montage, _ = _get_info(eeg, eog=eog) for key, val in event_id.items(): diff --git a/mne/io/fiff/tests/test_raw_fiff.py b/mne/io/fiff/tests/test_raw_fiff.py index 98199016bf9..31a6c28195d 100644 --- a/mne/io/fiff/tests/test_raw_fiff.py +++ b/mne/io/fiff/tests/test_raw_fiff.py @@ -5,6 +5,7 @@ # License: BSD-3-Clause from copy import deepcopy +from pathlib import Path from functools import partial from io import BytesIO import os @@ -1786,3 +1787,22 @@ def test_corrupted(tmpdir): with pytest.warns(RuntimeWarning, match='.*tag directory.*corrupt.*'): raw_bad = read_raw_fif(bad_fname) assert_allclose(raw.get_data(), raw_bad.get_data()) + + +@testing.requires_testing_data +def test_expand_user(tmp_path, monkeypatch): + """Test that we're expanding `~` before reading and writing.""" + monkeypatch.setenv('HOME', str(tmp_path)) + monkeypatch.setenv('USERPROFILE', str(tmp_path)) # Windows + + path_in = Path(fif_fname) + path_out = tmp_path / path_in.name + path_home = Path('~') / path_in.name + + shutil.copyfile( + src=path_in, + dst=path_out + ) + + raw = read_raw_fif(fname=path_home, preload=True) + raw.save(fname=path_home, overwrite=True) diff --git a/mne/io/kit/kit.py b/mne/io/kit/kit.py index 16dd04be4f4..d962cd5af21 100644 --- a/mne/io/kit/kit.py +++ b/mne/io/kit/kit.py @@ -17,7 +17,7 @@ from ..pick import pick_types from ...utils import (verbose, logger, warn, fill_doc, _check_option, - _stamp_to_dt) + _stamp_to_dt, _check_fname) from ...transforms import apply_trans, als_ras_trans from ..base import BaseRaw from ..utils import _mult_cal_one @@ -386,8 +386,9 @@ def __init__(self, input_fname, events, event_id=None, tmin=0, if isinstance(events, str): events = read_events(events) + input_fname = _check_fname(fname=input_fname, must_exist=True, + overwrite='read') logger.info('Extracting KIT Parameters from %s...' % input_fname) - input_fname = op.abspath(input_fname) self.info, kit_info = get_kit_info( input_fname, allow_unknown_format, standardize_names) kit_info.update(filename=input_fname) diff --git a/mne/minimum_norm/inverse.py b/mne/minimum_norm/inverse.py index 8d14288dfbb..853d0f37c17 100644 --- a/mne/minimum_norm/inverse.py +++ b/mne/minimum_norm/inverse.py @@ -41,9 +41,10 @@ from ..source_estimate import _make_stc, _get_src_type from ..utils import (check_fname, logger, verbose, warn, _validate_type, _check_compensation_grade, _check_option, - _check_depth, _check_src_normal) + _check_depth, _check_src_normal, _check_fname) from ..data.html_templates import inverse_operator_template + INVERSE_METHODS = ('MNE', 'dSPM', 'sLORETA', 'eLORETA') @@ -132,7 +133,7 @@ def read_inverse_operator(fname, verbose=None): """ check_fname(fname, 'inverse operator', ('-inv.fif', '-inv.fif.gz', '_inv.fif', '_inv.fif.gz')) - + fname = _check_fname(fname=fname, must_exist=True, overwrite='read') # # Open the file, create directory # @@ -350,6 +351,8 @@ def write_inverse_operator(fname, inv, verbose=None): """ check_fname(fname, 'inverse operator', ('-inv.fif', '-inv.fif.gz', '_inv.fif', '_inv.fif.gz')) + fname = _check_fname(fname=fname, overwrite=True) + _validate_type(inv, InverseOperator, 'inv') # diff --git a/mne/report/report.py b/mne/report/report.py index 37e88fc68e0..52aab770f35 100644 --- a/mne/report/report.py +++ b/mne/report/report.py @@ -37,7 +37,7 @@ from .._freesurfer import _reorient_image, _mri_orientation from ..utils import (logger, verbose, get_subjects_dir, warn, _ensure_int, fill_doc, _check_option, _validate_type, _safe_input, - _check_path_like, use_log_level, deprecated) + _path_like, use_log_level, deprecated) from ..viz import (plot_events, plot_alignment, plot_cov, plot_projs_topomap, plot_compare_evokeds, set_3d_view, get_3d_backend) from ..viz.misc import _plot_mri_contours, _get_bem_plotting_surfaces @@ -1498,7 +1498,7 @@ def add_figure(self, fig, title, *, caption=None, image_format=None, figs = (fig,) for fig in figs: - if _check_path_like(fig): + if _path_like(fig): raise TypeError( f'It seems you passed a path to `add_figure`. However, ' f'only Matplotlib figures, Mayavi scenes, and NumPy ' @@ -1582,9 +1582,9 @@ def add_figs_to_section(self, figs, captions, section='custom', _check_scale(scale) if ( - _check_path_like(figs) or + _path_like(figs) or (hasattr(figs, '__iter__') and - any(_check_path_like(f) for f in figs)) + any(_path_like(f) for f in figs)) ): raise TypeError( 'It seems you passed a path to `add_figs_to_section`. ' diff --git a/mne/report/tests/test_report.py b/mne/report/tests/test_report.py index bb8feacc858..43d2da70e9b 100644 --- a/mne/report/tests/test_report.py +++ b/mne/report/tests/test_report.py @@ -466,8 +466,9 @@ def test_add_html(): assert (repr(report)) +@testing.requires_testing_data def test_multiple_figs(tmpdir): - """Test adding a slider with a series of figures to mne report.""" + """Test adding a slider with a series of figures to a Report.""" tempdir = str(tmpdir) report = Report(info_fname=raw_fname, subject='sample', subjects_dir=subjects_dir) diff --git a/mne/source_estimate.py b/mne/source_estimate.py index e0db73b1e3f..b3cd8cbc8d7 100644 --- a/mne/source_estimate.py +++ b/mne/source_estimate.py @@ -31,8 +31,8 @@ fill_doc, _check_option, _validate_type, _check_src_normal, _check_stc_units, _check_pandas_installed, _check_pandas_index_arguments, _convert_times, _ensure_int, - _build_data_frame, _check_time_format, _check_path_like, - sizeof_fmt, object_size) + _build_data_frame, _check_time_format, _path_like, + sizeof_fmt, object_size, _check_fname) from .viz import (plot_source_estimates, plot_vector_source_estimates, plot_volume_source_estimates) from .io.base import TimeMixin @@ -250,7 +250,11 @@ def read_source_estimate(fname, subject=None): """ # noqa: E501 fname_arg = fname _validate_type(fname, 'path-like', 'fname') - fname = str(fname) + + # expand `~` without checking whether the file actually exists – we'll + # take care of that later, as it's complicated by the different suffixes + # STC files can have + fname = _check_fname(fname=fname, overwrite='read', must_exist=False) # make sure corresponding file(s) can be found ftype = None @@ -622,7 +626,8 @@ def save(self, fname, ftype='h5', verbose=None): %(verbose_meth)s """ _validate_type(fname, 'path-like', 'fname') - fname = str(fname) + # TODO: Add `overwrite` param to method signature + fname = _check_fname(fname=fname, overwrite=True) if ftype != 'h5': raise ValueError('%s objects can only be written as HDF5 files.' % (self.__class__.__name__,)) @@ -632,7 +637,9 @@ def save(self, fname, ftype='h5', verbose=None): dict(vertices=self.vertices, data=self.data, tmin=self.tmin, tstep=self.tstep, subject=self.subject, src_type=self._src_type), - title='mnepython', overwrite=True) + title='mnepython', + # TODO: Add `overwrite` param to method signature + overwrite=True) @copy_function_doc_to_method_doc(plot_source_estimates) def plot(self, subject=None, surface='inflated', hemi='lh', @@ -1595,7 +1602,8 @@ def save(self, fname, ftype='stc', verbose=None): %(verbose_meth)s """ _validate_type(fname, 'path-like', 'fname') - fname = str(fname) + # TODO: Add `overwrite` param to method signature + fname = _check_fname(fname=fname, overwrite=True) _check_option('ftype', ftype, ['stc', 'w', 'h5']) lh_data = self.data[:len(self.lh_vertno)] @@ -2084,7 +2092,8 @@ def save_as_volume(self, fname, src, dest='mri', mri_resolution=False, """ import nibabel as nib _validate_type(fname, 'path-like', 'fname') - fname = str(fname) + # TODO: Add `overwrite` param to method signature + fname = _check_fname(fname=fname, overwrite=True) img = self.as_volume(src, dest=dest, mri_resolution=mri_resolution, format=format) nib.save(img, fname) @@ -2187,7 +2196,8 @@ def save(self, fname, ftype='stc', verbose=None): %(verbose_meth)s """ _validate_type(fname, 'path-like', 'fname') - fname = str(fname) + # TODO: Add `overwrite` param to method signature + fname = _check_fname(fname=fname, overwrite=True) _check_option('ftype', ftype, ['stc', 'w', 'h5']) if ftype != 'h5' and len(self.vertices) != 1: raise ValueError('Can only write to .stc or .w if a single volume ' @@ -2987,7 +2997,7 @@ def _volume_labels(src, labels, mri_resolution): extra = ' when using a volume source space' _import_nibabel('use volume atlas labels') _validate_type(labels, ('path-like', list, tuple), 'labels' + extra) - if _check_path_like(labels): + if _path_like(labels): mri = labels infer_labels = True else: diff --git a/mne/source_space.py b/mne/source_space.py index e8dea0b0857..5cda384828a 100644 --- a/mne/source_space.py +++ b/mne/source_space.py @@ -35,7 +35,7 @@ read_freesurfer_lut, get_mni_fiducials, _check_mri) from .utils import (get_subjects_dir, check_fname, logger, verbose, fill_doc, _ensure_int, check_version, _get_call_line, warn, - _check_fname, _check_path_like, _check_sphere, + _check_fname, _path_like, _check_sphere, _validate_type, _check_option, _is_numeric, _pl, _suggest, object_size, sizeof_fmt) from .parallel import parallel_func, check_n_jobs @@ -2187,7 +2187,7 @@ def _ensure_src(src, kind=None, extra='', verbose=None): _check_option( 'kind', kind, (None, 'surface', 'volume', 'mixed', 'discrete')) msg = 'src must be a string or instance of SourceSpaces%s' % (extra,) - if _check_path_like(src): + if _path_like(src): src = str(src) if not op.isfile(src): raise IOError('Source space file "%s" not found' % src) diff --git a/mne/time_frequency/csd.py b/mne/time_frequency/csd.py index 1bba4f8ea58..765dbb26198 100644 --- a/mne/time_frequency/csd.py +++ b/mne/time_frequency/csd.py @@ -459,6 +459,7 @@ def save(self, fname): if not fname.endswith('.h5'): fname += '.h5' + # TODO: Add `overwrite` param to method signature write_hdf5(fname, self.__getstate__(), overwrite=True, title='conpy') def copy(self): diff --git a/mne/transforms.py b/mne/transforms.py index bc7e50b896f..7f086219738 100644 --- a/mne/transforms.py +++ b/mne/transforms.py @@ -20,7 +20,7 @@ from .io.write import start_file, end_file, write_coord_trans from .defaults import _handle_default from .utils import (check_fname, logger, verbose, _ensure_int, _validate_type, - _check_path_like, get_subjects_dir, fill_doc, _check_fname, + _path_like, get_subjects_dir, fill_doc, _check_fname, _check_option, _require_version, wrapped_stdout) @@ -450,7 +450,7 @@ def _get_trans(trans, fro='mri', to='head', allow_none=True): if allow_none: types += (None,) _validate_type(trans, types, 'trans') - if _check_path_like(trans): + if _path_like(trans): trans = str(trans) if trans == 'fsaverage': trans = op.join(op.dirname(__file__), 'data', 'fsaverage', @@ -574,6 +574,8 @@ def write_trans(fname, trans): """ check_fname(fname, 'trans', ('-trans.fif', '-trans.fif.gz', '_trans.fif', '_trans.fif.gz')) + # TODO: Add `overwrite` param to method signature + fname = _check_fname(fname=fname, overwrite=True) fid = start_file(fname) write_coord_trans(fid, trans) end_file(fid) diff --git a/mne/utils/__init__.py b/mne/utils/__init__.py index 4afb1d2e825..b12af51abf8 100644 --- a/mne/utils/__init__.py +++ b/mne/utils/__init__.py @@ -14,7 +14,7 @@ _validate_type, _check_info_inv, _check_channels_spatial_filter, _check_one_ch_type, _check_rank, _check_option, _check_depth, _check_combine, - _check_path_like, _check_src_normal, _check_stc_units, + _path_like, _check_src_normal, _check_stc_units, _check_pyqt5_version, _check_sphere, _check_time_format, _check_freesurfer_home, _suggest, _require_version, _on_missing, _check_on_missing, int_like, _safe_input, diff --git a/mne/utils/check.py b/mne/utils/check.py index db8c5310977..45681a527c4 100644 --- a/mne/utils/check.py +++ b/mne/utils/check.py @@ -156,6 +156,12 @@ def _check_fname(fname, overwrite=False, must_exist=False, name='File', need_dir=False): """Check for file existence, and return string of its absolute path.""" _validate_type(fname, 'path-like', name) + fname = str( + Path(fname) + .expanduser() + .absolute() + ) + if op.exists(fname): if not overwrite: raise FileExistsError('Destination file exists. Please use option ' @@ -178,7 +184,8 @@ def _check_fname(fname, overwrite=False, must_exist=False, name='File', f'{name} does not have read permissions: {fname}') elif must_exist: raise FileNotFoundError(f'{name} does not exist: {fname}') - return str(op.abspath(fname)) + + return fname def _check_subject(first, second, *, raise_error=True, @@ -420,7 +427,7 @@ def _validate_type(item, types=None, item_name=None, type_name=None): f"got {type(item)} instead.") -def _check_path_like(item): +def _path_like(item): """Validate that `item` is `path-like`. Parameters diff --git a/mne/utils/config.py b/mne/utils/config.py index 59811d89241..6a21b0f33a6 100644 --- a/mne/utils/config.py +++ b/mne/utils/config.py @@ -17,7 +17,8 @@ import numpy as np -from .check import _validate_type, _check_pyqt5_version, _check_option +from .check import (_validate_type, _check_pyqt5_version, _check_option, + _check_fname) from .docs import fill_doc from ._logging import warn, logger @@ -366,7 +367,11 @@ def get_subjects_dir(subjects_dir=None, raise_error=False): if subjects_dir is None: subjects_dir = get_config('SUBJECTS_DIR', raise_error=raise_error) if subjects_dir is not None: - subjects_dir = str(subjects_dir) + subjects_dir = _check_fname( + fname=subjects_dir, overwrite='read', must_exist=True, + need_dir=True, name='subjects_dir' + ) + return subjects_dir diff --git a/mne/utils/docs.py b/mne/utils/docs.py index 9d998f80212..7c597be9723 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -268,6 +268,11 @@ If proj is False no projections will be applied which is the recommended value if SSPs are not used for cleaning the data. """ +docdict['epochs_fname'] = """ +fname : path-like | file-like + The epochs to load. If a filename, should end with ``-epo.fif`` or + ``-epo.fif.gz``. If a file-like object, preloading must be used. +""" # Reject by annotation docdict['reject_by_annotation_all'] = """ diff --git a/mne/utils/tests/test_check.py b/mne/utils/tests/test_check.py index d55da37794c..1b6bd956b54 100644 --- a/mne/utils/tests/test_check.py +++ b/mne/utils/tests/test_check.py @@ -18,7 +18,7 @@ from mne.utils import (check_random_state, _check_fname, check_fname, _check_subject, requires_mayavi, traits_test, _check_mayavi_version, _check_info_inv, _check_option, - check_version, _check_path_like, _validate_type, + check_version, _path_like, _validate_type, _suggest, _on_missing, requires_nibabel, _safe_input) data_path = testing.data_path(download=False) @@ -176,15 +176,15 @@ def test_check_option(): assert _check_option('option', 'bad', ['valid']) -def test_check_path_like(): - """Test _check_path_like().""" +def test_path_like(): + """Test _path_like().""" str_path = str(base_dir) pathlib_path = Path(base_dir) no_path = dict(foo='bar') - assert _check_path_like(str_path) is True - assert _check_path_like(pathlib_path) is True - assert _check_path_like(no_path) is False + assert _path_like(str_path) is True + assert _path_like(pathlib_path) is True + assert _path_like(no_path) is False def test_validate_type(): diff --git a/mne/utils/tests/test_config.py b/mne/utils/tests/test_config.py index f381119daee..17db9fc5ee7 100644 --- a/mne/utils/tests/test_config.py +++ b/mne/utils/tests/test_config.py @@ -89,17 +89,23 @@ def test_sys_info(): assert 'Platform: macOS-' in out -def test_get_subjects_dir(monkeypatch, tmpdir): +def test_get_subjects_dir(tmp_path, monkeypatch): """Test get_subjects_dir().""" + subjects_dir = tmp_path / 'foo' + subjects_dir.mkdir() + # String - subjects_dir = '/foo' - assert get_subjects_dir(subjects_dir) == subjects_dir + assert get_subjects_dir(str(subjects_dir)) == str(subjects_dir) # Path - subjects_dir = Path('/foo') assert get_subjects_dir(subjects_dir) == str(subjects_dir) # `None` - monkeypatch.setenv('_MNE_FAKE_HOME_DIR', str(tmpdir)) + monkeypatch.setenv('_MNE_FAKE_HOME_DIR', str(tmp_path)) monkeypatch.delenv('SUBJECTS_DIR', raising=False) assert get_subjects_dir() is None + + # Expand `~` + monkeypatch.setenv('HOME', str(tmp_path)) + monkeypatch.setenv('USERPROFILE', str(tmp_path)) # Windows + assert get_subjects_dir('~/foo') == str(subjects_dir) diff --git a/mne/viz/tests/test_evoked.py b/mne/viz/tests/test_evoked.py index d7448e3088b..d1f9ddde934 100644 --- a/mne/viz/tests/test_evoked.py +++ b/mne/viz/tests/test_evoked.py @@ -81,7 +81,7 @@ def test_plot_evoked_cov(): evoked.plot(noise_cov=cov, time_unit='s') with pytest.raises(TypeError, match='Covariance'): evoked.plot(noise_cov=1., time_unit='s') - with pytest.raises(IOError, match='No such file'): + with pytest.raises(FileNotFoundError, match='File does not exist'): evoked.plot(noise_cov='nonexistent-cov.fif', time_unit='s') raw = read_raw_fif(raw_sss_fname) events = make_fixed_length_events(raw)