Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Current

Changelog
~~~~~~~~~
- Generation of simulated evoked responses by `Alex Gramfort`_, `Daniel Strohmeier`_, and `Martin Luessi`_

- Fit AR models to raw data for temporal whitening by `Alex Gramfort`_.

Expand Down
39 changes: 17 additions & 22 deletions mne/simulation/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def select_source_in_label(src, label, random_state=None):
return lh_vertno, rh_vertno


def generate_sparse_stc(src, labels, stc_data, tmin, tstep, random_state=0):
def generate_sparse_stc(src, labels, stc_data, tmin, tstep, random_state=None):
"""Generate sparse sources time courses from waveforms and labels

This function randomly selects a single vertex in each label and assigns
Expand Down Expand Up @@ -83,14 +83,20 @@ def generate_sparse_stc(src, labels, stc_data, tmin, tstep, random_state=0):
vertno[0] += lh_vertno
vertno[1] += rh_vertno
if len(lh_vertno) != 0:
lh_data.append(label_data)
lh_data.append(np.atleast_2d(label_data))
elif len(rh_vertno) != 0:
rh_data.append(label_data)
rh_data.append(np.atleast_2d(label_data))
else:
raise ValueError('No vertno found.')

vertno = map(np.array, vertno)
data = np.r_[lh_data, rh_data]

# the data is in the order left, right
lh_data.extend(rh_data)
data = np.concatenate(lh_data)

stc = _make_stc(data, tmin, tstep, vertno)

return stc


Expand Down Expand Up @@ -137,8 +143,7 @@ def generate_stc(src, labels, stc_data, tmin, tstep, value_fun=None):

vertno = [[], []]
stc_data_extended = [[], []]
hemi_to_ind = {}
hemi_to_ind['lh'], hemi_to_ind['rh'] = 0, 1
hemi_to_ind = {'lh': 0, 'rh': 1}
for i, label in enumerate(labels):
hemi_ind = hemi_to_ind[label['hemi']]
src_sel = np.intersect1d(src[hemi_ind]['vertno'],
Expand All @@ -153,7 +158,7 @@ def generate_stc(src, labels, stc_data, tmin, tstep, value_fun=None):
data = np.tile(stc_data[i], (len(src_sel), 1))

vertno[hemi_ind].append(src_sel)
stc_data_extended[hemi_ind].append(data)
stc_data_extended[hemi_ind].append(np.atleast_2d(data))

# format the vertno list
for idx in (0, 1):
Expand All @@ -163,21 +168,11 @@ def generate_stc(src, labels, stc_data, tmin, tstep, value_fun=None):
vertno[idx] = vertno[idx][0]
vertno = map(np.array, vertno)

# the data is in the same order as the vertices in vertno
n_vert_tot = len(vertno[0]) + len(vertno[1])
stc_data = np.zeros((n_vert_tot, stc_data.shape[1]))
for idx in (0, 1):
if len(stc_data_extended[idx]) == 0:
continue
if len(stc_data_extended[idx]) == 1:
data = stc_data_extended[idx][0]
else:
data = np.concatenate(stc_data_extended[idx])

if idx == 0:
stc_data[:len(vertno[0]), :] = data
else:
stc_data[len(vertno[0]):, :] = data
# the data is in the order left, right
lh_data = stc_data_extended[0]
rh_data = stc_data_extended[1]
lh_data.extend(rh_data)
stc_data = np.concatenate(lh_data)

stc = _make_stc(stc_data, tmin, tstep, vertno)
return stc
26 changes: 24 additions & 2 deletions mne/simulation/tests/test_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
import copy

import numpy as np
from numpy.testing import assert_array_almost_equal
from numpy.testing import assert_array_almost_equal, assert_array_equal
from nose.tools import assert_true

from ...datasets import sample
from ... import read_label
from ... import read_forward_solution

from ..source import generate_stc
from ..source import generate_stc, generate_sparse_stc

examples_folder = op.join(op.dirname(__file__), '..', '..', '..' '/examples')
data_path = sample.data_path(examples_folder)
Expand Down Expand Up @@ -46,3 +46,25 @@ def test_generate_stc():
# the first label has value 0, the second value 2
assert_array_almost_equal(stc.data[0], np.zeros(n_times))
assert_array_almost_equal(stc.data[-1], 4 * np.ones(n_times))


def test_generate_sparse_stc():
""" Test generation of sparse source estimate """

n_times = 10
tmin = 0
tstep = 1e-3

stc_data = np.ones((len(labels), n_times))
stc_1 = generate_sparse_stc(fwd['src'], labels, stc_data, tmin, tstep, 0)

assert_true(np.all(stc_1.data == 1.0))
assert_true(stc_1.data.shape[0] == len(labels))
assert_true(stc_1.data.shape[1] == n_times)

# make sure we get the same result when using the same seed
stc_2 = generate_sparse_stc(fwd['src'], labels, stc_data, tmin, tstep, 0)

assert_array_equal(stc_1.lh_vertno, stc_2.lh_vertno)
assert_array_equal(stc_1.rh_vertno, stc_2.rh_vertno)