diff --git a/doc/source/whats_new.rst b/doc/source/whats_new.rst index 1a6d0aad658..87496676443 100644 --- a/doc/source/whats_new.rst +++ b/doc/source/whats_new.rst @@ -6,6 +6,7 @@ Current Changelog ~~~~~~~~~ + - Generation of simulated evoked responses by `Alex Gramfort`_, `Daniel Strohmeier`_, and `Martin Luessi`_ - Fit AR models to raw data for temporal whitening by `Alex Gramfort`_. diff --git a/mne/simulation/source.py b/mne/simulation/source.py index bddb5ed8432..8fec4be1ef2 100644 --- a/mne/simulation/source.py +++ b/mne/simulation/source.py @@ -45,7 +45,7 @@ def select_source_in_label(src, label, random_state=None): return lh_vertno, rh_vertno -def generate_sparse_stc(src, labels, stc_data, tmin, tstep, random_state=0): +def generate_sparse_stc(src, labels, stc_data, tmin, tstep, random_state=None): """Generate sparse sources time courses from waveforms and labels This function randomly selects a single vertex in each label and assigns @@ -83,14 +83,20 @@ def generate_sparse_stc(src, labels, stc_data, tmin, tstep, random_state=0): vertno[0] += lh_vertno vertno[1] += rh_vertno if len(lh_vertno) != 0: - lh_data.append(label_data) + lh_data.append(np.atleast_2d(label_data)) elif len(rh_vertno) != 0: - rh_data.append(label_data) + rh_data.append(np.atleast_2d(label_data)) else: raise ValueError('No vertno found.') + vertno = map(np.array, vertno) - data = np.r_[lh_data, rh_data] + + # the data is in the order left, right + lh_data.extend(rh_data) + data = np.concatenate(lh_data) + stc = _make_stc(data, tmin, tstep, vertno) + return stc @@ -137,8 +143,7 @@ def generate_stc(src, labels, stc_data, tmin, tstep, value_fun=None): vertno = [[], []] stc_data_extended = [[], []] - hemi_to_ind = {} - hemi_to_ind['lh'], hemi_to_ind['rh'] = 0, 1 + hemi_to_ind = {'lh': 0, 'rh': 1} for i, label in enumerate(labels): hemi_ind = hemi_to_ind[label['hemi']] src_sel = np.intersect1d(src[hemi_ind]['vertno'], @@ -153,7 +158,7 @@ def generate_stc(src, labels, stc_data, tmin, tstep, value_fun=None): data = np.tile(stc_data[i], (len(src_sel), 1)) vertno[hemi_ind].append(src_sel) - stc_data_extended[hemi_ind].append(data) + stc_data_extended[hemi_ind].append(np.atleast_2d(data)) # format the vertno list for idx in (0, 1): @@ -163,21 +168,11 @@ def generate_stc(src, labels, stc_data, tmin, tstep, value_fun=None): vertno[idx] = vertno[idx][0] vertno = map(np.array, vertno) - # the data is in the same order as the vertices in vertno - n_vert_tot = len(vertno[0]) + len(vertno[1]) - stc_data = np.zeros((n_vert_tot, stc_data.shape[1])) - for idx in (0, 1): - if len(stc_data_extended[idx]) == 0: - continue - if len(stc_data_extended[idx]) == 1: - data = stc_data_extended[idx][0] - else: - data = np.concatenate(stc_data_extended[idx]) - - if idx == 0: - stc_data[:len(vertno[0]), :] = data - else: - stc_data[len(vertno[0]):, :] = data + # the data is in the order left, right + lh_data = stc_data_extended[0] + rh_data = stc_data_extended[1] + lh_data.extend(rh_data) + stc_data = np.concatenate(lh_data) stc = _make_stc(stc_data, tmin, tstep, vertno) return stc diff --git a/mne/simulation/tests/test_source.py b/mne/simulation/tests/test_source.py index bae2dd4acb9..f61dfa4d94b 100644 --- a/mne/simulation/tests/test_source.py +++ b/mne/simulation/tests/test_source.py @@ -2,14 +2,14 @@ import copy import numpy as np -from numpy.testing import assert_array_almost_equal +from numpy.testing import assert_array_almost_equal, assert_array_equal from nose.tools import assert_true from ...datasets import sample from ... import read_label from ... import read_forward_solution -from ..source import generate_stc +from ..source import generate_stc, generate_sparse_stc examples_folder = op.join(op.dirname(__file__), '..', '..', '..' '/examples') data_path = sample.data_path(examples_folder) @@ -46,3 +46,25 @@ def test_generate_stc(): # the first label has value 0, the second value 2 assert_array_almost_equal(stc.data[0], np.zeros(n_times)) assert_array_almost_equal(stc.data[-1], 4 * np.ones(n_times)) + + +def test_generate_sparse_stc(): + """ Test generation of sparse source estimate """ + + n_times = 10 + tmin = 0 + tstep = 1e-3 + + stc_data = np.ones((len(labels), n_times)) + stc_1 = generate_sparse_stc(fwd['src'], labels, stc_data, tmin, tstep, 0) + + assert_true(np.all(stc_1.data == 1.0)) + assert_true(stc_1.data.shape[0] == len(labels)) + assert_true(stc_1.data.shape[1] == n_times) + + # make sure we get the same result when using the same seed + stc_2 = generate_sparse_stc(fwd['src'], labels, stc_data, tmin, tstep, 0) + + assert_array_equal(stc_1.lh_vertno, stc_2.lh_vertno) + assert_array_equal(stc_1.rh_vertno, stc_2.rh_vertno) +