diff --git a/meegkit/cca.py b/meegkit/cca.py index 3f03d8c9..ebba7d32 100644 --- a/meegkit/cca.py +++ b/meegkit/cca.py @@ -94,6 +94,8 @@ def cca_crossvalidate(xx, yy, shifts=None, sfreq=1, surrogate=False, n_trials). shifts : array, shape=(n_shifts,) Array of shifts to apply to `y` relative to `x` (can be negative). + sfreq : float + Sampling frequency. If not 1, lags are assumed to be given in seconds. surrogate : bool If True, estimate SD of correlation over non-matching pairs. plot : bool @@ -133,8 +135,8 @@ def cca_crossvalidate(xx, yy, shifts=None, sfreq=1, surrogate=False, # Calculate leave-one-out CCAs print('Calculate CCAs...') - AA = list() - BB = list() + AA = [] + BB = [] for t in tqdm(np.arange(n_trials)): # covariance of all trials except t CC = np.sum(C[..., np.arange(n_trials) != t], axis=-1, keepdims=True) @@ -142,7 +144,7 @@ def cca_crossvalidate(xx, yy, shifts=None, sfreq=1, surrogate=False, CC = np.squeeze(CC, 3) # corresponding CCA - [A, B, R] = nt_cca(None, None, None, CC, xx[0].shape[1]) + A, B, _ = nt_cca(None, None, None, CC, xx[0].shape[1]) AA.append(A) BB.append(B) del A, B @@ -227,17 +229,17 @@ def nt_cca(X=None, Y=None, lags=None, C=None, m=None, thresh=1e-12, sfreq=1): independently from each page. m : int Number of channels of X. - thresh: float + thresh : float Discard principal components below this value. sfreq : float Sampling frequency. If not 1, lags are assumed to be given in seconds. Returns ------- - A : array, shape=(n_chans_X, min(n_chans_X, n_chans_Y)) + A : array, shape=(n_chans_X, min(n_chans_X, n_chans_Y)[, n_lags]) Transform matrix mapping `X` to canonical space, where `n_comps` is equal to `min(n_chans_X, n_chans_Y)`. - B : array, shape=(n_chans_Y, n_comps) + B : array, shape=(n_chans_Y, n_comps[, n_lags]) Transform matrix mapping `Y` to canonical space, where `n_comps` is equal to `min(n_chans_X, n_chans_Y)`. R : array, shape=(n_comps, n_lags) @@ -246,16 +248,16 @@ def nt_cca(X=None, Y=None, lags=None, C=None, m=None, thresh=1e-12, sfreq=1): Notes ----- Usage 1: CCA of X, Y - >> [A, B, R] = nt_cca(X, Y) # noqa + >> A, B, R = nt_cca(X, Y) # noqa Usage 2: CCA of X, Y for each value of lags. - >> [A, B, R] = nt_cca(X, Y, lags) # noqa + >> A, B, R = nt_cca(X, Y, lags) # noqa A positive lag indicates that Y is delayed relative to X. Usage 3: CCA from covariance matrix >> C = [X, Y].T * [X, Y] # noqa - >> [A, B, R] = nt_cca([], [], [], C, X.shape[1]) # noqa + >> A, B, R = nt_cca(None, None, None, C=C, m=X.shape[1]) # noqa Use the third form to handle multiple files or large data (covariance C can be calculated chunk-by-chunk). @@ -381,9 +383,10 @@ def whiten_nt(C, thresh=1e-12, keep=False): # break symmetry when x and y perfectly correlated (otherwise cols of x*A # and y*B are not orthogonal) d = d ** (1 - thresh) + d_norm = d / np.max(d) dd = np.zeros_like(d) - dd[d > thresh] = (1. / d[d > thresh]) + dd[d_norm > thresh] = (1. / d[d_norm > thresh]) D = np.diag(np.sqrt(dd)) W = np.dot(V, D) diff --git a/tests/data/ccadata_meg_2trials.npz b/tests/data/ccadata_meg_2trials.npz new file mode 100644 index 00000000..0b340264 Binary files /dev/null and b/tests/data/ccadata_meg_2trials.npz differ diff --git a/tests/test_cca.py b/tests/test_cca.py index cbd4cd57..0c469f40 100644 --- a/tests/test_cca.py +++ b/tests/test_cca.py @@ -78,6 +78,21 @@ def test_cca2(): # plt.show() +def test_cca_scaling(): + """Test CCA with MEG data.""" + data = np.load('./tests/data/ccadata_meg_2trials.npz') + raw = data['arr_0'] + env = data['arr_1'] + + # Test with scaling (unit: fT) + A0, B0, R0 = nt_cca(raw * 1e15, env) + + # Test without scaling (unit: T) + A1, B1, R1 = nt_cca(raw, env) + + np.testing.assert_almost_equal(R0, R1) + + def test_canoncorr(): """Compare with Matlab's canoncorr.""" x = np.array([[16, 2, 3, 13], @@ -130,6 +145,9 @@ def test_cca_lags(): lags = np.arange(-10, 11, 1) A1, B1, R1 = nt_cca(x, y, lags) + assert A1.ndim == B1.ndim == 3 + assert A1.shape[-1] == B1.shape[-1] == lags.size + # import matplotlib.pyplot as plt # f, ax1 = plt.subplots(1, 1) # ax1.plot(lags, R1.T)