Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ dependencies = [
"uvicorn",
"sqlalchemy",
"mysqlclient",
"aiomysql",
"python_dotenv",
"xmltodict",
]
Expand All @@ -28,6 +29,7 @@ dev = [
"pre-commit",
"pytest",
"pytest-mock",
"pytest-asyncio",
"httpx",
"hypothesis",
"deepdiff",
Expand Down Expand Up @@ -80,6 +82,7 @@ plugins = [
pythonpath = [
"src"
]
asyncio_mode = "auto"
markers = [
"slow: test or sets of tests which take more than a few seconds to run.",
# While the `mut`ation marker below is not strictly necessary as every change is
Expand Down
15 changes: 10 additions & 5 deletions src/core/access.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
from typing import Any

from sqlalchemy.engine import Row

from database.users import User, UserGroup
from schemas.datasets.openml import Visibility


def _user_has_access(
dataset: Row,
async def _user_has_access(
dataset: Row[Any],
user: User | None = None,
) -> bool:
"""Determine if `user` has the right to view `dataset`."""
is_public = dataset.visibility == Visibility.PUBLIC
return is_public or (
user is not None and (user.user_id == dataset.uploader or UserGroup.ADMIN in user.groups)
)
if is_public:
return True
if user is None:
return False
user_groups = await user.get_groups()
return user.user_id == dataset.uploader or UserGroup.ADMIN in user_groups
61 changes: 35 additions & 26 deletions src/database/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@

import datetime

from sqlalchemy import Connection, text
from sqlalchemy import text
from sqlalchemy.engine import Row
from sqlalchemy.ext.asyncio import AsyncConnection

from schemas.datasets.openml import Feature


def get(id_: int, connection: Connection) -> Row | None:
row = connection.execute(
async def get(id_: int, connection: AsyncConnection) -> Row | None:
row = await connection.execute(
text(
"""
SELECT *
Expand All @@ -22,8 +23,8 @@ def get(id_: int, connection: Connection) -> Row | None:
return row.one_or_none()


def get_file(*, file_id: int, connection: Connection) -> Row | None:
row = connection.execute(
async def get_file(*, file_id: int, connection: AsyncConnection) -> Row | None:
row = await connection.execute(
text(
"""
SELECT *
Expand All @@ -36,8 +37,8 @@ def get_file(*, file_id: int, connection: Connection) -> Row | None:
return row.one_or_none()


def get_tags_for(id_: int, connection: Connection) -> list[str]:
rows = connection.execute(
async def get_tags_for(id_: int, connection: AsyncConnection) -> list[str]:
row = await connection.execute(
text(
"""
SELECT *
Expand All @@ -47,11 +48,12 @@ def get_tags_for(id_: int, connection: Connection) -> list[str]:
),
parameters={"dataset_id": id_},
)
rows = row.all()
return [row.tag for row in rows]


def tag(id_: int, tag_: str, *, user_id: int, connection: Connection) -> None:
connection.execute(
async def tag(id_: int, tag_: str, *, user_id: int, connection: AsyncConnection) -> None:
await connection.execute(
text(
"""
INSERT INTO dataset_tag(`id`, `tag`, `uploader`)
Expand All @@ -66,12 +68,12 @@ def tag(id_: int, tag_: str, *, user_id: int, connection: Connection) -> None:
)


def get_description(
async def get_description(
id_: int,
connection: Connection,
connection: AsyncConnection,
) -> Row | None:
"""Get the most recent description for the dataset."""
row = connection.execute(
row = await connection.execute(
text(
"""
SELECT *
Expand All @@ -85,9 +87,9 @@ def get_description(
return row.first()


def get_status(id_: int, connection: Connection) -> Row | None:
async def get_status(id_: int, connection: AsyncConnection) -> Row | None:
"""Get most recent status for the dataset."""
row = connection.execute(
row = await connection.execute(
text(
"""
SELECT *
Expand All @@ -101,8 +103,8 @@ def get_status(id_: int, connection: Connection) -> Row | None:
return row.first()


def get_latest_processing_update(dataset_id: int, connection: Connection) -> Row | None:
row = connection.execute(
async def get_latest_processing_update(dataset_id: int, connection: AsyncConnection) -> Row | None:
row = await connection.execute(
text(
"""
SELECT *
Expand All @@ -116,8 +118,8 @@ def get_latest_processing_update(dataset_id: int, connection: Connection) -> Row
return row.first()


def get_features(dataset_id: int, connection: Connection) -> list[Feature]:
rows = connection.execute(
async def get_features(dataset_id: int, connection: AsyncConnection) -> list[Feature]:
row = await connection.execute(
text(
"""
SELECT `index`,`name`,`data_type`,`is_target`,
Expand All @@ -128,11 +130,17 @@ def get_features(dataset_id: int, connection: Connection) -> list[Feature]:
),
parameters={"dataset_id": dataset_id},
)
return [Feature(**row, nominal_values=None) for row in rows.mappings()]
rows = row.mappings().all()
return [Feature(**row, nominal_values=None) for row in rows]


def get_feature_values(dataset_id: int, *, feature_index: int, connection: Connection) -> list[str]:
rows = connection.execute(
async def get_feature_values(
dataset_id: int,
*,
feature_index: int,
connection: AsyncConnection,
) -> list[str]:
row = await connection.execute(
text(
"""
SELECT `value`
Expand All @@ -142,17 +150,18 @@ def get_feature_values(dataset_id: int, *, feature_index: int, connection: Conne
),
parameters={"dataset_id": dataset_id, "feature_index": feature_index},
)
rows = row.all()
return [row.value for row in rows]


def update_status(
async def update_status(
dataset_id: int,
status: str,
*,
user_id: int,
connection: Connection,
connection: AsyncConnection,
) -> None:
connection.execute(
await connection.execute(
text(
"""
INSERT INTO dataset_status(`did`,`status`,`status_date`,`user_id`)
Expand All @@ -168,8 +177,8 @@ def update_status(
)


def remove_deactivated_status(dataset_id: int, connection: Connection) -> None:
connection.execute(
async def remove_deactivated_status(dataset_id: int, connection: AsyncConnection) -> None:
await connection.execute(
text(
"""
DELETE FROM dataset_status
Expand Down
29 changes: 16 additions & 13 deletions src/database/evaluations.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,32 @@
from collections.abc import Sequence
from typing import cast

from sqlalchemy import Connection, Row, text
from sqlalchemy import Row, text
from sqlalchemy.ext.asyncio import AsyncConnection

from core.formatting import _str_to_bool
from schemas.datasets.openml import EstimationProcedure


def get_math_functions(function_type: str, connection: Connection) -> Sequence[Row]:
return cast(
"Sequence[Row]",
connection.execute(
text(
"""
async def get_math_functions(function_type: str, connection: AsyncConnection) -> Sequence[Row]:
rows = await connection.execute(
text(
"""
SELECT *
FROM math_function
WHERE `functionType` = :function_type
""",
),
parameters={"function_type": function_type},
).all(),
),
parameters={"function_type": function_type},
)
return cast(
"Sequence[Row]",
rows.all(),
)


def get_estimation_procedures(connection: Connection) -> list[EstimationProcedure]:
rows = connection.execute(
async def get_estimation_procedures(connection: AsyncConnection) -> list[EstimationProcedure]:
row = await connection.execute(
text(
"""
SELECT `id` as 'id_', `ttid` as 'task_type_id', `name`, `type` as 'type_',
Expand All @@ -33,11 +35,12 @@ def get_estimation_procedures(connection: Connection) -> list[EstimationProcedur
""",
),
)
rows = row.mappings().all()
typed_rows = [
{
k: v if k != "stratified_sampling" or v is None else _str_to_bool(v)
for k, v in row.items()
}
for row in rows.mappings()
for row in rows
]
return [EstimationProcedure(**typed_row) for typed_row in typed_rows]
56 changes: 31 additions & 25 deletions src/database/flows.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,29 @@
from collections.abc import Sequence
from typing import cast

from sqlalchemy import Connection, Row, text
from sqlalchemy import Row, text
from sqlalchemy.ext.asyncio import AsyncConnection


def get_subflows(for_flow: int, expdb: Connection) -> Sequence[Row]:
return cast(
"Sequence[Row]",
expdb.execute(
text(
"""
async def get_subflows(for_flow: int, expdb: AsyncConnection) -> Sequence[Row]:
rows = await expdb.execute(
text(
"""
SELECT child as child_id, identifier
FROM implementation_component
WHERE parent = :flow_id
""",
),
parameters={"flow_id": for_flow},
),
parameters={"flow_id": for_flow},
)
return cast(
"Sequence[Row]",
rows.all(),
)


def get_tags(flow_id: int, expdb: Connection) -> list[str]:
tag_rows = expdb.execute(
async def get_tags(flow_id: int, expdb: AsyncConnection) -> list[str]:
rows = await expdb.execute(
text(
"""
SELECT tag
Expand All @@ -31,28 +33,30 @@ def get_tags(flow_id: int, expdb: Connection) -> list[str]:
),
parameters={"flow_id": flow_id},
)
tag_rows = rows.all()
return [tag.tag for tag in tag_rows]


def get_parameters(flow_id: int, expdb: Connection) -> Sequence[Row]:
return cast(
"Sequence[Row]",
expdb.execute(
text(
"""
async def get_parameters(flow_id: int, expdb: AsyncConnection) -> Sequence[Row]:
rows = await expdb.execute(
text(
"""
SELECT *, defaultValue as default_value, dataType as data_type
FROM input
WHERE implementation_id = :flow_id
""",
),
parameters={"flow_id": flow_id},
),
parameters={"flow_id": flow_id},
)
return cast(
"Sequence[Row]",
rows.all(),
)


def get_by_name(name: str, external_version: str, expdb: Connection) -> Row | None:
async def get_by_name(name: str, external_version: str, expdb: AsyncConnection) -> Row | None:
"""Gets flow by name and external version."""
return expdb.execute(
row = await expdb.execute(
text(
"""
SELECT *, uploadDate as upload_date
Expand All @@ -61,11 +65,12 @@ def get_by_name(name: str, external_version: str, expdb: Connection) -> Row | No
""",
),
parameters={"name": name, "external_version": external_version},
).one_or_none()
)
return row.one_or_none()


def get(id_: int, expdb: Connection) -> Row | None:
return expdb.execute(
async def get(id_: int, expdb: AsyncConnection) -> Row | None:
row = await expdb.execute(
text(
"""
SELECT *, uploadDate as upload_date
Expand All @@ -74,4 +79,5 @@ def get(id_: int, expdb: Connection) -> Row | None:
""",
),
parameters={"flow_id": id_},
).one_or_none()
)
return row.one_or_none()
Loading