-
Notifications
You must be signed in to change notification settings - Fork 1
fixes for test #5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fixes for test #5
Conversation
| obj_type = all_types.pop() | ||
| is_epo = GetEpochsMixin in obj_type.__mro__ | ||
| is_tfr = BaseTFR in obj_type.__mro__ | ||
| is_arr = np.ndarray in obj_type.__mro__ | ||
| return is_epo, is_tfr, is_arr |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here, in the _validate_cluster_df helper function, it turned out to be more helpful to return some boolean variables rather than return the object type itself. It would have been possible to leave this as-is, but then these checks against __mro__ would have needed to happen a few times in a few different places, so I thought it was simpler to just do it once here.
side note: FYI,
obj.__mro__stands for "method resolution order" --- a list of parent classes and the order in which they're checked when looking up an object's attributes or methods. For example: EpochsTFR defines a few methods (average, drop, ...) but doesn't define e.g.croporget_data. So when you domy_epochs_tfr_obj.get_data(), python checks the__mro__of the object, and goes through that list in order, looking at each parent class until it finds one that defines aget_datamethod, then uses the first one it finds. In this case, BaseTFR is next on the MRO list, and it does defineget_data, so that's the method definition that gets used.
It turns out that EpochsTFR and EpochsTFRArray don't inherit from BaseEpochs, so isinstance(BaseEpochs) wasn't catching the EpochsTFR cases. But they do inherit from GetEpochsMixin (and so do time-domain Epochs and EpochsArray) so checking against __mro__ is a good way to capture both Epochs and EpochsTFR.
| def _extract_data_array(series): | ||
| outer_func = np.concatenate if is_epo else np.array | ||
| axes = (-3, -1) if is_tfr else (-2, -1) | ||
|
|
||
| def func_arr(series): | ||
| return np.concatenate(series.values) | ||
|
|
||
| def _extract_data_mne(series): # 2D data | ||
| return np.array( | ||
| series.map(lambda inst: inst.get_data().swapaxes(-2, -1)).to_list() | ||
| def func_mne(series): | ||
| return outer_func( | ||
| series.map(lambda inst: inst.get_data().swapaxes(*axes)).to_list() | ||
| ) | ||
|
|
||
| def _extract_data_tfr(series): | ||
| return series.map(lambda inst: inst.get_data().swapaxes(-3, -1)).to_list() | ||
|
|
||
| if _dtype is np.ndarray: | ||
| func = _extract_data_array | ||
| elif _dtype is BaseTFR: | ||
| func = _extract_data_tfr | ||
| else: | ||
| func = _extract_data_mne |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[REFACTORED] I realized I was going to need to write a fourth extraction function (in addition to _extract_data_array, _extract_data_tfr, and _extract_data_mne) to handle Epochs and EpochsTFRs. Looking at the mne and tfr ones we'd already written, I noticed the TFR one was wrong (needed to be wrapped in an np.array), which made it clear how similar those two were --- only different in the swapaxes arguments --- and the new func for Epochs was only going to differ in the outer wrapper function (np.concatenate instead of np.array).
So now we can have two variables: outer_func and axes and generate the extraction function we need based on whether is_epo and is_tfr.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see my comment above, I am unsure if we properly handle the case of EpochsTFR (swapping epochs) and I think a quick dimension check might do the job.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the test confirms that we are in fact doing the right thing:
n_epo, n_chan, n_freq, n_times = 6, 3, 4, 5since each dimension has a different size, I think that if we'd done something wrong then the test would fail.
| cmap_evokeds: None | str | tuple = None, | ||
| cmap_topo: None | str | tuple = None, | ||
| ci: float | bool | callable() | None = None, | ||
| ci: float | bool | callable | None = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unrelated parameter typing fix
| from ..epochs import BaseEpochs, EvokedArray | ||
| from ..evoked import Evoked |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some parts of the MNE API can be imported multiple places. For user scripts, it's OK to do from mne import Evoked, but within the MNE package, we should always do from mne.evoked import Evoked (or here, from ..evoked import Evoked, which is the equivalent relative import) because it ensures that we avoid any circular import problems. the way to check for this is pytest mne/tests/test_import_nesting.py
| inst2 = Inst( | ||
| data=data_tfr2, info=info, times=np.arange(n_times), freqs=np.arange(n_freq) | ||
| ) | ||
| rng = np.random.default_rng(seed=8675309) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should always explicitly set a random seed when testing. Otherwise there's a (slim) chance the test might fail occasionally just from randomness. Here, I create an np.random.Generator object called rng and then later call rng.normal() to actually generate the random values.
| # construct the instance | ||
| info = create_info(ch_names=n_chan, sfreq=1000, ch_types="eeg") | ||
| kw = dict(times=np.arange(n_times), freqs=np.arange(n_freq)) if is_tfr else dict() | ||
| cond_a = Inst(data=data, info=info, **kw) | ||
| cond_b = cond_a.copy() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
now we make two identical MNE objects. kw will be empty if we're not working with TFRs on this test run
| # introduce a significant difference in a specific region, time, and frequency | ||
| ch_start, ch_end = 0, 2 # 2 channels | ||
| t_start, t_end = 2, 4 # 2 times | ||
| f_start, f_end = 2, 4 # 2 freqs | ||
| if is_tfr: | ||
| cond_b._data[..., ch_start:ch_end, f_start:f_end, t_start:t_end] += 2 | ||
| else: | ||
| data = [inst1, inst2] | ||
| cond_b._data[..., ch_start:ch_end, t_start:t_end] += 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is basically unchanged from what you'd written
| for _n in range(n_epo): | ||
| if not _n: | ||
| insts.append(cond) | ||
| continue | ||
| _cond = cond.copy() | ||
| _cond.data += rng.normal(scale=0.1, size=_cond.data.shape) | ||
| insts.append(_cond) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here, if not _n means "if n is zero (the first iteration of the loop), just add cond to the list. Otherwise, make a copy of cond, add noise to it, then add it to the list"
| axes = (2, 1, 0) if is_tfr else (1, 0) | ||
| Xa = list() | ||
| Xb = list() | ||
| for inst, cond in zip(insts, conds): | ||
| container = Xa if cond == "a" else Xb | ||
| container.append(inst.get_data().transpose(*axes)) | ||
| X = [np.stack(Xa), np.stack(Xb)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here (the non-epochs case) we create separate lists for condition a and condition b, append to each list the inst.get_data().transpose(...) arrays, then use np.stack to concatenate each sub-list before combining into a single 2-element list X.
note: np.stack(Xa) automatically creates a new dimension at the beginning, so it's equivalent to np.concatenate([x[np.newaxis] for x in Xa], axis=0)
| assert len(result_new_api.clusters) == len(clusters) | ||
| for clu1, clu2 in zip(result_new_api.clusters, clusters): | ||
| assert_array_equal(clu1, clu2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here we can't use == because "the truth value of an array is ambiguous", so we check the lengths of the outer lists are the same, then assert that each array within the list matches.
|
@CarinaFo have a look through the files changed tab and see if you understand the changes and LMK if you have any questions. If you're satisfied, feel free to merge this now and it will become part of your PR into mne-python upstream. |
| # extract the data from the dataframe | ||
| def _extract_data_array(series): | ||
| outer_func = np.concatenate if is_epo else np.array | ||
| axes = (-3, -1) if is_tfr else (-2, -1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am wondering about the EpochsTFR case, currently, we only swap the third and first dimension for tfr evoked (is_tfr check), but the same should be done for EpochsTFR, right?
# check wether series is 3D (TFR)
dimensions = series.size == 3
axes = (-3,-1) if is_tfr or dimensions else (-2, -1)
```
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm. I'm inclined to say "I think it's correct as-is because the tests pass". But then I remember that I messed around a lot with which axes to swap in the old API part of the test, to get it to match the new API results' shape. So, either both this and the test are correct, or both this and the test are wrong. Let's try to reason through it:
- The reason we want to swap axes is to get the channels on the last dimension. We want that because it makes specifying the adjacency easier: we can pass an adjacency matrix for channels/vertices, and for other dims just use "immediate neighbor" adjacency. But note that we technically don't have to do this: in this tutorial the adjacency is computed for the existing/unmodified shape of the EpochsTFR.data array https://mne.tools/dev/auto_tutorials/stats-sensor-space/40_cluster_1samp_time_freq.html#define-adjacency-for-statistics.
- if we are going to swap axes, we should make sure that we also provide parameters like
max_stepfor the non-channel dimensions. This is a bit tricky because we need to potentially do so for both time and frequency dimensions, but the function doesn't necessarily know which dimensions will be present. - So assuming that we do swap axes, what matters is getting "channel" at the end. Does it matter whether we end up with "freq, time, chan" or "time, freq, chan"? I think it doesn't matter, as long as we're careful about our bookkeeping for which order we have. This doesn't matter yet because we haven't actually handled the
max_stepparameters for the freq and time dimensions (but that's something we will need to handle soon).- Epochs is (epo, chan, time), Evoked is (chan, time). In both cases swapping axis -1 (time) with axis -2 (chan) seems like the right solution (because we definitely want to keep "epo" as the first axis).
- EpochsTFR is (epo, chan, freq, time). So swapping axis -1 (time) with axis -3 (chan) will give us (epo, time, freq, chan).
- AverageTFR is (chan, freq, time). So swapping axis -1 (time) with axis -3 (chan) will give us (time, freq, chan).
So, I think what we're doing here is OK as long as we keep track of dimension order for TFRs and act accordingly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
...wait, maybe I misunderstood your question. Perhaps you think is_tfr is only for AverageTFR? in fact is_tfr should be true for both EpochsTFR and AverageTFR, so the same dims (-3, -1) will get swapped for any kind of TFR object.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought that is_tfr only handles AverageTFR, which might create issues. After checking the code again, I agree that it should be fine, EpochsTFR means is_tfr == True as well as is_epo == True.
don't merge yet. will add some comments on monday. Thought it would be more helpful/instructive to open as PR-into-your-PR rather than just pushing to your branch.