diff --git a/src/core/evaluation.py b/src/core/evaluation.py new file mode 100644 index 00000000..56a2bcaa --- /dev/null +++ b/src/core/evaluation.py @@ -0,0 +1,125 @@ +from __future__ import annotations + +import math +from collections.abc import Sequence + + +def accuracy(y_true: Sequence[str | int], y_pred: Sequence[str | int]) -> float: + """Fraction of predictions that exactly match the ground truth.""" + if len(y_true) != len(y_pred): + msg = f"Length mismatch: {len(y_true)} vs {len(y_pred)}" + raise ValueError(msg) + if not y_true: + return 0.0 + correct = sum(t == p for t, p in zip(y_true, y_pred, strict=True)) + return correct / len(y_true) + + +def rmse(y_true: Sequence[float], y_pred: Sequence[float]) -> float: + """Root Mean Squared Error.""" + if len(y_true) != len(y_pred): + msg = f"Length mismatch: {len(y_true)} vs {len(y_pred)}" + raise ValueError(msg) + if not y_true: + return 0.0 + mse = sum((t - p) ** 2 for t, p in zip(y_true, y_pred, strict=True)) / len(y_true) + return math.sqrt(mse) + + +def mean_absolute_error(y_true: Sequence[float], y_pred: Sequence[float]) -> float: + """Mean Absolute Error.""" + if len(y_true) != len(y_pred): + msg = f"Length mismatch: {len(y_true)} vs {len(y_pred)}" + raise ValueError(msg) + if not y_true: + return 0.0 + return sum(abs(t - p) for t, p in zip(y_true, y_pred, strict=True)) / len(y_true) + + +def auc(y_true: Sequence[int], y_score: Sequence[float]) -> float: + """Binary ROC AUC via an O(n log n) rank-based Mann-Whitney U statistic. + + Mathematically equivalent to the area under the ROC curve. + + y_true: sequence of 0/1 ground-truth labels. + y_score: sequence of predicted probabilities for the positive class (label=1). + + Raises: + ValueError: if y_true contains values outside {0, 1} or lengths differ. + """ + if len(y_true) != len(y_score): + msg = f"Length mismatch: {len(y_true)} vs {len(y_score)}" + raise ValueError(msg) + if not y_true: + return 0.0 + + unique = set(y_true) + invalid = unique - {0, 1} + if invalid: + msg = f"y_true must contain only 0/1 labels; found {invalid}" + raise ValueError(msg) + + n_pos = sum(y_true) + n_neg = len(y_true) - n_pos + if n_pos == 0 or n_neg == 0: + return 0.0 + + # O(n log n): rank all scores, then use the rank-sum formula + pairs = sorted(zip(y_score, y_true, strict=False), key=lambda x: x[0]) + + n = len(pairs) + ranks: list[float] = [0.0] * n + i = 0 + while i < n: + j = i + while j < n - 1 and pairs[j][0] == pairs[j + 1][0]: + j += 1 + mid_rank = (i + j) / 2 + 1 # 1-indexed + for k in range(i, j + 1): + ranks[k] = mid_rank + i = j + 1 + + rank_sum_pos = sum(ranks[k] for k in range(n) if pairs[k][1] == 1) + u_pos = rank_sum_pos - n_pos * (n_pos + 1) / 2 + return u_pos / (n_pos * n_neg) + + +#: Task type IDs from the OpenML schema +TASK_TYPE_SUPERVISED_CLASSIFICATION = 1 +TASK_TYPE_SUPERVISED_REGRESSION = 2 + + +def compute_metrics( + task_type_id: int, + y_true: Sequence[str | int | float], + y_pred: Sequence[str | int | float], + y_score: Sequence[float] | None = None, +) -> dict[str, float]: + """Compute all applicable metrics for the given task type. + + Returns a dict of {measure_name: value} using the same names found in + the OpenML `math_function` table (e.g. 'predictive_accuracy', + 'area_under_roc_curve'). + """ + results: dict[str, float] = {} + + if task_type_id == TASK_TYPE_SUPERVISED_CLASSIFICATION: + str_true = [str(v) for v in y_true] + str_pred = [str(v) for v in y_pred] + results["predictive_accuracy"] = accuracy(str_true, str_pred) + + # AUC only when binary and scores are provided + unique_labels = set(str_true) + if y_score is not None and len(unique_labels) == 2: # noqa: PLR2004 + # Map the positive class (lexicographically larger, matching OpenML) + pos_label = max(unique_labels) + int_true = [1 if str(v) == pos_label else 0 for v in y_true] + results["area_under_roc_curve"] = auc(int_true, list(y_score)) + + elif task_type_id == TASK_TYPE_SUPERVISED_REGRESSION: + float_true = [float(v) for v in y_true] + float_pred = [float(v) for v in y_pred] + results["root_mean_squared_error"] = rmse(float_true, float_pred) + results["mean_absolute_error"] = mean_absolute_error(float_true, float_pred) + + return results diff --git a/src/core/splits.py b/src/core/splits.py new file mode 100644 index 00000000..e0eb8f07 --- /dev/null +++ b/src/core/splits.py @@ -0,0 +1,113 @@ +from __future__ import annotations + +import random +import re + +SplitEntry = dict[str, int | str] + + +def generate_splits( + n_samples: int, + n_folds: int, + n_repeats: int, + *, + seed: int = 0, +) -> list[SplitEntry]: + """Generate cross-validation splits deterministically. + + Returns a flat list of dicts with keys: + repeat, fold, rowid, type ('TRAIN' or 'TEST') + """ + if n_folds <= 0: + msg = f"n_folds must be a positive integer, got {n_folds}" + raise ValueError(msg) + if n_repeats <= 0: + msg = f"n_repeats must be a positive integer, got {n_repeats}" + raise ValueError(msg) + if n_samples <= 0: + return [] + + entries: list[SplitEntry] = [] + rng = random.Random(seed) # noqa: S311 + + for repeat in range(n_repeats): + indices = list(range(n_samples)) + rng.shuffle(indices) + + for fold in range(n_folds): + for row_pos, rowid in enumerate(indices): + split_type = "TEST" if row_pos % n_folds == fold else "TRAIN" + entries.append( + { + "repeat": repeat, + "fold": fold, + "rowid": rowid, + "type": split_type, + }, + ) + + return entries + + +_ARFF_DATA_SECTION = re.compile(r"@[Dd][Aa][Tt][Aa]") + + +def parse_arff_splits(arff_content: str) -> list[SplitEntry]: + """Parse an OpenML splits ARFF file into the same list-of-dict format. + + Expected ARFF columns (in order): type, rowid, repeat, fold + (This is the column order used by OpenML's split ARFF files.) + """ + in_data = False + entries: list[SplitEntry] = [] + + for line in arff_content.splitlines(): + stripped = line.strip() + if not stripped or stripped.startswith("%"): + continue + if _ARFF_DATA_SECTION.match(stripped): + in_data = True + continue + if not in_data: + continue + + parts = [p.strip() for p in stripped.split(",")] + if len(parts) < 4: # noqa: PLR2004 + continue + split_type, rowid_s, repeat_s, fold_s = parts[:4] + try: + entries.append( + { + "repeat": int(repeat_s), + "fold": int(fold_s), + "rowid": int(rowid_s), + "type": split_type.strip("'\""), + }, + ) + except ValueError: + continue + + return entries + + +def build_fold_index( + splits: list[SplitEntry], + repeat: int = 0, +) -> dict[int, tuple[list[int], list[int]]]: + """Build a dict of fold -> (train_indices, test_indices) for a given repeat.""" + folds: dict[int, tuple[list[int], list[int]]] = {} + for entry in splits: + if entry["repeat"] != repeat: + continue + fold = int(entry["fold"]) + rowid = int(entry["rowid"]) + split_type = str(entry["type"]).upper() + if split_type not in {"TRAIN", "TEST"}: + continue + if fold not in folds: + folds[fold] = ([], []) + if split_type == "TRAIN": + folds[fold][0].append(rowid) + else: + folds[fold][1].append(rowid) + return folds diff --git a/src/database/processing.py b/src/database/processing.py new file mode 100644 index 00000000..78e3f090 --- /dev/null +++ b/src/database/processing.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +import datetime +from collections.abc import Sequence +from typing import cast + +from sqlalchemy import Connection, Row, text + + +def enqueue(run_id: int, expdb: Connection) -> None: + """Insert a new pending processing entry for the given run.""" + expdb.execute( + text( + """ + INSERT INTO processing_run(`run_id`, `status`, `date`) + VALUES (:run_id, 'pending', :date) + """, + ), + parameters={"run_id": run_id, "date": datetime.datetime.now()}, + ) + + +def get_pending(expdb: Connection) -> Sequence[Row]: + """Atomically claim all pending processing_run rows for this worker. + + Uses an UPDATE ... WHERE status='pending' approach so that concurrent + workers don't double-process the same run. Claimed rows are set to + 'processing' and this worker reads them back by that status. + """ + # Atomically mark pending rows as 'processing' so concurrent workers skip them + expdb.execute( + text( + """ + UPDATE processing_run + SET `status` = 'processing' + WHERE `status` = 'pending' + """, + ), + ) + return cast( + "Sequence[Row]", + expdb.execute( + text( + """ + SELECT `run_id`, `status`, `date` + FROM processing_run + WHERE `status` = 'processing' + ORDER BY `date` ASC + """, + ), + ).all(), + ) + + +def mark_done(run_id: int, expdb: Connection) -> None: + """Mark a processing_run entry as successfully completed.""" + expdb.execute( + text( + """ + UPDATE processing_run + SET `status` = 'done' + WHERE `run_id` = :run_id + """, + ), + parameters={"run_id": run_id}, + ) + + +def mark_error(run_id: int, error_message: str, expdb: Connection) -> None: + """Mark a processing_run entry as failed and store the error message.""" + expdb.execute( + text( + """ + UPDATE processing_run + SET `status` = 'error', `error` = :error_message + WHERE `run_id` = :run_id + """, + ), + parameters={"run_id": run_id, "error_message": error_message}, + ) diff --git a/src/database/runs.py b/src/database/runs.py new file mode 100644 index 00000000..a420472e --- /dev/null +++ b/src/database/runs.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +import datetime +from collections.abc import Sequence +from typing import cast + +from sqlalchemy import Connection, Row, text + + +def get(run_id: int, expdb: Connection) -> Row | None: + """Fetch a single run row by its primary key.""" + return expdb.execute( + text( + """ + SELECT `rid`, `task_id`, `implementation_id` AS `flow_id`, + `uploader`, `upload_time`, `setup_string` + FROM run + WHERE `rid` = :run_id + """, + ), + parameters={"run_id": run_id}, + ).one_or_none() + + +def create( + *, + task_id: int, + flow_id: int, + uploader_id: int, + setup_string: str | None, + expdb: Connection, +) -> int: + """Insert a new run row and return the generated run_id.""" + expdb.execute( + text( + """ + INSERT INTO run( + `task_id`, `implementation_id`, `uploader`, + `upload_time`, `setup_string` + ) + VALUES (:task_id, :flow_id, :uploader_id, :upload_time, :setup_string) + """, + ), + parameters={ + "task_id": task_id, + "flow_id": flow_id, + "uploader_id": uploader_id, + "upload_time": datetime.datetime.now(), + "setup_string": setup_string, + }, + ) + row = expdb.execute(text("SELECT LAST_INSERT_ID()")).one() + return int(row[0]) + + +def get_tags(run_id: int, expdb: Connection) -> list[str]: + """Return all tags for a given run.""" + rows = expdb.execute( + text( + """ + SELECT `tag` + FROM run_tag + WHERE `id` = :run_id + """, + ), + parameters={"run_id": run_id}, + ) + return [row.tag for row in rows] + + +def get_evaluations(run_id: int, expdb: Connection) -> Sequence[Row]: + """Return all evaluation measure rows for a given run.""" + return cast( + "Sequence[Row]", + expdb.execute( + text( + """ + SELECT `function`, `value`, `array_data` + FROM run_measure + WHERE `run_id` = :run_id + """, + ), + parameters={"run_id": run_id}, + ).all(), + ) + + +def store_evaluation( + *, + run_id: int, + function: str, + value: float | None, + array_data: str | None = None, + expdb: Connection, +) -> None: + """Insert or update a single evaluation measure for a run.""" + expdb.execute( + text( + """ + INSERT INTO run_measure(`run_id`, `function`, `value`, `array_data`) + VALUES (:run_id, :function, :value, :array_data) + ON DUPLICATE KEY UPDATE `value` = :value, `array_data` = :array_data + """, + ), + parameters={ + "run_id": run_id, + "function": function, + "value": value, + "array_data": array_data, + }, + ) + + +def delete(run_id: int, expdb: Connection) -> None: + """Delete a run row by primary key (used for rollback on enqueue failure).""" + expdb.execute( + text("DELETE FROM run WHERE `rid` = :run_id"), + parameters={"run_id": run_id}, + ) diff --git a/src/main.py b/src/main.py index 560b4c50..2fe219ae 100644 --- a/src/main.py +++ b/src/main.py @@ -11,6 +11,7 @@ from routers.openml.evaluations import router as evaluationmeasures_router from routers.openml.flows import router as flows_router from routers.openml.qualities import router as qualities_router +from routers.openml.runs import router as runs_router from routers.openml.study import router as study_router from routers.openml.tasks import router as task_router from routers.openml.tasktype import router as ttype_router @@ -55,6 +56,7 @@ def create_api() -> FastAPI: app.include_router(task_router) app.include_router(flows_router) app.include_router(study_router) + app.include_router(runs_router) return app diff --git a/src/routers/openml/runs.py b/src/routers/openml/runs.py new file mode 100644 index 00000000..456d1eb6 --- /dev/null +++ b/src/routers/openml/runs.py @@ -0,0 +1,222 @@ +from __future__ import annotations + +import json +import logging +import shutil +from http import HTTPStatus +from pathlib import Path +from typing import TYPE_CHECKING, Annotated, Any, cast + +import xmltodict +from fastapi import APIRouter, Depends, File, HTTPException, UploadFile + +import database.flows +import database.processing +import database.runs +import database.tasks +from config import load_configuration +from database.users import User +from routers.dependencies import expdb_connection, fetch_user +from schemas.runs import RunDetail, RunEvaluationResult, RunUploadResponse + +if TYPE_CHECKING: + from sqlalchemy import Connection + +router = APIRouter(prefix="/runs", tags=["runs"]) +log = logging.getLogger(__name__) + +_DEFAULT_UPLOAD_DIR = "/tmp/openml_runs" # noqa: S108 + +_OML_NAMESPACE = "http://openml.org/openml" + + +def _parse_run_xml(xml_bytes: bytes) -> dict[str, Any]: + """Parse the run description XML uploaded by the client. + + Uses xmltodict namespace stripping so that 'oml:task_id' in the source + becomes simply 'task_id' in the returned dict, without doing a string + replace that could corrupt any value that contains 'oml:'. + + Expected root element: + Required children: oml:task_id, oml:implementation_id (flow_id). + Optional: oml:setup_string, oml:output_data, oml:parameter_setting. + """ + try: + raw: dict[str, Any] = xmltodict.parse( + xml_bytes.decode("utf-8"), + process_namespaces=True, + namespaces={_OML_NAMESPACE: None}, + ) + except Exception as exc: + raise HTTPException( + status_code=HTTPStatus.UNPROCESSABLE_ENTITY, + detail={"code": "530", "message": f"Invalid run description XML: {exc}"}, + ) from exc + + return cast("dict[str, Any]", raw.get("run", {})) + + +def _require_auth(user: User | None) -> User: + """Raise 412 if the request is not authenticated.""" + if user is None: + raise HTTPException( + status_code=HTTPStatus.PRECONDITION_FAILED, + detail={"code": "103", "message": "Authentication failed"}, + ) + return user + + +def _require_task(task_id: int, expdb: Connection) -> None: + """Raise 404 with code 201 if task_id does not exist.""" + if not database.tasks.get(task_id, expdb): + raise HTTPException( + status_code=HTTPStatus.NOT_FOUND, + detail={"code": "201", "message": f"Unknown task: {task_id}"}, + ) + + +def _require_flow(flow_id: int, expdb: Connection) -> None: + """Raise 404 with code 180 if flow_id does not exist.""" + if not database.flows.get(flow_id, expdb): + raise HTTPException( + status_code=HTTPStatus.NOT_FOUND, + detail={"code": "180", "message": f"Unknown flow: {flow_id}"}, + ) + + +@router.post( + "/", + summary="Upload a run (predictions + description XML)", + response_model=RunUploadResponse, + status_code=HTTPStatus.CREATED, +) +async def upload_run( + description: Annotated[UploadFile, File(description="Run description XML file")], + predictions: Annotated[UploadFile, File(description="Predictions ARFF file")], + user: Annotated[User | None, Depends(fetch_user)] = None, + expdb: Annotated[Connection, Depends(expdb_connection)] = None, +) -> RunUploadResponse: + """Upload a new run. + + Accepts two multipart files: + - **description**: XML file conforming to the OpenML run description schema + - **predictions**: ARFF file with per-row predictions + (columns: row_id, fold, repeat, prediction [, confidence.*]) + + On success returns the new `run_id`. The run is immediately enqueued for + server-side evaluation; metrics will be available after the worker processes it. + """ + authenticated_user = _require_auth(user) + + xml_bytes = await description.read() + run_xml = _parse_run_xml(xml_bytes) + + try: + task_id = int(run_xml["task_id"]) + except (KeyError, ValueError) as exc: + raise HTTPException( + status_code=HTTPStatus.UNPROCESSABLE_ENTITY, + detail={"code": "531", "message": "Missing or invalid task_id in run description"}, + ) from exc + + try: + flow_id = int(run_xml["implementation_id"]) + except (KeyError, ValueError) as exc: + raise HTTPException( + status_code=HTTPStatus.UNPROCESSABLE_ENTITY, + detail={ + "code": "532", + "message": "Missing or invalid implementation_id (flow_id) in run description", + }, + ) from exc + + setup_string: str | None = run_xml.get("setup_string") + + _require_task(task_id, expdb) + _require_flow(flow_id, expdb) + + # Store the run row + run_id = database.runs.create( + task_id=task_id, + flow_id=flow_id, + uploader_id=authenticated_user.user_id, + setup_string=setup_string, + expdb=expdb, + ) + + # Persist the predictions file to disk so the worker can read it later + upload_dir: str = load_configuration().get("upload_dir", _DEFAULT_UPLOAD_DIR) + run_dir = Path(upload_dir) / str(run_id) + run_dir.mkdir(parents=True, exist_ok=True) + predictions_bytes = await predictions.read() + predictions_path = run_dir / "predictions.arff" + predictions_path.write_bytes(predictions_bytes) + + # Enqueue for server-side evaluation; on failure, clean up to avoid orphans + try: + database.processing.enqueue(run_id, expdb) + except Exception: + log.exception("Failed to enqueue run %d; rolling back artifacts.", run_id) + shutil.rmtree(run_dir, ignore_errors=True) + database.runs.delete(run_id, expdb) + raise + + log.info( + "Run %d uploaded by user %d (task=%d, flow=%d).", + run_id, + authenticated_user.user_id, + task_id, + flow_id, + ) + return RunUploadResponse(run_id=run_id) + + +@router.get( + "/{run_id}", + summary="Get run metadata and evaluation results", +) +def get_run( + run_id: int, + user: Annotated[User | None, Depends(fetch_user)] = None, # noqa: ARG001 + expdb: Annotated[Connection, Depends(expdb_connection)] = None, +) -> RunDetail: + """Return metadata and evaluation results for a single run.""" + run = database.runs.get(run_id, expdb) + if run is None: + raise HTTPException( + status_code=HTTPStatus.NOT_FOUND, + detail={"code": "220", "message": f"Unknown run: {run_id}"}, + ) + + tags = database.runs.get_tags(run_id, expdb) + eval_rows = database.runs.get_evaluations(run_id, expdb) + + evaluations = [] + for row in eval_rows: + per_fold: list[float] | None = None + if row.array_data: + try: + parsed = json.loads(row.array_data) + if isinstance(parsed, list): + per_fold = [float(v) for v in parsed] + except (json.JSONDecodeError, ValueError, TypeError): + per_fold = None + + evaluations.append( + RunEvaluationResult( + function=row.function, + value=row.value, + per_fold=per_fold, + ), + ) + + return RunDetail( + id_=run.rid, + task_id=run.task_id, + flow_id=run.flow_id, + uploader=run.uploader, + upload_time=run.upload_time, + setup_string=run.setup_string, + tags=tags, + evaluations=evaluations, + ) diff --git a/src/schemas/runs.py b/src/schemas/runs.py new file mode 100644 index 00000000..fa8c7361 --- /dev/null +++ b/src/schemas/runs.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +from datetime import datetime + +from pydantic import BaseModel, Field + + +class RunUploadResponse(BaseModel): + """Response returned by POST /runs after a successful upload.""" + + run_id: int = Field( + serialization_alias="run_id", + json_schema_extra={"example": 42}, + ) + + +class RunEvaluationResult(BaseModel): + """A single per-fold or global evaluation measure for a run.""" + + function: str = Field( + json_schema_extra={"example": "predictive_accuracy"}, + description="Name of the evaluation measure (math_function.name in the DB).", + ) + value: float | None = Field( + json_schema_extra={"example": 0.9312}, + ) + # Per-fold values are stored as a JSON array string in the DB. + per_fold: list[float] | None = Field( + default=None, + json_schema_extra={"example": [0.92, 0.94, 0.93]}, + ) + + +class RunDetail(BaseModel): + """Full metadata for a single run, returned by GET /runs/{run_id}.""" + + id_: int = Field(serialization_alias="run_id", json_schema_extra={"example": 42}) + task_id: int = Field(json_schema_extra={"example": 59}) + flow_id: int = Field(json_schema_extra={"example": 1}) + uploader: int = Field(json_schema_extra={"example": 16}) + upload_time: datetime = Field( + json_schema_extra={"example": "2024-01-15T10:30:00"}, + ) + setup_string: str | None = Field( + default=None, + json_schema_extra={"example": "weka.classifiers.trees.J48 -C 0.25 -M 2"}, + ) + tags: list[str] = Field(default_factory=list) + evaluations: list[RunEvaluationResult] = Field(default_factory=list) diff --git a/src/worker/__init__.py b/src/worker/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/worker/evaluator.py b/src/worker/evaluator.py new file mode 100644 index 00000000..6659d3d6 --- /dev/null +++ b/src/worker/evaluator.py @@ -0,0 +1,260 @@ +from __future__ import annotations + +import csv +import io +import logging +import urllib.request +from pathlib import Path +from typing import TYPE_CHECKING, Any, cast + +import database.processing +import database.runs +import database.tasks +from config import load_configuration, load_routing_configuration +from core.evaluation import compute_metrics +from core.formatting import _format_dataset_url +from core.splits import build_fold_index, parse_arff_splits +from database.datasets import get as get_dataset + +if TYPE_CHECKING: + from sqlalchemy import Connection + +log = logging.getLogger(__name__) + +# Shared default; also set in routers/openml/runs.py — kept in sync via config key +_DEFAULT_UPLOAD_DIR = "/tmp/openml_runs" # noqa: S108 + + +def _parse_predictions_arff(content: str) -> dict[str, list[Any]]: + """Parse an OpenML predictions ARFF using csv.reader to handle quoted values. + + Returns a dict with keys: 'row_id', 'prediction', 'confidence' (optional). + Expected columns: row_id, fold, repeat, prediction [, confidence.*] + """ + result: dict[str, list[Any]] = {"row_id": [], "prediction": [], "confidence": []} + in_data = False + + for line in content.splitlines(): + stripped = line.strip() + if not stripped or stripped.startswith("%"): + continue + if stripped.upper().startswith("@DATA"): + in_data = True + continue + if not in_data: + continue + + # Use csv.reader so quoted commas and quoted values are handled correctly + try: + (parts,) = csv.reader(io.StringIO(stripped)) + except (ValueError, StopIteration): + continue + parts = [p.strip().strip("'\"") for p in parts] + + try: + row_id = int(parts[0]) + prediction = parts[3] if len(parts) > 3 else parts[-1] # noqa: PLR2004 + confidence = float(parts[4]) if len(parts) > 4 else None # noqa: PLR2004 + except (ValueError, IndexError): + continue + + result["row_id"].append(row_id) + result["prediction"].append(prediction) + result["confidence"].append(confidence) + + return result + + +def _fetch_arff(url: str) -> str: + """Download an ARFF from a URL, returning the decoded text content.""" + with urllib.request.urlopen(url, timeout=30) as resp: # noqa: S310 + return cast("str", resp.read().decode("utf-8", errors="replace")) + + +def _parse_dataset_labels( + content: str, + target_attribute: str, +) -> dict[int, str]: + """Parse a dataset ARFF and return a {row_index: label} map for target_attribute. + + Returns an empty dict (with a warning) when the target column is absent. + """ + attr_names: list[str] = [] + data_rows: list[list[str]] = [] + in_data = False + + for line in content.splitlines(): + stripped = line.strip() + if not stripped or stripped.startswith("%"): + continue + if stripped.upper().startswith("@ATTRIBUTE"): + parts = stripped.split(None, 2) + attr_names.append(parts[1].strip("'\"") if len(parts) >= 2 else "") # noqa: PLR2004 + continue + if stripped.upper().startswith("@DATA"): + in_data = True + continue + if in_data: + try: + (row,) = csv.reader(io.StringIO(stripped)) + data_rows.append([v.strip().strip("'\"") for v in row]) + except (ValueError, StopIteration): + continue + + if target_attribute not in attr_names: + log.warning("Target attribute '%s' not found in dataset.", target_attribute) + return {} + + target_idx = attr_names.index(target_attribute) + return {i: row[target_idx] for i, row in enumerate(data_rows) if target_idx < len(row)} + + +def _evaluate_run(run_id: int, expdb: Connection) -> None: # noqa: C901, PLR0911, PLR0912, PLR0915 + """Evaluate a single run, store metrics, mark processing entry done/error.""" + run = database.runs.get(run_id, expdb) + if run is None: + log.warning("Run %d not found; skipping.", run_id) + database.processing.mark_error(run_id, "run row not found", expdb) + return + + task_row = database.tasks.get(run.task_id, expdb) + if task_row is None: + database.processing.mark_error(run_id, "task not found", expdb) + return + + task_type_row = database.tasks.get_task_type(task_row.ttid, expdb) + if task_type_row is None: + database.processing.mark_error(run_id, "task type not found", expdb) + return + + task_inputs = { + row.input: int(row.value) if str(row.value).isdigit() else row.value + for row in database.tasks.get_input_for_task(run.task_id, expdb) + } + + dataset_id = task_inputs.get("source_data") + target_attr = str(task_inputs.get("target_feature", "class")) + if not isinstance(dataset_id, int): + database.processing.mark_error(run_id, "no source_data task input", expdb) + return + + dataset_row = get_dataset(dataset_id, expdb) + if dataset_row is None: + database.processing.mark_error(run_id, "dataset not found", expdb) + return + dataset_url = str(_format_dataset_url(dataset_row)) + + # Fetch dataset once (not per fold) and build a complete label map + try: + dataset_content = _fetch_arff(dataset_url) + except Exception: + log.exception("Failed to download dataset from %s", dataset_url) + database.processing.mark_error(run_id, "could not fetch dataset", expdb) + return + label_map = _parse_dataset_labels(dataset_content, target_attr) + if not label_map: + database.processing.mark_error(run_id, "target attribute not found in dataset", expdb) + return + + cfg = load_routing_configuration() + task_id = run.task_id + splits_url = f"{cfg.get('server_url', '')}api_splits/get/{task_id}/Task_{task_id}_splits.arff" + try: + splits_content = _fetch_arff(splits_url) + except Exception: + log.exception("Could not fetch splits for task %d", task_id) + database.processing.mark_error(run_id, "could not fetch splits", expdb) + return + + splits = parse_arff_splits(splits_content) + + upload_dir: str = load_configuration().get("upload_dir", _DEFAULT_UPLOAD_DIR) + predictions_path = Path(upload_dir) / str(run_id) / "predictions.arff" + try: + with predictions_path.open(encoding="utf-8") as fh: + predictions_content = fh.read() + except OSError: + log.exception("Could not read predictions file for run %d", run_id) + database.processing.mark_error(run_id, "predictions file not found", expdb) + return + + predictions = _parse_predictions_arff(predictions_content) + pred_map: dict[int, str] = dict( + zip(predictions["row_id"], predictions["prediction"], strict=True), + ) + conf_map: dict[int, float | None] = dict( + zip(predictions["row_id"], predictions["confidence"], strict=True), + ) + + # Determine available repeats and iterate all of them + all_repeats = sorted({int(e["repeat"]) for e in splits}) + + all_true: list[str | int | float] = [] + all_pred: list[str | int | float] = [] + all_score: list[float] = [] + # has_scores starts True; disabled if any fold has a missing score + has_scores = any(v is not None for v in conf_map.values()) + + for repeat in all_repeats: + fold_index = build_fold_index(splits, repeat=repeat) + for _train_ids, test_ids in fold_index.values(): + # Validate ground truth: error out if any row ID is missing + missing = [rid for rid in test_ids if rid not in label_map] + if missing: + database.processing.mark_error( + run_id, + f"ground-truth missing for row_ids {missing[:5]}", + expdb, + ) + return + + y_true_fold = [label_map[rid] for rid in test_ids] + y_pred_fold = [pred_map.get(rid, "") for rid in test_ids] + all_true.extend(y_true_fold) + all_pred.extend(y_pred_fold) + + if has_scores: + fold_scores: list[float] = [] + for rid in test_ids: + raw = conf_map.get(rid) + if raw is None: + # Score missing for this fold — disable AUC for whole run + has_scores = False + fold_scores = [] + break + fold_scores.append(float(raw)) + all_score.extend(fold_scores) + + metrics = compute_metrics( + task_type_id=task_row.ttid, + y_true=all_true, + y_pred=all_pred, + y_score=all_score if has_scores else None, + ) + for measure_name, value in metrics.items(): + database.runs.store_evaluation( + run_id=run_id, + function=measure_name, + value=value, + expdb=expdb, + ) + + database.processing.mark_done(run_id, expdb) + log.info("Run %d evaluated: %s", run_id, metrics) + + +def process_pending_runs(expdb: Connection) -> None: + """Consume all pending processing_run entries and evaluate each one. + + Designed to be called from a cron job or a management CLI command. + Each run is evaluated independently; an error in one does not halt the rest. + """ + pending = database.processing.get_pending(expdb) + log.info("Processing %d pending run(s).", len(pending)) + for entry in pending: + run_id = int(entry.run_id) + try: + _evaluate_run(run_id, expdb) + except Exception: + log.exception("Unexpected error evaluating run %d", run_id) + database.processing.mark_error(run_id, "unexpected error", expdb) diff --git a/tests/core/__init__.py b/tests/core/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/core/evaluation_test.py b/tests/core/evaluation_test.py new file mode 100644 index 00000000..12e9c5e2 --- /dev/null +++ b/tests/core/evaluation_test.py @@ -0,0 +1,186 @@ +from __future__ import annotations + +import math + +import pytest + +from core.evaluation import ( + TASK_TYPE_SUPERVISED_CLASSIFICATION, + TASK_TYPE_SUPERVISED_REGRESSION, + accuracy, + auc, + compute_metrics, + mean_absolute_error, + rmse, +) + +# --------------------------------------------------------------------------- +# accuracy +# --------------------------------------------------------------------------- + + +def test_accuracy_perfect() -> None: + assert accuracy(["A", "B", "C"], ["A", "B", "C"]) == pytest.approx(1.0) + + +def test_accuracy_half() -> None: + assert accuracy(["A", "A", "B", "B"], ["A", "B", "B", "A"]) == pytest.approx(0.5) + + +def test_accuracy_none_correct() -> None: + assert accuracy(["A", "A"], ["B", "B"]) == pytest.approx(0.0) + + +def test_accuracy_empty() -> None: + assert accuracy([], []) == pytest.approx(0.0) + + +def test_accuracy_length_mismatch() -> None: + with pytest.raises(ValueError, match="Length mismatch"): + accuracy(["A"], ["A", "B"]) + + +def test_accuracy_integer_labels() -> None: + assert accuracy([1, 2, 3], [1, 2, 3]) == pytest.approx(1.0) + + +# --------------------------------------------------------------------------- +# rmse +# --------------------------------------------------------------------------- + + +def test_rmse_zero() -> None: + assert rmse([1.0, 2.0, 3.0], [1.0, 2.0, 3.0]) == pytest.approx(0.0) + + +def test_rmse_known() -> None: + # errors: 1, 1 → RMSE = sqrt((1+1)/2) = 1.0 + assert rmse([0.0, 0.0], [1.0, -1.0]) == pytest.approx(1.0) + + +def test_rmse_empty() -> None: + assert rmse([], []) == pytest.approx(0.0) + + +def test_rmse_length_mismatch() -> None: + with pytest.raises(ValueError, match="Length mismatch"): + rmse([1.0], [1.0, 2.0]) + + +# --------------------------------------------------------------------------- +# mean_absolute_error +# --------------------------------------------------------------------------- + + +def test_mae_zero() -> None: + assert mean_absolute_error([1.0, 2.0], [1.0, 2.0]) == pytest.approx(0.0) + + +def test_mae_known() -> None: + assert mean_absolute_error([0.0, 0.0], [1.0, 3.0]) == pytest.approx(2.0) + + +# --------------------------------------------------------------------------- +# auc +# --------------------------------------------------------------------------- + + +def test_auc_perfect() -> None: + y_true = [1, 1, 0, 0] + y_score = [0.9, 0.8, 0.2, 0.1] + assert auc(y_true, y_score) == pytest.approx(1.0) + + +def test_auc_random_classifier() -> None: + # A random classifier scored with alternating 0/1 at equal probability + # gives AUC ≈ 0.5 (not exactly, depends on tie-breaking) + y_true = [1, 0, 1, 0] + y_score = [0.5, 0.5, 0.5, 0.5] + result = auc(y_true, y_score) + assert 0.0 <= result <= 1.0 + + +def test_auc_empty() -> None: + assert auc([], []) == pytest.approx(0.0) + + +def test_auc_all_one_class() -> None: + # Only positives → undefined, returns 0.0 by convention + assert auc([1, 1, 1], [0.9, 0.8, 0.7]) == pytest.approx(0.0) + + +def test_auc_length_mismatch() -> None: + with pytest.raises(ValueError, match="Length mismatch"): + auc([1], [0.5, 0.6]) + + +# --------------------------------------------------------------------------- +# compute_metrics dispatcher +# --------------------------------------------------------------------------- + + +def test_compute_metrics_classification_accuracy() -> None: + metrics = compute_metrics( + TASK_TYPE_SUPERVISED_CLASSIFICATION, + y_true=["A", "A", "B"], + y_pred=["A", "B", "B"], + ) + assert "predictive_accuracy" in metrics + assert metrics["predictive_accuracy"] == pytest.approx(2 / 3) + + +def test_compute_metrics_classification_includes_auc_for_binary() -> None: + metrics = compute_metrics( + TASK_TYPE_SUPERVISED_CLASSIFICATION, + y_true=["pos", "pos", "neg", "neg"], + y_pred=["pos", "neg", "neg", "pos"], + y_score=[0.9, 0.7, 0.3, 0.4], + ) + assert "area_under_roc_curve" in metrics + assert 0.0 <= metrics["area_under_roc_curve"] <= 1.0 + + +def test_compute_metrics_classification_no_auc_without_scores() -> None: + metrics = compute_metrics( + TASK_TYPE_SUPERVISED_CLASSIFICATION, + y_true=["A", "B"], + y_pred=["A", "B"], + ) + assert "area_under_roc_curve" not in metrics + + +def test_compute_metrics_classification_no_auc_for_multiclass() -> None: + metrics = compute_metrics( + TASK_TYPE_SUPERVISED_CLASSIFICATION, + y_true=["A", "B", "C"], + y_pred=["A", "B", "C"], + y_score=[0.8, 0.9, 0.7], + ) + assert "area_under_roc_curve" not in metrics + + +def test_compute_metrics_regression() -> None: + metrics = compute_metrics( + TASK_TYPE_SUPERVISED_REGRESSION, + y_true=[1.0, 2.0, 3.0], + y_pred=[1.0, 2.0, 3.0], + ) + assert "root_mean_squared_error" in metrics + assert "mean_absolute_error" in metrics + assert metrics["root_mean_squared_error"] == pytest.approx(0.0) + assert metrics["mean_absolute_error"] == pytest.approx(0.0) + + +def test_compute_metrics_regression_known_values() -> None: + metrics = compute_metrics( + TASK_TYPE_SUPERVISED_REGRESSION, + y_true=[0.0, 0.0], + y_pred=[1.0, -1.0], + ) + assert metrics["root_mean_squared_error"] == pytest.approx(math.sqrt(1.0)) + assert metrics["mean_absolute_error"] == pytest.approx(1.0) + + +def test_compute_metrics_unknown_task_type_returns_empty() -> None: + metrics = compute_metrics(99, y_true=["A"], y_pred=["A"]) + assert metrics == {} diff --git a/tests/routers/openml/runs_test.py b/tests/routers/openml/runs_test.py new file mode 100644 index 00000000..092c58fd --- /dev/null +++ b/tests/routers/openml/runs_test.py @@ -0,0 +1,207 @@ +from __future__ import annotations + +import io +import pathlib +from datetime import datetime +from http import HTTPStatus +from typing import TYPE_CHECKING + +import pytest + +if TYPE_CHECKING: + from pytest_mock import MockerFixture + from starlette.testclient import TestClient + + +RUN_ID = 42 +FLOW_ID = 2 +EXPECTED_RUN_ID = 99 +EXPECTED_RUN_ID_2 = 71 + +MINIMAL_PREDICTIONS_ARFF = b"""@relation predictions + +@attribute row_id NUMERIC +@attribute fold NUMERIC +@attribute repeat NUMERIC +@attribute prediction {A,B} + +@data +0,0,0,A +1,0,0,B +2,0,0,A +""" + +MINIMAL_RUN_XML = b""" + + 1 + 1 + weka.classifiers.trees.J48 -C 0.25 + +""" + +INVALID_XML = b"not xml at all <<<" + + +def test_get_run_not_found(py_api: TestClient) -> None: + """GET an unknown run_id should return 404.""" + response = py_api.get("/runs/999999999") + assert response.status_code == HTTPStatus.NOT_FOUND + detail = response.json()["detail"] + assert detail["code"] == "220" + + +def test_get_run_returns_structure(mocker: MockerFixture, py_api: TestClient) -> None: + """GET /runs/{id} returns RunDetail structure when run exists.""" + mock_run = mocker.MagicMock() + mock_run.rid = RUN_ID + mock_run.task_id = 1 + mock_run.flow_id = FLOW_ID + mock_run.uploader = 16 + mock_run.upload_time = datetime(2024, 1, 15, 10, 30, 0) + mock_run.setup_string = "weka.J48 -C 0.25" + + mocker.patch("database.runs.get", return_value=mock_run) + mocker.patch("database.runs.get_tags", return_value=["study_1"]) + mocker.patch( + "database.runs.get_evaluations", + return_value=[ + mocker.MagicMock(function="predictive_accuracy", value=0.93, array_data=None), + ], + ) + + response = py_api.get(f"/runs/{RUN_ID}") + assert response.status_code == HTTPStatus.OK + body = response.json() + + assert body["run_id"] == RUN_ID + assert body["task_id"] == 1 + assert body["flow_id"] == FLOW_ID + assert body["tags"] == ["study_1"] + assert len(body["evaluations"]) == 1 + assert body["evaluations"][0]["function"] == "predictive_accuracy" + assert body["evaluations"][0]["value"] == pytest.approx(0.93) + + +def test_upload_run_requires_auth(py_api: TestClient) -> None: + """Unauthenticated POST /runs should return 412 with code 103.""" + response = py_api.post( + "/runs/", + files={ + "description": ("description.xml", io.BytesIO(MINIMAL_RUN_XML), "application/xml"), + "predictions": ("predictions.arff", io.BytesIO(MINIMAL_PREDICTIONS_ARFF), "text/plain"), + }, + ) + assert response.status_code == HTTPStatus.PRECONDITION_FAILED + assert response.json()["detail"]["code"] == "103" + + +def test_upload_run_invalid_xml(mocker: MockerFixture, py_api: TestClient) -> None: + """Malformed description XML should return 422.""" + mocker.patch("routers.dependencies.fetch_user", return_value=mocker.MagicMock(user_id=1)) + + response = py_api.post( + "/runs/", + files={ + "description": ("description.xml", io.BytesIO(INVALID_XML), "application/xml"), + "predictions": ("predictions.arff", io.BytesIO(MINIMAL_PREDICTIONS_ARFF), "text/plain"), + }, + ) + assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY + + +def test_upload_run_unknown_task(mocker: MockerFixture, py_api: TestClient) -> None: + """A run referencing a non-existent task_id should return 404 with code 201.""" + fake_user = mocker.MagicMock(user_id=16) + mocker.patch("routers.dependencies.fetch_user", return_value=fake_user) + mocker.patch("database.tasks.get", return_value=None) + + response = py_api.post( + "/runs/", + files={ + "description": ("description.xml", io.BytesIO(MINIMAL_RUN_XML), "application/xml"), + "predictions": ("predictions.arff", io.BytesIO(MINIMAL_PREDICTIONS_ARFF), "text/plain"), + }, + ) + assert response.status_code == HTTPStatus.NOT_FOUND + assert response.json()["detail"]["code"] == "201" + + +def test_upload_run_unknown_flow(mocker: MockerFixture, py_api: TestClient) -> None: + """A run referencing a non-existent flow_id should return 404 with code 180.""" + fake_user = mocker.MagicMock(user_id=16) + mocker.patch("routers.dependencies.fetch_user", return_value=fake_user) + mocker.patch("database.tasks.get", return_value=mocker.MagicMock()) + mocker.patch("database.flows.get", return_value=None) + + response = py_api.post( + "/runs/", + files={ + "description": ("description.xml", io.BytesIO(MINIMAL_RUN_XML), "application/xml"), + "predictions": ("predictions.arff", io.BytesIO(MINIMAL_PREDICTIONS_ARFF), "text/plain"), + }, + ) + assert response.status_code == HTTPStatus.NOT_FOUND + assert response.json()["detail"]["code"] == "180" + + +def test_upload_run_success( + mocker: MockerFixture, + tmp_path: pathlib.Path, + py_api: TestClient, +) -> None: + """A fully valid POST /runs should return 201 with a run_id.""" + fake_user = mocker.MagicMock(user_id=16) + mocker.patch("routers.dependencies.fetch_user", return_value=fake_user) + mocker.patch("database.tasks.get", return_value=mocker.MagicMock()) + mocker.patch("database.flows.get", return_value=mocker.MagicMock()) + mocker.patch("database.runs.create", return_value=EXPECTED_RUN_ID) + mocker.patch("database.processing.enqueue") + mocker.patch( + "routers.openml.runs.load_configuration", + return_value={"upload_dir": str(tmp_path)}, + ) + + response = py_api.post( + "/runs/", + files={ + "description": ("description.xml", io.BytesIO(MINIMAL_RUN_XML), "application/xml"), + "predictions": ("predictions.arff", io.BytesIO(MINIMAL_PREDICTIONS_ARFF), "text/plain"), + }, + ) + assert response.status_code == HTTPStatus.CREATED + body = response.json() + assert body["run_id"] == EXPECTED_RUN_ID + + # Verify predictions file was persisted + predictions_path = tmp_path / str(EXPECTED_RUN_ID) / "predictions.arff" + assert predictions_path.exists() + assert predictions_path.read_bytes() == MINIMAL_PREDICTIONS_ARFF + + +def test_upload_run_enqueues_processing( + mocker: MockerFixture, + tmp_path: pathlib.Path, + py_api: TestClient, +) -> None: + """Successful upload must enqueue a processing_run entry.""" + fake_user = mocker.MagicMock(user_id=16) + mocker.patch("routers.dependencies.fetch_user", return_value=fake_user) + mocker.patch("database.tasks.get", return_value=mocker.MagicMock()) + mocker.patch("database.flows.get", return_value=mocker.MagicMock()) + mocker.patch("database.runs.create", return_value=EXPECTED_RUN_ID_2) + enqueue_mock = mocker.patch("database.processing.enqueue") + mocker.patch( + "routers.openml.runs.load_configuration", + return_value={"upload_dir": str(tmp_path)}, + ) + + py_api.post( + "/runs/", + files={ + "description": ("description.xml", io.BytesIO(MINIMAL_RUN_XML), "application/xml"), + "predictions": ("predictions.arff", io.BytesIO(MINIMAL_PREDICTIONS_ARFF), "text/plain"), + }, + ) + enqueue_mock.assert_called_once() + call_kwargs = enqueue_mock.call_args + assert call_kwargs[0][0] == EXPECTED_RUN_ID_2 # run_id positional arg