Skip to content

Conversation

@drammock
Copy link

don't merge yet. will add some comments on monday. Thought it would be more helpful/instructive to open as PR-into-your-PR rather than just pushing to your branch.

Comment on lines +1788 to +1792
obj_type = all_types.pop()
is_epo = GetEpochsMixin in obj_type.__mro__
is_tfr = BaseTFR in obj_type.__mro__
is_arr = np.ndarray in obj_type.__mro__
return is_epo, is_tfr, is_arr
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here, in the _validate_cluster_df helper function, it turned out to be more helpful to return some boolean variables rather than return the object type itself. It would have been possible to leave this as-is, but then these checks against __mro__ would have needed to happen a few times in a few different places, so I thought it was simpler to just do it once here.

side note: FYI, obj.__mro__ stands for "method resolution order" --- a list of parent classes and the order in which they're checked when looking up an object's attributes or methods. For example: EpochsTFR defines a few methods (average, drop, ...) but doesn't define e.g. crop or get_data. So when you do my_epochs_tfr_obj.get_data(), python checks the __mro__ of the object, and goes through that list in order, looking at each parent class until it finds one that defines a get_data method, then uses the first one it finds. In this case, BaseTFR is next on the MRO list, and it does define get_data, so that's the method definition that gets used.

It turns out that EpochsTFR and EpochsTFRArray don't inherit from BaseEpochs, so isinstance(BaseEpochs) wasn't catching the EpochsTFR cases. But they do inherit from GetEpochsMixin (and so do time-domain Epochs and EpochsArray) so checking against __mro__ is a good way to capture both Epochs and EpochsTFR.

Comment on lines -1883 to -1899
def _extract_data_array(series):
outer_func = np.concatenate if is_epo else np.array
axes = (-3, -1) if is_tfr else (-2, -1)

def func_arr(series):
return np.concatenate(series.values)

def _extract_data_mne(series): # 2D data
return np.array(
series.map(lambda inst: inst.get_data().swapaxes(-2, -1)).to_list()
def func_mne(series):
return outer_func(
series.map(lambda inst: inst.get_data().swapaxes(*axes)).to_list()
)

def _extract_data_tfr(series):
return series.map(lambda inst: inst.get_data().swapaxes(-3, -1)).to_list()

if _dtype is np.ndarray:
func = _extract_data_array
elif _dtype is BaseTFR:
func = _extract_data_tfr
else:
func = _extract_data_mne
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[REFACTORED] I realized I was going to need to write a fourth extraction function (in addition to _extract_data_array, _extract_data_tfr, and _extract_data_mne) to handle Epochs and EpochsTFRs. Looking at the mne and tfr ones we'd already written, I noticed the TFR one was wrong (needed to be wrapped in an np.array), which made it clear how similar those two were --- only different in the swapaxes arguments --- and the new func for Epochs was only going to differ in the outer wrapper function (np.concatenate instead of np.array).

So now we can have two variables: outer_func and axes and generate the extraction function we need based on whether is_epo and is_tfr.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see my comment above, I am unsure if we properly handle the case of EpochsTFR (swapping epochs) and I think a quick dimension check might do the job.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the test confirms that we are in fact doing the right thing:

n_epo, n_chan, n_freq, n_times = 6, 3, 4, 5

since each dimension has a different size, I think that if we'd done something wrong then the test would fail.

cmap_evokeds: None | str | tuple = None,
cmap_topo: None | str | tuple = None,
ci: float | bool | callable() | None = None,
ci: float | bool | callable | None = None,
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unrelated parameter typing fix

Comment on lines +27 to +28
from ..epochs import BaseEpochs, EvokedArray
from ..evoked import Evoked
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some parts of the MNE API can be imported multiple places. For user scripts, it's OK to do from mne import Evoked, but within the MNE package, we should always do from mne.evoked import Evoked (or here, from ..evoked import Evoked, which is the equivalent relative import) because it ensures that we avoid any circular import problems. the way to check for this is pytest mne/tests/test_import_nesting.py

inst2 = Inst(
data=data_tfr2, info=info, times=np.arange(n_times), freqs=np.arange(n_freq)
)
rng = np.random.default_rng(seed=8675309)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should always explicitly set a random seed when testing. Otherwise there's a (slim) chance the test might fail occasionally just from randomness. Here, I create an np.random.Generator object called rng and then later call rng.normal() to actually generate the random values.

Comment on lines +929 to +933
# construct the instance
info = create_info(ch_names=n_chan, sfreq=1000, ch_types="eeg")
kw = dict(times=np.arange(n_times), freqs=np.arange(n_freq)) if is_tfr else dict()
cond_a = Inst(data=data, info=info, **kw)
cond_b = cond_a.copy()
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now we make two identical MNE objects. kw will be empty if we're not working with TFRs on this test run

Comment on lines +934 to +941
# introduce a significant difference in a specific region, time, and frequency
ch_start, ch_end = 0, 2 # 2 channels
t_start, t_end = 2, 4 # 2 times
f_start, f_end = 2, 4 # 2 freqs
if is_tfr:
cond_b._data[..., ch_start:ch_end, f_start:f_end, t_start:t_end] += 2
else:
data = [inst1, inst2]
cond_b._data[..., ch_start:ch_end, t_start:t_end] += 2
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is basically unchanged from what you'd written

Comment on lines +947 to +953
for _n in range(n_epo):
if not _n:
insts.append(cond)
continue
_cond = cond.copy()
_cond.data += rng.normal(scale=0.1, size=_cond.data.shape)
insts.append(_cond)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here, if not _n means "if n is zero (the first iteration of the loop), just add cond to the list. Otherwise, make a copy of cond, add noise to it, then add it to the list"

Comment on lines +973 to +979
axes = (2, 1, 0) if is_tfr else (1, 0)
Xa = list()
Xb = list()
for inst, cond in zip(insts, conds):
container = Xa if cond == "a" else Xb
container.append(inst.get_data().transpose(*axes))
X = [np.stack(Xa), np.stack(Xb)]
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here (the non-epochs case) we create separate lists for condition a and condition b, append to each list the inst.get_data().transpose(...) arrays, then use np.stack to concatenate each sub-list before combining into a single 2-element list X.

note: np.stack(Xa) automatically creates a new dimension at the beginning, so it's equivalent to np.concatenate([x[np.newaxis] for x in Xa], axis=0)

Comment on lines +985 to +987
assert len(result_new_api.clusters) == len(clusters)
for clu1, clu2 in zip(result_new_api.clusters, clusters):
assert_array_equal(clu1, clu2)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here we can't use == because "the truth value of an array is ambiguous", so we check the lengths of the outer lists are the same, then assert that each array within the list matches.

@drammock
Copy link
Author

@CarinaFo have a look through the files changed tab and see if you understand the changes and LMK if you have any questions. If you're satisfied, feel free to merge this now and it will become part of your PR into mne-python upstream.

# extract the data from the dataframe
def _extract_data_array(series):
outer_func = np.concatenate if is_epo else np.array
axes = (-3, -1) if is_tfr else (-2, -1)
Copy link
Owner

@CarinaFo CarinaFo Aug 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering about the EpochsTFR case, currently, we only swap the third and first dimension for tfr evoked (is_tfr check), but the same should be done for EpochsTFR, right?

# check wether series is 3D (TFR)
dimensions = series.size == 3

axes = (-3,-1) if is_tfr or dimensions else (-2, -1)

```

Copy link
Author

@drammock drammock Aug 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm. I'm inclined to say "I think it's correct as-is because the tests pass". But then I remember that I messed around a lot with which axes to swap in the old API part of the test, to get it to match the new API results' shape. So, either both this and the test are correct, or both this and the test are wrong. Let's try to reason through it:

  • The reason we want to swap axes is to get the channels on the last dimension. We want that because it makes specifying the adjacency easier: we can pass an adjacency matrix for channels/vertices, and for other dims just use "immediate neighbor" adjacency. But note that we technically don't have to do this: in this tutorial the adjacency is computed for the existing/unmodified shape of the EpochsTFR.data array https://mne.tools/dev/auto_tutorials/stats-sensor-space/40_cluster_1samp_time_freq.html#define-adjacency-for-statistics.
  • if we are going to swap axes, we should make sure that we also provide parameters like max_step for the non-channel dimensions. This is a bit tricky because we need to potentially do so for both time and frequency dimensions, but the function doesn't necessarily know which dimensions will be present.
  • So assuming that we do swap axes, what matters is getting "channel" at the end. Does it matter whether we end up with "freq, time, chan" or "time, freq, chan"? I think it doesn't matter, as long as we're careful about our bookkeeping for which order we have. This doesn't matter yet because we haven't actually handled the max_step parameters for the freq and time dimensions (but that's something we will need to handle soon).
    • Epochs is (epo, chan, time), Evoked is (chan, time). In both cases swapping axis -1 (time) with axis -2 (chan) seems like the right solution (because we definitely want to keep "epo" as the first axis).
    • EpochsTFR is (epo, chan, freq, time). So swapping axis -1 (time) with axis -3 (chan) will give us (epo, time, freq, chan).
    • AverageTFR is (chan, freq, time). So swapping axis -1 (time) with axis -3 (chan) will give us (time, freq, chan).

So, I think what we're doing here is OK as long as we keep track of dimension order for TFRs and act accordingly.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

...wait, maybe I misunderstood your question. Perhaps you think is_tfr is only for AverageTFR? in fact is_tfr should be true for both EpochsTFR and AverageTFR, so the same dims (-3, -1) will get swapped for any kind of TFR object.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought that is_tfr only handles AverageTFR, which might create issues. After checking the code again, I agree that it should be fine, EpochsTFR means is_tfr == True as well as is_epo == True.

@CarinaFo CarinaFo merged commit 81ce0d0 into CarinaFo:new_cluster_stats_api_GSOC24 Aug 21, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants