From 334640f6ae74624ccb1c2bcafb636d101be774ad Mon Sep 17 00:00:00 2001 From: Anes Belfodil Date: Fri, 30 Oct 2020 12:58:25 -0400 Subject: [PATCH] Replace GitHub LFS with S3 bucket --- backend/classification/load_model.py | 46 ++++++++++------------------ 1 file changed, 17 insertions(+), 29 deletions(-) diff --git a/backend/classification/load_model.py b/backend/classification/load_model.py index 0b20ba1c..4731235a 100644 --- a/backend/classification/load_model.py +++ b/backend/classification/load_model.py @@ -1,14 +1,16 @@ from os import path -import hashlib import sys from requests import get import onnxruntime +import re +import xml.etree.ElementTree as ET SCRIPT_PATH = path.dirname(path.realpath(sys.argv[0])) MODEL_FILENAME = 'model.onnx' MODEL_PATH = f'{SCRIPT_PATH}/{MODEL_FILENAME}' -MODEL_REPO = 'polycortex/polydodo-model' -MODEL_URL = f'https://raw.githubusercontent.com/{MODEL_REPO}/master/{MODEL_FILENAME}' +MODEL_BUCKET = 'polydodo' +BUCKET_URL = f'https://{MODEL_BUCKET}.s3.amazonaws.com' +MODEL_URL = f'{BUCKET_URL}/{MODEL_FILENAME}' def _download_file(url, output): @@ -16,34 +18,20 @@ def _download_file(url, output): f.write(get(url).content) -def _get_latest_model_githash(): - request = f'https://api.github.com/repos/{MODEL_REPO}/commits/master' - repo = get(request).json() - model_info = [file_info for file_info in repo['files'] if file_info['filename'] == MODEL_FILENAME][0] - return model_info["sha"] - - -def _get_file_githash(filepath): - BUF_SIZE = 65536 - - sha1 = hashlib.sha1() - - # https://stackoverflow.com/a/1869911 - filesize = path.getsize(filepath) - git_hash_header = f'blob {filesize}\0'.encode('utf-8') - sha1.update(git_hash_header) - - # https://stackoverflow.com/a/22058673 - with open(filepath, 'rb') as f: - while True: - data = f.read(BUF_SIZE) - if not data: - break - sha1.update(data) - return sha1.hexdigest() +def _get_latest_object_size(bucket_url, filename): + raw_result = get(bucket_url).text + # https://stackoverflow.com/a/15641319 + raw_result = re.sub(' xmlns="[^"]+"', '', raw_result) + result_root_node = ET.fromstring(raw_result) + objects_nodes = result_root_node.findall('Contents') + object_node = [object_node for object_node in objects_nodes if object_node.find("Key").text == filename][0] + object_size = int(object_node.find("Size").text) + return object_size def load_model(): - if not path.exists(MODEL_PATH) or _get_file_githash(MODEL_PATH) != _get_latest_model_githash(): + if not path.exists(MODEL_PATH) or path.getsize(MODEL_PATH) != _get_latest_object_size(BUCKET_URL, MODEL_FILENAME): + print("Downloading latest model...") _download_file(MODEL_URL, MODEL_PATH) + print("Loading model...") return onnxruntime.InferenceSession(MODEL_PATH)