Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 17 additions & 29 deletions backend/classification/load_model.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,37 @@
from os import path
import hashlib
import sys
from requests import get
import onnxruntime
import re
import xml.etree.ElementTree as ET

SCRIPT_PATH = path.dirname(path.realpath(sys.argv[0]))
MODEL_FILENAME = 'model.onnx'
MODEL_PATH = f'{SCRIPT_PATH}/{MODEL_FILENAME}'
MODEL_REPO = 'polycortex/polydodo-model'
MODEL_URL = f'https://raw.githubusercontent.com/{MODEL_REPO}/master/{MODEL_FILENAME}'
MODEL_BUCKET = 'polydodo'
BUCKET_URL = f'https://{MODEL_BUCKET}.s3.amazonaws.com'
MODEL_URL = f'{BUCKET_URL}/{MODEL_FILENAME}'


def _download_file(url, output):
with open(output, 'wb') as f:
f.write(get(url).content)


def _get_latest_model_githash():
request = f'https://api.github.com/repos/{MODEL_REPO}/commits/master'
repo = get(request).json()
model_info = [file_info for file_info in repo['files'] if file_info['filename'] == MODEL_FILENAME][0]
return model_info["sha"]


def _get_file_githash(filepath):
BUF_SIZE = 65536

sha1 = hashlib.sha1()

# https://stackoverflow.com/a/1869911
filesize = path.getsize(filepath)
git_hash_header = f'blob {filesize}\0'.encode('utf-8')
sha1.update(git_hash_header)

# https://stackoverflow.com/a/22058673
with open(filepath, 'rb') as f:
while True:
data = f.read(BUF_SIZE)
if not data:
break
sha1.update(data)
return sha1.hexdigest()
def _get_latest_object_size(bucket_url, filename):
raw_result = get(bucket_url).text
# https://stackoverflow.com/a/15641319
raw_result = re.sub(' xmlns="[^"]+"', '', raw_result)
result_root_node = ET.fromstring(raw_result)
objects_nodes = result_root_node.findall('Contents')
object_node = [object_node for object_node in objects_nodes if object_node.find("Key").text == filename][0]
object_size = int(object_node.find("Size").text)
return object_size


def load_model():
if not path.exists(MODEL_PATH) or _get_file_githash(MODEL_PATH) != _get_latest_model_githash():
if not path.exists(MODEL_PATH) or path.getsize(MODEL_PATH) != _get_latest_object_size(BUCKET_URL, MODEL_FILENAME):
print("Downloading latest model...")
_download_file(MODEL_URL, MODEL_PATH)
print("Loading model...")
return onnxruntime.InferenceSession(MODEL_PATH)