Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion src/database/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections import defaultdict
from typing import Any, Iterable

from schemas.datasets.openml import Quality
from schemas.datasets.openml import Feature, Quality
from sqlalchemy import Connection, text

from database.meta import get_column_names
Expand Down Expand Up @@ -172,3 +172,32 @@ def get_latest_processing_update(dataset_id: int, connection: Connection) -> dic
return (
dict(zip(columns, result[0], strict=True), strict=True) if (result := list(row)) else None
)


def get_features_for_dataset(dataset_id: int, connection: Connection) -> list[Feature]:
rows = connection.execute(
text(
"""
SELECT `index`,`name`,`data_type`,`is_target`,
`is_row_identifier`,`is_ignore`,`NumberOfMissingValues` as `number_of_missing_values`
FROM data_feature
WHERE `did` = :dataset_id
""",
),
parameters={"dataset_id": dataset_id},
)
return [Feature(**row, nominal_values=None) for row in rows.mappings()]


def get_feature_values(dataset_id: int, feature_index: int, connection: Connection) -> list[str]:
rows = connection.execute(
text(
"""
SELECT `value`
FROM data_feature_value
WHERE `did` = :dataset_id AND `index` = :feature_index
""",
),
parameters={"dataset_id": dataset_id, "feature_index": feature_index},
)
return [row.value for row in rows]
65 changes: 56 additions & 9 deletions src/routers/openml/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
)
from database.datasets import get_dataset as db_get_dataset
from database.datasets import (
get_feature_values,
get_features_for_dataset,
get_file,
get_latest_dataset_description,
get_latest_processing_update,
Expand All @@ -29,7 +31,7 @@
from database.datasets import tag_dataset as db_tag_dataset
from database.users import User, UserGroup
from fastapi import APIRouter, Body, Depends, HTTPException
from schemas.datasets.openml import DatasetMetadata, DatasetStatus
from schemas.datasets.openml import DatasetMetadata, DatasetStatus, Feature, FeatureType
from sqlalchemy import Connection, text

from routers.dependencies import Pagination, expdb_connection, fetch_user, userdb_connection
Expand Down Expand Up @@ -263,6 +265,58 @@ def _get_processing_information(dataset_id: int, connection: Connection) -> Proc
return ProcessingInformation(date=date_processed, warning=warning, error=error)


def _get_dataset_raise_otherwise(
dataset_id: int,
user: User | None,
expdb: Connection,
) -> dict[str, Any]:
"""Fetches the dataset from the database if it exists and the user has permissions.

Raises HTTPException if the dataset does not exist or the user can not access it.
"""
if not (dataset := db_get_dataset(dataset_id, expdb)):
error = _format_error(code=DatasetError.NOT_FOUND, message="Unknown dataset")
raise HTTPException(status_code=http.client.NOT_FOUND, detail=error)

if not _user_has_access(dataset=dataset, user=user):
error = _format_error(code=DatasetError.NO_ACCESS, message="No access granted")
raise HTTPException(status_code=http.client.FORBIDDEN, detail=error)

return dataset


@router.get("/features/{dataset_id}", response_model_exclude_none=True)
def get_dataset_features(
dataset_id: int,
user: Annotated[User | None, Depends(fetch_user)] = None,
expdb: Annotated[Connection, Depends(expdb_connection)] = None,
) -> list[Feature]:
_get_dataset_raise_otherwise(dataset_id, user, expdb)
features = get_features_for_dataset(dataset_id, expdb)
for feature in [f for f in features if f.data_type == FeatureType.NOMINAL]:
feature.nominal_values = get_feature_values(dataset_id, feature.index, expdb)

if not features:
processing_state = get_latest_processing_update(dataset_id, expdb)
if processing_state is None:
code, msg = (
273,
"Dataset not processed yet. The dataset was not processed yet, features are not yet available. Please wait for a few minutes.", # noqa: E501
)
elif processing_state.get("error"):
code, msg = 274, "No features found. Additionally, dataset processed with error"
else:
code, msg = (
272,
"No features found. The dataset did not contain any features, or we could not extract them.", # noqa: E501
)
raise HTTPException(
status_code=http.client.PRECONDITION_FAILED,
detail={"code": code, "message": msg},
)
return features


@router.get(
path="/{dataset_id}",
description="Get meta-data for dataset with ID `dataset_id`.",
Expand All @@ -273,14 +327,7 @@ def get_dataset(
user_db: Annotated[Connection, Depends(userdb_connection)] = None,
expdb_db: Annotated[Connection, Depends(expdb_connection)] = None,
) -> DatasetMetadata:
if not (dataset := db_get_dataset(dataset_id, expdb_db)):
error = _format_error(code=DatasetError.NOT_FOUND, message="Unknown dataset")
raise HTTPException(status_code=http.client.NOT_FOUND, detail=error)

if not _user_has_access(dataset=dataset, user=user):
error = _format_error(code=DatasetError.NO_ACCESS, message="No access granted")
raise HTTPException(status_code=http.client.FORBIDDEN, detail=error)

dataset = _get_dataset_raise_otherwise(dataset_id, user, expdb_db)
if not (dataset_file := get_file(dataset["file_id"], user_db)):
error = _format_error(
code=DatasetError.NO_DATA_FILE,
Expand Down
17 changes: 17 additions & 0 deletions src/schemas/datasets/openml.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,23 @@ class Quality(BaseModel):
value: float | None


class FeatureType(StrEnum):
NUMERIC = "numeric"
NOMINAL = "nominal"
STRING = "string"


class Feature(BaseModel):
index: int
name: str
data_type: FeatureType
is_target: bool
is_ignore: bool
is_row_identifier: bool
number_of_missing_values: int
nominal_values: list[str] | None


class DatasetMetadata(BaseModel):
id_: int = Field(json_schema_extra={"example": 1}, alias="id")
visibility: Visibility = Field(json_schema_extra={"example": Visibility.PUBLIC})
Expand Down
86 changes: 86 additions & 0 deletions tests/routers/openml/datasets_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import pytest
from starlette.testclient import TestClient

from tests.conftest import ApiKey


@pytest.mark.parametrize(
("dataset_id", "response_code"),
Expand Down Expand Up @@ -58,3 +60,87 @@ def test_private_dataset_owner_access(
def test_private_dataset_admin_access(py_api: TestClient) -> None:
cast(httpx.Response, py_api.get("/v2/datasets/130?api_key=..."))
# test against cached response


def test_dataset_features(py_api: TestClient) -> None:
# Dataset 4 has both nominal and numerical features, so provides reasonable coverage
response = py_api.get("/datasets/features/4")
assert response.status_code == http.client.OK
assert response.json() == [
{
"index": 0,
"name": "left-weight",
"data_type": "numeric",
"is_target": False,
"is_ignore": False,
"is_row_identifier": False,
"number_of_missing_values": 0,
},
{
"index": 1,
"name": "left-distance",
"data_type": "numeric",
"is_target": False,
"is_ignore": False,
"is_row_identifier": False,
"number_of_missing_values": 0,
},
{
"index": 2,
"name": "right-weight",
"data_type": "numeric",
"is_target": False,
"is_ignore": False,
"is_row_identifier": False,
"number_of_missing_values": 0,
},
{
"index": 3,
"name": "right-distance",
"data_type": "numeric",
"is_target": False,
"is_ignore": False,
"is_row_identifier": False,
"number_of_missing_values": 0,
},
{
"index": 4,
"name": "class",
"data_type": "nominal",
"nominal_values": ["B", "L", "R"],
"is_target": True,
"is_ignore": False,
"is_row_identifier": False,
"number_of_missing_values": 0,
},
]


def test_dataset_features_no_access(py_api: TestClient) -> None:
response = py_api.get("/datasets/features/130")
assert response.status_code == http.client.FORBIDDEN


@pytest.mark.parametrize(
"api_key",
[ApiKey.ADMIN, ApiKey.OWNER_USER],
)
def test_dataset_features_access_to_private(api_key: ApiKey, py_api: TestClient) -> None:
response = py_api.get(f"/datasets/features/130?api_key={api_key}")
assert response.status_code == http.client.OK


def test_dataset_features_with_processing_error(py_api: TestClient) -> None:
# When a dataset is processed to extract its feature metadata, errors may occur.
# In that case, no feature information will ever be available.
response = py_api.get("/datasets/features/55")
assert response.status_code == http.client.PRECONDITION_FAILED
assert response.json()["detail"] == {
"code": 274,
"message": "No features found. Additionally, dataset processed with error",
}


def test_dataset_features_dataset_does_not_exist(py_api: TestClient) -> None:
resource = py_api.get("/datasets/features/1000")
assert resource.status_code == http.client.NOT_FOUND
34 changes: 34 additions & 0 deletions tests/routers/openml/migration/datasets_migration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,37 @@ def test_dataset_tag_response_is_identical(
original = original.json()
new = new.json()
assert original == new


@pytest.mark.php()
@pytest.mark.parametrize(
"data_id",
list(range(1, 130)),
)
def test_datasets_feature_is_identical(
data_id: int,
py_api: TestClient,
php_api: httpx.Client,
) -> None:
response = py_api.get(f"/datasets/features/{data_id}")
original = php_api.get(f"/data/features/{data_id}")
assert response.status_code == original.status_code

if response.status_code != http.client.OK:
error = response.json()["detail"]
error["code"] = str(error["code"])
assert error == original.json()["error"]
return

python_body = response.json()
for feature in python_body:
for key, value in list(feature.items()):
if key == "nominal_values":
# The old API uses `nominal_value` instead of `nominal_values`
values = feature.pop(key)
# The old API returns a str if there is only a single element
feature["nominal_value"] = values if len(values) > 1 else values[0]
else:
# The old API formats bool as string in lower-case
feature[key] = str(value) if not isinstance(value, bool) else str(value).lower()
assert python_body == original.json()["data_features"]["feature"]