Skip to content
21 changes: 19 additions & 2 deletions mne/stats/tests/test_permutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from numpy.testing import assert_array_equal, assert_allclose
import numpy as np
from scipy import stats, sparse
import pytest

from mne.stats import permutation_cluster_1samp_test
from mne.stats.permutations import (
Expand Down Expand Up @@ -55,9 +56,25 @@ def test_permutation_t_test():
assert_allclose(t_obs_clust, t_obs)
assert_allclose(p_values_clust, p_values[keep], atol=1e-2)


@pytest.mark.parametrize(
"tail_name,tail_code",
[
("two-sided", 0),
pytest.param(
"less", -1, marks=pytest.mark.xfail(reason="Bug in permutation function")
),
pytest.param(
"greater", 1, marks=pytest.mark.xfail(reason="Bug in permutation function")
),
],
)
def test_permutation_t_test_tail(tail_name, tail_code):
"""Test that tails work properly."""
X = np.random.randn(18, 1)
t_obs, p_values, H0 = permutation_t_test(X, n_permutations="all")
t_obs_scipy, p_values_scipy = stats.ttest_1samp(X[:, 0], 0)

t_obs, p_values, _ = permutation_t_test(X, n_permutations="all", tail=tail_code)
t_obs_scipy, p_values_scipy = stats.ttest_1samp(X[:, 0], 0, alternative=tail_name)
assert_allclose(t_obs[0], t_obs_scipy, 8)
assert_allclose(p_values[0], p_values_scipy, rtol=1e-2)

Expand Down