Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 23 additions & 20 deletions backend/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@
from classification.exceptions import ClassificationError
from classification.config.constants import Sex, ALLOWED_FILE_EXTENSIONS
from classification.model import SleepStagesClassifier
from classification.request import ClassificationRequest
from classification.response import ClassificationResponse
from classification.features.preprocessing import preprocess
from classification.spectrogram_generator import SpectrogramGenerator

app = Flask(__name__)
model = SleepStagesClassifier()
sleep_stage_classifier = SleepStagesClassifier()


def allowed_file(filename):
Expand Down Expand Up @@ -46,29 +50,28 @@ def analyze_sleep():
return 'File format not allowed', HTTPStatus.BAD_REQUEST

form_data = request.form.to_dict()
raw_array = get_raw_array(file)

try:
age = int(form_data['age'])
sex = Sex[form_data['sex']]
stream_start = int(form_data['stream_start'])
bedtime = int(form_data['bedtime'])
wakeup = int(form_data['wakeup'])
except (KeyError, ValueError):
classification_request = ClassificationRequest(
age=int(form_data['age']),
sex=Sex[form_data['sex']],
stream_start=int(form_data['stream_start']),
bedtime=int(form_data['bedtime']),
wakeup=int(form_data['wakeup']),
raw_eeg=raw_array,
)
except (KeyError, ValueError, ClassificationError):
return 'Missing or invalid request parameters', HTTPStatus.BAD_REQUEST

try:
raw_array = get_raw_array(file)
model.predict(raw_array, info={
'sex': sex,
'age': age,
'in_bed_seconds': bedtime - stream_start,
'out_of_bed_seconds': wakeup - stream_start
})
except ClassificationError as e:
return e.message, HTTPStatus.BAD_REQUEST

with open("assets/mock_response.json", "r") as mock_response_file:
return mock_response_file.read()
preprocessed_epochs = preprocess(classification_request)
predictions = sleep_stage_classifier.predict(preprocessed_epochs, classification_request)
spectrogram_generator = SpectrogramGenerator(preprocessed_epochs)
classification_response = ClassificationResponse(
classification_request, predictions, spectrogram_generator.generate()
)

return classification_response.response


CORS(app,
Expand Down
14 changes: 10 additions & 4 deletions backend/classification/config/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@ class Sex(Enum):
M = 2


class SleepStage(Enum):
W = 0
N1 = 1
N2 = 2
N3 = 3
REM = 4


class HiddenMarkovModelProbability(Enum):
emission = auto()
start = auto()
Expand All @@ -20,8 +28,8 @@ def get_filename(self):
ALLOWED_FILE_EXTENSIONS = ('.txt', '.csv')

EEG_CHANNELS = [
'EEG Fpz-Cz',
'EEG Pz-Oz'
'Fpz-Cz',
'Pz-Oz'
]

EPOCH_DURATION = 30
Expand All @@ -38,5 +46,3 @@ def get_filename(self):
[85, 125]
]
ACCEPTED_AGE_RANGE = [AGE_FEATURE_BINS[0][0], AGE_FEATURE_BINS[-1][-1]]

N_STAGES = 5
15 changes: 4 additions & 11 deletions backend/classification/features/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,17 @@
)


def get_features(signal, info):
def get_features(signal, request):
"""Returns the raw features
Input:
- raw_eeg: instance of mne.io.RawArray
Should contain 2 channels (1: FPZ-CZ, 2: PZ-OZ)
- info: dict
Should contain the following keys:
- sex: instance of Sex enum
- age: indicates the subject's age
- in_bed_seconds: timespan, in seconds, from which
the subject started the recording and went to bed
- out_of_bed_seconds: timespan, in seconds, from which
the subject started the recording and got out of bed
- info: instance of ClassificationRequest
Returns
-------
- features X in a vector of (nb_epochs, nb_features)
"""
X_eeg = get_eeg_features(signal, info['in_bed_seconds'], info['out_of_bed_seconds'])
X_categorical = get_non_eeg_features(info['age'], info['sex'], X_eeg.shape[0])
X_eeg = get_eeg_features(signal, request.in_bed_seconds, request.out_of_bed_seconds)
X_categorical = get_non_eeg_features(request.age, request.sex, X_eeg.shape[0])

return np.append(X_categorical, X_eeg, axis=1).astype(np.float32)
19 changes: 9 additions & 10 deletions backend/classification/features/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,13 @@
AGE_FEATURE_BINS,
)
from classification.features.pipeline import get_feature_union
from classification.features.preprocessing import preprocess


def get_eeg_features(raw_data, in_bed_seconds, out_of_bed_seconds):
def get_eeg_features(epochs, in_bed_seconds, out_of_bed_seconds):
"""Returns the continuous feature matrix
Input
-------
raw_signal: MNE.Raw object with signals with or without annotations
epochs: mne.Epochs object with signals with or without annotations
in_bed_seconds: timespan, in seconds, from which the subject started
the recording and went to bed
out_of_bed_seconds: timespan, in seconds, from which the subject
Expand All @@ -23,21 +22,21 @@ def get_eeg_features(raw_data, in_bed_seconds, out_of_bed_seconds):
-------
Array of size (nb_epochs, nb_continuous_features)
"""
features_file = []
features = []
feature_union = get_feature_union()

for channel in EEG_CHANNELS:
chan_data = preprocess(raw_data, channel, in_bed_seconds, out_of_bed_seconds)
channel_epochs = epochs.copy().pick_channels({channel})
channel_features = feature_union.transform(channel_epochs)

X_features = feature_union.transform(chan_data)
features_file.append(X_features)
features.append(channel_features)

print(
f"Done extracting {X_features.shape[1]} features "
f"on {X_features.shape[0]} epochs for {channel}\n"
f"Done extracting {channel_features.shape[1]} features "
f"on {channel_features.shape[0]} epochs for {channel}\n"
)

return np.hstack(tuple(features_file))
return np.hstack(tuple(features))


def get_non_eeg_features(age, sex, nb_epochs):
Expand Down
26 changes: 8 additions & 18 deletions backend/classification/features/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,36 +12,26 @@
)


def preprocess(raw_data, channel, bed_seconds, out_of_bed_seconds):
def preprocess(classification_request):
"""Returns preprocessed epochs of the specified channel
Input
-------
raw_data: instance of mne.Raw
channel: channel to preprocess
bed_seconds: number of seconds between start of recording & moment at
which the subjet went to bed (in seconds)
out_of_bed_seconds: number of seconds between start of recording & moment
at which the subjet got out of bed (in seconds)
classification_request: instance of ClassificationRequest
"""
raw_data = raw_data.copy()
raw_data = classification_request.raw_eeg.copy()

raw_data = _drop_other_channels(raw_data, channel)
raw_data = _crop_raw_data(raw_data, bed_seconds, out_of_bed_seconds)
raw_data = _crop_raw_data(
raw_data,
classification_request.in_bed_seconds,
classification_request.out_of_bed_seconds,
)
raw_data = _apply_high_pass_filter(raw_data)
raw_data = raw_data.resample(DATASET_SAMPLE_RATE)
raw_data = _convert_to_epochs(raw_data)

return raw_data


def _drop_other_channels(raw_data, channel_to_keep):
"""returns mne.Raw with only the channel to keep"""
raw_data.drop_channels(
[ch for ch in raw_data.info['ch_names'] if ch != channel_to_keep])

return raw_data


def _crop_raw_data(
raw_data,
bed_seconds,
Expand Down
32 changes: 17 additions & 15 deletions backend/classification/file_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
The Cyton board logging format is also described here:
[https://docs.openbci.com/docs/02Cyton/CytonSDCard#data-logging-format]
"""
from io import StringIO
from mne import create_info
from mne.io import RawArray
import numpy as np

from classification.exceptions import ClassificationError
from classification.config.constants import (
EEG_CHANNELS,
OPENBCI_CYTON_SAMPLE_RATE,
Expand All @@ -31,6 +31,7 @@

FILE_COLUMN_OFFSET = 1
CYTON_TOTAL_NB_CHANNELS = 8
SKIP_ROWS = 2


def get_raw_array(file):
Expand All @@ -40,31 +41,32 @@ def get_raw_array(file):
Returns:
- mne.RawArray of the two EEG channels of interest
"""
file_content = StringIO(file.stream.read().decode("UTF8"))
lines = file.readlines()
eeg_raw = np.zeros((len(lines) - SKIP_ROWS, len(EEG_CHANNELS)))

eeg_raw = []
for line in file_content.readlines():
line_splitted = line.split(',')
for index, line in enumerate(lines[SKIP_ROWS:]):
line_splitted = line.decode('utf-8').split(',')

if len(line_splitted) >= CYTON_TOTAL_NB_CHANNELS:
eeg_raw.append(_get_decimals_from_hexadecimal_strings(line_splitted))
if len(line_splitted) < CYTON_TOTAL_NB_CHANNELS:
raise ClassificationError()

eeg_raw = SCALE_V_PER_COUNT * np.array(eeg_raw, dtype='object')
eeg_raw[index] = _get_decimals_from_hexadecimal_strings(line_splitted)

raw_object = RawArray(
np.transpose(eeg_raw),
SCALE_V_PER_COUNT * np.transpose(eeg_raw),
info=create_info(
ch_names=EEG_CHANNELS,
sfreq=OPENBCI_CYTON_SAMPLE_RATE,
ch_types='eeg'),
verbose=False,
)

print('First sample values: ', raw_object[:, 0])
print('Second sample values: ', raw_object[:, 1])
print('Number of samples: ', raw_object.n_times)
print('Duration of signal (h): ', raw_object.n_times / (3600 * OPENBCI_CYTON_SAMPLE_RATE))
print('Channel names: ', raw_object.ch_names)
print(f"""
First sample values: {raw_object[:, 0]}
Second sample values: {raw_object[:, 1]}
Number of samples: {raw_object.n_times}
Duration of signal (h): {raw_object.n_times / (3600 * OPENBCI_CYTON_SAMPLE_RATE)}
Channel names: {raw_object.ch_names}
""")

return raw_object

Expand Down
19 changes: 4 additions & 15 deletions backend/classification/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""defines models which predict sleep stages based off EEG signals"""

from classification.features import get_features
from classification.validation import validate
from classification.postprocessor import get_hmm_model
from classification.load_model import load_model, load_hmm

Expand All @@ -14,32 +13,22 @@ def __init__(self):
self.postprocessor_state = load_hmm()
self.postprocessor = get_hmm_model(self.postprocessor_state)

def predict(self, raw_eeg, info):
def predict(self, epochs, request):
"""
Input:
- raw_eeg: instance of mne.io.RawArray
- raw_eeg: instance of mne.Epochs
Should contain 2 channels (1: FPZ-CZ, 2: PZ-OZ)
- info: dict
Should contain the following keys:
- sex: instance of Sex enum
- age: indicates the subject's age
- in_bed_seconds: timespan, in seconds, from which
the subject started the recording and went to bed
- out_of_bed_seconds: timespan, in seconds, from which
the subject started the recording and got out of bed
- request: instance of ClassificationRequest
Returns: array of predicted sleep stages
"""

validate(raw_eeg, info)
features = get_features(raw_eeg, info)
features = get_features(epochs, request)

print(features, features.shape)

predictions = self._get_predictions(features)
predictions = self._get_postprocessed_predictions(predictions)

print(predictions)

return predictions

def _get_predictions(self, features):
Expand Down
4 changes: 2 additions & 2 deletions backend/classification/postprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from classification.config.constants import (
HiddenMarkovModelProbability,
N_STAGES,
SleepStage,
)


Expand All @@ -15,7 +15,7 @@ def get_hmm_model(state):
describes the according hidden markov model state
Returns: an instance of a trained MultinomialHMM
"""
hmm_model = MultinomialHMM(n_components=N_STAGES)
hmm_model = MultinomialHMM(n_components=len(SleepStage))

hmm_model.emissionprob_ = state[HiddenMarkovModelProbability.emission.name]
hmm_model.startprob_ = state[HiddenMarkovModelProbability.start.name]
Expand Down
Loading