Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions meegkit/cca.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ def cca_crossvalidate(xx, yy, shifts=None, sfreq=1, surrogate=False,
n_trials).
shifts : array, shape=(n_shifts,)
Array of shifts to apply to `y` relative to `x` (can be negative).
sfreq : float
Sampling frequency. If not 1, lags are assumed to be given in seconds.
surrogate : bool
If True, estimate SD of correlation over non-matching pairs.
plot : bool
Expand Down Expand Up @@ -133,16 +135,16 @@ def cca_crossvalidate(xx, yy, shifts=None, sfreq=1, surrogate=False,

# Calculate leave-one-out CCAs
print('Calculate CCAs...')
AA = list()
BB = list()
AA = []
BB = []
for t in tqdm(np.arange(n_trials)):
# covariance of all trials except t
CC = np.sum(C[..., np.arange(n_trials) != t], axis=-1, keepdims=True)
if CC.ndim == 4:
CC = np.squeeze(CC, 3)

# corresponding CCA
[A, B, R] = nt_cca(None, None, None, CC, xx[0].shape[1])
A, B, _ = nt_cca(None, None, None, CC, xx[0].shape[1])
AA.append(A)
BB.append(B)
del A, B
Expand Down Expand Up @@ -227,17 +229,17 @@ def nt_cca(X=None, Y=None, lags=None, C=None, m=None, thresh=1e-12, sfreq=1):
independently from each page.
m : int
Number of channels of X.
thresh: float
thresh : float
Discard principal components below this value.
sfreq : float
Sampling frequency. If not 1, lags are assumed to be given in seconds.

Returns
-------
A : array, shape=(n_chans_X, min(n_chans_X, n_chans_Y))
A : array, shape=(n_chans_X, min(n_chans_X, n_chans_Y)[, n_lags])
Transform matrix mapping `X` to canonical space, where `n_comps` is
equal to `min(n_chans_X, n_chans_Y)`.
B : array, shape=(n_chans_Y, n_comps)
B : array, shape=(n_chans_Y, n_comps[, n_lags])
Transform matrix mapping `Y` to canonical space, where `n_comps` is
equal to `min(n_chans_X, n_chans_Y)`.
R : array, shape=(n_comps, n_lags)
Expand All @@ -246,16 +248,16 @@ def nt_cca(X=None, Y=None, lags=None, C=None, m=None, thresh=1e-12, sfreq=1):
Notes
-----
Usage 1: CCA of X, Y
>> [A, B, R] = nt_cca(X, Y) # noqa
>> A, B, R = nt_cca(X, Y) # noqa

Usage 2: CCA of X, Y for each value of lags.
>> [A, B, R] = nt_cca(X, Y, lags) # noqa
>> A, B, R = nt_cca(X, Y, lags) # noqa

A positive lag indicates that Y is delayed relative to X.

Usage 3: CCA from covariance matrix
>> C = [X, Y].T * [X, Y] # noqa
>> [A, B, R] = nt_cca([], [], [], C, X.shape[1]) # noqa
>> A, B, R = nt_cca(None, None, None, C=C, m=X.shape[1]) # noqa

Use the third form to handle multiple files or large data (covariance C can
be calculated chunk-by-chunk).
Expand Down Expand Up @@ -381,9 +383,10 @@ def whiten_nt(C, thresh=1e-12, keep=False):
# break symmetry when x and y perfectly correlated (otherwise cols of x*A
# and y*B are not orthogonal)
d = d ** (1 - thresh)
d_norm = d / np.max(d)

dd = np.zeros_like(d)
dd[d > thresh] = (1. / d[d > thresh])
dd[d_norm > thresh] = (1. / d[d_norm > thresh])

D = np.diag(np.sqrt(dd))
W = np.dot(V, D)
Expand Down
Binary file added tests/data/ccadata_meg_2trials.npz
Binary file not shown.
18 changes: 18 additions & 0 deletions tests/test_cca.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,21 @@ def test_cca2():
# plt.show()


def test_cca_scaling():
"""Test CCA with MEG data."""
data = np.load('./tests/data/ccadata_meg_2trials.npz')
raw = data['arr_0']
env = data['arr_1']

# Test with scaling (unit: fT)
A0, B0, R0 = nt_cca(raw * 1e15, env)

# Test without scaling (unit: T)
A1, B1, R1 = nt_cca(raw, env)

np.testing.assert_almost_equal(R0, R1)


def test_canoncorr():
"""Compare with Matlab's canoncorr."""
x = np.array([[16, 2, 3, 13],
Expand Down Expand Up @@ -130,6 +145,9 @@ def test_cca_lags():
lags = np.arange(-10, 11, 1)
A1, B1, R1 = nt_cca(x, y, lags)

assert A1.ndim == B1.ndim == 3
assert A1.shape[-1] == B1.shape[-1] == lags.size

# import matplotlib.pyplot as plt
# f, ax1 = plt.subplots(1, 1)
# ax1.plot(lags, R1.T)
Expand Down