Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions backend/backend/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from logging import INFO

from config.logging import config_logger

config_logger(
logger_name=__name__,
log_level=INFO,
)
16 changes: 16 additions & 0 deletions backend/backend/analyze_sleep.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import falcon
import logging

from backend.request import ClassificationRequest
from backend.response import ClassificationResponse
Expand All @@ -10,9 +11,12 @@
from classification.model import SleepStagesClassifier
from classification.features.preprocessing import preprocess

_logger = logging.getLogger(__name__)


class AnalyzeSleep:
def __init__(self):
_logger.info("Initializing sleep stage classifier.")
self.sleep_stage_classifier = SleepStagesClassifier()

@staticmethod
Expand Down Expand Up @@ -55,6 +59,7 @@ def on_post(self, request, response):
}
"""

_logger.info("Validating and parsing form fields and EEG file")
try:
form_data, file = self._parse_form(request.get_media())
raw_array = get_raw_array(file)
Expand All @@ -67,16 +72,27 @@ def on_post(self, request, response):
raw_eeg=raw_array,
)
except (KeyError, ValueError, ClassificationError):
_logger.warn(
Comment thread
abelfodil marked this conversation as resolved.
"An error occured when validating and parsing form fields. "
"Request parameters are either missing or invalid."
)
response.status = falcon.HTTP_400
response.content_type = falcon.MEDIA_TEXT
response.body = 'Missing or invalid request parameters'
return

_logger.info("Preprocessing of raw EEG data.")
preprocessed_epochs = preprocess(classification_request)

_logger.info("Prediction of EEG data to sleep stages.")
predictions = self.sleep_stage_classifier.predict(preprocessed_epochs, classification_request)

_logger.info("Computations of visualisation data & of sleep report metrics...")
spectrogram_generator = SpectrogramGenerator(preprocessed_epochs)
classification_response = ClassificationResponse(
classification_request, predictions, spectrogram_generator.generate()
)

response.body = json.dumps(classification_response.response)

_logger.info("Request completed")
7 changes: 7 additions & 0 deletions backend/backend/app.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import falcon
import logging

from backend.ping import Ping
from backend.analyze_sleep import AnalyzeSleep

_logger = logging.getLogger(__name__)


def App():
app = falcon.App(cors_enable=True)
Expand All @@ -13,4 +16,8 @@ def App():
analyze = AnalyzeSleep()
app.add_route('/analyze-sleep', analyze)

_logger.info(
'Completed local server initialization. '
'Please go back to your browser in order to submit your sleep EEG file. '
)
return app
9 changes: 9 additions & 0 deletions backend/backend/request.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging

from classification.config.constants import EPOCH_DURATION
from classification.config.constants import (
Expand All @@ -10,6 +11,8 @@
ClassificationError,
)

_logger = logging.getLogger(__name__)


class ClassificationRequest():
def __init__(self, sex, age, stream_start, bedtime, wakeup, raw_eeg, stream_duration=None):
Expand Down Expand Up @@ -56,21 +59,27 @@ def _validate_timestamps(self):
and has_got_out_of_bed_after_in_bed
and has_respected_minimum_bed_time
):
_logger.warn("Received timestamps are invalid.")
raise TimestampsError()

def _validate_file_with_timestamps(self):
has_raw_respected_minimum_file_size = self.raw_eeg.times[-1] > FILE_MINIMUM_DURATION

if not has_raw_respected_minimum_file_size:
_logger.warn(f"Uploaded file must at least have {FILE_MINIMUM_DURATION} seconds of data.")
raise FileSizeError()

is_raw_at_least_as_long_as_out_of_bed = self.raw_eeg.times[-1] >= self.out_of_bed_seconds

if not is_raw_at_least_as_long_as_out_of_bed:
_logger.warn(
"Uploaded file must at least last the time between the start of the "
f"stream and out of bed time, which is {self.out_of_bed_seconds} seconds.")
raise TimestampsError()

def _validate_age(self):
is_in_accepted_range = ACCEPTED_AGE_RANGE[0] <= int(self.age) <= ACCEPTED_AGE_RANGE[1]

if not(is_in_accepted_range):
_logger.warn(f"Age must be in the following range: {ACCEPTED_AGE_RANGE}")
raise ClassificationError('invalid age')
1 change: 1 addition & 0 deletions backend/backend/spectrogram_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def generate(self):
self.epochs,
fmin=self.spectrogram_min_freq,
fmax=self.spectrogram_max_freq,
verbose=False,
)
psds_db = self._convert_amplitudes_to_decibel(psds)

Expand Down
9 changes: 9 additions & 0 deletions backend/classification/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from logging import INFO

from config.logging import config_logger

config_logger(
logger_name=__name__,
log_level=INFO,
message_sublevel=True,
)
5 changes: 0 additions & 5 deletions backend/classification/features/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,6 @@ def get_eeg_features(epochs, in_bed_seconds, out_of_bed_seconds):

features.append(channel_features)

print(
f"Done extracting {channel_features.shape[1]} features "
f"on {channel_features.shape[0]} epochs for {channel}\n"
)

return np.hstack(tuple(features))


Expand Down
2 changes: 1 addition & 1 deletion backend/classification/features/pipeline/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def get_psds_from_epochs(epochs):
--------
psds with associated frequencies calculated with the welch method.
"""
psds, freqs = psd_welch(epochs, fmin=0.5, fmax=30.)
psds, freqs = psd_welch(epochs, fmin=0.5, fmax=30., verbose=False)
return psds, freqs


Expand Down
10 changes: 10 additions & 0 deletions backend/classification/features/preprocessing.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import mne
from scipy.signal import cheby1

Expand All @@ -11,6 +12,8 @@
HIGH_PASS_MAX_RIPPLE_DB,
)

_logger = logging.getLogger(__name__)


def preprocess(classification_request):
"""Returns preprocessed epochs of the specified channel
Expand All @@ -20,13 +23,20 @@ def preprocess(classification_request):
"""
raw_data = classification_request.raw_eeg.copy()

_logger.info("Cropping data from bed time to out of bed time.")
raw_data = _crop_raw_data(
raw_data,
classification_request.in_bed_seconds,
classification_request.out_of_bed_seconds,
)

_logger.info(f"Applying high pass filter at {DATASET_HIGH_PASS_FREQ}Hz.")
raw_data = _apply_high_pass_filter(raw_data)

_logger.info(f"Resampling data at the dataset's sampling rate of {DATASET_SAMPLE_RATE} Hz.")
raw_data = raw_data.resample(DATASET_SAMPLE_RATE)

_logger.info(f"Epoching data with a {EPOCH_DURATION} seconds duration.")
raw_data = _convert_to_epochs(raw_data)

return raw_data
Expand Down
15 changes: 13 additions & 2 deletions backend/classification/load_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from datetime import datetime
import logging
from os import path, makedirs
from pathlib import Path
import re
Expand All @@ -12,6 +13,10 @@

from classification.config.constants import HiddenMarkovModelProbability


_logger = logging.getLogger(__name__)


SCRIPT_PATH = Path(path.realpath(sys.argv[0])).parent

BUCKET_NAME = 'polydodo'
Expand Down Expand Up @@ -50,9 +55,13 @@ def _has_latest_object(filename, local_path):

def load_model():
if not path.exists(MODEL_PATH) or not _has_latest_object(MODEL_FILENAME, MODEL_PATH):
print("Downloading latest model...")
_logger.info(
"Downloading latest sleep stage classification model... "
f"This could take a few minutes. (storing it at {MODEL_PATH})"
)
_download_file(MODEL_URL, MODEL_PATH)
print("Loading model...")

_logger.info(f"Loading latest sleep stage classification model... (from {MODEL_PATH})")
return onnxruntime.InferenceSession(str(MODEL_PATH))


Expand All @@ -67,8 +76,10 @@ def load_hmm():
model_path = SCRIPT_PATH / HMM_FOLDER / hmm_file

if not path.exists(model_path) or not _has_latest_object(hmm_file, model_path):
_logger.info(f"Downloading postprocessing model... (storing it at {model_path})")
_download_file(url=f"{BUCKET_URL}/{hmm_file}", output=model_path)

_logger.info(f"Loading postprocessing model... (from {model_path})")
hmm_matrices[hmm_probability.name] = np.load(str(model_path))

return hmm_matrices
11 changes: 8 additions & 3 deletions backend/classification/model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
"""defines models which predict sleep stages based off EEG signals"""
import logging

from classification.features import get_features
from classification.postprocessor import get_hmm_model
from classification.load_model import load_model, load_hmm

_logger = logging.getLogger(__name__)


class SleepStagesClassifier():
def __init__(self):
Expand All @@ -21,12 +24,14 @@ def predict(self, epochs, request):
- request: instance of ClassificationRequest
Returns: array of predicted sleep stages
"""

_logger.info("Extracting features...")
features = get_features(epochs, request)
_logger.info(f"Finished extracting {features.shape[1]} features over {features.shape[0]} epochs.")

print(features, features.shape)

_logger.info("Classifying sleep stages from extracted features...")
predictions = self._get_predictions(features)

_logger.info("Applying postprocessing step to the resulted sleep stages...")
predictions = self._get_postprocessed_predictions(predictions)

return predictions
Expand Down
31 changes: 18 additions & 13 deletions backend/classification/parser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
The Cyton board logging format is also described here:
[https://docs.openbci.com/docs/02Cyton/CytonSDCard#data-logging-format]
"""
import logging

from mne import create_info
from mne.io import RawArray

Expand All @@ -18,6 +20,9 @@
from classification.parser.sample_rate import detect_sample_rate


_logger = logging.getLogger(__name__)


def get_raw_array(file):
"""Converts a file following a logging format into a mne.RawArray
Input:
Expand All @@ -27,15 +32,15 @@ def get_raw_array(file):
"""

filetype = detect_file_type(file)
print(f"""
Detected {filetype.name} format.
""")

sample_rate = detect_sample_rate(file, filetype)
print(f"""
Detected {sample_rate}Hz sample rate.
""")

_logger.info(
f"EEG data has been detected to be in the {filetype.name} format "
f"and has a {sample_rate}Hz sample rate."
)

_logger.info("Parsing EEG file to a mne.RawArray object...")
eeg_raw = filetype.parser(file)

raw_object = RawArray(
Expand All @@ -47,12 +52,12 @@ def get_raw_array(file):
verbose=False,
)

print(f"""
First sample values: {raw_object[:, 0]}
Second sample values: {raw_object[:, 1]}
Number of samples: {raw_object.n_times}
Duration of signal (h): {raw_object.n_times / (3600 * sample_rate)}
Channel names: {raw_object.ch_names}
""")
_logger.info(
f"Finished converting EEG file to mne.RawArray object "
f"with the first sample being {*(raw_object[:, 0][0]),}, "
f"with {raw_object.n_times} samples, "
f"with a {raw_object.n_times / (3600 * sample_rate):.2f} hours duration and "
f"with channels named {raw_object.ch_names}."
)

return raw_object
20 changes: 20 additions & 0 deletions backend/config/logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import logging
import sys

STD_OUTPUT_FORMAT = "[%(asctime)s - %(levelname)s]:\t%(message)s (%(name)s)"
SUBLEVEL_OUTPUT_FORMAT = "[%(asctime)s - %(levelname)s]:\t\t%(message)s (%(name)s)"


def config_logger(logger_name, log_level, message_sublevel=False):
"""Configures logging with std output"""
logger = logging.getLogger(logger_name)
logger.setLevel(log_level)
logger.addHandler(_get_console_handler(message_sublevel))
logger.propagate = False


def _get_console_handler(message_sublevel):
console_handler = logging.StreamHandler(sys.stdout)
formatter = SUBLEVEL_OUTPUT_FORMAT if message_sublevel else STD_OUTPUT_FORMAT
console_handler.setFormatter(logging.Formatter(formatter))
return console_handler
Comment thread
abelfodil marked this conversation as resolved.