Skip to content
This repository was archived by the owner on Feb 23, 2022. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 56 additions & 7 deletions multinet/uploaders/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@

from flask import Blueprint, request
from flask import current_app as app
from webargs import fields as webarg_fields
from webargs.flaskparser import use_kwargs

# Import types
from typing import Set, MutableMapping, Sequence, Any, List
from typing import Set, MutableMapping, Sequence, Any, List, Dict


bp = Blueprint("csv", __name__)
Expand All @@ -30,21 +32,45 @@ class InvalidRow(ValidationFailure):


@dataclass
class KeyFieldAlreadyExists(ValidationFailure):
"""CSV file has both existing _key field and specified key field."""

key: str


@dataclass
class KeyFieldDoesNotExist(ValidationFailure):
"""The specified key field does not exist."""

key: str


class MissingBody(ValidationFailure):
"""Missing body in a CSV file."""


def validate_csv(rows: Sequence[MutableMapping]) -> None:
def validate_csv(
rows: Sequence[MutableMapping], key_field: str = "_key", overwrite: bool = False
) -> None:
"""Perform any necessary CSV validation, and return appropriate errors."""
data_errors: List[ValidationFailure] = []

if not rows:
raise ValidationFailed([MissingBody()])

fieldnames = rows[0].keys()
if "_key" in fieldnames:

if key_field != "_key" and key_field not in fieldnames:
data_errors.append(KeyFieldDoesNotExist(key=key_field))
raise ValidationFailed(data_errors)

if "_key" in fieldnames and key_field != "_key" and not overwrite:
data_errors.append(KeyFieldAlreadyExists(key=key_field))
raise ValidationFailed(data_errors)

if key_field in fieldnames:
# Node Table, check for key uniqueness
keys = [row["_key"] for row in rows]
keys = [row[key_field] for row in rows]
unique_keys: Set[str] = set()
for key in keys:
if key in unique_keys:
Expand Down Expand Up @@ -75,9 +101,28 @@ def validate_csv(rows: Sequence[MutableMapping]) -> None:
raise ValidationFailed(data_errors)


def set_table_key(rows: List[Dict[str, str]], key: str) -> List[Dict[str, str]]:
"""Update the _key field in each row."""
new_rows: List[Dict[str, str]] = []
for row in rows:
new_row = dict(row)
new_row["_key"] = new_row[key]
new_rows.append(new_row)

return new_rows


@bp.route("/<workspace>/<table>", methods=["POST"])
@use_kwargs(
{
"key": webarg_fields.Str(location="query"),
"overwrite": webarg_fields.Bool(location="query"),
}
)
@swag_from("swagger/csv.yaml")
def upload(workspace: str, table: str) -> Any:
def upload(
workspace: str, table: str, key: str = "_key", overwrite: bool = False
) -> Any:
"""
Store a CSV file into the database as a node or edge table.

Expand All @@ -91,10 +136,14 @@ def upload(workspace: str, table: str) -> Any:
# Read the request body into CSV format
body = decode_data(request.data)

rows = list(csv.DictReader(StringIO(body)))
# Type to a Dict rather than an OrderedDict
rows: List[Dict[str, str]] = list(csv.DictReader(StringIO(body)))

# Perform validation.
validate_csv(rows)
validate_csv(rows, key, overwrite)

if key != "_key" and overwrite:
rows = set_table_key(rows, key)

# Set the collection, paying attention to whether the data contains
# _from/_to fields.
Expand Down
20 changes: 19 additions & 1 deletion multinet/uploaders/swagger/csv.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ consumes:
parameters:
- $ref: "#/parameters/workspace"
- $ref: "#/parameters/table"
- name: data
-
name: data
Comment thread
jjnesbitt marked this conversation as resolved.
in: body
description: Raw CSV text
schema:
Expand All @@ -16,6 +17,23 @@ parameters:
0,picard,captain
1,riker,commander
2,data,lieutenant commander
-
name: key
in: query
Comment thread
JackWilb marked this conversation as resolved.
Comment thread
waxlamp marked this conversation as resolved.
description: Key Field
schema:
type: string
example: _key
-
name: overwrite
in: query
description: Overwrites the default key field if it exists
enum:
- true
- false
schema:
type: boolean
default: false

responses:
200:
Expand Down
4 changes: 3 additions & 1 deletion mypy_stubs/webargs/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ class fields:
@staticmethod
def Int() -> Any: ...
@staticmethod
def Str() -> Any: ...
def Str(required: bool = False, location: str = "json") -> Any: ...
@staticmethod
def List(t: Any) -> Any: ...
@staticmethod
def Bool(required: bool = False, location: str = "json") -> Any: ...
4 changes: 4 additions & 0 deletions test/data/startrek.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
_key,name,rank
0,picard,captain
1,riker,commander
2,data,lieutenant commander
4 changes: 4 additions & 0 deletions test/data/startrek_no_key_field.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
name,rank
picard,captain
riker,commander
data,lieutenant commander
94 changes: 75 additions & 19 deletions test/test_csv_uploader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,40 +5,94 @@
import pytest

from multinet.errors import ValidationFailed, DecodeFailed
from multinet.uploaders.csv import validate_csv, decode_data, InvalidRow
from multinet.validation import DuplicateKey
from multinet.uploaders.csv import (
validate_csv,
decode_data,
InvalidRow,
KeyFieldAlreadyExists,
KeyFieldDoesNotExist,
)
from multinet.validation import DuplicateKey, UnsupportedTable

TEST_DATA_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "data"))


def test_validate_csv():
"""Tests the validate_csv function."""
duplicate_keys_file_path = os.path.join(
TEST_DATA_DIR, "clubs_invalid_duplicate_keys.csv"
)
def read_csv(filename: str):
"""Read in CSV files."""
file_path = os.path.join(TEST_DATA_DIR, filename)
with open(file_path) as path_file:
return list(csv.DictReader(StringIO(path_file.read())))

invalid_headers_file_path = os.path.join(
TEST_DATA_DIR, "membership_invalid_syntax.csv"
)

# Test duplicate keys
with open(duplicate_keys_file_path) as test_file:
test_file = test_file.read()
def test_missing_key_field():
"""Test that missing key fields are handled properly."""
rows = read_csv("startrek_no_key_field.csv")

correct = UnsupportedTable().asdict()
with pytest.raises(ValidationFailed) as v_error:
validate_csv(rows)

validation_resp = v_error.value.errors
assert len(validation_resp) == 1
assert validation_resp[0] == correct


def test_invalid_key_field():
"""Test that specifying a missing key field results in an error."""
rows = read_csv("startrek.csv")
invalid_key = "invalid"

correct = KeyFieldDoesNotExist(key=invalid_key).asdict()
with pytest.raises(ValidationFailed) as v_error:
validate_csv(rows, key_field=invalid_key)

validation_resp = v_error.value.errors
assert len(validation_resp) == 1
assert validation_resp[0] == correct

rows = list(csv.DictReader(StringIO(test_file)))

def test_key_field_already_exists_a():
"""
Test that specifying a key when one already exists results in an error.

(overwrite = False)
"""
rows = read_csv("startrek.csv")
key_field = "name"

correct = KeyFieldAlreadyExists(key=key_field).asdict()
with pytest.raises(ValidationFailed) as v_error:
validate_csv(rows, key_field=key_field, overwrite=False)

validation_resp = v_error.value.errors
assert len(validation_resp) == 1
assert validation_resp[0] == correct


def test_key_field_already_exists_b():
"""
Test that specifying a key when one already exists doesn't result in an error.

(overwrite = True).
"""
rows = read_csv("startrek.csv")
validate_csv(rows, key_field="name", overwrite=True)


def test_duplicate_keys():
"""Test that duplicate keys are handled properly."""
rows = read_csv("clubs_invalid_duplicate_keys.csv")
with pytest.raises(ValidationFailed) as v_error:
validate_csv(rows)

validation_resp = v_error.value.errors
correct = [err.asdict() for err in [DuplicateKey(key="2"), DuplicateKey(key="5")]]
assert all(err in validation_resp for err in correct)

# Test invalid syntax
with open(invalid_headers_file_path) as test_file:
test_file = test_file.read()

rows = list(csv.DictReader(StringIO(test_file)))
def test_invalid_headers():
"""Test that invalid headers are handled properly."""
rows = read_csv("membership_invalid_syntax.csv")
with pytest.raises(ValidationFailed) as v_error:
validate_csv(rows)

Expand All @@ -53,6 +107,8 @@ def test_validate_csv():
]
assert all(err in validation_resp for err in correct)

# Test unicode decode errors

def test_decode_failed():
"""Test that the DecodeFailed validation error is raised."""
test_data = b"\xff\xfe_\x00k\x00e\x00y\x00,\x00n\x00a\x00m\x00e\x00\n"
pytest.raises(DecodeFailed, decode_data, test_data)