From 1a123ac2c9c7196405da723ad510e3317a0d308e Mon Sep 17 00:00:00 2001 From: rohansen856 Date: Thu, 19 Feb 2026 04:08:48 +0530 Subject: [PATCH 1/9] chore: added async mysql db driver dependency Signed-off-by: rohansen856 --- pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index d3b013c7..72b84d70 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ dependencies = [ "uvicorn", "sqlalchemy", "mysqlclient", + "aiomysql", "python_dotenv", "xmltodict", ] @@ -28,6 +29,7 @@ dev = [ "pre-commit", "pytest", "pytest-mock", + "pytest-asyncio", "httpx", "hypothesis", "deepdiff", @@ -80,6 +82,7 @@ plugins = [ pythonpath = [ "src" ] +asyncio_mode = "auto" markers = [ "slow: test or sets of tests which take more than a few seconds to run.", # While the `mut`ation marker below is not strictly necessary as every change is From f561220ed9186dedc4c078e009815db9fee3f440 Mon Sep 17 00:00:00 2001 From: rohansen856 Date: Thu, 19 Feb 2026 04:33:02 +0530 Subject: [PATCH 2/9] feat: implemented async queries in datasets and evalluations Signed-off-by: rohansen856 --- src/database/datasets.py | 56 ++++++++++++++++++++----------------- src/database/evaluations.py | 29 ++++++++++--------- 2 files changed, 46 insertions(+), 39 deletions(-) diff --git a/src/database/datasets.py b/src/database/datasets.py index f011a651..42da4025 100644 --- a/src/database/datasets.py +++ b/src/database/datasets.py @@ -2,14 +2,15 @@ import datetime -from sqlalchemy import Connection, text +from sqlalchemy import text from sqlalchemy.engine import Row +from sqlalchemy.ext.asyncio import AsyncConnection from schemas.datasets.openml import Feature -def get(id_: int, connection: Connection) -> Row | None: - row = connection.execute( +async def get(id_: int, connection: AsyncConnection) -> Row | None: + row = await connection.execute( text( """ SELECT * @@ -22,8 +23,8 @@ def get(id_: int, connection: Connection) -> Row | None: return row.one_or_none() -def get_file(*, file_id: int, connection: Connection) -> Row | None: - row = connection.execute( +async def get_file(*, file_id: int, connection: AsyncConnection) -> Row | None: + row = await connection.execute( text( """ SELECT * @@ -36,8 +37,8 @@ def get_file(*, file_id: int, connection: Connection) -> Row | None: return row.one_or_none() -def get_tags_for(id_: int, connection: Connection) -> list[str]: - rows = connection.execute( +async def get_tags_for(id_: int, connection: AsyncConnection) -> list[str]: + row = await connection.execute( text( """ SELECT * @@ -47,11 +48,12 @@ def get_tags_for(id_: int, connection: Connection) -> list[str]: ), parameters={"dataset_id": id_}, ) + rows = row.all() return [row.tag for row in rows] -def tag(id_: int, tag_: str, *, user_id: int, connection: Connection) -> None: - connection.execute( +async def tag(id_: int, tag_: str, *, user_id: int, connection: AsyncConnection) -> None: + await connection.execute( text( """ INSERT INTO dataset_tag(`id`, `tag`, `uploader`) @@ -66,12 +68,12 @@ def tag(id_: int, tag_: str, *, user_id: int, connection: Connection) -> None: ) -def get_description( +async def get_description( id_: int, - connection: Connection, + connection: AsyncConnection, ) -> Row | None: """Get the most recent description for the dataset.""" - row = connection.execute( + row = await connection.execute( text( """ SELECT * @@ -85,9 +87,9 @@ def get_description( return row.first() -def get_status(id_: int, connection: Connection) -> Row | None: +async def get_status(id_: int, connection: AsyncConnection) -> Row | None: """Get most recent status for the dataset.""" - row = connection.execute( + row = await connection.execute( text( """ SELECT * @@ -101,8 +103,8 @@ def get_status(id_: int, connection: Connection) -> Row | None: return row.first() -def get_latest_processing_update(dataset_id: int, connection: Connection) -> Row | None: - row = connection.execute( +async def get_latest_processing_update(dataset_id: int, connection: AsyncConnection) -> Row | None: + row = await connection.execute( text( """ SELECT * @@ -116,8 +118,8 @@ def get_latest_processing_update(dataset_id: int, connection: Connection) -> Row return row.one_or_none() -def get_features(dataset_id: int, connection: Connection) -> list[Feature]: - rows = connection.execute( +async def get_features(dataset_id: int, connection: AsyncConnection) -> list[Feature]: + row = await connection.execute( text( """ SELECT `index`,`name`,`data_type`,`is_target`, @@ -128,11 +130,12 @@ def get_features(dataset_id: int, connection: Connection) -> list[Feature]: ), parameters={"dataset_id": dataset_id}, ) - return [Feature(**row, nominal_values=None) for row in rows.mappings()] + rows = row.mappings().all() + return [Feature(**row, nominal_values=None) for row in rows] -def get_feature_values(dataset_id: int, *, feature_index: int, connection: Connection) -> list[str]: - rows = connection.execute( +async def get_feature_values(dataset_id: int, *, feature_index: int, connection: AsyncConnection) -> list[str]: + row = await connection.execute( text( """ SELECT `value` @@ -142,17 +145,18 @@ def get_feature_values(dataset_id: int, *, feature_index: int, connection: Conne ), parameters={"dataset_id": dataset_id, "feature_index": feature_index}, ) + rows = row.all() return [row.value for row in rows] -def update_status( +async def update_status( dataset_id: int, status: str, *, user_id: int, - connection: Connection, + connection: AsyncConnection, ) -> None: - connection.execute( + await connection.execute( text( """ INSERT INTO dataset_status(`did`,`status`,`status_date`,`user_id`) @@ -168,8 +172,8 @@ def update_status( ) -def remove_deactivated_status(dataset_id: int, connection: Connection) -> None: - connection.execute( +async def remove_deactivated_status(dataset_id: int, connection: AsyncConnection) -> None: + await connection.execute( text( """ DELETE FROM dataset_status diff --git a/src/database/evaluations.py b/src/database/evaluations.py index 799d4112..74faf59b 100644 --- a/src/database/evaluations.py +++ b/src/database/evaluations.py @@ -1,30 +1,32 @@ from collections.abc import Sequence from typing import cast -from sqlalchemy import Connection, Row, text +from sqlalchemy import Row, text +from sqlalchemy.ext.asyncio import AsyncConnection from core.formatting import _str_to_bool from schemas.datasets.openml import EstimationProcedure -def get_math_functions(function_type: str, connection: Connection) -> Sequence[Row]: - return cast( - "Sequence[Row]", - connection.execute( - text( - """ +async def get_math_functions(function_type: str, connection: AsyncConnection) -> Sequence[Row]: + rows = await connection.execute( + text( + """ SELECT * FROM math_function WHERE `functionType` = :function_type """, - ), - parameters={"function_type": function_type}, - ).all(), + ), + parameters={"function_type": function_type}, + ) + return cast( + "Sequence[Row]", + rows.all(), ) -def get_estimation_procedures(connection: Connection) -> list[EstimationProcedure]: - rows = connection.execute( +async def get_estimation_procedures(connection: AsyncConnection) -> list[EstimationProcedure]: + row = await connection.execute( text( """ SELECT `id` as 'id_', `ttid` as 'task_type_id', `name`, `type` as 'type_', @@ -33,11 +35,12 @@ def get_estimation_procedures(connection: Connection) -> list[EstimationProcedur """, ), ) + rows = row.mappings().all() typed_rows = [ { k: v if k != "stratified_sampling" or v is None else _str_to_bool(v) for k, v in row.items() } - for row in rows.mappings() + for row in rows ] return [EstimationProcedure(**typed_row) for typed_row in typed_rows] From e3377680f54a3c87727951c8fc1b4f8586775eca Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 18 Feb 2026 23:04:24 +0000 Subject: [PATCH 3/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/database/datasets.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/database/datasets.py b/src/database/datasets.py index 42da4025..873aeb57 100644 --- a/src/database/datasets.py +++ b/src/database/datasets.py @@ -134,7 +134,9 @@ async def get_features(dataset_id: int, connection: AsyncConnection) -> list[Fea return [Feature(**row, nominal_values=None) for row in rows] -async def get_feature_values(dataset_id: int, *, feature_index: int, connection: AsyncConnection) -> list[str]: +async def get_feature_values( + dataset_id: int, *, feature_index: int, connection: AsyncConnection +) -> list[str]: row = await connection.execute( text( """ From 10cccb96987821b2452311e4d5c5a1dfb8cf72e5 Mon Sep 17 00:00:00 2001 From: rohansen856 Date: Fri, 20 Feb 2026 14:42:45 +0530 Subject: [PATCH 4/9] feat: implemented async queries in flows, qualities, studies, tasks and users Signed-off-by: rohansen856 --- src/database/flows.py | 56 ++++++++++++---------- src/database/qualities.py | 20 ++++---- src/database/setup.py | 25 ++++++++-- src/database/studies.py | 78 ++++++++++++++++-------------- src/database/tasks.py | 99 ++++++++++++++++++++++----------------- src/database/users.py | 28 ++++++----- 6 files changed, 175 insertions(+), 131 deletions(-) diff --git a/src/database/flows.py b/src/database/flows.py index 3129e91e..3b4ab847 100644 --- a/src/database/flows.py +++ b/src/database/flows.py @@ -1,27 +1,29 @@ from collections.abc import Sequence from typing import cast -from sqlalchemy import Connection, Row, text +from sqlalchemy import Row, text +from sqlalchemy.ext.asyncio import AsyncConnection -def get_subflows(for_flow: int, expdb: Connection) -> Sequence[Row]: - return cast( - "Sequence[Row]", - expdb.execute( - text( - """ +async def get_subflows(for_flow: int, expdb: AsyncConnection) -> Sequence[Row]: + rows = await expdb.execute( + text( + """ SELECT child as child_id, identifier FROM implementation_component WHERE parent = :flow_id """, - ), - parameters={"flow_id": for_flow}, ), + parameters={"flow_id": for_flow}, + ) + return cast( + "Sequence[Row]", + rows.all(), ) -def get_tags(flow_id: int, expdb: Connection) -> list[str]: - tag_rows = expdb.execute( +async def get_tags(flow_id: int, expdb: AsyncConnection) -> list[str]: + rows = await expdb.execute( text( """ SELECT tag @@ -31,28 +33,30 @@ def get_tags(flow_id: int, expdb: Connection) -> list[str]: ), parameters={"flow_id": flow_id}, ) + tag_rows = rows.all() return [tag.tag for tag in tag_rows] -def get_parameters(flow_id: int, expdb: Connection) -> Sequence[Row]: - return cast( - "Sequence[Row]", - expdb.execute( - text( - """ +async def get_parameters(flow_id: int, expdb: AsyncConnection) -> Sequence[Row]: + rows = await expdb.execute( + text( + """ SELECT *, defaultValue as default_value, dataType as data_type FROM input WHERE implementation_id = :flow_id """, - ), - parameters={"flow_id": flow_id}, ), + parameters={"flow_id": flow_id}, + ) + return cast( + "Sequence[Row]", + rows.all(), ) -def get_by_name(name: str, external_version: str, expdb: Connection) -> Row | None: +async def get_by_name(name: str, external_version: str, expdb: AsyncConnection) -> Row | None: """Gets flow by name and external version.""" - return expdb.execute( + row = await expdb.execute( text( """ SELECT *, uploadDate as upload_date @@ -61,11 +65,12 @@ def get_by_name(name: str, external_version: str, expdb: Connection) -> Row | No """, ), parameters={"name": name, "external_version": external_version}, - ).one_or_none() + ) + return row.one_or_none() -def get(id_: int, expdb: Connection) -> Row | None: - return expdb.execute( +async def get(id_: int, expdb: AsyncConnection) -> Row | None: + row = await expdb.execute( text( """ SELECT *, uploadDate as upload_date @@ -74,4 +79,5 @@ def get(id_: int, expdb: Connection) -> Row | None: """, ), parameters={"flow_id": id_}, - ).one_or_none() + ) + return row.one_or_none() diff --git a/src/database/qualities.py b/src/database/qualities.py index 81499c1e..08647f41 100644 --- a/src/database/qualities.py +++ b/src/database/qualities.py @@ -1,13 +1,14 @@ from collections import defaultdict from collections.abc import Iterable -from sqlalchemy import Connection, text +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncConnection from schemas.datasets.openml import Quality -def get_for_dataset(dataset_id: int, connection: Connection) -> list[Quality]: - rows = connection.execute( +async def get_for_dataset(dataset_id: int, connection: AsyncConnection) -> list[Quality]: + row = await connection.execute( text( """ SELECT `quality`,`value` @@ -17,13 +18,14 @@ def get_for_dataset(dataset_id: int, connection: Connection) -> list[Quality]: ), parameters={"dataset_id": dataset_id}, ) + rows = row.all() return [Quality(name=row.quality, value=row.value) for row in rows] -def get_for_datasets( +async def get_for_datasets( dataset_ids: Iterable[int], quality_names: Iterable[str], - connection: Connection, + connection: AsyncConnection, ) -> dict[int, list[Quality]]: """Don't call with user-provided input, as query is not parameterized.""" qualities_filter = ",".join(f"'{q}'" for q in quality_names) @@ -35,7 +37,8 @@ def get_for_datasets( WHERE `data` in ({dids}) AND `quality` IN ({qualities_filter}) """, # noqa: S608 - dids and qualities are not user-provided ) - rows = connection.execute(qualities_query) + row = await connection.execute(qualities_query) + rows = row.all() qualities_by_id = defaultdict(list) for did, quality, value in rows: if value is not None: @@ -43,10 +46,10 @@ def get_for_datasets( return dict(qualities_by_id) -def list_all_qualities(connection: Connection) -> list[str]: +async def list_all_qualities(connection: AsyncConnection) -> list[str]: # The current implementation only fetches *used* qualities, otherwise you should # query: SELECT `name` FROM `quality` WHERE `type`='DataQuality' - qualities_ = connection.execute( + rows = await connection.execute( text( """ SELECT DISTINCT(`quality`) @@ -54,4 +57,5 @@ def list_all_qualities(connection: Connection) -> list[str]: """, ), ) + qualities_ = rows.all() return [quality.quality for quality in qualities_] diff --git a/src/database/setup.py b/src/database/setup.py index 3a1be2f6..6f7e3017 100644 --- a/src/database/setup.py +++ b/src/database/setup.py @@ -1,5 +1,5 @@ -from sqlalchemy import Engine, create_engine from sqlalchemy.engine import URL +from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine from config import load_database_configuration @@ -7,26 +7,41 @@ _expdb_engine = None -def _create_engine(database_name: str) -> Engine: +def _create_engine(database_name: str) -> AsyncEngine: database_configuration = load_database_configuration() echo = database_configuration[database_name].pop("echo", False) + + # Update driver to use aiomysql for async support + database_configuration[database_name]["drivername"] = "mysql+aiomysql" + db_url = URL.create(**database_configuration[database_name]) - return create_engine( + return create_async_engine( db_url, echo=echo, pool_recycle=3600, ) -def user_database() -> Engine: +def user_database() -> AsyncEngine: global _user_engine # noqa: PLW0603 if _user_engine is None: _user_engine = _create_engine("openml") return _user_engine -def expdb_database() -> Engine: +def expdb_database() -> AsyncEngine: global _expdb_engine # noqa: PLW0603 if _expdb_engine is None: _expdb_engine = _create_engine("expdb") return _expdb_engine + + +async def close_databases() -> None: + """Close all database connections.""" + global _user_engine, _expdb_engine # noqa: PLW0603 + if _user_engine is not None: + await _user_engine.dispose() + _user_engine = None + if _expdb_engine is not None: + await _expdb_engine.dispose() + _expdb_engine = None diff --git a/src/database/studies.py b/src/database/studies.py index 35c1b790..837e2bd2 100644 --- a/src/database/studies.py +++ b/src/database/studies.py @@ -3,14 +3,15 @@ from datetime import datetime from typing import cast -from sqlalchemy import Connection, Row, text +from sqlalchemy import Row, text +from sqlalchemy.ext.asyncio import AsyncConnection from database.users import User from schemas.study import CreateStudy, StudyType -def get_by_id(id_: int, connection: Connection) -> Row | None: - return connection.execute( +async def get_by_id(id_: int, connection: AsyncConnection) -> Row | None: + row = await connection.execute( text( """ SELECT *, main_entity_type as type_ @@ -19,11 +20,12 @@ def get_by_id(id_: int, connection: Connection) -> Row | None: """, ), parameters={"study_id": id_}, - ).one_or_none() + ) + return row.one_or_none() -def get_by_alias(alias: str, connection: Connection) -> Row | None: - return connection.execute( +async def get_by_alias(alias: str, connection: AsyncConnection) -> Row | None: + row = await connection.execute( text( """ SELECT *, main_entity_type as type_ @@ -32,34 +34,34 @@ def get_by_alias(alias: str, connection: Connection) -> Row | None: """, ), parameters={"study_id": alias}, - ).one_or_none() + ) + return row.one_or_none() -def get_study_data(study: Row, expdb: Connection) -> Sequence[Row]: +async def get_study_data(study: Row, expdb: AsyncConnection) -> Sequence[Row]: """Return data related to the study, content depends on the study type. For task studies: (task id, dataset id) For run studies: (run id, task id, setup id, dataset id, flow id) """ if study.type_ == StudyType.TASK: - return cast( - "Sequence[Row]", - expdb.execute( - text( - """ + rows = await expdb.execute( + text( + """ SELECT ts.task_id as task_id, ti.value as data_id FROM task_study as ts LEFT JOIN task_inputs ti ON ts.task_id = ti.task_id WHERE ts.study_id = :study_id AND ti.input = 'source_data' """, - ), - parameters={"study_id": study.id}, - ).all(), + ), + parameters={"study_id": study.id}, ) - return cast( - "Sequence[Row]", - expdb.execute( - text( - """ + return cast( + "Sequence[Row]", + rows.all(), + ) + rows = await expdb.execute( + text( + """ SELECT rs.run_id as run_id, run.task_id as task_id, @@ -72,14 +74,17 @@ def get_study_data(study: Row, expdb: Connection) -> Sequence[Row]: JOIN task_inputs as ti ON ti.task_id = run.task_id WHERE rs.study_id = :study_id AND ti.input = 'source_data' """, - ), - parameters={"study_id": study.id}, - ).all(), + ), + parameters={"study_id": study.id}, + ) + return cast( + "Sequence[Row]", + rows.all(), ) -def create(study: CreateStudy, user: User, expdb: Connection) -> int: - expdb.execute( +async def create(study: CreateStudy, user: User, expdb: AsyncConnection) -> int: + await expdb.execute( text( """ INSERT INTO study ( @@ -102,12 +107,13 @@ def create(study: CreateStudy, user: User, expdb: Connection) -> int: "benchmark_suite": study.benchmark_suite, }, ) - (study_id,) = expdb.execute(text("""SELECT LAST_INSERT_ID();""")).one() + row = await expdb.execute(text("""SELECT LAST_INSERT_ID();""")) + (study_id,) = row.one() return cast("int", study_id) -def attach_task(task_id: int, study_id: int, user: User, expdb: Connection) -> None: - expdb.execute( +async def attach_task(task_id: int, study_id: int, user: User, expdb: AsyncConnection) -> None: + await expdb.execute( text( """ INSERT INTO task_study (study_id, task_id, uploader) @@ -118,8 +124,8 @@ def attach_task(task_id: int, study_id: int, user: User, expdb: Connection) -> N ) -def attach_run(*, run_id: int, study_id: int, user: User, expdb: Connection) -> None: - expdb.execute( +async def attach_run(*, run_id: int, study_id: int, user: User, expdb: AsyncConnection) -> None: + await expdb.execute( text( """ INSERT INTO run_study (study_id, run_id, uploader) @@ -130,16 +136,16 @@ def attach_run(*, run_id: int, study_id: int, user: User, expdb: Connection) -> ) -def attach_tasks( +async def attach_tasks( *, study_id: int, task_ids: list[int], user: User, - connection: Connection, + connection: AsyncConnection, ) -> None: to_link = [(study_id, task_id, user.user_id) for task_id in task_ids] try: - connection.execute( + await connection.execute( text( """ INSERT INTO task_study (study_id, task_id, uploader) @@ -162,10 +168,10 @@ def attach_tasks( raise ValueError(msg) from e -def attach_runs( +async def attach_runs( study_id: int, run_ids: list[int], user: User, - connection: Connection, + connection: AsyncConnection, ) -> None: raise NotImplementedError diff --git a/src/database/tasks.py b/src/database/tasks.py index 97caef3b..e9670d26 100644 --- a/src/database/tasks.py +++ b/src/database/tasks.py @@ -1,11 +1,12 @@ from collections.abc import Sequence from typing import cast -from sqlalchemy import Connection, Row, text +from sqlalchemy import Row, text +from sqlalchemy.ext.asyncio import AsyncConnection -def get(id_: int, expdb: Connection) -> Row | None: - return expdb.execute( +async def get(id_: int, expdb: AsyncConnection) -> Row | None: + row = await expdb.execute( text( """ SELECT * @@ -14,25 +15,27 @@ def get(id_: int, expdb: Connection) -> Row | None: """, ), parameters={"task_id": id_}, - ).one_or_none() + ) + return row.one_or_none() -def get_task_types(expdb: Connection) -> Sequence[Row]: - return cast( - "Sequence[Row]", - expdb.execute( - text( - """ +async def get_task_types(expdb: AsyncConnection) -> Sequence[Row]: + rows = await expdb.execute( + text( + """ SELECT `ttid`, `name`, `description`, `creator` FROM task_type """, - ), - ).all(), + ), + ) + return cast( + "Sequence[Row]", + rows.all(), ) -def get_task_type(task_type_id: int, expdb: Connection) -> Row | None: - return expdb.execute( +async def get_task_type(task_type_id: int, expdb: AsyncConnection) -> Row | None: + row = await expdb.execute( text( """ SELECT * @@ -41,59 +44,66 @@ def get_task_type(task_type_id: int, expdb: Connection) -> Row | None: """, ), parameters={"ttid": task_type_id}, - ).one_or_none() + ) + return row.one_or_none() -def get_input_for_task_type(task_type_id: int, expdb: Connection) -> Sequence[Row]: - return cast( - "Sequence[Row]", - expdb.execute( - text( - """ +async def get_input_for_task_type(task_type_id: int, expdb: AsyncConnection) -> Sequence[Row]: + rows = await expdb.execute( + text( + """ SELECT * FROM task_type_inout WHERE `ttid`=:ttid AND `io`='input' """, - ), - parameters={"ttid": task_type_id}, - ).all(), + ), + parameters={"ttid": task_type_id}, ) - - -def get_input_for_task(id_: int, expdb: Connection) -> Sequence[Row]: return cast( "Sequence[Row]", - expdb.execute( - text( - """ + rows.all(), + ) + + +async def get_input_for_task(id_: int, expdb: AsyncConnection) -> Sequence[Row]: + rows = await expdb.execute( + text( + """ SELECT `input`, `value` FROM task_inputs WHERE task_id = :task_id """, - ), - parameters={"task_id": id_}, - ).all(), + ), + parameters={"task_id": id_}, ) - - -def get_task_type_inout_with_template(task_type: int, expdb: Connection) -> Sequence[Row]: return cast( "Sequence[Row]", - expdb.execute( - text( - """ + rows.all(), + ) + + +async def get_task_type_inout_with_template( + task_type: int, + expdb: AsyncConnection, +) -> Sequence[Row]: + rows = await expdb.execute( + text( + """ SELECT * FROM task_type_inout WHERE `ttid`=:ttid AND `template_api` IS NOT NULL """, - ), - parameters={"ttid": task_type}, - ).all(), + ), + parameters={"ttid": task_type}, + ) + return cast( + "Sequence[Row]", + rows.all(), ) -def get_tags(id_: int, expdb: Connection) -> list[str]: - tag_rows = expdb.execute( +async def get_tags(id_: int, expdb: AsyncConnection) -> list[str]: + rows = await expdb.execute( text( """ SELECT `tag` @@ -103,4 +113,5 @@ def get_tags(id_: int, expdb: Connection) -> list[str]: ), parameters={"task_id": id_}, ) + tag_rows = rows.all() return [row.tag for row in tag_rows] diff --git a/src/database/users.py b/src/database/users.py index a045f5da..91d97e7e 100644 --- a/src/database/users.py +++ b/src/database/users.py @@ -3,7 +3,8 @@ from typing import Annotated, Self from pydantic import StringConstraints -from sqlalchemy import Connection, text +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncConnection # Enforces str is 32 hexadecimal characters, does not check validity. APIKey = Annotated[str, StringConstraints(pattern=r"^[0-9a-fA-F]{32}$")] @@ -15,8 +16,8 @@ class UserGroup(IntEnum): READ_ONLY = (3,) -def get_user_id_for(*, api_key: APIKey, connection: Connection) -> int | None: - user = connection.execute( +async def get_user_id_for(*, api_key: APIKey, connection: AsyncConnection) -> int | None: + row = await connection.execute( text( """ SELECT * @@ -25,12 +26,13 @@ def get_user_id_for(*, api_key: APIKey, connection: Connection) -> int | None: """, ), parameters={"api_key": api_key}, - ).one_or_none() + ) + user = row.one_or_none() return user.id if user else None -def get_user_groups_for(*, user_id: int, connection: Connection) -> list[UserGroup]: - row = connection.execute( +async def get_user_groups_for(*, user_id: int, connection: AsyncConnection) -> list[UserGroup]: + row = await connection.execute( text( """ SELECT group_id @@ -40,24 +42,24 @@ def get_user_groups_for(*, user_id: int, connection: Connection) -> list[UserGro ), parameters={"user_id": user_id}, ) - return [UserGroup(group) for (group,) in row] + rows = row.all() + return [UserGroup(group) for (group,) in rows] @dataclasses.dataclass class User: user_id: int - _database: Connection + _database: AsyncConnection _groups: list[UserGroup] | None = None @classmethod - def fetch(cls, api_key: APIKey, user_db: Connection) -> Self | None: - if user_id := get_user_id_for(api_key=api_key, connection=user_db): + async def fetch(cls, api_key: APIKey, user_db: AsyncConnection) -> Self | None: + if user_id := await get_user_id_for(api_key=api_key, connection=user_db): return cls(user_id, _database=user_db) return None - @property - def groups(self) -> list[UserGroup]: + async def get_groups(self) -> list[UserGroup]: if self._groups is None: - groups = get_user_groups_for(user_id=self.user_id, connection=self._database) + groups = await get_user_groups_for(user_id=self.user_id, connection=self._database) self._groups = [UserGroup(group_id) for group_id in groups] return self._groups From 736191132632475f3f6dd26dbaa2d1e83292dd9e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 20 Feb 2026 09:14:40 +0000 Subject: [PATCH 5/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/database/datasets.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/database/datasets.py b/src/database/datasets.py index 873aeb57..e914c203 100644 --- a/src/database/datasets.py +++ b/src/database/datasets.py @@ -135,7 +135,10 @@ async def get_features(dataset_id: int, connection: AsyncConnection) -> list[Fea async def get_feature_values( - dataset_id: int, *, feature_index: int, connection: AsyncConnection + dataset_id: int, + *, + feature_index: int, + connection: AsyncConnection, ) -> list[str]: row = await connection.execute( text( From faec1725e08e861ec915263340c7f4ef57be8b7d Mon Sep 17 00:00:00 2001 From: rohansen856 Date: Thu, 26 Feb 2026 03:44:28 +0530 Subject: [PATCH 6/9] chore: migrate to async database driver (aiomysql) Signed-off-by: rohansen856 --- src/core/access.py | 15 ++- src/database/datasets.py | 5 +- src/main.py | 12 ++- src/routers/dependencies.py | 21 +++-- src/routers/mldcat_ap/dataset.py | 32 ++++--- src/routers/openml/datasets.py | 101 +++++++++++++-------- src/routers/openml/estimation_procedure.py | 8 +- src/routers/openml/evaluations.py | 8 +- src/routers/openml/flows.py | 28 ++++-- src/routers/openml/qualities.py | 18 ++-- src/routers/openml/study.py | 57 +++++++----- src/routers/openml/tasks.py | 42 +++++---- src/routers/openml/tasktype.py | 20 ++-- tests/conftest.py | 54 ++++++----- tests/database/flows_test.py | 10 +- tests/routers/openml/dataset_tag_test.py | 6 +- tests/routers/openml/datasets_test.py | 14 +-- tests/routers/openml/flows_test.py | 20 ++-- tests/routers/openml/qualities_test.py | 13 +-- tests/routers/openml/study_test.py | 37 ++++---- tests/routers/openml/users_test.py | 14 +-- 21 files changed, 313 insertions(+), 222 deletions(-) diff --git a/src/core/access.py b/src/core/access.py index c44d97e6..01638f5c 100644 --- a/src/core/access.py +++ b/src/core/access.py @@ -1,15 +1,20 @@ +from typing import Any + from sqlalchemy.engine import Row from database.users import User, UserGroup from schemas.datasets.openml import Visibility -def _user_has_access( - dataset: Row, +async def _user_has_access( + dataset: Row[Any], user: User | None = None, ) -> bool: """Determine if `user` has the right to view `dataset`.""" is_public = dataset.visibility == Visibility.PUBLIC - return is_public or ( - user is not None and (user.user_id == dataset.uploader or UserGroup.ADMIN in user.groups) - ) + if is_public: + return True + if user is None: + return False + user_groups = await user.get_groups() + return user.user_id == dataset.uploader or UserGroup.ADMIN in user_groups diff --git a/src/database/datasets.py b/src/database/datasets.py index 873aeb57..e914c203 100644 --- a/src/database/datasets.py +++ b/src/database/datasets.py @@ -135,7 +135,10 @@ async def get_features(dataset_id: int, connection: AsyncConnection) -> list[Fea async def get_feature_values( - dataset_id: int, *, feature_index: int, connection: AsyncConnection + dataset_id: int, + *, + feature_index: int, + connection: AsyncConnection, ) -> list[str]: row = await connection.execute( text( diff --git a/src/main.py b/src/main.py index d8e61b34..fe99022d 100644 --- a/src/main.py +++ b/src/main.py @@ -1,9 +1,12 @@ import argparse +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager import uvicorn from fastapi import FastAPI from config import load_configuration +from database.setup import close_databases from routers.mldcat_ap.dataset import router as mldcat_ap_router from routers.openml.datasets import router as datasets_router from routers.openml.estimation_procedure import router as estimationprocedure_router @@ -15,6 +18,13 @@ from routers.openml.tasktype import router as ttype_router +@asynccontextmanager +async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001 + """Manage application lifespan - startup and shutdown events.""" + yield + await close_databases() + + def _parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser() uvicorn_options = parser.add_argument_group( @@ -43,7 +53,7 @@ def _parse_args() -> argparse.Namespace: def create_api() -> FastAPI: fastapi_kwargs = load_configuration()["fastapi"] - app = FastAPI(**fastapi_kwargs) + app = FastAPI(**fastapi_kwargs, lifespan=lifespan) app.include_router(datasets_router) app.include_router(qualities_router) diff --git a/src/routers/dependencies.py b/src/routers/dependencies.py index 2ddccf83..d9bfe76a 100644 --- a/src/routers/dependencies.py +++ b/src/routers/dependencies.py @@ -1,32 +1,33 @@ +from collections.abc import AsyncGenerator from typing import Annotated from fastapi import Depends from pydantic import BaseModel -from sqlalchemy import Connection +from sqlalchemy.ext.asyncio import AsyncConnection from database.setup import expdb_database, user_database from database.users import APIKey, User -def expdb_connection() -> Connection: +async def expdb_connection() -> AsyncGenerator[AsyncConnection, None]: engine = expdb_database() - with engine.connect() as connection: + async with engine.connect() as connection: yield connection - connection.commit() + await connection.commit() -def userdb_connection() -> Connection: +async def userdb_connection() -> AsyncGenerator[AsyncConnection, None]: engine = user_database() - with engine.connect() as connection: + async with engine.connect() as connection: yield connection - connection.commit() + await connection.commit() -def fetch_user( +async def fetch_user( api_key: APIKey | None = None, - user_data: Annotated[Connection, Depends(userdb_connection)] = None, + user_data: Annotated[AsyncConnection | None, Depends(userdb_connection)] = None, ) -> User | None: - return User.fetch(api_key, user_data) if api_key else None + return await User.fetch(api_key, user_data) if api_key and user_data else None class Pagination(BaseModel): diff --git a/src/routers/mldcat_ap/dataset.py b/src/routers/mldcat_ap/dataset.py index db34e5ce..61c16b42 100644 --- a/src/routers/mldcat_ap/dataset.py +++ b/src/routers/mldcat_ap/dataset.py @@ -7,7 +7,7 @@ from typing import Annotated from fastapi import APIRouter, Depends, HTTPException -from sqlalchemy import Connection +from sqlalchemy.ext.asyncio import AsyncConnection import config from database.users import User @@ -37,19 +37,21 @@ path="/distribution/{distribution_id}", description="Get meta-data for distribution with ID `distribution_id`.", ) -def get_mldcat_ap_distribution( +async def get_mldcat_ap_distribution( distribution_id: int, user: Annotated[User | None, Depends(fetch_user)] = None, - user_db: Annotated[Connection, Depends(userdb_connection)] = None, - expdb: Annotated[Connection, Depends(expdb_connection)] = None, + user_db: Annotated[AsyncConnection | None, Depends(userdb_connection)] = None, + expdb: Annotated[AsyncConnection | None, Depends(expdb_connection)] = None, ) -> JsonLDGraph: - oml_dataset = get_dataset( + assert user_db is not None # noqa: S101 + assert expdb is not None # noqa: S101 + oml_dataset = await get_dataset( dataset_id=distribution_id, user=user, user_db=user_db, expdb_db=expdb, ) - openml_features = get_dataset_features(distribution_id, user, expdb) + openml_features = await get_dataset_features(distribution_id, user, expdb) features = [ Feature( id_=f"{_server_url}/feature/{distribution_id}/{feature.index}", @@ -58,7 +60,7 @@ def get_mldcat_ap_distribution( ) for feature in openml_features ] - oml_qualities = get_qualities(distribution_id, user, expdb) + oml_qualities = await get_qualities(distribution_id, user, expdb) qualities = [ Quality( id_=f"{_server_url}/quality/{quality.name}/{distribution_id}", @@ -138,13 +140,14 @@ def get_dataservice(service_id: int) -> JsonLDGraph: path="/quality/{quality_name}/{distribution_id}", description="Get meta-data for a specific quality and distribution.", ) -def get_distribution_quality( +async def get_distribution_quality( quality_name: str, distribution_id: int, user: Annotated[User | None, Depends(fetch_user)] = None, - expdb: Annotated[Connection, Depends(expdb_connection)] = None, + expdb: Annotated[AsyncConnection | None, Depends(expdb_connection)] = None, ) -> JsonLDGraph: - qualities = get_qualities(distribution_id, user, expdb) + assert expdb is not None # noqa: S101 + qualities = await get_qualities(distribution_id, user, expdb) quality = next(q for q in qualities if q.name == quality_name) example_quality = Quality( id_=f"{_server_url}/quality/{quality_name}/{distribution_id}", @@ -164,13 +167,14 @@ def get_distribution_quality( path="/feature/{distribution_id}/{feature_no}", description="Get meta-data for the n-th feature of a distribution.", ) -def get_distribution_feature( +async def get_distribution_feature( distribution_id: int, feature_no: int, - user: Annotated[Connection, Depends(fetch_user)] = None, - expdb: Annotated[Connection, Depends(expdb_connection)] = None, + user: Annotated[User | None, Depends(fetch_user)] = None, + expdb: Annotated[AsyncConnection | None, Depends(expdb_connection)] = None, ) -> JsonLDGraph: - features = get_dataset_features( + assert expdb is not None # noqa: S101 + features = await get_dataset_features( dataset_id=distribution_id, user=user, expdb=expdb, diff --git a/src/routers/openml/datasets.py b/src/routers/openml/datasets.py index dda25117..39eedc72 100644 --- a/src/routers/openml/datasets.py +++ b/src/routers/openml/datasets.py @@ -5,8 +5,9 @@ from typing import Annotated, Any, Literal, NamedTuple from fastapi import APIRouter, Body, Depends, HTTPException -from sqlalchemy import Connection, text +from sqlalchemy import text from sqlalchemy.engine import Row +from sqlalchemy.ext.asyncio import AsyncConnection import database.datasets import database.qualities @@ -29,20 +30,21 @@ @router.post( path="/tag", ) -def tag_dataset( +async def tag_dataset( data_id: Annotated[int, Body()], tag: Annotated[str, SystemString64], user: Annotated[User | None, Depends(fetch_user)] = None, - expdb_db: Annotated[Connection, Depends(expdb_connection)] = None, + expdb_db: Annotated[AsyncConnection | None, Depends(expdb_connection)] = None, ) -> dict[str, dict[str, Any]]: - tags = database.datasets.get_tags_for(data_id, expdb_db) + assert expdb_db is not None # noqa: S101 + tags = await database.datasets.get_tags_for(data_id, expdb_db) if tag.casefold() in [t.casefold() for t in tags]: raise create_tag_exists_error(data_id, tag) if user is None: raise create_authentication_failed_error() - database.datasets.tag(data_id, tag, user_id=user.user_id, connection=expdb_db) + await database.datasets.tag(data_id, tag, user_id=user.user_id, connection=expdb_db) return { "data_tag": {"id": str(data_id), "tag": [*tags, tag]}, } @@ -75,7 +77,7 @@ class DatasetStatusFilter(StrEnum): @router.post(path="/list", description="Provided for convenience, same as `GET` endpoint.") @router.get(path="/list") -def list_datasets( # noqa: PLR0913 +async def list_datasets( # noqa: PLR0913 pagination: Annotated[Pagination, Body(default_factory=Pagination)], data_name: Annotated[str | None, CasualString128] = None, tag: Annotated[str | None, SystemString64] = None, @@ -100,8 +102,9 @@ def list_datasets( # noqa: PLR0913 number_missing_values: Annotated[str | None, IntegerRange] = None, status: Annotated[DatasetStatusFilter, Body()] = DatasetStatusFilter.ACTIVE, user: Annotated[User | None, Depends(fetch_user)] = None, - expdb_db: Annotated[Connection, Depends(expdb_connection)] = None, + expdb_db: Annotated[AsyncConnection | None, Depends(expdb_connection)] = None, ) -> list[dict[str, Any]]: + assert expdb_db is not None # noqa: S101 current_status = text( """ SELECT ds1.`did`, ds1.`status` @@ -126,7 +129,7 @@ def list_datasets( # noqa: PLR0913 where_status = ",".join(f"'{status}'" for status in statuses) if user is None: visible_to_user = "`visibility`='public'" - elif UserGroup.ADMIN in user.groups: + elif UserGroup.ADMIN in await user.get_groups(): visible_to_user = "TRUE" else: visible_to_user = f"(`visibility`='public' OR `uploader`={user.user_id})" @@ -190,7 +193,7 @@ def quality_clause(quality: str, range_: str | None) -> str: # subquery also has no user input. So I think this should be safe. ) columns = ["did", "name", "version", "format", "file_id", "status"] - rows = expdb_db.execute( + result = await expdb_db.execute( matching_filter, parameters={ "tag": tag, @@ -199,6 +202,7 @@ def quality_clause(quality: str, range_: str | None) -> str: "uploader": uploader, }, ) + rows = result.all() datasets: dict[int, dict[str, Any]] = { row.did: dict(zip(columns, row, strict=True)) for row in rows } @@ -230,7 +234,7 @@ def quality_clause(quality: str, range_: str | None) -> str: "NumberOfNumericFeatures", "NumberOfSymbolicFeatures", ] - qualities_by_dataset = database.qualities.get_for_datasets( + qualities_by_dataset = await database.qualities.get_for_datasets( dataset_ids=datasets.keys(), quality_names=qualities_to_show, connection=expdb_db, @@ -246,10 +250,16 @@ class ProcessingInformation(NamedTuple): error: str | None -def _get_processing_information(dataset_id: int, connection: Connection) -> ProcessingInformation: +async def _get_processing_information( + dataset_id: int, + connection: AsyncConnection, +) -> ProcessingInformation: """Return processing information, if any. Otherwise, all fields `None`.""" if not ( - data_processed := database.datasets.get_latest_processing_update(dataset_id, connection) + data_processed := await database.datasets.get_latest_processing_update( + dataset_id, + connection, + ) ): return ProcessingInformation(date=None, warning=None, error=None) @@ -259,20 +269,20 @@ def _get_processing_information(dataset_id: int, connection: Connection) -> Proc return ProcessingInformation(date=date_processed, warning=warning, error=error) -def _get_dataset_raise_otherwise( +async def _get_dataset_raise_otherwise( dataset_id: int, user: User | None, - expdb: Connection, -) -> Row: + expdb: AsyncConnection, +) -> Row[Any]: """Fetches the dataset from the database if it exists and the user has permissions. Raises HTTPException if the dataset does not exist or the user can not access it. """ - if not (dataset := database.datasets.get(dataset_id, expdb)): + if not (dataset := await database.datasets.get(dataset_id, expdb)): error = _format_error(code=DatasetError.NOT_FOUND, message="Unknown dataset") raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail=error) - if not _user_has_access(dataset=dataset, user=user): + if not await _user_has_access(dataset=dataset, user=user): error = _format_error(code=DatasetError.NO_ACCESS, message="No access granted") raise HTTPException(status_code=HTTPStatus.FORBIDDEN, detail=error) @@ -280,22 +290,23 @@ def _get_dataset_raise_otherwise( @router.get("/features/{dataset_id}", response_model_exclude_none=True) -def get_dataset_features( +async def get_dataset_features( dataset_id: int, user: Annotated[User | None, Depends(fetch_user)] = None, - expdb: Annotated[Connection, Depends(expdb_connection)] = None, + expdb: Annotated[AsyncConnection | None, Depends(expdb_connection)] = None, ) -> list[Feature]: - _get_dataset_raise_otherwise(dataset_id, user, expdb) - features = database.datasets.get_features(dataset_id, expdb) + assert expdb is not None # noqa: S101 + await _get_dataset_raise_otherwise(dataset_id, user, expdb) + features = await database.datasets.get_features(dataset_id, expdb) for feature in [f for f in features if f.data_type == FeatureType.NOMINAL]: - feature.nominal_values = database.datasets.get_feature_values( + feature.nominal_values = await database.datasets.get_feature_values( dataset_id, feature_index=feature.index, connection=expdb, ) if not features: - processing_state = database.datasets.get_latest_processing_update(dataset_id, expdb) + processing_state = await database.datasets.get_latest_processing_update(dataset_id, expdb) if processing_state is None: code, msg = ( 273, @@ -318,11 +329,11 @@ def get_dataset_features( @router.post( path="/status/update", ) -def update_dataset_status( +async def update_dataset_status( dataset_id: Annotated[int, Body()], status: Annotated[Literal[DatasetStatus.ACTIVE, DatasetStatus.DEACTIVATED], Body()], user: Annotated[User | None, Depends(fetch_user)], - expdb: Annotated[Connection, Depends(expdb_connection)], + expdb: Annotated[AsyncConnection, Depends(expdb_connection)], ) -> dict[str, str | int]: if user is None: raise HTTPException( @@ -330,21 +341,21 @@ def update_dataset_status( detail="Updating dataset status required authorization", ) - dataset = _get_dataset_raise_otherwise(dataset_id, user, expdb) + dataset = await _get_dataset_raise_otherwise(dataset_id, user, expdb) - can_deactivate = dataset.uploader == user.user_id or UserGroup.ADMIN in user.groups + can_deactivate = dataset.uploader == user.user_id or UserGroup.ADMIN in await user.get_groups() if status == DatasetStatus.DEACTIVATED and not can_deactivate: raise HTTPException( status_code=HTTPStatus.FORBIDDEN, detail={"code": 693, "message": "Dataset is not owned by you"}, ) - if status == DatasetStatus.ACTIVE and UserGroup.ADMIN not in user.groups: + if status == DatasetStatus.ACTIVE and UserGroup.ADMIN not in await user.get_groups(): raise HTTPException( status_code=HTTPStatus.FORBIDDEN, detail={"code": 696, "message": "Only administrators can activate datasets."}, ) - current_status = database.datasets.get_status(dataset_id, expdb) + current_status = await database.datasets.get_status(dataset_id, expdb) if current_status and current_status.status == status: raise HTTPException( status_code=HTTPStatus.PRECONDITION_FAILED, @@ -358,9 +369,14 @@ def update_dataset_status( # - active => deactivated (add a row) # - deactivated => active (delete a row) if current_status is None or status == DatasetStatus.DEACTIVATED: - database.datasets.update_status(dataset_id, status, user_id=user.user_id, connection=expdb) + await database.datasets.update_status( + dataset_id, + status, + user_id=user.user_id, + connection=expdb, + ) elif current_status.status == DatasetStatus.DEACTIVATED: - database.datasets.remove_deactivated_status(dataset_id, expdb) + await database.datasets.remove_deactivated_status(dataset_id, expdb) else: raise HTTPException( status_code=HTTPStatus.INTERNAL_SERVER_ERROR, @@ -374,15 +390,20 @@ def update_dataset_status( path="/{dataset_id}", description="Get meta-data for dataset with ID `dataset_id`.", ) -def get_dataset( +async def get_dataset( dataset_id: int, user: Annotated[User | None, Depends(fetch_user)] = None, - user_db: Annotated[Connection, Depends(userdb_connection)] = None, - expdb_db: Annotated[Connection, Depends(expdb_connection)] = None, + user_db: Annotated[AsyncConnection | None, Depends(userdb_connection)] = None, + expdb_db: Annotated[AsyncConnection | None, Depends(expdb_connection)] = None, ) -> DatasetMetadata: - dataset = _get_dataset_raise_otherwise(dataset_id, user, expdb_db) + assert user_db is not None # noqa: S101 + assert expdb_db is not None # noqa: S101 + dataset = await _get_dataset_raise_otherwise(dataset_id, user, expdb_db) if not ( - dataset_file := database.datasets.get_file(file_id=dataset.file_id, connection=user_db) + dataset_file := await database.datasets.get_file( + file_id=dataset.file_id, + connection=user_db, + ) ): error = _format_error( code=DatasetError.NO_DATA_FILE, @@ -390,10 +411,10 @@ def get_dataset( ) raise HTTPException(status_code=HTTPStatus.PRECONDITION_FAILED, detail=error) - tags = database.datasets.get_tags_for(dataset_id, expdb_db) - description = database.datasets.get_description(dataset_id, expdb_db) - processing_result = _get_processing_information(dataset_id, expdb_db) - status = database.datasets.get_status(dataset_id, expdb_db) + tags = await database.datasets.get_tags_for(dataset_id, expdb_db) + description = await database.datasets.get_description(dataset_id, expdb_db) + processing_result = await _get_processing_information(dataset_id, expdb_db) + status = await database.datasets.get_status(dataset_id, expdb_db) status_ = DatasetStatus(status.status) if status else DatasetStatus.IN_PREPARATION diff --git a/src/routers/openml/estimation_procedure.py b/src/routers/openml/estimation_procedure.py index 1ebaf929..5df5c798 100644 --- a/src/routers/openml/estimation_procedure.py +++ b/src/routers/openml/estimation_procedure.py @@ -2,7 +2,7 @@ from typing import Annotated from fastapi import APIRouter, Depends -from sqlalchemy import Connection +from sqlalchemy.ext.asyncio import AsyncConnection import database.evaluations from routers.dependencies import expdb_connection @@ -12,7 +12,7 @@ @router.get("/list", response_model_exclude_none=True) -def get_estimation_procedures( - expdb: Annotated[Connection, Depends(expdb_connection)], +async def get_estimation_procedures( + expdb: Annotated[AsyncConnection, Depends(expdb_connection)], ) -> Iterable[EstimationProcedure]: - return database.evaluations.get_estimation_procedures(expdb) + return await database.evaluations.get_estimation_procedures(expdb) diff --git a/src/routers/openml/evaluations.py b/src/routers/openml/evaluations.py index 49178b7a..f6650b36 100644 --- a/src/routers/openml/evaluations.py +++ b/src/routers/openml/evaluations.py @@ -1,7 +1,7 @@ from typing import Annotated from fastapi import APIRouter, Depends -from sqlalchemy import Connection +from sqlalchemy.ext.asyncio import AsyncConnection import database.evaluations from routers.dependencies import expdb_connection @@ -10,8 +10,10 @@ @router.get("/list") -def get_evaluation_measures(expdb: Annotated[Connection, Depends(expdb_connection)]) -> list[str]: - functions = database.evaluations.get_math_functions( +async def get_evaluation_measures( + expdb: Annotated[AsyncConnection, Depends(expdb_connection)], +) -> list[str]: + functions = await database.evaluations.get_math_functions( function_type="EvaluationFunction", connection=expdb, ) diff --git a/src/routers/openml/flows.py b/src/routers/openml/flows.py index cb6df5d9..5ed71d83 100644 --- a/src/routers/openml/flows.py +++ b/src/routers/openml/flows.py @@ -2,7 +2,7 @@ from typing import Annotated, Literal from fastapi import APIRouter, Depends, HTTPException -from sqlalchemy import Connection +from sqlalchemy.ext.asyncio import AsyncConnection import database.flows from core.conversions import _str_to_num @@ -13,13 +13,17 @@ @router.get("/exists/{name}/{external_version}") -def flow_exists( +async def flow_exists( name: str, external_version: str, - expdb: Annotated[Connection, Depends(expdb_connection)], + expdb: Annotated[AsyncConnection, Depends(expdb_connection)], ) -> dict[Literal["flow_id"], int]: """Check if a Flow with the name and version exists, if so, return the flow id.""" - flow = database.flows.get_by_name(name=name, external_version=external_version, expdb=expdb) + flow = await database.flows.get_by_name( + name=name, + external_version=external_version, + expdb=expdb, + ) if flow is None: raise HTTPException( status_code=HTTPStatus.NOT_FOUND, @@ -29,12 +33,16 @@ def flow_exists( @router.get("/{flow_id}") -def get_flow(flow_id: int, expdb: Annotated[Connection, Depends(expdb_connection)] = None) -> Flow: - flow = database.flows.get(flow_id, expdb) +async def get_flow( + flow_id: int, + expdb: Annotated[AsyncConnection | None, Depends(expdb_connection)] = None, +) -> Flow: + assert expdb is not None # noqa: S101 + flow = await database.flows.get(flow_id, expdb) if not flow: raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail="Flow not found") - parameter_rows = database.flows.get_parameters(flow_id, expdb) + parameter_rows = await database.flows.get_parameters(flow_id, expdb) parameters = [ Parameter( name=parameter.name, @@ -48,12 +56,12 @@ def get_flow(flow_id: int, expdb: Annotated[Connection, Depends(expdb_connection for parameter in parameter_rows ] - tags = database.flows.get_tags(flow_id, expdb) - subflow_rows = database.flows.get_subflows(flow_id, expdb) + tags = await database.flows.get_tags(flow_id, expdb) + subflow_rows = await database.flows.get_subflows(flow_id, expdb) subflows = [ Subflow( identifier=subflow.identifier, - flow=get_flow(flow_id=subflow.child_id, expdb=expdb), + flow=await get_flow(flow_id=subflow.child_id, expdb=expdb), ) for subflow in subflow_rows ] diff --git a/src/routers/openml/qualities.py b/src/routers/openml/qualities.py index 54181f8f..b71ca52f 100644 --- a/src/routers/openml/qualities.py +++ b/src/routers/openml/qualities.py @@ -2,7 +2,7 @@ from typing import Annotated, Literal from fastapi import APIRouter, Depends, HTTPException -from sqlalchemy import Connection +from sqlalchemy.ext.asyncio import AsyncConnection import database.datasets import database.qualities @@ -16,10 +16,10 @@ @router.get("/qualities/list") -def list_qualities( - expdb: Annotated[Connection, Depends(expdb_connection)], +async def list_qualities( + expdb: Annotated[AsyncConnection, Depends(expdb_connection)], ) -> dict[Literal["data_qualities_list"], dict[Literal["quality"], list[str]]]: - qualities = database.qualities.list_all_qualities(connection=expdb) + qualities = await database.qualities.list_all_qualities(connection=expdb) return { "data_qualities_list": { "quality": qualities, @@ -28,18 +28,18 @@ def list_qualities( @router.get("/qualities/{dataset_id}") -def get_qualities( +async def get_qualities( dataset_id: int, user: Annotated[User | None, Depends(fetch_user)], - expdb: Annotated[Connection, Depends(expdb_connection)], + expdb: Annotated[AsyncConnection, Depends(expdb_connection)], ) -> list[Quality]: - dataset = database.datasets.get(dataset_id, expdb) - if not dataset or not _user_has_access(dataset, user): + dataset = await database.datasets.get(dataset_id, expdb) + if not dataset or not await _user_has_access(dataset, user): raise HTTPException( status_code=HTTPStatus.PRECONDITION_FAILED, detail={"code": DatasetError.NO_DATA_FILE, "message": "Unknown dataset"}, ) from None - return database.qualities.get_for_dataset(dataset_id, expdb) + return await database.qualities.get_for_dataset(dataset_id, expdb) # The PHP API provided (sometime) helpful error messages # if not qualities: # check if dataset exists: error 360 diff --git a/src/routers/openml/study.py b/src/routers/openml/study.py index 6fe1dcc6..7175da43 100644 --- a/src/routers/openml/study.py +++ b/src/routers/openml/study.py @@ -1,9 +1,10 @@ from http import HTTPStatus -from typing import Annotated, Literal +from typing import Annotated, Any, Literal from fastapi import APIRouter, Body, Depends, HTTPException from pydantic import BaseModel -from sqlalchemy import Connection, Row +from sqlalchemy.engine import Row +from sqlalchemy.ext.asyncio import AsyncConnection import database.studies from core.formatting import _str_to_bool @@ -15,11 +16,15 @@ router = APIRouter(prefix="/studies", tags=["studies"]) -def _get_study_raise_otherwise(id_or_alias: int | str, user: User | None, expdb: Connection) -> Row: +async def _get_study_raise_otherwise( + id_or_alias: int | str, + user: User | None, + expdb: AsyncConnection, +) -> Row[Any]: if isinstance(id_or_alias, int) or id_or_alias.isdigit(): - study = database.studies.get_by_id(int(id_or_alias), expdb) + study = await database.studies.get_by_id(int(id_or_alias), expdb) else: - study = database.studies.get_by_alias(id_or_alias, expdb) + study = await database.studies.get_by_alias(id_or_alias, expdb) if study is None: raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail="Study not found.") @@ -29,7 +34,7 @@ def _get_study_raise_otherwise(id_or_alias: int | str, user: User | None, expdb: status_code=HTTPStatus.UNAUTHORIZED, detail="Must authenticate for private study.", ) - if study.creator != user.user_id and UserGroup.ADMIN not in user.groups: + if study.creator != user.user_id and UserGroup.ADMIN not in await user.get_groups(): raise HTTPException(status_code=HTTPStatus.FORBIDDEN, detail="Study is private.") if _str_to_bool(study.legacy): raise HTTPException( @@ -45,17 +50,18 @@ class AttachDetachResponse(BaseModel): @router.post("/attach") -def attach_to_study( +async def attach_to_study( study_id: Annotated[int, Body()], entity_ids: Annotated[list[int], Body()], user: Annotated[User | None, Depends(fetch_user)] = None, - expdb: Annotated[Connection, Depends(expdb_connection)] = None, + expdb: Annotated[AsyncConnection | None, Depends(expdb_connection)] = None, ) -> AttachDetachResponse: + assert expdb is not None # noqa: S101 if user is None: raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, detail="User not found.") - study = _get_study_raise_otherwise(study_id, user, expdb) + study = await _get_study_raise_otherwise(study_id, user, expdb) # PHP lets *anyone* edit *any* study. We're not going to do that. - if study.creator != user.user_id and UserGroup.ADMIN not in user.groups: + if study.creator != user.user_id and UserGroup.ADMIN not in await user.get_groups(): raise HTTPException( status_code=HTTPStatus.FORBIDDEN, detail="Study can only be edited by its creator.", @@ -75,9 +81,9 @@ def attach_to_study( } try: if study.type_ == StudyType.TASK: - database.studies.attach_tasks(task_ids=entity_ids, **attach_kwargs) + await database.studies.attach_tasks(task_ids=entity_ids, **attach_kwargs) else: - database.studies.attach_runs(run_ids=entity_ids, **attach_kwargs) + await database.studies.attach_runs(run_ids=entity_ids, **attach_kwargs) except ValueError as e: raise HTTPException( status_code=HTTPStatus.CONFLICT, @@ -87,11 +93,12 @@ def attach_to_study( @router.post("/") -def create_study( +async def create_study( study: CreateStudy, user: Annotated[User | None, Depends(fetch_user)] = None, - expdb: Annotated[Connection, Depends(expdb_connection)] = None, + expdb: Annotated[AsyncConnection | None, Depends(expdb_connection)] = None, ) -> dict[Literal["study_id"], int]: + assert expdb is not None # noqa: S101 if user is None: raise HTTPException( status_code=HTTPStatus.UNAUTHORIZED, @@ -107,30 +114,36 @@ def create_study( status_code=HTTPStatus.BAD_REQUEST, detail="Cannot create a task study with runs.", ) - if study.alias and database.studies.get_by_alias(study.alias, expdb): + if study.alias and await database.studies.get_by_alias(study.alias, expdb): raise HTTPException( status_code=HTTPStatus.CONFLICT, detail="Study alias already exists.", ) - study_id = database.studies.create(study, user, expdb) + study_id = await database.studies.create(study, user, expdb) if study.main_entity_type == StudyType.TASK: for task_id in study.tasks: - database.studies.attach_task(task_id, study_id, user, expdb) + await database.studies.attach_task(task_id, study_id, user, expdb) if study.main_entity_type == StudyType.RUN: for run_id in study.runs: - database.studies.attach_run(run_id=run_id, study_id=study_id, user=user, expdb=expdb) + await database.studies.attach_run( + run_id=run_id, + study_id=study_id, + user=user, + expdb=expdb, + ) # Make sure that invalid fields raise an error (e.g., "task_ids") return {"study_id": study_id} @router.get("/{alias_or_id}") -def get_study( +async def get_study( alias_or_id: int | str, user: Annotated[User | None, Depends(fetch_user)] = None, - expdb: Annotated[Connection, Depends(expdb_connection)] = None, + expdb: Annotated[AsyncConnection | None, Depends(expdb_connection)] = None, ) -> Study: - study = _get_study_raise_otherwise(alias_or_id, user, expdb) - study_data = database.studies.get_study_data(study, expdb) + assert expdb is not None # noqa: S101 + study = await _get_study_raise_otherwise(alias_or_id, user, expdb) + study_data = await database.studies.get_study_data(study, expdb) return Study( _legacy=_str_to_bool(study.legacy), id_=study.id, diff --git a/src/routers/openml/tasks.py b/src/routers/openml/tasks.py index 8397f1da..83ebf719 100644 --- a/src/routers/openml/tasks.py +++ b/src/routers/openml/tasks.py @@ -5,7 +5,8 @@ import xmltodict from fastapi import APIRouter, Depends, HTTPException -from sqlalchemy import Connection, RowMapping, text +from sqlalchemy import RowMapping, text +from sqlalchemy.ext.asyncio import AsyncConnection import config import database.datasets @@ -27,11 +28,11 @@ def convert_template_xml_to_json(xml_template: str) -> dict[str, JSON]: return cast("dict[str, JSON]", json.loads(json_str)) -def fill_template( +async def fill_template( template: str, task: RowMapping, task_inputs: dict[str, str | int], - connection: Connection, + connection: AsyncConnection, ) -> dict[str, JSON]: """Fill in the XML template as used for task descriptions and return the result, converted to JSON. @@ -83,7 +84,7 @@ def fill_template( json_template = convert_template_xml_to_json(template) return cast( "dict[str, JSON]", - _fill_json_template( + await _fill_json_template( json_template, task, task_inputs, @@ -93,21 +94,22 @@ def fill_template( ) -def _fill_json_template( +async def _fill_json_template( template: JSON, task: RowMapping, task_inputs: dict[str, str | int], fetched_data: dict[str, str], - connection: Connection, + connection: AsyncConnection, ) -> JSON: if isinstance(template, dict): return { - k: _fill_json_template(v, task, task_inputs, fetched_data, connection) + k: await _fill_json_template(v, task, task_inputs, fetched_data, connection) for k, v in template.items() } if isinstance(template, list): return [ - _fill_json_template(v, task, task_inputs, fetched_data, connection) for v in template + await _fill_json_template(v, task, task_inputs, fetched_data, connection) + for v in template ] if not isinstance(template, str): msg = f"Unexpected type for `template`: {template=}, {type(template)=}" @@ -125,7 +127,7 @@ def _fill_json_template( (field,) = match.groups() if field not in fetched_data: table, _ = field.split(".") - rows = connection.execute( + result = await connection.execute( text( f""" SELECT * @@ -137,7 +139,8 @@ def _fill_json_template( # quotes which is not legal. parameters={"id_": int(task_inputs[table])}, ) - for column, value in next(rows.mappings()).items(): + rows = result.mappings() + for column, value in next(rows).items(): fetched_data[f"{table}.{column}"] = value if match.string == template: return fetched_data[field] @@ -150,13 +153,14 @@ def _fill_json_template( @router.get("/{task_id}") -def get_task( +async def get_task( task_id: int, - expdb: Annotated[Connection, Depends(expdb_connection)] = None, + expdb: Annotated[AsyncConnection | None, Depends(expdb_connection)] = None, ) -> Task: - if not (task := database.tasks.get(task_id, expdb)): + assert expdb is not None # noqa: S101 + if not (task := await database.tasks.get(task_id, expdb)): raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail="Task not found") - if not (task_type := database.tasks.get_task_type(task.ttid, expdb)): + if not (task_type := await database.tasks.get_task_type(task.ttid, expdb)): raise HTTPException( status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail="Task type not found", @@ -164,12 +168,12 @@ def get_task( task_inputs = { row.input: int(row.value) if row.value.isdigit() else row.value - for row in database.tasks.get_input_for_task(task_id, expdb) + for row in await database.tasks.get_input_for_task(task_id, expdb) } - ttios = database.tasks.get_task_type_inout_with_template(task_type.ttid, expdb) + ttios = await database.tasks.get_task_type_inout_with_template(task_type.ttid, expdb) templates = [(tt_io.name, tt_io.io, tt_io.requirement, tt_io.template_api) for tt_io in ttios] inputs = [ - fill_template(template, task, task_inputs, expdb) | {"name": name} + await fill_template(template, task, task_inputs, expdb) | {"name": name} for name, io, required, template in templates if io == "input" ] @@ -178,10 +182,10 @@ def get_task( for name, io, required, template in templates if io == "output" ] - tags = database.tasks.get_tags(task_id, expdb) + tags = await database.tasks.get_tags(task_id, expdb) name = f"Task {task_id} ({task_type.name})" dataset_id = task_inputs.get("source_data") - if isinstance(dataset_id, int) and (dataset := database.datasets.get(dataset_id, expdb)): + if isinstance(dataset_id, int) and (dataset := await database.datasets.get(dataset_id, expdb)): name = f"Task {task_id}: {dataset.name} ({task_type.name})" return Task( diff --git a/src/routers/openml/tasktype.py b/src/routers/openml/tasktype.py index 5213f177..63a1e879 100644 --- a/src/routers/openml/tasktype.py +++ b/src/routers/openml/tasktype.py @@ -3,7 +3,8 @@ from typing import Annotated, Any, Literal, cast from fastapi import APIRouter, Depends, HTTPException -from sqlalchemy import Connection, Row +from sqlalchemy.engine import Row +from sqlalchemy.ext.asyncio import AsyncConnection from database.tasks import get_input_for_task_type, get_task_types from database.tasks import get_task_type as db_get_task_type @@ -12,7 +13,7 @@ router = APIRouter(prefix="/tasktype", tags=["tasks"]) -def _normalize_task_type(task_type: Row) -> dict[str, str | None | list[Any]]: +def _normalize_task_type(task_type: Row[Any]) -> dict[str, str | None | list[Any]]: # Task types may contain multi-line fields which have either \r\n or \n line endings ttype: dict[str, str | None | list[Any]] = { k: str(v).replace("\r\n", "\n").strip() if v is not None else v @@ -26,24 +27,25 @@ def _normalize_task_type(task_type: Row) -> dict[str, str | None | list[Any]]: @router.get(path="/list") -def list_task_types( - expdb: Annotated[Connection, Depends(expdb_connection)] = None, +async def list_task_types( + expdb: Annotated[AsyncConnection | None, Depends(expdb_connection)] = None, ) -> dict[ Literal["task_types"], dict[Literal["task_type"], list[dict[str, str | None | list[Any]]]], ]: + assert expdb is not None # noqa: S101 task_types: list[dict[str, str | None | list[Any]]] = [ - _normalize_task_type(ttype) for ttype in get_task_types(expdb) + _normalize_task_type(ttype) for ttype in await get_task_types(expdb) ] return {"task_types": {"task_type": task_types}} @router.get(path="/{task_type_id}") -def get_task_type( +async def get_task_type( task_type_id: int, - expdb: Annotated[Connection, Depends(expdb_connection)], + expdb: Annotated[AsyncConnection, Depends(expdb_connection)], ) -> dict[Literal["task_type"], dict[str, str | None | list[str] | list[dict[str, str]]]]: - task_type_record = db_get_task_type(task_type_id, expdb) + task_type_record = await db_get_task_type(task_type_id, expdb) if task_type_record is None: raise HTTPException( status_code=HTTPStatus.PRECONDITION_FAILED, @@ -60,7 +62,7 @@ def get_task_type( creator.strip(' "') for creator in cast("str", contributors).split(",") ] task_type["creation_date"] = task_type.pop("creationDate") - task_type_inputs = get_input_for_task_type(task_type_id, expdb) + task_type_inputs = await get_input_for_task_type(task_type_id, expdb) input_types = [] for task_type_input in task_type_inputs: input_ = {} diff --git a/tests/conftest.py b/tests/conftest.py index eecc1288..7edfeba4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,6 @@ import contextlib import json -from collections.abc import Iterator +from collections.abc import AsyncIterator, Iterator from pathlib import Path from typing import Any, NamedTuple @@ -10,7 +10,8 @@ from _pytest.config import Config from _pytest.nodes import Item from fastapi.testclient import TestClient -from sqlalchemy import Connection, Engine, text +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine from database.setup import expdb_database, user_database from main import create_api @@ -19,24 +20,24 @@ PHP_API_URL = "http://openml-php-rest-api:80/api/v1/json" -@contextlib.contextmanager -def automatic_rollback(engine: Engine) -> Iterator[Connection]: - with engine.connect() as connection: - transaction = connection.begin() +@contextlib.asynccontextmanager +async def automatic_rollback(engine: AsyncEngine) -> AsyncIterator[AsyncConnection]: + async with engine.connect() as connection: + transaction = await connection.begin() yield connection if transaction.is_active: - transaction.rollback() + await transaction.rollback() @pytest.fixture -def expdb_test() -> Connection: - with automatic_rollback(expdb_database()) as connection: +async def expdb_test() -> AsyncIterator[AsyncConnection]: + async with automatic_rollback(expdb_database()) as connection: yield connection @pytest.fixture -def user_test() -> Connection: - with automatic_rollback(user_database()) as connection: +async def user_test() -> AsyncIterator[AsyncConnection]: + async with automatic_rollback(user_database()) as connection: yield connection @@ -47,11 +48,19 @@ def php_api() -> httpx.Client: @pytest.fixture -def py_api(expdb_test: Connection, user_test: Connection) -> TestClient: +def py_api(expdb_test: AsyncConnection, user_test: AsyncConnection) -> TestClient: app = create_api() + # We use the lambda definitions because fixtures may not be called directly. - app.dependency_overrides[expdb_connection] = lambda: expdb_test - app.dependency_overrides[userdb_connection] = lambda: user_test + # The lambda returns an async generator for FastAPI to handle properly + async def override_expdb() -> AsyncIterator[AsyncConnection]: + yield expdb_test + + async def override_userdb() -> AsyncIterator[AsyncConnection]: + yield user_test + + app.dependency_overrides[expdb_connection] = override_expdb + app.dependency_overrides[userdb_connection] = override_userdb return TestClient(app) @@ -76,8 +85,8 @@ class Flow(NamedTuple): @pytest.fixture -def flow(expdb_test: Connection) -> Flow: - expdb_test.execute( +async def flow(expdb_test: AsyncConnection) -> Flow: + await expdb_test.execute( text( """ INSERT INTO implementation(fullname,name,version,external_version,uploadDate) @@ -85,19 +94,20 @@ def flow(expdb_test: Connection) -> Flow: """, ), ) - (flow_id,) = expdb_test.execute(text("""SELECT LAST_INSERT_ID();""")).one() + result = await expdb_test.execute(text("""SELECT LAST_INSERT_ID();""")) + (flow_id,) = result.one() return Flow(id=flow_id, name="name", external_version="external_version") @pytest.fixture -def persisted_flow(flow: Flow, expdb_test: Connection) -> Iterator[Flow]: - expdb_test.commit() +async def persisted_flow(flow: Flow, expdb_test: AsyncConnection) -> AsyncIterator[Flow]: + await expdb_test.commit() yield flow # We want to ensure the commit below does not accidentally persist new # data to the database. - expdb_test.rollback() + await expdb_test.rollback() - expdb_test.execute( + await expdb_test.execute( text( """ DELETE FROM implementation @@ -106,7 +116,7 @@ def persisted_flow(flow: Flow, expdb_test: Connection) -> Iterator[Flow]: ), parameters={"flow_id": flow.id}, ) - expdb_test.commit() + await expdb_test.commit() def pytest_collection_modifyitems(config: Config, items: list[Item]) -> None: # noqa: ARG001 diff --git a/tests/database/flows_test.py b/tests/database/flows_test.py index 7c952eaa..a8b98d84 100644 --- a/tests/database/flows_test.py +++ b/tests/database/flows_test.py @@ -1,18 +1,18 @@ -from sqlalchemy import Connection +from sqlalchemy.ext.asyncio import AsyncConnection import database.flows from tests.conftest import Flow -def test_database_flow_exists(flow: Flow, expdb_test: Connection) -> None: - retrieved_flow = database.flows.get_by_name(flow.name, flow.external_version, expdb_test) +async def test_database_flow_exists(flow: Flow, expdb_test: AsyncConnection) -> None: + retrieved_flow = await database.flows.get_by_name(flow.name, flow.external_version, expdb_test) assert retrieved_flow is not None assert retrieved_flow.id == flow.id # when using actual ORM, can instead ensure _all_ fields match. -def test_database_flow_exists_returns_none_if_no_match(expdb_test: Connection) -> None: - retrieved_flow = database.flows.get_by_name( +async def test_database_flow_exists_returns_none_if_no_match(expdb_test: AsyncConnection) -> None: + retrieved_flow = await database.flows.get_by_name( name="foo", external_version="bar", expdb=expdb_test, diff --git a/tests/routers/openml/dataset_tag_test.py b/tests/routers/openml/dataset_tag_test.py index 5449862a..9e3fd447 100644 --- a/tests/routers/openml/dataset_tag_test.py +++ b/tests/routers/openml/dataset_tag_test.py @@ -1,7 +1,7 @@ from http import HTTPStatus import pytest -from sqlalchemy import Connection +from sqlalchemy.ext.asyncio import AsyncConnection from starlette.testclient import TestClient from database.datasets import get_tags_for @@ -29,7 +29,7 @@ def test_dataset_tag_rejects_unauthorized(key: ApiKey, py_api: TestClient) -> No [ApiKey.ADMIN, ApiKey.SOME_USER, ApiKey.OWNER_USER], ids=["administrator", "non-owner", "owner"], ) -def test_dataset_tag(key: ApiKey, expdb_test: Connection, py_api: TestClient) -> None: +async def test_dataset_tag(key: ApiKey, expdb_test: AsyncConnection, py_api: TestClient) -> None: dataset_id, tag = next(iter(constants.PRIVATE_DATASET_ID)), "test" response = py_api.post( f"/datasets/tag?api_key={key}", @@ -38,7 +38,7 @@ def test_dataset_tag(key: ApiKey, expdb_test: Connection, py_api: TestClient) -> assert response.status_code == HTTPStatus.OK assert response.json() == {"data_tag": {"id": str(dataset_id), "tag": [tag]}} - tags = get_tags_for(id_=dataset_id, connection=expdb_test) + tags = await get_tags_for(id_=dataset_id, connection=expdb_test) assert tag in tags diff --git a/tests/routers/openml/datasets_test.py b/tests/routers/openml/datasets_test.py index b463d3d7..51d2613f 100644 --- a/tests/routers/openml/datasets_test.py +++ b/tests/routers/openml/datasets_test.py @@ -2,7 +2,7 @@ import pytest from fastapi import HTTPException -from sqlalchemy import Connection +from sqlalchemy.ext.asyncio import AsyncConnection from starlette.testclient import TestClient from database.users import User @@ -76,12 +76,12 @@ def test_get_dataset(py_api: TestClient) -> None: SOME_USER, ], ) -def test_private_dataset_no_access( +async def test_private_dataset_no_access( user: User | None, - expdb_test: Connection, + expdb_test: AsyncConnection, ) -> None: with pytest.raises(HTTPException) as e: - get_dataset( + await get_dataset( dataset_id=130, user=user, user_db=None, @@ -94,8 +94,10 @@ def test_private_dataset_no_access( @pytest.mark.parametrize( "user", [OWNER_USER, ADMIN_USER, pytest.param(SOME_USER, marks=pytest.mark.xfail)] ) -def test_private_dataset_access(user: User, expdb_test: Connection, user_test: Connection) -> None: - dataset = get_dataset( +async def test_private_dataset_access( + user: User, expdb_test: AsyncConnection, user_test: AsyncConnection +) -> None: + dataset = await get_dataset( dataset_id=130, user=user, user_db=user_test, diff --git a/tests/routers/openml/flows_test.py b/tests/routers/openml/flows_test.py index d5188d0e..d2c76513 100644 --- a/tests/routers/openml/flows_test.py +++ b/tests/routers/openml/flows_test.py @@ -4,7 +4,7 @@ import pytest from fastapi import HTTPException from pytest_mock import MockerFixture -from sqlalchemy import Connection +from sqlalchemy.ext.asyncio import AsyncConnection from starlette.testclient import TestClient from routers.openml.flows import flow_exists @@ -18,14 +18,14 @@ ("c", "d"), ], ) -def test_flow_exists_calls_db_correctly( +async def test_flow_exists_calls_db_correctly( name: str, external_version: str, - expdb_test: Connection, + expdb_test: AsyncConnection, mocker: MockerFixture, ) -> None: mocked_db = mocker.patch("database.flows.get_by_name") - flow_exists(name, external_version, expdb_test) + await flow_exists(name, external_version, expdb_test) mocked_db.assert_called_once_with( name=name, external_version=external_version, @@ -37,24 +37,26 @@ def test_flow_exists_calls_db_correctly( "flow_id", [1, 2], ) -def test_flow_exists_processes_found( +async def test_flow_exists_processes_found( flow_id: int, mocker: MockerFixture, - expdb_test: Connection, + expdb_test: AsyncConnection, ) -> None: fake_flow = mocker.MagicMock(id=flow_id) mocker.patch( "database.flows.get_by_name", return_value=fake_flow, ) - response = flow_exists("name", "external_version", expdb_test) + response = await flow_exists("name", "external_version", expdb_test) assert response == {"flow_id": fake_flow.id} -def test_flow_exists_handles_flow_not_found(mocker: MockerFixture, expdb_test: Connection) -> None: +async def test_flow_exists_handles_flow_not_found( + mocker: MockerFixture, expdb_test: AsyncConnection +) -> None: mocker.patch("database.flows.get_by_name", return_value=None) with pytest.raises(HTTPException) as error: - flow_exists("foo", "bar", expdb_test) + await flow_exists("foo", "bar", expdb_test) assert error.value.status_code == HTTPStatus.NOT_FOUND assert error.value.detail == "Flow not found." diff --git a/tests/routers/openml/qualities_test.py b/tests/routers/openml/qualities_test.py index eed569e9..bb31caa8 100644 --- a/tests/routers/openml/qualities_test.py +++ b/tests/routers/openml/qualities_test.py @@ -3,12 +3,13 @@ import deepdiff import httpx import pytest -from sqlalchemy import Connection, text +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncConnection from starlette.testclient import TestClient -def _remove_quality_from_database(quality_name: str, expdb_test: Connection) -> None: - expdb_test.execute( +async def _remove_quality_from_database(quality_name: str, expdb_test: AsyncConnection) -> None: + await expdb_test.execute( text( """ DELETE FROM data_quality @@ -17,7 +18,7 @@ def _remove_quality_from_database(quality_name: str, expdb_test: Connection) -> ), parameters={"deleted_quality": quality_name}, ) - expdb_test.execute( + await expdb_test.execute( text( """ DELETE FROM quality @@ -36,7 +37,7 @@ def test_list_qualities_identical(py_api: TestClient, php_api: httpx.Client) -> # To keep the test idempotent, we cannot test if reaction to database changes is identical -def test_list_qualities(py_api: TestClient, expdb_test: Connection) -> None: +async def test_list_qualities(py_api: TestClient, expdb_test: AsyncConnection) -> None: response = py_api.get("/datasets/qualities/list") assert response.status_code == HTTPStatus.OK expected = { @@ -155,7 +156,7 @@ def test_list_qualities(py_api: TestClient, expdb_test: Connection) -> None: assert expected == response.json() deleted = expected["data_qualities_list"]["quality"].pop() - _remove_quality_from_database(quality_name=deleted, expdb_test=expdb_test) + await _remove_quality_from_database(quality_name=deleted, expdb_test=expdb_test) response = py_api.get("/datasets/qualities/list") assert response.status_code == HTTPStatus.OK diff --git a/tests/routers/openml/study_test.py b/tests/routers/openml/study_test.py index f32b6b70..2e4d13ef 100644 --- a/tests/routers/openml/study_test.py +++ b/tests/routers/openml/study_test.py @@ -2,7 +2,8 @@ from http import HTTPStatus import httpx -from sqlalchemy import Connection, text +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncConnection from starlette.testclient import TestClient from schemas.study import StudyType @@ -501,24 +502,24 @@ def test_create_task_study(py_api: TestClient) -> None: assert new_study == expected -def _attach_tasks_to_study( +async def _attach_tasks_to_study( study_id: int, task_ids: list[int], api_key: str, py_api: TestClient, - expdb_test: Connection, + expdb_test: AsyncConnection, ) -> httpx.Response: # Adding requires the study to be in preparation, # but the current snapshot has no in-preparation studies. - expdb_test.execute(text("UPDATE study SET status = 'in_preparation' WHERE id = 1")) + await expdb_test.execute(text("UPDATE study SET status = 'in_preparation' WHERE id = 1")) return py_api.post( f"/studies/attach?api_key={api_key}", json={"study_id": study_id, "entity_ids": task_ids}, ) -def test_attach_task_to_study(py_api: TestClient, expdb_test: Connection) -> None: - response = _attach_tasks_to_study( +async def test_attach_task_to_study(py_api: TestClient, expdb_test: AsyncConnection) -> None: + response = await _attach_tasks_to_study( study_id=1, task_ids=[2, 3, 4], api_key="AD000000000000000000000000000000", @@ -529,9 +530,11 @@ def test_attach_task_to_study(py_api: TestClient, expdb_test: Connection) -> Non assert response.json() == {"study_id": 1, "main_entity_type": StudyType.TASK} -def test_attach_task_to_study_needs_owner(py_api: TestClient, expdb_test: Connection) -> None: - expdb_test.execute(text("UPDATE study SET status = 'in_preparation' WHERE id = 1")) - response = _attach_tasks_to_study( +async def test_attach_task_to_study_needs_owner( + py_api: TestClient, expdb_test: AsyncConnection +) -> None: + await expdb_test.execute(text("UPDATE study SET status = 'in_preparation' WHERE id = 1")) + response = await _attach_tasks_to_study( study_id=1, task_ids=[2, 3, 4], api_key="00000000000000000000000000000000", @@ -541,12 +544,12 @@ def test_attach_task_to_study_needs_owner(py_api: TestClient, expdb_test: Connec assert response.status_code == HTTPStatus.FORBIDDEN -def test_attach_task_to_study_already_linked_raises( +async def test_attach_task_to_study_already_linked_raises( py_api: TestClient, - expdb_test: Connection, + expdb_test: AsyncConnection, ) -> None: - expdb_test.execute(text("UPDATE study SET status = 'in_preparation' WHERE id = 1")) - response = _attach_tasks_to_study( + await expdb_test.execute(text("UPDATE study SET status = 'in_preparation' WHERE id = 1")) + response = await _attach_tasks_to_study( study_id=1, task_ids=[1, 3, 4], api_key="AD000000000000000000000000000000", @@ -557,12 +560,12 @@ def test_attach_task_to_study_already_linked_raises( assert response.json() == {"detail": "Task 1 is already attached to study 1."} -def test_attach_task_to_study_but_task_not_exist_raises( +async def test_attach_task_to_study_but_task_not_exist_raises( py_api: TestClient, - expdb_test: Connection, + expdb_test: AsyncConnection, ) -> None: - expdb_test.execute(text("UPDATE study SET status = 'in_preparation' WHERE id = 1")) - response = _attach_tasks_to_study( + await expdb_test.execute(text("UPDATE study SET status = 'in_preparation' WHERE id = 1")) + response = await _attach_tasks_to_study( study_id=1, task_ids=[80123, 78914], api_key="AD000000000000000000000000000000", diff --git a/tests/routers/openml/users_test.py b/tests/routers/openml/users_test.py index 7ce97680..ce6fe0c5 100644 --- a/tests/routers/openml/users_test.py +++ b/tests/routers/openml/users_test.py @@ -1,5 +1,5 @@ import pytest -from sqlalchemy import Connection +from sqlalchemy.ext.asyncio import AsyncConnection from database.users import User from routers.dependencies import fetch_user @@ -14,14 +14,14 @@ (ApiKey.SOME_USER, SOME_USER), ], ) -def test_fetch_user(api_key: str, user: User, user_test: Connection) -> None: - db_user = fetch_user(api_key, user_data=user_test) +async def test_fetch_user(api_key: str, user: User, user_test: AsyncConnection) -> None: + db_user = await fetch_user(api_key, user_data=user_test) assert db_user is not None assert user.user_id == db_user.user_id - assert user.groups == db_user.groups + assert user._groups == db_user._groups # noqa: SLF001 -def test_fetch_user_invalid_key_returns_none(user_test: Connection) -> None: - assert fetch_user(api_key=None, user_data=user_test) is None +async def test_fetch_user_invalid_key_returns_none(user_test: AsyncConnection) -> None: + assert await fetch_user(api_key=None, user_data=user_test) is None invalid_key = "f" * 32 - assert fetch_user(api_key=invalid_key, user_data=user_test) is None + assert await fetch_user(api_key=invalid_key, user_data=user_test) is None From cf5fc3f25856399250dd212119543d12376b55d5 Mon Sep 17 00:00:00 2001 From: rohansen856 Date: Mon, 2 Mar 2026 15:04:26 +0530 Subject: [PATCH 7/9] chore: fixed acc to github reviews Signed-off-by: rohansen856 --- src/database/setup.py | 8 ++++---- src/database/studies.py | 23 ++++++++++++++++++++++- src/database/users.py | 10 +++++----- src/routers/dependencies.py | 6 ++---- src/routers/mldcat_ap/dataset.py | 15 ++++++++++----- src/routers/openml/datasets.py | 10 +++++----- src/routers/openml/flows.py | 16 ++++++++-------- src/routers/openml/study.py | 6 +++--- src/routers/openml/tasks.py | 11 +++++++---- src/routers/openml/tasktype.py | 3 +-- tests/conftest.py | 9 +++++---- tests/routers/openml/flows_test.py | 12 ++++++++++-- tests/routers/openml/users_test.py | 2 +- 13 files changed, 83 insertions(+), 48 deletions(-) diff --git a/src/database/setup.py b/src/database/setup.py index 6f7e3017..6322415e 100644 --- a/src/database/setup.py +++ b/src/database/setup.py @@ -9,12 +9,12 @@ def _create_engine(database_name: str) -> AsyncEngine: database_configuration = load_database_configuration() - echo = database_configuration[database_name].pop("echo", False) + db_config = dict(database_configuration[database_name]) + echo = db_config.pop("echo", False) - # Update driver to use aiomysql for async support - database_configuration[database_name]["drivername"] = "mysql+aiomysql" + db_config["drivername"] = "mysql+aiomysql" - db_url = URL.create(**database_configuration[database_name]) + db_url = URL.create(**db_config) return create_async_engine( db_url, echo=echo, diff --git a/src/database/studies.py b/src/database/studies.py index 837e2bd2..0c11a531 100644 --- a/src/database/studies.py +++ b/src/database/studies.py @@ -174,4 +174,25 @@ async def attach_runs( user: User, connection: AsyncConnection, ) -> None: - raise NotImplementedError + to_link = [(study_id, run_id, user.user_id) for run_id in run_ids] + try: + await connection.execute( + text( + """ + INSERT INTO run_study (study_id, run_id, uploader) + VALUES (:study_id, :run_id, :user_id) + """, + ), + parameters=[{"study_id": s, "run_id": r, "user_id": u} for s, r, u in to_link], + ) + except Exception as e: + (msg,) = e.args + if match := re.search(r"Duplicate entry '(\d+)-(\d+)' for key 'run_study.PRIMARY'", msg): + msg = f"Run {match.group(2)} is already attached to study {match.group(1)}." + elif "a foreign key constraint fails" in msg: + msg = "One or more of the runs do not exist." + elif "Out of range value for column 'run_id'" in msg: + msg = "One specified id is not in the valid range of run ids." + else: + raise + raise ValueError(msg) from e diff --git a/src/database/users.py b/src/database/users.py index 91d97e7e..78a759b8 100644 --- a/src/database/users.py +++ b/src/database/users.py @@ -31,7 +31,7 @@ async def get_user_id_for(*, api_key: APIKey, connection: AsyncConnection) -> in return user.id if user else None -async def get_user_groups_for(*, user_id: int, connection: AsyncConnection) -> list[UserGroup]: +async def get_user_groups_for(*, user_id: int, connection: AsyncConnection) -> list[int]: row = await connection.execute( text( """ @@ -43,7 +43,7 @@ async def get_user_groups_for(*, user_id: int, connection: AsyncConnection) -> l parameters={"user_id": user_id}, ) rows = row.all() - return [UserGroup(group) for (group,) in rows] + return [group for (group,) in rows] @dataclasses.dataclass @@ -54,12 +54,12 @@ class User: @classmethod async def fetch(cls, api_key: APIKey, user_db: AsyncConnection) -> Self | None: - if user_id := await get_user_id_for(api_key=api_key, connection=user_db): + if (user_id := await get_user_id_for(api_key=api_key, connection=user_db)) is not None: return cls(user_id, _database=user_db) return None async def get_groups(self) -> list[UserGroup]: if self._groups is None: - groups = await get_user_groups_for(user_id=self.user_id, connection=self._database) - self._groups = [UserGroup(group_id) for group_id in groups] + group_ids = await get_user_groups_for(user_id=self.user_id, connection=self._database) + self._groups = [UserGroup(group_id) for group_id in group_ids] return self._groups diff --git a/src/routers/dependencies.py b/src/routers/dependencies.py index d9bfe76a..f73ac995 100644 --- a/src/routers/dependencies.py +++ b/src/routers/dependencies.py @@ -11,16 +11,14 @@ async def expdb_connection() -> AsyncGenerator[AsyncConnection, None]: engine = expdb_database() - async with engine.connect() as connection: + async with engine.connect() as connection, connection.begin(): yield connection - await connection.commit() async def userdb_connection() -> AsyncGenerator[AsyncConnection, None]: engine = user_database() - async with engine.connect() as connection: + async with engine.connect() as connection, connection.begin(): yield connection - await connection.commit() async def fetch_user( diff --git a/src/routers/mldcat_ap/dataset.py b/src/routers/mldcat_ap/dataset.py index 61c16b42..940f36b6 100644 --- a/src/routers/mldcat_ap/dataset.py +++ b/src/routers/mldcat_ap/dataset.py @@ -40,8 +40,8 @@ async def get_mldcat_ap_distribution( distribution_id: int, user: Annotated[User | None, Depends(fetch_user)] = None, - user_db: Annotated[AsyncConnection | None, Depends(userdb_connection)] = None, - expdb: Annotated[AsyncConnection | None, Depends(expdb_connection)] = None, + user_db: Annotated[AsyncConnection, Depends(userdb_connection)] = None, + expdb: Annotated[AsyncConnection, Depends(expdb_connection)] = None, ) -> JsonLDGraph: assert user_db is not None # noqa: S101 assert expdb is not None # noqa: S101 @@ -144,11 +144,16 @@ async def get_distribution_quality( quality_name: str, distribution_id: int, user: Annotated[User | None, Depends(fetch_user)] = None, - expdb: Annotated[AsyncConnection | None, Depends(expdb_connection)] = None, + expdb: Annotated[AsyncConnection, Depends(expdb_connection)] = None, ) -> JsonLDGraph: assert expdb is not None # noqa: S101 qualities = await get_qualities(distribution_id, user, expdb) - quality = next(q for q in qualities if q.name == quality_name) + quality = next((q for q in qualities if q.name == quality_name), None) + if quality is None: + raise HTTPException( + status_code=404, + detail=f"Quality '{quality_name}' not found for distribution {distribution_id}.", + ) example_quality = Quality( id_=f"{_server_url}/quality/{quality_name}/{distribution_id}", quality_type=f"{_server_url}/quality/{quality_name}", @@ -171,7 +176,7 @@ async def get_distribution_feature( distribution_id: int, feature_no: int, user: Annotated[User | None, Depends(fetch_user)] = None, - expdb: Annotated[AsyncConnection | None, Depends(expdb_connection)] = None, + expdb: Annotated[AsyncConnection, Depends(expdb_connection)] = None, ) -> JsonLDGraph: assert expdb is not None # noqa: S101 features = await get_dataset_features( diff --git a/src/routers/openml/datasets.py b/src/routers/openml/datasets.py index 39eedc72..a578102d 100644 --- a/src/routers/openml/datasets.py +++ b/src/routers/openml/datasets.py @@ -34,7 +34,7 @@ async def tag_dataset( data_id: Annotated[int, Body()], tag: Annotated[str, SystemString64], user: Annotated[User | None, Depends(fetch_user)] = None, - expdb_db: Annotated[AsyncConnection | None, Depends(expdb_connection)] = None, + expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)] = None, ) -> dict[str, dict[str, Any]]: assert expdb_db is not None # noqa: S101 tags = await database.datasets.get_tags_for(data_id, expdb_db) @@ -102,7 +102,7 @@ async def list_datasets( # noqa: PLR0913 number_missing_values: Annotated[str | None, IntegerRange] = None, status: Annotated[DatasetStatusFilter, Body()] = DatasetStatusFilter.ACTIVE, user: Annotated[User | None, Depends(fetch_user)] = None, - expdb_db: Annotated[AsyncConnection | None, Depends(expdb_connection)] = None, + expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)] = None, ) -> list[dict[str, Any]]: assert expdb_db is not None # noqa: S101 current_status = text( @@ -293,7 +293,7 @@ async def _get_dataset_raise_otherwise( async def get_dataset_features( dataset_id: int, user: Annotated[User | None, Depends(fetch_user)] = None, - expdb: Annotated[AsyncConnection | None, Depends(expdb_connection)] = None, + expdb: Annotated[AsyncConnection, Depends(expdb_connection)] = None, ) -> list[Feature]: assert expdb is not None # noqa: S101 await _get_dataset_raise_otherwise(dataset_id, user, expdb) @@ -393,8 +393,8 @@ async def update_dataset_status( async def get_dataset( dataset_id: int, user: Annotated[User | None, Depends(fetch_user)] = None, - user_db: Annotated[AsyncConnection | None, Depends(userdb_connection)] = None, - expdb_db: Annotated[AsyncConnection | None, Depends(expdb_connection)] = None, + user_db: Annotated[AsyncConnection, Depends(userdb_connection)] = None, + expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)] = None, ) -> DatasetMetadata: assert user_db is not None # noqa: S101 assert expdb_db is not None # noqa: S101 diff --git a/src/routers/openml/flows.py b/src/routers/openml/flows.py index 5ed71d83..c6177243 100644 --- a/src/routers/openml/flows.py +++ b/src/routers/openml/flows.py @@ -35,9 +35,8 @@ async def flow_exists( @router.get("/{flow_id}") async def get_flow( flow_id: int, - expdb: Annotated[AsyncConnection | None, Depends(expdb_connection)] = None, + expdb: Annotated[AsyncConnection, Depends(expdb_connection)], ) -> Flow: - assert expdb is not None # noqa: S101 flow = await database.flows.get(flow_id, expdb) if not flow: raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail="Flow not found") @@ -58,13 +57,14 @@ async def get_flow( tags = await database.flows.get_tags(flow_id, expdb) subflow_rows = await database.flows.get_subflows(flow_id, expdb) - subflows = [ - Subflow( - identifier=subflow.identifier, - flow=await get_flow(flow_id=subflow.child_id, expdb=expdb), + subflows = [] + for subflow in subflow_rows: + subflows.append( # noqa: PERF401 + Subflow( + identifier=subflow.identifier, + flow=await get_flow(flow_id=subflow.child_id, expdb=expdb), + ), ) - for subflow in subflow_rows - ] return Flow( id_=flow.id, diff --git a/src/routers/openml/study.py b/src/routers/openml/study.py index 7175da43..6c7b0019 100644 --- a/src/routers/openml/study.py +++ b/src/routers/openml/study.py @@ -54,7 +54,7 @@ async def attach_to_study( study_id: Annotated[int, Body()], entity_ids: Annotated[list[int], Body()], user: Annotated[User | None, Depends(fetch_user)] = None, - expdb: Annotated[AsyncConnection | None, Depends(expdb_connection)] = None, + expdb: Annotated[AsyncConnection, Depends(expdb_connection)] = None, ) -> AttachDetachResponse: assert expdb is not None # noqa: S101 if user is None: @@ -96,7 +96,7 @@ async def attach_to_study( async def create_study( study: CreateStudy, user: Annotated[User | None, Depends(fetch_user)] = None, - expdb: Annotated[AsyncConnection | None, Depends(expdb_connection)] = None, + expdb: Annotated[AsyncConnection, Depends(expdb_connection)] = None, ) -> dict[Literal["study_id"], int]: assert expdb is not None # noqa: S101 if user is None: @@ -139,7 +139,7 @@ async def create_study( async def get_study( alias_or_id: int | str, user: Annotated[User | None, Depends(fetch_user)] = None, - expdb: Annotated[AsyncConnection | None, Depends(expdb_connection)] = None, + expdb: Annotated[AsyncConnection, Depends(expdb_connection)] = None, ) -> Study: assert expdb is not None # noqa: S101 study = await _get_study_raise_otherwise(alias_or_id, user, expdb) diff --git a/src/routers/openml/tasks.py b/src/routers/openml/tasks.py index 83ebf719..505bd69b 100644 --- a/src/routers/openml/tasks.py +++ b/src/routers/openml/tasks.py @@ -94,7 +94,7 @@ async def fill_template( ) -async def _fill_json_template( +async def _fill_json_template( # noqa: C901 template: JSON, task: RowMapping, task_inputs: dict[str, str | int], @@ -140,7 +140,11 @@ async def _fill_json_template( parameters={"id_": int(task_inputs[table])}, ) rows = result.mappings() - for column, value in next(rows).items(): + row_data = next(rows, None) + if row_data is None: + msg = f"No data found for table {table} with id {task_inputs[table]}" + raise ValueError(msg) + for column, value in row_data.items(): fetched_data[f"{table}.{column}"] = value if match.string == template: return fetched_data[field] @@ -155,9 +159,8 @@ async def _fill_json_template( @router.get("/{task_id}") async def get_task( task_id: int, - expdb: Annotated[AsyncConnection | None, Depends(expdb_connection)] = None, + expdb: Annotated[AsyncConnection, Depends(expdb_connection)], ) -> Task: - assert expdb is not None # noqa: S101 if not (task := await database.tasks.get(task_id, expdb)): raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail="Task not found") if not (task_type := await database.tasks.get_task_type(task.ttid, expdb)): diff --git a/src/routers/openml/tasktype.py b/src/routers/openml/tasktype.py index 63a1e879..17e52018 100644 --- a/src/routers/openml/tasktype.py +++ b/src/routers/openml/tasktype.py @@ -28,12 +28,11 @@ def _normalize_task_type(task_type: Row[Any]) -> dict[str, str | None | list[Any @router.get(path="/list") async def list_task_types( - expdb: Annotated[AsyncConnection | None, Depends(expdb_connection)] = None, + expdb: Annotated[AsyncConnection, Depends(expdb_connection)], ) -> dict[ Literal["task_types"], dict[Literal["task_type"], list[dict[str, str | None | list[Any]]]], ]: - assert expdb is not None # noqa: S101 task_types: list[dict[str, str | None | list[Any]]] = [ _normalize_task_type(ttype) for ttype in await get_task_types(expdb) ] diff --git a/tests/conftest.py b/tests/conftest.py index 7edfeba4..2f842f27 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -48,11 +48,11 @@ def php_api() -> httpx.Client: @pytest.fixture -def py_api(expdb_test: AsyncConnection, user_test: AsyncConnection) -> TestClient: +def py_api(expdb_test: AsyncConnection, user_test: AsyncConnection) -> Iterator[TestClient]: app = create_api() - # We use the lambda definitions because fixtures may not be called directly. - # The lambda returns an async generator for FastAPI to handle properly + # We use async generator functions because fixtures may not be called directly. + # The async generator returns the test connections for FastAPI to handle properly async def override_expdb() -> AsyncIterator[AsyncConnection]: yield expdb_test @@ -61,7 +61,8 @@ async def override_userdb() -> AsyncIterator[AsyncConnection]: app.dependency_overrides[expdb_connection] = override_expdb app.dependency_overrides[userdb_connection] = override_userdb - return TestClient(app) + with TestClient(app) as client: + yield client @pytest.fixture diff --git a/tests/routers/openml/flows_test.py b/tests/routers/openml/flows_test.py index d2c76513..e2174c45 100644 --- a/tests/routers/openml/flows_test.py +++ b/tests/routers/openml/flows_test.py @@ -24,7 +24,10 @@ async def test_flow_exists_calls_db_correctly( expdb_test: AsyncConnection, mocker: MockerFixture, ) -> None: - mocked_db = mocker.patch("database.flows.get_by_name") + mocked_db = mocker.patch( + "database.flows.get_by_name", + new_callable=mocker.AsyncMock, + ) await flow_exists(name, external_version, expdb_test) mocked_db.assert_called_once_with( name=name, @@ -45,6 +48,7 @@ async def test_flow_exists_processes_found( fake_flow = mocker.MagicMock(id=flow_id) mocker.patch( "database.flows.get_by_name", + new_callable=mocker.AsyncMock, return_value=fake_flow, ) response = await flow_exists("name", "external_version", expdb_test) @@ -54,7 +58,11 @@ async def test_flow_exists_processes_found( async def test_flow_exists_handles_flow_not_found( mocker: MockerFixture, expdb_test: AsyncConnection ) -> None: - mocker.patch("database.flows.get_by_name", return_value=None) + mocker.patch( + "database.flows.get_by_name", + new_callable=mocker.AsyncMock, + return_value=None, + ) with pytest.raises(HTTPException) as error: await flow_exists("foo", "bar", expdb_test) assert error.value.status_code == HTTPStatus.NOT_FOUND diff --git a/tests/routers/openml/users_test.py b/tests/routers/openml/users_test.py index ce6fe0c5..edd23dac 100644 --- a/tests/routers/openml/users_test.py +++ b/tests/routers/openml/users_test.py @@ -18,7 +18,7 @@ async def test_fetch_user(api_key: str, user: User, user_test: AsyncConnection) db_user = await fetch_user(api_key, user_data=user_test) assert db_user is not None assert user.user_id == db_user.user_id - assert user._groups == db_user._groups # noqa: SLF001 + assert await user.get_groups() == await db_user.get_groups() async def test_fetch_user_invalid_key_returns_none(user_test: AsyncConnection) -> None: From 04b4b229171e45af92ed147415e758471e8ca5bb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 2 Mar 2026 09:55:05 +0000 Subject: [PATCH 8/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/main.py | 2 +- tests/routers/openml/datasets_test.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/main.py b/src/main.py index 8efa21ff..3e50a38b 100644 --- a/src/main.py +++ b/src/main.py @@ -1,7 +1,7 @@ import argparse +import logging from collections.abc import AsyncGenerator from contextlib import asynccontextmanager -import logging import uvicorn from fastapi import FastAPI diff --git a/tests/routers/openml/datasets_test.py b/tests/routers/openml/datasets_test.py index 0ebf4ba7..ff239135 100644 --- a/tests/routers/openml/datasets_test.py +++ b/tests/routers/openml/datasets_test.py @@ -2,8 +2,8 @@ import pytest from fastapi import HTTPException -from sqlalchemy.ext.asyncio import AsyncConnection from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncConnection from starlette.testclient import TestClient from database.users import User From d4f33a6eb94e25ed640569cd5c7fc8d2e61db163 Mon Sep 17 00:00:00 2001 From: rohansen856 Date: Mon, 2 Mar 2026 15:36:57 +0530 Subject: [PATCH 9/9] chore: fixed ci errors Signed-off-by: rohansen856 --- tests/routers/openml/datasets_test.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/routers/openml/datasets_test.py b/tests/routers/openml/datasets_test.py index ff239135..446c431f 100644 --- a/tests/routers/openml/datasets_test.py +++ b/tests/routers/openml/datasets_test.py @@ -275,15 +275,15 @@ def test_dataset_status_unauthorized( assert response.status_code == HTTPStatus.FORBIDDEN -def test_dataset_no_500_with_multiple_processing_entries( +async def test_dataset_no_500_with_multiple_processing_entries( py_api: TestClient, - expdb_test: Connection, + expdb_test: AsyncConnection, ) -> None: """Regression test for issue #145: multiple processing entries caused 500.""" - expdb_test.execute( + await expdb_test.execute( text("INSERT INTO evaluation_engine(id, name, description) VALUES (99, 'test_engine', '')"), ) - expdb_test.execute( + await expdb_test.execute( text( "INSERT INTO data_processed(did, evaluation_engine_id, user_id, processing_date) " "VALUES (1, 99, 2, '2020-01-01 00:00:00')",