Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ jobs:
sudo apt install -y g++-11
sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-11 90
sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-11 90
sudo apt install libzip-dev

- name: install-cpprest
run: sudo apt install libcpprest-dev
Expand Down
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
Build/
build/

# Large-files
zip-examples
model_data

# IDE
.vscode
.idea
Expand All @@ -10,6 +14,7 @@ build/
.DS_Store
.AppleDouble
.LSOverride
*.entitlements

# Javascript
package-lock.json
Expand Down
102 changes: 71 additions & 31 deletions client/scripts/requests.js
Original file line number Diff line number Diff line change
Expand Up @@ -124,29 +124,60 @@ async function deleteConnection(sending_object) {
return response
}

function trainRequest() {
if (!train_data) {
errorNotification("No training data was set.")
} else {
fetch(`http://${py_server_address}/train/${user_id}/${model_id}/0`, {
method: "PUT",
mode: "cors",
headers: {"Content-Type": "text/csv"},
body: train_data,
}).then(response => {
showBuildNotification(response.ok)
onTrainShowPredict(response.ok)
if (response.ok) {
setModelView("success")
} else {
setModelView("error")
}
function uploadRequest() {
if (data_upload.files.length == 0) return
const file = data_upload.files[0]

fetch(`http://${py_server_address}/${user_id}/${model_id}`, {
method: "PATCH",
mode: "cors",
body: file,
}).then(response => {
if (!response.ok) {
Swal.fire({
position: "top-end",
icon: "error",
title: "Failed to upload data",
showConfirmButton: false,
timer: 1500,
})
console.error(`Failed to upload data for ${file.name}`)
return
}
Swal.fire({
position: "top-end",
icon: "success",
title: "Successfully uploaded",
showConfirmButton: false,
timer: 1500,
})
}
})
setModelView("irrelevant")
// allow user to press a train button from now on
button_wrapper = document.getElementById("train-button")
button_wrapper.getElementsByTy
train_button = button_wrapper.children[0]
button_wrapper.removeAttribute("disabled")
train_button.removeAttribute("disabled")
}

function trainRequest() {
fetch(`http://${py_server_address}/train/${user_id}/${model_id}/0`, {
method: "PUT",
mode: "cors",
}).then(response => {
showBuildNotification(response.ok)
onTrainShowPredict(response.ok)
if (response.ok) {
setModelView("success")
} else {
setModelView("error")
}
})
}

async function predictRequest() {
if (csv_predict.files.length == 0) {
if (predict_button.files.length == 0) {
errorNotification("Empty predict file.")
return
}
Expand All @@ -159,25 +190,34 @@ async function predictRequest() {
showConfirmButton: true,
})
}
const file = csv_predict.files[0]
const text = await file.text()
const response = await fetch(
const file = predict_button.files[0]
let response = await fetch(
`http://${py_server_address}/predict/${user_id}/${model_id}`,
{
method: "PUT",
mode: "cors",
headers: {"Content-Type": "text/csv"},
body: text,
body: file,
},
)
if (!response.ok) {
Swal.fire("Error!", "Failed to upload the png image", "error")
console.error(
`Failed to upload the png with ${response.statusText}: ${responseJson.error}`,
)
return
}
response = await fetch(
`http://${py_server_address}/predict/${user_id}/${model_id}`,
{method: "GET", mode: "cors"},
)
const responseJson = await response.json()
hideResult() // hide previous predict result
if (response.ok) onPredictShowResult(responseJson)
else {
Swal.fire("Error!", "Server is not responding now.", "error")
errorNotification("Failed to predict.\n" + responseJson.error)
console.log(
`Predict failed with ${response.statusText}: ${responseJson.error}`,
if (!response.ok) {
Swal.fire("Error!", "Failed to predict", "error")
console.error(
`Failed to predict with with ${response.statusText}: ${responseJson.error}`,
)
return
}
hideResult() // hide previous predict result
onPredictShowResult(responseJson)
}
34 changes: 11 additions & 23 deletions client/templates/main.html
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ <h2 onmousedown="return false">MLCraft</h2>

<label class="button-wrapper">
Upload
<input type="file" id="csv-upload-button" accept=".csv" />
<input type="file" id="dataset-upload-button" accept=".csv,.zip" />
</label>

<label class="button-wrapper" id="train-button" disabled>
Expand All @@ -123,7 +123,7 @@ <h2 onmousedown="return false">MLCraft</h2>

<label class="button-wrapper" disabled>
Predict
<input type="file" accept=".csv" id="csv-predict-button" disabled />
<input type="file" accept=".png" id="predict-button" disabled />
</label>
<div class="predict-result-wrapper" id="predict-result-wrapper">
<output class="form-control-res" id="res-value-out" name="res-value" for="x-value y-value"></output>
Expand Down Expand Up @@ -1082,12 +1082,12 @@ <h3></h3>

<!--JSON requests handling-->
<script>
const csv_predict = document.getElementById("csv-predict-button")
csv_predict.onclick = () => {
csv_predict.value = null
const predict_button = document.getElementById("predict-button")
predict_button.onclick = () => {
predict_button.value = null
}

csv_predict.addEventListener("input", predictRequest)
predict_button.addEventListener("input", predictRequest)

function buildJsonFormData(form) {
const jsonFormData = {}
Expand Down Expand Up @@ -1251,26 +1251,14 @@ <h3></h3>
const showBuildNotification = build_status => {
build_status ? buildSuccessful() : buildUnsuccessful()
}
async function csvUploaded() {
if (csv_upload.files.length == 0) return
const file = csv_upload.files[0]
const text = await file.text()
// do whatever with the `text`
setModelView("irrelevant")
console.log(`Got goods: ${text}`)
train_data = text
// allow user to press a train button from now on
button_wrapper = document.getElementById("train-button")
button_wrapper.getElementsByTy
train_button = button_wrapper.children[0]
button_wrapper.removeAttribute("disabled")
train_button.removeAttribute("disabled")
const data_upload = document.getElementById("dataset-upload-button")
data_upload.onclick = () => {
data_upload.value = null
}
const csv_upload = document.getElementById("csv-upload-button")
csv_upload.addEventListener("change", csvUploaded)
data_upload.addEventListener("change", uploadRequest)

function onTrainShowPredict(training_successful) {
const predict_button = document.getElementById("csv-predict-button")
const predict_button = document.getElementById("predict-button")
hideResult() // hide result when new training results arrive
if (training_successful) {
predict_button.removeAttribute("disabled")
Expand Down
Binary file added documentation/png-examples/Ours.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added documentation/png-examples/Vectors.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions py_server/mlcraft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def make_app(config=None):
app.register_error_handler(KeyError, key_error) # no json field
app.register_error_handler(ValueError, value_error) # json wrong type
app.register_error_handler(HTTPException, http_error)
app.register_error_handler(TimeoutError, timeout_error)
app.register_error_handler(ConnectTimeout, timeout_error)
app.register_error_handler(ConnectionError, connection_error)

Expand Down
48 changes: 0 additions & 48 deletions py_server/mlcraft/dataset.py

This file was deleted.

2 changes: 1 addition & 1 deletion py_server/mlcraft/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def http_error(e: HTTPException):
return {"error": str(e)}, HTTPStatus.BAD_REQUEST


def timeout_error(e: ConnectTimeout):
def timeout_error(e: TimeoutError | ConnectTimeout):
return {"error": "Request to c++ server timeout"}, HTTPStatus.INTERNAL_SERVER_ERROR


Expand Down
42 changes: 27 additions & 15 deletions py_server/mlcraft/server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
from json import dumps
from sqlite3 import IntegrityError
import datetime
import requests
from http import HTTPStatus
from flask import Blueprint, request, current_app, send_file
Expand All @@ -12,19 +9,21 @@
is_valid_model,
convert_model,
plot_metrics,
delete_file,
)
from .check_dimensions import assert_dimensions_match

from .errors import Error

from .db import sql_worker
from .dataset import extract_predict_data, extract_train_data


app = Blueprint("make a better name", __name__)


def cpp_url(method: str):
return current_app.config["CPP_SERVER"] + "/" + method


@app.route("/user", methods=["POST"])
def add_user():
json_data = request.json
Expand Down Expand Up @@ -52,7 +51,7 @@ def add_model(user_id: int):
return {"model_id": inserted_id}, HTTPStatus.CREATED


@app.route("/<int:user_id>/<int:model_id>", methods=["GET", "PUT", "DELETE"])
@app.route("/<int:user_id>/<int:model_id>", methods=["GET", "PUT", "PATCH", "DELETE"])
def model(user_id: int, model_id: int):
sql_worker.verify_access(user_id, model_id)
match request.method:
Expand All @@ -63,6 +62,14 @@ def model(user_id: int, model_id: int):
d: dict[str, str | None] = defaultdict(lambda: None, **request.json) # type: ignore
sql_worker.update_model(model_id, d["name"], d["raw"])
return "", HTTPStatus.OK
case "PATCH":
response = requests.post(
cpp_url(f"upload_data/{model_id}/0"),
data=request.data,
headers={"Content-Type": request.content_type},
timeout=10,
)
return "", response.status_code
case "DELETE":
sql_worker.delete_model(model_id)
return "", HTTPStatus.OK
Expand Down Expand Up @@ -166,7 +173,7 @@ def train_model(
model = {"graph": model}

response = requests.post(
current_app.config["CPP_SERVER"] + f"/train/{user_id}/{model_id}",
cpp_url(f"train/{user_id}/{model_id}"),
json=model,
timeout=3,
)
Expand All @@ -175,20 +182,25 @@ def train_model(
return response.text, response.status_code


@app.route("/predict/<int:user_id>/<int:model_id>", methods=["PUT"])
@app.route("/predict/<int:user_id>/<int:model_id>", methods=["GET", "PUT"])
def predict(user_id: int, model_id: int):
"""PUT method is for uploading the png, GET method is for receiving the result"""
sql_worker.verify_access(user_id, model_id)

model = sql_worker.get_graph_elements(model_id)
convert_model_parameters(model)

if not sql_worker.is_model_trained(model_id):
raise Error("Not trained", HTTPStatus.PRECONDITION_FAILED)

response = requests.put(
current_app.config["CPP_SERVER"] + f"/predict/{model_id}",
timeout=3,
)
match request.method:
case "GET":
response = requests.put(
cpp_url(f"predict/{model_id}"),
)
case "PUT":
response = requests.post(
cpp_url(f"upload_data/{model_id}/1"),
data=request.data,
headers={"Content-Type": "image/png"},
)

return response.text, response.status_code

Expand Down
Loading