diff --git a/ai/.gitignore b/ai/.gitignore index 310601c9..1c3c84bf 100644 --- a/ai/.gitignore +++ b/ai/.gitignore @@ -109,4 +109,5 @@ venv.bak/ data/* .vscode/ -*.joblib \ No newline at end of file +*.joblib +trained_model/* diff --git a/ai/export_to_onnx.ipynb b/ai/export_to_onnx.ipynb new file mode 100644 index 00000000..908f88bd --- /dev/null +++ b/ai/export_to_onnx.ipynb @@ -0,0 +1,479 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Save & reload trained model with ONNX\n", + "___\n", + "\n", + "This notebook aims to save, reload and check if the model can be correctly serialized through ONNX, and the Scikit-learn ONNX package." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%matplotlib inline\n", + "\n", + "import numpy as np\n", + "\n", + "from sklearn.pipeline import Pipeline\n", + "from sklearn.compose import ColumnTransformer\n", + "from sklearn.preprocessing import StandardScaler\n", + "from sklearn.neighbors import KNeighborsClassifier\n", + "from sklearn.discriminant_analysis import LinearDiscriminantAnalysis\n", + "from sklearn.ensemble import (RandomForestClassifier,\n", + " VotingClassifier)\n", + "from sklearn.metrics import (confusion_matrix,\n", + " classification_report,\n", + " cohen_kappa_score)\n", + "from skl2onnx import convert_sklearn\n", + "from skl2onnx.common.data_types import FloatTensorType\n", + "from skl2onnx.helpers.onnx_helper import save_onnx_model\n", + "from onnxruntime import InferenceSession\n", + "\n", + "from models.model_utils import (train_test_split_according_to_age)\n", + "from constants import (SLEEP_STAGES_VALUES,)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Generate trained pipeline\n", + "____" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "SUBJECT_IDX = 0 \n", + "NIGHT_IDX = 1\n", + "USE_CONTINUOUS_AGE = False\n", + "DOWNSIZE_SET = False\n", + "TEST_SET_SUBJECTS = [0.0, 24.0, 49.0, 71.0]" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(168954, 50)\n", + "(168954,)\n", + "Number of subjects: 78\n", + "Number of nights: 153\n", + "Subjects available: [ 0. 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11. 12. 13. 14. 15. 16. 17.\n", + " 18. 19. 20. 21. 22. 23. 24. 25. 26. 27. 28. 29. 30. 31. 32. 33. 34. 35.\n", + " 36. 37. 38. 40. 41. 42. 43. 44. 45. 46. 47. 48. 49. 50. 51. 52. 53. 54.\n", + " 55. 56. 57. 58. 59. 60. 61. 62. 63. 64. 65. 66. 67. 70. 71. 72. 73. 74.\n", + " 75. 76. 77. 80. 81. 82.]\n", + "Selected subjects for the test set are: [0.0, 24.0, 49.0, 71.0]\n", + "(8123, 50) (160831, 50) (8123,) (160831,)\n" + ] + } + ], + "source": [ + "def load_features():\n", + " if USE_CONTINUOUS_AGE:\n", + " X_file_name = \"data/x_features-age-continuous.npy\"\n", + " y_file_name = \"data/y_observations-age-continuous.npy\"\n", + " else:\n", + " X_file_name = \"data/x_features.npy\"\n", + " y_file_name = \"data/y_observations.npy\"\n", + "\n", + " X_init = np.load(X_file_name, allow_pickle=True)\n", + " y_init = np.load(y_file_name, allow_pickle=True)\n", + "\n", + " X_init = np.vstack(X_init)\n", + " y_init = np.hstack(y_init)\n", + "\n", + " print(X_init.shape)\n", + " print(y_init.shape)\n", + " print(\"Number of subjects: \", np.unique(X_init[:,SUBJECT_IDX]).shape[0]) # Some subject indexes are skipped, thus total number is below 83 (as we can see in https://physionet.org/content/sleep-edfx/1.0.0/)\n", + " print(\"Number of nights: \", len(np.unique([f\"{int(x[0])}-{int(x[1])}\" for x in X_init[:,SUBJECT_IDX:NIGHT_IDX+1]])))\n", + " print(\"Subjects available: \", np.unique(X_init[:,SUBJECT_IDX]))\n", + " \n", + " return X_init, y_init\n", + "\n", + "def split_data(X_init, y_init):\n", + " X_test, X_train_valid, y_test, y_train_valid = train_test_split_according_to_age(\n", + " X_init,\n", + " y_init,\n", + " use_continuous_age=USE_CONTINUOUS_AGE,\n", + " subjects_test=TEST_SET_SUBJECTS)\n", + " \n", + " print(X_test.shape, X_train_valid.shape, y_test.shape, y_train_valid.shape)\n", + " \n", + " return X_test, X_train_valid, y_test, y_train_valid\n", + "\n", + "X_init, y_init = load_features()\n", + "X_test, X_train_valid, y_test, y_train_valid = split_data(X_init, y_init)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[1522 52 2 4 44]\n", + " [ 240 145 325 1 272]\n", + " [ 38 57 3210 188 110]\n", + " [ 4 0 31 576 0]\n", + " [ 57 83 264 0 898]]\n", + " precision recall f1-score support\n", + "\n", + " W 0.82 0.94 0.87 1624\n", + " N1 0.43 0.15 0.22 983\n", + " N2 0.84 0.89 0.86 3603\n", + " N3 0.75 0.94 0.83 611\n", + " REM 0.68 0.69 0.68 1302\n", + "\n", + " accuracy 0.78 8123\n", + " macro avg 0.70 0.72 0.70 8123\n", + "weighted avg 0.75 0.78 0.76 8123\n", + "\n", + "Agreement score (Cohen Kappa): 0.6913101923642638\n" + ] + } + ], + "source": [ + "def get_voting_classifier_pipeline():\n", + " NB_CATEGORICAL_FEATURES = 2\n", + " NB_FEATURES = 48\n", + "\n", + " estimator_list = [\n", + " ('random_forest', RandomForestClassifier(\n", + " random_state=42, # enables deterministic behaviour\n", + " n_jobs=-1\n", + " )),\n", + " ('knn', Pipeline([\n", + " ('knn_dim_red', LinearDiscriminantAnalysis()),\n", + " ('knn_clf', KNeighborsClassifier(\n", + " weights='uniform',\n", + " n_neighbors=300,\n", + " leaf_size=100,\n", + " metric='chebyshev',\n", + " n_jobs=-1\n", + " ))\n", + " ])),\n", + " ]\n", + " \n", + " return Pipeline([\n", + " ('scaling', ColumnTransformer([\n", + " ('pass-through-categorical', 'passthrough', list(range(NB_CATEGORICAL_FEATURES))),\n", + " ('scaling-continuous', StandardScaler(copy=False), list(range(NB_CATEGORICAL_FEATURES,NB_FEATURES)))\n", + " ])),\n", + " ('voting_clf', VotingClassifier(\n", + " estimators=estimator_list,\n", + " voting='soft',\n", + " weights=np.array([0.83756205, 0.16243795]),\n", + " flatten_transform=False,\n", + " n_jobs=-1,\n", + " ))\n", + " ])\n", + "\n", + "testing_pipeline = get_voting_classifier_pipeline()\n", + "testing_pipeline.fit(X_train_valid[:, 2:], y_train_valid)\n", + "y_test_pred = testing_pipeline.predict(X_test[:,2:])\n", + "\n", + "print(confusion_matrix(y_test, y_test_pred))\n", + "print(classification_report(y_test, y_test_pred, target_names=SLEEP_STAGES_VALUES.keys()))\n", + "print(\"Agreement score (Cohen Kappa): \", cohen_kappa_score(y_test, y_test_pred))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Saving pipeline to ONNX\n", + "___" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "ename": "KeyError", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 3\u001b[0m initial_types=[(\n\u001b[1;32m 4\u001b[0m \u001b[0;34m'float_input'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0mFloatTensorType\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX_train_valid\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m )]\n\u001b[1;32m 7\u001b[0m )\n", + "\u001b[0;32m~/miniconda3/envs/py3/lib/python3.6/site-packages/skl2onnx/convert.py\u001b[0m in \u001b[0;36mconvert_sklearn\u001b[0;34m(model, name, initial_types, doc_string, target_opset, custom_conversion_functions, custom_shape_calculators, custom_parsers, options, dtype, intermediate, white_op, black_op, final_types)\u001b[0m\n\u001b[1;32m 152\u001b[0m \u001b[0;31m# Convert our Topology object into ONNX. The outcome is an ONNX model.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 153\u001b[0m onnx_model = convert_topology(topology, name, doc_string, target_opset,\n\u001b[0;32m--> 154\u001b[0;31m dtype=dtype, options=options)\n\u001b[0m\u001b[1;32m 155\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 156\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0monnx_model\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtopology\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mintermediate\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0monnx_model\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/py3/lib/python3.6/site-packages/skl2onnx/common/_topology.py\u001b[0m in \u001b[0;36mconvert_topology\u001b[0;34m(topology, model_name, doc_string, target_opset, channel_first_inputs, dtype, options)\u001b[0m\n\u001b[1;32m 1052\u001b[0m type(getattr(operator, 'raw_model', None))))\n\u001b[1;32m 1053\u001b[0m \u001b[0mcontainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalidate_options\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moperator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1054\u001b[0;31m \u001b[0mconv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mscope\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moperator\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcontainer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1055\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1056\u001b[0m \u001b[0;31m# Create a graph from its main components\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/py3/lib/python3.6/site-packages/skl2onnx/common/_registration.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args)\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mraw_operator\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_get_allowed_options\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mraw_operator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 29\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_fct\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 30\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mget_allowed_options\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/py3/lib/python3.6/site-packages/skl2onnx/operator_converters/voting_classifier.py\u001b[0m in \u001b[0;36mconvert_voting_classifier\u001b[0;34m(scope, operator, container)\u001b[0m\n\u001b[1;32m 44\u001b[0m \u001b[0;32mcontinue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 46\u001b[0;31m \u001b[0mop_type\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msklearn_operator_name_map\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mestimator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 47\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[0mthis_operator\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mscope\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdeclare_local_operator\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mop_type\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mKeyError\u001b[0m: " + ] + } + ], + "source": [ + "onnx_pipeline = convert_sklearn(\n", + " testing_pipeline,\n", + " initial_types=[(\n", + " 'float_input',\n", + " FloatTensorType([None, X_train_valid[:,2:].shape[1]])\n", + " )]\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can see the Voting classifier conversion do not currently support a Pipeline typed estimator in its estimators list.\n", + "\n", + "Considering that;\n", + "- the option of adding a pipeline as an estimator in the voting classifier is not supported.\n", + "- the size of a KNearestNeighbor classifier would be too big without its LDA, and that the performance of the RandomForest would be significantly decreased with an LDA beforehand.\n", + "- the voting classifier had a Cohen Kappa agreement's score of 0.6913 on the testing set, whilst we obtained 0.6916 with the fat Random Forest, and we obtained 0.6879 with the skinny RF.\n", + "- the voting classifier had a Cohen Kappa agreement's score of 0.62 ± 0.043 on the validation set, whilst we obtained [TO DE DEFINED] with the fat Random Forest, and we obtained 0.62 ± 0.043 with the skinny RF (where the validation set is a CV of 5 partitions and considering subjects)\n", + "- the size of the small RF is 322.8 Mbytes\n", + "- the size of the fat RF is 1.91 Gbytes\n", + "- the size of the voting classifier is 376.8 Mbytes\n", + "\n", + "We have decided to temporaly choose to use the skinny random forest, as its performance is quite similar to the voting classifier's and the fat random forest's." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Generate RF trained pipeline\n", + "___" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[1512 65 3 3 41]\n", + " [ 220 147 332 0 284]\n", + " [ 39 45 3212 194 113]\n", + " [ 4 0 32 575 0]\n", + " [ 49 81 284 0 888]]\n", + " precision recall f1-score support\n", + "\n", + " W 0.83 0.93 0.88 1624\n", + " N1 0.43 0.15 0.22 983\n", + " N2 0.83 0.89 0.86 3603\n", + " N3 0.74 0.94 0.83 611\n", + " REM 0.67 0.68 0.68 1302\n", + "\n", + " accuracy 0.78 8123\n", + " macro avg 0.70 0.72 0.69 8123\n", + "weighted avg 0.75 0.78 0.75 8123\n", + "\n", + "Agreement score (Cohen Kappa): 0.6879671218212182\n", + "CPU times: user 3min 41s, sys: 2.5 s, total: 3min 43s\n", + "Wall time: 1min 49s\n" + ] + } + ], + "source": [ + "%%time\n", + "\n", + "def get_random_forest_model():\n", + " NB_CATEGORICAL_FEATURES = 2\n", + " NB_FEATURES = 48\n", + " \n", + " return Pipeline([\n", + " ('scaling', ColumnTransformer([\n", + " ('pass-through-categorical', 'passthrough', list(range(NB_CATEGORICAL_FEATURES))),\n", + " ('scaling-continuous', StandardScaler(copy=False), list(range(NB_CATEGORICAL_FEATURES,NB_FEATURES)))\n", + " ])),\n", + " ('classifier', RandomForestClassifier(\n", + " n_estimators=100,\n", + " max_depth=24,\n", + " random_state=42, # enables deterministic behaviour\n", + " n_jobs=-1\n", + " ))\n", + " ])\n", + "\n", + "testing_pipeline = get_random_forest_model()\n", + "testing_pipeline.fit(X_train_valid[:, 2:], y_train_valid)\n", + "y_test_pred = testing_pipeline.predict(X_test[:,2:])\n", + "\n", + "print(confusion_matrix(y_test, y_test_pred))\n", + "print(classification_report(y_test, y_test_pred, target_names=SLEEP_STAGES_VALUES.keys()))\n", + "print(\"Agreement score (Cohen Kappa): \", cohen_kappa_score(y_test, y_test_pred))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Saving with ONNX\n", + "____" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "onnx_pipeline = convert_sklearn(\n", + " testing_pipeline,\n", + " initial_types=[(\n", + " 'float_input',\n", + " FloatTensorType([None, X_train_valid[:,2:].shape[1]])\n", + " )]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "save_onnx_model(onnx_pipeline, 'trained_model/rf_pipeline.onnx')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Comparing ONNX pipeline vs normal pipeline results\n", + "____" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "sess = InferenceSession('trained_model/rf_pipeline.onnx')" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "y_test_pred_onnx = sess.run(None, {'float_input': X_test[:,2:].astype(np.float32)})[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.000738643358365136" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sum(~(y_test_pred_onnx == y_test_pred))/len(y_test_pred)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## ONNX Pipepline drawing" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from onnx.tools.net_drawer import GetPydotGraph, GetOpNodeProducer\n", + "pydot_graph = GetPydotGraph(onnx_pipeline.graph, name=onnx_pipeline.graph.name, rankdir=\"TP\",\n", + " node_producer=GetOpNodeProducer(\"docstring\"))\n", + "pydot_graph.write_dot(\"graph.dot\")\n", + "\n", + "import os\n", + "os.system('dot -O -V -Tpng graph.dot')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.8" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/ai/models/RF_HMM_smaller.ipynb b/ai/models/RF_HMM_smaller.ipynb new file mode 100644 index 00000000..7e6f6838 --- /dev/null +++ b/ai/models/RF_HMM_smaller.ipynb @@ -0,0 +1,736 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Sleep stage classification: Random Forest & Hidden Markov Model\n", + "____\n", + "\n", + "This model aims to classify sleep stages based on two EEG channel. We will use the features extracted in the `pipeline.ipynb` notebook as the input to a Random Forest. The output of this model will then be used as the input of a HMM. We will implement our HMM the same as in this paper (Malafeev et al., « Automatic Human Sleep Stage Scoring Using Deep Neural Networks »)." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "import os\n", + "import sys\n", + "\n", + "# Ensure parent folder is in PYTHONPATH\n", + "module_path = os.path.abspath(os.path.join('..'))\n", + "if module_path not in sys.path:\n", + " sys.path.append(module_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "%matplotlib inline\n", + "\n", + "import sys\n", + "from itertools import groupby\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import joblib\n", + "\n", + "from sklearn.pipeline import Pipeline\n", + "from sklearn.compose import ColumnTransformer\n", + "from sklearn.preprocessing import StandardScaler\n", + "from sklearn.ensemble import RandomForestClassifier\n", + "from sklearn.model_selection import (GridSearchCV,\n", + " RandomizedSearchCV,\n", + " GroupKFold,\n", + " cross_validate)\n", + "from sklearn.metrics import (accuracy_score,\n", + " confusion_matrix,\n", + " classification_report,\n", + " f1_score,\n", + " cohen_kappa_score,\n", + " make_scorer)\n", + "from sklearn.discriminant_analysis import LinearDiscriminantAnalysis\n", + "from sklearn.decomposition import PCA\n", + "\n", + "from scipy.signal import medfilt\n", + "\n", + "from hmmlearn.hmm import MultinomialHMM\n", + "from constants import (SLEEP_STAGES_VALUES,\n", + " N_STAGES,\n", + " EPOCH_DURATION)\n", + "from model_utils import (print_hypnogram,\n", + " train_test_split_one_subject,\n", + " train_test_split_according_to_age)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load the features\n", + "___" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# position of the subject information and night information in the X matrix\n", + "SUBJECT_IDX = 0 \n", + "NIGHT_IDX = 1\n", + "USE_CONTINUOUS_AGE = False\n", + "DOWNSIZE_SET = False\n", + "TEST_SET_SUBJECTS = [0.0, 24.0, 49.0, 71.0]\n", + "\n", + "if USE_CONTINUOUS_AGE:\n", + " X_file_name = \"../data/x_features-age-continuous.npy\"\n", + " y_file_name = \"../data/y_observations-age-continuous.npy\"\n", + "else:\n", + " X_file_name = \"../data/x_features.npy\"\n", + " y_file_name = \"../data/y_observations.npy\"" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "X_init = np.load(X_file_name, allow_pickle=True)\n", + "y_init = np.load(y_file_name, allow_pickle=True)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(168954, 50)\n", + "(168954,)\n" + ] + } + ], + "source": [ + "X_init = np.vstack(X_init)\n", + "y_init = np.hstack(y_init)\n", + "print(X_init.shape)\n", + "print(y_init.shape)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of subjects: 78\n", + "Number of nights: 153\n" + ] + } + ], + "source": [ + "print(\"Number of subjects: \", np.unique(X_init[:,SUBJECT_IDX]).shape[0]) # Some subject indexes are skipped, thus total number is below 83 (as we can see in https://physionet.org/content/sleep-edfx/1.0.0/)\n", + "print(\"Number of nights: \", len(np.unique([f\"{int(x[0])}-{int(x[1])}\" for x in X_init[:,SUBJECT_IDX:NIGHT_IDX+1]])))\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Downsizing sets\n", + "___\n", + "\n", + "We will use the same set for all experiments. It includes the first 20 subjects, and excludes the 13th, because it only has one night.\n", + "\n", + "The last subject will be put in the test set. " + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "if DOWNSIZE_SET:\n", + " # Filtering to only keep first 20 subjects\n", + " X_20 = X_init[np.isin(X_init[:,SUBJECT_IDX], range(20))]\n", + " y_20 = y_init[np.isin(X_init[:,SUBJECT_IDX], range(20))]\n", + "\n", + " # Exclude the subject with only one night recording (13th)\n", + " MISSING_NIGHT_SUBJECT = 13\n", + "\n", + " X = X_20[X_20[:,SUBJECT_IDX] != MISSING_NIGHT_SUBJECT]\n", + " y = y_20[X_20[:,SUBJECT_IDX] != MISSING_NIGHT_SUBJECT]\n", + "\n", + " print(X.shape)\n", + " print(y.shape)\n", + "else:\n", + " X = X_init\n", + " y = y_init" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of subjects: 78\n", + "Subjects available: [ 0. 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11. 12. 13. 14. 15. 16. 17.\n", + " 18. 19. 20. 21. 22. 23. 24. 25. 26. 27. 28. 29. 30. 31. 32. 33. 34. 35.\n", + " 36. 37. 38. 40. 41. 42. 43. 44. 45. 46. 47. 48. 49. 50. 51. 52. 53. 54.\n", + " 55. 56. 57. 58. 59. 60. 61. 62. 63. 64. 65. 66. 67. 70. 71. 72. 73. 74.\n", + " 75. 76. 77. 80. 81. 82.]\n", + "Number of nights: 153\n" + ] + } + ], + "source": [ + "print(\"Number of subjects: \", np.unique(X[:,SUBJECT_IDX]).shape[0]) # Some subject indexes are skipped, thus total number is below 83 (as we can see in https://physionet.org/content/sleep-edfx/1.0.0/)\n", + "print(\"Subjects available: \", np.unique(X[:,SUBJECT_IDX]))\n", + "print(\"Number of nights: \", len(np.unique([f\"{int(x[0])}-{int(x[1])}\" for x in X[:,SUBJECT_IDX:NIGHT_IDX+1]])))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Train, validation and test sets\n", + "___\n", + "\n", + "If we downsize the dataset, the test set will only contain the two nights recording of the last subject (no 19) will be the test set. The rest will be the train and validation sets.\n", + "\n", + "If we did not downsize the dataset, we will randomly pick a subject from each age group to be in the test set. Both nights (if there are two) are placed in the test set so that the classifier does not train on any recordings from a subject placed in the test set.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Selected subjects for the test set are: [0.0, 24.0, 49.0, 71.0]\n", + "(8123, 50) (160831, 50) (8123,) (160831,)\n" + ] + } + ], + "source": [ + "if DOWNSIZE_SET:\n", + " X_test, X_train_valid, y_test, y_train_valid = train_test_split_one_subject(X, y)\n", + "else:\n", + " X_test, X_train_valid, y_test, y_train_valid = train_test_split_according_to_age(X,\n", + " y,\n", + " subjects_test=TEST_SET_SUBJECTS,\n", + " use_continuous_age=USE_CONTINUOUS_AGE)\n", + " \n", + "print(X_test.shape, X_train_valid.shape, y_test.shape, y_train_valid.shape)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Random forest validation\n", + "___" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "NB_KFOLDS = 5\n", + "NB_CATEGORICAL_FEATURES = 2\n", + "NB_FEATURES = 48\n", + "\n", + "CLASSIFIER_PIPELINE_KEY = 'classifier'\n", + "\n", + "def get_random_forest_model():\n", + " return Pipeline([\n", + " ('scaling', ColumnTransformer([\n", + " ('pass-through-categorical', 'passthrough', list(range(NB_CATEGORICAL_FEATURES))),\n", + " ('scaling-continuous', StandardScaler(copy=False), list(range(NB_CATEGORICAL_FEATURES,NB_FEATURES)))\n", + " ])),\n", + " (CLASSIFIER_PIPELINE_KEY, RandomForestClassifier(\n", + " n_estimators=100,\n", + " random_state=42, # enables deterministic behaviour\n", + " n_jobs=-1\n", + " ))\n", + " ])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For the cross validation, we will use the `GroupKFold` technique. For each fold, we make sure to train and validate on different subjects, to avoid overfitting over subjects." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 9 µs, sys: 1 µs, total: 10 µs\n", + "Wall time: 14.1 µs\n" + ] + } + ], + "source": [ + "%%time\n", + "\n", + "def cross_validate_pipeline(pipeline):\n", + " accuracies = []\n", + " macro_f1_scores = []\n", + " weighted_f1_scores = []\n", + " kappa_agreements = []\n", + " emission_matrix = np.zeros((N_STAGES,N_STAGES))\n", + "\n", + " for train_index, valid_index in GroupKFold(n_splits=5).split(X_train_valid, groups=X_train_valid[:,SUBJECT_IDX]):\n", + " # We drop the subject and night indexes\n", + " X_train, X_valid = X_train_valid[train_index, 2:], X_train_valid[valid_index, 2:]\n", + " y_train, y_valid = y_train_valid[train_index], y_train_valid[valid_index]\n", + "\n", + " pipeline.fit(X_train, y_train)\n", + " y_valid_pred = pipeline.predict(X_valid)\n", + "\n", + " print(\"----------------------------- FOLD RESULTS --------------------------------------\\n\")\n", + " current_kappa = cohen_kappa_score(y_valid, y_valid_pred)\n", + "\n", + " print(\"TRAIN:\", train_index, \"VALID:\", valid_index, \"\\n\\n\")\n", + " print(confusion_matrix(y_valid, y_valid_pred), \"\\n\")\n", + " print(classification_report(y_valid, y_valid_pred, target_names=SLEEP_STAGES_VALUES.keys()), \"\\n\")\n", + " print(\"Agreement score (Cohen Kappa): \", current_kappa, \"\\n\")\n", + "\n", + " accuracies.append(round(accuracy_score(y_valid, y_valid_pred),2))\n", + " macro_f1_scores.append(f1_score(y_valid, y_valid_pred, average=\"macro\"))\n", + " weighted_f1_scores.append(f1_score(y_valid, y_valid_pred, average=\"weighted\"))\n", + " kappa_agreements.append(current_kappa)\n", + "\n", + " for y_pred, y_true in zip(y_valid_pred, y_valid):\n", + " emission_matrix[y_true, y_pred] += 1\n", + "\n", + " emission_matrix = emission_matrix / emission_matrix.sum(axis=1, keepdims=True)\n", + " \n", + " print(f\"Mean accuracy : {np.mean(accuracies):0.2f} ± {np.std(accuracies):0.3f}\")\n", + " print(f\"Mean macro F1-score : {np.mean(macro_f1_scores):0.2f} ± {np.std(macro_f1_scores):0.3f}\")\n", + " print(f\"Mean weighted F1-score : {np.mean(weighted_f1_scores):0.2f} ± {np.std(weighted_f1_scores):0.3f}\")\n", + " print(f\"Mean Kappa's agreement : {np.mean(kappa_agreements):0.2f} ± {np.std(kappa_agreements):0.3f}\")\n", + "\n", + " return emission_matrix" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------------- FOLD RESULTS --------------------------------------\n", + "\n", + "TRAIN: [ 2137 2138 2139 ... 158843 158844 158845] VALID: [ 0 1 2 ... 160828 160829 160830] \n", + "\n", + "\n", + "[[ 7206 194 111 2 139]\n", + " [ 1235 534 1404 1 543]\n", + " [ 993 439 10654 360 492]\n", + " [ 155 7 632 2132 5]\n", + " [ 842 907 1233 5 2579]] \n", + "\n", + " precision recall f1-score support\n", + "\n", + " W 0.69 0.94 0.80 7652\n", + " N1 0.26 0.14 0.18 3717\n", + " N2 0.76 0.82 0.79 12938\n", + " N3 0.85 0.73 0.79 2931\n", + " REM 0.69 0.46 0.55 5566\n", + "\n", + " accuracy 0.70 32804\n", + " macro avg 0.65 0.62 0.62 32804\n", + "weighted avg 0.68 0.70 0.68 32804\n", + " \n", + "\n", + "Agreement score (Cohen Kappa): 0.5914311657565539 \n", + "\n", + "----------------------------- FOLD RESULTS --------------------------------------\n", + "\n", + "TRAIN: [ 0 1 2 ... 160828 160829 160830] VALID: [ 5807 5808 5809 ... 158843 158844 158845] \n", + "\n", + "\n", + "[[ 6893 550 108 11 267]\n", + " [ 888 867 1136 3 1036]\n", + " [ 156 327 11494 392 992]\n", + " [ 23 0 452 1814 0]\n", + " [ 222 632 796 4 3185]] \n", + "\n", + " precision recall f1-score support\n", + "\n", + " W 0.84 0.88 0.86 7829\n", + " N1 0.36 0.22 0.27 3930\n", + " N2 0.82 0.86 0.84 13361\n", + " N3 0.82 0.79 0.80 2289\n", + " REM 0.58 0.66 0.62 4839\n", + "\n", + " accuracy 0.75 32248\n", + " macro avg 0.69 0.68 0.68 32248\n", + "weighted avg 0.73 0.75 0.74 32248\n", + " \n", + "\n", + "Agreement score (Cohen Kappa): 0.6553464447399939 \n", + "\n", + "----------------------------- FOLD RESULTS --------------------------------------\n", + "\n", + "TRAIN: [ 0 1 2 ... 160828 160829 160830] VALID: [ 2137 2138 2139 ... 151913 151914 151915] \n", + "\n", + "\n", + "[[7954 616 219 19 606]\n", + " [ 855 984 1223 10 1704]\n", + " [ 567 701 9904 181 1698]\n", + " [ 41 0 216 767 0]\n", + " [ 384 511 661 12 2462]] \n", + "\n", + " precision recall f1-score support\n", + "\n", + " W 0.81 0.84 0.83 9414\n", + " N1 0.35 0.21 0.26 4776\n", + " N2 0.81 0.76 0.78 13051\n", + " N3 0.78 0.75 0.76 1024\n", + " REM 0.38 0.61 0.47 4030\n", + "\n", + " accuracy 0.68 32295\n", + " macro avg 0.63 0.63 0.62 32295\n", + "weighted avg 0.69 0.68 0.68 32295\n", + " \n", + "\n", + "Agreement score (Cohen Kappa): 0.5601422740587234 \n", + "\n", + "----------------------------- FOLD RESULTS --------------------------------------\n", + "\n", + "TRAIN: [ 0 1 2 ... 160828 160829 160830] VALID: [ 4057 4058 4059 ... 121623 121624 121625] \n", + "\n", + "\n", + "[[ 6661 321 99 4 189]\n", + " [ 791 549 1154 14 873]\n", + " [ 216 192 11469 359 656]\n", + " [ 41 0 687 2567 2]\n", + " [ 386 498 1039 6 3351]] \n", + "\n", + " precision recall f1-score support\n", + "\n", + " W 0.82 0.92 0.87 7274\n", + " N1 0.35 0.16 0.22 3381\n", + " N2 0.79 0.89 0.84 12892\n", + " N3 0.87 0.78 0.82 3297\n", + " REM 0.66 0.63 0.65 5280\n", + "\n", + " accuracy 0.77 32124\n", + " macro avg 0.70 0.68 0.68 32124\n", + "weighted avg 0.74 0.77 0.75 32124\n", + " \n", + "\n", + "Agreement score (Cohen Kappa): 0.6754525828472742 \n", + "\n", + "----------------------------- FOLD RESULTS --------------------------------------\n", + "\n", + "TRAIN: [ 0 1 2 ... 160828 160829 160830] VALID: [ 13884 13885 13886 ... 156772 156773 156774] \n", + "\n", + "\n", + "[[ 6545 612 379 27 325]\n", + " [ 674 750 1561 11 939]\n", + " [ 238 313 10710 242 626]\n", + " [ 39 0 729 1887 1]\n", + " [ 355 829 1147 7 2414]] \n", + "\n", + " precision recall f1-score support\n", + "\n", + " W 0.83 0.83 0.83 7888\n", + " N1 0.30 0.19 0.23 3935\n", + " N2 0.74 0.88 0.80 12129\n", + " N3 0.87 0.71 0.78 2656\n", + " REM 0.56 0.51 0.53 4752\n", + "\n", + " accuracy 0.71 31360\n", + " macro avg 0.66 0.62 0.64 31360\n", + "weighted avg 0.69 0.71 0.70 31360\n", + " \n", + "\n", + "Agreement score (Cohen Kappa): 0.5996710408224084 \n", + "\n", + "Mean accuracy : 0.72 ± 0.033\n", + "Mean macro F1-score : 0.65 ± 0.027\n", + "Mean weighted F1-score : 0.71 ± 0.029\n", + "Mean Kappa's agreement : 0.62 ± 0.043\n" + ] + }, + { + "data": { + "text/plain": [ + "array([[8.80220686e-01, 5.72434281e-02, 2.28674139e-02, 1.57275882e-03,\n", + " 3.80957136e-02],\n", + " [2.25087390e-01, 1.86635595e-01, 3.28182785e-01, 1.97578398e-03,\n", + " 2.58118446e-01],\n", + " [3.37108325e-02, 3.06349132e-02, 8.42475649e-01, 2.38306070e-02,\n", + " 6.93479983e-02],\n", + " [2.45142248e-02, 5.73911618e-04, 2.22677708e-01, 7.51578257e-01,\n", + " 6.55898992e-04],\n", + " [8.94674459e-02, 1.38022643e-01, 1.99288838e-01, 1.38962684e-03,\n", + " 5.71831446e-01]])" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "validation_pipeline = get_random_forest_model()\n", + "validation_pipeline.set_params(\n", + " classifier__max_depth=24,\n", + " classifier__n_estimators=100,\n", + ")\n", + "\n", + "cross_validate_pipeline(validation_pipeline)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Random forest training and testing\n", + "___" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 3min 41s, sys: 2.12 s, total: 3min 43s\n", + "Wall time: 1min 13s\n" + ] + } + ], + "source": [ + "%%time\n", + "\n", + "testing_pipeline = get_random_forest_model()\n", + "testing_pipeline.set_params(\n", + " classifier__max_depth=24,\n", + " classifier__n_estimators=100,\n", + ")\n", + "\n", + "testing_pipeline.fit(X_train_valid[:, 2:], y_train_valid);" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Categorical features: [0 1]\n", + "Time domain features: [ 2 3 4 5 6 7 8 25 26 27 28 29 30 31]\n", + "Frequency domain features: [ 9 10 11 12 13 14 15 16 17 18 19 32 33 34 35 36 37 38 39 40 41 42]\n", + "Subband time domain features: [20 21 22 23 24 43 44 45 46 47]\n", + "\n", + "Top 5 features: [(41, 0.0627), (29, 0.0487), (18, 0.0421), (20, 0.0411), (47, 0.0403)]\n", + "Bottom 5 features: [(11, 0.0108), (27, 0.0093), (4, 0.0091), (42, 0.0066), (0, 0.0031)]\n", + "\n", + "Fpz-Cz feature importances: 0.4553\n", + "Pz-Oz feature importances: 0.5284\n", + "\n", + "Category feature importances: 0.0162\n", + "Time domain feature importances: 0.2843\n", + "Frequency domain feature importances: 0.4711\n", + "Subband time domain feature importances: 0.2283\n" + ] + } + ], + "source": [ + "feature_importance_indexes = [\n", + " (idx, round(importance,4))\n", + " for idx, importance in enumerate(testing_pipeline.steps[1][1].feature_importances_)\n", + "]\n", + "feature_importance_indexes.sort(reverse=True, key=lambda x: x[1])\n", + "\n", + "category_feature_range = np.array([2, 3]) - 2\n", + "time_domaine_feature_range = np.array([4, 5, 6, 7, 8, 9, 10, 27, 28, 29, 30, 31, 32, 33]) - 2\n", + "freq_domain_feature_range = np.array([11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44]) - 2\n", + "subband_domain_feature_range = np.array([22, 23, 24, 25, 26, 45, 46, 47, 48, 49]) - 2\n", + "fpz_cz_feature_range = np.array(range(2, 25))\n", + "pz_oz_feature_range = np.array(range(25, 48))\n", + "\n", + "def get_feature_range_importance(indexes):\n", + " return np.sum([feature[1] for feature in feature_importance_indexes if feature[0] in indexes])\n", + "\n", + "print(f\"Categorical features: {category_feature_range}\")\n", + "print(f\"Time domain features: {time_domaine_feature_range}\")\n", + "print(f\"Frequency domain features: {freq_domain_feature_range}\")\n", + "print(f\"Subband time domain features: {subband_domain_feature_range}\\n\")\n", + "\n", + "print(f\"Top 5 features: {[feature for feature in feature_importance_indexes[:5]]}\")\n", + "print(f\"Bottom 5 features: {[feature for feature in feature_importance_indexes[-5:]]}\\n\")\n", + "\n", + "print(f\"Fpz-Cz feature importances: {get_feature_range_importance(fpz_cz_feature_range):.4f}\")\n", + "print(f\"Pz-Oz feature importances: {get_feature_range_importance(pz_oz_feature_range):.4f}\\n\")\n", + "\n", + "print(f\"Category feature importances: {get_feature_range_importance([0,1]):.4f}\")\n", + "print(f\"Time domain feature importances: {get_feature_range_importance(time_domaine_feature_range):.4f}\")\n", + "print(f\"Frequency domain feature importances: {get_feature_range_importance(freq_domain_feature_range):.4f}\")\n", + "print(f\"Subband time domain feature importances: {get_feature_range_importance(subband_domain_feature_range):.4f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[1512 65 3 3 41]\n", + " [ 220 147 332 0 284]\n", + " [ 39 45 3212 194 113]\n", + " [ 4 0 32 575 0]\n", + " [ 49 81 284 0 888]]\n", + " precision recall f1-score support\n", + "\n", + " W 0.83 0.93 0.88 1624\n", + " N1 0.43 0.15 0.22 983\n", + " N2 0.83 0.89 0.86 3603\n", + " N3 0.74 0.94 0.83 611\n", + " REM 0.67 0.68 0.68 1302\n", + "\n", + " accuracy 0.78 8123\n", + " macro avg 0.70 0.72 0.69 8123\n", + "weighted avg 0.75 0.78 0.75 8123\n", + "\n", + "Agreement score (Cohen Kappa): 0.6879671218212182\n" + ] + } + ], + "source": [ + "y_test_pred = testing_pipeline.predict(X_test[:,2:])\n", + "\n", + "print(confusion_matrix(y_test, y_test_pred))\n", + "\n", + "print(classification_report(y_test, y_test_pred, target_names=SLEEP_STAGES_VALUES.keys()))\n", + "\n", + "print(\"Agreement score (Cohen Kappa): \", cohen_kappa_score(y_test, y_test_pred))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Saving trained model\n", + "___\n", + "\n", + "We save the trained model with the postprocessing step, HMM. We will save only the matrix that define it. We do not need to persist the median filter postprocessing step, because it is stateless." + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [], + "source": [ + "SAVED_DIR = \"trained_model\"\n", + "\n", + "if not os.path.exists(SAVED_DIR):\n", + " os.mkdir(SAVED_DIR); " + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Pipeline object size (Mbytes): 322.775421\n" + ] + } + ], + "source": [ + "if USE_CONTINUOUS_AGE: \n", + " joblib.dump(testing_pipeline, f\"{SAVED_DIR}/classifier_RF_continous_age.joblib\")\n", + "else:\n", + " fd = joblib.dump(testing_pipeline, f\"{SAVED_DIR}/classifier_RF_small.joblib\")\n", + " print(\n", + " \"Pipeline object size (Mbytes): \",\n", + " os.path.getsize(f\"{SAVED_DIR}/classifier_RF_small.joblib\")/1e6\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.8" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/backend/.gitignore b/backend/.gitignore index 6dfe2dc3..552623c5 100644 --- a/backend/.gitignore +++ b/backend/.gitignore @@ -2,3 +2,5 @@ .vscode/ __pycache__/ *.onnx +.DS_Store +hmm_model/ diff --git a/backend/app.py b/backend/app.py index 64c37bb3..11e736d9 100644 --- a/backend/app.py +++ b/backend/app.py @@ -4,11 +4,12 @@ from http import HTTPStatus from classification.file_loading import get_raw_array -from classification.predict import predict from classification.exceptions import ClassificationError from classification.config.constants import Sex, ALLOWED_FILE_EXTENSIONS +from classification.model import SleepStagesClassifier app = Flask(__name__) +model = SleepStagesClassifier() def allowed_file(filename): @@ -57,7 +58,7 @@ def analyze_sleep(): try: raw_array = get_raw_array(file) - predict(raw_array, info={ + model.predict(raw_array, info={ 'sex': sex, 'age': age, 'in_bed_seconds': bedtime - stream_start, diff --git a/backend/classification/config/constants.py b/backend/classification/config/constants.py index 80692c6d..10414a75 100644 --- a/backend/classification/config/constants.py +++ b/backend/classification/config/constants.py @@ -1,4 +1,4 @@ -from enum import Enum +from enum import Enum, auto class Sex(Enum): @@ -8,6 +8,15 @@ class Sex(Enum): M = 2 +class HiddenMarkovModelProbability(Enum): + emission = auto() + start = auto() + transition = auto() + + def get_filename(self): + return f'hmm_{self.name}_probabilities.npy' + + ALLOWED_FILE_EXTENSIONS = ('.txt', '.csv') EEG_CHANNELS = [ @@ -29,3 +38,5 @@ class Sex(Enum): [85, 125] ] ACCEPTED_AGE_RANGE = [AGE_FEATURE_BINS[0][0], AGE_FEATURE_BINS[-1][-1]] + +N_STAGES = 5 diff --git a/backend/classification/features/__init__.py b/backend/classification/features/__init__.py index dc66d99b..c21ae168 100644 --- a/backend/classification/features/__init__.py +++ b/backend/classification/features/__init__.py @@ -26,4 +26,4 @@ def get_features(signal, info): X_eeg = get_eeg_features(signal, info['in_bed_seconds'], info['out_of_bed_seconds']) X_categorical = get_non_eeg_features(info['age'], info['sex'], X_eeg.shape[0]) - return np.append(X_categorical, X_eeg, axis=1) + return np.append(X_categorical, X_eeg, axis=1).astype(np.float32) diff --git a/backend/classification/features/constants.py b/backend/classification/features/constants.py index 4dd21316..171b3fe6 100644 --- a/backend/classification/features/constants.py +++ b/backend/classification/features/constants.py @@ -1,7 +1,4 @@ -from classification.config.constants import ( - DATASET_SAMPLE_RATE, - EPOCH_DURATION, -) +from classification.config.constants import DATASET_SAMPLE_RATE NYQUIST_FREQ = DATASET_SAMPLE_RATE / 2 diff --git a/backend/classification/features/preprocessing.py b/backend/classification/features/preprocessing.py index 321f02a8..0561baa7 100644 --- a/backend/classification/features/preprocessing.py +++ b/backend/classification/features/preprocessing.py @@ -1,8 +1,10 @@ import mne from scipy.signal import cheby1 -from classification.features.constants import ( +from classification.config.constants import ( EPOCH_DURATION, +) +from classification.features.constants import ( DATASET_SAMPLE_RATE, DATASET_HIGH_PASS_FREQ, HIGH_PASS_FILTER_ORDER, diff --git a/backend/classification/file_loading.py b/backend/classification/file_loading.py index 9dc44683..c9307180 100644 --- a/backend/classification/file_loading.py +++ b/backend/classification/file_loading.py @@ -27,6 +27,7 @@ ADS1299_Vref = 4.5 ADS1299_gain = 24. SCALE_uV_PER_COUNT = ADS1299_Vref / ((2**23) - 1) / ADS1299_gain * 1000000 +SCALE_V_PER_COUNT = SCALE_uV_PER_COUNT / 1e6 FILE_COLUMN_OFFSET = 1 CYTON_TOTAL_NB_CHANNELS = 8 @@ -48,7 +49,7 @@ def get_raw_array(file): if len(line_splitted) >= CYTON_TOTAL_NB_CHANNELS: eeg_raw.append(_get_decimals_from_hexadecimal_strings(line_splitted)) - eeg_raw = SCALE_uV_PER_COUNT * np.array(eeg_raw, dtype='object') + eeg_raw = SCALE_V_PER_COUNT * np.array(eeg_raw, dtype='object') raw_object = RawArray( np.transpose(eeg_raw), diff --git a/backend/classification/load_model.py b/backend/classification/load_model.py index 4731235a..f1426be5 100644 --- a/backend/classification/load_model.py +++ b/backend/classification/load_model.py @@ -1,37 +1,74 @@ -from os import path +from datetime import datetime +from os import path, makedirs +from pathlib import Path +import re import sys +import xml.etree.ElementTree as ET + +import certifi +import numpy as np from requests import get import onnxruntime -import re -import xml.etree.ElementTree as ET -SCRIPT_PATH = path.dirname(path.realpath(sys.argv[0])) +from classification.config.constants import HiddenMarkovModelProbability + +SCRIPT_PATH = Path(path.realpath(sys.argv[0])).parent + +BUCKET_NAME = 'polydodo' +BUCKET_URL = f'https://{BUCKET_NAME}.s3.amazonaws.com' + MODEL_FILENAME = 'model.onnx' -MODEL_PATH = f'{SCRIPT_PATH}/{MODEL_FILENAME}' -MODEL_BUCKET = 'polydodo' -BUCKET_URL = f'https://{MODEL_BUCKET}.s3.amazonaws.com' +MODEL_PATH = SCRIPT_PATH / MODEL_FILENAME MODEL_URL = f'{BUCKET_URL}/{MODEL_FILENAME}' +HMM_FOLDER = 'hmm_model' + def _download_file(url, output): with open(output, 'wb') as f: - f.write(get(url).content) + f.write(get(url, verify=certifi.where()).content) -def _get_latest_object_size(bucket_url, filename): - raw_result = get(bucket_url).text +def _get_object_latest_update(filename): + raw_result = get(BUCKET_URL, verify=certifi.where()).text # https://stackoverflow.com/a/15641319 raw_result = re.sub(' xmlns="[^"]+"', '', raw_result) result_root_node = ET.fromstring(raw_result) objects_nodes = result_root_node.findall('Contents') object_node = [object_node for object_node in objects_nodes if object_node.find("Key").text == filename][0] - object_size = int(object_node.find("Size").text) - return object_size + object_latest_update = datetime.strptime(object_node.find("LastModified").text, "%Y-%m-%dT%H:%M:%S.%f%z") + + return object_latest_update + + +def _has_latest_object(filename, local_path): + latest_model_update = _get_object_latest_update(filename) + current_model_update = datetime.fromtimestamp(path.getmtime(local_path)).astimezone() + + return current_model_update >= latest_model_update def load_model(): - if not path.exists(MODEL_PATH) or path.getsize(MODEL_PATH) != _get_latest_object_size(BUCKET_URL, MODEL_FILENAME): + if not path.exists(MODEL_PATH) or not _has_latest_object(MODEL_FILENAME, MODEL_PATH): print("Downloading latest model...") _download_file(MODEL_URL, MODEL_PATH) print("Loading model...") - return onnxruntime.InferenceSession(MODEL_PATH) + return onnxruntime.InferenceSession(str(MODEL_PATH)) + + +def load_hmm(): + hmm_matrices = dict() + + if not path.exists(SCRIPT_PATH / HMM_FOLDER): + makedirs(SCRIPT_PATH / HMM_FOLDER) + + for hmm_probability in HiddenMarkovModelProbability: + hmm_file = hmm_probability.get_filename() + model_path = SCRIPT_PATH / HMM_FOLDER / hmm_file + + if not path.exists(model_path) or not _has_latest_object(hmm_file, model_path): + _download_file(url=f"{BUCKET_URL}/{hmm_file}", output=model_path) + + hmm_matrices[hmm_probability.name] = np.load(str(model_path)) + + return hmm_matrices diff --git a/backend/classification/model.py b/backend/classification/model.py new file mode 100644 index 00000000..db1f2a00 --- /dev/null +++ b/backend/classification/model.py @@ -0,0 +1,49 @@ +"""defines models which predict sleep stages based off EEG signals""" + +from classification.features import get_features +from classification.validation import validate +from classification.postprocessor import get_hmm_model +from classification.load_model import load_model, load_hmm + + +class SleepStagesClassifier(): + def __init__(self): + self.model = load_model() + self.model_input_name = self.model.get_inputs()[0].name + + self.postprocessor_state = load_hmm() + self.postprocessor = get_hmm_model(self.postprocessor_state) + + def predict(self, raw_eeg, info): + """ + Input: + - raw_eeg: instance of mne.io.RawArray + Should contain 2 channels (1: FPZ-CZ, 2: PZ-OZ) + - info: dict + Should contain the following keys: + - sex: instance of Sex enum + - age: indicates the subject's age + - in_bed_seconds: timespan, in seconds, from which + the subject started the recording and went to bed + - out_of_bed_seconds: timespan, in seconds, from which + the subject started the recording and got out of bed + Returns: array of predicted sleep stages + """ + + validate(raw_eeg, info) + features = get_features(raw_eeg, info) + + print(features, features.shape) + + predictions = self._get_predictions(features) + predictions = self._get_postprocessed_predictions(predictions) + + print(predictions) + + return predictions + + def _get_predictions(self, features): + return self.model.run(None, {self.model_input_name: features})[0] + + def _get_postprocessed_predictions(self, predictions): + return self.postprocessor.predict(predictions.reshape(-1, 1)) diff --git a/backend/classification/postprocessor.py b/backend/classification/postprocessor.py new file mode 100644 index 00000000..260fdae9 --- /dev/null +++ b/backend/classification/postprocessor.py @@ -0,0 +1,24 @@ +from hmmlearn.hmm import MultinomialHMM + +from classification.config.constants import ( + HiddenMarkovModelProbability, + N_STAGES, +) + + +def get_hmm_model(state): + """Creates an instance of MultinomialHMM, which follows sklearn interface + Input: + - state: dictionnary + where the keys are HiddenMarkovModelProbability choices + where the values are the probabilities matrices or arrays which + describes the according hidden markov model state + Returns: an instance of a trained MultinomialHMM + """ + hmm_model = MultinomialHMM(n_components=N_STAGES) + + hmm_model.emissionprob_ = state[HiddenMarkovModelProbability.emission.name] + hmm_model.startprob_ = state[HiddenMarkovModelProbability.start.name] + hmm_model.transmat_ = state[HiddenMarkovModelProbability.transition.name] + + return hmm_model diff --git a/backend/classification/predict.py b/backend/classification/predict.py deleted file mode 100644 index 07c6dbb7..00000000 --- a/backend/classification/predict.py +++ /dev/null @@ -1,23 +0,0 @@ -"""defines functions to predict sleep stages based off EEG signals""" -from classification.features import get_features -from classification.validation import validate - - -def predict(raw_eeg, info): - """ - Input: - - raw_eeg: instance of mne.io.RawArray - Should contain 2 channels (1: FPZ-CZ, 2: PZ-OZ) - - info: dict - Should contain the following keys: - - sex: instance of Sex enum - - age: indicates the subject's age - - in_bed_seconds: timespan, in seconds, from which - the subject started the recording and went to bed - - out_of_bed_seconds: timespan, in seconds, from which - the subject started the recording and got out of bed - """ - validate(raw_eeg, info) - X_openbci = get_features(raw_eeg, info) - - print(X_openbci[0], X_openbci.shape) diff --git a/backend/readme.md b/backend/readme.md index 256b487b..b7d97173 100644 --- a/backend/readme.md +++ b/backend/readme.md @@ -20,6 +20,13 @@ Install the required dependencies. pip install -r requirements.txt requirements-dev.txt ``` +If you are running on Linux or MacOS, you also have to install OpenMP with your package manager. It is a dependency of ONNX runtime, used to load our model and make predictions. + +```bash +apt-get install libgomp1 # on linux +brew install libomp # on macos +``` + ## Run it locally Activate your virtual environment. diff --git a/backend/requirements.txt b/backend/requirements.txt index e062c41a..684f25fc 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -7,3 +7,6 @@ onnxruntime==1.5.2 numpy==1.19.2 scipy==1.5.2 scikit-learn==0.23.2 +requests==2.7.0 +hmmlearn==0.2.4 +certifi==2020.6.20