Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@ Works for Gwilliams2022 dataset and Brennan2018 dataset.

## TODOs

- [ ] Full reproducibility support. Will be useful for HP tuning.
- [ ] Match accuracy to numbers reported in the paper.
- [ ] Work with huge memory consumption issue in Gwilliams multiprocessing
- [ ] Achieve accuracies in the paper (Brennan -> Gwilliams).
- [ ] Reorganize Gwilliams dataclass to look more similar to Brennan (using Base class).

# Usage

Expand Down
Binary file added assets/questioned_lines1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
22 changes: 16 additions & 6 deletions assets/reports.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,19 @@
# Brennan2018
# Questions

## Experiments
- About "We ensure that there is no identical sentences across splits" below, can it be rephrased as "We ensure that there is no two segments across splits that are coming from the same sentence"?

- About "we restrict the test segments to those that contain a word at a fixed location (here 500ms into the sample)" below, can it be rephrased as "we dropped the test segments that don't contain any word until 500ms"? In that case why does it happend when we segment with word onsets?

<img src="questioned_lines1.png">

- After wav2vec2.0 embedding, audios become something like 50Hz (because wav2vec2.0 requires them to be originally 16kHz and it downsamples them a lot), so we need to upsample them to match brains' 120Hz. Do you actually do that? If so, which method do you use? We've tried linear interpolation by torchaudio and zero-padding but neither worked well.

- Learnable temperature of CLIP loss was not mentioned and absent in equation (2). It was mentioned in the original CLIP paper but do you actually use it?


# Experiment results

## Brennan2018

- Basically couldn't achieve performance in the paper.

Expand Down Expand Up @@ -29,8 +42,5 @@
- Channel-wise
- Robust Scaler was applied channel-wise

## Questions

- After wav2vec2.0 embedding, audios become somewhere like 50Hz (because wav2vec2.0 requires them to be originally 16kHz and it downsamples them a lot), so we need to upsample them to match brains' 120Hz. Do you actually do that? If so, which method do you use? We've tried linear interpolation by torchaudio and zero-padding but neither worked well.

- Learnable temperature of CLIP loss was not mentioned and absent in equation (2). It was mentioned in the original CLIP paper but do you actually use it?
## Gwilliams2022
4 changes: 2 additions & 2 deletions configs/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ split_mode: sentence # sentence, shallow, deep
num_workers: 6
batch_size: 64
updates: 1200
lr: 3e-4
lr: 3.0e-4
epochs: 300
reduction: mean

Expand All @@ -43,7 +43,7 @@ preprocs:
shift_brain: True # whether to shift M/EEG into the future relative to audio
shift_len: 150 # if True, by how many ms
last4layers: True # if True, the brain_encoder's emsize will be 1024, not 512
channel_wise: True # whether to scale each channel's EEG dataset individually (only for Brennan2018)
channel_wise: True # Whether to scale each channel of MEG/EEG datasets individually
clamp: True
clamp_lim: 20
y_upsample: interpolate # interpolate / pad
Expand Down
26 changes: 26 additions & 0 deletions speech_decoding/dataclass/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import torch
import torchaudio.functional as F

from termcolor import cprint


class SpeechDecodingDatasetBase(torch.utils.data.Dataset):
def __init__(self, args):
super().__init__()

def _resample_audio(
self, waveform: torch.Tensor, sample_rate: int, resample_rate: int = 16000
) -> torch.Tensor:
"""Resamples audio to 16kHz. (16kHz is required by wav2vec2.0)"""

waveform = F.resample(
waveform,
sample_rate,
resample_rate,
lowpass_filter_width=self.lowpass_filter_width,
)

len_audio_s = waveform.shape[1] / resample_rate
cprint(f">>> Audio length: {len_audio_s} s.", color="cyan")

return waveform
95 changes: 55 additions & 40 deletions speech_decoding/dataclass/brennan2018.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from transformers import Wav2Vec2Model

from speech_decoding.dataclass.base import SpeechDecodingDatasetBase
from speech_decoding.utils.wav2vec_util import get_last4layers_avg
from speech_decoding.utils.preproc_utils import (
shift_brain_signal,
Expand Down Expand Up @@ -48,18 +49,16 @@
"""


class Brennan2018Dataset(Dataset):
class Brennan2018Dataset(SpeechDecodingDatasetBase):
def __init__(self, args):
super().__init__()
super().__init__(args)

# Both
chance = args.chance
force_recompute = args.rebuild_dataset
self.split_mode = args.split_mode
self.root_dir = args.root_dir
self.channel_wise = args.preprocs.channel_wise
self.seq_len_sec = args.preprocs.seq_len_sec
# EEG
self.channel_wise = args.preprocs.channel_wise
self.filter_brain = args.preprocs.filter
self.brain_filter_low = args.preprocs.brain_filter_low
self.brain_filter_high = args.preprocs.brain_filter_high
Expand All @@ -72,16 +71,22 @@ def __init__(self, args):
self.lowpass_filter_width = args.preprocs.lowpass_filter_width
self.wav2vec = Wav2Vec2Model.from_pretrained(args.wav2vec_model)
# Data Paths
self.matfile_paths = natsorted(glob.glob(f"{self.root_dir}/data/Brennan2018/raw/*.mat"))
self.audio_paths = natsorted(glob.glob(f"{self.root_dir}/data/Brennan2018/audio/*.wav"))
self.matfile_paths = natsorted(
glob.glob(f"{self.root_dir}/data/Brennan2018/raw/*.mat")
)
self.audio_paths = natsorted(
glob.glob(f"{self.root_dir}/data/Brennan2018/audio/*.wav")
)
self.onsets_path = f"{self.root_dir}/data/Brennan2018/AliceChapterOne-EEG.csv"
# Save Paths
X_path = f"{self.root_dir}/data/Brennan2018/X.pt"
Y_path = f"{self.root_dir}/data/Brennan2018/Y.pt"
sentence_idxs_path = f"{self.root_dir}/data/Brennan2018/sentence_idxs.npy"

# Rebuild dataset
if force_recompute or not (os.path.exists(X_path) and os.path.exists(Y_path)):
if args.rebuild_dataset or not (
os.path.exists(X_path) and os.path.exists(Y_path)
):
cprint(f"> Preprocessing EEG and audio.", color="cyan")

self.X, self.Y, self.sentence_idxs = self.rebuild_dataset(
Expand Down Expand Up @@ -109,15 +114,18 @@ def __init__(self, args):
self.num_subjects = self.X.shape[1]
cprint(f">>> Number of subjects: {self.num_subjects}", color="cyan")

cprint(f">> Upsampling audio embedding with: {args.preprocs.y_upsample}", color="cyan")
cprint(
f">> Upsampling audio embedding with: {args.preprocs.y_upsample}",
color="cyan",
)
if args.preprocs.y_upsample == "interpolate":
self.Y = interpolate_y_time(self.Y, self.brain_num_samples)
elif args.preprocs.y_upsample == "pad":
self.Y = pad_y_time(self.Y, self.brain_num_samples)
else:
raise ValueError(f"Unknown upsampling strategy: {args.preprocs.y_upsample}")

if chance:
if args.chance:
self.Y = self.Y[torch.randperm(len(self.Y))]

def __len__(self):
Expand All @@ -144,11 +152,11 @@ def rebuild_dataset(
# ----------------------
waveform = [torchaudio.load(path) for path in audio_paths]

audio_rate = self.get_audio_rate(waveform)
audio_rate = self._get_audio_rate(waveform)

waveform = torch.cat([w[0] for w in waveform], dim=1) # ( 1, time@44.1kHz )

audio = self.resample_audio(waveform, audio_rate, AUDIO_RESAMPLE_RATE)
audio = self._resample_audio(waveform, audio_rate, AUDIO_RESAMPLE_RATE)

cprint(
f">>> Resampled audio {audio_rate}Hz -> {AUDIO_RESAMPLE_RATE}Hz | shape: {waveform.shape} -> {audio.shape}",
Expand All @@ -159,12 +167,14 @@ def rebuild_dataset(
# EEG Loading
# ----------------------
matfile_paths = [
path for path in matfile_paths if not path.split(".")[0][-3:] in EXCLUDED_SUBJECTS
path
for path in matfile_paths
if not path.split(".")[0][-3:] in EXCLUDED_SUBJECTS
]
mat_raws = [scipy.io.loadmat(path)["raw"][0, 0] for path in matfile_paths]
eeg_raws = [mat_raw["trial"][0, 0] for mat_raw in mat_raws]

eeg_rate = self.get_eeg_rate(mat_raws)
eeg_rate = self._get_eeg_rate(mat_raws)

# ----------------------
# Preprocessing
Expand All @@ -183,7 +193,9 @@ def rebuild_dataset(

if self.split_mode == "sentence":
cprint(">> Dropping last segments of each sentence.", color="cyan")
X, audio, sentence_idxs = self.drop_last_segments(X, audio, onsets_path)

X, audio, sentence_idxs = self._drop_last_segments(X, audio, onsets_path)

cprint(f">>> X (EEG): {X.shape} | Audio: {audio.shape}", color="cyan")
else:
sentence_idxs = None
Expand All @@ -201,46 +213,46 @@ def rebuild_dataset(

return X, Y, sentence_idxs

def drop_last_segments(
self, X: torch.Tensor, audio: torch.Tensor, onsets_path: str
@staticmethod
def _drop_last_segments(
X: torch.Tensor, audio: torch.Tensor, onsets_path: str
) -> Tuple[torch.Tensor, torch.Tensor, np.ndarray]:
"""Drops last segments of each sentence.
FIXME: currently drops last 5 words, but this number should be variable to cover 3 secs.
Args:
X: ( segment, subject, channel, time@120Hz//segment )
audio: ( segment, 1, time@16kHz//segment )
Returns:
X: ( _segment, subject, channel, time@120Hz//segment )
audio: ( _segment, 1, time@16kHz//segment )
"""
num_drops = 5
NUM_DROPS = 5

sentence_idxs = pd.read_csv(onsets_path).Sentence.to_numpy()
assert np.all(
np.diff(sentence_idxs) >= 0
), "sentence_idxs is not a non-decreasing step sequence."

sentence_ends = np.where(np.diff(sentence_idxs) > 0)[0]
# NOTE: end idx for the last sentence
sentence_ends = np.append(sentence_ends, sentence_idxs.shape[0] - 1)

# NOTE: There are sentences that are shorter than NUM_DROPS, but the overlap is not
# a problem as we convert it to a boolean array.
# assert np.all(
# np.diff(np.append(-1, sentence_ends)) > NUM_DROPS
# ), f"Some sentence(s) are not longer than NUM_DROPS={NUM_DROPS}."

drop_idxs = np.concatenate([np.arange(i - num_drops, i) + 1 for i in sentence_ends])
drop_idxs = np.concatenate(
[np.arange(i - NUM_DROPS, i) + 1 for i in sentence_ends]
)

# NOTE: to boolean array
drop_bools = np.ones(len(X), dtype=bool)
drop_bools[drop_idxs] = False

return X[drop_bools], audio[drop_bools], sentence_idxs[drop_bools]

def resample_audio(
self, waveform: torch.Tensor, sample_rate: int, resample_rate: int = 16000
) -> torch.Tensor:
"""Resamples audio to 16kHz. (16kHz is required by wav2vec2.0)"""

waveform = F.resample(
waveform,
sample_rate,
resample_rate,
lowpass_filter_width=self.lowpass_filter_width,
)

len_audio_s = waveform.shape[1] / resample_rate
cprint(f">>> Audio length: {len_audio_s} s.", color="cyan")

return waveform

def embed_audio(self, audio: torch.Tensor) -> torch.Tensor:
"""
Args:
Expand Down Expand Up @@ -316,7 +328,6 @@ def resample_brain(

eeg_raw = torch.from_numpy(eeg_raw.astype(np.float32))

# NOTE: in the paper they say they downsampled to exactly 120Hz with Torchaudio, so I'll stick to that
eeg_resampled = F.resample(
waveform=eeg_raw,
orig_freq=fsample,
Expand All @@ -328,16 +339,20 @@ def resample_brain(

return torch.stack(X)

def get_audio_rate(self, waveform: List[Tuple[torch.Tensor, int]]) -> int:
@staticmethod
def _get_audio_rate(waveform: List[Tuple[torch.Tensor, int]]) -> int:
sample_rates = np.array([w[1] for w in waveform])
# is all 44.1kHz
assert np.all(sample_rates == sample_rates[0])

return sample_rates[0]

def get_eeg_rate(self, mat_raws: List) -> int:
@staticmethod
def _get_eeg_rate(mat_raws: List) -> int:
sample_rates = np.array([mat_raw["fsample"][0, 0] for mat_raw in mat_raws])
# is all 500Hz
assert np.all(sample_rates == sample_rates[0]), "Wrong EEG sampling rate detected."
assert np.all(
sample_rates == sample_rates[0]
), "Wrong EEG sampling rate detected."

return sample_rates[0]
Loading