From 94000f30a25a1b3c518d330a87647a07f9ea21b5 Mon Sep 17 00:00:00 2001 From: Claudia Onorato Date: Sun, 25 Oct 2020 21:25:26 -0400 Subject: [PATCH 01/24] added notebook prototype with ONNX --- ai/export_to_onnx.ipynb | 479 +++++++++++++++++++++ ai/models/RF_HMM_smaller.ipynb | 736 +++++++++++++++++++++++++++++++++ backend/requirements.txt | 4 +- 3 files changed, 1218 insertions(+), 1 deletion(-) create mode 100644 ai/export_to_onnx.ipynb create mode 100644 ai/models/RF_HMM_smaller.ipynb diff --git a/ai/export_to_onnx.ipynb b/ai/export_to_onnx.ipynb new file mode 100644 index 00000000..908f88bd --- /dev/null +++ b/ai/export_to_onnx.ipynb @@ -0,0 +1,479 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Save & reload trained model with ONNX\n", + "___\n", + "\n", + "This notebook aims to save, reload and check if the model can be correctly serialized through ONNX, and the Scikit-learn ONNX package." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%matplotlib inline\n", + "\n", + "import numpy as np\n", + "\n", + "from sklearn.pipeline import Pipeline\n", + "from sklearn.compose import ColumnTransformer\n", + "from sklearn.preprocessing import StandardScaler\n", + "from sklearn.neighbors import KNeighborsClassifier\n", + "from sklearn.discriminant_analysis import LinearDiscriminantAnalysis\n", + "from sklearn.ensemble import (RandomForestClassifier,\n", + " VotingClassifier)\n", + "from sklearn.metrics import (confusion_matrix,\n", + " classification_report,\n", + " cohen_kappa_score)\n", + "from skl2onnx import convert_sklearn\n", + "from skl2onnx.common.data_types import FloatTensorType\n", + "from skl2onnx.helpers.onnx_helper import save_onnx_model\n", + "from onnxruntime import InferenceSession\n", + "\n", + "from models.model_utils import (train_test_split_according_to_age)\n", + "from constants import (SLEEP_STAGES_VALUES,)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Generate trained pipeline\n", + "____" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "SUBJECT_IDX = 0 \n", + "NIGHT_IDX = 1\n", + "USE_CONTINUOUS_AGE = False\n", + "DOWNSIZE_SET = False\n", + "TEST_SET_SUBJECTS = [0.0, 24.0, 49.0, 71.0]" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(168954, 50)\n", + "(168954,)\n", + "Number of subjects: 78\n", + "Number of nights: 153\n", + "Subjects available: [ 0. 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11. 12. 13. 14. 15. 16. 17.\n", + " 18. 19. 20. 21. 22. 23. 24. 25. 26. 27. 28. 29. 30. 31. 32. 33. 34. 35.\n", + " 36. 37. 38. 40. 41. 42. 43. 44. 45. 46. 47. 48. 49. 50. 51. 52. 53. 54.\n", + " 55. 56. 57. 58. 59. 60. 61. 62. 63. 64. 65. 66. 67. 70. 71. 72. 73. 74.\n", + " 75. 76. 77. 80. 81. 82.]\n", + "Selected subjects for the test set are: [0.0, 24.0, 49.0, 71.0]\n", + "(8123, 50) (160831, 50) (8123,) (160831,)\n" + ] + } + ], + "source": [ + "def load_features():\n", + " if USE_CONTINUOUS_AGE:\n", + " X_file_name = \"data/x_features-age-continuous.npy\"\n", + " y_file_name = \"data/y_observations-age-continuous.npy\"\n", + " else:\n", + " X_file_name = \"data/x_features.npy\"\n", + " y_file_name = \"data/y_observations.npy\"\n", + "\n", + " X_init = np.load(X_file_name, allow_pickle=True)\n", + " y_init = np.load(y_file_name, allow_pickle=True)\n", + "\n", + " X_init = np.vstack(X_init)\n", + " y_init = np.hstack(y_init)\n", + "\n", + " print(X_init.shape)\n", + " print(y_init.shape)\n", + " print(\"Number of subjects: \", np.unique(X_init[:,SUBJECT_IDX]).shape[0]) # Some subject indexes are skipped, thus total number is below 83 (as we can see in https://physionet.org/content/sleep-edfx/1.0.0/)\n", + " print(\"Number of nights: \", len(np.unique([f\"{int(x[0])}-{int(x[1])}\" for x in X_init[:,SUBJECT_IDX:NIGHT_IDX+1]])))\n", + " print(\"Subjects available: \", np.unique(X_init[:,SUBJECT_IDX]))\n", + " \n", + " return X_init, y_init\n", + "\n", + "def split_data(X_init, y_init):\n", + " X_test, X_train_valid, y_test, y_train_valid = train_test_split_according_to_age(\n", + " X_init,\n", + " y_init,\n", + " use_continuous_age=USE_CONTINUOUS_AGE,\n", + " subjects_test=TEST_SET_SUBJECTS)\n", + " \n", + " print(X_test.shape, X_train_valid.shape, y_test.shape, y_train_valid.shape)\n", + " \n", + " return X_test, X_train_valid, y_test, y_train_valid\n", + "\n", + "X_init, y_init = load_features()\n", + "X_test, X_train_valid, y_test, y_train_valid = split_data(X_init, y_init)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[1522 52 2 4 44]\n", + " [ 240 145 325 1 272]\n", + " [ 38 57 3210 188 110]\n", + " [ 4 0 31 576 0]\n", + " [ 57 83 264 0 898]]\n", + " precision recall f1-score support\n", + "\n", + " W 0.82 0.94 0.87 1624\n", + " N1 0.43 0.15 0.22 983\n", + " N2 0.84 0.89 0.86 3603\n", + " N3 0.75 0.94 0.83 611\n", + " REM 0.68 0.69 0.68 1302\n", + "\n", + " accuracy 0.78 8123\n", + " macro avg 0.70 0.72 0.70 8123\n", + "weighted avg 0.75 0.78 0.76 8123\n", + "\n", + "Agreement score (Cohen Kappa): 0.6913101923642638\n" + ] + } + ], + "source": [ + "def get_voting_classifier_pipeline():\n", + " NB_CATEGORICAL_FEATURES = 2\n", + " NB_FEATURES = 48\n", + "\n", + " estimator_list = [\n", + " ('random_forest', RandomForestClassifier(\n", + " random_state=42, # enables deterministic behaviour\n", + " n_jobs=-1\n", + " )),\n", + " ('knn', Pipeline([\n", + " ('knn_dim_red', LinearDiscriminantAnalysis()),\n", + " ('knn_clf', KNeighborsClassifier(\n", + " weights='uniform',\n", + " n_neighbors=300,\n", + " leaf_size=100,\n", + " metric='chebyshev',\n", + " n_jobs=-1\n", + " ))\n", + " ])),\n", + " ]\n", + " \n", + " return Pipeline([\n", + " ('scaling', ColumnTransformer([\n", + " ('pass-through-categorical', 'passthrough', list(range(NB_CATEGORICAL_FEATURES))),\n", + " ('scaling-continuous', StandardScaler(copy=False), list(range(NB_CATEGORICAL_FEATURES,NB_FEATURES)))\n", + " ])),\n", + " ('voting_clf', VotingClassifier(\n", + " estimators=estimator_list,\n", + " voting='soft',\n", + " weights=np.array([0.83756205, 0.16243795]),\n", + " flatten_transform=False,\n", + " n_jobs=-1,\n", + " ))\n", + " ])\n", + "\n", + "testing_pipeline = get_voting_classifier_pipeline()\n", + "testing_pipeline.fit(X_train_valid[:, 2:], y_train_valid)\n", + "y_test_pred = testing_pipeline.predict(X_test[:,2:])\n", + "\n", + "print(confusion_matrix(y_test, y_test_pred))\n", + "print(classification_report(y_test, y_test_pred, target_names=SLEEP_STAGES_VALUES.keys()))\n", + "print(\"Agreement score (Cohen Kappa): \", cohen_kappa_score(y_test, y_test_pred))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Saving pipeline to ONNX\n", + "___" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "ename": "KeyError", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 3\u001b[0m initial_types=[(\n\u001b[1;32m 4\u001b[0m \u001b[0;34m'float_input'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0mFloatTensorType\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX_train_valid\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m )]\n\u001b[1;32m 7\u001b[0m )\n", + "\u001b[0;32m~/miniconda3/envs/py3/lib/python3.6/site-packages/skl2onnx/convert.py\u001b[0m in \u001b[0;36mconvert_sklearn\u001b[0;34m(model, name, initial_types, doc_string, target_opset, custom_conversion_functions, custom_shape_calculators, custom_parsers, options, dtype, intermediate, white_op, black_op, final_types)\u001b[0m\n\u001b[1;32m 152\u001b[0m \u001b[0;31m# Convert our Topology object into ONNX. The outcome is an ONNX model.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 153\u001b[0m onnx_model = convert_topology(topology, name, doc_string, target_opset,\n\u001b[0;32m--> 154\u001b[0;31m dtype=dtype, options=options)\n\u001b[0m\u001b[1;32m 155\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 156\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0monnx_model\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtopology\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mintermediate\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0monnx_model\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/py3/lib/python3.6/site-packages/skl2onnx/common/_topology.py\u001b[0m in \u001b[0;36mconvert_topology\u001b[0;34m(topology, model_name, doc_string, target_opset, channel_first_inputs, dtype, options)\u001b[0m\n\u001b[1;32m 1052\u001b[0m type(getattr(operator, 'raw_model', None))))\n\u001b[1;32m 1053\u001b[0m \u001b[0mcontainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalidate_options\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moperator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1054\u001b[0;31m \u001b[0mconv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mscope\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moperator\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcontainer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1055\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1056\u001b[0m \u001b[0;31m# Create a graph from its main components\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/py3/lib/python3.6/site-packages/skl2onnx/common/_registration.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args)\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mraw_operator\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_get_allowed_options\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mraw_operator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 29\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_fct\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 30\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mget_allowed_options\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/py3/lib/python3.6/site-packages/skl2onnx/operator_converters/voting_classifier.py\u001b[0m in \u001b[0;36mconvert_voting_classifier\u001b[0;34m(scope, operator, container)\u001b[0m\n\u001b[1;32m 44\u001b[0m \u001b[0;32mcontinue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 46\u001b[0;31m \u001b[0mop_type\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msklearn_operator_name_map\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mestimator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 47\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[0mthis_operator\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mscope\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdeclare_local_operator\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mop_type\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mKeyError\u001b[0m: " + ] + } + ], + "source": [ + "onnx_pipeline = convert_sklearn(\n", + " testing_pipeline,\n", + " initial_types=[(\n", + " 'float_input',\n", + " FloatTensorType([None, X_train_valid[:,2:].shape[1]])\n", + " )]\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can see the Voting classifier conversion do not currently support a Pipeline typed estimator in its estimators list.\n", + "\n", + "Considering that;\n", + "- the option of adding a pipeline as an estimator in the voting classifier is not supported.\n", + "- the size of a KNearestNeighbor classifier would be too big without its LDA, and that the performance of the RandomForest would be significantly decreased with an LDA beforehand.\n", + "- the voting classifier had a Cohen Kappa agreement's score of 0.6913 on the testing set, whilst we obtained 0.6916 with the fat Random Forest, and we obtained 0.6879 with the skinny RF.\n", + "- the voting classifier had a Cohen Kappa agreement's score of 0.62 ± 0.043 on the validation set, whilst we obtained [TO DE DEFINED] with the fat Random Forest, and we obtained 0.62 ± 0.043 with the skinny RF (where the validation set is a CV of 5 partitions and considering subjects)\n", + "- the size of the small RF is 322.8 Mbytes\n", + "- the size of the fat RF is 1.91 Gbytes\n", + "- the size of the voting classifier is 376.8 Mbytes\n", + "\n", + "We have decided to temporaly choose to use the skinny random forest, as its performance is quite similar to the voting classifier's and the fat random forest's." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Generate RF trained pipeline\n", + "___" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[1512 65 3 3 41]\n", + " [ 220 147 332 0 284]\n", + " [ 39 45 3212 194 113]\n", + " [ 4 0 32 575 0]\n", + " [ 49 81 284 0 888]]\n", + " precision recall f1-score support\n", + "\n", + " W 0.83 0.93 0.88 1624\n", + " N1 0.43 0.15 0.22 983\n", + " N2 0.83 0.89 0.86 3603\n", + " N3 0.74 0.94 0.83 611\n", + " REM 0.67 0.68 0.68 1302\n", + "\n", + " accuracy 0.78 8123\n", + " macro avg 0.70 0.72 0.69 8123\n", + "weighted avg 0.75 0.78 0.75 8123\n", + "\n", + "Agreement score (Cohen Kappa): 0.6879671218212182\n", + "CPU times: user 3min 41s, sys: 2.5 s, total: 3min 43s\n", + "Wall time: 1min 49s\n" + ] + } + ], + "source": [ + "%%time\n", + "\n", + "def get_random_forest_model():\n", + " NB_CATEGORICAL_FEATURES = 2\n", + " NB_FEATURES = 48\n", + " \n", + " return Pipeline([\n", + " ('scaling', ColumnTransformer([\n", + " ('pass-through-categorical', 'passthrough', list(range(NB_CATEGORICAL_FEATURES))),\n", + " ('scaling-continuous', StandardScaler(copy=False), list(range(NB_CATEGORICAL_FEATURES,NB_FEATURES)))\n", + " ])),\n", + " ('classifier', RandomForestClassifier(\n", + " n_estimators=100,\n", + " max_depth=24,\n", + " random_state=42, # enables deterministic behaviour\n", + " n_jobs=-1\n", + " ))\n", + " ])\n", + "\n", + "testing_pipeline = get_random_forest_model()\n", + "testing_pipeline.fit(X_train_valid[:, 2:], y_train_valid)\n", + "y_test_pred = testing_pipeline.predict(X_test[:,2:])\n", + "\n", + "print(confusion_matrix(y_test, y_test_pred))\n", + "print(classification_report(y_test, y_test_pred, target_names=SLEEP_STAGES_VALUES.keys()))\n", + "print(\"Agreement score (Cohen Kappa): \", cohen_kappa_score(y_test, y_test_pred))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Saving with ONNX\n", + "____" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "onnx_pipeline = convert_sklearn(\n", + " testing_pipeline,\n", + " initial_types=[(\n", + " 'float_input',\n", + " FloatTensorType([None, X_train_valid[:,2:].shape[1]])\n", + " )]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "save_onnx_model(onnx_pipeline, 'trained_model/rf_pipeline.onnx')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Comparing ONNX pipeline vs normal pipeline results\n", + "____" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "sess = InferenceSession('trained_model/rf_pipeline.onnx')" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "y_test_pred_onnx = sess.run(None, {'float_input': X_test[:,2:].astype(np.float32)})[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.000738643358365136" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sum(~(y_test_pred_onnx == y_test_pred))/len(y_test_pred)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## ONNX Pipepline drawing" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from onnx.tools.net_drawer import GetPydotGraph, GetOpNodeProducer\n", + "pydot_graph = GetPydotGraph(onnx_pipeline.graph, name=onnx_pipeline.graph.name, rankdir=\"TP\",\n", + " node_producer=GetOpNodeProducer(\"docstring\"))\n", + "pydot_graph.write_dot(\"graph.dot\")\n", + "\n", + "import os\n", + "os.system('dot -O -V -Tpng graph.dot')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.8" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/ai/models/RF_HMM_smaller.ipynb b/ai/models/RF_HMM_smaller.ipynb new file mode 100644 index 00000000..7e6f6838 --- /dev/null +++ b/ai/models/RF_HMM_smaller.ipynb @@ -0,0 +1,736 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Sleep stage classification: Random Forest & Hidden Markov Model\n", + "____\n", + "\n", + "This model aims to classify sleep stages based on two EEG channel. We will use the features extracted in the `pipeline.ipynb` notebook as the input to a Random Forest. The output of this model will then be used as the input of a HMM. We will implement our HMM the same as in this paper (Malafeev et al., « Automatic Human Sleep Stage Scoring Using Deep Neural Networks »)." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "import os\n", + "import sys\n", + "\n", + "# Ensure parent folder is in PYTHONPATH\n", + "module_path = os.path.abspath(os.path.join('..'))\n", + "if module_path not in sys.path:\n", + " sys.path.append(module_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "%matplotlib inline\n", + "\n", + "import sys\n", + "from itertools import groupby\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import joblib\n", + "\n", + "from sklearn.pipeline import Pipeline\n", + "from sklearn.compose import ColumnTransformer\n", + "from sklearn.preprocessing import StandardScaler\n", + "from sklearn.ensemble import RandomForestClassifier\n", + "from sklearn.model_selection import (GridSearchCV,\n", + " RandomizedSearchCV,\n", + " GroupKFold,\n", + " cross_validate)\n", + "from sklearn.metrics import (accuracy_score,\n", + " confusion_matrix,\n", + " classification_report,\n", + " f1_score,\n", + " cohen_kappa_score,\n", + " make_scorer)\n", + "from sklearn.discriminant_analysis import LinearDiscriminantAnalysis\n", + "from sklearn.decomposition import PCA\n", + "\n", + "from scipy.signal import medfilt\n", + "\n", + "from hmmlearn.hmm import MultinomialHMM\n", + "from constants import (SLEEP_STAGES_VALUES,\n", + " N_STAGES,\n", + " EPOCH_DURATION)\n", + "from model_utils import (print_hypnogram,\n", + " train_test_split_one_subject,\n", + " train_test_split_according_to_age)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load the features\n", + "___" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# position of the subject information and night information in the X matrix\n", + "SUBJECT_IDX = 0 \n", + "NIGHT_IDX = 1\n", + "USE_CONTINUOUS_AGE = False\n", + "DOWNSIZE_SET = False\n", + "TEST_SET_SUBJECTS = [0.0, 24.0, 49.0, 71.0]\n", + "\n", + "if USE_CONTINUOUS_AGE:\n", + " X_file_name = \"../data/x_features-age-continuous.npy\"\n", + " y_file_name = \"../data/y_observations-age-continuous.npy\"\n", + "else:\n", + " X_file_name = \"../data/x_features.npy\"\n", + " y_file_name = \"../data/y_observations.npy\"" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "X_init = np.load(X_file_name, allow_pickle=True)\n", + "y_init = np.load(y_file_name, allow_pickle=True)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(168954, 50)\n", + "(168954,)\n" + ] + } + ], + "source": [ + "X_init = np.vstack(X_init)\n", + "y_init = np.hstack(y_init)\n", + "print(X_init.shape)\n", + "print(y_init.shape)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of subjects: 78\n", + "Number of nights: 153\n" + ] + } + ], + "source": [ + "print(\"Number of subjects: \", np.unique(X_init[:,SUBJECT_IDX]).shape[0]) # Some subject indexes are skipped, thus total number is below 83 (as we can see in https://physionet.org/content/sleep-edfx/1.0.0/)\n", + "print(\"Number of nights: \", len(np.unique([f\"{int(x[0])}-{int(x[1])}\" for x in X_init[:,SUBJECT_IDX:NIGHT_IDX+1]])))\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Downsizing sets\n", + "___\n", + "\n", + "We will use the same set for all experiments. It includes the first 20 subjects, and excludes the 13th, because it only has one night.\n", + "\n", + "The last subject will be put in the test set. " + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "if DOWNSIZE_SET:\n", + " # Filtering to only keep first 20 subjects\n", + " X_20 = X_init[np.isin(X_init[:,SUBJECT_IDX], range(20))]\n", + " y_20 = y_init[np.isin(X_init[:,SUBJECT_IDX], range(20))]\n", + "\n", + " # Exclude the subject with only one night recording (13th)\n", + " MISSING_NIGHT_SUBJECT = 13\n", + "\n", + " X = X_20[X_20[:,SUBJECT_IDX] != MISSING_NIGHT_SUBJECT]\n", + " y = y_20[X_20[:,SUBJECT_IDX] != MISSING_NIGHT_SUBJECT]\n", + "\n", + " print(X.shape)\n", + " print(y.shape)\n", + "else:\n", + " X = X_init\n", + " y = y_init" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of subjects: 78\n", + "Subjects available: [ 0. 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11. 12. 13. 14. 15. 16. 17.\n", + " 18. 19. 20. 21. 22. 23. 24. 25. 26. 27. 28. 29. 30. 31. 32. 33. 34. 35.\n", + " 36. 37. 38. 40. 41. 42. 43. 44. 45. 46. 47. 48. 49. 50. 51. 52. 53. 54.\n", + " 55. 56. 57. 58. 59. 60. 61. 62. 63. 64. 65. 66. 67. 70. 71. 72. 73. 74.\n", + " 75. 76. 77. 80. 81. 82.]\n", + "Number of nights: 153\n" + ] + } + ], + "source": [ + "print(\"Number of subjects: \", np.unique(X[:,SUBJECT_IDX]).shape[0]) # Some subject indexes are skipped, thus total number is below 83 (as we can see in https://physionet.org/content/sleep-edfx/1.0.0/)\n", + "print(\"Subjects available: \", np.unique(X[:,SUBJECT_IDX]))\n", + "print(\"Number of nights: \", len(np.unique([f\"{int(x[0])}-{int(x[1])}\" for x in X[:,SUBJECT_IDX:NIGHT_IDX+1]])))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Train, validation and test sets\n", + "___\n", + "\n", + "If we downsize the dataset, the test set will only contain the two nights recording of the last subject (no 19) will be the test set. The rest will be the train and validation sets.\n", + "\n", + "If we did not downsize the dataset, we will randomly pick a subject from each age group to be in the test set. Both nights (if there are two) are placed in the test set so that the classifier does not train on any recordings from a subject placed in the test set.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Selected subjects for the test set are: [0.0, 24.0, 49.0, 71.0]\n", + "(8123, 50) (160831, 50) (8123,) (160831,)\n" + ] + } + ], + "source": [ + "if DOWNSIZE_SET:\n", + " X_test, X_train_valid, y_test, y_train_valid = train_test_split_one_subject(X, y)\n", + "else:\n", + " X_test, X_train_valid, y_test, y_train_valid = train_test_split_according_to_age(X,\n", + " y,\n", + " subjects_test=TEST_SET_SUBJECTS,\n", + " use_continuous_age=USE_CONTINUOUS_AGE)\n", + " \n", + "print(X_test.shape, X_train_valid.shape, y_test.shape, y_train_valid.shape)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Random forest validation\n", + "___" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "NB_KFOLDS = 5\n", + "NB_CATEGORICAL_FEATURES = 2\n", + "NB_FEATURES = 48\n", + "\n", + "CLASSIFIER_PIPELINE_KEY = 'classifier'\n", + "\n", + "def get_random_forest_model():\n", + " return Pipeline([\n", + " ('scaling', ColumnTransformer([\n", + " ('pass-through-categorical', 'passthrough', list(range(NB_CATEGORICAL_FEATURES))),\n", + " ('scaling-continuous', StandardScaler(copy=False), list(range(NB_CATEGORICAL_FEATURES,NB_FEATURES)))\n", + " ])),\n", + " (CLASSIFIER_PIPELINE_KEY, RandomForestClassifier(\n", + " n_estimators=100,\n", + " random_state=42, # enables deterministic behaviour\n", + " n_jobs=-1\n", + " ))\n", + " ])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For the cross validation, we will use the `GroupKFold` technique. For each fold, we make sure to train and validate on different subjects, to avoid overfitting over subjects." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 9 µs, sys: 1 µs, total: 10 µs\n", + "Wall time: 14.1 µs\n" + ] + } + ], + "source": [ + "%%time\n", + "\n", + "def cross_validate_pipeline(pipeline):\n", + " accuracies = []\n", + " macro_f1_scores = []\n", + " weighted_f1_scores = []\n", + " kappa_agreements = []\n", + " emission_matrix = np.zeros((N_STAGES,N_STAGES))\n", + "\n", + " for train_index, valid_index in GroupKFold(n_splits=5).split(X_train_valid, groups=X_train_valid[:,SUBJECT_IDX]):\n", + " # We drop the subject and night indexes\n", + " X_train, X_valid = X_train_valid[train_index, 2:], X_train_valid[valid_index, 2:]\n", + " y_train, y_valid = y_train_valid[train_index], y_train_valid[valid_index]\n", + "\n", + " pipeline.fit(X_train, y_train)\n", + " y_valid_pred = pipeline.predict(X_valid)\n", + "\n", + " print(\"----------------------------- FOLD RESULTS --------------------------------------\\n\")\n", + " current_kappa = cohen_kappa_score(y_valid, y_valid_pred)\n", + "\n", + " print(\"TRAIN:\", train_index, \"VALID:\", valid_index, \"\\n\\n\")\n", + " print(confusion_matrix(y_valid, y_valid_pred), \"\\n\")\n", + " print(classification_report(y_valid, y_valid_pred, target_names=SLEEP_STAGES_VALUES.keys()), \"\\n\")\n", + " print(\"Agreement score (Cohen Kappa): \", current_kappa, \"\\n\")\n", + "\n", + " accuracies.append(round(accuracy_score(y_valid, y_valid_pred),2))\n", + " macro_f1_scores.append(f1_score(y_valid, y_valid_pred, average=\"macro\"))\n", + " weighted_f1_scores.append(f1_score(y_valid, y_valid_pred, average=\"weighted\"))\n", + " kappa_agreements.append(current_kappa)\n", + "\n", + " for y_pred, y_true in zip(y_valid_pred, y_valid):\n", + " emission_matrix[y_true, y_pred] += 1\n", + "\n", + " emission_matrix = emission_matrix / emission_matrix.sum(axis=1, keepdims=True)\n", + " \n", + " print(f\"Mean accuracy : {np.mean(accuracies):0.2f} ± {np.std(accuracies):0.3f}\")\n", + " print(f\"Mean macro F1-score : {np.mean(macro_f1_scores):0.2f} ± {np.std(macro_f1_scores):0.3f}\")\n", + " print(f\"Mean weighted F1-score : {np.mean(weighted_f1_scores):0.2f} ± {np.std(weighted_f1_scores):0.3f}\")\n", + " print(f\"Mean Kappa's agreement : {np.mean(kappa_agreements):0.2f} ± {np.std(kappa_agreements):0.3f}\")\n", + "\n", + " return emission_matrix" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------------- FOLD RESULTS --------------------------------------\n", + "\n", + "TRAIN: [ 2137 2138 2139 ... 158843 158844 158845] VALID: [ 0 1 2 ... 160828 160829 160830] \n", + "\n", + "\n", + "[[ 7206 194 111 2 139]\n", + " [ 1235 534 1404 1 543]\n", + " [ 993 439 10654 360 492]\n", + " [ 155 7 632 2132 5]\n", + " [ 842 907 1233 5 2579]] \n", + "\n", + " precision recall f1-score support\n", + "\n", + " W 0.69 0.94 0.80 7652\n", + " N1 0.26 0.14 0.18 3717\n", + " N2 0.76 0.82 0.79 12938\n", + " N3 0.85 0.73 0.79 2931\n", + " REM 0.69 0.46 0.55 5566\n", + "\n", + " accuracy 0.70 32804\n", + " macro avg 0.65 0.62 0.62 32804\n", + "weighted avg 0.68 0.70 0.68 32804\n", + " \n", + "\n", + "Agreement score (Cohen Kappa): 0.5914311657565539 \n", + "\n", + "----------------------------- FOLD RESULTS --------------------------------------\n", + "\n", + "TRAIN: [ 0 1 2 ... 160828 160829 160830] VALID: [ 5807 5808 5809 ... 158843 158844 158845] \n", + "\n", + "\n", + "[[ 6893 550 108 11 267]\n", + " [ 888 867 1136 3 1036]\n", + " [ 156 327 11494 392 992]\n", + " [ 23 0 452 1814 0]\n", + " [ 222 632 796 4 3185]] \n", + "\n", + " precision recall f1-score support\n", + "\n", + " W 0.84 0.88 0.86 7829\n", + " N1 0.36 0.22 0.27 3930\n", + " N2 0.82 0.86 0.84 13361\n", + " N3 0.82 0.79 0.80 2289\n", + " REM 0.58 0.66 0.62 4839\n", + "\n", + " accuracy 0.75 32248\n", + " macro avg 0.69 0.68 0.68 32248\n", + "weighted avg 0.73 0.75 0.74 32248\n", + " \n", + "\n", + "Agreement score (Cohen Kappa): 0.6553464447399939 \n", + "\n", + "----------------------------- FOLD RESULTS --------------------------------------\n", + "\n", + "TRAIN: [ 0 1 2 ... 160828 160829 160830] VALID: [ 2137 2138 2139 ... 151913 151914 151915] \n", + "\n", + "\n", + "[[7954 616 219 19 606]\n", + " [ 855 984 1223 10 1704]\n", + " [ 567 701 9904 181 1698]\n", + " [ 41 0 216 767 0]\n", + " [ 384 511 661 12 2462]] \n", + "\n", + " precision recall f1-score support\n", + "\n", + " W 0.81 0.84 0.83 9414\n", + " N1 0.35 0.21 0.26 4776\n", + " N2 0.81 0.76 0.78 13051\n", + " N3 0.78 0.75 0.76 1024\n", + " REM 0.38 0.61 0.47 4030\n", + "\n", + " accuracy 0.68 32295\n", + " macro avg 0.63 0.63 0.62 32295\n", + "weighted avg 0.69 0.68 0.68 32295\n", + " \n", + "\n", + "Agreement score (Cohen Kappa): 0.5601422740587234 \n", + "\n", + "----------------------------- FOLD RESULTS --------------------------------------\n", + "\n", + "TRAIN: [ 0 1 2 ... 160828 160829 160830] VALID: [ 4057 4058 4059 ... 121623 121624 121625] \n", + "\n", + "\n", + "[[ 6661 321 99 4 189]\n", + " [ 791 549 1154 14 873]\n", + " [ 216 192 11469 359 656]\n", + " [ 41 0 687 2567 2]\n", + " [ 386 498 1039 6 3351]] \n", + "\n", + " precision recall f1-score support\n", + "\n", + " W 0.82 0.92 0.87 7274\n", + " N1 0.35 0.16 0.22 3381\n", + " N2 0.79 0.89 0.84 12892\n", + " N3 0.87 0.78 0.82 3297\n", + " REM 0.66 0.63 0.65 5280\n", + "\n", + " accuracy 0.77 32124\n", + " macro avg 0.70 0.68 0.68 32124\n", + "weighted avg 0.74 0.77 0.75 32124\n", + " \n", + "\n", + "Agreement score (Cohen Kappa): 0.6754525828472742 \n", + "\n", + "----------------------------- FOLD RESULTS --------------------------------------\n", + "\n", + "TRAIN: [ 0 1 2 ... 160828 160829 160830] VALID: [ 13884 13885 13886 ... 156772 156773 156774] \n", + "\n", + "\n", + "[[ 6545 612 379 27 325]\n", + " [ 674 750 1561 11 939]\n", + " [ 238 313 10710 242 626]\n", + " [ 39 0 729 1887 1]\n", + " [ 355 829 1147 7 2414]] \n", + "\n", + " precision recall f1-score support\n", + "\n", + " W 0.83 0.83 0.83 7888\n", + " N1 0.30 0.19 0.23 3935\n", + " N2 0.74 0.88 0.80 12129\n", + " N3 0.87 0.71 0.78 2656\n", + " REM 0.56 0.51 0.53 4752\n", + "\n", + " accuracy 0.71 31360\n", + " macro avg 0.66 0.62 0.64 31360\n", + "weighted avg 0.69 0.71 0.70 31360\n", + " \n", + "\n", + "Agreement score (Cohen Kappa): 0.5996710408224084 \n", + "\n", + "Mean accuracy : 0.72 ± 0.033\n", + "Mean macro F1-score : 0.65 ± 0.027\n", + "Mean weighted F1-score : 0.71 ± 0.029\n", + "Mean Kappa's agreement : 0.62 ± 0.043\n" + ] + }, + { + "data": { + "text/plain": [ + "array([[8.80220686e-01, 5.72434281e-02, 2.28674139e-02, 1.57275882e-03,\n", + " 3.80957136e-02],\n", + " [2.25087390e-01, 1.86635595e-01, 3.28182785e-01, 1.97578398e-03,\n", + " 2.58118446e-01],\n", + " [3.37108325e-02, 3.06349132e-02, 8.42475649e-01, 2.38306070e-02,\n", + " 6.93479983e-02],\n", + " [2.45142248e-02, 5.73911618e-04, 2.22677708e-01, 7.51578257e-01,\n", + " 6.55898992e-04],\n", + " [8.94674459e-02, 1.38022643e-01, 1.99288838e-01, 1.38962684e-03,\n", + " 5.71831446e-01]])" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "validation_pipeline = get_random_forest_model()\n", + "validation_pipeline.set_params(\n", + " classifier__max_depth=24,\n", + " classifier__n_estimators=100,\n", + ")\n", + "\n", + "cross_validate_pipeline(validation_pipeline)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Random forest training and testing\n", + "___" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 3min 41s, sys: 2.12 s, total: 3min 43s\n", + "Wall time: 1min 13s\n" + ] + } + ], + "source": [ + "%%time\n", + "\n", + "testing_pipeline = get_random_forest_model()\n", + "testing_pipeline.set_params(\n", + " classifier__max_depth=24,\n", + " classifier__n_estimators=100,\n", + ")\n", + "\n", + "testing_pipeline.fit(X_train_valid[:, 2:], y_train_valid);" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Categorical features: [0 1]\n", + "Time domain features: [ 2 3 4 5 6 7 8 25 26 27 28 29 30 31]\n", + "Frequency domain features: [ 9 10 11 12 13 14 15 16 17 18 19 32 33 34 35 36 37 38 39 40 41 42]\n", + "Subband time domain features: [20 21 22 23 24 43 44 45 46 47]\n", + "\n", + "Top 5 features: [(41, 0.0627), (29, 0.0487), (18, 0.0421), (20, 0.0411), (47, 0.0403)]\n", + "Bottom 5 features: [(11, 0.0108), (27, 0.0093), (4, 0.0091), (42, 0.0066), (0, 0.0031)]\n", + "\n", + "Fpz-Cz feature importances: 0.4553\n", + "Pz-Oz feature importances: 0.5284\n", + "\n", + "Category feature importances: 0.0162\n", + "Time domain feature importances: 0.2843\n", + "Frequency domain feature importances: 0.4711\n", + "Subband time domain feature importances: 0.2283\n" + ] + } + ], + "source": [ + "feature_importance_indexes = [\n", + " (idx, round(importance,4))\n", + " for idx, importance in enumerate(testing_pipeline.steps[1][1].feature_importances_)\n", + "]\n", + "feature_importance_indexes.sort(reverse=True, key=lambda x: x[1])\n", + "\n", + "category_feature_range = np.array([2, 3]) - 2\n", + "time_domaine_feature_range = np.array([4, 5, 6, 7, 8, 9, 10, 27, 28, 29, 30, 31, 32, 33]) - 2\n", + "freq_domain_feature_range = np.array([11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44]) - 2\n", + "subband_domain_feature_range = np.array([22, 23, 24, 25, 26, 45, 46, 47, 48, 49]) - 2\n", + "fpz_cz_feature_range = np.array(range(2, 25))\n", + "pz_oz_feature_range = np.array(range(25, 48))\n", + "\n", + "def get_feature_range_importance(indexes):\n", + " return np.sum([feature[1] for feature in feature_importance_indexes if feature[0] in indexes])\n", + "\n", + "print(f\"Categorical features: {category_feature_range}\")\n", + "print(f\"Time domain features: {time_domaine_feature_range}\")\n", + "print(f\"Frequency domain features: {freq_domain_feature_range}\")\n", + "print(f\"Subband time domain features: {subband_domain_feature_range}\\n\")\n", + "\n", + "print(f\"Top 5 features: {[feature for feature in feature_importance_indexes[:5]]}\")\n", + "print(f\"Bottom 5 features: {[feature for feature in feature_importance_indexes[-5:]]}\\n\")\n", + "\n", + "print(f\"Fpz-Cz feature importances: {get_feature_range_importance(fpz_cz_feature_range):.4f}\")\n", + "print(f\"Pz-Oz feature importances: {get_feature_range_importance(pz_oz_feature_range):.4f}\\n\")\n", + "\n", + "print(f\"Category feature importances: {get_feature_range_importance([0,1]):.4f}\")\n", + "print(f\"Time domain feature importances: {get_feature_range_importance(time_domaine_feature_range):.4f}\")\n", + "print(f\"Frequency domain feature importances: {get_feature_range_importance(freq_domain_feature_range):.4f}\")\n", + "print(f\"Subband time domain feature importances: {get_feature_range_importance(subband_domain_feature_range):.4f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[1512 65 3 3 41]\n", + " [ 220 147 332 0 284]\n", + " [ 39 45 3212 194 113]\n", + " [ 4 0 32 575 0]\n", + " [ 49 81 284 0 888]]\n", + " precision recall f1-score support\n", + "\n", + " W 0.83 0.93 0.88 1624\n", + " N1 0.43 0.15 0.22 983\n", + " N2 0.83 0.89 0.86 3603\n", + " N3 0.74 0.94 0.83 611\n", + " REM 0.67 0.68 0.68 1302\n", + "\n", + " accuracy 0.78 8123\n", + " macro avg 0.70 0.72 0.69 8123\n", + "weighted avg 0.75 0.78 0.75 8123\n", + "\n", + "Agreement score (Cohen Kappa): 0.6879671218212182\n" + ] + } + ], + "source": [ + "y_test_pred = testing_pipeline.predict(X_test[:,2:])\n", + "\n", + "print(confusion_matrix(y_test, y_test_pred))\n", + "\n", + "print(classification_report(y_test, y_test_pred, target_names=SLEEP_STAGES_VALUES.keys()))\n", + "\n", + "print(\"Agreement score (Cohen Kappa): \", cohen_kappa_score(y_test, y_test_pred))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Saving trained model\n", + "___\n", + "\n", + "We save the trained model with the postprocessing step, HMM. We will save only the matrix that define it. We do not need to persist the median filter postprocessing step, because it is stateless." + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [], + "source": [ + "SAVED_DIR = \"trained_model\"\n", + "\n", + "if not os.path.exists(SAVED_DIR):\n", + " os.mkdir(SAVED_DIR); " + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Pipeline object size (Mbytes): 322.775421\n" + ] + } + ], + "source": [ + "if USE_CONTINUOUS_AGE: \n", + " joblib.dump(testing_pipeline, f\"{SAVED_DIR}/classifier_RF_continous_age.joblib\")\n", + "else:\n", + " fd = joblib.dump(testing_pipeline, f\"{SAVED_DIR}/classifier_RF_small.joblib\")\n", + " print(\n", + " \"Pipeline object size (Mbytes): \",\n", + " os.path.getsize(f\"{SAVED_DIR}/classifier_RF_small.joblib\")/1e6\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.8" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/backend/requirements.txt b/backend/requirements.txt index 246c7204..ee23dc74 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -5,4 +5,6 @@ waitress==1.4.4 mne==0.21.0 numpy==1.19.2 scipy==1.5.2 -scikit-learn==0.23.2 +scikit-learn==0.23.2 # TODO: move this to requirements-dev after removing pipeline code from predict.py +skl2onnx==1.7.0 +onnxruntime==1.5.2 From 3c80f91389439dab3a67c6c4dcbdaa70a9e806f0 Mon Sep 17 00:00:00 2001 From: Claudia Onorato Date: Mon, 26 Oct 2020 00:06:25 -0400 Subject: [PATCH 02/24] updated readme & fixed req --- backend/readme.md | 7 +++++++ backend/requirements.txt | 1 - 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/backend/readme.md b/backend/readme.md index 256b487b..b7d97173 100644 --- a/backend/readme.md +++ b/backend/readme.md @@ -20,6 +20,13 @@ Install the required dependencies. pip install -r requirements.txt requirements-dev.txt ``` +If you are running on Linux or MacOS, you also have to install OpenMP with your package manager. It is a dependency of ONNX runtime, used to load our model and make predictions. + +```bash +apt-get install libgomp1 # on linux +brew install libomp # on macos +``` + ## Run it locally Activate your virtual environment. diff --git a/backend/requirements.txt b/backend/requirements.txt index d1b813a0..0f725c96 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -8,4 +8,3 @@ numpy==1.19.2 scipy==1.5.2 scikit-learn==0.23.2 # TODO: move this to requirements-dev after removing pipeline code from predict.py skl2onnx==1.7.0 -onnxruntime==1.5.2 From d58f4737e3233d02e2cb1694274ba369ae64d669 Mon Sep 17 00:00:00 2001 From: Claudia Onorato Date: Mon, 26 Oct 2020 00:39:52 -0400 Subject: [PATCH 03/24] print predictions --- backend/app.py | 4 +++- backend/classification/features/__init__.py | 2 +- backend/classification/load_model.py | 3 ++- backend/classification/predict.py | 11 ++++++++--- backend/requirements.txt | 3 ++- 5 files changed, 16 insertions(+), 7 deletions(-) diff --git a/backend/app.py b/backend/app.py index 64c37bb3..d314c509 100644 --- a/backend/app.py +++ b/backend/app.py @@ -7,8 +7,10 @@ from classification.predict import predict from classification.exceptions import ClassificationError from classification.config.constants import Sex, ALLOWED_FILE_EXTENSIONS +from classification.load_model import load_model app = Flask(__name__) +model = load_model() def allowed_file(filename): @@ -57,7 +59,7 @@ def analyze_sleep(): try: raw_array = get_raw_array(file) - predict(raw_array, info={ + predict(raw_array, model, info={ 'sex': sex, 'age': age, 'in_bed_seconds': bedtime - stream_start, diff --git a/backend/classification/features/__init__.py b/backend/classification/features/__init__.py index dc66d99b..c21ae168 100644 --- a/backend/classification/features/__init__.py +++ b/backend/classification/features/__init__.py @@ -26,4 +26,4 @@ def get_features(signal, info): X_eeg = get_eeg_features(signal, info['in_bed_seconds'], info['out_of_bed_seconds']) X_categorical = get_non_eeg_features(info['age'], info['sex'], X_eeg.shape[0]) - return np.append(X_categorical, X_eeg, axis=1) + return np.append(X_categorical, X_eeg, axis=1).astype(np.float32) diff --git a/backend/classification/load_model.py b/backend/classification/load_model.py index 0b20ba1c..42dbe8ba 100644 --- a/backend/classification/load_model.py +++ b/backend/classification/load_model.py @@ -8,7 +8,7 @@ MODEL_FILENAME = 'model.onnx' MODEL_PATH = f'{SCRIPT_PATH}/{MODEL_FILENAME}' MODEL_REPO = 'polycortex/polydodo-model' -MODEL_URL = f'https://raw.githubusercontent.com/{MODEL_REPO}/master/{MODEL_FILENAME}' +MODEL_URL = f'https://github.com/{MODEL_REPO}/blob/master/{MODEL_FILENAME}?raw=true' def _download_file(url, output): @@ -45,5 +45,6 @@ def _get_file_githash(filepath): def load_model(): if not path.exists(MODEL_PATH) or _get_file_githash(MODEL_PATH) != _get_latest_model_githash(): + print("downloading model") _download_file(MODEL_URL, MODEL_PATH) return onnxruntime.InferenceSession(MODEL_PATH) diff --git a/backend/classification/predict.py b/backend/classification/predict.py index 07c6dbb7..8a18684f 100644 --- a/backend/classification/predict.py +++ b/backend/classification/predict.py @@ -3,11 +3,12 @@ from classification.validation import validate -def predict(raw_eeg, info): +def predict(raw_eeg, model, info): """ Input: - raw_eeg: instance of mne.io.RawArray Should contain 2 channels (1: FPZ-CZ, 2: PZ-OZ) + - model: instance of InferenceSession - info: dict Should contain the following keys: - sex: instance of Sex enum @@ -18,6 +19,10 @@ def predict(raw_eeg, info): the subject started the recording and got out of bed """ validate(raw_eeg, info) - X_openbci = get_features(raw_eeg, info) + features = get_features(raw_eeg, info) - print(X_openbci[0], X_openbci.shape) + input_name = model.get_inputs()[0].name + predictions = model.run(None, {input_name: features})[0] + + print(features[0], features.shape) + print(predictions) diff --git a/backend/requirements.txt b/backend/requirements.txt index 0f725c96..49c340a6 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -6,5 +6,6 @@ mne==0.21.0 onnxruntime==1.5.2 numpy==1.19.2 scipy==1.5.2 -scikit-learn==0.23.2 # TODO: move this to requirements-dev after removing pipeline code from predict.py +scikit-learn==0.23.2 skl2onnx==1.7.0 +requests==2.7.0 From 5f34701094147b7068aba4678110d7263e0349f2 Mon Sep 17 00:00:00 2001 From: Claudia Onorato Date: Thu, 29 Oct 2020 11:10:33 -0400 Subject: [PATCH 04/24] revert model url --- backend/classification/load_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/classification/load_model.py b/backend/classification/load_model.py index 42dbe8ba..d24a16e4 100644 --- a/backend/classification/load_model.py +++ b/backend/classification/load_model.py @@ -8,7 +8,7 @@ MODEL_FILENAME = 'model.onnx' MODEL_PATH = f'{SCRIPT_PATH}/{MODEL_FILENAME}' MODEL_REPO = 'polycortex/polydodo-model' -MODEL_URL = f'https://github.com/{MODEL_REPO}/blob/master/{MODEL_FILENAME}?raw=true' +MODEL_URL = f'https://raw.githubusercontent.com/{MODEL_REPO}/master/{MODEL_FILENAME}' def _download_file(url, output): From a540c35863e292941ddcc66975f4b6b42069715c Mon Sep 17 00:00:00 2001 From: Anes Belfodil Date: Thu, 29 Oct 2020 11:13:53 -0400 Subject: [PATCH 05/24] Fix download --- backend/classification/load_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/classification/load_model.py b/backend/classification/load_model.py index d24a16e4..a6af19fc 100644 --- a/backend/classification/load_model.py +++ b/backend/classification/load_model.py @@ -8,7 +8,7 @@ MODEL_FILENAME = 'model.onnx' MODEL_PATH = f'{SCRIPT_PATH}/{MODEL_FILENAME}' MODEL_REPO = 'polycortex/polydodo-model' -MODEL_URL = f'https://raw.githubusercontent.com/{MODEL_REPO}/master/{MODEL_FILENAME}' +MODEL_URL = f'https://github.com/{MODEL_REPO}/raw/master/{MODEL_FILENAME}' def _download_file(url, output): From cccdd2761759ef38be5959876ebea517807f38c6 Mon Sep 17 00:00:00 2001 From: Claudia Onorato Date: Thu, 29 Oct 2020 11:49:42 -0400 Subject: [PATCH 06/24] added log to debug on macos --- backend/classification/load_model.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/backend/classification/load_model.py b/backend/classification/load_model.py index a6af19fc..fd00fabb 100644 --- a/backend/classification/load_model.py +++ b/backend/classification/load_model.py @@ -44,6 +44,8 @@ def _get_file_githash(filepath): def load_model(): + if path.exists(MODEL_PATH): + print(f"Model is already downloaded at {MODEL_PATH}, checking if it's the latest.") if not path.exists(MODEL_PATH) or _get_file_githash(MODEL_PATH) != _get_latest_model_githash(): print("downloading model") _download_file(MODEL_URL, MODEL_PATH) From 1b9289c50f940dfa6732e641622dde8e9a24f416 Mon Sep 17 00:00:00 2001 From: Claudia Onorato Date: Fri, 30 Oct 2020 22:44:35 -0400 Subject: [PATCH 07/24] fixed bug; replaced uV by V scaling --- backend/classification/file_loading.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/backend/classification/file_loading.py b/backend/classification/file_loading.py index 9dc44683..c9307180 100644 --- a/backend/classification/file_loading.py +++ b/backend/classification/file_loading.py @@ -27,6 +27,7 @@ ADS1299_Vref = 4.5 ADS1299_gain = 24. SCALE_uV_PER_COUNT = ADS1299_Vref / ((2**23) - 1) / ADS1299_gain * 1000000 +SCALE_V_PER_COUNT = SCALE_uV_PER_COUNT / 1e6 FILE_COLUMN_OFFSET = 1 CYTON_TOTAL_NB_CHANNELS = 8 @@ -48,7 +49,7 @@ def get_raw_array(file): if len(line_splitted) >= CYTON_TOTAL_NB_CHANNELS: eeg_raw.append(_get_decimals_from_hexadecimal_strings(line_splitted)) - eeg_raw = SCALE_uV_PER_COUNT * np.array(eeg_raw, dtype='object') + eeg_raw = SCALE_V_PER_COUNT * np.array(eeg_raw, dtype='object') raw_object = RawArray( np.transpose(eeg_raw), From 05c7e3f2026ce381bb2c41545715653dded48500 Mon Sep 17 00:00:00 2001 From: Claudia Onorato Date: Sat, 31 Oct 2020 22:06:48 -0400 Subject: [PATCH 08/24] added investigation notebook --- ai/.gitignore | 4 +- ai/investigation.ipynb | 767 +++++++++++++++++++++++++++++++++++++++++ backend/.gitignore | 1 + 3 files changed, 771 insertions(+), 1 deletion(-) create mode 100644 ai/investigation.ipynb diff --git a/ai/.gitignore b/ai/.gitignore index 310601c9..f72027b1 100644 --- a/ai/.gitignore +++ b/ai/.gitignore @@ -1,3 +1,5 @@ +investigation_data/* + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] @@ -109,4 +111,4 @@ venv.bak/ data/* .vscode/ -*.joblib \ No newline at end of file +*.joblib diff --git a/ai/investigation.ipynb b/ai/investigation.ipynb new file mode 100644 index 00000000..07e10bad --- /dev/null +++ b/ai/investigation.ipynb @@ -0,0 +1,767 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Investigate differences in the feature matrices\n", + "____\n", + "\n", + "Current bug: In the backend, we currently obtain invalid predictions (all Wake). When investigating, I saw that the feature matrix sent to the classifier was really different compared to the one generated in our notebooks.\n", + "\n", + "## Matrix feature differences\n", + "___" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "import mne\n", + "from constants import EPOCH_DURATION" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [], + "source": [ + "start = 1582418280\n", + "bed = 1582423980\n", + "wake = 1582452240\n", + "\n", + "nb_epoch_bed = int((bed - start) // EPOCH_DURATION)\n", + "nb_epoch_wake = int((wake - start) // EPOCH_DURATION)" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "((942, 48), (942, 48))" + ] + }, + "execution_count": 37, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "X_notebook = np.vstack(np.load(\"data/X_openbci_HP_PRIOR.npy\", allow_pickle=True))[nb_epoch_bed:nb_epoch_wake]\n", + "X_backend = np.load(\"investigation_data/feature.npy\", allow_pickle=True)\n", + "\n", + "X_notebook.shape, X_backend.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Feature 0 diff: 0.00\n", + "Feature 1 diff: 0.00\n", + "Feature 2 diff: 0.14\n", + "Feature 3 diff: 21.24\n", + "Feature 4 diff: 0.01\n", + "Feature 5 diff: 0.04\n", + "Feature 6 diff: 5.94\n", + "Feature 7 diff: 0.00\n", + "Feature 8 diff: 0.80\n", + "Feature 9 diff: 116.65\n", + "Feature 10 diff: 6.50\n", + "Feature 11 diff: 1.92\n", + "Feature 12 diff: 1.94\n", + "Feature 13 diff: 0.23\n", + "Feature 14 diff: 0.00\n", + "Feature 15 diff: 0.00\n", + "Feature 16 diff: 0.00\n", + "Feature 17 diff: 0.00\n", + "Feature 18 diff: 0.00\n", + "Feature 19 diff: 0.01\n", + "Feature 20 diff: 1486454042448.59\n", + "Feature 21 diff: 71806204238.74\n", + "Feature 22 diff: 16018476987.50\n", + "Feature 23 diff: 22727447431.86\n", + "Feature 24 diff: 9513626021.77\n", + "Feature 25 diff: 0.08\n", + "Feature 26 diff: 12.56\n", + "Feature 27 diff: 0.01\n", + "Feature 28 diff: 0.04\n", + "Feature 29 diff: 6.80\n", + "Feature 30 diff: 0.00\n", + "Feature 31 diff: 3.70\n", + "Feature 32 diff: 51.18\n", + "Feature 33 diff: 2.82\n", + "Feature 34 diff: 1.20\n", + "Feature 35 diff: 0.55\n", + "Feature 36 diff: 0.10\n", + "Feature 37 diff: 0.00\n", + "Feature 38 diff: 0.00\n", + "Feature 39 diff: 0.00\n", + "Feature 40 diff: 0.00\n", + "Feature 41 diff: 0.00\n", + "Feature 42 diff: 0.02\n", + "Feature 43 diff: 670759560535.92\n", + "Feature 44 diff: 31305042058.57\n", + "Feature 45 diff: 10179418544.09\n", + "Feature 46 diff: 6343296375.84\n", + "Feature 47 diff: 4072542937.41\n" + ] + } + ], + "source": [ + "for feature_idx in range(X_notebook.shape[1]):\n", + " difference = np.mean(abs(X_notebook[:,feature_idx] - X_backend[:,feature_idx]))\n", + " print(f\"Feature {feature_idx} diff: {difference:.2f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can see there's a big difference in the time subband domain. The related features was to apply a subband filter (i.e. delta) and take the mean signal energy of the signal in the time domain. It consists of the sum of each sample powered by two.\n", + "\n", + "Then, if there are differences in the signal, those are highly amplified in the subband domain features.\n", + "\n", + "### Known errors\n", + "___\n", + "\n", + "We know that the signal we've trained on had a different quantification than the original quantification of the OpenBCI Cyton quantification. The EDF specification encodes the samples on 16 bits, whereas we currently keep the original encoding of 24 bits.\n", + "\n", + "## 1. Conversion from hexadecimal to decimal\n", + "\n", + "We will compare both files converted with the OpenBCI GUI and our own code. The file used will be the mini file." + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Fpz-CzPz-Oz
0-11104.548205.09960
1-11065.186191.77795
2-11017.264180.24446
3-11041.448194.37076
4-11036.777194.03549
.........
899993-17148.0576566.42800
899994-17136.9266567.14360
899995-17134.3776574.34100
899996-17127.5826579.28030
899997-17121.3906581.44870
\n", + "

899998 rows × 2 columns

\n", + "
" + ], + "text/plain": [ + " Fpz-Cz Pz-Oz\n", + "0 -11104.548 205.09960\n", + "1 -11065.186 191.77795\n", + "2 -11017.264 180.24446\n", + "3 -11041.448 194.37076\n", + "4 -11036.777 194.03549\n", + "... ... ...\n", + "899993 -17148.057 6566.42800\n", + "899994 -17136.926 6567.14360\n", + "899995 -17134.377 6574.34100\n", + "899996 -17127.582 6579.28030\n", + "899997 -17121.390 6581.44870\n", + "\n", + "[899998 rows x 2 columns]" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "openbci_gui_data = pd.read_csv(\"investigation_data/SDconverted-2020-10-29_00-05-19_mini.csv\", skiprows=7, usecols=[1,2], names=[\"Fpz-Cz\", \"Pz-Oz\"])\n", + "openbci_gui_data" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Fpz-CzPz-Oz
0-11104.547811205.099607
1-11065.186389191.777967
2-11017.264249180.244467
3-11041.448836194.370770
4-11036.777322194.035494
.........
899993-17148.0571806566.428431
899994-17136.9260126567.143687
899995-17134.3779136574.340948
899996-17127.5829826579.280684
899997-17121.3915496581.448803
\n", + "

899998 rows × 2 columns

\n", + "
" + ], + "text/plain": [ + " Fpz-Cz Pz-Oz\n", + "0 -11104.547811 205.099607\n", + "1 -11065.186389 191.777967\n", + "2 -11017.264249 180.244467\n", + "3 -11041.448836 194.370770\n", + "4 -11036.777322 194.035494\n", + "... ... ...\n", + "899993 -17148.057180 6566.428431\n", + "899994 -17136.926012 6567.143687\n", + "899995 -17134.377913 6574.340948\n", + "899996 -17127.582982 6579.280684\n", + "899997 -17121.391549 6581.448803\n", + "\n", + "[899998 rows x 2 columns]" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "script_data = pd.DataFrame(data=np.transpose(np.load('investigation_data/raw_converted.npy')), columns=[\"Fpz-Cz\", \"Pz-Oz\"])\n", + "script_data" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Mean difference between decimal raw data \n", + "Fpz-Cz 0.000550\n", + "Pz-Oz 0.000263\n", + "dtype: float64\n", + "\n", + "Median difference between decimal raw data \n", + "Fpz-Cz 0.000503\n", + "Pz-Oz 0.000212\n", + "dtype: float64\n", + "\n", + "Min difference between decimal raw data \n", + "Fpz-Cz 4.529284e-09\n", + "Pz-Oz 1.073204e-10\n", + "dtype: float64\n", + "\n", + "Max difference between decimal raw data \n", + "Fpz-Cz 0.002508\n", + "Pz-Oz 0.001277\n", + "dtype: float64\n" + ] + } + ], + "source": [ + "difference_df = abs(openbci_gui_data - script_data)\n", + "print(f'Mean difference between decimal raw data \\n{difference_df.mean()}\\n')\n", + "print(f'Median difference between decimal raw data \\n{difference_df.median()}\\n')\n", + "print(f'Min difference between decimal raw data \\n{difference_df.min()}\\n')\n", + "print(f'Max difference between decimal raw data \\n{difference_df.max()}')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As we see, the OpenBCI GUI stores only the five first decimals, whereas we do not round the number in our backend. We can look if its the cause of the problem." + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Mean difference between decimal raw data \n", + "Fpz-Cz 0.000550\n", + "Pz-Oz 0.000263\n", + "dtype: float64\n", + "\n", + "Median difference between decimal raw data \n", + "Fpz-Cz 0.00050\n", + "Pz-Oz 0.00021\n", + "dtype: float64\n", + "\n", + "Min difference between decimal raw data \n", + "Fpz-Cz 0.0\n", + "Pz-Oz 0.0\n", + "dtype: float64\n", + "\n", + "Max difference between decimal raw data \n", + "Fpz-Cz 0.00251\n", + "Pz-Oz 0.00128\n", + "dtype: float64\n" + ] + } + ], + "source": [ + "script_data_chopped = script_data.apply(lambda x: np.round(x, decimals=5))\n", + "script_data_chopped\n", + "\n", + "difference_df = abs(openbci_gui_data - script_data_chopped)\n", + "print(f'Mean difference between decimal raw data \\n{difference_df.mean()}\\n')\n", + "print(f'Median difference between decimal raw data \\n{difference_df.median()}\\n')\n", + "print(f'Min difference between decimal raw data \\n{difference_df.min()}\\n')\n", + "print(f'Max difference between decimal raw data \\n{difference_df.max()}')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We see that the maximum difference is the same as before. It is then not caused by the number of decimals.\n", + "\n", + "## Quantification\n", + "____\n", + "\n", + "In the case of the notebooks, we first convert the decimal converted file to another format, the edf format. Since this format enforces quantification to 16 bits, the data is then requantified from 24 to 16 bits. It is not the case in the backend. We will compare both results." + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(array([[-0.01110396, -0.01106457, -0.01101625, ..., -0.01713194,\n", + " -0.01713614, -0.01712879],\n", + " [ 0.0002052 , 0.00019168, 0.00018013, ..., 0.00657419,\n", + " 0.00657165, 0.00656827]]), (2, 899975))" + ] + }, + "execution_count": 46, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "notebook_data = mne.io.read_raw_edf('investigation_data/william-recording-mini.edf', preload=True, stim_channel=None, verbose=False)\n", + "notebook_data.get_data(), notebook_data.get_data().shape" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Fpz-CzPz-Oz
0-11104.548205.09960
1-11065.186191.77795
2-11017.264180.24446
3-11041.448194.37076
4-11036.777194.03549
.........
899993-17148.0576566.42800
899994-17136.9266567.14360
899995-17134.3776574.34100
899996-17127.5826579.28030
899997-17121.3906581.44870
\n", + "

899998 rows × 2 columns

\n", + "
" + ], + "text/plain": [ + " Fpz-Cz Pz-Oz\n", + "0 -11104.548 205.09960\n", + "1 -11065.186 191.77795\n", + "2 -11017.264 180.24446\n", + "3 -11041.448 194.37076\n", + "4 -11036.777 194.03549\n", + "... ... ...\n", + "899993 -17148.057 6566.42800\n", + "899994 -17136.926 6567.14360\n", + "899995 -17134.377 6574.34100\n", + "899996 -17127.582 6579.28030\n", + "899997 -17121.390 6581.44870\n", + "\n", + "[899998 rows x 2 columns]" + ] + }, + "execution_count": 47, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "openbci_gui_data = pd.read_csv(\"investigation_data/SDconverted-2020-10-29_00-05-19_mini.csv\", skiprows=7, usecols=[1,2], names=[\"Fpz-Cz\", \"Pz-Oz\"])\n", + "openbci_gui_data" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Fpz-CzPz-Oz
0-11104.547811205.099607
1-11065.186389191.777967
2-11017.264249180.244467
3-11041.448836194.370770
4-11036.777322194.035494
.........
899993-17148.0571806566.428431
899994-17136.9260126567.143687
899995-17134.3779136574.340948
899996-17127.5829826579.280684
899997-17121.3915496581.448803
\n", + "

899998 rows × 2 columns

\n", + "
" + ], + "text/plain": [ + " Fpz-Cz Pz-Oz\n", + "0 -11104.547811 205.099607\n", + "1 -11065.186389 191.777967\n", + "2 -11017.264249 180.244467\n", + "3 -11041.448836 194.370770\n", + "4 -11036.777322 194.035494\n", + "... ... ...\n", + "899993 -17148.057180 6566.428431\n", + "899994 -17136.926012 6567.143687\n", + "899995 -17134.377913 6574.340948\n", + "899996 -17127.582982 6579.280684\n", + "899997 -17121.391549 6581.448803\n", + "\n", + "[899998 rows x 2 columns]" + ] + }, + "execution_count": 43, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "script_data = pd.DataFrame(data=np.transpose(np.load('investigation_data/raw_converted.npy')), columns=[\"Fpz-Cz\", \"Pz-Oz\"])\n", + "script_data" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.8" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/backend/.gitignore b/backend/.gitignore index 6dfe2dc3..14ddb2ef 100644 --- a/backend/.gitignore +++ b/backend/.gitignore @@ -2,3 +2,4 @@ .vscode/ __pycache__/ *.onnx +.DS_Store From 26be93026bc8b1a890c23c59ac551f76ca236ebd Mon Sep 17 00:00:00 2001 From: Claudia Onorato Date: Sun, 1 Nov 2020 00:11:38 -0400 Subject: [PATCH 09/24] added check to see it local object modified date is after remote object update date --- backend/classification/load_model.py | 28 +++++++++++++++++++++------- backend/requirements.txt | 1 + 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/backend/classification/load_model.py b/backend/classification/load_model.py index 86724776..3134bbff 100644 --- a/backend/classification/load_model.py +++ b/backend/classification/load_model.py @@ -1,9 +1,12 @@ +from datetime import datetime from os import path +import re import sys +import xml.etree.ElementTree as ET + +from pytz import utc from requests import get import onnxruntime -import re -import xml.etree.ElementTree as ET SCRIPT_PATH = path.dirname(path.realpath(sys.argv[0])) MODEL_FILENAME = 'model.onnx' @@ -18,7 +21,7 @@ def _download_file(url, output): f.write(get(url).content) -def _get_latest_object_size(bucket_url, filename): +def _get_latest_object_information(bucket_url, filename): raw_result = get(bucket_url).text # https://stackoverflow.com/a/15641319 raw_result = re.sub(' xmlns="[^"]+"', '', raw_result) @@ -26,13 +29,24 @@ def _get_latest_object_size(bucket_url, filename): objects_nodes = result_root_node.findall('Contents') object_node = [object_node for object_node in objects_nodes if object_node.find("Key").text == filename][0] object_size = int(object_node.find("Size").text) - return object_size + object_latest_update = datetime.strptime(object_node.find("LastModified").text, "%Y-%m-%dT%H:%M:%S.%f%z") + + return {'size': object_size, 'latest_update': object_latest_update} + + +def _has_latest_model(): + latest_model_information = _get_latest_object_information(BUCKET_URL, MODEL_FILENAME) + current_model_size = path.getsize(MODEL_PATH) + current_model_update = utc.localize(datetime.fromtimestamp(path.getmtime(MODEL_PATH))) + + return ( + current_model_update >= latest_model_information['latest_update'] + and current_model_size == latest_model_information['size'] + ) def load_model(): - if path.exists(MODEL_PATH): - print(f"Model is already downloaded at {MODEL_PATH}, checking if it's the latest.") - if not path.exists(MODEL_PATH) or path.getsize(MODEL_PATH) != _get_latest_object_size(BUCKET_URL, MODEL_FILENAME): + if not path.exists(MODEL_PATH) or not _has_latest_model(): print("Downloading latest model...") _download_file(MODEL_URL, MODEL_PATH) print("Loading model...") diff --git a/backend/requirements.txt b/backend/requirements.txt index 49c340a6..064baa99 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -9,3 +9,4 @@ scipy==1.5.2 scikit-learn==0.23.2 skl2onnx==1.7.0 requests==2.7.0 +pytz==2020.1 From 212003f7aca55898320e79360680a5c0e96425e5 Mon Sep 17 00:00:00 2001 From: Claudia Onorato Date: Sun, 1 Nov 2020 00:38:03 -0400 Subject: [PATCH 10/24] changed hardcoded / to pathlib to handle Windows & POSIX paths --- backend/classification/load_model.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/backend/classification/load_model.py b/backend/classification/load_model.py index 3134bbff..609763eb 100644 --- a/backend/classification/load_model.py +++ b/backend/classification/load_model.py @@ -1,5 +1,6 @@ from datetime import datetime from os import path +from pathlib import Path import re import sys import xml.etree.ElementTree as ET @@ -8,11 +9,14 @@ from requests import get import onnxruntime -SCRIPT_PATH = path.dirname(path.realpath(sys.argv[0])) + +SCRIPT_PATH = Path(path.realpath(sys.argv[0])).parent + +BUCKET_NAME = 'polydodo' +BUCKET_URL = f'https://{BUCKET_NAME}.s3.amazonaws.com' + MODEL_FILENAME = 'model.onnx' -MODEL_PATH = f'{SCRIPT_PATH}/{MODEL_FILENAME}' -MODEL_BUCKET = 'polydodo' -BUCKET_URL = f'https://{MODEL_BUCKET}.s3.amazonaws.com' +MODEL_PATH = SCRIPT_PATH / MODEL_FILENAME MODEL_URL = f'{BUCKET_URL}/{MODEL_FILENAME}' @@ -50,4 +54,4 @@ def load_model(): print("Downloading latest model...") _download_file(MODEL_URL, MODEL_PATH) print("Loading model...") - return onnxruntime.InferenceSession(MODEL_PATH) + return onnxruntime.InferenceSession(str(MODEL_PATH)) From 3f419b0d8d81fdb8e82a740910ce9d4697433dd5 Mon Sep 17 00:00:00 2001 From: Claudia Onorato Date: Sun, 1 Nov 2020 01:54:38 -0400 Subject: [PATCH 11/24] added hmm files download --- backend/.gitignore | 1 + backend/app.py | 3 +- backend/classification/load_model.py | 43 ++++++++++++++++++++++------ backend/requirements.txt | 1 - 4 files changed, 37 insertions(+), 11 deletions(-) diff --git a/backend/.gitignore b/backend/.gitignore index 14ddb2ef..552623c5 100644 --- a/backend/.gitignore +++ b/backend/.gitignore @@ -3,3 +3,4 @@ __pycache__/ *.onnx .DS_Store +hmm_model/ diff --git a/backend/app.py b/backend/app.py index d314c509..7060d767 100644 --- a/backend/app.py +++ b/backend/app.py @@ -7,10 +7,11 @@ from classification.predict import predict from classification.exceptions import ClassificationError from classification.config.constants import Sex, ALLOWED_FILE_EXTENSIONS -from classification.load_model import load_model +from classification.load_model import load_model, load_hmm app = Flask(__name__) model = load_model() +hmm_model = load_hmm() def allowed_file(filename): diff --git a/backend/classification/load_model.py b/backend/classification/load_model.py index 609763eb..160f263f 100644 --- a/backend/classification/load_model.py +++ b/backend/classification/load_model.py @@ -1,11 +1,11 @@ from datetime import datetime -from os import path +from os import path, makedirs from pathlib import Path import re import sys import xml.etree.ElementTree as ET -from pytz import utc +import numpy as np from requests import get import onnxruntime @@ -19,14 +19,21 @@ MODEL_PATH = SCRIPT_PATH / MODEL_FILENAME MODEL_URL = f'{BUCKET_URL}/{MODEL_FILENAME}' +HMM_FOLDER = 'hmm_model' +HMM_FILENAMES = [ + ('emission', 'hmm_emission_probabilites.npy'), + ('start', 'hmm_start_probabilities.npy'), + ('transition', 'hmm_transition_probabilites.npy') +] + def _download_file(url, output): with open(output, 'wb') as f: f.write(get(url).content) -def _get_latest_object_information(bucket_url, filename): - raw_result = get(bucket_url).text +def _get_latest_object_information(filename): + raw_result = get(BUCKET_URL).text # https://stackoverflow.com/a/15641319 raw_result = re.sub(' xmlns="[^"]+"', '', raw_result) result_root_node = ET.fromstring(raw_result) @@ -38,10 +45,10 @@ def _get_latest_object_information(bucket_url, filename): return {'size': object_size, 'latest_update': object_latest_update} -def _has_latest_model(): - latest_model_information = _get_latest_object_information(BUCKET_URL, MODEL_FILENAME) - current_model_size = path.getsize(MODEL_PATH) - current_model_update = utc.localize(datetime.fromtimestamp(path.getmtime(MODEL_PATH))) +def _has_latest_object(filename, local_path): + latest_model_information = _get_latest_object_information(filename) + current_model_size = path.getsize(local_path) + current_model_update = datetime.fromtimestamp(path.getmtime(local_path)).astimezone() return ( current_model_update >= latest_model_information['latest_update'] @@ -50,8 +57,26 @@ def _has_latest_model(): def load_model(): - if not path.exists(MODEL_PATH) or not _has_latest_model(): + if not path.exists(MODEL_PATH) or not _has_latest_object(MODEL_FILENAME, MODEL_PATH): print("Downloading latest model...") _download_file(MODEL_URL, MODEL_PATH) print("Loading model...") return onnxruntime.InferenceSession(str(MODEL_PATH)) + + +def load_hmm(): + hmm_matrices = dict() + + if not path.exists(SCRIPT_PATH / HMM_FOLDER): + makedirs(SCRIPT_PATH / HMM_FOLDER) + + for hmm_object_name, hmm_file in HMM_FILENAMES: + model_path = SCRIPT_PATH / HMM_FOLDER / hmm_file + + if not path.exists(model_path) or not _has_latest_object(hmm_file, model_path): + print(f"Downloading latest {hmm_object_name} HMM matrix...") + _download_file(url=f"{BUCKET_URL}/{hmm_file}", output=model_path) + + hmm_matrices[hmm_object_name] = np.load(str(model_path)) + + return hmm_matrices diff --git a/backend/requirements.txt b/backend/requirements.txt index 064baa99..49c340a6 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -9,4 +9,3 @@ scipy==1.5.2 scikit-learn==0.23.2 skl2onnx==1.7.0 requests==2.7.0 -pytz==2020.1 From b3c085089caddb5308a5a65019a74b6d455de0bd Mon Sep 17 00:00:00 2001 From: Claudia Onorato Date: Sun, 1 Nov 2020 01:25:24 -0500 Subject: [PATCH 12/24] added postprocessing --- backend/app.py | 6 ++++-- backend/classification/config/constants.py | 6 ++++++ backend/classification/features/constants.py | 5 +---- .../classification/features/preprocessing.py | 4 +++- backend/classification/load_model.py | 11 ++++++++--- backend/classification/postprocess.py | 18 ++++++++++++++++++ backend/classification/predict.py | 12 +++++++++--- backend/requirements.txt | 1 + 8 files changed, 50 insertions(+), 13 deletions(-) create mode 100644 backend/classification/postprocess.py diff --git a/backend/app.py b/backend/app.py index 7060d767..bef27c1d 100644 --- a/backend/app.py +++ b/backend/app.py @@ -10,8 +10,10 @@ from classification.load_model import load_model, load_hmm app = Flask(__name__) -model = load_model() -hmm_model = load_hmm() +model = { + 'classifier': load_model(), + 'postprocessing': load_hmm() +} def allowed_file(filename): diff --git a/backend/classification/config/constants.py b/backend/classification/config/constants.py index 80692c6d..b445c12b 100644 --- a/backend/classification/config/constants.py +++ b/backend/classification/config/constants.py @@ -29,3 +29,9 @@ class Sex(Enum): [85, 125] ] ACCEPTED_AGE_RANGE = [AGE_FEATURE_BINS[0][0], AGE_FEATURE_BINS[-1][-1]] + +N_STAGES = 5 + +HMM_EMISSION_MATRIX = 'emission' +HMM_START_PROBABILITIES = 'start' +HMM_TRANSITION_MATRIX = 'transition' diff --git a/backend/classification/features/constants.py b/backend/classification/features/constants.py index 4dd21316..171b3fe6 100644 --- a/backend/classification/features/constants.py +++ b/backend/classification/features/constants.py @@ -1,7 +1,4 @@ -from classification.config.constants import ( - DATASET_SAMPLE_RATE, - EPOCH_DURATION, -) +from classification.config.constants import DATASET_SAMPLE_RATE NYQUIST_FREQ = DATASET_SAMPLE_RATE / 2 diff --git a/backend/classification/features/preprocessing.py b/backend/classification/features/preprocessing.py index 321f02a8..0561baa7 100644 --- a/backend/classification/features/preprocessing.py +++ b/backend/classification/features/preprocessing.py @@ -1,8 +1,10 @@ import mne from scipy.signal import cheby1 -from classification.features.constants import ( +from classification.config.constants import ( EPOCH_DURATION, +) +from classification.features.constants import ( DATASET_SAMPLE_RATE, DATASET_HIGH_PASS_FREQ, HIGH_PASS_FILTER_ORDER, diff --git a/backend/classification/load_model.py b/backend/classification/load_model.py index 160f263f..f554fb9f 100644 --- a/backend/classification/load_model.py +++ b/backend/classification/load_model.py @@ -9,6 +9,11 @@ from requests import get import onnxruntime +from classification.config.constants import ( + HMM_EMISSION_MATRIX, + HMM_START_PROBABILITIES, + HMM_TRANSITION_MATRIX, +) SCRIPT_PATH = Path(path.realpath(sys.argv[0])).parent @@ -21,9 +26,9 @@ HMM_FOLDER = 'hmm_model' HMM_FILENAMES = [ - ('emission', 'hmm_emission_probabilites.npy'), - ('start', 'hmm_start_probabilities.npy'), - ('transition', 'hmm_transition_probabilites.npy') + (HMM_EMISSION_MATRIX, 'hmm_emission_probabilites.npy'), + (HMM_START_PROBABILITIES, 'hmm_start_probabilities.npy'), + (HMM_TRANSITION_MATRIX, 'hmm_transition_probabilites.npy') ] diff --git a/backend/classification/postprocess.py b/backend/classification/postprocess.py new file mode 100644 index 00000000..ceace8f4 --- /dev/null +++ b/backend/classification/postprocess.py @@ -0,0 +1,18 @@ +from hmmlearn.hmm import MultinomialHMM + +from classification.config.constants import ( + HMM_EMISSION_MATRIX, + HMM_START_PROBABILITIES, + HMM_TRANSITION_MATRIX, + N_STAGES, +) + + +def postprocess(predictions, postprocessing_state): + hmm_model = MultinomialHMM(n_components=N_STAGES) + + hmm_model.emissionprob_ = postprocessing_state[HMM_EMISSION_MATRIX] + hmm_model.startprob_ = postprocessing_state[HMM_START_PROBABILITIES] + hmm_model.transmat_ = postprocessing_state[HMM_TRANSITION_MATRIX] + + return hmm_model.predict(predictions.reshape(-1, 1)) diff --git a/backend/classification/predict.py b/backend/classification/predict.py index 8a18684f..a91ba953 100644 --- a/backend/classification/predict.py +++ b/backend/classification/predict.py @@ -1,6 +1,7 @@ """defines functions to predict sleep stages based off EEG signals""" from classification.features import get_features from classification.validation import validate +from classification.postprocess import postprocess def predict(raw_eeg, model, info): @@ -8,7 +9,9 @@ def predict(raw_eeg, model, info): Input: - raw_eeg: instance of mne.io.RawArray Should contain 2 channels (1: FPZ-CZ, 2: PZ-OZ) - - model: instance of InferenceSession + - model: dict + Contains an instance of InferenceSession and the matrices + needed for the postprocessing - info: dict Should contain the following keys: - sex: instance of Sex enum @@ -18,11 +21,14 @@ def predict(raw_eeg, model, info): - out_of_bed_seconds: timespan, in seconds, from which the subject started the recording and got out of bed """ + classifier, postprocessing_state = model['classifier'], model['postprocessing'] + validate(raw_eeg, info) features = get_features(raw_eeg, info) + input_name = classifier.get_inputs()[0].name - input_name = model.get_inputs()[0].name - predictions = model.run(None, {input_name: features})[0] + predictions = classifier.run(None, {input_name: features})[0] + predictions = postprocess(predictions, postprocessing_state) print(features[0], features.shape) print(predictions) diff --git a/backend/requirements.txt b/backend/requirements.txt index 49c340a6..a24f8c6d 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -9,3 +9,4 @@ scipy==1.5.2 scikit-learn==0.23.2 skl2onnx==1.7.0 requests==2.7.0 +hmmlearn==0.2.4 From 56b4d7f625725506570813513c562cadd29f15bc Mon Sep 17 00:00:00 2001 From: Claudia Onorato Date: Sun, 1 Nov 2020 01:32:10 -0500 Subject: [PATCH 13/24] return predictions --- backend/classification/predict.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/backend/classification/predict.py b/backend/classification/predict.py index a91ba953..810bd8ce 100644 --- a/backend/classification/predict.py +++ b/backend/classification/predict.py @@ -20,6 +20,7 @@ def predict(raw_eeg, model, info): the subject started the recording and went to bed - out_of_bed_seconds: timespan, in seconds, from which the subject started the recording and got out of bed + Returns: array of predicted sleep stages """ classifier, postprocessing_state = model['classifier'], model['postprocessing'] @@ -32,3 +33,5 @@ def predict(raw_eeg, model, info): print(features[0], features.shape) print(predictions) + + return predictions From 26525cc879af9acf19668b30d253064ba1ff5252 Mon Sep 17 00:00:00 2001 From: Claudia Onorato Date: Sun, 1 Nov 2020 14:34:15 -0500 Subject: [PATCH 14/24] removed investigation file --- ai/investigation.ipynb | 767 ----------------------------------------- 1 file changed, 767 deletions(-) delete mode 100644 ai/investigation.ipynb diff --git a/ai/investigation.ipynb b/ai/investigation.ipynb deleted file mode 100644 index 07e10bad..00000000 --- a/ai/investigation.ipynb +++ /dev/null @@ -1,767 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Investigate differences in the feature matrices\n", - "____\n", - "\n", - "Current bug: In the backend, we currently obtain invalid predictions (all Wake). When investigating, I saw that the feature matrix sent to the classifier was really different compared to the one generated in our notebooks.\n", - "\n", - "## Matrix feature differences\n", - "___" - ] - }, - { - "cell_type": "code", - "execution_count": 41, - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "import pandas as pd\n", - "import mne\n", - "from constants import EPOCH_DURATION" - ] - }, - { - "cell_type": "code", - "execution_count": 36, - "metadata": {}, - "outputs": [], - "source": [ - "start = 1582418280\n", - "bed = 1582423980\n", - "wake = 1582452240\n", - "\n", - "nb_epoch_bed = int((bed - start) // EPOCH_DURATION)\n", - "nb_epoch_wake = int((wake - start) // EPOCH_DURATION)" - ] - }, - { - "cell_type": "code", - "execution_count": 37, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "((942, 48), (942, 48))" - ] - }, - "execution_count": 37, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "X_notebook = np.vstack(np.load(\"data/X_openbci_HP_PRIOR.npy\", allow_pickle=True))[nb_epoch_bed:nb_epoch_wake]\n", - "X_backend = np.load(\"investigation_data/feature.npy\", allow_pickle=True)\n", - "\n", - "X_notebook.shape, X_backend.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 40, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Feature 0 diff: 0.00\n", - "Feature 1 diff: 0.00\n", - "Feature 2 diff: 0.14\n", - "Feature 3 diff: 21.24\n", - "Feature 4 diff: 0.01\n", - "Feature 5 diff: 0.04\n", - "Feature 6 diff: 5.94\n", - "Feature 7 diff: 0.00\n", - "Feature 8 diff: 0.80\n", - "Feature 9 diff: 116.65\n", - "Feature 10 diff: 6.50\n", - "Feature 11 diff: 1.92\n", - "Feature 12 diff: 1.94\n", - "Feature 13 diff: 0.23\n", - "Feature 14 diff: 0.00\n", - "Feature 15 diff: 0.00\n", - "Feature 16 diff: 0.00\n", - "Feature 17 diff: 0.00\n", - "Feature 18 diff: 0.00\n", - "Feature 19 diff: 0.01\n", - "Feature 20 diff: 1486454042448.59\n", - "Feature 21 diff: 71806204238.74\n", - "Feature 22 diff: 16018476987.50\n", - "Feature 23 diff: 22727447431.86\n", - "Feature 24 diff: 9513626021.77\n", - "Feature 25 diff: 0.08\n", - "Feature 26 diff: 12.56\n", - "Feature 27 diff: 0.01\n", - "Feature 28 diff: 0.04\n", - "Feature 29 diff: 6.80\n", - "Feature 30 diff: 0.00\n", - "Feature 31 diff: 3.70\n", - "Feature 32 diff: 51.18\n", - "Feature 33 diff: 2.82\n", - "Feature 34 diff: 1.20\n", - "Feature 35 diff: 0.55\n", - "Feature 36 diff: 0.10\n", - "Feature 37 diff: 0.00\n", - "Feature 38 diff: 0.00\n", - "Feature 39 diff: 0.00\n", - "Feature 40 diff: 0.00\n", - "Feature 41 diff: 0.00\n", - "Feature 42 diff: 0.02\n", - "Feature 43 diff: 670759560535.92\n", - "Feature 44 diff: 31305042058.57\n", - "Feature 45 diff: 10179418544.09\n", - "Feature 46 diff: 6343296375.84\n", - "Feature 47 diff: 4072542937.41\n" - ] - } - ], - "source": [ - "for feature_idx in range(X_notebook.shape[1]):\n", - " difference = np.mean(abs(X_notebook[:,feature_idx] - X_backend[:,feature_idx]))\n", - " print(f\"Feature {feature_idx} diff: {difference:.2f}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can see there's a big difference in the time subband domain. The related features was to apply a subband filter (i.e. delta) and take the mean signal energy of the signal in the time domain. It consists of the sum of each sample powered by two.\n", - "\n", - "Then, if there are differences in the signal, those are highly amplified in the subband domain features.\n", - "\n", - "### Known errors\n", - "___\n", - "\n", - "We know that the signal we've trained on had a different quantification than the original quantification of the OpenBCI Cyton quantification. The EDF specification encodes the samples on 16 bits, whereas we currently keep the original encoding of 24 bits.\n", - "\n", - "## 1. Conversion from hexadecimal to decimal\n", - "\n", - "We will compare both files converted with the OpenBCI GUI and our own code. The file used will be the mini file." - ] - }, - { - "cell_type": "code", - "execution_count": 28, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
Fpz-CzPz-Oz
0-11104.548205.09960
1-11065.186191.77795
2-11017.264180.24446
3-11041.448194.37076
4-11036.777194.03549
.........
899993-17148.0576566.42800
899994-17136.9266567.14360
899995-17134.3776574.34100
899996-17127.5826579.28030
899997-17121.3906581.44870
\n", - "

899998 rows × 2 columns

\n", - "
" - ], - "text/plain": [ - " Fpz-Cz Pz-Oz\n", - "0 -11104.548 205.09960\n", - "1 -11065.186 191.77795\n", - "2 -11017.264 180.24446\n", - "3 -11041.448 194.37076\n", - "4 -11036.777 194.03549\n", - "... ... ...\n", - "899993 -17148.057 6566.42800\n", - "899994 -17136.926 6567.14360\n", - "899995 -17134.377 6574.34100\n", - "899996 -17127.582 6579.28030\n", - "899997 -17121.390 6581.44870\n", - "\n", - "[899998 rows x 2 columns]" - ] - }, - "execution_count": 28, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "openbci_gui_data = pd.read_csv(\"investigation_data/SDconverted-2020-10-29_00-05-19_mini.csv\", skiprows=7, usecols=[1,2], names=[\"Fpz-Cz\", \"Pz-Oz\"])\n", - "openbci_gui_data" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
Fpz-CzPz-Oz
0-11104.547811205.099607
1-11065.186389191.777967
2-11017.264249180.244467
3-11041.448836194.370770
4-11036.777322194.035494
.........
899993-17148.0571806566.428431
899994-17136.9260126567.143687
899995-17134.3779136574.340948
899996-17127.5829826579.280684
899997-17121.3915496581.448803
\n", - "

899998 rows × 2 columns

\n", - "
" - ], - "text/plain": [ - " Fpz-Cz Pz-Oz\n", - "0 -11104.547811 205.099607\n", - "1 -11065.186389 191.777967\n", - "2 -11017.264249 180.244467\n", - "3 -11041.448836 194.370770\n", - "4 -11036.777322 194.035494\n", - "... ... ...\n", - "899993 -17148.057180 6566.428431\n", - "899994 -17136.926012 6567.143687\n", - "899995 -17134.377913 6574.340948\n", - "899996 -17127.582982 6579.280684\n", - "899997 -17121.391549 6581.448803\n", - "\n", - "[899998 rows x 2 columns]" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "script_data = pd.DataFrame(data=np.transpose(np.load('investigation_data/raw_converted.npy')), columns=[\"Fpz-Cz\", \"Pz-Oz\"])\n", - "script_data" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Mean difference between decimal raw data \n", - "Fpz-Cz 0.000550\n", - "Pz-Oz 0.000263\n", - "dtype: float64\n", - "\n", - "Median difference between decimal raw data \n", - "Fpz-Cz 0.000503\n", - "Pz-Oz 0.000212\n", - "dtype: float64\n", - "\n", - "Min difference between decimal raw data \n", - "Fpz-Cz 4.529284e-09\n", - "Pz-Oz 1.073204e-10\n", - "dtype: float64\n", - "\n", - "Max difference between decimal raw data \n", - "Fpz-Cz 0.002508\n", - "Pz-Oz 0.001277\n", - "dtype: float64\n" - ] - } - ], - "source": [ - "difference_df = abs(openbci_gui_data - script_data)\n", - "print(f'Mean difference between decimal raw data \\n{difference_df.mean()}\\n')\n", - "print(f'Median difference between decimal raw data \\n{difference_df.median()}\\n')\n", - "print(f'Min difference between decimal raw data \\n{difference_df.min()}\\n')\n", - "print(f'Max difference between decimal raw data \\n{difference_df.max()}')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "As we see, the OpenBCI GUI stores only the five first decimals, whereas we do not round the number in our backend. We can look if its the cause of the problem." - ] - }, - { - "cell_type": "code", - "execution_count": 33, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Mean difference between decimal raw data \n", - "Fpz-Cz 0.000550\n", - "Pz-Oz 0.000263\n", - "dtype: float64\n", - "\n", - "Median difference between decimal raw data \n", - "Fpz-Cz 0.00050\n", - "Pz-Oz 0.00021\n", - "dtype: float64\n", - "\n", - "Min difference between decimal raw data \n", - "Fpz-Cz 0.0\n", - "Pz-Oz 0.0\n", - "dtype: float64\n", - "\n", - "Max difference between decimal raw data \n", - "Fpz-Cz 0.00251\n", - "Pz-Oz 0.00128\n", - "dtype: float64\n" - ] - } - ], - "source": [ - "script_data_chopped = script_data.apply(lambda x: np.round(x, decimals=5))\n", - "script_data_chopped\n", - "\n", - "difference_df = abs(openbci_gui_data - script_data_chopped)\n", - "print(f'Mean difference between decimal raw data \\n{difference_df.mean()}\\n')\n", - "print(f'Median difference between decimal raw data \\n{difference_df.median()}\\n')\n", - "print(f'Min difference between decimal raw data \\n{difference_df.min()}\\n')\n", - "print(f'Max difference between decimal raw data \\n{difference_df.max()}')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We see that the maximum difference is the same as before. It is then not caused by the number of decimals.\n", - "\n", - "## Quantification\n", - "____\n", - "\n", - "In the case of the notebooks, we first convert the decimal converted file to another format, the edf format. Since this format enforces quantification to 16 bits, the data is then requantified from 24 to 16 bits. It is not the case in the backend. We will compare both results." - ] - }, - { - "cell_type": "code", - "execution_count": 46, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(array([[-0.01110396, -0.01106457, -0.01101625, ..., -0.01713194,\n", - " -0.01713614, -0.01712879],\n", - " [ 0.0002052 , 0.00019168, 0.00018013, ..., 0.00657419,\n", - " 0.00657165, 0.00656827]]), (2, 899975))" - ] - }, - "execution_count": 46, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "notebook_data = mne.io.read_raw_edf('investigation_data/william-recording-mini.edf', preload=True, stim_channel=None, verbose=False)\n", - "notebook_data.get_data(), notebook_data.get_data().shape" - ] - }, - { - "cell_type": "code", - "execution_count": 47, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
Fpz-CzPz-Oz
0-11104.548205.09960
1-11065.186191.77795
2-11017.264180.24446
3-11041.448194.37076
4-11036.777194.03549
.........
899993-17148.0576566.42800
899994-17136.9266567.14360
899995-17134.3776574.34100
899996-17127.5826579.28030
899997-17121.3906581.44870
\n", - "

899998 rows × 2 columns

\n", - "
" - ], - "text/plain": [ - " Fpz-Cz Pz-Oz\n", - "0 -11104.548 205.09960\n", - "1 -11065.186 191.77795\n", - "2 -11017.264 180.24446\n", - "3 -11041.448 194.37076\n", - "4 -11036.777 194.03549\n", - "... ... ...\n", - "899993 -17148.057 6566.42800\n", - "899994 -17136.926 6567.14360\n", - "899995 -17134.377 6574.34100\n", - "899996 -17127.582 6579.28030\n", - "899997 -17121.390 6581.44870\n", - "\n", - "[899998 rows x 2 columns]" - ] - }, - "execution_count": 47, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "openbci_gui_data = pd.read_csv(\"investigation_data/SDconverted-2020-10-29_00-05-19_mini.csv\", skiprows=7, usecols=[1,2], names=[\"Fpz-Cz\", \"Pz-Oz\"])\n", - "openbci_gui_data" - ] - }, - { - "cell_type": "code", - "execution_count": 43, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
Fpz-CzPz-Oz
0-11104.547811205.099607
1-11065.186389191.777967
2-11017.264249180.244467
3-11041.448836194.370770
4-11036.777322194.035494
.........
899993-17148.0571806566.428431
899994-17136.9260126567.143687
899995-17134.3779136574.340948
899996-17127.5829826579.280684
899997-17121.3915496581.448803
\n", - "

899998 rows × 2 columns

\n", - "
" - ], - "text/plain": [ - " Fpz-Cz Pz-Oz\n", - "0 -11104.547811 205.099607\n", - "1 -11065.186389 191.777967\n", - "2 -11017.264249 180.244467\n", - "3 -11041.448836 194.370770\n", - "4 -11036.777322 194.035494\n", - "... ... ...\n", - "899993 -17148.057180 6566.428431\n", - "899994 -17136.926012 6567.143687\n", - "899995 -17134.377913 6574.340948\n", - "899996 -17127.582982 6579.280684\n", - "899997 -17121.391549 6581.448803\n", - "\n", - "[899998 rows x 2 columns]" - ] - }, - "execution_count": 43, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "script_data = pd.DataFrame(data=np.transpose(np.load('investigation_data/raw_converted.npy')), columns=[\"Fpz-Cz\", \"Pz-Oz\"])\n", - "script_data" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.8" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} From f1b304e500bee6f46d4dc9ca7b96c5cd5e5e8a75 Mon Sep 17 00:00:00 2001 From: Claudia Onorato Date: Sun, 1 Nov 2020 15:12:14 -0500 Subject: [PATCH 15/24] changed gitignore --- ai/.gitignore | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ai/.gitignore b/ai/.gitignore index f72027b1..1c3c84bf 100644 --- a/ai/.gitignore +++ b/ai/.gitignore @@ -1,5 +1,3 @@ -investigation_data/* - # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] @@ -112,3 +110,4 @@ data/* .vscode/ *.joblib +trained_model/* From f645862cb86e7333dcb39fece6db569c1e1989bb Mon Sep 17 00:00:00 2001 From: Claudia Onorato Date: Sun, 1 Nov 2020 15:14:50 -0500 Subject: [PATCH 16/24] removed unnecessary requirement --- backend/requirements.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/backend/requirements.txt b/backend/requirements.txt index a24f8c6d..81d90da1 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -7,6 +7,5 @@ onnxruntime==1.5.2 numpy==1.19.2 scipy==1.5.2 scikit-learn==0.23.2 -skl2onnx==1.7.0 requests==2.7.0 hmmlearn==0.2.4 From 130adf5e079a277b2b4cf22d56263f62c7663a87 Mon Sep 17 00:00:00 2001 From: Claudia Onorato Date: Sun, 1 Nov 2020 15:31:47 -0500 Subject: [PATCH 17/24] added logs to debug error when running executable on macos --- backend/classification/load_model.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/backend/classification/load_model.py b/backend/classification/load_model.py index f554fb9f..e17812b3 100644 --- a/backend/classification/load_model.py +++ b/backend/classification/load_model.py @@ -33,11 +33,15 @@ def _download_file(url, output): + print(f'downloading from {url}') + with open(output, 'wb') as f: f.write(get(url).content) def _get_latest_object_information(filename): + print(f'fetching bucket files info from {BUCKET_URL}') + raw_result = get(BUCKET_URL).text # https://stackoverflow.com/a/15641319 raw_result = re.sub(' xmlns="[^"]+"', '', raw_result) From 533f4b62636dd680f9958630fde1f5add748ca1e Mon Sep 17 00:00:00 2001 From: Claudia Onorato Date: Sun, 1 Nov 2020 15:44:26 -0500 Subject: [PATCH 18/24] added requirements --- backend/requirements.txt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/backend/requirements.txt b/backend/requirements.txt index 81d90da1..4b8a7bad 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -7,5 +7,7 @@ onnxruntime==1.5.2 numpy==1.19.2 scipy==1.5.2 scikit-learn==0.23.2 -requests==2.7.0 +requests[security]==2.7.0 hmmlearn==0.2.4 +cryptography==3.2.1 +pyopenssl==19.1.0 From 6781be6d95d15977fe3366626f82dac937ba67ca Mon Sep 17 00:00:00 2001 From: Claudia Onorato Date: Sun, 1 Nov 2020 16:07:04 -0500 Subject: [PATCH 19/24] added verify param with certifi to indicate certification path --- backend/classification/load_model.py | 5 +++-- backend/requirements.txt | 5 ++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/backend/classification/load_model.py b/backend/classification/load_model.py index e17812b3..c4c27973 100644 --- a/backend/classification/load_model.py +++ b/backend/classification/load_model.py @@ -5,6 +5,7 @@ import sys import xml.etree.ElementTree as ET +import certifi import numpy as np from requests import get import onnxruntime @@ -36,13 +37,13 @@ def _download_file(url, output): print(f'downloading from {url}') with open(output, 'wb') as f: - f.write(get(url).content) + f.write(get(url, verify=certifi.where()).content) def _get_latest_object_information(filename): print(f'fetching bucket files info from {BUCKET_URL}') - raw_result = get(BUCKET_URL).text + raw_result = get(BUCKET_URL, verify=certifi.where()).text # https://stackoverflow.com/a/15641319 raw_result = re.sub(' xmlns="[^"]+"', '', raw_result) result_root_node = ET.fromstring(raw_result) diff --git a/backend/requirements.txt b/backend/requirements.txt index 4b8a7bad..684f25fc 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -7,7 +7,6 @@ onnxruntime==1.5.2 numpy==1.19.2 scipy==1.5.2 scikit-learn==0.23.2 -requests[security]==2.7.0 +requests==2.7.0 hmmlearn==0.2.4 -cryptography==3.2.1 -pyopenssl==19.1.0 +certifi==2020.6.20 From 1d0a27372083aaa9a1243dfca7995990b203f3db Mon Sep 17 00:00:00 2001 From: Claudia Onorato Date: Sun, 1 Nov 2020 16:21:38 -0500 Subject: [PATCH 20/24] removed temporary prints --- backend/classification/load_model.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/backend/classification/load_model.py b/backend/classification/load_model.py index c4c27973..3df63bb3 100644 --- a/backend/classification/load_model.py +++ b/backend/classification/load_model.py @@ -34,15 +34,11 @@ def _download_file(url, output): - print(f'downloading from {url}') - with open(output, 'wb') as f: f.write(get(url, verify=certifi.where()).content) def _get_latest_object_information(filename): - print(f'fetching bucket files info from {BUCKET_URL}') - raw_result = get(BUCKET_URL, verify=certifi.where()).text # https://stackoverflow.com/a/15641319 raw_result = re.sub(' xmlns="[^"]+"', '', raw_result) From 324d292c06b03d0c0a964545f31d9779ab29362f Mon Sep 17 00:00:00 2001 From: Claudia Onorato Date: Tue, 3 Nov 2020 20:42:15 -0500 Subject: [PATCH 21/24] only check by latest update if object is up to date --- backend/classification/load_model.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/backend/classification/load_model.py b/backend/classification/load_model.py index 3df63bb3..70cab9c0 100644 --- a/backend/classification/load_model.py +++ b/backend/classification/load_model.py @@ -38,28 +38,23 @@ def _download_file(url, output): f.write(get(url, verify=certifi.where()).content) -def _get_latest_object_information(filename): +def _get_object_latest_update(filename): raw_result = get(BUCKET_URL, verify=certifi.where()).text # https://stackoverflow.com/a/15641319 raw_result = re.sub(' xmlns="[^"]+"', '', raw_result) result_root_node = ET.fromstring(raw_result) objects_nodes = result_root_node.findall('Contents') object_node = [object_node for object_node in objects_nodes if object_node.find("Key").text == filename][0] - object_size = int(object_node.find("Size").text) object_latest_update = datetime.strptime(object_node.find("LastModified").text, "%Y-%m-%dT%H:%M:%S.%f%z") - return {'size': object_size, 'latest_update': object_latest_update} + return object_latest_update def _has_latest_object(filename, local_path): - latest_model_information = _get_latest_object_information(filename) - current_model_size = path.getsize(local_path) + latest_model_update = _get_object_latest_update(filename) current_model_update = datetime.fromtimestamp(path.getmtime(local_path)).astimezone() - return ( - current_model_update >= latest_model_information['latest_update'] - and current_model_size == latest_model_information['size'] - ) + return current_model_update >= latest_model_update def load_model(): From 8ccd1a0f82b478a8f14a0ade0d74548c79ff03bb Mon Sep 17 00:00:00 2001 From: Claudia Onorato Date: Tue, 3 Nov 2020 21:25:37 -0500 Subject: [PATCH 22/24] converted HMM info to Enum --- backend/classification/config/constants.py | 11 ++++++++++- backend/classification/load_model.py | 17 ++++------------- backend/classification/postprocess.py | 10 ++++------ 3 files changed, 18 insertions(+), 20 deletions(-) diff --git a/backend/classification/config/constants.py b/backend/classification/config/constants.py index b445c12b..05ea1914 100644 --- a/backend/classification/config/constants.py +++ b/backend/classification/config/constants.py @@ -1,4 +1,4 @@ -from enum import Enum +from enum import Enum, auto class Sex(Enum): @@ -8,6 +8,15 @@ class Sex(Enum): M = 2 +class HiddenMarkovModelProbability(Enum): + emission = auto() + start = auto() + transition = auto() + + def get_filename(self): + return f'hmm_{self.name}_probabilities.npy' + + ALLOWED_FILE_EXTENSIONS = ('.txt', '.csv') EEG_CHANNELS = [ diff --git a/backend/classification/load_model.py b/backend/classification/load_model.py index 70cab9c0..f1426be5 100644 --- a/backend/classification/load_model.py +++ b/backend/classification/load_model.py @@ -10,11 +10,7 @@ from requests import get import onnxruntime -from classification.config.constants import ( - HMM_EMISSION_MATRIX, - HMM_START_PROBABILITIES, - HMM_TRANSITION_MATRIX, -) +from classification.config.constants import HiddenMarkovModelProbability SCRIPT_PATH = Path(path.realpath(sys.argv[0])).parent @@ -26,11 +22,6 @@ MODEL_URL = f'{BUCKET_URL}/{MODEL_FILENAME}' HMM_FOLDER = 'hmm_model' -HMM_FILENAMES = [ - (HMM_EMISSION_MATRIX, 'hmm_emission_probabilites.npy'), - (HMM_START_PROBABILITIES, 'hmm_start_probabilities.npy'), - (HMM_TRANSITION_MATRIX, 'hmm_transition_probabilites.npy') -] def _download_file(url, output): @@ -71,13 +62,13 @@ def load_hmm(): if not path.exists(SCRIPT_PATH / HMM_FOLDER): makedirs(SCRIPT_PATH / HMM_FOLDER) - for hmm_object_name, hmm_file in HMM_FILENAMES: + for hmm_probability in HiddenMarkovModelProbability: + hmm_file = hmm_probability.get_filename() model_path = SCRIPT_PATH / HMM_FOLDER / hmm_file if not path.exists(model_path) or not _has_latest_object(hmm_file, model_path): - print(f"Downloading latest {hmm_object_name} HMM matrix...") _download_file(url=f"{BUCKET_URL}/{hmm_file}", output=model_path) - hmm_matrices[hmm_object_name] = np.load(str(model_path)) + hmm_matrices[hmm_probability.name] = np.load(str(model_path)) return hmm_matrices diff --git a/backend/classification/postprocess.py b/backend/classification/postprocess.py index ceace8f4..5d943bdf 100644 --- a/backend/classification/postprocess.py +++ b/backend/classification/postprocess.py @@ -1,9 +1,7 @@ from hmmlearn.hmm import MultinomialHMM from classification.config.constants import ( - HMM_EMISSION_MATRIX, - HMM_START_PROBABILITIES, - HMM_TRANSITION_MATRIX, + HiddenMarkovModelProbability, N_STAGES, ) @@ -11,8 +9,8 @@ def postprocess(predictions, postprocessing_state): hmm_model = MultinomialHMM(n_components=N_STAGES) - hmm_model.emissionprob_ = postprocessing_state[HMM_EMISSION_MATRIX] - hmm_model.startprob_ = postprocessing_state[HMM_START_PROBABILITIES] - hmm_model.transmat_ = postprocessing_state[HMM_TRANSITION_MATRIX] + hmm_model.emissionprob_ = postprocessing_state[HiddenMarkovModelProbability.emission.name] + hmm_model.startprob_ = postprocessing_state[HiddenMarkovModelProbability.start.name] + hmm_model.transmat_ = postprocessing_state[HiddenMarkovModelProbability.transition.name] return hmm_model.predict(predictions.reshape(-1, 1)) From 2ef51763a9d543bdb88d28eb4b4586c9da0b5fc7 Mon Sep 17 00:00:00 2001 From: Claudia Onorato Date: Wed, 4 Nov 2020 00:07:49 -0500 Subject: [PATCH 23/24] modified predict module to model module, with SleepStagesClassifier --- backend/app.py | 10 ++--- backend/classification/model.py | 49 +++++++++++++++++++++++++ backend/classification/postprocess.py | 16 -------- backend/classification/postprocessor.py | 24 ++++++++++++ backend/classification/predict.py | 37 ------------------- 5 files changed, 76 insertions(+), 60 deletions(-) create mode 100644 backend/classification/model.py delete mode 100644 backend/classification/postprocess.py create mode 100644 backend/classification/postprocessor.py delete mode 100644 backend/classification/predict.py diff --git a/backend/app.py b/backend/app.py index bef27c1d..11e736d9 100644 --- a/backend/app.py +++ b/backend/app.py @@ -4,16 +4,12 @@ from http import HTTPStatus from classification.file_loading import get_raw_array -from classification.predict import predict from classification.exceptions import ClassificationError from classification.config.constants import Sex, ALLOWED_FILE_EXTENSIONS -from classification.load_model import load_model, load_hmm +from classification.model import SleepStagesClassifier app = Flask(__name__) -model = { - 'classifier': load_model(), - 'postprocessing': load_hmm() -} +model = SleepStagesClassifier() def allowed_file(filename): @@ -62,7 +58,7 @@ def analyze_sleep(): try: raw_array = get_raw_array(file) - predict(raw_array, model, info={ + model.predict(raw_array, info={ 'sex': sex, 'age': age, 'in_bed_seconds': bedtime - stream_start, diff --git a/backend/classification/model.py b/backend/classification/model.py new file mode 100644 index 00000000..db1f2a00 --- /dev/null +++ b/backend/classification/model.py @@ -0,0 +1,49 @@ +"""defines models which predict sleep stages based off EEG signals""" + +from classification.features import get_features +from classification.validation import validate +from classification.postprocessor import get_hmm_model +from classification.load_model import load_model, load_hmm + + +class SleepStagesClassifier(): + def __init__(self): + self.model = load_model() + self.model_input_name = self.model.get_inputs()[0].name + + self.postprocessor_state = load_hmm() + self.postprocessor = get_hmm_model(self.postprocessor_state) + + def predict(self, raw_eeg, info): + """ + Input: + - raw_eeg: instance of mne.io.RawArray + Should contain 2 channels (1: FPZ-CZ, 2: PZ-OZ) + - info: dict + Should contain the following keys: + - sex: instance of Sex enum + - age: indicates the subject's age + - in_bed_seconds: timespan, in seconds, from which + the subject started the recording and went to bed + - out_of_bed_seconds: timespan, in seconds, from which + the subject started the recording and got out of bed + Returns: array of predicted sleep stages + """ + + validate(raw_eeg, info) + features = get_features(raw_eeg, info) + + print(features, features.shape) + + predictions = self._get_predictions(features) + predictions = self._get_postprocessed_predictions(predictions) + + print(predictions) + + return predictions + + def _get_predictions(self, features): + return self.model.run(None, {self.model_input_name: features})[0] + + def _get_postprocessed_predictions(self, predictions): + return self.postprocessor.predict(predictions.reshape(-1, 1)) diff --git a/backend/classification/postprocess.py b/backend/classification/postprocess.py deleted file mode 100644 index 5d943bdf..00000000 --- a/backend/classification/postprocess.py +++ /dev/null @@ -1,16 +0,0 @@ -from hmmlearn.hmm import MultinomialHMM - -from classification.config.constants import ( - HiddenMarkovModelProbability, - N_STAGES, -) - - -def postprocess(predictions, postprocessing_state): - hmm_model = MultinomialHMM(n_components=N_STAGES) - - hmm_model.emissionprob_ = postprocessing_state[HiddenMarkovModelProbability.emission.name] - hmm_model.startprob_ = postprocessing_state[HiddenMarkovModelProbability.start.name] - hmm_model.transmat_ = postprocessing_state[HiddenMarkovModelProbability.transition.name] - - return hmm_model.predict(predictions.reshape(-1, 1)) diff --git a/backend/classification/postprocessor.py b/backend/classification/postprocessor.py new file mode 100644 index 00000000..260fdae9 --- /dev/null +++ b/backend/classification/postprocessor.py @@ -0,0 +1,24 @@ +from hmmlearn.hmm import MultinomialHMM + +from classification.config.constants import ( + HiddenMarkovModelProbability, + N_STAGES, +) + + +def get_hmm_model(state): + """Creates an instance of MultinomialHMM, which follows sklearn interface + Input: + - state: dictionnary + where the keys are HiddenMarkovModelProbability choices + where the values are the probabilities matrices or arrays which + describes the according hidden markov model state + Returns: an instance of a trained MultinomialHMM + """ + hmm_model = MultinomialHMM(n_components=N_STAGES) + + hmm_model.emissionprob_ = state[HiddenMarkovModelProbability.emission.name] + hmm_model.startprob_ = state[HiddenMarkovModelProbability.start.name] + hmm_model.transmat_ = state[HiddenMarkovModelProbability.transition.name] + + return hmm_model diff --git a/backend/classification/predict.py b/backend/classification/predict.py deleted file mode 100644 index 810bd8ce..00000000 --- a/backend/classification/predict.py +++ /dev/null @@ -1,37 +0,0 @@ -"""defines functions to predict sleep stages based off EEG signals""" -from classification.features import get_features -from classification.validation import validate -from classification.postprocess import postprocess - - -def predict(raw_eeg, model, info): - """ - Input: - - raw_eeg: instance of mne.io.RawArray - Should contain 2 channels (1: FPZ-CZ, 2: PZ-OZ) - - model: dict - Contains an instance of InferenceSession and the matrices - needed for the postprocessing - - info: dict - Should contain the following keys: - - sex: instance of Sex enum - - age: indicates the subject's age - - in_bed_seconds: timespan, in seconds, from which - the subject started the recording and went to bed - - out_of_bed_seconds: timespan, in seconds, from which - the subject started the recording and got out of bed - Returns: array of predicted sleep stages - """ - classifier, postprocessing_state = model['classifier'], model['postprocessing'] - - validate(raw_eeg, info) - features = get_features(raw_eeg, info) - input_name = classifier.get_inputs()[0].name - - predictions = classifier.run(None, {input_name: features})[0] - predictions = postprocess(predictions, postprocessing_state) - - print(features[0], features.shape) - print(predictions) - - return predictions From 54cc43b8039f4900e715084a0d94a63e3d3f26f2 Mon Sep 17 00:00:00 2001 From: Claudia Onorato Date: Wed, 4 Nov 2020 00:14:00 -0500 Subject: [PATCH 24/24] deleted unused constants --- backend/classification/config/constants.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/backend/classification/config/constants.py b/backend/classification/config/constants.py index 05ea1914..10414a75 100644 --- a/backend/classification/config/constants.py +++ b/backend/classification/config/constants.py @@ -40,7 +40,3 @@ def get_filename(self): ACCEPTED_AGE_RANGE = [AGE_FEATURE_BINS[0][0], AGE_FEATURE_BINS[-1][-1]] N_STAGES = 5 - -HMM_EMISSION_MATRIX = 'emission' -HMM_START_PROBABILITIES = 'start' -HMM_TRANSITION_MATRIX = 'transition'