diff --git a/pyproject.toml b/pyproject.toml index d3b013c7..e056595a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ dependencies = [ "pydantic", "uvicorn", "sqlalchemy", - "mysqlclient", + "aiomysql", "python_dotenv", "xmltodict", ] @@ -27,6 +27,7 @@ dev = [ "coverage", "pre-commit", "pytest", + "pytest-asyncio", "pytest-mock", "httpx", "hypothesis", @@ -77,6 +78,8 @@ plugins = [ ] [tool.pytest.ini_options] +asyncio_mode = "auto" +asyncio_default_fixture_loop_scope = "function" pythonpath = [ "src" ] diff --git a/src/config.toml b/src/config.toml index 10d75534..79f7c557 100644 --- a/src/config.toml +++ b/src/config.toml @@ -11,7 +11,7 @@ root_path="" host="openml-test-database" port="3306" # SQLAlchemy `dialect` and `driver`: https://docs.sqlalchemy.org/en/20/dialects/index.html -drivername="mysql" +drivername="mysql+aiomysql" [databases.expdb] database="openml_expdb" diff --git a/src/core/access.py b/src/core/access.py index c44d97e6..72f9a73d 100644 --- a/src/core/access.py +++ b/src/core/access.py @@ -4,12 +4,17 @@ from schemas.datasets.openml import Visibility -def _user_has_access( +async def _user_has_access( dataset: Row, user: User | None = None, ) -> bool: """Determine if `user` has the right to view `dataset`.""" is_public = dataset.visibility == Visibility.PUBLIC - return is_public or ( - user is not None and (user.user_id == dataset.uploader or UserGroup.ADMIN in user.groups) - ) + if is_public: + return True + if user is None: + return False + if user.user_id == dataset.uploader: + return True + groups = await user.get_groups() + return UserGroup.ADMIN in groups diff --git a/src/database/datasets.py b/src/database/datasets.py index f011a651..df05a549 100644 --- a/src/database/datasets.py +++ b/src/database/datasets.py @@ -2,14 +2,15 @@ import datetime -from sqlalchemy import Connection, text +from sqlalchemy import text from sqlalchemy.engine import Row +from sqlalchemy.ext.asyncio import AsyncConnection from schemas.datasets.openml import Feature -def get(id_: int, connection: Connection) -> Row | None: - row = connection.execute( +async def get(id_: int, connection: AsyncConnection) -> Row | None: + row = await connection.execute( text( """ SELECT * @@ -22,8 +23,8 @@ def get(id_: int, connection: Connection) -> Row | None: return row.one_or_none() -def get_file(*, file_id: int, connection: Connection) -> Row | None: - row = connection.execute( +async def get_file(*, file_id: int, connection: AsyncConnection) -> Row | None: + row = await connection.execute( text( """ SELECT * @@ -36,8 +37,8 @@ def get_file(*, file_id: int, connection: Connection) -> Row | None: return row.one_or_none() -def get_tags_for(id_: int, connection: Connection) -> list[str]: - rows = connection.execute( +async def get_tags_for(id_: int, connection: AsyncConnection) -> list[str]: + rows = await connection.execute( text( """ SELECT * @@ -50,8 +51,8 @@ def get_tags_for(id_: int, connection: Connection) -> list[str]: return [row.tag for row in rows] -def tag(id_: int, tag_: str, *, user_id: int, connection: Connection) -> None: - connection.execute( +async def tag(id_: int, tag_: str, *, user_id: int, connection: AsyncConnection) -> None: + await connection.execute( text( """ INSERT INTO dataset_tag(`id`, `tag`, `uploader`) @@ -66,12 +67,12 @@ def tag(id_: int, tag_: str, *, user_id: int, connection: Connection) -> None: ) -def get_description( +async def get_description( id_: int, - connection: Connection, + connection: AsyncConnection, ) -> Row | None: """Get the most recent description for the dataset.""" - row = connection.execute( + row = await connection.execute( text( """ SELECT * @@ -85,9 +86,9 @@ def get_description( return row.first() -def get_status(id_: int, connection: Connection) -> Row | None: +async def get_status(id_: int, connection: AsyncConnection) -> Row | None: """Get most recent status for the dataset.""" - row = connection.execute( + row = await connection.execute( text( """ SELECT * @@ -101,8 +102,8 @@ def get_status(id_: int, connection: Connection) -> Row | None: return row.first() -def get_latest_processing_update(dataset_id: int, connection: Connection) -> Row | None: - row = connection.execute( +async def get_latest_processing_update(dataset_id: int, connection: AsyncConnection) -> Row | None: + row = await connection.execute( text( """ SELECT * @@ -116,8 +117,8 @@ def get_latest_processing_update(dataset_id: int, connection: Connection) -> Row return row.one_or_none() -def get_features(dataset_id: int, connection: Connection) -> list[Feature]: - rows = connection.execute( +async def get_features(dataset_id: int, connection: AsyncConnection) -> list[Feature]: + rows = await connection.execute( text( """ SELECT `index`,`name`,`data_type`,`is_target`, @@ -131,8 +132,13 @@ def get_features(dataset_id: int, connection: Connection) -> list[Feature]: return [Feature(**row, nominal_values=None) for row in rows.mappings()] -def get_feature_values(dataset_id: int, *, feature_index: int, connection: Connection) -> list[str]: - rows = connection.execute( +async def get_feature_values( + dataset_id: int, + *, + feature_index: int, + connection: AsyncConnection, +) -> list[str]: + rows = await connection.execute( text( """ SELECT `value` @@ -145,14 +151,14 @@ def get_feature_values(dataset_id: int, *, feature_index: int, connection: Conne return [row.value for row in rows] -def update_status( +async def update_status( dataset_id: int, status: str, *, user_id: int, - connection: Connection, + connection: AsyncConnection, ) -> None: - connection.execute( + await connection.execute( text( """ INSERT INTO dataset_status(`did`,`status`,`status_date`,`user_id`) @@ -168,8 +174,8 @@ def update_status( ) -def remove_deactivated_status(dataset_id: int, connection: Connection) -> None: - connection.execute( +async def remove_deactivated_status(dataset_id: int, connection: AsyncConnection) -> None: + await connection.execute( text( """ DELETE FROM dataset_status diff --git a/src/database/evaluations.py b/src/database/evaluations.py index 799d4112..6f41a6ac 100644 --- a/src/database/evaluations.py +++ b/src/database/evaluations.py @@ -1,30 +1,33 @@ from collections.abc import Sequence from typing import cast -from sqlalchemy import Connection, Row, text +from sqlalchemy import Row, text +from sqlalchemy.ext.asyncio import AsyncConnection from core.formatting import _str_to_bool from schemas.datasets.openml import EstimationProcedure -def get_math_functions(function_type: str, connection: Connection) -> Sequence[Row]: +async def get_math_functions(function_type: str, connection: AsyncConnection) -> Sequence[Row]: return cast( "Sequence[Row]", - connection.execute( - text( - """ + ( + await connection.execute( + text( + """ SELECT * FROM math_function WHERE `functionType` = :function_type """, - ), - parameters={"function_type": function_type}, + ), + parameters={"function_type": function_type}, + ) ).all(), ) -def get_estimation_procedures(connection: Connection) -> list[EstimationProcedure]: - rows = connection.execute( +async def get_estimation_procedures(connection: AsyncConnection) -> list[EstimationProcedure]: + rows = await connection.execute( text( """ SELECT `id` as 'id_', `ttid` as 'task_type_id', `name`, `type` as 'type_', diff --git a/src/database/flows.py b/src/database/flows.py index 3129e91e..55f05259 100644 --- a/src/database/flows.py +++ b/src/database/flows.py @@ -1,27 +1,26 @@ from collections.abc import Sequence from typing import cast -from sqlalchemy import Connection, Row, text +from sqlalchemy import Row, text +from sqlalchemy.ext.asyncio import AsyncConnection -def get_subflows(for_flow: int, expdb: Connection) -> Sequence[Row]: - return cast( - "Sequence[Row]", - expdb.execute( - text( - """ - SELECT child as child_id, identifier - FROM implementation_component - WHERE parent = :flow_id - """, - ), - parameters={"flow_id": for_flow}, +async def get_subflows(for_flow: int, expdb: AsyncConnection) -> Sequence[Row]: + result = await expdb.execute( + text( + """ + SELECT child as child_id, identifier + FROM implementation_component + WHERE parent = :flow_id + """, ), + parameters={"flow_id": for_flow}, ) + return cast("Sequence[Row]", result.all()) -def get_tags(flow_id: int, expdb: Connection) -> list[str]: - tag_rows = expdb.execute( +async def get_tags(flow_id: int, expdb: AsyncConnection) -> list[str]: + tag_rows = await expdb.execute( text( """ SELECT tag @@ -34,25 +33,23 @@ def get_tags(flow_id: int, expdb: Connection) -> list[str]: return [tag.tag for tag in tag_rows] -def get_parameters(flow_id: int, expdb: Connection) -> Sequence[Row]: - return cast( - "Sequence[Row]", - expdb.execute( - text( - """ - SELECT *, defaultValue as default_value, dataType as data_type - FROM input - WHERE implementation_id = :flow_id - """, - ), - parameters={"flow_id": flow_id}, +async def get_parameters(flow_id: int, expdb: AsyncConnection) -> Sequence[Row]: + result = await expdb.execute( + text( + """ + SELECT *, defaultValue as default_value, dataType as data_type + FROM input + WHERE implementation_id = :flow_id + """, ), + parameters={"flow_id": flow_id}, ) + return cast("Sequence[Row]", result.all()) -def get_by_name(name: str, external_version: str, expdb: Connection) -> Row | None: +async def get_by_name(name: str, external_version: str, expdb: AsyncConnection) -> Row | None: """Gets flow by name and external version.""" - return expdb.execute( + result = await expdb.execute( text( """ SELECT *, uploadDate as upload_date @@ -61,11 +58,12 @@ def get_by_name(name: str, external_version: str, expdb: Connection) -> Row | No """, ), parameters={"name": name, "external_version": external_version}, - ).one_or_none() + ) + return result.one_or_none() -def get(id_: int, expdb: Connection) -> Row | None: - return expdb.execute( +async def get(id_: int, expdb: AsyncConnection) -> Row | None: + result = await expdb.execute( text( """ SELECT *, uploadDate as upload_date @@ -74,4 +72,5 @@ def get(id_: int, expdb: Connection) -> Row | None: """, ), parameters={"flow_id": id_}, - ).one_or_none() + ) + return result.one_or_none() diff --git a/src/database/qualities.py b/src/database/qualities.py index 81499c1e..34fa2d9c 100644 --- a/src/database/qualities.py +++ b/src/database/qualities.py @@ -1,13 +1,14 @@ from collections import defaultdict from collections.abc import Iterable -from sqlalchemy import Connection, text +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncConnection from schemas.datasets.openml import Quality -def get_for_dataset(dataset_id: int, connection: Connection) -> list[Quality]: - rows = connection.execute( +async def get_for_dataset(dataset_id: int, connection: AsyncConnection) -> list[Quality]: + rows = await connection.execute( text( """ SELECT `quality`,`value` @@ -20,10 +21,10 @@ def get_for_dataset(dataset_id: int, connection: Connection) -> list[Quality]: return [Quality(name=row.quality, value=row.value) for row in rows] -def get_for_datasets( +async def get_for_datasets( dataset_ids: Iterable[int], quality_names: Iterable[str], - connection: Connection, + connection: AsyncConnection, ) -> dict[int, list[Quality]]: """Don't call with user-provided input, as query is not parameterized.""" qualities_filter = ",".join(f"'{q}'" for q in quality_names) @@ -35,7 +36,7 @@ def get_for_datasets( WHERE `data` in ({dids}) AND `quality` IN ({qualities_filter}) """, # noqa: S608 - dids and qualities are not user-provided ) - rows = connection.execute(qualities_query) + rows = await connection.execute(qualities_query) qualities_by_id = defaultdict(list) for did, quality, value in rows: if value is not None: @@ -43,10 +44,10 @@ def get_for_datasets( return dict(qualities_by_id) -def list_all_qualities(connection: Connection) -> list[str]: +async def list_all_qualities(connection: AsyncConnection) -> list[str]: # The current implementation only fetches *used* qualities, otherwise you should # query: SELECT `name` FROM `quality` WHERE `type`='DataQuality' - qualities_ = connection.execute( + qualities_ = await connection.execute( text( """ SELECT DISTINCT(`quality`) diff --git a/src/database/setup.py b/src/database/setup.py index 3a1be2f6..f615881f 100644 --- a/src/database/setup.py +++ b/src/database/setup.py @@ -1,5 +1,5 @@ -from sqlalchemy import Engine, create_engine from sqlalchemy.engine import URL +from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine from config import load_database_configuration @@ -7,25 +7,25 @@ _expdb_engine = None -def _create_engine(database_name: str) -> Engine: +def _create_engine(database_name: str) -> AsyncEngine: database_configuration = load_database_configuration() echo = database_configuration[database_name].pop("echo", False) db_url = URL.create(**database_configuration[database_name]) - return create_engine( + return create_async_engine( db_url, echo=echo, pool_recycle=3600, ) -def user_database() -> Engine: +def user_database() -> AsyncEngine: global _user_engine # noqa: PLW0603 if _user_engine is None: _user_engine = _create_engine("openml") return _user_engine -def expdb_database() -> Engine: +def expdb_database() -> AsyncEngine: global _expdb_engine # noqa: PLW0603 if _expdb_engine is None: _expdb_engine = _create_engine("expdb") diff --git a/src/database/studies.py b/src/database/studies.py index 35c1b790..cb4b90dd 100644 --- a/src/database/studies.py +++ b/src/database/studies.py @@ -3,39 +3,44 @@ from datetime import datetime from typing import cast -from sqlalchemy import Connection, Row, text +from sqlalchemy import Row, text +from sqlalchemy.ext.asyncio import AsyncConnection from database.users import User from schemas.study import CreateStudy, StudyType -def get_by_id(id_: int, connection: Connection) -> Row | None: - return connection.execute( - text( - """ +async def get_by_id(id_: int, connection: AsyncConnection) -> Row | None: + return ( + await connection.execute( + text( + """ SELECT *, main_entity_type as type_ FROM study WHERE id = :study_id """, - ), - parameters={"study_id": id_}, + ), + parameters={"study_id": id_}, + ) ).one_or_none() -def get_by_alias(alias: str, connection: Connection) -> Row | None: - return connection.execute( - text( - """ +async def get_by_alias(alias: str, connection: AsyncConnection) -> Row | None: + return ( + await connection.execute( + text( + """ SELECT *, main_entity_type as type_ FROM study WHERE alias = :study_id """, - ), - parameters={"study_id": alias}, + ), + parameters={"study_id": alias}, + ) ).one_or_none() -def get_study_data(study: Row, expdb: Connection) -> Sequence[Row]: +async def get_study_data(study: Row, expdb: AsyncConnection) -> Sequence[Row]: """Return data related to the study, content depends on the study type. For task studies: (task id, dataset id) @@ -44,22 +49,25 @@ def get_study_data(study: Row, expdb: Connection) -> Sequence[Row]: if study.type_ == StudyType.TASK: return cast( "Sequence[Row]", - expdb.execute( - text( - """ + ( + await expdb.execute( + text( + """ SELECT ts.task_id as task_id, ti.value as data_id FROM task_study as ts LEFT JOIN task_inputs ti ON ts.task_id = ti.task_id WHERE ts.study_id = :study_id AND ti.input = 'source_data' """, - ), - parameters={"study_id": study.id}, + ), + parameters={"study_id": study.id}, + ) ).all(), ) return cast( "Sequence[Row]", - expdb.execute( - text( - """ + ( + await expdb.execute( + text( + """ SELECT rs.run_id as run_id, run.task_id as task_id, @@ -72,14 +80,15 @@ def get_study_data(study: Row, expdb: Connection) -> Sequence[Row]: JOIN task_inputs as ti ON ti.task_id = run.task_id WHERE rs.study_id = :study_id AND ti.input = 'source_data' """, - ), - parameters={"study_id": study.id}, + ), + parameters={"study_id": study.id}, + ) ).all(), ) -def create(study: CreateStudy, user: User, expdb: Connection) -> int: - expdb.execute( +async def create(study: CreateStudy, user: User, expdb: AsyncConnection) -> int: + await expdb.execute( text( """ INSERT INTO study ( @@ -102,12 +111,12 @@ def create(study: CreateStudy, user: User, expdb: Connection) -> int: "benchmark_suite": study.benchmark_suite, }, ) - (study_id,) = expdb.execute(text("""SELECT LAST_INSERT_ID();""")).one() + (study_id,) = (await expdb.execute(text("""SELECT LAST_INSERT_ID();"""))).one() return cast("int", study_id) -def attach_task(task_id: int, study_id: int, user: User, expdb: Connection) -> None: - expdb.execute( +async def attach_task(task_id: int, study_id: int, user: User, expdb: AsyncConnection) -> None: + await expdb.execute( text( """ INSERT INTO task_study (study_id, task_id, uploader) @@ -118,8 +127,8 @@ def attach_task(task_id: int, study_id: int, user: User, expdb: Connection) -> N ) -def attach_run(*, run_id: int, study_id: int, user: User, expdb: Connection) -> None: - expdb.execute( +async def attach_run(*, run_id: int, study_id: int, user: User, expdb: AsyncConnection) -> None: + await expdb.execute( text( """ INSERT INTO run_study (study_id, run_id, uploader) @@ -130,16 +139,16 @@ def attach_run(*, run_id: int, study_id: int, user: User, expdb: Connection) -> ) -def attach_tasks( +async def attach_tasks( *, study_id: int, task_ids: list[int], user: User, - connection: Connection, + connection: AsyncConnection, ) -> None: to_link = [(study_id, task_id, user.user_id) for task_id in task_ids] try: - connection.execute( + await connection.execute( text( """ INSERT INTO task_study (study_id, task_id, uploader) @@ -162,10 +171,10 @@ def attach_tasks( raise ValueError(msg) from e -def attach_runs( +async def attach_runs( study_id: int, run_ids: list[int], user: User, - connection: Connection, + connection: AsyncConnection, ) -> None: raise NotImplementedError diff --git a/src/database/tasks.py b/src/database/tasks.py index 97caef3b..a6ea83ec 100644 --- a/src/database/tasks.py +++ b/src/database/tasks.py @@ -1,99 +1,115 @@ from collections.abc import Sequence from typing import cast -from sqlalchemy import Connection, Row, text +from sqlalchemy import Row, text +from sqlalchemy.ext.asyncio import AsyncConnection -def get(id_: int, expdb: Connection) -> Row | None: - return expdb.execute( - text( - """ +async def get(id_: int, expdb: AsyncConnection) -> Row | None: + return ( + await expdb.execute( + text( + """ SELECT * FROM task WHERE `task_id` = :task_id """, - ), - parameters={"task_id": id_}, + ), + parameters={"task_id": id_}, + ) ).one_or_none() -def get_task_types(expdb: Connection) -> Sequence[Row]: +async def get_task_types(expdb: AsyncConnection) -> Sequence[Row]: return cast( "Sequence[Row]", - expdb.execute( - text( - """ + ( + await expdb.execute( + text( + """ SELECT `ttid`, `name`, `description`, `creator` FROM task_type """, - ), + ), + ) ).all(), ) -def get_task_type(task_type_id: int, expdb: Connection) -> Row | None: - return expdb.execute( - text( - """ +async def get_task_type(task_type_id: int, expdb: AsyncConnection) -> Row | None: + return ( + await expdb.execute( + text( + """ SELECT * FROM task_type WHERE `ttid`=:ttid """, - ), - parameters={"ttid": task_type_id}, + ), + parameters={"ttid": task_type_id}, + ) ).one_or_none() -def get_input_for_task_type(task_type_id: int, expdb: Connection) -> Sequence[Row]: +async def get_input_for_task_type(task_type_id: int, expdb: AsyncConnection) -> Sequence[Row]: return cast( "Sequence[Row]", - expdb.execute( - text( - """ + ( + await expdb.execute( + text( + """ SELECT * FROM task_type_inout WHERE `ttid`=:ttid AND `io`='input' """, - ), - parameters={"ttid": task_type_id}, + ), + parameters={"ttid": task_type_id}, + ) ).all(), ) -def get_input_for_task(id_: int, expdb: Connection) -> Sequence[Row]: +async def get_input_for_task(id_: int, expdb: AsyncConnection) -> Sequence[Row]: return cast( "Sequence[Row]", - expdb.execute( - text( - """ + ( + await expdb.execute( + text( + """ SELECT `input`, `value` FROM task_inputs WHERE task_id = :task_id """, - ), - parameters={"task_id": id_}, + ), + parameters={"task_id": id_}, + ) ).all(), ) -def get_task_type_inout_with_template(task_type: int, expdb: Connection) -> Sequence[Row]: +async def get_task_type_inout_with_template( + task_type: int, + expdb: AsyncConnection, +) -> Sequence[Row]: return cast( "Sequence[Row]", - expdb.execute( - text( - """ + ( + await expdb.execute( + text( + """ SELECT * FROM task_type_inout WHERE `ttid`=:ttid AND `template_api` IS NOT NULL """, - ), - parameters={"ttid": task_type}, + ), + parameters={"ttid": task_type}, + ) ).all(), ) -def get_tags(id_: int, expdb: Connection) -> list[str]: - tag_rows = expdb.execute( +async def get_tags(id_: int, expdb: AsyncConnection) -> list[str]: + tag_rows = await expdb.execute( text( """ SELECT `tag` diff --git a/src/database/users.py b/src/database/users.py index b439be7e..15167417 100644 --- a/src/database/users.py +++ b/src/database/users.py @@ -3,7 +3,8 @@ from typing import Annotated, Self from pydantic import StringConstraints -from sqlalchemy import Connection, text +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncConnection from config import load_configuration @@ -26,22 +27,24 @@ class UserGroup(IntEnum): READ_ONLY = (3,) -def get_user_id_for(*, api_key: APIKey, connection: Connection) -> int | None: - user = connection.execute( - text( - """ +async def get_user_id_for(*, api_key: APIKey, connection: AsyncConnection) -> int | None: + user = ( + await connection.execute( + text( + """ SELECT * FROM users WHERE session_hash = :api_key """, - ), - parameters={"api_key": api_key}, + ), + parameters={"api_key": api_key}, + ) ).one_or_none() return user.id if user else None -def get_user_groups_for(*, user_id: int, connection: Connection) -> list[UserGroup]: - row = connection.execute( +async def get_user_groups_for(*, user_id: int, connection: AsyncConnection) -> list[UserGroup]: + row = await connection.execute( text( """ SELECT group_id @@ -57,18 +60,17 @@ def get_user_groups_for(*, user_id: int, connection: Connection) -> list[UserGro @dataclasses.dataclass class User: user_id: int - _database: Connection + _database: AsyncConnection _groups: list[UserGroup] | None = None @classmethod - def fetch(cls, api_key: APIKey, user_db: Connection) -> Self | None: - if user_id := get_user_id_for(api_key=api_key, connection=user_db): + async def fetch(cls, api_key: APIKey, user_db: AsyncConnection) -> Self | None: + if user_id := await get_user_id_for(api_key=api_key, connection=user_db): return cls(user_id, _database=user_db) return None - @property - def groups(self) -> list[UserGroup]: + async def get_groups(self) -> list[UserGroup]: if self._groups is None: - groups = get_user_groups_for(user_id=self.user_id, connection=self._database) + groups = await get_user_groups_for(user_id=self.user_id, connection=self._database) self._groups = [UserGroup(group_id) for group_id in groups] return self._groups diff --git a/src/routers/dependencies.py b/src/routers/dependencies.py index 2ddccf83..69ddd786 100644 --- a/src/routers/dependencies.py +++ b/src/routers/dependencies.py @@ -2,31 +2,31 @@ from fastapi import Depends from pydantic import BaseModel -from sqlalchemy import Connection +from sqlalchemy.ext.asyncio import AsyncConnection from database.setup import expdb_database, user_database from database.users import APIKey, User -def expdb_connection() -> Connection: +async def expdb_connection() -> AsyncConnection: engine = expdb_database() - with engine.connect() as connection: + async with engine.connect() as connection: yield connection - connection.commit() + await connection.commit() -def userdb_connection() -> Connection: +async def userdb_connection() -> AsyncConnection: engine = user_database() - with engine.connect() as connection: + async with engine.connect() as connection: yield connection - connection.commit() + await connection.commit() -def fetch_user( +async def fetch_user( api_key: APIKey | None = None, - user_data: Annotated[Connection, Depends(userdb_connection)] = None, + user_data: Annotated[AsyncConnection, Depends(userdb_connection)] = None, ) -> User | None: - return User.fetch(api_key, user_data) if api_key else None + return await User.fetch(api_key, user_data) if api_key else None class Pagination(BaseModel): diff --git a/src/routers/mldcat_ap/dataset.py b/src/routers/mldcat_ap/dataset.py index db34e5ce..73902524 100644 --- a/src/routers/mldcat_ap/dataset.py +++ b/src/routers/mldcat_ap/dataset.py @@ -7,7 +7,7 @@ from typing import Annotated from fastapi import APIRouter, Depends, HTTPException -from sqlalchemy import Connection +from sqlalchemy.ext.asyncio import AsyncConnection import config from database.users import User @@ -37,19 +37,19 @@ path="/distribution/{distribution_id}", description="Get meta-data for distribution with ID `distribution_id`.", ) -def get_mldcat_ap_distribution( +async def get_mldcat_ap_distribution( distribution_id: int, user: Annotated[User | None, Depends(fetch_user)] = None, - user_db: Annotated[Connection, Depends(userdb_connection)] = None, - expdb: Annotated[Connection, Depends(expdb_connection)] = None, + user_db: Annotated[AsyncConnection, Depends(userdb_connection)] = None, + expdb: Annotated[AsyncConnection, Depends(expdb_connection)] = None, ) -> JsonLDGraph: - oml_dataset = get_dataset( + oml_dataset = await get_dataset( dataset_id=distribution_id, user=user, user_db=user_db, expdb_db=expdb, ) - openml_features = get_dataset_features(distribution_id, user, expdb) + openml_features = await get_dataset_features(distribution_id, user, expdb) features = [ Feature( id_=f"{_server_url}/feature/{distribution_id}/{feature.index}", @@ -58,7 +58,7 @@ def get_mldcat_ap_distribution( ) for feature in openml_features ] - oml_qualities = get_qualities(distribution_id, user, expdb) + oml_qualities = await get_qualities(distribution_id, user, expdb) qualities = [ Quality( id_=f"{_server_url}/quality/{quality.name}/{distribution_id}", @@ -119,7 +119,7 @@ def get_mldcat_ap_distribution( path="/dataservice/{service_id}", description="Get meta-data for a specific data service.", ) -def get_dataservice(service_id: int) -> JsonLDGraph: +async def get_dataservice(service_id: int) -> JsonLDGraph: if service_id != 1: raise HTTPException(status_code=404, detail="Service not found.") return JsonLDGraph( @@ -138,13 +138,13 @@ def get_dataservice(service_id: int) -> JsonLDGraph: path="/quality/{quality_name}/{distribution_id}", description="Get meta-data for a specific quality and distribution.", ) -def get_distribution_quality( +async def get_distribution_quality( quality_name: str, distribution_id: int, user: Annotated[User | None, Depends(fetch_user)] = None, - expdb: Annotated[Connection, Depends(expdb_connection)] = None, + expdb: Annotated[AsyncConnection, Depends(expdb_connection)] = None, ) -> JsonLDGraph: - qualities = get_qualities(distribution_id, user, expdb) + qualities = await get_qualities(distribution_id, user, expdb) quality = next(q for q in qualities if q.name == quality_name) example_quality = Quality( id_=f"{_server_url}/quality/{quality_name}/{distribution_id}", @@ -164,13 +164,13 @@ def get_distribution_quality( path="/feature/{distribution_id}/{feature_no}", description="Get meta-data for the n-th feature of a distribution.", ) -def get_distribution_feature( +async def get_distribution_feature( distribution_id: int, feature_no: int, - user: Annotated[Connection, Depends(fetch_user)] = None, - expdb: Annotated[Connection, Depends(expdb_connection)] = None, + user: Annotated[User | None, Depends(fetch_user)] = None, + expdb: Annotated[AsyncConnection, Depends(expdb_connection)] = None, ) -> JsonLDGraph: - features = get_dataset_features( + features = await get_dataset_features( dataset_id=distribution_id, user=user, expdb=expdb, diff --git a/src/routers/openml/datasets.py b/src/routers/openml/datasets.py index dda25117..f1956610 100644 --- a/src/routers/openml/datasets.py +++ b/src/routers/openml/datasets.py @@ -5,8 +5,9 @@ from typing import Annotated, Any, Literal, NamedTuple from fastapi import APIRouter, Body, Depends, HTTPException -from sqlalchemy import Connection, text +from sqlalchemy import text from sqlalchemy.engine import Row +from sqlalchemy.ext.asyncio import AsyncConnection import database.datasets import database.qualities @@ -29,20 +30,20 @@ @router.post( path="/tag", ) -def tag_dataset( +async def tag_dataset( data_id: Annotated[int, Body()], tag: Annotated[str, SystemString64], user: Annotated[User | None, Depends(fetch_user)] = None, - expdb_db: Annotated[Connection, Depends(expdb_connection)] = None, + expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)] = None, ) -> dict[str, dict[str, Any]]: - tags = database.datasets.get_tags_for(data_id, expdb_db) + tags = await database.datasets.get_tags_for(data_id, expdb_db) if tag.casefold() in [t.casefold() for t in tags]: raise create_tag_exists_error(data_id, tag) if user is None: raise create_authentication_failed_error() - database.datasets.tag(data_id, tag, user_id=user.user_id, connection=expdb_db) + await database.datasets.tag(data_id, tag, user_id=user.user_id, connection=expdb_db) return { "data_tag": {"id": str(data_id), "tag": [*tags, tag]}, } @@ -75,7 +76,7 @@ class DatasetStatusFilter(StrEnum): @router.post(path="/list", description="Provided for convenience, same as `GET` endpoint.") @router.get(path="/list") -def list_datasets( # noqa: PLR0913 +async def list_datasets( # noqa: PLR0913 pagination: Annotated[Pagination, Body(default_factory=Pagination)], data_name: Annotated[str | None, CasualString128] = None, tag: Annotated[str | None, SystemString64] = None, @@ -100,7 +101,7 @@ def list_datasets( # noqa: PLR0913 number_missing_values: Annotated[str | None, IntegerRange] = None, status: Annotated[DatasetStatusFilter, Body()] = DatasetStatusFilter.ACTIVE, user: Annotated[User | None, Depends(fetch_user)] = None, - expdb_db: Annotated[Connection, Depends(expdb_connection)] = None, + expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)] = None, ) -> list[dict[str, Any]]: current_status = text( """ @@ -126,10 +127,12 @@ def list_datasets( # noqa: PLR0913 where_status = ",".join(f"'{status}'" for status in statuses) if user is None: visible_to_user = "`visibility`='public'" - elif UserGroup.ADMIN in user.groups: - visible_to_user = "TRUE" else: - visible_to_user = f"(`visibility`='public' OR `uploader`={user.user_id})" + user_groups = await user.get_groups() + if UserGroup.ADMIN in user_groups: + visible_to_user = "TRUE" + else: + visible_to_user = f"(`visibility`='public' OR `uploader`={user.user_id})" where_name = "" if data_name is None else "AND `name`=:data_name" where_version = "" if data_version is None else "AND `version`=:data_version" @@ -190,7 +193,7 @@ def quality_clause(quality: str, range_: str | None) -> str: # subquery also has no user input. So I think this should be safe. ) columns = ["did", "name", "version", "format", "file_id", "status"] - rows = expdb_db.execute( + rows = await expdb_db.execute( matching_filter, parameters={ "tag": tag, @@ -230,7 +233,7 @@ def quality_clause(quality: str, range_: str | None) -> str: "NumberOfNumericFeatures", "NumberOfSymbolicFeatures", ] - qualities_by_dataset = database.qualities.get_for_datasets( + qualities_by_dataset = await database.qualities.get_for_datasets( dataset_ids=datasets.keys(), quality_names=qualities_to_show, connection=expdb_db, @@ -246,10 +249,16 @@ class ProcessingInformation(NamedTuple): error: str | None -def _get_processing_information(dataset_id: int, connection: Connection) -> ProcessingInformation: +async def _get_processing_information( + dataset_id: int, + connection: AsyncConnection, +) -> ProcessingInformation: """Return processing information, if any. Otherwise, all fields `None`.""" if not ( - data_processed := database.datasets.get_latest_processing_update(dataset_id, connection) + data_processed := await database.datasets.get_latest_processing_update( + dataset_id, + connection, + ) ): return ProcessingInformation(date=None, warning=None, error=None) @@ -259,20 +268,20 @@ def _get_processing_information(dataset_id: int, connection: Connection) -> Proc return ProcessingInformation(date=date_processed, warning=warning, error=error) -def _get_dataset_raise_otherwise( +async def _get_dataset_raise_otherwise( dataset_id: int, user: User | None, - expdb: Connection, + expdb: AsyncConnection, ) -> Row: """Fetches the dataset from the database if it exists and the user has permissions. Raises HTTPException if the dataset does not exist or the user can not access it. """ - if not (dataset := database.datasets.get(dataset_id, expdb)): + if not (dataset := await database.datasets.get(dataset_id, expdb)): error = _format_error(code=DatasetError.NOT_FOUND, message="Unknown dataset") raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail=error) - if not _user_has_access(dataset=dataset, user=user): + if not await _user_has_access(dataset=dataset, user=user): error = _format_error(code=DatasetError.NO_ACCESS, message="No access granted") raise HTTPException(status_code=HTTPStatus.FORBIDDEN, detail=error) @@ -280,22 +289,22 @@ def _get_dataset_raise_otherwise( @router.get("/features/{dataset_id}", response_model_exclude_none=True) -def get_dataset_features( +async def get_dataset_features( dataset_id: int, user: Annotated[User | None, Depends(fetch_user)] = None, - expdb: Annotated[Connection, Depends(expdb_connection)] = None, + expdb: Annotated[AsyncConnection, Depends(expdb_connection)] = None, ) -> list[Feature]: - _get_dataset_raise_otherwise(dataset_id, user, expdb) - features = database.datasets.get_features(dataset_id, expdb) + await _get_dataset_raise_otherwise(dataset_id, user, expdb) + features = await database.datasets.get_features(dataset_id, expdb) for feature in [f for f in features if f.data_type == FeatureType.NOMINAL]: - feature.nominal_values = database.datasets.get_feature_values( + feature.nominal_values = await database.datasets.get_feature_values( dataset_id, feature_index=feature.index, connection=expdb, ) if not features: - processing_state = database.datasets.get_latest_processing_update(dataset_id, expdb) + processing_state = await database.datasets.get_latest_processing_update(dataset_id, expdb) if processing_state is None: code, msg = ( 273, @@ -318,11 +327,11 @@ def get_dataset_features( @router.post( path="/status/update", ) -def update_dataset_status( +async def update_dataset_status( dataset_id: Annotated[int, Body()], status: Annotated[Literal[DatasetStatus.ACTIVE, DatasetStatus.DEACTIVATED], Body()], user: Annotated[User | None, Depends(fetch_user)], - expdb: Annotated[Connection, Depends(expdb_connection)], + expdb: Annotated[AsyncConnection, Depends(expdb_connection)], ) -> dict[str, str | int]: if user is None: raise HTTPException( @@ -330,21 +339,22 @@ def update_dataset_status( detail="Updating dataset status required authorization", ) - dataset = _get_dataset_raise_otherwise(dataset_id, user, expdb) + dataset = await _get_dataset_raise_otherwise(dataset_id, user, expdb) - can_deactivate = dataset.uploader == user.user_id or UserGroup.ADMIN in user.groups + user_groups = await user.get_groups() + can_deactivate = dataset.uploader == user.user_id or UserGroup.ADMIN in user_groups if status == DatasetStatus.DEACTIVATED and not can_deactivate: raise HTTPException( status_code=HTTPStatus.FORBIDDEN, detail={"code": 693, "message": "Dataset is not owned by you"}, ) - if status == DatasetStatus.ACTIVE and UserGroup.ADMIN not in user.groups: + if status == DatasetStatus.ACTIVE and UserGroup.ADMIN not in user_groups: raise HTTPException( status_code=HTTPStatus.FORBIDDEN, detail={"code": 696, "message": "Only administrators can activate datasets."}, ) - current_status = database.datasets.get_status(dataset_id, expdb) + current_status = await database.datasets.get_status(dataset_id, expdb) if current_status and current_status.status == status: raise HTTPException( status_code=HTTPStatus.PRECONDITION_FAILED, @@ -358,9 +368,14 @@ def update_dataset_status( # - active => deactivated (add a row) # - deactivated => active (delete a row) if current_status is None or status == DatasetStatus.DEACTIVATED: - database.datasets.update_status(dataset_id, status, user_id=user.user_id, connection=expdb) + await database.datasets.update_status( + dataset_id, + status, + user_id=user.user_id, + connection=expdb, + ) elif current_status.status == DatasetStatus.DEACTIVATED: - database.datasets.remove_deactivated_status(dataset_id, expdb) + await database.datasets.remove_deactivated_status(dataset_id, expdb) else: raise HTTPException( status_code=HTTPStatus.INTERNAL_SERVER_ERROR, @@ -374,15 +389,18 @@ def update_dataset_status( path="/{dataset_id}", description="Get meta-data for dataset with ID `dataset_id`.", ) -def get_dataset( +async def get_dataset( dataset_id: int, user: Annotated[User | None, Depends(fetch_user)] = None, - user_db: Annotated[Connection, Depends(userdb_connection)] = None, - expdb_db: Annotated[Connection, Depends(expdb_connection)] = None, + user_db: Annotated[AsyncConnection, Depends(userdb_connection)] = None, + expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)] = None, ) -> DatasetMetadata: - dataset = _get_dataset_raise_otherwise(dataset_id, user, expdb_db) + dataset = await _get_dataset_raise_otherwise(dataset_id, user, expdb_db) if not ( - dataset_file := database.datasets.get_file(file_id=dataset.file_id, connection=user_db) + dataset_file := await database.datasets.get_file( + file_id=dataset.file_id, + connection=user_db, + ) ): error = _format_error( code=DatasetError.NO_DATA_FILE, @@ -390,10 +408,10 @@ def get_dataset( ) raise HTTPException(status_code=HTTPStatus.PRECONDITION_FAILED, detail=error) - tags = database.datasets.get_tags_for(dataset_id, expdb_db) - description = database.datasets.get_description(dataset_id, expdb_db) - processing_result = _get_processing_information(dataset_id, expdb_db) - status = database.datasets.get_status(dataset_id, expdb_db) + tags = await database.datasets.get_tags_for(dataset_id, expdb_db) + description = await database.datasets.get_description(dataset_id, expdb_db) + processing_result = await _get_processing_information(dataset_id, expdb_db) + status = await database.datasets.get_status(dataset_id, expdb_db) status_ = DatasetStatus(status.status) if status else DatasetStatus.IN_PREPARATION diff --git a/src/routers/openml/estimation_procedure.py b/src/routers/openml/estimation_procedure.py index 1ebaf929..5df5c798 100644 --- a/src/routers/openml/estimation_procedure.py +++ b/src/routers/openml/estimation_procedure.py @@ -2,7 +2,7 @@ from typing import Annotated from fastapi import APIRouter, Depends -from sqlalchemy import Connection +from sqlalchemy.ext.asyncio import AsyncConnection import database.evaluations from routers.dependencies import expdb_connection @@ -12,7 +12,7 @@ @router.get("/list", response_model_exclude_none=True) -def get_estimation_procedures( - expdb: Annotated[Connection, Depends(expdb_connection)], +async def get_estimation_procedures( + expdb: Annotated[AsyncConnection, Depends(expdb_connection)], ) -> Iterable[EstimationProcedure]: - return database.evaluations.get_estimation_procedures(expdb) + return await database.evaluations.get_estimation_procedures(expdb) diff --git a/src/routers/openml/evaluations.py b/src/routers/openml/evaluations.py index 49178b7a..f6650b36 100644 --- a/src/routers/openml/evaluations.py +++ b/src/routers/openml/evaluations.py @@ -1,7 +1,7 @@ from typing import Annotated from fastapi import APIRouter, Depends -from sqlalchemy import Connection +from sqlalchemy.ext.asyncio import AsyncConnection import database.evaluations from routers.dependencies import expdb_connection @@ -10,8 +10,10 @@ @router.get("/list") -def get_evaluation_measures(expdb: Annotated[Connection, Depends(expdb_connection)]) -> list[str]: - functions = database.evaluations.get_math_functions( +async def get_evaluation_measures( + expdb: Annotated[AsyncConnection, Depends(expdb_connection)], +) -> list[str]: + functions = await database.evaluations.get_math_functions( function_type="EvaluationFunction", connection=expdb, ) diff --git a/src/routers/openml/flows.py b/src/routers/openml/flows.py index cb6df5d9..3c6c6d0d 100644 --- a/src/routers/openml/flows.py +++ b/src/routers/openml/flows.py @@ -1,8 +1,9 @@ +import asyncio from http import HTTPStatus from typing import Annotated, Literal from fastapi import APIRouter, Depends, HTTPException -from sqlalchemy import Connection +from sqlalchemy.ext.asyncio import AsyncConnection import database.flows from core.conversions import _str_to_num @@ -13,13 +14,17 @@ @router.get("/exists/{name}/{external_version}") -def flow_exists( +async def flow_exists( name: str, external_version: str, - expdb: Annotated[Connection, Depends(expdb_connection)], + expdb: Annotated[AsyncConnection, Depends(expdb_connection)], ) -> dict[Literal["flow_id"], int]: """Check if a Flow with the name and version exists, if so, return the flow id.""" - flow = database.flows.get_by_name(name=name, external_version=external_version, expdb=expdb) + flow = await database.flows.get_by_name( + name=name, + external_version=external_version, + expdb=expdb, + ) if flow is None: raise HTTPException( status_code=HTTPStatus.NOT_FOUND, @@ -28,13 +33,20 @@ def flow_exists( return {"flow_id": flow.id} +async def _make_subflow(identifier: str, child_id: int, expdb: AsyncConnection) -> Subflow: + return Subflow(identifier=identifier, flow=await get_flow(flow_id=child_id, expdb=expdb)) + + @router.get("/{flow_id}") -def get_flow(flow_id: int, expdb: Annotated[Connection, Depends(expdb_connection)] = None) -> Flow: - flow = database.flows.get(flow_id, expdb) +async def get_flow( + flow_id: int, + expdb: Annotated[AsyncConnection, Depends(expdb_connection)] = None, +) -> Flow: + flow = await database.flows.get(flow_id, expdb) if not flow: raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail="Flow not found") - parameter_rows = database.flows.get_parameters(flow_id, expdb) + parameter_rows = await database.flows.get_parameters(flow_id, expdb) parameters = [ Parameter( name=parameter.name, @@ -48,15 +60,16 @@ def get_flow(flow_id: int, expdb: Annotated[Connection, Depends(expdb_connection for parameter in parameter_rows ] - tags = database.flows.get_tags(flow_id, expdb) - subflow_rows = database.flows.get_subflows(flow_id, expdb) - subflows = [ - Subflow( - identifier=subflow.identifier, - flow=get_flow(flow_id=subflow.child_id, expdb=expdb), - ) - for subflow in subflow_rows - ] + tags = await database.flows.get_tags(flow_id, expdb) + subflow_rows = await database.flows.get_subflows(flow_id, expdb) + subflows = list( + await asyncio.gather( + *[ + _make_subflow(subflow.identifier, subflow.child_id, expdb) + for subflow in subflow_rows + ], + ), + ) return Flow( id_=flow.id, diff --git a/src/routers/openml/qualities.py b/src/routers/openml/qualities.py index 54181f8f..b71ca52f 100644 --- a/src/routers/openml/qualities.py +++ b/src/routers/openml/qualities.py @@ -2,7 +2,7 @@ from typing import Annotated, Literal from fastapi import APIRouter, Depends, HTTPException -from sqlalchemy import Connection +from sqlalchemy.ext.asyncio import AsyncConnection import database.datasets import database.qualities @@ -16,10 +16,10 @@ @router.get("/qualities/list") -def list_qualities( - expdb: Annotated[Connection, Depends(expdb_connection)], +async def list_qualities( + expdb: Annotated[AsyncConnection, Depends(expdb_connection)], ) -> dict[Literal["data_qualities_list"], dict[Literal["quality"], list[str]]]: - qualities = database.qualities.list_all_qualities(connection=expdb) + qualities = await database.qualities.list_all_qualities(connection=expdb) return { "data_qualities_list": { "quality": qualities, @@ -28,18 +28,18 @@ def list_qualities( @router.get("/qualities/{dataset_id}") -def get_qualities( +async def get_qualities( dataset_id: int, user: Annotated[User | None, Depends(fetch_user)], - expdb: Annotated[Connection, Depends(expdb_connection)], + expdb: Annotated[AsyncConnection, Depends(expdb_connection)], ) -> list[Quality]: - dataset = database.datasets.get(dataset_id, expdb) - if not dataset or not _user_has_access(dataset, user): + dataset = await database.datasets.get(dataset_id, expdb) + if not dataset or not await _user_has_access(dataset, user): raise HTTPException( status_code=HTTPStatus.PRECONDITION_FAILED, detail={"code": DatasetError.NO_DATA_FILE, "message": "Unknown dataset"}, ) from None - return database.qualities.get_for_dataset(dataset_id, expdb) + return await database.qualities.get_for_dataset(dataset_id, expdb) # The PHP API provided (sometime) helpful error messages # if not qualities: # check if dataset exists: error 360 diff --git a/src/routers/openml/study.py b/src/routers/openml/study.py index 6fe1dcc6..0d724ce3 100644 --- a/src/routers/openml/study.py +++ b/src/routers/openml/study.py @@ -3,7 +3,8 @@ from fastapi import APIRouter, Body, Depends, HTTPException from pydantic import BaseModel -from sqlalchemy import Connection, Row +from sqlalchemy import Row +from sqlalchemy.ext.asyncio import AsyncConnection import database.studies from core.formatting import _str_to_bool @@ -15,11 +16,15 @@ router = APIRouter(prefix="/studies", tags=["studies"]) -def _get_study_raise_otherwise(id_or_alias: int | str, user: User | None, expdb: Connection) -> Row: +async def _get_study_raise_otherwise( + id_or_alias: int | str, + user: User | None, + expdb: AsyncConnection, +) -> Row: if isinstance(id_or_alias, int) or id_or_alias.isdigit(): - study = database.studies.get_by_id(int(id_or_alias), expdb) + study = await database.studies.get_by_id(int(id_or_alias), expdb) else: - study = database.studies.get_by_alias(id_or_alias, expdb) + study = await database.studies.get_by_alias(id_or_alias, expdb) if study is None: raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail="Study not found.") @@ -29,7 +34,8 @@ def _get_study_raise_otherwise(id_or_alias: int | str, user: User | None, expdb: status_code=HTTPStatus.UNAUTHORIZED, detail="Must authenticate for private study.", ) - if study.creator != user.user_id and UserGroup.ADMIN not in user.groups: + user_groups = await user.get_groups() + if study.creator != user.user_id and UserGroup.ADMIN not in user_groups: raise HTTPException(status_code=HTTPStatus.FORBIDDEN, detail="Study is private.") if _str_to_bool(study.legacy): raise HTTPException( @@ -45,17 +51,18 @@ class AttachDetachResponse(BaseModel): @router.post("/attach") -def attach_to_study( +async def attach_to_study( study_id: Annotated[int, Body()], entity_ids: Annotated[list[int], Body()], user: Annotated[User | None, Depends(fetch_user)] = None, - expdb: Annotated[Connection, Depends(expdb_connection)] = None, + expdb: Annotated[AsyncConnection, Depends(expdb_connection)] = None, ) -> AttachDetachResponse: if user is None: raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, detail="User not found.") - study = _get_study_raise_otherwise(study_id, user, expdb) + study = await _get_study_raise_otherwise(study_id, user, expdb) # PHP lets *anyone* edit *any* study. We're not going to do that. - if study.creator != user.user_id and UserGroup.ADMIN not in user.groups: + user_groups = await user.get_groups() + if study.creator != user.user_id and UserGroup.ADMIN not in user_groups: raise HTTPException( status_code=HTTPStatus.FORBIDDEN, detail="Study can only be edited by its creator.", @@ -75,9 +82,9 @@ def attach_to_study( } try: if study.type_ == StudyType.TASK: - database.studies.attach_tasks(task_ids=entity_ids, **attach_kwargs) + await database.studies.attach_tasks(task_ids=entity_ids, **attach_kwargs) else: - database.studies.attach_runs(run_ids=entity_ids, **attach_kwargs) + await database.studies.attach_runs(run_ids=entity_ids, **attach_kwargs) except ValueError as e: raise HTTPException( status_code=HTTPStatus.CONFLICT, @@ -87,10 +94,10 @@ def attach_to_study( @router.post("/") -def create_study( +async def create_study( study: CreateStudy, user: Annotated[User | None, Depends(fetch_user)] = None, - expdb: Annotated[Connection, Depends(expdb_connection)] = None, + expdb: Annotated[AsyncConnection, Depends(expdb_connection)] = None, ) -> dict[Literal["study_id"], int]: if user is None: raise HTTPException( @@ -107,30 +114,35 @@ def create_study( status_code=HTTPStatus.BAD_REQUEST, detail="Cannot create a task study with runs.", ) - if study.alias and database.studies.get_by_alias(study.alias, expdb): + if study.alias and await database.studies.get_by_alias(study.alias, expdb): raise HTTPException( status_code=HTTPStatus.CONFLICT, detail="Study alias already exists.", ) - study_id = database.studies.create(study, user, expdb) + study_id = await database.studies.create(study, user, expdb) if study.main_entity_type == StudyType.TASK: for task_id in study.tasks: - database.studies.attach_task(task_id, study_id, user, expdb) + await database.studies.attach_task(task_id, study_id, user, expdb) if study.main_entity_type == StudyType.RUN: for run_id in study.runs: - database.studies.attach_run(run_id=run_id, study_id=study_id, user=user, expdb=expdb) + await database.studies.attach_run( + run_id=run_id, + study_id=study_id, + user=user, + expdb=expdb, + ) # Make sure that invalid fields raise an error (e.g., "task_ids") return {"study_id": study_id} @router.get("/{alias_or_id}") -def get_study( +async def get_study( alias_or_id: int | str, user: Annotated[User | None, Depends(fetch_user)] = None, - expdb: Annotated[Connection, Depends(expdb_connection)] = None, + expdb: Annotated[AsyncConnection, Depends(expdb_connection)] = None, ) -> Study: - study = _get_study_raise_otherwise(alias_or_id, user, expdb) - study_data = database.studies.get_study_data(study, expdb) + study = await _get_study_raise_otherwise(alias_or_id, user, expdb) + study_data = await database.studies.get_study_data(study, expdb) return Study( _legacy=_str_to_bool(study.legacy), id_=study.id, diff --git a/src/routers/openml/tasks.py b/src/routers/openml/tasks.py index 8397f1da..92798348 100644 --- a/src/routers/openml/tasks.py +++ b/src/routers/openml/tasks.py @@ -5,7 +5,8 @@ import xmltodict from fastapi import APIRouter, Depends, HTTPException -from sqlalchemy import Connection, RowMapping, text +from sqlalchemy import RowMapping, text +from sqlalchemy.ext.asyncio import AsyncConnection import config import database.datasets @@ -27,11 +28,11 @@ def convert_template_xml_to_json(xml_template: str) -> dict[str, JSON]: return cast("dict[str, JSON]", json.loads(json_str)) -def fill_template( +async def fill_template( template: str, task: RowMapping, task_inputs: dict[str, str | int], - connection: Connection, + connection: AsyncConnection, ) -> dict[str, JSON]: """Fill in the XML template as used for task descriptions and return the result, converted to JSON. @@ -83,7 +84,7 @@ def fill_template( json_template = convert_template_xml_to_json(template) return cast( "dict[str, JSON]", - _fill_json_template( + await _fill_json_template( json_template, task, task_inputs, @@ -93,21 +94,22 @@ def fill_template( ) -def _fill_json_template( +async def _fill_json_template( template: JSON, task: RowMapping, task_inputs: dict[str, str | int], fetched_data: dict[str, str], - connection: Connection, + connection: AsyncConnection, ) -> JSON: if isinstance(template, dict): return { - k: _fill_json_template(v, task, task_inputs, fetched_data, connection) + k: await _fill_json_template(v, task, task_inputs, fetched_data, connection) for k, v in template.items() } if isinstance(template, list): return [ - _fill_json_template(v, task, task_inputs, fetched_data, connection) for v in template + await _fill_json_template(v, task, task_inputs, fetched_data, connection) + for v in template ] if not isinstance(template, str): msg = f"Unexpected type for `template`: {template=}, {type(template)=}" @@ -125,7 +127,7 @@ def _fill_json_template( (field,) = match.groups() if field not in fetched_data: table, _ = field.split(".") - rows = connection.execute( + rows = await connection.execute( text( f""" SELECT * @@ -150,13 +152,13 @@ def _fill_json_template( @router.get("/{task_id}") -def get_task( +async def get_task( task_id: int, - expdb: Annotated[Connection, Depends(expdb_connection)] = None, + expdb: Annotated[AsyncConnection, Depends(expdb_connection)] = None, ) -> Task: - if not (task := database.tasks.get(task_id, expdb)): + if not (task := await database.tasks.get(task_id, expdb)): raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail="Task not found") - if not (task_type := database.tasks.get_task_type(task.ttid, expdb)): + if not (task_type := await database.tasks.get_task_type(task.ttid, expdb)): raise HTTPException( status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail="Task type not found", @@ -164,12 +166,12 @@ def get_task( task_inputs = { row.input: int(row.value) if row.value.isdigit() else row.value - for row in database.tasks.get_input_for_task(task_id, expdb) + for row in await database.tasks.get_input_for_task(task_id, expdb) } - ttios = database.tasks.get_task_type_inout_with_template(task_type.ttid, expdb) + ttios = await database.tasks.get_task_type_inout_with_template(task_type.ttid, expdb) templates = [(tt_io.name, tt_io.io, tt_io.requirement, tt_io.template_api) for tt_io in ttios] inputs = [ - fill_template(template, task, task_inputs, expdb) | {"name": name} + await fill_template(template, task, task_inputs, expdb) | {"name": name} for name, io, required, template in templates if io == "input" ] @@ -178,10 +180,10 @@ def get_task( for name, io, required, template in templates if io == "output" ] - tags = database.tasks.get_tags(task_id, expdb) + tags = await database.tasks.get_tags(task_id, expdb) name = f"Task {task_id} ({task_type.name})" dataset_id = task_inputs.get("source_data") - if isinstance(dataset_id, int) and (dataset := database.datasets.get(dataset_id, expdb)): + if isinstance(dataset_id, int) and (dataset := await database.datasets.get(dataset_id, expdb)): name = f"Task {task_id}: {dataset.name} ({task_type.name})" return Task( diff --git a/src/routers/openml/tasktype.py b/src/routers/openml/tasktype.py index 5213f177..3e810f5d 100644 --- a/src/routers/openml/tasktype.py +++ b/src/routers/openml/tasktype.py @@ -3,7 +3,8 @@ from typing import Annotated, Any, Literal, cast from fastapi import APIRouter, Depends, HTTPException -from sqlalchemy import Connection, Row +from sqlalchemy import Row +from sqlalchemy.ext.asyncio import AsyncConnection from database.tasks import get_input_for_task_type, get_task_types from database.tasks import get_task_type as db_get_task_type @@ -26,24 +27,24 @@ def _normalize_task_type(task_type: Row) -> dict[str, str | None | list[Any]]: @router.get(path="/list") -def list_task_types( - expdb: Annotated[Connection, Depends(expdb_connection)] = None, +async def list_task_types( + expdb: Annotated[AsyncConnection, Depends(expdb_connection)] = None, ) -> dict[ Literal["task_types"], dict[Literal["task_type"], list[dict[str, str | None | list[Any]]]], ]: task_types: list[dict[str, str | None | list[Any]]] = [ - _normalize_task_type(ttype) for ttype in get_task_types(expdb) + _normalize_task_type(ttype) for ttype in await get_task_types(expdb) ] return {"task_types": {"task_type": task_types}} @router.get(path="/{task_type_id}") -def get_task_type( +async def get_task_type( task_type_id: int, - expdb: Annotated[Connection, Depends(expdb_connection)], + expdb: Annotated[AsyncConnection, Depends(expdb_connection)], ) -> dict[Literal["task_type"], dict[str, str | None | list[str] | list[dict[str, str]]]]: - task_type_record = db_get_task_type(task_type_id, expdb) + task_type_record = await db_get_task_type(task_type_id, expdb) if task_type_record is None: raise HTTPException( status_code=HTTPStatus.PRECONDITION_FAILED, @@ -60,7 +61,7 @@ def get_task_type( creator.strip(' "') for creator in cast("str", contributors).split(",") ] task_type["creation_date"] = task_type.pop("creationDate") - task_type_inputs = get_input_for_task_type(task_type_id, expdb) + task_type_inputs = await get_input_for_task_type(task_type_id, expdb) input_types = [] for task_type_input in task_type_inputs: input_ = {} diff --git a/tests/conftest.py b/tests/conftest.py index eecc1288..53f369b2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,16 +1,18 @@ import contextlib import json -from collections.abc import Iterator +from collections.abc import AsyncIterator, Iterator from pathlib import Path from typing import Any, NamedTuple import _pytest.mark import httpx import pytest +import pytest_asyncio from _pytest.config import Config from _pytest.nodes import Item from fastapi.testclient import TestClient -from sqlalchemy import Connection, Engine, text +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine from database.setup import expdb_database, user_database from main import create_api @@ -19,24 +21,24 @@ PHP_API_URL = "http://openml-php-rest-api:80/api/v1/json" -@contextlib.contextmanager -def automatic_rollback(engine: Engine) -> Iterator[Connection]: - with engine.connect() as connection: - transaction = connection.begin() +@contextlib.asynccontextmanager +async def automatic_rollback(engine: AsyncEngine) -> AsyncIterator[AsyncConnection]: + async with engine.connect() as connection: + transaction = await connection.begin() yield connection if transaction.is_active: - transaction.rollback() + await transaction.rollback() -@pytest.fixture -def expdb_test() -> Connection: - with automatic_rollback(expdb_database()) as connection: +@pytest_asyncio.fixture # type: ignore[untyped-decorator] +async def expdb_test() -> AsyncIterator[AsyncConnection]: + async with automatic_rollback(expdb_database()) as connection: yield connection -@pytest.fixture -def user_test() -> Connection: - with automatic_rollback(user_database()) as connection: +@pytest_asyncio.fixture # type: ignore[untyped-decorator] +async def user_test() -> AsyncIterator[AsyncConnection]: + async with automatic_rollback(user_database()) as connection: yield connection @@ -47,7 +49,7 @@ def php_api() -> httpx.Client: @pytest.fixture -def py_api(expdb_test: Connection, user_test: Connection) -> TestClient: +def py_api(expdb_test: AsyncConnection, user_test: AsyncConnection) -> TestClient: app = create_api() # We use the lambda definitions because fixtures may not be called directly. app.dependency_overrides[expdb_connection] = lambda: expdb_test @@ -75,9 +77,9 @@ class Flow(NamedTuple): external_version: str -@pytest.fixture -def flow(expdb_test: Connection) -> Flow: - expdb_test.execute( +@pytest_asyncio.fixture # type: ignore[untyped-decorator] +async def flow(expdb_test: AsyncConnection) -> Flow: + await expdb_test.execute( text( """ INSERT INTO implementation(fullname,name,version,external_version,uploadDate) @@ -85,19 +87,19 @@ def flow(expdb_test: Connection) -> Flow: """, ), ) - (flow_id,) = expdb_test.execute(text("""SELECT LAST_INSERT_ID();""")).one() + (flow_id,) = (await expdb_test.execute(text("""SELECT LAST_INSERT_ID();"""))).one() return Flow(id=flow_id, name="name", external_version="external_version") -@pytest.fixture -def persisted_flow(flow: Flow, expdb_test: Connection) -> Iterator[Flow]: - expdb_test.commit() +@pytest_asyncio.fixture # type: ignore[untyped-decorator] +async def persisted_flow(flow: Flow, expdb_test: AsyncConnection) -> AsyncIterator[Flow]: + await expdb_test.commit() yield flow # We want to ensure the commit below does not accidentally persist new # data to the database. - expdb_test.rollback() + await expdb_test.rollback() - expdb_test.execute( + await expdb_test.execute( text( """ DELETE FROM implementation @@ -106,7 +108,7 @@ def persisted_flow(flow: Flow, expdb_test: Connection) -> Iterator[Flow]: ), parameters={"flow_id": flow.id}, ) - expdb_test.commit() + await expdb_test.commit() def pytest_collection_modifyitems(config: Config, items: list[Item]) -> None: # noqa: ARG001 diff --git a/tests/database/flows_test.py b/tests/database/flows_test.py index 7c952eaa..c3c6fa28 100644 --- a/tests/database/flows_test.py +++ b/tests/database/flows_test.py @@ -1,18 +1,21 @@ -from sqlalchemy import Connection +import pytest +from sqlalchemy.ext.asyncio import AsyncConnection import database.flows from tests.conftest import Flow -def test_database_flow_exists(flow: Flow, expdb_test: Connection) -> None: - retrieved_flow = database.flows.get_by_name(flow.name, flow.external_version, expdb_test) +@pytest.mark.asyncio +async def test_database_flow_exists(flow: Flow, expdb_test: AsyncConnection) -> None: + retrieved_flow = await database.flows.get_by_name(flow.name, flow.external_version, expdb_test) assert retrieved_flow is not None assert retrieved_flow.id == flow.id # when using actual ORM, can instead ensure _all_ fields match. -def test_database_flow_exists_returns_none_if_no_match(expdb_test: Connection) -> None: - retrieved_flow = database.flows.get_by_name( +@pytest.mark.asyncio +async def test_database_flow_exists_returns_none_if_no_match(expdb_test: AsyncConnection) -> None: + retrieved_flow = await database.flows.get_by_name( name="foo", external_version="bar", expdb=expdb_test, diff --git a/tests/routers/openml/dataset_tag_test.py b/tests/routers/openml/dataset_tag_test.py index 5449862a..3763c4dd 100644 --- a/tests/routers/openml/dataset_tag_test.py +++ b/tests/routers/openml/dataset_tag_test.py @@ -1,7 +1,7 @@ from http import HTTPStatus import pytest -from sqlalchemy import Connection +from sqlalchemy.ext.asyncio import AsyncConnection from starlette.testclient import TestClient from database.datasets import get_tags_for @@ -29,7 +29,8 @@ def test_dataset_tag_rejects_unauthorized(key: ApiKey, py_api: TestClient) -> No [ApiKey.ADMIN, ApiKey.SOME_USER, ApiKey.OWNER_USER], ids=["administrator", "non-owner", "owner"], ) -def test_dataset_tag(key: ApiKey, expdb_test: Connection, py_api: TestClient) -> None: +@pytest.mark.asyncio +async def test_dataset_tag(key: ApiKey, expdb_test: AsyncConnection, py_api: TestClient) -> None: dataset_id, tag = next(iter(constants.PRIVATE_DATASET_ID)), "test" response = py_api.post( f"/datasets/tag?api_key={key}", @@ -38,7 +39,7 @@ def test_dataset_tag(key: ApiKey, expdb_test: Connection, py_api: TestClient) -> assert response.status_code == HTTPStatus.OK assert response.json() == {"data_tag": {"id": str(dataset_id), "tag": [tag]}} - tags = get_tags_for(id_=dataset_id, connection=expdb_test) + tags = await get_tags_for(id_=dataset_id, connection=expdb_test) assert tag in tags diff --git a/tests/routers/openml/datasets_test.py b/tests/routers/openml/datasets_test.py index 4ba5ad83..7ee4a666 100644 --- a/tests/routers/openml/datasets_test.py +++ b/tests/routers/openml/datasets_test.py @@ -2,7 +2,7 @@ import pytest from fastapi import HTTPException -from sqlalchemy import Connection +from sqlalchemy.ext.asyncio import AsyncConnection from starlette.testclient import TestClient from database.users import User @@ -76,12 +76,13 @@ def test_get_dataset(py_api: TestClient) -> None: SOME_USER, ], ) -def test_private_dataset_no_access( +@pytest.mark.asyncio +async def test_private_dataset_no_access( user: User | None, - expdb_test: Connection, + expdb_test: AsyncConnection, ) -> None: with pytest.raises(HTTPException) as e: - get_dataset( + await get_dataset( dataset_id=130, user=user, user_db=None, @@ -94,8 +95,11 @@ def test_private_dataset_no_access( @pytest.mark.parametrize( "user", [DATASET_130_OWNER, ADMIN_USER, pytest.param(SOME_USER, marks=pytest.mark.xfail)] ) -def test_private_dataset_access(user: User, expdb_test: Connection, user_test: Connection) -> None: - dataset = get_dataset( +@pytest.mark.asyncio +async def test_private_dataset_access( + user: User, expdb_test: AsyncConnection, user_test: AsyncConnection +) -> None: + dataset = await get_dataset( dataset_id=130, user=user, user_db=user_test, diff --git a/tests/routers/openml/flows_test.py b/tests/routers/openml/flows_test.py index d5188d0e..354b968f 100644 --- a/tests/routers/openml/flows_test.py +++ b/tests/routers/openml/flows_test.py @@ -1,10 +1,11 @@ from http import HTTPStatus +from unittest.mock import AsyncMock import deepdiff.diff import pytest from fastapi import HTTPException from pytest_mock import MockerFixture -from sqlalchemy import Connection +from sqlalchemy.ext.asyncio import AsyncConnection from starlette.testclient import TestClient from routers.openml.flows import flow_exists @@ -18,14 +19,15 @@ ("c", "d"), ], ) -def test_flow_exists_calls_db_correctly( +@pytest.mark.asyncio +async def test_flow_exists_calls_db_correctly( name: str, external_version: str, - expdb_test: Connection, + expdb_test: AsyncConnection, mocker: MockerFixture, ) -> None: - mocked_db = mocker.patch("database.flows.get_by_name") - flow_exists(name, external_version, expdb_test) + mocked_db = mocker.patch("database.flows.get_by_name", new_callable=AsyncMock) + await flow_exists(name, external_version, expdb_test) mocked_db.assert_called_once_with( name=name, external_version=external_version, @@ -37,24 +39,29 @@ def test_flow_exists_calls_db_correctly( "flow_id", [1, 2], ) -def test_flow_exists_processes_found( +@pytest.mark.asyncio +async def test_flow_exists_processes_found( flow_id: int, mocker: MockerFixture, - expdb_test: Connection, + expdb_test: AsyncConnection, ) -> None: fake_flow = mocker.MagicMock(id=flow_id) mocker.patch( "database.flows.get_by_name", + new_callable=AsyncMock, return_value=fake_flow, ) - response = flow_exists("name", "external_version", expdb_test) + response = await flow_exists("name", "external_version", expdb_test) assert response == {"flow_id": fake_flow.id} -def test_flow_exists_handles_flow_not_found(mocker: MockerFixture, expdb_test: Connection) -> None: - mocker.patch("database.flows.get_by_name", return_value=None) +@pytest.mark.asyncio +async def test_flow_exists_handles_flow_not_found( + mocker: MockerFixture, expdb_test: AsyncConnection +) -> None: + mocker.patch("database.flows.get_by_name", new_callable=AsyncMock, return_value=None) with pytest.raises(HTTPException) as error: - flow_exists("foo", "bar", expdb_test) + await flow_exists("foo", "bar", expdb_test) assert error.value.status_code == HTTPStatus.NOT_FOUND assert error.value.detail == "Flow not found." diff --git a/tests/routers/openml/qualities_test.py b/tests/routers/openml/qualities_test.py index eed569e9..56dc5f10 100644 --- a/tests/routers/openml/qualities_test.py +++ b/tests/routers/openml/qualities_test.py @@ -3,12 +3,13 @@ import deepdiff import httpx import pytest -from sqlalchemy import Connection, text +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncConnection from starlette.testclient import TestClient -def _remove_quality_from_database(quality_name: str, expdb_test: Connection) -> None: - expdb_test.execute( +async def _remove_quality_from_database(quality_name: str, expdb_test: AsyncConnection) -> None: + await expdb_test.execute( text( """ DELETE FROM data_quality @@ -17,7 +18,7 @@ def _remove_quality_from_database(quality_name: str, expdb_test: Connection) -> ), parameters={"deleted_quality": quality_name}, ) - expdb_test.execute( + await expdb_test.execute( text( """ DELETE FROM quality @@ -36,7 +37,8 @@ def test_list_qualities_identical(py_api: TestClient, php_api: httpx.Client) -> # To keep the test idempotent, we cannot test if reaction to database changes is identical -def test_list_qualities(py_api: TestClient, expdb_test: Connection) -> None: +@pytest.mark.asyncio +async def test_list_qualities(py_api: TestClient, expdb_test: AsyncConnection) -> None: response = py_api.get("/datasets/qualities/list") assert response.status_code == HTTPStatus.OK expected = { @@ -155,7 +157,7 @@ def test_list_qualities(py_api: TestClient, expdb_test: Connection) -> None: assert expected == response.json() deleted = expected["data_qualities_list"]["quality"].pop() - _remove_quality_from_database(quality_name=deleted, expdb_test=expdb_test) + await _remove_quality_from_database(quality_name=deleted, expdb_test=expdb_test) response = py_api.get("/datasets/qualities/list") assert response.status_code == HTTPStatus.OK diff --git a/tests/routers/openml/study_test.py b/tests/routers/openml/study_test.py index a9a8ed4a..19486de6 100644 --- a/tests/routers/openml/study_test.py +++ b/tests/routers/openml/study_test.py @@ -2,7 +2,9 @@ from http import HTTPStatus import httpx -from sqlalchemy import Connection, text +import pytest +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncConnection from starlette.testclient import TestClient from schemas.study import StudyType @@ -502,25 +504,26 @@ def test_create_task_study(py_api: TestClient) -> None: assert new_study == expected -def _attach_tasks_to_study( +async def _attach_tasks_to_study( study_id: int, task_ids: list[int], api_key: str, py_api: TestClient, - expdb_test: Connection, + expdb_test: AsyncConnection, ) -> httpx.Response: # Adding requires the study to be in preparation, # but the current snapshot has no in-preparation studies. - expdb_test.execute(text("UPDATE study SET status = 'in_preparation' WHERE id = 1")) + await expdb_test.execute(text("UPDATE study SET status = 'in_preparation' WHERE id = 1")) return py_api.post( f"/studies/attach?api_key={api_key}", json={"study_id": study_id, "entity_ids": task_ids}, ) -def test_attach_task_to_study(py_api: TestClient, expdb_test: Connection) -> None: - expdb_test.execute(text("UPDATE study SET status = 'in_preparation' WHERE id = 7")) - response = _attach_tasks_to_study( +@pytest.mark.asyncio +async def test_attach_task_to_study(py_api: TestClient, expdb_test: AsyncConnection) -> None: + await expdb_test.execute(text("UPDATE study SET status = 'in_preparation' WHERE id = 7")) + response = await _attach_tasks_to_study( study_id=7, task_ids=[50], api_key=ApiKey.OWNER_USER, @@ -531,9 +534,12 @@ def test_attach_task_to_study(py_api: TestClient, expdb_test: Connection) -> Non assert response.json() == {"study_id": 7, "main_entity_type": StudyType.TASK} -def test_attach_task_to_study_needs_owner(py_api: TestClient, expdb_test: Connection) -> None: - expdb_test.execute(text("UPDATE study SET status = 'in_preparation' WHERE id = 7")) - response = _attach_tasks_to_study( +@pytest.mark.asyncio +async def test_attach_task_to_study_needs_owner( + py_api: TestClient, expdb_test: AsyncConnection +) -> None: + await expdb_test.execute(text("UPDATE study SET status = 'in_preparation' WHERE id = 7")) + response = await _attach_tasks_to_study( study_id=1, task_ids=[2, 3, 4], api_key=ApiKey.OWNER_USER, @@ -543,12 +549,13 @@ def test_attach_task_to_study_needs_owner(py_api: TestClient, expdb_test: Connec assert response.status_code == HTTPStatus.FORBIDDEN, response.content -def test_attach_task_to_study_already_linked_raises( +@pytest.mark.asyncio +async def test_attach_task_to_study_already_linked_raises( py_api: TestClient, - expdb_test: Connection, + expdb_test: AsyncConnection, ) -> None: - expdb_test.execute(text("UPDATE study SET status = 'in_preparation' WHERE id = 1")) - response = _attach_tasks_to_study( + await expdb_test.execute(text("UPDATE study SET status = 'in_preparation' WHERE id = 1")) + response = await _attach_tasks_to_study( study_id=1, task_ids=[1, 3, 4], api_key=ApiKey.ADMIN, @@ -559,12 +566,13 @@ def test_attach_task_to_study_already_linked_raises( assert response.json() == {"detail": "Task 1 is already attached to study 1."} -def test_attach_task_to_study_but_task_not_exist_raises( +@pytest.mark.asyncio +async def test_attach_task_to_study_but_task_not_exist_raises( py_api: TestClient, - expdb_test: Connection, + expdb_test: AsyncConnection, ) -> None: - expdb_test.execute(text("UPDATE study SET status = 'in_preparation' WHERE id = 1")) - response = _attach_tasks_to_study( + await expdb_test.execute(text("UPDATE study SET status = 'in_preparation' WHERE id = 1")) + response = await _attach_tasks_to_study( study_id=1, task_ids=[80123, 78914], api_key=ApiKey.ADMIN, diff --git a/tests/routers/openml/users_test.py b/tests/routers/openml/users_test.py index 45b330ae..7092afb7 100644 --- a/tests/routers/openml/users_test.py +++ b/tests/routers/openml/users_test.py @@ -1,5 +1,5 @@ import pytest -from sqlalchemy import Connection +from sqlalchemy.ext.asyncio import AsyncConnection from database.users import User from routers.dependencies import fetch_user @@ -14,14 +14,18 @@ (ApiKey.SOME_USER, SOME_USER), ], ) -def test_fetch_user(api_key: str, user: User, user_test: Connection) -> None: - db_user = fetch_user(api_key, user_data=user_test) +@pytest.mark.asyncio +async def test_fetch_user(api_key: str, user: User, user_test: AsyncConnection) -> None: + db_user = await fetch_user(api_key, user_data=user_test) assert db_user is not None assert user.user_id == db_user.user_id - assert set(user.groups) == set(db_user.groups) + user_groups = await user.get_groups() + db_user_groups = await db_user.get_groups() + assert set(user_groups) == set(db_user_groups) -def test_fetch_user_invalid_key_returns_none(user_test: Connection) -> None: - assert fetch_user(api_key=None, user_data=user_test) is None +@pytest.mark.asyncio +async def test_fetch_user_invalid_key_returns_none(user_test: AsyncConnection) -> None: + assert await fetch_user(api_key=None, user_data=user_test) is None invalid_key = "f" * 32 - assert fetch_user(api_key=invalid_key, user_data=user_test) is None + assert await fetch_user(api_key=invalid_key, user_data=user_test) is None