diff --git a/doc/changes/devel/12454.newfeature.rst b/doc/changes/devel/12454.newfeature.rst new file mode 100644 index 00000000000..5a4a9cc9cdb --- /dev/null +++ b/doc/changes/devel/12454.newfeature.rst @@ -0,0 +1 @@ +Completing PR 12453. Add option to pass ``image_kwargs`` per channel type to :class:`mne.Report.add_epochs`. \ No newline at end of file diff --git a/mne/report/report.py b/mne/report/report.py index 7e80047a32b..43c3d7c7ac4 100644 --- a/mne/report/report.py +++ b/mne/report/report.py @@ -1124,6 +1124,13 @@ def add_epochs( image_kwargs : dict | None Keyword arguments to pass to the "epochs image"-generating function (:meth:`mne.Epochs.plot_image`). + Keys are channel types, values are dicts containing kwargs to pass. + For example, to use the rejection limits per channel type you could pass:: + + image_kwargs=dict( + grad=dict(vmin=-reject['grad'], vmax=-reject['grad']), + mag=dict(vmin=-reject['mag'], vmax=reject['mag']), + ) .. versionadded:: 1.7 %(topomap_kwargs)s @@ -3888,15 +3895,16 @@ def _add_epochs( ch_types = _get_data_ch_types(epochs) epochs.load_data() - if image_kwargs is None: - image_kwargs = dict() + _validate_type(image_kwargs, (dict, None), "image_kwargs") + # ensure dict with shallow copy because we will modify it + image_kwargs = dict() if image_kwargs is None else image_kwargs.copy() for ch_type in ch_types: with use_log_level(_verbose_safe_false(level="error")): figs = ( epochs.copy() .pick(ch_type, verbose=False) - .plot_image(show=False, **image_kwargs) + .plot_image(show=False, **image_kwargs.pop(ch_type, dict())) ) assert len(figs) == 1 @@ -3920,6 +3928,12 @@ def _add_epochs( replace=replace, own_figure=True, ) + if image_kwargs: + raise ValueError( + f"Ensure the keys in image_kwargs map onto channel types plotted in " + f"epochs.plot_image() of {ch_types}, could not use: " + f"{list(image_kwargs)}" + ) # Drop log if epochs._bad_dropped: diff --git a/mne/report/tests/test_report.py b/mne/report/tests/test_report.py index 65d3ceb697a..7374868c559 100644 --- a/mne/report/tests/test_report.py +++ b/mne/report/tests/test_report.py @@ -936,8 +936,10 @@ def test_manual_report_2d(tmp_path, invisible_fig): tags=("epochs",), psd=False, projs=False, - image_kwargs=dict(colorbar=False), + image_kwargs=dict(mag=dict(colorbar=False)), ) + with pytest.raises(ValueError, match="map onto channel types"): + r.add_epochs(epochs=epochs_without_metadata, image_kwargs=dict(a=1), title="a") r.add_epochs( epochs=epochs_without_metadata, title="my epochs 2", psd=1, projs=False )