Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backend/.gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
.venv/
.vscode/
__pycache__/
*.onnx
49 changes: 49 additions & 0 deletions backend/classification/load_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from os import path
import hashlib
import sys
from requests import get
import onnxruntime

SCRIPT_PATH = path.dirname(path.realpath(sys.argv[0]))
MODEL_FILENAME = 'model.onnx'
MODEL_PATH = f'{SCRIPT_PATH}/{MODEL_FILENAME}'
MODEL_REPO = 'polycortex/polydodo-model'
MODEL_URL = f'https://raw.githubusercontent.com/{MODEL_REPO}/master/{MODEL_FILENAME}'


def _download_file(url, output):
with open(output, 'wb') as f:
f.write(get(url).content)


def _get_latest_model_githash():
request = f'https://api.github.com/repos/{MODEL_REPO}/commits/master'
repo = get(request).json()
model_info = [file_info for file_info in repo['files'] if file_info['filename'] == MODEL_FILENAME][0]
return model_info["sha"]


def _get_file_githash(filepath):
BUF_SIZE = 65536

sha1 = hashlib.sha1()

# https://stackoverflow.com/a/1869911
filesize = path.getsize(filepath)
git_hash_header = f'blob {filesize}\0'.encode('utf-8')
sha1.update(git_hash_header)

# https://stackoverflow.com/a/22058673
with open(filepath, 'rb') as f:
while True:
data = f.read(BUF_SIZE)
if not data:
break
sha1.update(data)
return sha1.hexdigest()


def load_model():
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Il se peut qu'on doit importer le modèle et le feature extraction en deux objets différents, comme au niveau de l'entraînement, je ne peux pas les joindre ensemble. On pourrait ajuster ça plus tard, la logique pour le loader sera exactement la même.

if not path.exists(MODEL_PATH) or _get_file_githash(MODEL_PATH) != _get_latest_model_githash():
_download_file(MODEL_URL, MODEL_PATH)
return onnxruntime.InferenceSession(MODEL_PATH)
1 change: 1 addition & 0 deletions backend/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ Flask-Cors==1.10.3
waitress==1.4.4

mne==0.21.0
onnxruntime==1.5.2
numpy==1.19.2
scipy==1.5.2
scikit-learn==0.23.2