diff --git a/backend/app.py b/backend/app.py index 11e736d9..226311b3 100644 --- a/backend/app.py +++ b/backend/app.py @@ -7,9 +7,13 @@ from classification.exceptions import ClassificationError from classification.config.constants import Sex, ALLOWED_FILE_EXTENSIONS from classification.model import SleepStagesClassifier +from classification.request import ClassificationRequest +from classification.response import ClassificationResponse +from classification.features.preprocessing import preprocess +from classification.spectrogram_generator import SpectrogramGenerator app = Flask(__name__) -model = SleepStagesClassifier() +sleep_stage_classifier = SleepStagesClassifier() def allowed_file(filename): @@ -46,29 +50,28 @@ def analyze_sleep(): return 'File format not allowed', HTTPStatus.BAD_REQUEST form_data = request.form.to_dict() + raw_array = get_raw_array(file) try: - age = int(form_data['age']) - sex = Sex[form_data['sex']] - stream_start = int(form_data['stream_start']) - bedtime = int(form_data['bedtime']) - wakeup = int(form_data['wakeup']) - except (KeyError, ValueError): + classification_request = ClassificationRequest( + age=int(form_data['age']), + sex=Sex[form_data['sex']], + stream_start=int(form_data['stream_start']), + bedtime=int(form_data['bedtime']), + wakeup=int(form_data['wakeup']), + raw_eeg=raw_array, + ) + except (KeyError, ValueError, ClassificationError): return 'Missing or invalid request parameters', HTTPStatus.BAD_REQUEST - try: - raw_array = get_raw_array(file) - model.predict(raw_array, info={ - 'sex': sex, - 'age': age, - 'in_bed_seconds': bedtime - stream_start, - 'out_of_bed_seconds': wakeup - stream_start - }) - except ClassificationError as e: - return e.message, HTTPStatus.BAD_REQUEST - - with open("assets/mock_response.json", "r") as mock_response_file: - return mock_response_file.read() + preprocessed_epochs = preprocess(classification_request) + predictions = sleep_stage_classifier.predict(preprocessed_epochs, classification_request) + spectrogram_generator = SpectrogramGenerator(preprocessed_epochs) + classification_response = ClassificationResponse( + classification_request, predictions, spectrogram_generator.generate() + ) + + return classification_response.response CORS(app, diff --git a/backend/classification/config/constants.py b/backend/classification/config/constants.py index 10414a75..cc0bb101 100644 --- a/backend/classification/config/constants.py +++ b/backend/classification/config/constants.py @@ -8,6 +8,14 @@ class Sex(Enum): M = 2 +class SleepStage(Enum): + W = 0 + N1 = 1 + N2 = 2 + N3 = 3 + REM = 4 + + class HiddenMarkovModelProbability(Enum): emission = auto() start = auto() @@ -20,8 +28,8 @@ def get_filename(self): ALLOWED_FILE_EXTENSIONS = ('.txt', '.csv') EEG_CHANNELS = [ - 'EEG Fpz-Cz', - 'EEG Pz-Oz' + 'Fpz-Cz', + 'Pz-Oz' ] EPOCH_DURATION = 30 @@ -38,5 +46,3 @@ def get_filename(self): [85, 125] ] ACCEPTED_AGE_RANGE = [AGE_FEATURE_BINS[0][0], AGE_FEATURE_BINS[-1][-1]] - -N_STAGES = 5 diff --git a/backend/classification/features/__init__.py b/backend/classification/features/__init__.py index c21ae168..ee76c1f9 100644 --- a/backend/classification/features/__init__.py +++ b/backend/classification/features/__init__.py @@ -6,24 +6,17 @@ ) -def get_features(signal, info): +def get_features(signal, request): """Returns the raw features Input: - raw_eeg: instance of mne.io.RawArray Should contain 2 channels (1: FPZ-CZ, 2: PZ-OZ) - - info: dict - Should contain the following keys: - - sex: instance of Sex enum - - age: indicates the subject's age - - in_bed_seconds: timespan, in seconds, from which - the subject started the recording and went to bed - - out_of_bed_seconds: timespan, in seconds, from which - the subject started the recording and got out of bed + - info: instance of ClassificationRequest Returns ------- - features X in a vector of (nb_epochs, nb_features) """ - X_eeg = get_eeg_features(signal, info['in_bed_seconds'], info['out_of_bed_seconds']) - X_categorical = get_non_eeg_features(info['age'], info['sex'], X_eeg.shape[0]) + X_eeg = get_eeg_features(signal, request.in_bed_seconds, request.out_of_bed_seconds) + X_categorical = get_non_eeg_features(request.age, request.sex, X_eeg.shape[0]) return np.append(X_categorical, X_eeg, axis=1).astype(np.float32) diff --git a/backend/classification/features/extraction.py b/backend/classification/features/extraction.py index 494a9b00..e53667e9 100644 --- a/backend/classification/features/extraction.py +++ b/backend/classification/features/extraction.py @@ -6,14 +6,13 @@ AGE_FEATURE_BINS, ) from classification.features.pipeline import get_feature_union -from classification.features.preprocessing import preprocess -def get_eeg_features(raw_data, in_bed_seconds, out_of_bed_seconds): +def get_eeg_features(epochs, in_bed_seconds, out_of_bed_seconds): """Returns the continuous feature matrix Input ------- - raw_signal: MNE.Raw object with signals with or without annotations + epochs: mne.Epochs object with signals with or without annotations in_bed_seconds: timespan, in seconds, from which the subject started the recording and went to bed out_of_bed_seconds: timespan, in seconds, from which the subject @@ -23,21 +22,21 @@ def get_eeg_features(raw_data, in_bed_seconds, out_of_bed_seconds): ------- Array of size (nb_epochs, nb_continuous_features) """ - features_file = [] + features = [] feature_union = get_feature_union() for channel in EEG_CHANNELS: - chan_data = preprocess(raw_data, channel, in_bed_seconds, out_of_bed_seconds) + channel_epochs = epochs.copy().pick_channels({channel}) + channel_features = feature_union.transform(channel_epochs) - X_features = feature_union.transform(chan_data) - features_file.append(X_features) + features.append(channel_features) print( - f"Done extracting {X_features.shape[1]} features " - f"on {X_features.shape[0]} epochs for {channel}\n" + f"Done extracting {channel_features.shape[1]} features " + f"on {channel_features.shape[0]} epochs for {channel}\n" ) - return np.hstack(tuple(features_file)) + return np.hstack(tuple(features)) def get_non_eeg_features(age, sex, nb_epochs): diff --git a/backend/classification/features/preprocessing.py b/backend/classification/features/preprocessing.py index 0561baa7..eb1b82cd 100644 --- a/backend/classification/features/preprocessing.py +++ b/backend/classification/features/preprocessing.py @@ -12,21 +12,19 @@ ) -def preprocess(raw_data, channel, bed_seconds, out_of_bed_seconds): +def preprocess(classification_request): """Returns preprocessed epochs of the specified channel Input ------- - raw_data: instance of mne.Raw - channel: channel to preprocess - bed_seconds: number of seconds between start of recording & moment at - which the subjet went to bed (in seconds) - out_of_bed_seconds: number of seconds between start of recording & moment - at which the subjet got out of bed (in seconds) + classification_request: instance of ClassificationRequest """ - raw_data = raw_data.copy() + raw_data = classification_request.raw_eeg.copy() - raw_data = _drop_other_channels(raw_data, channel) - raw_data = _crop_raw_data(raw_data, bed_seconds, out_of_bed_seconds) + raw_data = _crop_raw_data( + raw_data, + classification_request.in_bed_seconds, + classification_request.out_of_bed_seconds, + ) raw_data = _apply_high_pass_filter(raw_data) raw_data = raw_data.resample(DATASET_SAMPLE_RATE) raw_data = _convert_to_epochs(raw_data) @@ -34,14 +32,6 @@ def preprocess(raw_data, channel, bed_seconds, out_of_bed_seconds): return raw_data -def _drop_other_channels(raw_data, channel_to_keep): - """returns mne.Raw with only the channel to keep""" - raw_data.drop_channels( - [ch for ch in raw_data.info['ch_names'] if ch != channel_to_keep]) - - return raw_data - - def _crop_raw_data( raw_data, bed_seconds, diff --git a/backend/classification/file_loading.py b/backend/classification/file_loading.py index c9307180..030de041 100644 --- a/backend/classification/file_loading.py +++ b/backend/classification/file_loading.py @@ -14,11 +14,11 @@ The Cyton board logging format is also described here: [https://docs.openbci.com/docs/02Cyton/CytonSDCard#data-logging-format] """ -from io import StringIO from mne import create_info from mne.io import RawArray import numpy as np +from classification.exceptions import ClassificationError from classification.config.constants import ( EEG_CHANNELS, OPENBCI_CYTON_SAMPLE_RATE, @@ -31,6 +31,7 @@ FILE_COLUMN_OFFSET = 1 CYTON_TOTAL_NB_CHANNELS = 8 +SKIP_ROWS = 2 def get_raw_array(file): @@ -40,31 +41,32 @@ def get_raw_array(file): Returns: - mne.RawArray of the two EEG channels of interest """ - file_content = StringIO(file.stream.read().decode("UTF8")) + lines = file.readlines() + eeg_raw = np.zeros((len(lines) - SKIP_ROWS, len(EEG_CHANNELS))) - eeg_raw = [] - for line in file_content.readlines(): - line_splitted = line.split(',') + for index, line in enumerate(lines[SKIP_ROWS:]): + line_splitted = line.decode('utf-8').split(',') - if len(line_splitted) >= CYTON_TOTAL_NB_CHANNELS: - eeg_raw.append(_get_decimals_from_hexadecimal_strings(line_splitted)) + if len(line_splitted) < CYTON_TOTAL_NB_CHANNELS: + raise ClassificationError() - eeg_raw = SCALE_V_PER_COUNT * np.array(eeg_raw, dtype='object') + eeg_raw[index] = _get_decimals_from_hexadecimal_strings(line_splitted) raw_object = RawArray( - np.transpose(eeg_raw), + SCALE_V_PER_COUNT * np.transpose(eeg_raw), info=create_info( ch_names=EEG_CHANNELS, sfreq=OPENBCI_CYTON_SAMPLE_RATE, ch_types='eeg'), verbose=False, ) - - print('First sample values: ', raw_object[:, 0]) - print('Second sample values: ', raw_object[:, 1]) - print('Number of samples: ', raw_object.n_times) - print('Duration of signal (h): ', raw_object.n_times / (3600 * OPENBCI_CYTON_SAMPLE_RATE)) - print('Channel names: ', raw_object.ch_names) + print(f""" + First sample values: {raw_object[:, 0]} + Second sample values: {raw_object[:, 1]} + Number of samples: {raw_object.n_times} + Duration of signal (h): {raw_object.n_times / (3600 * OPENBCI_CYTON_SAMPLE_RATE)} + Channel names: {raw_object.ch_names} + """) return raw_object diff --git a/backend/classification/model.py b/backend/classification/model.py index db1f2a00..eb889153 100644 --- a/backend/classification/model.py +++ b/backend/classification/model.py @@ -1,7 +1,6 @@ """defines models which predict sleep stages based off EEG signals""" from classification.features import get_features -from classification.validation import validate from classification.postprocessor import get_hmm_model from classification.load_model import load_model, load_hmm @@ -14,32 +13,22 @@ def __init__(self): self.postprocessor_state = load_hmm() self.postprocessor = get_hmm_model(self.postprocessor_state) - def predict(self, raw_eeg, info): + def predict(self, epochs, request): """ Input: - - raw_eeg: instance of mne.io.RawArray + - raw_eeg: instance of mne.Epochs Should contain 2 channels (1: FPZ-CZ, 2: PZ-OZ) - - info: dict - Should contain the following keys: - - sex: instance of Sex enum - - age: indicates the subject's age - - in_bed_seconds: timespan, in seconds, from which - the subject started the recording and went to bed - - out_of_bed_seconds: timespan, in seconds, from which - the subject started the recording and got out of bed + - request: instance of ClassificationRequest Returns: array of predicted sleep stages """ - validate(raw_eeg, info) - features = get_features(raw_eeg, info) + features = get_features(epochs, request) print(features, features.shape) predictions = self._get_predictions(features) predictions = self._get_postprocessed_predictions(predictions) - print(predictions) - return predictions def _get_predictions(self, features): diff --git a/backend/classification/postprocessor.py b/backend/classification/postprocessor.py index 260fdae9..bd918e17 100644 --- a/backend/classification/postprocessor.py +++ b/backend/classification/postprocessor.py @@ -2,7 +2,7 @@ from classification.config.constants import ( HiddenMarkovModelProbability, - N_STAGES, + SleepStage, ) @@ -15,7 +15,7 @@ def get_hmm_model(state): describes the according hidden markov model state Returns: an instance of a trained MultinomialHMM """ - hmm_model = MultinomialHMM(n_components=N_STAGES) + hmm_model = MultinomialHMM(n_components=len(SleepStage)) hmm_model.emissionprob_ = state[HiddenMarkovModelProbability.emission.name] hmm_model.startprob_ = state[HiddenMarkovModelProbability.start.name] diff --git a/backend/classification/request.py b/backend/classification/request.py new file mode 100644 index 00000000..38b1e2ee --- /dev/null +++ b/backend/classification/request.py @@ -0,0 +1,73 @@ + +from classification.config.constants import EPOCH_DURATION +from classification.config.constants import ( + FILE_MINIMUM_DURATION, + ACCEPTED_AGE_RANGE, +) +from classification.exceptions import ( + TimestampsError, + FileSizeError, + ClassificationError, +) + + +class ClassificationRequest(): + def __init__(self, sex, age, stream_start, bedtime, wakeup, raw_eeg): + self.sex = sex + self.age = age + self.stream_start = stream_start + self.bedtime = bedtime + self.wakeup = wakeup + + self.stream_duration = raw_eeg.times[-1] + self.raw_eeg = raw_eeg + + self._validate() + + @property + def in_bed_seconds(self): + """timespan, in seconds, from which the subject started the recording and went to bed""" + return self.bedtime - self.stream_start + + @property + def out_of_bed_seconds(self): + """timespan, in seconds, from which the subject started the recording and got out of bed""" + return self.wakeup - self.stream_start + + @property + def n_epochs(self): + return (self.wakeup - self.bedtime) / EPOCH_DURATION + + def _validate(self): + self._validate_timestamps() + self._validate_file_with_timestamps() + self._validate_age() + + def _validate_timestamps(self): + has_positive_timespan = self.bedtime > self.stream_start and self.wakeup > self.stream_start + has_got_out_of_bed_after_in_bed = self.wakeup > self.bedtime + has_respected_minimum_bed_time = (self.wakeup - self.bedtime) > FILE_MINIMUM_DURATION + + if not( + has_positive_timespan + and has_got_out_of_bed_after_in_bed + and has_respected_minimum_bed_time + ): + raise TimestampsError() + + def _validate_file_with_timestamps(self): + has_raw_respected_minimum_file_size = self.raw_eeg.times[-1] > FILE_MINIMUM_DURATION + + if not has_raw_respected_minimum_file_size: + raise FileSizeError() + + is_raw_at_least_as_long_as_out_of_bed = self.raw_eeg.times[-1] >= self.out_of_bed_seconds + + if not is_raw_at_least_as_long_as_out_of_bed: + raise TimestampsError() + + def _validate_age(self): + is_in_accepted_range = ACCEPTED_AGE_RANGE[0] <= int(self.age) <= ACCEPTED_AGE_RANGE[1] + + if not(is_in_accepted_range): + raise ClassificationError('invalid age') diff --git a/backend/classification/response.py b/backend/classification/response.py new file mode 100644 index 00000000..dce5a2b3 --- /dev/null +++ b/backend/classification/response.py @@ -0,0 +1,56 @@ +import numpy as np + +from classification.config.constants import EPOCH_DURATION, SleepStage + + +class ClassificationResponse(): + def __init__(self, request, predictions, spectrogram): + self.sex = request.sex + self.age = request.age + self.stream_start = request.stream_start + self.stream_duration = request.stream_duration + self.bedtime = request.bedtime + self.wakeup = request.wakeup + self.n_epochs = request.n_epochs + + self.spectrogram = spectrogram + self.predictions = predictions + + @property + def sleep_stages(self): + ordered_sleep_stage_names = np.array([SleepStage(stage_index).name for stage_index in range(len(SleepStage))]) + return ordered_sleep_stage_names[self.predictions] + + @property + def epochs(self): + timestamps = np.arange(self.n_epochs * EPOCH_DURATION, step=EPOCH_DURATION) + self.bedtime + return {'timestamps': timestamps.tolist(), 'stages': self.sleep_stages.tolist()} + + @property + def metadata(self): + return { + "sessionStartTime": self.stream_start, + "sessionEndTime": self.stream_duration + self.stream_start, + "totalSessionTime": self.stream_duration, + "bedTime": self.bedtime, + "wakeUpTime": None, + "totalBedTime": None, + } + + @property + def subject(self): + return { + 'age': self.age, + 'sex': self.sex.name, + } + + @property + def response(self): + return { + 'epochs': self.epochs, + 'report': None, + 'metadata': self.metadata, + 'subject': self.subject, + 'board': None, + 'spectrograms': self.spectrogram, + } diff --git a/backend/classification/spectrogram_generator.py b/backend/classification/spectrogram_generator.py new file mode 100644 index 00000000..d4f66296 --- /dev/null +++ b/backend/classification/spectrogram_generator.py @@ -0,0 +1,34 @@ +from itertools import chain + +from mne.time_frequency import psd_welch +import numpy as np + +from classification.features.constants import FREQ_BANDS_RANGE +from classification.config.constants import EEG_CHANNELS + + +class SpectrogramGenerator(): + def __init__(self, epochs): + self.epochs = epochs + + range_frequencies = set(chain(*FREQ_BANDS_RANGE.values())) + self.spectrogram_min_freq = min(range_frequencies) + self.spectrogram_max_freq = max(range_frequencies) + + def generate(self): + psds, freqs = psd_welch( + self.epochs, + fmin=self.spectrogram_min_freq, + fmax=self.spectrogram_max_freq, + ) + psds_db = self._convert_amplitudes_to_decibel(psds) + + spectrogram = {'frequencies': freqs.tolist()} + + for index, eeg_channel in enumerate(EEG_CHANNELS): + spectrogram[eeg_channel.lower()] = psds_db[:, index, :].tolist() + + return spectrogram + + def _convert_amplitudes_to_decibel(self, amplitudes): + return 10 * np.log10(np.maximum(amplitudes, np.finfo(float).tiny)) diff --git a/backend/classification/validation.py b/backend/classification/validation.py deleted file mode 100644 index 587f928e..00000000 --- a/backend/classification/validation.py +++ /dev/null @@ -1,47 +0,0 @@ -from classification.config.constants import ( - FILE_MINIMUM_DURATION, - ACCEPTED_AGE_RANGE, -) -from classification.exceptions import ( - TimestampsError, - FileSizeError, - ClassificationError, -) - - -def validate(raw_eeg, info): - _validate_timestamps(info['in_bed_seconds'], info['out_of_bed_seconds']) - _validate_file_with_timestamps(raw_eeg, info['out_of_bed_seconds']) - _validate_age(info['age']) - - -def _validate_timestamps(in_bed_seconds, out_of_bed_seconds): - has_positive_timespan = in_bed_seconds > 0 and out_of_bed_seconds > 0 - has_got_out_of_bed_after_in_bed = out_of_bed_seconds > in_bed_seconds - has_respected_minimum_bed_time = (out_of_bed_seconds - in_bed_seconds) > FILE_MINIMUM_DURATION - - if not( - has_positive_timespan - and has_got_out_of_bed_after_in_bed - and has_respected_minimum_bed_time - ): - raise TimestampsError() - - -def _validate_file_with_timestamps(raw_eeg, out_of_bed_seconds): - has_raw_respected_minimum_file_size = raw_eeg.times[-1] > FILE_MINIMUM_DURATION - - if not has_raw_respected_minimum_file_size: - raise FileSizeError() - - is_raw_at_least_as_long_as_out_of_bed = raw_eeg.times[-1] >= out_of_bed_seconds - - if not is_raw_at_least_as_long_as_out_of_bed: - raise TimestampsError() - - -def _validate_age(age): - is_in_accepted_range = ACCEPTED_AGE_RANGE[0] <= int(age) <= ACCEPTED_AGE_RANGE[1] - - if not(is_in_accepted_range): - raise ClassificationError('invalid age') diff --git a/polydodo.code-workspace b/polydodo.code-workspace index fe66e646..02033be8 100644 --- a/polydodo.code-workspace +++ b/polydodo.code-workspace @@ -43,7 +43,8 @@ "python.linting.enabled": true, "python.linting.flake8Enabled": true, "python.formatting.provider": "autopep8", - "python.pythonPath": "/usr/bin/python" + "python.pythonPath": "/usr/bin/python", + "git.pullTags": false }, "extensions": { "recommendations": [