diff --git a/multinet/uploaders/csv.py b/multinet/uploaders/csv.py index 3dc50895..15d41e9c 100644 --- a/multinet/uploaders/csv.py +++ b/multinet/uploaders/csv.py @@ -57,16 +57,46 @@ def flask_response(self) -> FlaskTuple: return ("Could not read CSV data", "415 Unsupported Media Type") -def validate_csv( - rows: Sequence[MutableMapping], key_field: str = "_key", overwrite: bool = False -) -> None: - """Perform any necessary CSV validation, and return appropriate errors.""" +def is_edge_table(rows: Sequence[MutableMapping]) -> bool: + """Determine if this table should be treated as an edge table.""" + fieldnames = rows[0].keys() + return "_from" in fieldnames and "_to" in fieldnames + + +def is_node_table(rows: Sequence[MutableMapping], key_field: str) -> bool: + """Determine if this table should be treated as a node table.""" + fieldnames = rows[0].keys() + return key_field != "_key" or "_key" in fieldnames + + +def validate_edge_table(rows: Sequence[MutableMapping]) -> None: + """Validate that the given table is a valid edge table.""" data_errors: List[ValidationFailure] = [] - if not rows: - raise ValidationFailed([MissingBody()]) + # Checks that a cell has the form table_name/key + valid_cell = re.compile("[^/]+/[^/]+") + + for i, row in enumerate(rows): + fields: List[str] = [] + if not valid_cell.match(row["_from"]): + fields.append("_from") + if not valid_cell.match(row["_to"]): + fields.append("_to") + + if fields: + # i+2 -> +1 for index offset, +1 due to header row + data_errors.append(InvalidRow(fields=fields, row=i + 2)) + + if len(data_errors) > 0: + raise ValidationFailed(data_errors) + +def validate_node_table( + rows: Sequence[MutableMapping], key_field: str, overwrite: bool +) -> None: + """Validate that the given table is a valid node table.""" fieldnames = rows[0].keys() + data_errors: List[ValidationFailure] = [] if key_field != "_key" and key_field not in fieldnames: data_errors.append(KeyFieldDoesNotExist(key=key_field)) @@ -76,39 +106,33 @@ def validate_csv( data_errors.append(KeyFieldAlreadyExists(key=key_field)) raise ValidationFailed(data_errors) - if key_field in fieldnames: - # Node Table, check for key uniqueness - keys = [row[key_field] for row in rows] - unique_keys: Set[str] = set() - for key in keys: - if key in unique_keys: - data_errors.append(DuplicateKey(key=key)) - else: - unique_keys.add(key) - - elif "_from" in fieldnames and "_to" in fieldnames: - # Edge Table, check that each cell has the correct format - valid_cell = re.compile("[^/]+/[^/]+") - - for i, row in enumerate(rows): - fields: List[str] = [] - if not valid_cell.match(row["_from"]): - fields.append("_from") - if not valid_cell.match(row["_to"]): - fields.append("_to") - - if fields: - # i+2 -> +1 for index offset, +1 due to header row - data_errors.append(InvalidRow(fields=fields, row=i + 2)) - - else: - # Unsupported Table, error since we don't know what's coming in - data_errors.append(UnsupportedTable()) + keys = (row[key_field] for row in rows) + unique_keys: Set[str] = set() + for key in keys: + if key in unique_keys: + data_errors.append(DuplicateKey(key=key)) + else: + unique_keys.add(key) if len(data_errors) > 0: raise ValidationFailed(data_errors) +def validate_csv( + rows: Sequence[MutableMapping], key_field: str, overwrite: bool +) -> None: + """Perform any necessary CSV validation, and return appropriate errors.""" + if not rows: + raise ValidationFailed([MissingBody()]) + + if is_node_table(rows, key_field): + validate_node_table(rows, key_field, overwrite) + elif is_edge_table(rows): + validate_edge_table(rows) + else: + raise ValidationFailed([UnsupportedTable()]) + + def set_table_key(rows: List[Dict[str, str]], key: str) -> List[Dict[str, str]]: """Update the _key field in each row.""" new_rows: List[Dict[str, str]] = [] diff --git a/test/test_csv_uploader.py b/test/test_csv_uploader.py index 5be5a3c2..758cc24e 100644 --- a/test/test_csv_uploader.py +++ b/test/test_csv_uploader.py @@ -30,7 +30,7 @@ def test_missing_key_field(): correct = UnsupportedTable().asdict() with pytest.raises(ValidationFailed) as v_error: - validate_csv(rows) + validate_csv(rows, key_field="_key", overwrite=False) validation_resp = v_error.value.errors assert len(validation_resp) == 1 @@ -44,7 +44,7 @@ def test_invalid_key_field(): correct = KeyFieldDoesNotExist(key=invalid_key).asdict() with pytest.raises(ValidationFailed) as v_error: - validate_csv(rows, key_field=invalid_key) + validate_csv(rows, key_field=invalid_key, overwrite=False) validation_resp = v_error.value.errors assert len(validation_resp) == 1 @@ -83,7 +83,7 @@ def test_duplicate_keys(): """Test that duplicate keys are handled properly.""" rows = read_csv("clubs_invalid_duplicate_keys.csv") with pytest.raises(ValidationFailed) as v_error: - validate_csv(rows) + validate_csv(rows, key_field="_key", overwrite=False) validation_resp = v_error.value.errors correct = [err.asdict() for err in [DuplicateKey(key="2"), DuplicateKey(key="5")]] @@ -94,7 +94,7 @@ def test_invalid_headers(): """Test that invalid headers are handled properly.""" rows = read_csv("membership_invalid_syntax.csv") with pytest.raises(ValidationFailed) as v_error: - validate_csv(rows) + validate_csv(rows, key_field="_key", overwrite=False) validation_resp = v_error.value.errors correct = [