diff --git a/multinet/api.py b/multinet/api.py index c503ccb1..3229be9a 100644 --- a/multinet/api.py +++ b/multinet/api.py @@ -100,6 +100,22 @@ def get_table_rows(workspace: str, table: str, offset: int = 0, limit: int = 30) return Workspace(workspace).table(table).rows(offset, limit) +@bp.route("/workspaces//tables//metadata", methods=["GET"]) +@require_reader +@swag_from("swagger/get_metadata.yaml") +def get_table_metadata(workspace: str, table: str) -> Any: + """Retrieve the metadata of a table, if it exists.""" + return Workspace(workspace).table(table).get_metadata().dict() + + +@bp.route("/workspaces//tables/
/metadata", methods=["PUT"]) +@require_reader +@swag_from("swagger/set_metadata.yaml") +def set_table_metadata(workspace: str, table: str) -> Any: + """Retrieve the rows and headers of a table.""" + return Workspace(workspace).table(table).set_metadata(request.json).dict() + + @bp.route("/workspaces//graphs", methods=["GET"]) @require_reader @swag_from("swagger/workspace_graphs.yaml") @@ -158,7 +174,7 @@ def get_node_edges( """Return the edges connected to a node.""" allowed = ["incoming", "outgoing", "all"] if direction not in allowed: - raise BadQueryArgument("direction", direction, allowed) + raise BadQueryArgument("direction", direction) return ( Workspace(workspace) diff --git a/multinet/db/models/table.py b/multinet/db/models/table.py index 751e224e..1f5778fa 100644 --- a/multinet/db/models/table.py +++ b/multinet/db/models/table.py @@ -1,11 +1,19 @@ """Operations that deal with tables.""" from __future__ import annotations # noqa: T484 + from arango.collection import StandardCollection from arango.aql import AQL +from pydantic import ValidationError as PydanticValidationError from multinet import util -from multinet.types import EdgeTableProperties -from multinet.errors import ServerError, FlaskTuple +from multinet.db.models import workspace +from multinet.types import ( + EdgeTableProperties, + ArangoEntityDocument, + EntityMetadata, + TableMetadata, +) +from multinet.errors import ServerError, FlaskTuple, InvalidMetadata from typing import List, Set, Dict, Iterable, Union, Optional @@ -25,22 +33,23 @@ def flask_response(self) -> FlaskTuple: class Table: """Tables store tabular data, and are the root of all data storage in Multinet.""" - def __init__(self, name: str, workspace: str, handle: StandardCollection, aql: AQL): + def __init__(self, name: str, workspace: workspace.Workspace): """ Initialize all Table parameters, but make no requests. - The `workspace` parameter is the name of the workspace this table belongs to. - - The `handle` parameter is the handle to the arangodb collection for which this - class instance is associated. - - The `aql` parameter is the AQL handle of the creating Workspace, so that this - class may make AQL requests when necessary. + The `name` parameter is the name of this table. + The `workspace` parameter is the workspace this table belongs to. """ self.name = name - self.workspace = workspace - self.handle = handle - self.aql = aql + + # Used for inserting/modifying table metadata + self.metadata_collection = workspace.entity_metadata_collection() + + # Used for querying table items + self.handle: StandardCollection = workspace.handle.collection(name) + + # Used for running AQL queries when necessary + self.aql: AQL = workspace.handle.aql def rows(self, offset: Optional[int] = None, limit: Optional[int] = None) -> Dict: """Return the desired rows in a table.""" @@ -72,6 +81,34 @@ def headers(self) -> List[str]: return keys + def get_metadata(self) -> ArangoEntityDocument: + """Retrieve metadata for this table, if it exists.""" + try: + doc = next(self.metadata_collection.find({"item_id": self.name}, limit=1)) + except StopIteration: + entity = EntityMetadata(item_id=self.name, table=TableMetadata()) + + # Return is just metadata, merge with entity to get full doc + doc = self.metadata_collection.insert(entity.dict()) + doc.update(entity.dict()) + + return ArangoEntityDocument(**doc) + + def set_metadata(self, raw_data: Dict) -> ArangoEntityDocument: + """Set metadata for this table.""" + try: + data = TableMetadata(**raw_data) + except PydanticValidationError: + raise InvalidMetadata(raw_data) + + entity = self.get_metadata() + entity.table = data + + new_doc = entity.dict() + new_doc.update(self.metadata_collection.insert(new_doc, overwrite=True)) + + return ArangoEntityDocument(**new_doc) + def rename(self, new_name: str) -> None: """Rename a table.""" self.handle.rename(new_name) diff --git a/multinet/db/models/workspace.py b/multinet/db/models/workspace.py index 975906d9..e9e100c0 100644 --- a/multinet/db/models/workspace.py +++ b/multinet/db/models/workspace.py @@ -5,6 +5,7 @@ from pydantic import BaseModel, Field from arango.exceptions import DatabaseCreateError, EdgeDefinitionCreateError from arango.cursor import Cursor +from arango.collection import StandardCollection from multinet import util from multinet.types import EdgeTableProperties, TableType @@ -194,6 +195,13 @@ def get_metadata(self) -> Dict: # Copy so modifications to return don't poison cache return copy.deepcopy(doc) + def entity_metadata_collection(self) -> StandardCollection: + """Return the collection handle for table/graph metadata.""" + if not self.readonly_handle.has_collection("_metadata"): + return self.handle.create_collection("_metadata", system=True) + + return self.handle.collection("_metadata") + def graphs(self) -> List[Dict]: """Return the graphs in this workspace.""" return self.readonly_handle.graphs() @@ -299,7 +307,10 @@ def is_node(x: Dict[str, Any]) -> bool: def table(self, name: str) -> Table: """Return a specific table.""" - return Table(name, self.name, self.handle.collection(name), self.handle.aql) + if not self.readonly_handle.has_collection(name): + raise TableNotFound(self.name, name) + + return Table(name, self) def has_table(self, name: str) -> bool: """Return if a specific table exists.""" @@ -307,8 +318,8 @@ def has_table(self, name: str) -> bool: def create_table(self, table: str, edge: bool, sync: bool = False) -> Table: """Create a table in this workspace.""" - table_handle = self.handle.create_collection(table, edge=edge, sync=sync) - return Table(table, self.name, table_handle, self.handle.aql) + self.handle.create_collection(table, edge=edge, sync=sync) + return Table(table, self) def create_aql_table(self, table: str, aql_query: str) -> Table: """Create a table in this workspace from an aql query.""" diff --git a/multinet/errors.py b/multinet/errors.py index bf3ec295..c40fda06 100644 --- a/multinet/errors.py +++ b/multinet/errors.py @@ -1,12 +1,10 @@ """Exception objects representing Multinet-specific HTTP error conditions.""" -from typing import Tuple, Any, Union, List, Sequence -from typing_extensions import TypedDict +from typing import Tuple, Any, Union, Dict, List, Sequence from multinet.validation import ValidationFailure FlaskTuple = Tuple[Any, Union[int, str]] -Payload = TypedDict("Payload", {"argument": str, "value": str, "allowed": List[str]}) class ServerError(Exception): @@ -118,20 +116,14 @@ def __init__(self, table: str, node: str): class BadQueryArgument(ServerError): """Exception for illegal query argument value.""" - def __init__(self, argument: str, value: str, allowed: List[str]): + def __init__(self, argument: str, value: str): """Initialize the exception.""" self.argument = argument self.value = value - self.allowed = allowed def flask_response(self) -> FlaskTuple: """Generate a 400 error for the bad argument.""" - payload: Payload = { - "argument": self.argument, - "value": self.value, - "allowed": self.allowed, - } - + payload = {"argument": self.argument, "value": self.value} return (payload, "400 Bad Query Argument") @@ -160,6 +152,18 @@ def flask_response(self) -> FlaskTuple: return (self.body, "400 Malformed Request Body") +class InvalidMetadata(ServerError): + """Exception for specifying invalid metadata.""" + + def __init__(self, metadata: Dict): + """Initialize the exception.""" + self.metadata = metadata + + def flask_response(self) -> FlaskTuple: + """Generate a 400 error.""" + return (self.metadata, "400 Invalid Metadata") + + class RequiredParamsMissing(ServerError): """Exception for missing required parameters.""" diff --git a/multinet/swagger/get_metadata.yaml b/multinet/swagger/get_metadata.yaml new file mode 100644 index 00000000..4f6810a0 --- /dev/null +++ b/multinet/swagger/get_metadata.yaml @@ -0,0 +1,18 @@ +Retrieve the metadata of a table. +--- +parameters: + - $ref: "#/parameters/workspace" + - $ref: "#/parameters/table" + +responses: + 200: + description: The metadata for this table + + 404: + description: Specified workspace or table could not be found + schema: + type: string + example: workspace_or_table_that_doesnt_exist + +tags: + - table diff --git a/multinet/swagger/set_metadata.yaml b/multinet/swagger/set_metadata.yaml new file mode 100644 index 00000000..870db626 --- /dev/null +++ b/multinet/swagger/set_metadata.yaml @@ -0,0 +1,31 @@ +Set the metadata of a table. +--- +parameters: + - $ref: "#/parameters/workspace" + - $ref: "#/parameters/table" + - name: metadata + in: body + description: The metadata to set (overwrites existing data) + required: true + schema: + type: object + example: + columns: + - key: test + type: label + + - key: length + type: number + +responses: + 200: + description: The metadata for this table + + 404: + description: Specified workspace or table could not be found + schema: + type: string + example: workspace_or_table_that_doesnt_exist + +tags: + - table diff --git a/multinet/types.py b/multinet/types.py index 6159ad6e..1bb7ec6f 100644 --- a/multinet/types.py +++ b/multinet/types.py @@ -1,9 +1,62 @@ """Custom types for Multinet codebase.""" -from typing import Dict, Set +from pydantic import BaseModel, Field +from typing import List, Dict, Set, Optional, Any from typing_extensions import Literal, TypedDict EdgeDirection = Literal["all", "incoming", "outgoing"] + TableType = Literal["all", "node", "edge"] +ColumnType = Literal["label", "boolean", "category", "number", "date"] + + +class ColumnMetadata(BaseModel): + """Metadata for a table column.""" + + key: str + type: ColumnType + + +class TableMetadata(BaseModel): + """Metadata for a table.""" + + columns: List[ColumnMetadata] = Field(default_factory=list) + + +class GraphMetadata(BaseModel): + """Metadata for a graph.""" + + +class EntityMetadata(BaseModel): + """Metadata for a table or graph.""" + + item_id: str + table: Optional[TableMetadata] + graph: Optional[GraphMetadata] + + +class ArangoEntityDocument(EntityMetadata): + """An entity metadata document with arangodb metadata.""" + + def dict(self, **kwargs: Any) -> Dict: # noqa: A003 + """ + Overload existing dict function to use alias for dict serialization. + + Variable names with leading underscores aren't treated normally, and need to be + aliased to be properly specified. Since pydantic doesn't serialize with alias + names by default, this overload is needed. + """ + + kwargs["by_alias"] = True + return super().dict(**kwargs) + + id: str = Field(alias="_id") + key: str = Field(alias="_key") + rev: str = Field(alias="_rev") + + class Config: + """Model config.""" + + allow_population_by_field_name = True class EdgeTableProperties(TypedDict): diff --git a/multinet/uploaders/csv.py b/multinet/uploaders/csv.py index 9e156fcd..5f83a74f 100644 --- a/multinet/uploaders/csv.py +++ b/multinet/uploaders/csv.py @@ -1,12 +1,13 @@ """Multinet uploader for CSV files.""" import csv +import json from flasgger import swag_from from io import StringIO from multinet import util from multinet.db.models.workspace import Workspace from multinet.auth.util import require_writer -from multinet.errors import AlreadyExists, FlaskTuple, ServerError +from multinet.errors import AlreadyExists, FlaskTuple, ServerError, BadQueryArgument from multinet.util import decode_data from multinet.validation.csv import validate_csv @@ -16,7 +17,7 @@ from webargs.flaskparser import use_kwargs # Import types -from typing import Any, List, Dict +from typing import Any, List, Dict, Optional bp = Blueprint("csv", __name__) @@ -47,12 +48,17 @@ def set_table_key(rows: List[Dict[str, str]], key: str) -> List[Dict[str, str]]: { "key": webarg_fields.Str(location="query"), "overwrite": webarg_fields.Bool(location="query"), + "metadata": webarg_fields.Str(location="query"), } ) @require_writer @swag_from("swagger/csv.yaml") def upload( - workspace: str, table: str, key: str = "_key", overwrite: bool = False + workspace: str, + table: str, + key: str = "_key", + overwrite: bool = False, + metadata: Optional[str] = None, ) -> Any: """ Store a CSV file into the database as a node or edge table. @@ -95,6 +101,12 @@ def upload( # Create table and insert the data loaded_table = loaded_workspace.create_table(table, edges) - results = loaded_table.insert(rows) + if metadata: + try: + loaded_table.set_metadata(json.loads(metadata)) + except json.decoder.JSONDecodeError: + raise BadQueryArgument("metadata", metadata) + + results = loaded_table.insert(rows) return {"count": len(results)} diff --git a/multinet/uploaders/swagger/csv.yaml b/multinet/uploaders/swagger/csv.yaml index a0eaada2..07c40cc9 100644 --- a/multinet/uploaders/swagger/csv.yaml +++ b/multinet/uploaders/swagger/csv.yaml @@ -34,6 +34,21 @@ parameters: schema: type: boolean default: false + - + name: metadata + in: query + description: Metadata to set for the table + schema: + type: object + example: + columns: + - + key: city + type: label + - + key: size + type: number + responses: 200: diff --git a/mypy_stubs/arango/collection.pyi b/mypy_stubs/arango/collection.pyi index a8d905eb..6599fd0c 100644 --- a/mypy_stubs/arango/collection.pyi +++ b/mypy_stubs/arango/collection.pyi @@ -20,7 +20,6 @@ class Collection: class StandardCollection(Collection): name: str - def insert( self, document: Any, diff --git a/test/test_metadata.py b/test/test_metadata.py new file mode 100644 index 00000000..03bcef19 --- /dev/null +++ b/test/test_metadata.py @@ -0,0 +1,92 @@ +"""Testing operations on metadata.""" +import conftest +import pytest +import json + + +def test_set_valid_table_metadata(populated_workspace, managed_user, server): + """Test that setting valid metadata succeeds.""" + workspace, _, node_table, _ = populated_workspace + metadata = {"columns": [{"key": "test", "type": "label"}]} + + with conftest.login(managed_user, server): + resp = server.put( + f"/api/workspaces/{workspace.name}/tables/{node_table}/metadata", + json=metadata, + ) + + assert resp.status_code == 200 + assert resp.json["table"] == metadata + + +@pytest.mark.parametrize( + "metadata, expected", + [ + ({"columns": [{"key": "test", "type": "invalid"}]}, 400), + ({"columns": [{"foo": "bar"}]}, 400), + ({"foo": "bar"}, 200), + ], +) +def test_set_invalid_table_metadata( + populated_workspace, managed_user, server, metadata, expected +): + """Test that setting invalid metadata fails.""" + workspace, _, node_table, _ = populated_workspace + + with conftest.login(managed_user, server): + resp = server.put( + f"/api/workspaces/{workspace.name}/tables/{node_table}/metadata", + json=metadata, + ) + assert resp.status_code == expected + + +# NOTE: Including metadata in CSV uploads will likely be removed in a future API change +def test_csv_upload_with_valid_metadata( + managed_workspace, managed_user, server, data_directory +): + """Test that uploading a CSV file with metadata succeeds.""" + + table_name = "test" + metadata = {"columns": [{"key": "test", "type": "label"}]} + + with open(data_directory / "membership_with_keys.csv") as csv_file: + request_body = csv_file.read() + + with conftest.login(managed_user, server): + resp = server.post( + f"/api/csv/{managed_workspace.name}/{table_name}", + data=request_body, + query_string={"metadata": json.dumps(metadata)}, + ) + + assert resp.status_code == 200 + + resp = server.get( + f"/api/workspaces/{managed_workspace.name}/tables/{table_name}/metadata" + ) + + assert resp.status_code == 200 + assert resp.json["table"] == metadata + + +def test_csv_upload_with_invalid_metadata( + managed_workspace, managed_user, server, data_directory +): + """Test that uploading a CSV file with invalid json fails properly.""" + + table_name = "test" + metadata_str = "{" + with open(data_directory / "membership_with_keys.csv") as csv_file: + request_body = csv_file.read() + + with conftest.login(managed_user, server): + resp = server.post( + f"/api/csv/{managed_workspace.name}/{table_name}", + data=request_body, + query_string={"metadata": metadata_str}, + ) + + assert resp.status_code == 400 + assert resp.json["argument"] == "metadata" + assert resp.json["value"] == metadata_str