diff --git a/multinet/errors.py b/multinet/errors.py index 38a0b51e..a6329e26 100644 --- a/multinet/errors.py +++ b/multinet/errors.py @@ -153,3 +153,15 @@ class DatabaseNotLive(ServerError): def flask_response(self): """Generate a 500 error.""" return ("", "500 Database Not Live") + + +class DecodeFailed(ServerError): + """Exception for reporting decoding errors.""" + + def __init__(self, error): + """Initialize the exception.""" + self.error = error + + def flask_response(self): + """Generate a 400 error.""" + return (self.error, "400 Decode Failed") diff --git a/multinet/uploaders/csv.py b/multinet/uploaders/csv.py index 45b40945..fbb413ed 100644 --- a/multinet/uploaders/csv.py +++ b/multinet/uploaders/csv.py @@ -5,6 +5,7 @@ from .. import db, util from ..errors import ValidationFailed +from ..util import decode_data from flask import Blueprint, request from flask import current_app as app @@ -15,6 +16,8 @@ def validate_csv(rows): """Perform any necessary CSV validation, and return appropriate errors.""" + data_errors = [] + fieldnames = rows[0].keys() if "_key" in fieldnames: # Node Table, check for key uniqueness @@ -28,7 +31,7 @@ def validate_csv(rows): unique_keys.add(key) if len(duplicates) > 0: - return {"error": "duplicate", "detail": list(duplicates)} + data_errors.append({"error": "duplicate", "detail": list(duplicates)}) elif "_from" in fieldnames and "_to" in fieldnames: # Edge Table, check that each cell has the correct format valid_cell = re.compile("[^/]+/[^/]+") @@ -47,9 +50,15 @@ def validate_csv(rows): detail.append({"fields": fields, "row": i + 2}) if detail: - return {"error": "syntax", "detail": detail} + data_errors.append({"error": "syntax", "detail": detail}) + else: + # Unsupported Table, error since we don't know what's coming in + data_errors.append({"error": "unsupported"}) - return None + if len(data_errors) > 0: + raise ValidationFailed(data_errors) + else: + return None @bp.route("//", methods=["POST"]) @@ -65,13 +74,12 @@ def upload(workspace, table): app.logger.info("Bulk Loading") # Read the request body into CSV format - body = request.data.decode("utf8") + body = decode_data(request.data) + rows = list(csv.DictReader(StringIO(body))) # Perform validation. - result = validate_csv(rows) - if result: - raise ValidationFailed(result) + validate_csv(rows) # Set the collection, paying attention to whether the data contains # _from/_to fields. diff --git a/multinet/uploaders/newick.py b/multinet/uploaders/newick.py index 5086dc92..1a1d02ff 100644 --- a/multinet/uploaders/newick.py +++ b/multinet/uploaders/newick.py @@ -3,6 +3,8 @@ import newick from .. import db, util +from ..errors import ValidationFailed +from ..util import decode_data from flask import Blueprint, request from flask import current_app as app @@ -11,6 +13,51 @@ bp.before_request(util.require_db) +def validate_newick(tree): + """Validate newick tree.""" + data_errors = [] + unique_keys = [] + duplicate_keys = [] + unique_edges = [] + duplicate_edges = [] + + def read_tree(parent, node): + key = node.name or uuid.uuid4().hex + + if key not in unique_keys: + unique_keys.append(key) + elif key not in duplicate_keys: + duplicate_keys.append(key) + + for desc in node.descendants: + read_tree(key, desc) + + if parent: + edge = { + "_from": "table/%s" % (parent), + "_to": "table/%s" % (key), + "length": node.length, + } + + if edge not in unique_edges: + unique_edges.append(edge) + elif edge not in duplicate_edges: + duplicate_edges.append(edge) + + read_tree(None, tree[0]) + + if len(duplicate_keys) > 0: + data_errors.append({"error": "duplicate", "detail": duplicate_keys}) + + if len(duplicate_edges) > 0: + data_errors.append({"error": "duplicate", "detail": duplicate_edges}) + + if len(data_errors) > 0: + raise ValidationFailed(data_errors) + else: + return + + @bp.route("//
", methods=["POST"]) def upload(workspace, table): """ @@ -21,7 +68,13 @@ def upload(workspace, table): `data` - the newick data, passed in the request body. """ app.logger.info("newick tree") - tree = newick.loads(request.data.decode("utf8")) + + body = decode_data(request.data) + + tree = newick.loads(body) + + validate_newick(tree) + workspace = db.db(workspace) edgetable_name = "%s_edges" % table nodetable_name = "%s_nodes" % table diff --git a/multinet/util.py b/multinet/util.py index 21675b00..455caf86 100644 --- a/multinet/util.py +++ b/multinet/util.py @@ -4,7 +4,7 @@ from flask import Response from . import db -from .errors import DatabaseNotLive +from .errors import DatabaseNotLive, DecodeFailed def generate(iterator): @@ -28,3 +28,13 @@ def require_db(): """Check if the db is live.""" if not db.check_db(): raise DatabaseNotLive() + + +def decode_data(input): + """Decode the request data assuming utf8 encoding.""" + try: + body = input.decode("utf8") + except UnicodeDecodeError as e: + raise DecodeFailed({"error": "utf8", "detail": str(e)}) + + return body diff --git a/test/data/basic_newick.tree b/test/data/basic_newick.tree new file mode 100644 index 00000000..e7548cac --- /dev/null +++ b/test/data/basic_newick.tree @@ -0,0 +1 @@ +(B,(A,C,E),D); diff --git a/test/data/basic_newick_duplicates.tree b/test/data/basic_newick_duplicates.tree new file mode 100644 index 00000000..7886bd9c --- /dev/null +++ b/test/data/basic_newick_duplicates.tree @@ -0,0 +1 @@ +(B,(A,C,A),D); diff --git a/test/data/basic_newick_utf16.tree b/test/data/basic_newick_utf16.tree new file mode 100644 index 00000000..ae988151 Binary files /dev/null and b/test/data/basic_newick_utf16.tree differ diff --git a/test/data/clubs_utf16.csv b/test/data/clubs_utf16.csv new file mode 100644 index 00000000..a2550bb3 Binary files /dev/null and b/test/data/clubs_utf16.csv differ diff --git a/test/test_csv_uploader.py b/test/test_csv_uploader.py index f740f4e0..34e1f873 100644 --- a/test/test_csv_uploader.py +++ b/test/test_csv_uploader.py @@ -2,8 +2,10 @@ import csv from io import StringIO import os +import pytest -from multinet.uploaders.csv import validate_csv +from multinet.errors import ValidationFailed, DecodeFailed +from multinet.uploaders.csv import validate_csv, decode_data TEST_DATA_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "data")) @@ -23,19 +25,31 @@ def test_validate_csv(): test_file = test_file.read() rows = list(csv.DictReader(StringIO(test_file))) - validation_resp = validate_csv(rows) - assert "error" in validation_resp.keys() - assert "5" in validation_resp["detail"] - assert "2" in validation_resp["detail"] + + with pytest.raises(ValidationFailed) as v_error: + validate_csv(rows) + + validation_resp = v_error.value.errors[0] + assert "error" in validation_resp + duplicate_keys = validation_resp["detail"] + assert "5" in duplicate_keys + assert "2" in duplicate_keys # Test invalid syntax with open(invalid_headers_file_path) as test_file: test_file = test_file.read() rows = list(csv.DictReader(StringIO(test_file))) - validation_resp = validate_csv(rows) + with pytest.raises(ValidationFailed) as v_error: + validate_csv(rows) + + validation_resp = v_error.value.errors[0] invalid_rows = [x["row"] for x in validation_resp["detail"]] - assert "error" in validation_resp.keys() + assert "error" in validation_resp assert 3 in invalid_rows assert 4 in invalid_rows assert 5 in invalid_rows + + # Test unicode decode errors + test_data = b"\xff\xfe_\x00k\x00e\x00y\x00,\x00n\x00a\x00m\x00e\x00\n" + pytest.raises(DecodeFailed, decode_data, test_data) diff --git a/test/test_newick_uploader.py b/test/test_newick_uploader.py new file mode 100644 index 00000000..25688755 --- /dev/null +++ b/test/test_newick_uploader.py @@ -0,0 +1,35 @@ +"""Tests functions in the Neick Uploader Flask Blueprint.""" +import newick +import os +import pytest + +from multinet.errors import ValidationFailed, DecodeFailed +from multinet.uploaders.newick import validate_newick, decode_data + +TEST_DATA_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "data")) + + +def test_validate_newick(): + """Tests the validate_csv function.""" + duplicate_keys_file_path = os.path.join( + TEST_DATA_DIR, "basic_newick_duplicates.tree" + ) + + # Test duplicate keys + with open(duplicate_keys_file_path) as test_file: + test_file = test_file.read() + + body = newick.loads(test_file) + + with pytest.raises(ValidationFailed) as v_error: + validate_newick(body) + + validation_resp = v_error.value.errors[0] + assert "error" in validation_resp.keys() + + # Test unicode decode errors + test_data = ( + b"\xff\xfe(\x00B\x00,\x00(\x00A\x00," + b"\x00C\x00,\x00E\x00)\x00,\x00D\x00)\x00;\x00\n\x00" + ) + pytest.raises(DecodeFailed, decode_data, test_data)