-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathAPI_server.py
More file actions
145 lines (104 loc) · 4 KB
/
API_server.py
File metadata and controls
145 lines (104 loc) · 4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
#! /usr/bin/env python3
# coding: utf-8
import os
import io
import pathlib
from flask import Flask, flash, request, redirect, jsonify
import numpy as np
from PIL import Image
# import tflite_runtime.interpreter as tflite
import onnxruntime as rt
print("ONX:", rt.get_device())
# ########## API ##########
# --- Load TF Model ---
base_W = 512
base_H = 256
base_resolution = f"{base_W}x{base_H}"
print("Load Semantic-segmentation Model")
model_name = "FPN-efficientnetb7_with_data_augmentation_2_diceLoss_512x256"
# -- with a keras model
# model = keras.models.load_model(
# f"models/{model_name}.keras",
# custom_objects={
# "iou_score": sm.metrics.iou_score,
# "f1-score": sm.metrics.f1_score,
# "dice_loss": sm.losses.DiceLoss(),
# },
# )
# -- with a TF-Lite model
# interpreter = tflite.Interpreter(model_path=f"models/{model_name}.tflite")
# interpreter.resize_tensor_input(0, [1, base_H, base_W, 3])
# interpreter.allocate_tensors()
# input_index = interpreter.get_input_details()[0]["index"]
# output_index = interpreter.get_output_details()[0]["index"]
# --- with a ONNX model
# providers = ['CPUExecutionProvider']
providers = ['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider']
m = rt.InferenceSession(str(pathlib.Path('models', f"{model_name}.onnx")), providers=providers)
# --- API Flask app ---
app = Flask(__name__)
app.secret_key = "super secret key"
UPLOAD_FOLDER = "/uploads"
ALLOWED_EXTENSIONS = {"png", "jpg", "jpeg"}
def allowed_file(filename):
return "." in filename and filename.rsplit(".", 1)[1].lower() in ALLOWED_EXTENSIONS
@app.route("/")
def index():
return "The 'CityScape Semantic-segmentation API' server is up."
@app.route("/predict/", methods=["GET", "POST"])
def upload_file():
if request.method == "POST":
# check if the post request has the file part
if "file" not in request.files:
flash("No file part")
return redirect(request.url)
file = request.files["file"]
# If the user does not select a file, the browser submits an
# empty file without a filename.
if file.filename == "":
flash("No selected file")
return redirect(request.url)
if file and (allowed_file(file.filename) or file.filename == 'file'):
print(os.getcwd())
# filename = secure_filename(file.filename)
image_bytes = Image.open(io.BytesIO(file.read()))
# Preprocess image
# img = preprocess_sample(image_bytes, preprocess_input)
# /!\ Preprocessed layers are now included in the model
img = np.array([np.array(image_bytes)], dtype=np.float32)
if (img.shape[1] != base_H or img.shape[2] != base_W):
raise Exception(f"Custom Error: wrong image size ({base_H}x{base_W}) required!")
# Apply model
print("--- Predict")
# pred = model.predict(img) # keras model
img = np.array(img, dtype=np.float32)
print(img.shape)
# -- Predict with TF-Lite
# interpreter.set_tensor(input_index, img)
# interpreter.invoke()
# pred = interpreter.get_tensor(output_index)
# -- Predict with ONNX
pred = m.run(['model_6'], {'input': img})[0]
# Convert to categories
mask = np.argmax(pred, axis=3)[0]
# Return the matrix
return jsonify(mask.tolist())
return """
<!doctype html>
<html>
<head>
<title>Upload new File</title>
</head>
<body>
<h1>Upload new File</h1>
<form method=post enctype=multipart/form-data>
<input type=file name=file>
<input type=submit value=Upload>
</form>
</body>
</html>
"""
# ########## START API ##########
if __name__ == "__main__":
current_port = int(os.environ.get("PORT") or 5000)
app.run(debug=True, host="0.0.0.0", port=current_port)