Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ dependencies = [
"pydantic",
"uvicorn",
"sqlalchemy",
"mysqlclient",
"aiomysql",
"python_dotenv",
"xmltodict",
]
Expand All @@ -27,6 +27,7 @@ dev = [
"coverage",
"pre-commit",
"pytest",
"pytest-asyncio",
"pytest-mock",
"httpx",
"hypothesis",
Expand Down Expand Up @@ -77,6 +78,8 @@ plugins = [
]

[tool.pytest.ini_options]
asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "function"
pythonpath = [
"src"
]
Expand Down
2 changes: 1 addition & 1 deletion src/config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ root_path=""
host="openml-test-database"
port="3306"
# SQLAlchemy `dialect` and `driver`: https://docs.sqlalchemy.org/en/20/dialects/index.html
drivername="mysql"
drivername="mysql+aiomysql"

[databases.expdb]
database="openml_expdb"
Expand Down
13 changes: 9 additions & 4 deletions src/core/access.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,17 @@
from schemas.datasets.openml import Visibility


def _user_has_access(
async def _user_has_access(
dataset: Row,
user: User | None = None,
) -> bool:
"""Determine if `user` has the right to view `dataset`."""
is_public = dataset.visibility == Visibility.PUBLIC
return is_public or (
user is not None and (user.user_id == dataset.uploader or UserGroup.ADMIN in user.groups)
)
if is_public:
return True
if user is None:
return False
if user.user_id == dataset.uploader:
return True
groups = await user.get_groups()
return UserGroup.ADMIN in groups
56 changes: 31 additions & 25 deletions src/database/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@

import datetime

from sqlalchemy import Connection, text
from sqlalchemy import text
from sqlalchemy.engine import Row
from sqlalchemy.ext.asyncio import AsyncConnection

from schemas.datasets.openml import Feature


def get(id_: int, connection: Connection) -> Row | None:
row = connection.execute(
async def get(id_: int, connection: AsyncConnection) -> Row | None:
row = await connection.execute(
text(
"""
SELECT *
Expand All @@ -22,8 +23,8 @@ def get(id_: int, connection: Connection) -> Row | None:
return row.one_or_none()


def get_file(*, file_id: int, connection: Connection) -> Row | None:
row = connection.execute(
async def get_file(*, file_id: int, connection: AsyncConnection) -> Row | None:
row = await connection.execute(
text(
"""
SELECT *
Expand All @@ -36,8 +37,8 @@ def get_file(*, file_id: int, connection: Connection) -> Row | None:
return row.one_or_none()


def get_tags_for(id_: int, connection: Connection) -> list[str]:
rows = connection.execute(
async def get_tags_for(id_: int, connection: AsyncConnection) -> list[str]:
rows = await connection.execute(
text(
"""
SELECT *
Expand All @@ -50,8 +51,8 @@ def get_tags_for(id_: int, connection: Connection) -> list[str]:
return [row.tag for row in rows]


def tag(id_: int, tag_: str, *, user_id: int, connection: Connection) -> None:
connection.execute(
async def tag(id_: int, tag_: str, *, user_id: int, connection: AsyncConnection) -> None:
await connection.execute(
text(
"""
INSERT INTO dataset_tag(`id`, `tag`, `uploader`)
Expand All @@ -66,12 +67,12 @@ def tag(id_: int, tag_: str, *, user_id: int, connection: Connection) -> None:
)


def get_description(
async def get_description(
id_: int,
connection: Connection,
connection: AsyncConnection,
) -> Row | None:
"""Get the most recent description for the dataset."""
row = connection.execute(
row = await connection.execute(
text(
"""
SELECT *
Expand All @@ -85,9 +86,9 @@ def get_description(
return row.first()


def get_status(id_: int, connection: Connection) -> Row | None:
async def get_status(id_: int, connection: AsyncConnection) -> Row | None:
"""Get most recent status for the dataset."""
row = connection.execute(
row = await connection.execute(
text(
"""
SELECT *
Expand All @@ -101,8 +102,8 @@ def get_status(id_: int, connection: Connection) -> Row | None:
return row.first()


def get_latest_processing_update(dataset_id: int, connection: Connection) -> Row | None:
row = connection.execute(
async def get_latest_processing_update(dataset_id: int, connection: AsyncConnection) -> Row | None:
row = await connection.execute(
text(
"""
SELECT *
Expand All @@ -116,8 +117,8 @@ def get_latest_processing_update(dataset_id: int, connection: Connection) -> Row
return row.one_or_none()


def get_features(dataset_id: int, connection: Connection) -> list[Feature]:
rows = connection.execute(
async def get_features(dataset_id: int, connection: AsyncConnection) -> list[Feature]:
rows = await connection.execute(
text(
"""
SELECT `index`,`name`,`data_type`,`is_target`,
Expand All @@ -131,8 +132,13 @@ def get_features(dataset_id: int, connection: Connection) -> list[Feature]:
return [Feature(**row, nominal_values=None) for row in rows.mappings()]


def get_feature_values(dataset_id: int, *, feature_index: int, connection: Connection) -> list[str]:
rows = connection.execute(
async def get_feature_values(
dataset_id: int,
*,
feature_index: int,
connection: AsyncConnection,
) -> list[str]:
rows = await connection.execute(
text(
"""
SELECT `value`
Expand All @@ -145,14 +151,14 @@ def get_feature_values(dataset_id: int, *, feature_index: int, connection: Conne
return [row.value for row in rows]


def update_status(
async def update_status(
dataset_id: int,
status: str,
*,
user_id: int,
connection: Connection,
connection: AsyncConnection,
) -> None:
connection.execute(
await connection.execute(
text(
"""
INSERT INTO dataset_status(`did`,`status`,`status_date`,`user_id`)
Expand All @@ -168,8 +174,8 @@ def update_status(
)


def remove_deactivated_status(dataset_id: int, connection: Connection) -> None:
connection.execute(
async def remove_deactivated_status(dataset_id: int, connection: AsyncConnection) -> None:
await connection.execute(
text(
"""
DELETE FROM dataset_status
Expand Down
21 changes: 12 additions & 9 deletions src/database/evaluations.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,33 @@
from collections.abc import Sequence
from typing import cast

from sqlalchemy import Connection, Row, text
from sqlalchemy import Row, text
from sqlalchemy.ext.asyncio import AsyncConnection

from core.formatting import _str_to_bool
from schemas.datasets.openml import EstimationProcedure


def get_math_functions(function_type: str, connection: Connection) -> Sequence[Row]:
async def get_math_functions(function_type: str, connection: AsyncConnection) -> Sequence[Row]:
return cast(
"Sequence[Row]",
connection.execute(
text(
"""
(
await connection.execute(
text(
"""
SELECT *
FROM math_function
WHERE `functionType` = :function_type
""",
),
parameters={"function_type": function_type},
),
parameters={"function_type": function_type},
)
).all(),
)


def get_estimation_procedures(connection: Connection) -> list[EstimationProcedure]:
rows = connection.execute(
async def get_estimation_procedures(connection: AsyncConnection) -> list[EstimationProcedure]:
rows = await connection.execute(
text(
"""
SELECT `id` as 'id_', `ttid` as 'task_type_id', `name`, `type` as 'type_',
Expand Down
65 changes: 32 additions & 33 deletions src/database/flows.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,26 @@
from collections.abc import Sequence
from typing import cast

from sqlalchemy import Connection, Row, text
from sqlalchemy import Row, text
from sqlalchemy.ext.asyncio import AsyncConnection


def get_subflows(for_flow: int, expdb: Connection) -> Sequence[Row]:
return cast(
"Sequence[Row]",
expdb.execute(
text(
"""
SELECT child as child_id, identifier
FROM implementation_component
WHERE parent = :flow_id
""",
),
parameters={"flow_id": for_flow},
async def get_subflows(for_flow: int, expdb: AsyncConnection) -> Sequence[Row]:
result = await expdb.execute(
text(
"""
SELECT child as child_id, identifier
FROM implementation_component
WHERE parent = :flow_id
""",
),
parameters={"flow_id": for_flow},
)
return cast("Sequence[Row]", result.all())


def get_tags(flow_id: int, expdb: Connection) -> list[str]:
tag_rows = expdb.execute(
async def get_tags(flow_id: int, expdb: AsyncConnection) -> list[str]:
tag_rows = await expdb.execute(
text(
"""
SELECT tag
Expand All @@ -34,25 +33,23 @@ def get_tags(flow_id: int, expdb: Connection) -> list[str]:
return [tag.tag for tag in tag_rows]


def get_parameters(flow_id: int, expdb: Connection) -> Sequence[Row]:
return cast(
"Sequence[Row]",
expdb.execute(
text(
"""
SELECT *, defaultValue as default_value, dataType as data_type
FROM input
WHERE implementation_id = :flow_id
""",
),
parameters={"flow_id": flow_id},
async def get_parameters(flow_id: int, expdb: AsyncConnection) -> Sequence[Row]:
result = await expdb.execute(
text(
"""
SELECT *, defaultValue as default_value, dataType as data_type
FROM input
WHERE implementation_id = :flow_id
""",
),
parameters={"flow_id": flow_id},
)
return cast("Sequence[Row]", result.all())


def get_by_name(name: str, external_version: str, expdb: Connection) -> Row | None:
async def get_by_name(name: str, external_version: str, expdb: AsyncConnection) -> Row | None:
"""Gets flow by name and external version."""
return expdb.execute(
result = await expdb.execute(
text(
"""
SELECT *, uploadDate as upload_date
Expand All @@ -61,11 +58,12 @@ def get_by_name(name: str, external_version: str, expdb: Connection) -> Row | No
""",
),
parameters={"name": name, "external_version": external_version},
).one_or_none()
)
return result.one_or_none()


def get(id_: int, expdb: Connection) -> Row | None:
return expdb.execute(
async def get(id_: int, expdb: AsyncConnection) -> Row | None:
result = await expdb.execute(
text(
"""
SELECT *, uploadDate as upload_date
Expand All @@ -74,4 +72,5 @@ def get(id_: int, expdb: Connection) -> Row | None:
""",
),
parameters={"flow_id": id_},
).one_or_none()
)
return result.one_or_none()
Loading