diff --git a/mne/stats/tests/test_permutations.py b/mne/stats/tests/test_permutations.py index 2a0b86381b7..245ac140182 100644 --- a/mne/stats/tests/test_permutations.py +++ b/mne/stats/tests/test_permutations.py @@ -5,6 +5,7 @@ from numpy.testing import assert_array_equal, assert_allclose import numpy as np from scipy import stats, sparse +import pytest from mne.stats import permutation_cluster_1samp_test from mne.stats.permutations import ( @@ -55,9 +56,25 @@ def test_permutation_t_test(): assert_allclose(t_obs_clust, t_obs) assert_allclose(p_values_clust, p_values[keep], atol=1e-2) + +@pytest.mark.parametrize( + "tail_name,tail_code", + [ + ("two-sided", 0), + pytest.param( + "less", -1, marks=pytest.mark.xfail(reason="Bug in permutation function") + ), + pytest.param( + "greater", 1, marks=pytest.mark.xfail(reason="Bug in permutation function") + ), + ], +) +def test_permutation_t_test_tail(tail_name, tail_code): + """Test that tails work properly.""" X = np.random.randn(18, 1) - t_obs, p_values, H0 = permutation_t_test(X, n_permutations="all") - t_obs_scipy, p_values_scipy = stats.ttest_1samp(X[:, 0], 0) + + t_obs, p_values, _ = permutation_t_test(X, n_permutations="all", tail=tail_code) + t_obs_scipy, p_values_scipy = stats.ttest_1samp(X[:, 0], 0, alternative=tail_name) assert_allclose(t_obs[0], t_obs_scipy, 8) assert_allclose(p_values[0], p_values_scipy, rtol=1e-2)