diff --git a/src/database/setups.py b/src/database/setups.py new file mode 100644 index 00000000..6ac0c83f --- /dev/null +++ b/src/database/setups.py @@ -0,0 +1,49 @@ +from sqlalchemy import Connection, text +from sqlalchemy.engine import Row + + +def get(id_: int, connection: Connection) -> Row | None: + """Get the setup by its ID.""" + row = connection.execute( + text( + """ + SELECT * + FROM algorithm_setup + WHERE sid = :setup_id + """, + ), + parameters={"setup_id": id_}, + ) + return row.one_or_none() + + +def get_tags_for(id_: int, connection: Connection) -> list[str]: + """Get all tags for a specific setup.""" + rows = connection.execute( + text( + """ + SELECT tag + FROM setup_tag + WHERE id = :setup_id + """, + ), + parameters={"setup_id": id_}, + ) + return [row.tag for row in rows] + + +def tag(id_: int, tag_: str, *, user_id: int, connection: Connection) -> None: + """Insert a new tag for the setup.""" + connection.execute( + text( + """ + INSERT INTO setup_tag(`id`, `tag`, `uploader`) + VALUES (:setup_id, :tag, :user_id) + """, + ), + parameters={ + "setup_id": id_, + "user_id": user_id, + "tag": tag_, + }, + ) diff --git a/src/main.py b/src/main.py index 560b4c50..dca6fc73 100644 --- a/src/main.py +++ b/src/main.py @@ -11,6 +11,7 @@ from routers.openml.evaluations import router as evaluationmeasures_router from routers.openml.flows import router as flows_router from routers.openml.qualities import router as qualities_router +from routers.openml.setups import router as setups_router from routers.openml.study import router as study_router from routers.openml.tasks import router as task_router from routers.openml.tasktype import router as ttype_router @@ -47,6 +48,7 @@ def create_api() -> FastAPI: app = FastAPI(**fastapi_kwargs) app.include_router(datasets_router) + app.include_router(setups_router) app.include_router(qualities_router) app.include_router(mldcat_ap_router) app.include_router(ttype_router) diff --git a/src/routers/openml/setups.py b/src/routers/openml/setups.py new file mode 100644 index 00000000..8af937cf --- /dev/null +++ b/src/routers/openml/setups.py @@ -0,0 +1,68 @@ +from http import HTTPStatus +from typing import Annotated, Any + +from fastapi import APIRouter, Body, Depends, HTTPException +from sqlalchemy import Connection + +import database.setups +from database.users import User, UserGroup +from routers.dependencies import expdb_connection, fetch_user +from routers.types import SystemString64 + +router = APIRouter(prefix="/setup", tags=["setups"]) + + +def create_authentication_failed_error() -> HTTPException: + return HTTPException( + status_code=HTTPStatus.PRECONDITION_FAILED, + detail={"code": "103", "message": "Authentication failed"}, + ) + + +def create_tag_exists_error(setup_id: int, tag: str) -> HTTPException: + return HTTPException( + # Changed from INTERNAL_SERVER_ERROR (500) to CONFLICT (409) + status_code=HTTPStatus.CONFLICT, + detail={ + "code": "473", + "message": "Entity already tagged by this tag.", + "additional_information": f"id={setup_id}; tag={tag}", + }, + ) + + +@router.post("/tag") +def tag_setup( + setup_id: Annotated[int, Body()], + tag: Annotated[str, Body(..., embed=False), SystemString64], + user: Annotated[User | None, Depends(fetch_user)] = None, + expdb_db: Annotated[Connection, Depends(expdb_connection)] = None, +) -> dict[str, dict[str, Any]]: + # 1. AUTHENTICATE FIRST + if user is None: + raise create_authentication_failed_error() + + # 2. VERIFY EXISTENCE + setup = database.setups.get(setup_id, expdb_db) + if not setup: + raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail="Setup not found") + + # 3. VERIFY OWNERSHIP / PERMISSIONS + # (Fixes the crash by not looking for a Dataset 'visibility' column) + is_admin = UserGroup.ADMIN in user.groups + is_owner = getattr(setup, "uploader", None) == user.user_id + + if not (is_admin or is_owner): + raise HTTPException(status_code=HTTPStatus.FORBIDDEN, detail="No access granted") + + # 4. CHECK IF TAG EXISTS + tags = database.setups.get_tags_for(setup_id, expdb_db) + if tag.casefold() in [t.casefold() for t in tags]: + raise create_tag_exists_error(setup_id, tag) + + # 5. APPLY THE TAG + database.setups.tag(setup_id, tag, user_id=user.user_id, connection=expdb_db) + + return { + "setup_tag": {"id": str(setup_id), "tag": [*tags, tag]}, + }