diff --git a/README.md b/README.md index bd1ef49..81b6cf1 100644 --- a/README.md +++ b/README.md @@ -10,9 +10,8 @@ Works for Gwilliams2022 dataset and Brennan2018 dataset. ## TODOs -- [ ] Full reproducibility support. Will be useful for HP tuning. -- [ ] Match accuracy to numbers reported in the paper. -- [ ] Work with huge memory consumption issue in Gwilliams multiprocessing +- [ ] Achieve accuracies in the paper (Brennan -> Gwilliams). +- [ ] Reorganize Gwilliams dataclass to look more similar to Brennan (using Base class). # Usage diff --git a/assets/questioned_lines1.png b/assets/questioned_lines1.png new file mode 100644 index 0000000..459385e Binary files /dev/null and b/assets/questioned_lines1.png differ diff --git a/assets/reports.md b/assets/reports.md index 43d7b01..3ee7154 100644 --- a/assets/reports.md +++ b/assets/reports.md @@ -1,6 +1,19 @@ -# Brennan2018 +# Questions -## Experiments +- About "We ensure that there is no identical sentences across splits" below, can it be rephrased as "We ensure that there is no two segments across splits that are coming from the same sentence"? + +- About "we restrict the test segments to those that contain a word at a fixed location (here 500ms into the sample)" below, can it be rephrased as "we dropped the test segments that don't contain any word until 500ms"? In that case why does it happend when we segment with word onsets? + + + +- After wav2vec2.0 embedding, audios become something like 50Hz (because wav2vec2.0 requires them to be originally 16kHz and it downsamples them a lot), so we need to upsample them to match brains' 120Hz. Do you actually do that? If so, which method do you use? We've tried linear interpolation by torchaudio and zero-padding but neither worked well. + +- Learnable temperature of CLIP loss was not mentioned and absent in equation (2). It was mentioned in the original CLIP paper but do you actually use it? + + +# Experiment results + +## Brennan2018 - Basically couldn't achieve performance in the paper. @@ -29,8 +42,5 @@ - Channel-wise - Robust Scaler was applied channel-wise -## Questions - -- After wav2vec2.0 embedding, audios become somewhere like 50Hz (because wav2vec2.0 requires them to be originally 16kHz and it downsamples them a lot), so we need to upsample them to match brains' 120Hz. Do you actually do that? If so, which method do you use? We've tried linear interpolation by torchaudio and zero-padding but neither worked well. -- Learnable temperature of CLIP loss was not mentioned and absent in equation (2). It was mentioned in the original CLIP paper but do you actually use it? \ No newline at end of file +## Gwilliams2022 diff --git a/configs/config.yaml b/configs/config.yaml index bf4962d..d2d79f9 100644 --- a/configs/config.yaml +++ b/configs/config.yaml @@ -18,7 +18,7 @@ split_mode: sentence # sentence, shallow, deep num_workers: 6 batch_size: 64 updates: 1200 -lr: 3e-4 +lr: 3.0e-4 epochs: 300 reduction: mean @@ -43,7 +43,7 @@ preprocs: shift_brain: True # whether to shift M/EEG into the future relative to audio shift_len: 150 # if True, by how many ms last4layers: True # if True, the brain_encoder's emsize will be 1024, not 512 - channel_wise: True # whether to scale each channel's EEG dataset individually (only for Brennan2018) + channel_wise: True # Whether to scale each channel of MEG/EEG datasets individually clamp: True clamp_lim: 20 y_upsample: interpolate # interpolate / pad diff --git a/speech_decoding/dataclass/base.py b/speech_decoding/dataclass/base.py new file mode 100644 index 0000000..3c0d5d0 --- /dev/null +++ b/speech_decoding/dataclass/base.py @@ -0,0 +1,26 @@ +import torch +import torchaudio.functional as F + +from termcolor import cprint + + +class SpeechDecodingDatasetBase(torch.utils.data.Dataset): + def __init__(self, args): + super().__init__() + + def _resample_audio( + self, waveform: torch.Tensor, sample_rate: int, resample_rate: int = 16000 + ) -> torch.Tensor: + """Resamples audio to 16kHz. (16kHz is required by wav2vec2.0)""" + + waveform = F.resample( + waveform, + sample_rate, + resample_rate, + lowpass_filter_width=self.lowpass_filter_width, + ) + + len_audio_s = waveform.shape[1] / resample_rate + cprint(f">>> Audio length: {len_audio_s} s.", color="cyan") + + return waveform diff --git a/speech_decoding/dataclass/brennan2018.py b/speech_decoding/dataclass/brennan2018.py index d10ffae..1320d0b 100644 --- a/speech_decoding/dataclass/brennan2018.py +++ b/speech_decoding/dataclass/brennan2018.py @@ -20,6 +20,7 @@ from transformers import Wav2Vec2Model +from speech_decoding.dataclass.base import SpeechDecodingDatasetBase from speech_decoding.utils.wav2vec_util import get_last4layers_avg from speech_decoding.utils.preproc_utils import ( shift_brain_signal, @@ -48,18 +49,16 @@ """ -class Brennan2018Dataset(Dataset): +class Brennan2018Dataset(SpeechDecodingDatasetBase): def __init__(self, args): - super().__init__() + super().__init__(args) # Both - chance = args.chance - force_recompute = args.rebuild_dataset self.split_mode = args.split_mode self.root_dir = args.root_dir - self.channel_wise = args.preprocs.channel_wise self.seq_len_sec = args.preprocs.seq_len_sec # EEG + self.channel_wise = args.preprocs.channel_wise self.filter_brain = args.preprocs.filter self.brain_filter_low = args.preprocs.brain_filter_low self.brain_filter_high = args.preprocs.brain_filter_high @@ -72,8 +71,12 @@ def __init__(self, args): self.lowpass_filter_width = args.preprocs.lowpass_filter_width self.wav2vec = Wav2Vec2Model.from_pretrained(args.wav2vec_model) # Data Paths - self.matfile_paths = natsorted(glob.glob(f"{self.root_dir}/data/Brennan2018/raw/*.mat")) - self.audio_paths = natsorted(glob.glob(f"{self.root_dir}/data/Brennan2018/audio/*.wav")) + self.matfile_paths = natsorted( + glob.glob(f"{self.root_dir}/data/Brennan2018/raw/*.mat") + ) + self.audio_paths = natsorted( + glob.glob(f"{self.root_dir}/data/Brennan2018/audio/*.wav") + ) self.onsets_path = f"{self.root_dir}/data/Brennan2018/AliceChapterOne-EEG.csv" # Save Paths X_path = f"{self.root_dir}/data/Brennan2018/X.pt" @@ -81,7 +84,9 @@ def __init__(self, args): sentence_idxs_path = f"{self.root_dir}/data/Brennan2018/sentence_idxs.npy" # Rebuild dataset - if force_recompute or not (os.path.exists(X_path) and os.path.exists(Y_path)): + if args.rebuild_dataset or not ( + os.path.exists(X_path) and os.path.exists(Y_path) + ): cprint(f"> Preprocessing EEG and audio.", color="cyan") self.X, self.Y, self.sentence_idxs = self.rebuild_dataset( @@ -109,7 +114,10 @@ def __init__(self, args): self.num_subjects = self.X.shape[1] cprint(f">>> Number of subjects: {self.num_subjects}", color="cyan") - cprint(f">> Upsampling audio embedding with: {args.preprocs.y_upsample}", color="cyan") + cprint( + f">> Upsampling audio embedding with: {args.preprocs.y_upsample}", + color="cyan", + ) if args.preprocs.y_upsample == "interpolate": self.Y = interpolate_y_time(self.Y, self.brain_num_samples) elif args.preprocs.y_upsample == "pad": @@ -117,7 +125,7 @@ def __init__(self, args): else: raise ValueError(f"Unknown upsampling strategy: {args.preprocs.y_upsample}") - if chance: + if args.chance: self.Y = self.Y[torch.randperm(len(self.Y))] def __len__(self): @@ -144,11 +152,11 @@ def rebuild_dataset( # ---------------------- waveform = [torchaudio.load(path) for path in audio_paths] - audio_rate = self.get_audio_rate(waveform) + audio_rate = self._get_audio_rate(waveform) waveform = torch.cat([w[0] for w in waveform], dim=1) # ( 1, time@44.1kHz ) - audio = self.resample_audio(waveform, audio_rate, AUDIO_RESAMPLE_RATE) + audio = self._resample_audio(waveform, audio_rate, AUDIO_RESAMPLE_RATE) cprint( f">>> Resampled audio {audio_rate}Hz -> {AUDIO_RESAMPLE_RATE}Hz | shape: {waveform.shape} -> {audio.shape}", @@ -159,12 +167,14 @@ def rebuild_dataset( # EEG Loading # ---------------------- matfile_paths = [ - path for path in matfile_paths if not path.split(".")[0][-3:] in EXCLUDED_SUBJECTS + path + for path in matfile_paths + if not path.split(".")[0][-3:] in EXCLUDED_SUBJECTS ] mat_raws = [scipy.io.loadmat(path)["raw"][0, 0] for path in matfile_paths] eeg_raws = [mat_raw["trial"][0, 0] for mat_raw in mat_raws] - eeg_rate = self.get_eeg_rate(mat_raws) + eeg_rate = self._get_eeg_rate(mat_raws) # ---------------------- # Preprocessing @@ -183,7 +193,9 @@ def rebuild_dataset( if self.split_mode == "sentence": cprint(">> Dropping last segments of each sentence.", color="cyan") - X, audio, sentence_idxs = self.drop_last_segments(X, audio, onsets_path) + + X, audio, sentence_idxs = self._drop_last_segments(X, audio, onsets_path) + cprint(f">>> X (EEG): {X.shape} | Audio: {audio.shape}", color="cyan") else: sentence_idxs = None @@ -201,13 +213,20 @@ def rebuild_dataset( return X, Y, sentence_idxs - def drop_last_segments( - self, X: torch.Tensor, audio: torch.Tensor, onsets_path: str + @staticmethod + def _drop_last_segments( + X: torch.Tensor, audio: torch.Tensor, onsets_path: str ) -> Tuple[torch.Tensor, torch.Tensor, np.ndarray]: """Drops last segments of each sentence. FIXME: currently drops last 5 words, but this number should be variable to cover 3 secs. + Args: + X: ( segment, subject, channel, time@120Hz//segment ) + audio: ( segment, 1, time@16kHz//segment ) + Returns: + X: ( _segment, subject, channel, time@120Hz//segment ) + audio: ( _segment, 1, time@16kHz//segment ) """ - num_drops = 5 + NUM_DROPS = 5 sentence_idxs = pd.read_csv(onsets_path).Sentence.to_numpy() assert np.all( @@ -215,8 +234,18 @@ def drop_last_segments( ), "sentence_idxs is not a non-decreasing step sequence." sentence_ends = np.where(np.diff(sentence_idxs) > 0)[0] + # NOTE: end idx for the last sentence + sentence_ends = np.append(sentence_ends, sentence_idxs.shape[0] - 1) + + # NOTE: There are sentences that are shorter than NUM_DROPS, but the overlap is not + # a problem as we convert it to a boolean array. + # assert np.all( + # np.diff(np.append(-1, sentence_ends)) > NUM_DROPS + # ), f"Some sentence(s) are not longer than NUM_DROPS={NUM_DROPS}." - drop_idxs = np.concatenate([np.arange(i - num_drops, i) + 1 for i in sentence_ends]) + drop_idxs = np.concatenate( + [np.arange(i - NUM_DROPS, i) + 1 for i in sentence_ends] + ) # NOTE: to boolean array drop_bools = np.ones(len(X), dtype=bool) @@ -224,23 +253,6 @@ def drop_last_segments( return X[drop_bools], audio[drop_bools], sentence_idxs[drop_bools] - def resample_audio( - self, waveform: torch.Tensor, sample_rate: int, resample_rate: int = 16000 - ) -> torch.Tensor: - """Resamples audio to 16kHz. (16kHz is required by wav2vec2.0)""" - - waveform = F.resample( - waveform, - sample_rate, - resample_rate, - lowpass_filter_width=self.lowpass_filter_width, - ) - - len_audio_s = waveform.shape[1] / resample_rate - cprint(f">>> Audio length: {len_audio_s} s.", color="cyan") - - return waveform - def embed_audio(self, audio: torch.Tensor) -> torch.Tensor: """ Args: @@ -316,7 +328,6 @@ def resample_brain( eeg_raw = torch.from_numpy(eeg_raw.astype(np.float32)) - # NOTE: in the paper they say they downsampled to exactly 120Hz with Torchaudio, so I'll stick to that eeg_resampled = F.resample( waveform=eeg_raw, orig_freq=fsample, @@ -328,16 +339,20 @@ def resample_brain( return torch.stack(X) - def get_audio_rate(self, waveform: List[Tuple[torch.Tensor, int]]) -> int: + @staticmethod + def _get_audio_rate(waveform: List[Tuple[torch.Tensor, int]]) -> int: sample_rates = np.array([w[1] for w in waveform]) # is all 44.1kHz assert np.all(sample_rates == sample_rates[0]) return sample_rates[0] - def get_eeg_rate(self, mat_raws: List) -> int: + @staticmethod + def _get_eeg_rate(mat_raws: List) -> int: sample_rates = np.array([mat_raw["fsample"][0, 0] for mat_raw in mat_raws]) # is all 500Hz - assert np.all(sample_rates == sample_rates[0]), "Wrong EEG sampling rate detected." + assert np.all( + sample_rates == sample_rates[0] + ), "Wrong EEG sampling rate detected." return sample_rates[0] diff --git a/speech_decoding/dataclass/gwilliams2022.py b/speech_decoding/dataclass/gwilliams2022.py index 3a8ac58..dd7ba28 100644 --- a/speech_decoding/dataclass/gwilliams2022.py +++ b/speech_decoding/dataclass/gwilliams2022.py @@ -14,7 +14,7 @@ import mne, mne_bids from tqdm import tqdm import ast -from typing import Union, Tuple, List +from typing import Union, Tuple, List, Dict from psutil import virtual_memory as vm from termcolor import cprint from pprint import pprint @@ -23,12 +23,14 @@ from itertools import repeat from omegaconf import open_dict +from speech_decoding.dataclass.base import SpeechDecodingDatasetBase from speech_decoding.utils.wav2vec_util import get_last4layers_avg from speech_decoding.utils.preproc_utils import ( check_preprocs, continuous, scale_and_clamp, ) +from speech_decoding.constants import BRAIN_RESAMPLE_RATE, AUDIO_RESAMPLE_RATE mne.set_log_level(verbose="WARNING") @@ -38,12 +40,14 @@ global_sentence_idxs = manager.dict() -class Gwilliams2022DatasetBase(Dataset): +class Gwilliams2022DatasetBase(SpeechDecodingDatasetBase): def __init__(self, args): super().__init__() + # Both + self.root_dir = args.root_dir # + "/data/Gwilliams2022/" + # MEG self.wav2vec_model = args.wav2vec_model - self.root_dir = args.root_dir + "/data/Gwilliams2022/" self.brain_orig_rate = 1000 self.brain_resample_rate = args.preprocs["brain_resample_rate"] self.brain_filter_low = args.preprocs["brain_filter_low"] @@ -74,7 +78,7 @@ def __init__(self, args): # Preprocess X (MEG) # --------------------------- if args.rebuild_dataset or not args.preprocs["x_done"]: - _out = self.brain_preproc_handler() + _out = self.brain_preproc() self.X, self.meg_onsets, self.speech_onsets, self.sentence_idxs = _out np.save(self.x_path, self.X) @@ -89,8 +93,12 @@ def __init__(self, args): else: self.X = np.load(self.x_path, allow_pickle=True).item() self.meg_onsets = np.load(self.meg_onsets_path, allow_pickle=True).item() - self.speech_onsets = np.load(self.speech_onsets_path, allow_pickle=True).item() - self.sentence_idxs = np.load(self.sentence_idxs_path, allow_pickle=True).item() + self.speech_onsets = np.load( + self.speech_onsets_path, allow_pickle=True + ).item() + self.sentence_idxs = np.load( + self.sentence_idxs_path, allow_pickle=True + ).item() # ---------------------------------------- # Preprocess Y (embedded speech) @@ -117,19 +125,23 @@ def __init__(self, args): assert len(self.X) == len(self.meg_onsets) - self.valid_subjects = np.array(list(set([k.split("_")[0] for k in self.X.keys()]))) + self.valid_subjects = np.array( + list(set([k.split("_")[0] for k in self.X.keys()])) + ) self.num_subjects = len(self.valid_subjects) cprint(f"X keys: {self.X.keys()}", color="cyan") cprint(f"Y: {self.Y.shape}", color="cyan") - cprint(f"num_subjects: {self.num_subjects} (each has 2 or 1 sessions)", color="cyan") + cprint( + f"num_subjects: {self.num_subjects} (each has 2 or 1 sessions)", color="cyan" + ) print(self.valid_subjects) def __len__(self): return len(self.Y) def __getitem__(self, i): # NOTE: i is id of a speech segment - i_in_task, task = self.segment_to_task(i) + i_in_task, task = self._segment_to_task(i) key_no_task = np.random.choice(list(self.X.keys())) X = self.X[key_no_task][task] # ( 208, ~100000 ) @@ -142,7 +154,13 @@ def __getitem__(self, i): # NOTE: i is id of a speech segment return X, self.Y[i], subject_idx - def segment_to_task(self, i) -> Tuple[int, str]: + def rebuild_dataset(self): + audio_tasks = self.load_resample_audio() + + _X = self.load_resample_brain() + self.X, self.meg_onsets, self.speech_onsets, self.sentence_idxs = _X + + def _segment_to_task(self, i) -> Tuple[int, str]: nseg_task_accum = np.cumsum(self.num_segments_foreach_task) task = np.searchsorted(nseg_task_accum, i + 1) @@ -162,7 +180,9 @@ def segment_speech(self, data: torch.Tensor, key: str) -> torch.Tensor: def sentence_to_word_idxs(self, _sentence_idxs, key): return [ i - for si, i in zip(self.sentence_idxs[key], np.arange(len(self.sentence_idxs[key]))) + for si, i in zip( + self.sentence_idxs[key], np.arange(len(self.sentence_idxs[key])) + ) if si in _sentence_idxs ] @@ -189,22 +209,24 @@ def drop_task_missing_sessions(self) -> None: self.meg_onsets.pop(key) @staticmethod - def brain_preproc(dat): - subject_idx, d, speech_onsets, meg_onsets, sentence_idxs, session_idx, task_idx = dat - - num_channels = d["num_channels"] - brain_orig_rate = d["brain_orig_rate"] - brain_filter_low = d["brain_filter_low"] - brain_filter_high = d["brain_filter_high"] - brain_resample_rate = d["brain_resample_rate"] - root_dir = d["root_dir"] - preproc_dir = d["preproc_dir"] - - description = f"subject{str(subject_idx+1).zfill(2)}_sess{session_idx}_task{task_idx}" + def _load_resample_brain( + args: dict, + subject_idx: int, + speech_onsets: np.ndarray, + meg_onsets: np.ndarray, + sentence_idxs: np.ndarray, + session_idx: int, + task_idx: int, + num_channels: int, + root_dir: str, + preproc_dir: str, + ) -> None: + description = ( + f"subject{str(subject_idx+1).zfill(2)}_sess{session_idx}_task{task_idx}" + ) bids_path = mne_bids.BIDSPath( - subject=str(subject_idx + 1).zfill(2), - # '01', '02', ... + subject=str(subject_idx + 1).zfill(2), # '01', '02', ... session=str(session_idx), task=str(task_idx), datatype="meg", @@ -215,7 +237,7 @@ def brain_preproc(dat): raw = mne_bids.read_raw_bids(bids_path) except: cprint("No .con data was found", color="yellow") - return 1 + return None cprint(description, color="cyan") @@ -244,43 +266,48 @@ def brain_preproc(dat): speech_onsets.update({task_str: _speech_onsets}) sentence_idxs.update({task_str: _sentence_idxs}) - meg_raw = np.stack([df[key] for key in df.keys() if "MEG" in key]) # ( 224, ~396000 ) + meg_raw = np.stack( + [df[key] for key in df.keys() if "MEG" in key] + ) # ( 224, ~396000 ) # NOTE: (kind of) confirmed that last 16 channels are REF meg_raw = meg_raw[:num_channels] # ( 208, ~396000 ) - meg_filtered = mne.filter.filter_data( - meg_raw, - sfreq=brain_orig_rate, - l_freq=brain_filter_low, - h_freq=brain_filter_high, - ) + if filter: + meg_raw = mne.filter.filter_data( + meg_raw, + sfreq=args.brain_orig_rate, + l_freq=args.brain_filter_low, + h_freq=args.brain_filter_high, + ) # To 120 Hz - meg_resampled = mne.filter.resample( - meg_filtered, - down=brain_orig_rate / brain_resample_rate, + meg_resampled = F.resample( + waveform=meg_raw, + orig_freq=args.brain_orig_rate, + new_freq=args.brain_resample_rate, + lowpass_filter_width=args.lowpass_filter_width, ) # ( 208, 37853 ) np.save( f"{preproc_dir}_parts/{description}", meg_resampled, ) - return 0 + # return 0 - def brain_preproc_handler(self, num_subjects=27, num_channels=208): + def load_resample_brain(self, args, num_subjects: int, num_channels: int = 208): tmp_dir = self.preproc_dir + "_parts/" if not os.path.exists(tmp_dir): os.mkdir(tmp_dir) - consts = dict( - num_channels=num_channels, - brain_orig_rate=self.brain_orig_rate, - brain_filter_low=self.brain_filter_low, - brain_filter_high=self.brain_filter_high, - brain_resample_rate=self.brain_resample_rate, - root_dir=self.root_dir, - preproc_dir=self.preproc_dir, - ) + # consts = dict( + # num_channels=num_channels, + # brain_orig_rate=self.brain_orig_rate, + # brain_filter_low=self.brain_filter_low, + # brain_filter_high=self.brain_filter_high, + # brain_resample_rate=self.brain_resample_rate, + # root_dir=self.root_dir, + # preproc_dir=self.preproc_dir, + # ) subj_list = [] for subj in range(num_subjects): @@ -288,8 +315,8 @@ def brain_preproc_handler(self, num_subjects=27, num_channels=208): for task_idx in range(4): subj_list.append( ( + args.preprocs, subj, - consts, global_speech_onsets, global_meg_onsets, global_sentence_idxs, @@ -301,7 +328,7 @@ def brain_preproc_handler(self, num_subjects=27, num_channels=208): with Pool(processes=20) as p: res = list( tqdm( - p.imap(self.brain_preproc, subj_list), + p.imap(self._load_resample_brain, subj_list), total=len(subj_list), bar_format="{desc:<5.5}{percentage:3.0f}%|{bar:10}{r_bar}", ) @@ -316,16 +343,58 @@ def brain_preproc_handler(self, num_subjects=27, num_channels=208): # NOTE: assemble files into one and clean up fnames = natsorted(os.listdir(tmp_dir)) - # cprint(fnames, color='yellow') + + # NOTE: data MUST be task0, ... taskN, task0, ..., taskN (N=4) X = dict() - for fname in fnames: # NOTE: data MUST be task0, ... taskN, task0, ..., taskN (N=4) + for fname in fnames: key = os.path.splitext(fname)[0] X[key] = np.load(tmp_dir + fname, allow_pickle=True) cprint("removing temp files for EEG data", color="white") shutil.rmtree(tmp_dir) - return X, dict(global_meg_onsets), dict(global_speech_onsets), dict(global_sentence_idxs) + return ( + X, + dict(global_meg_onsets), + dict(global_speech_onsets), + dict(global_sentence_idxs), + ) + + def load_resample_audio(self) -> Dict[str, torch.Tensor]: + audio_tasks = {} + assert os.path.exists( + f"{self.root_dir}stimuli/audio" + ), "Path data/Gwilliams2022/stimuli/audio doesn't exist." + + for task_idx in self.speech_onsets.keys(): # 4 tasks for each subject + task_idx_ID = int(task_idx[-1]) + + audio_paths = natsorted( + glob.glob( + f"{self.root_dir}stimuli/audio/{self.task_prefixes[task_idx_ID]}*.wav" + ) + ) + + audio_list = [] + for path in audio_paths: + waveform, sample_rate = torchaudio.load(path) + + # Upsample to 16000Hz + # waveform = F.resample( + # waveform, + # orig_freq=sample_rate, + # new_freq=self.audio_resample_rate, + # lowpass_filter_width=self.lowpass_filter_width, + # ) + audio = self._resample_audio(waveform, sample_rate, AUDIO_RESAMPLE_RATE) + + audio_list.append(audio) + + audio = torch.cat(audio_list, dim=-1) + + audio_tasks.update({task_idx: audio}) + + return audio_tasks @torch.no_grad() def audio_preproc(self): @@ -341,7 +410,9 @@ def audio_preproc(self): task_idx_ID = int(task_idx[-1]) audio_paths = natsorted( - glob.glob(f"{self.root_dir}stimuli/audio/{self.task_prefixes[task_idx_ID]}*.wav") + glob.glob( + f"{self.root_dir}stimuli/audio/{self.task_prefixes[task_idx_ID]}*.wav" + ) ) audio_raw = [] @@ -647,7 +718,9 @@ def __init__(self, args): super(Gwilliams2022Collator, self).__init__() self.brain_resample_rate = args.preprocs["brain_resample_rate"] - self.baseline_len_samp = int(self.brain_resample_rate * args.preprocs["baseline_len_sec"]) + self.baseline_len_samp = int( + self.brain_resample_rate * args.preprocs["baseline_len_sec"] + ) self.clamp = args.preprocs["clamp"] self.clamp_lim = args.preprocs["clamp_lim"] diff --git a/speech_decoding/utils/get_dataloaders.py b/speech_decoding/utils/get_dataloaders.py index 69a89b1..427520c 100644 --- a/speech_decoding/utils/get_dataloaders.py +++ b/speech_decoding/utils/get_dataloaders.py @@ -1,4 +1,5 @@ from torch.utils.data import DataLoader, RandomSampler, BatchSampler +from termcolor import cprint def get_dataloaders(train_set, test_set, args, g, seed_worker, test_bsz=None): diff --git a/speech_decoding/utils/preproc_utils.py b/speech_decoding/utils/preproc_utils.py index aa3a718..063f1e1 100644 --- a/speech_decoding/utils/preproc_utils.py +++ b/speech_decoding/utils/preproc_utils.py @@ -104,7 +104,8 @@ def scale_and_clamp(X: torch.Tensor, clamp_lim: Union[int, float], channel_wise= if orig_dim == 4: num_segments = X.shape[0] - X = X.clone().flatten(end_dim=1) # ( segment * subject, channel, time//segment ) + X = X.clone().flatten(end_dim=1) + # ( segment * subject, channel, time//segment ) res = [] @@ -146,7 +147,7 @@ def interpolate_y_time(Y: torch.Tensor, num_samples: int) -> torch.Tensor: return F.interpolate(Y, size=num_samples, mode="linear") -# NOTE currently only works for gwilliams2022.yml +# NOTE: Works only for Gwilliams2022 dataset def check_preprocs(args, data_dir): is_processed = False preproc_dirs = glob.glob(data_dir + "*/") @@ -170,7 +171,11 @@ def check_preprocs(args, data_dir): # is_processed = np.all([v == args.preprocs[k] for k, v in settings.items() if not k in excluded_keys]) excluded_keys = ["preceding_chunk_for_baseline", "mode"] is_processed = np.all( - [v == args.preprocs[k] for k, v in settings.items() if k not in excluded_keys] + [ + v == args.preprocs[k] + for k, v in settings.items() + if k not in excluded_keys + ] ) if is_processed: cprint( diff --git a/tests/modules_for_test/models.py b/tests/modules_for_test/models.py index 8440ec1..2b37a6d 100644 --- a/tests/modules_for_test/models.py +++ b/tests/modules_for_test/models.py @@ -17,10 +17,14 @@ def __init__(self, args): self.D1 = args.D1 self.K = args.K self.spatial_attention = SpatialAttention(args) - self.conv = nn.Conv1d(in_channels=self.D1, out_channels=self.D1, kernel_size=1, stride=1) + self.conv = nn.Conv1d( + in_channels=self.D1, out_channels=self.D1, kernel_size=1, stride=1 + ) # NOTE: The below implementations are equivalent to learning a matrix: - self.subject_matrix = nn.Parameter(torch.rand(self.num_subjects, self.D1, self.D1)) + self.subject_matrix = nn.Parameter( + torch.rand(self.num_subjects, self.D1, self.D1) + ) # self.subject_layer = [ # nn.Conv1d(in_channels=self.D1, out_channels=self.D1, kernel_size=1, stride=1, device=device) # for _ in range(self.num_subjects) @@ -54,11 +58,19 @@ def __init__(self, num_subjects, D1, K, dataset_name, d_drop): # self.spatial_attention = SpatialAttentionX(D1, K, dataset_name) self.spatial_attention = SpatialAttention(D1, K, dataset_name, d_drop) - self.conv = nn.Conv1d(in_channels=self.D1, out_channels=self.D1, kernel_size=1, stride=1) - self.subject_matrix = nn.Parameter(torch.rand(self.num_subjects, self.D1, self.D1)) + self.conv = nn.Conv1d( + in_channels=self.D1, out_channels=self.D1, kernel_size=1, stride=1 + ) + self.subject_matrix = nn.Parameter( + torch.rand(self.num_subjects, self.D1, self.D1) + ) self.subject_layer = [ nn.Conv1d( - in_channels=self.D1, out_channels=self.D1, kernel_size=1, stride=1, device=device + in_channels=self.D1, + out_channels=self.D1, + kernel_size=1, + stride=1, + device=device, ) for _ in range(self.num_subjects) ] @@ -85,22 +97,20 @@ class SpatialAttentionTest1(nn.Module): I reimplemented to SpatialAttentionVer2 (which is not the final version). """ - def __init__(self, D1, K, dataset_name, z_re=None, z_im=None): + def __init__(self, args, z_re, z_im): super(SpatialAttentionTest1, self).__init__() - self.D1 = D1 - self.K = K + self.D1 = args.D1 + self.K = args.K - if z_re is None or z_im is None: - self.z_re = nn.Parameter(torch.Tensor(self.D1, self.K, self.K)) - self.z_im = nn.Parameter(torch.Tensor(self.D1, self.K, self.K)) - nn.init.kaiming_uniform_(self.z_re, a=np.sqrt(5)) - nn.init.kaiming_uniform_(self.z_im, a=np.sqrt(5)) - else: - self.z_re = z_re - self.z_im = z_im + # self.z_re = nn.Parameter(torch.Tensor(self.D1, self.K, self.K)) + # self.z_im = nn.Parameter(torch.Tensor(self.D1, self.K, self.K)) + # nn.init.kaiming_uniform_(self.z_re, a=np.sqrt(5)) + # nn.init.kaiming_uniform_(self.z_im, a=np.sqrt(5)) + self.z_re = z_re + self.z_im = z_im - self.ch_locations_2d = ch_locations_2d(dataset_name).to(device) + self.ch_locations_2d = ch_locations_2d(args).to(device) def fourier_space(self, j, x: torch.Tensor, y: torch.Tensor): # x: ( 60, ) y: ( 60, ) a_j = 0 @@ -130,16 +140,18 @@ def forward(self, X): # ( B, C, T ) (=( 128, 60, 256 )) class SpatialAttentionTest2(nn.Module): """Faster version of SpatialAttentionVer1""" - def __init__(self, args): + def __init__(self, args, z_re, z_im): super(SpatialAttentionTest2, self).__init__() self.D1 = args.D1 self.K = args.K - self.z_re = nn.Parameter(torch.Tensor(self.D1, self.K, self.K)) - self.z_im = nn.Parameter(torch.Tensor(self.D1, self.K, self.K)) - nn.init.kaiming_uniform_(self.z_re, a=np.sqrt(5)) - nn.init.kaiming_uniform_(self.z_im, a=np.sqrt(5)) + # self.z_re = nn.Parameter(torch.Tensor(self.D1, self.K, self.K)) + # self.z_im = nn.Parameter(torch.Tensor(self.D1, self.K, self.K)) + # nn.init.kaiming_uniform_(self.z_re, a=np.sqrt(5)) + # nn.init.kaiming_uniform_(self.z_im, a=np.sqrt(5)) + self.z_re = z_re + self.z_im = z_im self.K_arange = torch.arange(self.K).to(device) @@ -153,12 +165,16 @@ def fourier_space(self, x: torch.Tensor, y: torch.Tensor): # x: ( 60, ) y: ( 60 # ( 32, 1, 60 ) + ( 1, 32, 60 ) -> ( 32, 32, 60 ) rad = rad1.unsqueeze(1) + rad2.unsqueeze(0) - real = torch.einsum("dkl,klc->dc", self.z_re, torch.cos(2 * torch.pi * rad)) # ( 270, 60 ) + real = torch.einsum( + "dkl,klc->dc", self.z_re, torch.cos(2 * torch.pi * rad) + ) # ( 270, 60 ) imag = torch.einsum("dkl,klc->dc", self.z_im, torch.sin(2 * torch.pi * rad)) return real + imag # ( 270, 60 ) - def fourier_space_orig(self, x: torch.Tensor, y: torch.Tensor): # x: ( 60, ) y: ( 60, ) + def fourier_space_orig( + self, x: torch.Tensor, y: torch.Tensor + ): # x: ( 60, ) y: ( 60, ) """Slower version of fourier_space""" a = torch.zeros(self.D1, x.shape[0], device=device) # ( 270, 60 ) @@ -166,10 +182,14 @@ def fourier_space_orig(self, x: torch.Tensor, y: torch.Tensor): # x: ( 60, ) y: for l in range(self.K): # This einsum is same as torch.stack([_d * c for _d in d]) a += torch.einsum( - "d,c->dc", self.z_re[:, k, l], torch.cos(2 * torch.pi * (k * x + l * y)) + "d,c->dc", + self.z_re[:, k, l], + torch.cos(2 * torch.pi * (k * x + l * y)), ) # ( 270, 60 ) a += torch.einsum( - "d,c->dc", self.z_im[:, k, l], torch.sin(2 * torch.pi * (k * x + l * y)) + "d,c->dc", + self.z_im[:, k, l], + torch.sin(2 * torch.pi * (k * x + l * y)), ) return a # ( 270, 60 ) @@ -200,7 +220,8 @@ def forward(self, SA_wts): drop_center_id = np.random.randint(self.num_channels) distances = np.sqrt( - (self.x - self.x[drop_center_id]) ** 2 + (self.y - self.y[drop_center_id]) ** 2 + (self.x - self.x[drop_center_id]) ** 2 + + (self.y - self.y[drop_center_id]) ** 2 ) is_dropped = torch.where(distances < self.d_drop, 0.0, 1.0).to(device) # cprint( @@ -210,60 +231,6 @@ def forward(self, SA_wts): return SA_wts * is_dropped -class SpatialAttentionTest(nn.Module): - """ - Same as SpatialAttentionVer2, but a little more concise. - Also SpatialAttention is added. - """ - - def __init__(self, D1, K, dataset_name, d_drop): - super(SpatialAttentionTest, self).__init__() - - # vectorize of k's and l's - a = [] - for k in range(K): - for l in range(K): - a.append((k, l)) - a = torch.tensor(a) - k, l = a[:, 0], a[:, 1] - - # vectorize x- and y-positions of the sensors - loc = ch_locations_2d(dataset_name) - x, y = loc[:, 0], loc[:, 1] - - # make a complex-valued parameter, reshape k,l into one dimension - self.z = nn.Parameter(torch.rand(size=(D1, K**2), dtype=torch.cfloat)).to(device) - - # NOTE: pre-compute the values of cos and sin (they depend on k, l, x and y which repeat) - phi = ( - 2 * torch.pi * (torch.einsum("k,x->kx", k, x) + torch.einsum("l,y->ly", l, y)) - ) # torch.Size([1024, 60])) - self.cos = torch.cos(phi).to(device) - self.sin = torch.cos(phi).to(device) - - self.spatial_dropout = SpatialDropout(x, y, d_drop) - - def forward(self, X): - # NOTE: do hadamard product and and sum over l and m (i.e. m, which is l X m) - re = torch.einsum("jm, me -> je", self.z.real, self.cos) # torch.Size([270, 60]) - im = torch.einsum("jm, me -> je", self.z.imag, self.sin) - a = ( - re + im - ) # essentially (unnormalized) weights with which to mix input channels into ouput channels - # ( D1, num_channels ) - - # NOTE: to get the softmax spatial attention weights over input electrodes, - # we don't compute exp, etc (as in the eq. 5), we take softmax instead: - SA_wts = F.softmax(a, dim=-1) # each row sums to 1 - # ( D1, num_channels ) - - SA_wts = self.spatial_dropout(SA_wts) - - return torch.einsum( - "oi,bit->bot", SA_wts, X - ) # each output is a diff weighted sum over each input channel - - class ConvBlockTest(nn.Module): def __init__(self, k, D1, D2): super(ConvBlockTest, self).__init__() diff --git a/tests/test_brennan_dataclass.py b/tests/test_brennan_dataclass.py new file mode 100644 index 0000000..a1ca9de --- /dev/null +++ b/tests/test_brennan_dataclass.py @@ -0,0 +1,22 @@ +import torch +from hydra import initialize, compose + +from speech_decoding.dataclass.brennan2018 import Brennan2018Dataset + +with initialize(version_base=None, config_path="../configs/"): + args = compose(config_name="config.yaml") + + +def test_drop_last_segments(): + _X = torch.rand(2129, 33, 60, 360) + _audio = torch.rand(2129, 1, 48000) + onsets_path = "/home/sensho/speech_decoding/data/Brennan2018/AliceChapterOne-EEG.csv" + + X, audio, sentence_idxs = Brennan2018Dataset._drop_last_segments( + _X, _audio, onsets_path + ) + + print(X.shape) + + assert X.shape[0] < _X.shape[0] + assert X.shape[1:] == _X.shape[1:] diff --git a/tests/test_models.py b/tests/test_models.py index 5a0f008..806241c 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -5,22 +5,35 @@ from speech_decoding.models import * from tests.modules_for_test.models import * +torch.manual_seed(0) + +device = "cuda:0" if torch.cuda.is_available() else "cpu" + with initialize(version_base=None, config_path="../configs/"): - args = compose(config_name="config.yaml") + args = compose(config_name="config") + + with open_dict(args): + args.root_dir = "/home/sensho/speech_decoding" + + +def test_spatial_attention() -> None: + # NOTE: 60 channels is for Brennan2018 + input = torch.rand(8, 60, 360).to(device) + + sa = SpatialAttention(args).eval().to(device) + z_re = sa.z.real.reshape(args.D1, args.K, args.K) + z_im = sa.z.imag.reshape(args.D1, args.K, args.K) -# def test_spatial_attention() -> None: -# with initialize(version_base=None, config_path="../configs"): -# args = compose(config_name="config") -# with open_dict(args): -# # FIXME: get_original_cwd() can't be called -# args.root_dir = "/home/sensho/speech_decoding" # get_original_cwd() + sa_test1 = SpatialAttentionTest1(args, z_re, z_im).eval().to(device) + sa_test2 = SpatialAttentionTest2(args, z_re, z_im).eval().to(device) -# input = torch.rand(8, 208, 256).to(device) -# output = SpatialAttention(args)(input) -# output_test = SpatialAttentionTest2(args)(input) + output = sa(input) + output_test1 = sa_test1(input) + output_test2 = sa_test2(input) -# assert output == output_test + assert torch.allclose(output, output_test1, rtol=1e-4, atol=1e-5) + assert torch.allclose(output, output_test2) def test_classifier(): @@ -35,10 +48,10 @@ def test_classifier(): assert torch.allclose(similarity_train, similarity_test) -def test_standard_normalization(): - input = torch.rand(64, 1024, 360) - output = BrainEncoder._standard_normalization(input) +# def test_standard_normalization(): +# input = torch.rand(64, 1024, 360) +# output = BrainEncoder._standard_normalization(input) - input0_norm = (input[0] - input[0].mean()) / input[0].std() +# input0_norm = (input[0] - input[0].mean()) / input[0].std() - assert torch.equal(output[0], input0_norm) +# assert torch.equal(output[0], input0_norm) diff --git a/train.py b/train.py index 2b41911..1056e79 100644 --- a/train.py +++ b/train.py @@ -130,11 +130,13 @@ def run(args: DictConfig) -> None: elif args.split_mode == "deep": train_set = torch.utils.data.Subset(dataset, range(train_size)) - test_set = torch.utils.data.Subset(dataset, range(train_size, train_size + test_size)) + test_set = torch.utils.data.Subset( + dataset, range(train_size, train_size + test_size) + ) elif args.split_mode == "sentence": # NOTE: sentence_idxs starts from 1 - num_sentences = dataset.sentence_idxs.max() + num_sentences = dataset.sentence_idxs.max() # 84 num_train_sentences = int(num_sentences * args.split_ratio) train_sentences, test_sentences = torch.utils.data.random_split( @@ -173,7 +175,9 @@ def run(args: DictConfig) -> None: raise ValueError("Unknown dataset") if args.use_wandb: - wandb.config = {k: v for k, v in dict(args).items() if k not in ["root_dir", "wandb"]} + wandb.config = { + k: v for k, v in dict(args).items() if k not in ["root_dir", "wandb"] + } wandb.init( project=args.wandb.project, entity=args.wandb.entity, @@ -200,8 +204,7 @@ def run(args: DictConfig) -> None: # Optimizer # -------------------- optimizer = torch.optim.Adam( - list(brain_encoder.parameters()) + list(loss_func.parameters()), - lr=float(args.lr), + list(brain_encoder.parameters()) + list(loss_func.parameters()), lr=args.lr ) # --------------------