From 5f02cf3269dbaceac105f0f608105bf20a4f9f9a Mon Sep 17 00:00:00 2001 From: vinayak sharma Date: Fri, 27 Feb 2026 14:56:20 +0530 Subject: [PATCH 1/5] main --- src/core/evaluation.py | 118 ++++++++++++++++ src/core/splits.py | 102 ++++++++++++++ src/database/processing.py | 66 +++++++++ src/database/runs.py | 112 ++++++++++++++++ src/main.py | 2 + src/routers/openml/runs.py | 216 ++++++++++++++++++++++++++++++ src/schemas/runs.py | 50 +++++++ src/worker/__init__.py | 0 src/worker/evaluator.py | 215 +++++++++++++++++++++++++++++ tests/core/__init__.py | 0 tests/core/evaluation_test.py | 188 ++++++++++++++++++++++++++ tests/routers/openml/runs_test.py | 207 ++++++++++++++++++++++++++++ 12 files changed, 1276 insertions(+) create mode 100644 src/core/evaluation.py create mode 100644 src/core/splits.py create mode 100644 src/database/processing.py create mode 100644 src/database/runs.py create mode 100644 src/routers/openml/runs.py create mode 100644 src/schemas/runs.py create mode 100644 src/worker/__init__.py create mode 100644 src/worker/evaluator.py create mode 100644 tests/core/__init__.py create mode 100644 tests/core/evaluation_test.py create mode 100644 tests/routers/openml/runs_test.py diff --git a/src/core/evaluation.py b/src/core/evaluation.py new file mode 100644 index 00000000..7339ad77 --- /dev/null +++ b/src/core/evaluation.py @@ -0,0 +1,118 @@ +from __future__ import annotations + +import math + +# --------------------------------------------------------------------------- +# Individual metrics +# --------------------------------------------------------------------------- + + +def accuracy(y_true: list[str | int], y_pred: list[str | int]) -> float: + """Fraction of predictions that exactly match the ground truth.""" + if len(y_true) != len(y_pred): + msg = f"Length mismatch: {len(y_true)} vs {len(y_pred)}" + raise ValueError(msg) + if not y_true: + return 0.0 + correct = sum(t == p for t, p in zip(y_true, y_pred, strict=True)) + return correct / len(y_true) + + +def rmse(y_true: list[float], y_pred: list[float]) -> float: + """Root Mean Squared Error.""" + if len(y_true) != len(y_pred): + msg = f"Length mismatch: {len(y_true)} vs {len(y_pred)}" + raise ValueError(msg) + if not y_true: + return 0.0 + mse = sum((t - p) ** 2 for t, p in zip(y_true, y_pred, strict=True)) / len(y_true) + return math.sqrt(mse) + + +def mean_absolute_error(y_true: list[float], y_pred: list[float]) -> float: + """Mean Absolute Error.""" + if len(y_true) != len(y_pred): + msg = f"Length mismatch: {len(y_true)} vs {len(y_pred)}" + raise ValueError(msg) + if not y_true: + return 0.0 + return sum(abs(t - p) for t, p in zip(y_true, y_pred, strict=True)) / len(y_true) + + +def auc(y_true: list[int], y_score: list[float]) -> float: + """Binary ROC AUC via the Wilcoxon-Mann-Whitney U statistic. + + Mathematically equivalent to the area under the ROC curve. + Counts concordant pairs: for each (positive, negative) pair, score 1 if + y_score[pos] > y_score[neg], 0.5 if tied, 0 otherwise, then normalise. + + y_true: list of 0/1 ground-truth labels. + y_score: list of predicted probabilities for the positive class (label=1). + """ + if len(y_true) != len(y_score): + msg = f"Length mismatch: {len(y_true)} vs {len(y_score)}" + raise ValueError(msg) + if not y_true: + return 0.0 + + n_pos = sum(y_true) + n_neg = len(y_true) - n_pos + if n_pos == 0 or n_neg == 0: + return 0.0 + + pos_scores = [s for t, s in zip(y_true, y_score, strict=True) if t == 1] + neg_scores = [s for t, s in zip(y_true, y_score, strict=True) if t == 0] + + concordant = 0.0 + for ps in pos_scores: + for ns in neg_scores: + if ps > ns: + concordant += 1.0 + elif ps == ns: + concordant += 0.5 + + return concordant / (n_pos * n_neg) + + +# --------------------------------------------------------------------------- +# Dispatcher +# --------------------------------------------------------------------------- + +#: Task type IDs from the OpenML schema +TASK_TYPE_SUPERVISED_CLASSIFICATION = 1 +TASK_TYPE_SUPERVISED_REGRESSION = 2 + + +def compute_metrics( + task_type_id: int, + y_true: list[str | int | float], + y_pred: list[str | int | float], + y_score: list[float] | None = None, +) -> dict[str, float]: + """Compute all applicable metrics for the given task type. + + Returns a dict of {measure_name: value} using the same names found in + the OpenML `math_function` table (e.g. 'predictive_accuracy', 'area_under_roc_curve'). + """ + results: dict[str, float] = {} + + if task_type_id == TASK_TYPE_SUPERVISED_CLASSIFICATION: + str_true = [str(v) for v in y_true] + str_pred = [str(v) for v in y_pred] + results["predictive_accuracy"] = accuracy(str_true, str_pred) + + # AUC only when binary and scores are provided + unique_labels = set(str_true) + if y_score is not None and len(unique_labels) == 2: # noqa: PLR2004 + # Map the positive class (lexicographically larger, matching OpenML convention) + pos_label = max(unique_labels) + int_true = [1 if str(v) == pos_label else 0 for v in y_true] + results["area_under_roc_curve"] = auc(int_true, y_score) + + elif task_type_id == TASK_TYPE_SUPERVISED_REGRESSION: + float_true = [float(v) for v in y_true] + float_pred = [float(v) for v in y_pred] + results["root_mean_squared_error"] = rmse(float_true, float_pred) + results["mean_absolute_error"] = mean_absolute_error(float_true, float_pred) + + return results diff --git a/src/core/splits.py b/src/core/splits.py new file mode 100644 index 00000000..a79c8e62 --- /dev/null +++ b/src/core/splits.py @@ -0,0 +1,102 @@ + +from __future__ import annotations + +import random +import re + +SplitEntry = dict[str, int | str] + + +def generate_splits( + n_samples: int, + n_folds: int, + n_repeats: int, + *, + seed: int = 0, +) -> list[SplitEntry]: + """Generate cross-validation splits deterministically. + + Returns a flat list of dicts with keys: + repeat, fold, rowid, type ('TRAIN' or 'TEST') + """ + entries: list[SplitEntry] = [] + rng = random.Random(seed) # noqa: S311 + + for repeat in range(n_repeats): + indices = list(range(n_samples)) + rng.shuffle(indices) + + for fold in range(n_folds): + for row_pos, rowid in enumerate(indices): + split_type = "TEST" if row_pos % n_folds == fold else "TRAIN" + entries.append( + { + "repeat": repeat, + "fold": fold, + "rowid": rowid, + "type": split_type, + }, + ) + + return entries + + +_ARFF_DATA_SECTION = re.compile(r"@[Dd][Aa][Tt][Aa]") + + +def parse_arff_splits(arff_content: str) -> list[SplitEntry]: + """Parse an OpenML splits ARFF file into the same list-of-dict format. + + Expected ARFF columns (in order): type, rowid, repeat, fold + (This is the column order used by OpenML's split ARFF files.) + """ + in_data = False + entries: list[SplitEntry] = [] + + for line in arff_content.splitlines(): + stripped = line.strip() + if not stripped or stripped.startswith("%"): + continue + if _ARFF_DATA_SECTION.match(stripped): + in_data = True + continue + if not in_data: + continue + + parts = [p.strip() for p in stripped.split(",")] + if len(parts) < 4: # noqa: PLR2004 + continue + split_type, rowid_s, repeat_s, fold_s = parts[:4] + try: + entries.append( + { + "repeat": int(repeat_s), + "fold": int(fold_s), + "rowid": int(rowid_s), + "type": split_type.strip("'\""), + }, + ) + except ValueError: + continue + + return entries + + +def build_fold_index( + splits: list[SplitEntry], + repeat: int = 0, +) -> dict[int, tuple[list[int], list[int]]]: + """Build a dict of fold → (train_indices, test_indices) for a given repeat.""" + folds: dict[int, tuple[list[int], list[int]]] = {} + for entry in splits: + if entry["repeat"] != repeat: + continue + fold = int(entry["fold"]) + rowid = int(entry["rowid"]) + if fold not in folds: + folds[fold] = ([], []) + if entry["type"] == "TRAIN": + folds[fold][0].append(rowid) + else: + folds[fold][1].append(rowid) + return folds diff --git a/src/database/processing.py b/src/database/processing.py new file mode 100644 index 00000000..5168ea90 --- /dev/null +++ b/src/database/processing.py @@ -0,0 +1,66 @@ + +from __future__ import annotations + +import datetime +from collections.abc import Sequence +from typing import cast + +from sqlalchemy import Connection, Row, text + + +def enqueue(run_id: int, expdb: Connection) -> None: + """Insert a new pending processing entry for the given run.""" + expdb.execute( + text( + """ + INSERT INTO processing_run(`run_id`, `status`, `date`) + VALUES (:run_id, 'pending', :date) + """, + ), + parameters={"run_id": run_id, "date": datetime.datetime.now()}, + ) + + +def get_pending(expdb: Connection) -> Sequence[Row]: + """Return all processing_run rows whose status is 'pending'.""" + return cast( + "Sequence[Row]", + expdb.execute( + text( + """ + SELECT `run_id`, `status`, `date` + FROM processing_run + WHERE `status` = 'pending' + ORDER BY `date` ASC + """, + ), + ).all(), + ) + + +def mark_done(run_id: int, expdb: Connection) -> None: + """Mark a processing_run entry as successfully completed.""" + expdb.execute( + text( + """ + UPDATE processing_run + SET `status` = 'done' + WHERE `run_id` = :run_id + """, + ), + parameters={"run_id": run_id}, + ) + + +def mark_error(run_id: int, error_message: str, expdb: Connection) -> None: + """Mark a processing_run entry as failed and store the error message.""" + expdb.execute( + text( + """ + UPDATE processing_run + SET `status` = 'error', `error` = :error_message + WHERE `run_id` = :run_id + """, + ), + parameters={"run_id": run_id, "error_message": error_message}, + ) diff --git a/src/database/runs.py b/src/database/runs.py new file mode 100644 index 00000000..6fd91d46 --- /dev/null +++ b/src/database/runs.py @@ -0,0 +1,112 @@ + +from __future__ import annotations + +import datetime +from collections.abc import Sequence +from typing import cast + +from sqlalchemy import Connection, Row, text + + +def get(run_id: int, expdb: Connection) -> Row | None: + """Fetch a single run row by its primary key.""" + return expdb.execute( + text( + """ + SELECT `rid`, `task_id`, `implementation_id` AS `flow_id`, + `uploader`, `upload_time`, `setup_string` + FROM run + WHERE `rid` = :run_id + """, + ), + parameters={"run_id": run_id}, + ).one_or_none() + + +def create( + *, + task_id: int, + flow_id: int, + uploader_id: int, + setup_string: str | None, + expdb: Connection, +) -> int: + """Insert a new run row and return the generated run_id.""" + expdb.execute( + text( + """ + INSERT INTO run( + `task_id`, `implementation_id`, `uploader`, + `upload_time`, `setup_string` + ) + VALUES (:task_id, :flow_id, :uploader_id, :upload_time, :setup_string) + """, + ), + parameters={ + "task_id": task_id, + "flow_id": flow_id, + "uploader_id": uploader_id, + "upload_time": datetime.datetime.now(), + "setup_string": setup_string, + }, + ) + row = expdb.execute(text("SELECT LAST_INSERT_ID()")).one() + return int(row[0]) + + +def get_tags(run_id: int, expdb: Connection) -> list[str]: + """Return all tags for a given run.""" + rows = expdb.execute( + text( + """ + SELECT `tag` + FROM run_tag + WHERE `id` = :run_id + """, + ), + parameters={"run_id": run_id}, + ) + return [row.tag for row in rows] + + +def get_evaluations(run_id: int, expdb: Connection) -> Sequence[Row]: + """Return all evaluation measure rows for a given run.""" + return cast( + "Sequence[Row]", + expdb.execute( + text( + """ + SELECT `function`, `value`, `array_data` + FROM run_measure + WHERE `run_id` = :run_id + """, + ), + parameters={"run_id": run_id}, + ).all(), + ) + + +def store_evaluation( + *, + run_id: int, + function: str, + value: float | None, + array_data: str | None = None, + expdb: Connection, +) -> None: + """Insert or update a single evaluation measure for a run.""" + expdb.execute( + text( + """ + INSERT INTO run_measure(`run_id`, `function`, `value`, `array_data`) + VALUES (:run_id, :function, :value, :array_data) + ON DUPLICATE KEY UPDATE `value` = :value, `array_data` = :array_data + """, + ), + parameters={ + "run_id": run_id, + "function": function, + "value": value, + "array_data": array_data, + }, + ) diff --git a/src/main.py b/src/main.py index 560b4c50..82911071 100644 --- a/src/main.py +++ b/src/main.py @@ -12,6 +12,7 @@ from routers.openml.flows import router as flows_router from routers.openml.qualities import router as qualities_router from routers.openml.study import router as study_router +from routers.openml.runs import router as runs_router from routers.openml.tasks import router as task_router from routers.openml.tasktype import router as ttype_router @@ -55,6 +56,7 @@ def create_api() -> FastAPI: app.include_router(task_router) app.include_router(flows_router) app.include_router(study_router) + app.include_router(runs_router) return app diff --git a/src/routers/openml/runs.py b/src/routers/openml/runs.py new file mode 100644 index 00000000..93e4e066 --- /dev/null +++ b/src/routers/openml/runs.py @@ -0,0 +1,216 @@ + +from __future__ import annotations + +import json +import logging +from http import HTTPStatus +from pathlib import Path +from typing import TYPE_CHECKING, Annotated + +import xmltodict +from fastapi import APIRouter, Depends, File, HTTPException, UploadFile + +import database.flows +import database.processing +import database.runs +import database.tasks +from database.users import User +from routers.dependencies import expdb_connection, fetch_user +from schemas.runs import RunDetail, RunEvaluationResult, RunUploadResponse + +if TYPE_CHECKING: + from sqlalchemy import Connection + +router = APIRouter(prefix="/runs", tags=["runs"]) +log = logging.getLogger(__name__) + + + + + +def _parse_run_xml(xml_bytes: bytes) -> dict: + """Parse the run description XML uploaded by the client. + + Expected root element: + Required children: oml:task_id, oml:implementation_id (flow_id). + Optional: oml:setup_string, oml:output_data, oml:parameter_setting. + """ + try: + raw = xmltodict.parse(xml_bytes.decode("utf-8")) + except Exception as exc: + raise HTTPException( + status_code=HTTPStatus.UNPROCESSABLE_ENTITY, + detail={"code": "530", "message": f"Invalid run description XML: {exc}"}, + ) from exc + + # Strip the namespace prefix so keys are consistent + run_str = json.dumps(raw).replace("oml:", "") + data: dict = json.loads(run_str) + + return data.get("run", {}) + + + + + +def _require_auth(user: User | None) -> User: + if user is None: + raise HTTPException( + status_code=HTTPStatus.PRECONDITION_FAILED, + detail={"code": "103", "message": "Authentication failed"}, + ) + return user + + +def _require_task(task_id: int, expdb: Connection) -> None: + if not database.tasks.get(task_id, expdb): + raise HTTPException( + status_code=HTTPStatus.NOT_FOUND, + detail={"code": "201", "message": f"Unknown task: {task_id}"}, + ) + + +def _require_flow(flow_id: int, expdb: Connection) -> None: + if not database.flows.get(flow_id, expdb): + raise HTTPException( + status_code=HTTPStatus.NOT_FOUND, + detail={"code": "180", "message": f"Unknown flow: {flow_id}"}, + ) + + + + + +@router.post( + "/", + summary="Upload a run (predictions + description XML)", + response_model=RunUploadResponse, + status_code=HTTPStatus.CREATED, +) +async def upload_run( + description: Annotated[UploadFile, File(description="Run description XML file")], + predictions: Annotated[UploadFile, File(description="Predictions ARFF file")], + user: Annotated[User | None, Depends(fetch_user)] = None, + expdb: Annotated[Connection, Depends(expdb_connection)] = None, +) -> RunUploadResponse: + """Upload a new run. + + Accepts two multipart files: + - **description**: XML file conforming to the OpenML run description schema + - **predictions**: ARFF file with per-row predictions + (columns: row_id, fold, repeat, prediction [, confidence.*]) + + On success returns the new `run_id`. The run is immediately enqueued for + server-side evaluation; metrics will be available after the worker processes it. + """ + authenticated_user = _require_auth(user) + + xml_bytes = await description.read() + run_xml = _parse_run_xml(xml_bytes) + + try: + task_id = int(run_xml["task_id"]) + except (KeyError, ValueError) as exc: + raise HTTPException( + status_code=HTTPStatus.UNPROCESSABLE_ENTITY, + detail={"code": "531", "message": "Missing or invalid task_id in run description"}, + ) from exc + + try: + flow_id = int(run_xml["implementation_id"]) + except (KeyError, ValueError) as exc: + raise HTTPException( + status_code=HTTPStatus.UNPROCESSABLE_ENTITY, + detail={ + "code": "532", + "message": "Missing or invalid implementation_id (flow_id) in run description", + }, + ) from exc + + setup_string: str | None = run_xml.get("setup_string") + + _require_task(task_id, expdb) + _require_flow(flow_id, expdb) + + # Store the run row + run_id = database.runs.create( + task_id=task_id, + flow_id=flow_id, + uploader_id=authenticated_user.user_id, + setup_string=setup_string, + expdb=expdb, + ) + + # Persist the predictions file to disk so the worker can read it later + from config import load_configuration # noqa: PLC0415 + + upload_dir: str = load_configuration().get("upload_dir", "/tmp/openml_runs") # noqa: S108 + run_dir = Path(upload_dir) / str(run_id) + run_dir.mkdir(parents=True, exist_ok=True) + predictions_bytes = await predictions.read() + predictions_path = run_dir / "predictions.arff" + predictions_path.write_bytes(predictions_bytes) + + # Enqueue for server-side evaluation + database.processing.enqueue(run_id, expdb) + + log.info( + "Run %d uploaded by user %d (task=%d, flow=%d).", + run_id, + authenticated_user.user_id, + task_id, + flow_id, + ) + return RunUploadResponse(run_id=run_id) + + + + + +@router.get( + "/{run_id}", + summary="Get run metadata and evaluation results", +) +def get_run( + run_id: int, + user: Annotated[User | None, Depends(fetch_user)] = None, # noqa: ARG001 + expdb: Annotated[Connection, Depends(expdb_connection)] = None, +) -> RunDetail: + """Return metadata and evaluation results for a single run.""" + run = database.runs.get(run_id, expdb) + if run is None: + raise HTTPException( + status_code=HTTPStatus.NOT_FOUND, + detail={"code": "220", "message": f"Unknown run: {run_id}"}, + ) + + tags = database.runs.get_tags(run_id, expdb) + eval_rows = database.runs.get_evaluations(run_id, expdb) + + evaluations = [] + for row in eval_rows: + per_fold: list[float] | None = None + if row.array_data: + try: + per_fold = [float(v) for v in json.loads(row.array_data)] + except (json.JSONDecodeError, ValueError): + per_fold = None + + evaluations.append( + RunEvaluationResult( + function=row.function, + value=row.value, + per_fold=per_fold, + ), + ) + + return RunDetail( + id_=run.rid, + task_id=run.task_id, + flow_id=run.flow_id, + uploader=run.uploader, + upload_time=run.upload_time, + setup_string=run.setup_string, + tags=tags, + evaluations=evaluations, + ) diff --git a/src/schemas/runs.py b/src/schemas/runs.py new file mode 100644 index 00000000..91284012 --- /dev/null +++ b/src/schemas/runs.py @@ -0,0 +1,50 @@ + +from __future__ import annotations + +from datetime import datetime + +from pydantic import BaseModel, Field + + +class RunUploadResponse(BaseModel): + """Response returned by POST /runs after a successful upload.""" + + run_id: int = Field( + serialization_alias="run_id", + json_schema_extra={"example": 42}, + ) + + +class RunEvaluationResult(BaseModel): + """A single per-fold or global evaluation measure for a run.""" + + function: str = Field( + json_schema_extra={"example": "predictive_accuracy"}, + description="Name of the evaluation measure (math_function.name in the DB).", + ) + value: float | None = Field( + json_schema_extra={"example": 0.9312}, + ) + # Per-fold values are stored as a JSON array string in the DB. + per_fold: list[float] | None = Field( + default=None, + json_schema_extra={"example": [0.92, 0.94, 0.93]}, + ) + + +class RunDetail(BaseModel): + """Full metadata for a single run, returned by GET /runs/{run_id}.""" + + id_: int = Field(serialization_alias="run_id", json_schema_extra={"example": 42}) + task_id: int = Field(json_schema_extra={"example": 59}) + flow_id: int = Field(json_schema_extra={"example": 1}) + uploader: int = Field(json_schema_extra={"example": 16}) + upload_time: datetime = Field( + json_schema_extra={"example": "2024-01-15T10:30:00"}, + ) + setup_string: str | None = Field( + default=None, + json_schema_extra={"example": "weka.classifiers.trees.J48 -C 0.25 -M 2"}, + ) + tags: list[str] = Field(default_factory=list) + evaluations: list[RunEvaluationResult] = Field(default_factory=list) diff --git a/src/worker/__init__.py b/src/worker/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/worker/evaluator.py b/src/worker/evaluator.py new file mode 100644 index 00000000..a449e020 --- /dev/null +++ b/src/worker/evaluator.py @@ -0,0 +1,215 @@ + +from __future__ import annotations + +import logging +import urllib.request +from pathlib import Path +from typing import TYPE_CHECKING, Any + +import database.processing +import database.runs +import database.tasks +from config import load_configuration, load_routing_configuration +from core.evaluation import compute_metrics +from core.formatting import _format_dataset_url +from core.splits import build_fold_index, parse_arff_splits +from database.datasets import get as get_dataset + +if TYPE_CHECKING: + from sqlalchemy import Connection + +log = logging.getLogger(__name__) + + +def _parse_predictions_arff(content: str) -> dict[str, list[Any]]: + """Parse an OpenML predictions ARFF. + + Returns a dict with keys: 'row_id', 'prediction', 'confidence' (optional). + Expected columns: row_id, fold, repeat, prediction [, confidence.*] + """ + result: dict[str, list[Any]] = {"row_id": [], "prediction": [], "confidence": []} + in_data = False + + for line in content.splitlines(): + stripped = line.strip() + if not stripped or stripped.startswith("%"): + continue + if stripped.upper().startswith("@DATA"): + in_data = True + continue + if not in_data: + continue + + parts = [p.strip().strip("'\"") for p in stripped.split(",")] + if not parts: + continue + try: + row_id = int(parts[0]) + prediction = parts[3] if len(parts) > 3 else parts[-1] # noqa: PLR2004 + confidence = float(parts[4]) if len(parts) > 4 else None # noqa: PLR2004 + except (ValueError, IndexError): + continue + + result["row_id"].append(row_id) + result["prediction"].append(prediction) + result["confidence"].append(confidence) + + return result + + +def _load_ground_truth( + dataset_url: str, + target_attribute: str, + test_row_ids: list[int], +) -> list[str]: + """Download the dataset ARFF and extract the target column for given row IDs. + + Only extracts rows whose 0-based index is in `test_row_ids`. + Returns labels as strings in the order of `test_row_ids`. + """ + try: + with urllib.request.urlopen(dataset_url, timeout=30) as resp: # noqa: S310 + content = resp.read().decode("utf-8", errors="replace") + except Exception: + log.exception("Failed to download dataset from %s", dataset_url) + return [] + + attr_names: list[str] = [] + data_rows: list[list[str]] = [] + in_data = False + + for line in content.splitlines(): + stripped = line.strip() + if not stripped or stripped.startswith("%"): + continue + if stripped.upper().startswith("@ATTRIBUTE"): + parts = stripped.split(None, 2) + attr_names.append(parts[1].strip("'\"") if len(parts) >= 2 else "") # noqa: PLR2004 + continue + if stripped.upper().startswith("@DATA"): + in_data = True + continue + if in_data: + data_rows.append([v.strip().strip("'\"") for v in stripped.split(",")]) + + if target_attribute not in attr_names: + log.warning("Target attribute '%s' not found in dataset.", target_attribute) + return [] + + target_idx = attr_names.index(target_attribute) + pos_to_label = { + i: row[target_idx] + for i, row in enumerate(data_rows) + if i in set(test_row_ids) and target_idx < len(row) + } + return [pos_to_label.get(rid, "") for rid in test_row_ids] + + +def _evaluate_run(run_id: int, expdb: Connection) -> None: # noqa: C901, PLR0911, PLR0915 + """Evaluate a single run, store metrics, mark processing entry done/error.""" + run = database.runs.get(run_id, expdb) + if run is None: + log.warning("Run %d not found; skipping.", run_id) + database.processing.mark_error(run_id, "run row not found", expdb) + return + + task_row = database.tasks.get(run.task_id, expdb) + if task_row is None: + database.processing.mark_error(run_id, "task not found", expdb) + return + + task_type_row = database.tasks.get_task_type(task_row.ttid, expdb) + if task_type_row is None: + database.processing.mark_error(run_id, "task type not found", expdb) + return + + task_inputs = { + row.input: int(row.value) if str(row.value).isdigit() else row.value + for row in database.tasks.get_input_for_task(run.task_id, expdb) + } + + dataset_id = task_inputs.get("source_data") + target_attr = str(task_inputs.get("target_feature", "class")) + if not isinstance(dataset_id, int): + database.processing.mark_error(run_id, "no source_data task input", expdb) + return + + dataset_row = get_dataset(dataset_id, expdb) + if dataset_row is None: + database.processing.mark_error(run_id, "dataset not found", expdb) + return + dataset_url = str(_format_dataset_url(dataset_row)) + + cfg = load_routing_configuration() + task_id = run.task_id + splits_url = f"{cfg.get('server_url', '')}api_splits/get/{task_id}/Task_{task_id}_splits.arff" + try: + with urllib.request.urlopen(splits_url, timeout=30) as resp: # noqa: S310 + splits_content = resp.read().decode("utf-8", errors="replace") + except Exception: + log.exception("Could not fetch splits for task %d", task_id) + database.processing.mark_error(run_id, "could not fetch splits", expdb) + return + + fold_index = build_fold_index(parse_arff_splits(splits_content), repeat=0) + + upload_dir: str = load_configuration().get("upload_dir", "/tmp/openml_runs") # noqa: S108 + predictions_path = Path(upload_dir) / str(run_id) / "predictions.arff" + try: + with predictions_path.open(encoding="utf-8") as fh: + predictions_content = fh.read() + except OSError: + log.exception("Could not read predictions file for run %d", run_id) + database.processing.mark_error(run_id, "predictions file not found", expdb) + return + + predictions = _parse_predictions_arff(predictions_content) + pred_map: dict[int, str] = dict( + zip(predictions["row_id"], predictions["prediction"], strict=True), + ) + conf_map: dict[int, float | None] = dict( + zip(predictions["row_id"], predictions["confidence"], strict=True), + ) + has_scores = any(v is not None for v in conf_map.values()) + + all_true: list[str] = [] + all_pred: list[str] = [] + all_score: list[float] = [] + for train_ids, test_ids in fold_index.values(): # noqa: B007 + all_true.extend(_load_ground_truth(dataset_url, target_attr, test_ids)) + all_pred.extend(pred_map.get(rid, "") for rid in test_ids) + if has_scores: + for rid in test_ids: + raw = conf_map.get(rid) + all_score.append(float(raw) if raw is not None else 0.0) + + metrics = compute_metrics( + task_type_id=task_row.ttid, + y_true=all_true, + y_pred=all_pred, + y_score=all_score if has_scores else None, + ) + for measure_name, value in metrics.items(): + database.runs.store_evaluation( + run_id=run_id, function=measure_name, value=value, expdb=expdb, + ) + + database.processing.mark_done(run_id, expdb) + log.info("Run %d evaluated: %s", run_id, metrics) + + +def process_pending_runs(expdb: Connection) -> None: + """Consume all pending processing_run entries and evaluate each one. + + Designed to be called from a cron job or a management CLI command. + Each run is evaluated independently; an error in one does not halt the rest. + """ + pending = database.processing.get_pending(expdb) + log.info("Processing %d pending run(s).", len(pending)) + for entry in pending: + run_id = int(entry.run_id) + try: + _evaluate_run(run_id, expdb) + except Exception: + log.exception("Unexpected error evaluating run %d", run_id) + database.processing.mark_error(run_id, "unexpected error", expdb) diff --git a/tests/core/__init__.py b/tests/core/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/core/evaluation_test.py b/tests/core/evaluation_test.py new file mode 100644 index 00000000..f13fa630 --- /dev/null +++ b/tests/core/evaluation_test.py @@ -0,0 +1,188 @@ + +from __future__ import annotations + +import math + +import pytest + +from core.evaluation import ( + TASK_TYPE_SUPERVISED_CLASSIFICATION, + TASK_TYPE_SUPERVISED_REGRESSION, + accuracy, + auc, + compute_metrics, + mean_absolute_error, + rmse, +) + + +# --------------------------------------------------------------------------- +# accuracy +# --------------------------------------------------------------------------- + + +def test_accuracy_perfect() -> None: + assert accuracy(["A", "B", "C"], ["A", "B", "C"]) == pytest.approx(1.0) + + +def test_accuracy_half() -> None: + assert accuracy(["A", "A", "B", "B"], ["A", "B", "B", "A"]) == pytest.approx(0.5) + + +def test_accuracy_none_correct() -> None: + assert accuracy(["A", "A"], ["B", "B"]) == pytest.approx(0.0) + + +def test_accuracy_empty() -> None: + assert accuracy([], []) == pytest.approx(0.0) + + +def test_accuracy_length_mismatch() -> None: + with pytest.raises(ValueError, match="Length mismatch"): + accuracy(["A"], ["A", "B"]) + + +def test_accuracy_integer_labels() -> None: + assert accuracy([1, 2, 3], [1, 2, 3]) == pytest.approx(1.0) + + +# --------------------------------------------------------------------------- +# rmse +# --------------------------------------------------------------------------- + + +def test_rmse_zero() -> None: + assert rmse([1.0, 2.0, 3.0], [1.0, 2.0, 3.0]) == pytest.approx(0.0) + + +def test_rmse_known() -> None: + # errors: 1, 1 → RMSE = sqrt((1+1)/2) = 1.0 + assert rmse([0.0, 0.0], [1.0, -1.0]) == pytest.approx(1.0) + + +def test_rmse_empty() -> None: + assert rmse([], []) == pytest.approx(0.0) + + +def test_rmse_length_mismatch() -> None: + with pytest.raises(ValueError, match="Length mismatch"): + rmse([1.0], [1.0, 2.0]) + + +# --------------------------------------------------------------------------- +# mean_absolute_error +# --------------------------------------------------------------------------- + + +def test_mae_zero() -> None: + assert mean_absolute_error([1.0, 2.0], [1.0, 2.0]) == pytest.approx(0.0) + + +def test_mae_known() -> None: + assert mean_absolute_error([0.0, 0.0], [1.0, 3.0]) == pytest.approx(2.0) + + +# --------------------------------------------------------------------------- +# auc +# --------------------------------------------------------------------------- + + +def test_auc_perfect() -> None: + y_true = [1, 1, 0, 0] + y_score = [0.9, 0.8, 0.2, 0.1] + assert auc(y_true, y_score) == pytest.approx(1.0) + + +def test_auc_random_classifier() -> None: + # A random classifier scored with alternating 0/1 at equal probability + # gives AUC ≈ 0.5 (not exactly, depends on tie-breaking) + y_true = [1, 0, 1, 0] + y_score = [0.5, 0.5, 0.5, 0.5] + result = auc(y_true, y_score) + assert 0.0 <= result <= 1.0 + + +def test_auc_empty() -> None: + assert auc([], []) == pytest.approx(0.0) + + +def test_auc_all_one_class() -> None: + # Only positives → undefined, returns 0.0 by convention + assert auc([1, 1, 1], [0.9, 0.8, 0.7]) == pytest.approx(0.0) + + +def test_auc_length_mismatch() -> None: + with pytest.raises(ValueError, match="Length mismatch"): + auc([1], [0.5, 0.6]) + + +# --------------------------------------------------------------------------- +# compute_metrics dispatcher +# --------------------------------------------------------------------------- + + +def test_compute_metrics_classification_accuracy() -> None: + metrics = compute_metrics( + TASK_TYPE_SUPERVISED_CLASSIFICATION, + y_true=["A", "A", "B"], + y_pred=["A", "B", "B"], + ) + assert "predictive_accuracy" in metrics + assert metrics["predictive_accuracy"] == pytest.approx(2 / 3) + + +def test_compute_metrics_classification_includes_auc_for_binary() -> None: + metrics = compute_metrics( + TASK_TYPE_SUPERVISED_CLASSIFICATION, + y_true=["pos", "pos", "neg", "neg"], + y_pred=["pos", "neg", "neg", "pos"], + y_score=[0.9, 0.7, 0.3, 0.4], + ) + assert "area_under_roc_curve" in metrics + assert 0.0 <= metrics["area_under_roc_curve"] <= 1.0 + + +def test_compute_metrics_classification_no_auc_without_scores() -> None: + metrics = compute_metrics( + TASK_TYPE_SUPERVISED_CLASSIFICATION, + y_true=["A", "B"], + y_pred=["A", "B"], + ) + assert "area_under_roc_curve" not in metrics + + +def test_compute_metrics_classification_no_auc_for_multiclass() -> None: + metrics = compute_metrics( + TASK_TYPE_SUPERVISED_CLASSIFICATION, + y_true=["A", "B", "C"], + y_pred=["A", "B", "C"], + y_score=[0.8, 0.9, 0.7], + ) + assert "area_under_roc_curve" not in metrics + + +def test_compute_metrics_regression() -> None: + metrics = compute_metrics( + TASK_TYPE_SUPERVISED_REGRESSION, + y_true=[1.0, 2.0, 3.0], + y_pred=[1.0, 2.0, 3.0], + ) + assert "root_mean_squared_error" in metrics + assert "mean_absolute_error" in metrics + assert metrics["root_mean_squared_error"] == pytest.approx(0.0) + assert metrics["mean_absolute_error"] == pytest.approx(0.0) + + +def test_compute_metrics_regression_known_values() -> None: + metrics = compute_metrics( + TASK_TYPE_SUPERVISED_REGRESSION, + y_true=[0.0, 0.0], + y_pred=[1.0, -1.0], + ) + assert metrics["root_mean_squared_error"] == pytest.approx(math.sqrt(1.0)) + assert metrics["mean_absolute_error"] == pytest.approx(1.0) + + +def test_compute_metrics_unknown_task_type_returns_empty() -> None: + metrics = compute_metrics(99, y_true=["A"], y_pred=["A"]) + assert metrics == {} diff --git a/tests/routers/openml/runs_test.py b/tests/routers/openml/runs_test.py new file mode 100644 index 00000000..f5651ba0 --- /dev/null +++ b/tests/routers/openml/runs_test.py @@ -0,0 +1,207 @@ + +from __future__ import annotations + +import io +from http import HTTPStatus + +import pytest +from pytest_mock import MockerFixture +from starlette.testclient import TestClient + + +# --------------------------------------------------------------------------- +# Minimal ARFF predictions content used in tests +# --------------------------------------------------------------------------- + +MINIMAL_PREDICTIONS_ARFF = b"""@relation predictions + +@attribute row_id NUMERIC +@attribute fold NUMERIC +@attribute repeat NUMERIC +@attribute prediction {A,B} + +@data +0,0,0,A +1,0,0,B +2,0,0,A +""" + +MINIMAL_RUN_XML = b""" + + 1 + 1 + weka.classifiers.trees.J48 -C 0.25 + +""" + +INVALID_XML = b"not xml at all <<<" + + +# --------------------------------------------------------------------------- +# GET /runs/{run_id} +# --------------------------------------------------------------------------- + + +def test_get_run_not_found(py_api: TestClient) -> None: + """GET an unknown run_id should return 404.""" + response = py_api.get("/runs/999999999") + assert response.status_code == HTTPStatus.NOT_FOUND + detail = response.json()["detail"] + assert detail["code"] == "220" + + +def test_get_run_returns_structure(mocker: MockerFixture, py_api: TestClient) -> None: + """GET /runs/{id} returns RunDetail structure when run exists.""" + from datetime import datetime + + mock_run = mocker.MagicMock() + mock_run.rid = 42 + mock_run.task_id = 1 + mock_run.flow_id = 2 + mock_run.uploader = 16 + mock_run.upload_time = datetime(2024, 1, 15, 10, 30, 0) + mock_run.setup_string = "weka.J48 -C 0.25" + + mocker.patch("database.runs.get", return_value=mock_run) + mocker.patch("database.runs.get_tags", return_value=["study_1"]) + mocker.patch( + "database.runs.get_evaluations", + return_value=[ + mocker.MagicMock(function="predictive_accuracy", value=0.93, array_data=None), + ], + ) + + response = py_api.get("/runs/42") + assert response.status_code == HTTPStatus.OK + body = response.json() + + assert body["run_id"] == 42 + assert body["task_id"] == 1 + assert body["flow_id"] == 2 + assert body["tags"] == ["study_1"] + assert len(body["evaluations"]) == 1 + assert body["evaluations"][0]["function"] == "predictive_accuracy" + assert body["evaluations"][0]["value"] == pytest.approx(0.93) + + +# --------------------------------------------------------------------------- +# POST /runs +# --------------------------------------------------------------------------- + + +def test_upload_run_requires_auth(py_api: TestClient) -> None: + """Unauthenticated POST /runs should return 412 with code 103.""" + response = py_api.post( + "/runs/", + files={ + "description": ("description.xml", io.BytesIO(MINIMAL_RUN_XML), "application/xml"), + "predictions": ("predictions.arff", io.BytesIO(MINIMAL_PREDICTIONS_ARFF), "text/plain"), + }, + ) + assert response.status_code == HTTPStatus.PRECONDITION_FAILED + assert response.json()["detail"]["code"] == "103" + + +def test_upload_run_invalid_xml(mocker: MockerFixture, py_api: TestClient) -> None: + """Malformed description XML should return 422.""" + # Simulate an authenticated user + mocker.patch("routers.dependencies.fetch_user", return_value=mocker.MagicMock(user_id=1)) + + response = py_api.post( + "/runs/", + files={ + "description": ("description.xml", io.BytesIO(INVALID_XML), "application/xml"), + "predictions": ("predictions.arff", io.BytesIO(MINIMAL_PREDICTIONS_ARFF), "text/plain"), + }, + ) + assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY + + +def test_upload_run_unknown_task(mocker: MockerFixture, py_api: TestClient) -> None: + """A run referencing a non-existent task_id should return 404 with code 201.""" + fake_user = mocker.MagicMock(user_id=16) + mocker.patch("routers.dependencies.fetch_user", return_value=fake_user) + mocker.patch("database.tasks.get", return_value=None) + + response = py_api.post( + "/runs/", + files={ + "description": ("description.xml", io.BytesIO(MINIMAL_RUN_XML), "application/xml"), + "predictions": ("predictions.arff", io.BytesIO(MINIMAL_PREDICTIONS_ARFF), "text/plain"), + }, + ) + assert response.status_code == HTTPStatus.NOT_FOUND + assert response.json()["detail"]["code"] == "201" + + +def test_upload_run_unknown_flow(mocker: MockerFixture, py_api: TestClient) -> None: + """A run referencing a non-existent flow_id should return 404 with code 180.""" + fake_user = mocker.MagicMock(user_id=16) + mocker.patch("routers.dependencies.fetch_user", return_value=fake_user) + mocker.patch("database.tasks.get", return_value=mocker.MagicMock()) + mocker.patch("database.flows.get", return_value=None) + + response = py_api.post( + "/runs/", + files={ + "description": ("description.xml", io.BytesIO(MINIMAL_RUN_XML), "application/xml"), + "predictions": ("predictions.arff", io.BytesIO(MINIMAL_PREDICTIONS_ARFF), "text/plain"), + }, + ) + assert response.status_code == HTTPStatus.NOT_FOUND + assert response.json()["detail"]["code"] == "180" + + +def test_upload_run_success(mocker: MockerFixture, tmp_path, py_api: TestClient) -> None: + """A fully valid POST /runs should return 201 with a run_id.""" + fake_user = mocker.MagicMock(user_id=16) + mocker.patch("routers.dependencies.fetch_user", return_value=fake_user) + mocker.patch("database.tasks.get", return_value=mocker.MagicMock()) + mocker.patch("database.flows.get", return_value=mocker.MagicMock()) + mocker.patch("database.runs.create", return_value=99) + mocker.patch("database.processing.enqueue") + mocker.patch( + "routers.openml.runs.load_configuration", + return_value={"upload_dir": str(tmp_path)}, + ) + + response = py_api.post( + "/runs/", + files={ + "description": ("description.xml", io.BytesIO(MINIMAL_RUN_XML), "application/xml"), + "predictions": ("predictions.arff", io.BytesIO(MINIMAL_PREDICTIONS_ARFF), "text/plain"), + }, + ) + assert response.status_code == HTTPStatus.CREATED + body = response.json() + assert body["run_id"] == 99 + + # Verify predictions file was persisted + predictions_path = tmp_path / "99" / "predictions.arff" + assert predictions_path.exists() + assert predictions_path.read_bytes() == MINIMAL_PREDICTIONS_ARFF + + +def test_upload_run_enqueues_processing(mocker: MockerFixture, tmp_path, py_api: TestClient) -> None: + """Successful upload must enqueue a processing_run entry.""" + fake_user = mocker.MagicMock(user_id=16) + mocker.patch("routers.dependencies.fetch_user", return_value=fake_user) + mocker.patch("database.tasks.get", return_value=mocker.MagicMock()) + mocker.patch("database.flows.get", return_value=mocker.MagicMock()) + mocker.patch("database.runs.create", return_value=7) + enqueue_mock = mocker.patch("database.processing.enqueue") + mocker.patch( + "routers.openml.runs.load_configuration", + return_value={"upload_dir": str(tmp_path)}, + ) + + py_api.post( + "/runs/", + files={ + "description": ("description.xml", io.BytesIO(MINIMAL_RUN_XML), "application/xml"), + "predictions": ("predictions.arff", io.BytesIO(MINIMAL_PREDICTIONS_ARFF), "text/plain"), + }, + ) + enqueue_mock.assert_called_once() + call_kwargs = enqueue_mock.call_args + assert call_kwargs[0][0] == 7 # run_id positional arg From 363b9f26188830bdb4096aa264d42a6c9cd5daa9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 27 Feb 2026 09:33:44 +0000 Subject: [PATCH 2/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/core/splits.py | 1 - src/database/processing.py | 1 - src/database/runs.py | 1 - src/main.py | 2 +- src/routers/openml/runs.py | 13 ------------- src/schemas/runs.py | 1 - src/worker/evaluator.py | 6 ++++-- tests/core/evaluation_test.py | 2 -- tests/routers/openml/runs_test.py | 6 +++--- 9 files changed, 8 insertions(+), 25 deletions(-) diff --git a/src/core/splits.py b/src/core/splits.py index a79c8e62..d225cc19 100644 --- a/src/core/splits.py +++ b/src/core/splits.py @@ -1,4 +1,3 @@ - from __future__ import annotations import random diff --git a/src/database/processing.py b/src/database/processing.py index 5168ea90..a1e8d7ce 100644 --- a/src/database/processing.py +++ b/src/database/processing.py @@ -1,4 +1,3 @@ - from __future__ import annotations import datetime diff --git a/src/database/runs.py b/src/database/runs.py index 6fd91d46..ef415ce3 100644 --- a/src/database/runs.py +++ b/src/database/runs.py @@ -1,4 +1,3 @@ - from __future__ import annotations import datetime diff --git a/src/main.py b/src/main.py index 82911071..2fe219ae 100644 --- a/src/main.py +++ b/src/main.py @@ -11,8 +11,8 @@ from routers.openml.evaluations import router as evaluationmeasures_router from routers.openml.flows import router as flows_router from routers.openml.qualities import router as qualities_router -from routers.openml.study import router as study_router from routers.openml.runs import router as runs_router +from routers.openml.study import router as study_router from routers.openml.tasks import router as task_router from routers.openml.tasktype import router as ttype_router diff --git a/src/routers/openml/runs.py b/src/routers/openml/runs.py index 93e4e066..b2aebc39 100644 --- a/src/routers/openml/runs.py +++ b/src/routers/openml/runs.py @@ -1,4 +1,3 @@ - from __future__ import annotations import json @@ -25,9 +24,6 @@ log = logging.getLogger(__name__) - - - def _parse_run_xml(xml_bytes: bytes) -> dict: """Parse the run description XML uploaded by the client. @@ -50,9 +46,6 @@ def _parse_run_xml(xml_bytes: bytes) -> dict: return data.get("run", {}) - - - def _require_auth(user: User | None) -> User: if user is None: raise HTTPException( @@ -78,9 +71,6 @@ def _require_flow(flow_id: int, expdb: Connection) -> None: ) - - - @router.post( "/", summary="Upload a run (predictions + description XML)", @@ -164,9 +154,6 @@ async def upload_run( return RunUploadResponse(run_id=run_id) - - - @router.get( "/{run_id}", summary="Get run metadata and evaluation results", diff --git a/src/schemas/runs.py b/src/schemas/runs.py index 91284012..fa8c7361 100644 --- a/src/schemas/runs.py +++ b/src/schemas/runs.py @@ -1,4 +1,3 @@ - from __future__ import annotations from datetime import datetime diff --git a/src/worker/evaluator.py b/src/worker/evaluator.py index a449e020..6f68def7 100644 --- a/src/worker/evaluator.py +++ b/src/worker/evaluator.py @@ -1,4 +1,3 @@ - from __future__ import annotations import logging @@ -191,7 +190,10 @@ def _evaluate_run(run_id: int, expdb: Connection) -> None: # noqa: C901, PLR091 ) for measure_name, value in metrics.items(): database.runs.store_evaluation( - run_id=run_id, function=measure_name, value=value, expdb=expdb, + run_id=run_id, + function=measure_name, + value=value, + expdb=expdb, ) database.processing.mark_done(run_id, expdb) diff --git a/tests/core/evaluation_test.py b/tests/core/evaluation_test.py index f13fa630..12e9c5e2 100644 --- a/tests/core/evaluation_test.py +++ b/tests/core/evaluation_test.py @@ -1,4 +1,3 @@ - from __future__ import annotations import math @@ -15,7 +14,6 @@ rmse, ) - # --------------------------------------------------------------------------- # accuracy # --------------------------------------------------------------------------- diff --git a/tests/routers/openml/runs_test.py b/tests/routers/openml/runs_test.py index f5651ba0..afbc183b 100644 --- a/tests/routers/openml/runs_test.py +++ b/tests/routers/openml/runs_test.py @@ -1,4 +1,3 @@ - from __future__ import annotations import io @@ -8,7 +7,6 @@ from pytest_mock import MockerFixture from starlette.testclient import TestClient - # --------------------------------------------------------------------------- # Minimal ARFF predictions content used in tests # --------------------------------------------------------------------------- @@ -182,7 +180,9 @@ def test_upload_run_success(mocker: MockerFixture, tmp_path, py_api: TestClient) assert predictions_path.read_bytes() == MINIMAL_PREDICTIONS_ARFF -def test_upload_run_enqueues_processing(mocker: MockerFixture, tmp_path, py_api: TestClient) -> None: +def test_upload_run_enqueues_processing( + mocker: MockerFixture, tmp_path, py_api: TestClient +) -> None: """Successful upload must enqueue a processing_run entry.""" fake_user = mocker.MagicMock(user_id=16) mocker.patch("routers.dependencies.fetch_user", return_value=fake_user) From 0bd41fd50f75a6167b22b74599eb80b1b1dd1027 Mon Sep 17 00:00:00 2001 From: vinayak sharma Date: Sat, 28 Feb 2026 00:11:08 +0530 Subject: [PATCH 3/5] fix --- src/core/evaluation.py | 73 +++++++++-------- src/core/splits.py | 16 +++- src/database/processing.py | 19 ++++- src/database/runs.py | 9 +++ src/routers/openml/runs.py | 58 ++++++++------ src/worker/evaluator.py | 126 +++++++++++++++++++++--------- tests/routers/openml/runs_test.py | 60 +++++++------- 7 files changed, 228 insertions(+), 133 deletions(-) diff --git a/src/core/evaluation.py b/src/core/evaluation.py index 7339ad77..56a2bcaa 100644 --- a/src/core/evaluation.py +++ b/src/core/evaluation.py @@ -1,13 +1,10 @@ from __future__ import annotations import math +from collections.abc import Sequence -# --------------------------------------------------------------------------- -# Individual metrics -# --------------------------------------------------------------------------- - -def accuracy(y_true: list[str | int], y_pred: list[str | int]) -> float: +def accuracy(y_true: Sequence[str | int], y_pred: Sequence[str | int]) -> float: """Fraction of predictions that exactly match the ground truth.""" if len(y_true) != len(y_pred): msg = f"Length mismatch: {len(y_true)} vs {len(y_pred)}" @@ -18,7 +15,7 @@ def accuracy(y_true: list[str | int], y_pred: list[str | int]) -> float: return correct / len(y_true) -def rmse(y_true: list[float], y_pred: list[float]) -> float: +def rmse(y_true: Sequence[float], y_pred: Sequence[float]) -> float: """Root Mean Squared Error.""" if len(y_true) != len(y_pred): msg = f"Length mismatch: {len(y_true)} vs {len(y_pred)}" @@ -29,7 +26,7 @@ def rmse(y_true: list[float], y_pred: list[float]) -> float: return math.sqrt(mse) -def mean_absolute_error(y_true: list[float], y_pred: list[float]) -> float: +def mean_absolute_error(y_true: Sequence[float], y_pred: Sequence[float]) -> float: """Mean Absolute Error.""" if len(y_true) != len(y_pred): msg = f"Length mismatch: {len(y_true)} vs {len(y_pred)}" @@ -39,15 +36,16 @@ def mean_absolute_error(y_true: list[float], y_pred: list[float]) -> float: return sum(abs(t - p) for t, p in zip(y_true, y_pred, strict=True)) / len(y_true) -def auc(y_true: list[int], y_score: list[float]) -> float: - """Binary ROC AUC via the Wilcoxon-Mann-Whitney U statistic. +def auc(y_true: Sequence[int], y_score: Sequence[float]) -> float: + """Binary ROC AUC via an O(n log n) rank-based Mann-Whitney U statistic. Mathematically equivalent to the area under the ROC curve. - Counts concordant pairs: for each (positive, negative) pair, score 1 if - y_score[pos] > y_score[neg], 0.5 if tied, 0 otherwise, then normalise. - y_true: list of 0/1 ground-truth labels. - y_score: list of predicted probabilities for the positive class (label=1). + y_true: sequence of 0/1 ground-truth labels. + y_score: sequence of predicted probabilities for the positive class (label=1). + + Raises: + ValueError: if y_true contains values outside {0, 1} or lengths differ. """ if len(y_true) != len(y_score): msg = f"Length mismatch: {len(y_true)} vs {len(y_score)}" @@ -55,28 +53,36 @@ def auc(y_true: list[int], y_score: list[float]) -> float: if not y_true: return 0.0 + unique = set(y_true) + invalid = unique - {0, 1} + if invalid: + msg = f"y_true must contain only 0/1 labels; found {invalid}" + raise ValueError(msg) + n_pos = sum(y_true) n_neg = len(y_true) - n_pos if n_pos == 0 or n_neg == 0: return 0.0 - pos_scores = [s for t, s in zip(y_true, y_score, strict=True) if t == 1] - neg_scores = [s for t, s in zip(y_true, y_score, strict=True) if t == 0] - - concordant = 0.0 - for ps in pos_scores: - for ns in neg_scores: - if ps > ns: - concordant += 1.0 - elif ps == ns: - concordant += 0.5 + # O(n log n): rank all scores, then use the rank-sum formula + pairs = sorted(zip(y_score, y_true, strict=False), key=lambda x: x[0]) - return concordant / (n_pos * n_neg) + n = len(pairs) + ranks: list[float] = [0.0] * n + i = 0 + while i < n: + j = i + while j < n - 1 and pairs[j][0] == pairs[j + 1][0]: + j += 1 + mid_rank = (i + j) / 2 + 1 # 1-indexed + for k in range(i, j + 1): + ranks[k] = mid_rank + i = j + 1 + rank_sum_pos = sum(ranks[k] for k in range(n) if pairs[k][1] == 1) + u_pos = rank_sum_pos - n_pos * (n_pos + 1) / 2 + return u_pos / (n_pos * n_neg) -# --------------------------------------------------------------------------- -# Dispatcher -# --------------------------------------------------------------------------- #: Task type IDs from the OpenML schema TASK_TYPE_SUPERVISED_CLASSIFICATION = 1 @@ -85,14 +91,15 @@ def auc(y_true: list[int], y_score: list[float]) -> float: def compute_metrics( task_type_id: int, - y_true: list[str | int | float], - y_pred: list[str | int | float], - y_score: list[float] | None = None, + y_true: Sequence[str | int | float], + y_pred: Sequence[str | int | float], + y_score: Sequence[float] | None = None, ) -> dict[str, float]: """Compute all applicable metrics for the given task type. Returns a dict of {measure_name: value} using the same names found in - the OpenML `math_function` table (e.g. 'predictive_accuracy', 'area_under_roc_curve'). + the OpenML `math_function` table (e.g. 'predictive_accuracy', + 'area_under_roc_curve'). """ results: dict[str, float] = {} @@ -104,10 +111,10 @@ def compute_metrics( # AUC only when binary and scores are provided unique_labels = set(str_true) if y_score is not None and len(unique_labels) == 2: # noqa: PLR2004 - # Map the positive class (lexicographically larger, matching OpenML convention) + # Map the positive class (lexicographically larger, matching OpenML) pos_label = max(unique_labels) int_true = [1 if str(v) == pos_label else 0 for v in y_true] - results["area_under_roc_curve"] = auc(int_true, y_score) + results["area_under_roc_curve"] = auc(int_true, list(y_score)) elif task_type_id == TASK_TYPE_SUPERVISED_REGRESSION: float_true = [float(v) for v in y_true] diff --git a/src/core/splits.py b/src/core/splits.py index a79c8e62..8ba8426e 100644 --- a/src/core/splits.py +++ b/src/core/splits.py @@ -19,6 +19,15 @@ def generate_splits( Returns a flat list of dicts with keys: repeat, fold, rowid, type ('TRAIN' or 'TEST') """ + if n_folds <= 0: + msg = f"n_folds must be a positive integer, got {n_folds}" + raise ValueError(msg) + if n_repeats <= 0: + msg = f"n_repeats must be a positive integer, got {n_repeats}" + raise ValueError(msg) + if n_samples <= 0: + return [] + entries: list[SplitEntry] = [] rng = random.Random(seed) # noqa: S311 @@ -86,16 +95,19 @@ def build_fold_index( splits: list[SplitEntry], repeat: int = 0, ) -> dict[int, tuple[list[int], list[int]]]: - """Build a dict of fold → (train_indices, test_indices) for a given repeat.""" + """Build a dict of fold -> (train_indices, test_indices) for a given repeat.""" folds: dict[int, tuple[list[int], list[int]]] = {} for entry in splits: if entry["repeat"] != repeat: continue fold = int(entry["fold"]) rowid = int(entry["rowid"]) + split_type = str(entry["type"]).upper() + if split_type not in {"TRAIN", "TEST"}: + continue if fold not in folds: folds[fold] = ([], []) - if entry["type"] == "TRAIN": + if split_type == "TRAIN": folds[fold][0].append(rowid) else: folds[fold][1].append(rowid) diff --git a/src/database/processing.py b/src/database/processing.py index 5168ea90..f4db3229 100644 --- a/src/database/processing.py +++ b/src/database/processing.py @@ -22,7 +22,22 @@ def enqueue(run_id: int, expdb: Connection) -> None: def get_pending(expdb: Connection) -> Sequence[Row]: - """Return all processing_run rows whose status is 'pending'.""" + """Atomically claim all pending processing_run rows for this worker. + + Uses an UPDATE ... WHERE status='pending' approach so that concurrent + workers don't double-process the same run. Claimed rows are set to + 'processing' and this worker reads them back by that status. + """ + # Atomically mark pending rows as 'processing' so concurrent workers skip them + expdb.execute( + text( + """ + UPDATE processing_run + SET `status` = 'processing' + WHERE `status` = 'pending' + """, + ), + ) return cast( "Sequence[Row]", expdb.execute( @@ -30,7 +45,7 @@ def get_pending(expdb: Connection) -> Sequence[Row]: """ SELECT `run_id`, `status`, `date` FROM processing_run - WHERE `status` = 'pending' + WHERE `status` = 'processing' ORDER BY `date` ASC """, ), diff --git a/src/database/runs.py b/src/database/runs.py index 6fd91d46..15db6483 100644 --- a/src/database/runs.py +++ b/src/database/runs.py @@ -110,3 +110,12 @@ def store_evaluation( "array_data": array_data, }, ) + + +def delete(run_id: int, expdb: Connection) -> None: + """Delete a run row by primary key (used for rollback on enqueue failure).""" + expdb.execute( + text("DELETE FROM run WHERE `rid` = :run_id"), + parameters={"run_id": run_id}, + ) + diff --git a/src/routers/openml/runs.py b/src/routers/openml/runs.py index 93e4e066..36e6361d 100644 --- a/src/routers/openml/runs.py +++ b/src/routers/openml/runs.py @@ -1,11 +1,11 @@ - from __future__ import annotations import json import logging +import shutil from http import HTTPStatus from pathlib import Path -from typing import TYPE_CHECKING, Annotated +from typing import TYPE_CHECKING, Annotated, Any import xmltodict from fastapi import APIRouter, Depends, File, HTTPException, UploadFile @@ -14,6 +14,7 @@ import database.processing import database.runs import database.tasks +from config import load_configuration from database.users import User from routers.dependencies import expdb_connection, fetch_user from schemas.runs import RunDetail, RunEvaluationResult, RunUploadResponse @@ -24,36 +25,39 @@ router = APIRouter(prefix="/runs", tags=["runs"]) log = logging.getLogger(__name__) +_DEFAULT_UPLOAD_DIR = "/tmp/openml_runs" # noqa: S108 +_OML_NAMESPACE = "http://openml.org/openml" - -def _parse_run_xml(xml_bytes: bytes) -> dict: +def _parse_run_xml(xml_bytes: bytes) -> dict[str, Any]: """Parse the run description XML uploaded by the client. + Uses xmltodict namespace stripping so that 'oml:task_id' in the source + becomes simply 'task_id' in the returned dict, without doing a string + replace that could corrupt any value that contains 'oml:'. + Expected root element: Required children: oml:task_id, oml:implementation_id (flow_id). Optional: oml:setup_string, oml:output_data, oml:parameter_setting. """ try: - raw = xmltodict.parse(xml_bytes.decode("utf-8")) + raw: dict[str, Any] = xmltodict.parse( + xml_bytes.decode("utf-8"), + process_namespaces=True, + namespaces={_OML_NAMESPACE: None}, + ) except Exception as exc: raise HTTPException( status_code=HTTPStatus.UNPROCESSABLE_ENTITY, detail={"code": "530", "message": f"Invalid run description XML: {exc}"}, ) from exc - # Strip the namespace prefix so keys are consistent - run_str = json.dumps(raw).replace("oml:", "") - data: dict = json.loads(run_str) - - return data.get("run", {}) - - - + return raw.get("run", {}) def _require_auth(user: User | None) -> User: + """Raise 412 if the request is not authenticated.""" if user is None: raise HTTPException( status_code=HTTPStatus.PRECONDITION_FAILED, @@ -63,6 +67,7 @@ def _require_auth(user: User | None) -> User: def _require_task(task_id: int, expdb: Connection) -> None: + """Raise 404 with code 201 if task_id does not exist.""" if not database.tasks.get(task_id, expdb): raise HTTPException( status_code=HTTPStatus.NOT_FOUND, @@ -71,6 +76,7 @@ def _require_task(task_id: int, expdb: Connection) -> None: def _require_flow(flow_id: int, expdb: Connection) -> None: + """Raise 404 with code 180 if flow_id does not exist.""" if not database.flows.get(flow_id, expdb): raise HTTPException( status_code=HTTPStatus.NOT_FOUND, @@ -78,9 +84,6 @@ def _require_flow(flow_id: int, expdb: Connection) -> None: ) - - - @router.post( "/", summary="Upload a run (predictions + description XML)", @@ -142,17 +145,21 @@ async def upload_run( ) # Persist the predictions file to disk so the worker can read it later - from config import load_configuration # noqa: PLC0415 - - upload_dir: str = load_configuration().get("upload_dir", "/tmp/openml_runs") # noqa: S108 + upload_dir: str = load_configuration().get("upload_dir", _DEFAULT_UPLOAD_DIR) run_dir = Path(upload_dir) / str(run_id) run_dir.mkdir(parents=True, exist_ok=True) predictions_bytes = await predictions.read() predictions_path = run_dir / "predictions.arff" predictions_path.write_bytes(predictions_bytes) - # Enqueue for server-side evaluation - database.processing.enqueue(run_id, expdb) + # Enqueue for server-side evaluation; on failure, clean up to avoid orphans + try: + database.processing.enqueue(run_id, expdb) + except Exception: + log.exception("Failed to enqueue run %d; rolling back artifacts.", run_id) + shutil.rmtree(run_dir, ignore_errors=True) + database.runs.delete(run_id, expdb) + raise log.info( "Run %d uploaded by user %d (task=%d, flow=%d).", @@ -164,9 +171,6 @@ async def upload_run( return RunUploadResponse(run_id=run_id) - - - @router.get( "/{run_id}", summary="Get run metadata and evaluation results", @@ -192,8 +196,10 @@ def get_run( per_fold: list[float] | None = None if row.array_data: try: - per_fold = [float(v) for v in json.loads(row.array_data)] - except (json.JSONDecodeError, ValueError): + parsed = json.loads(row.array_data) + if isinstance(parsed, list): + per_fold = [float(v) for v in parsed] + except (json.JSONDecodeError, ValueError, TypeError): per_fold = None evaluations.append( diff --git a/src/worker/evaluator.py b/src/worker/evaluator.py index a449e020..dba2460b 100644 --- a/src/worker/evaluator.py +++ b/src/worker/evaluator.py @@ -1,6 +1,7 @@ - from __future__ import annotations +import csv +import io import logging import urllib.request from pathlib import Path @@ -20,9 +21,12 @@ log = logging.getLogger(__name__) +# Shared default; also set in routers/openml/runs.py — kept in sync via config key +_DEFAULT_UPLOAD_DIR = "/tmp/openml_runs" # noqa: S108 + def _parse_predictions_arff(content: str) -> dict[str, list[Any]]: - """Parse an OpenML predictions ARFF. + """Parse an OpenML predictions ARFF using csv.reader to handle quoted values. Returns a dict with keys: 'row_id', 'prediction', 'confidence' (optional). Expected columns: row_id, fold, repeat, prediction [, confidence.*] @@ -40,9 +44,13 @@ def _parse_predictions_arff(content: str) -> dict[str, list[Any]]: if not in_data: continue - parts = [p.strip().strip("'\"") for p in stripped.split(",")] - if not parts: + # Use csv.reader so quoted commas and quoted values are handled correctly + try: + (parts,) = csv.reader(io.StringIO(stripped)) + except (ValueError, StopIteration): continue + parts = [p.strip().strip("'\"") for p in parts] + try: row_id = int(parts[0]) prediction = parts[3] if len(parts) > 3 else parts[-1] # noqa: PLR2004 @@ -57,23 +65,20 @@ def _parse_predictions_arff(content: str) -> dict[str, list[Any]]: return result -def _load_ground_truth( - dataset_url: str, +def _fetch_arff(url: str) -> str: + """Download an ARFF from a URL, returning the decoded text content.""" + with urllib.request.urlopen(url, timeout=30) as resp: # noqa: S310 + return resp.read().decode("utf-8", errors="replace") + + +def _parse_dataset_labels( + content: str, target_attribute: str, - test_row_ids: list[int], -) -> list[str]: - """Download the dataset ARFF and extract the target column for given row IDs. +) -> dict[int, str]: + """Parse a dataset ARFF and return a {row_index: label} map for target_attribute. - Only extracts rows whose 0-based index is in `test_row_ids`. - Returns labels as strings in the order of `test_row_ids`. + Returns an empty dict (with a warning) when the target column is absent. """ - try: - with urllib.request.urlopen(dataset_url, timeout=30) as resp: # noqa: S310 - content = resp.read().decode("utf-8", errors="replace") - except Exception: - log.exception("Failed to download dataset from %s", dataset_url) - return [] - attr_names: list[str] = [] data_rows: list[list[str]] = [] in_data = False @@ -90,22 +95,25 @@ def _load_ground_truth( in_data = True continue if in_data: - data_rows.append([v.strip().strip("'\"") for v in stripped.split(",")]) + try: + (row,) = csv.reader(io.StringIO(stripped)) + data_rows.append([v.strip().strip("'\"") for v in row]) + except (ValueError, StopIteration): + continue if target_attribute not in attr_names: log.warning("Target attribute '%s' not found in dataset.", target_attribute) - return [] + return {} target_idx = attr_names.index(target_attribute) - pos_to_label = { + return { i: row[target_idx] for i, row in enumerate(data_rows) - if i in set(test_row_ids) and target_idx < len(row) + if target_idx < len(row) } - return [pos_to_label.get(rid, "") for rid in test_row_ids] -def _evaluate_run(run_id: int, expdb: Connection) -> None: # noqa: C901, PLR0911, PLR0915 +def _evaluate_run(run_id: int, expdb: Connection) -> None: # noqa: C901, PLR0911, PLR0912, PLR0915 """Evaluate a single run, store metrics, mark processing entry done/error.""" run = database.runs.get(run_id, expdb) if run is None: @@ -140,20 +148,33 @@ def _evaluate_run(run_id: int, expdb: Connection) -> None: # noqa: C901, PLR091 return dataset_url = str(_format_dataset_url(dataset_row)) + # Fetch dataset once (not per fold) and build a complete label map + try: + dataset_content = _fetch_arff(dataset_url) + except Exception: + log.exception("Failed to download dataset from %s", dataset_url) + database.processing.mark_error(run_id, "could not fetch dataset", expdb) + return + label_map = _parse_dataset_labels(dataset_content, target_attr) + if not label_map: + database.processing.mark_error(run_id, "target attribute not found in dataset", expdb) + return + cfg = load_routing_configuration() task_id = run.task_id - splits_url = f"{cfg.get('server_url', '')}api_splits/get/{task_id}/Task_{task_id}_splits.arff" + splits_url = ( + f"{cfg.get('server_url', '')}api_splits/get/{task_id}/Task_{task_id}_splits.arff" + ) try: - with urllib.request.urlopen(splits_url, timeout=30) as resp: # noqa: S310 - splits_content = resp.read().decode("utf-8", errors="replace") + splits_content = _fetch_arff(splits_url) except Exception: log.exception("Could not fetch splits for task %d", task_id) database.processing.mark_error(run_id, "could not fetch splits", expdb) return - fold_index = build_fold_index(parse_arff_splits(splits_content), repeat=0) + splits = parse_arff_splits(splits_content) - upload_dir: str = load_configuration().get("upload_dir", "/tmp/openml_runs") # noqa: S108 + upload_dir: str = load_configuration().get("upload_dir", _DEFAULT_UPLOAD_DIR) predictions_path = Path(upload_dir) / str(run_id) / "predictions.arff" try: with predictions_path.open(encoding="utf-8") as fh: @@ -170,18 +191,45 @@ def _evaluate_run(run_id: int, expdb: Connection) -> None: # noqa: C901, PLR091 conf_map: dict[int, float | None] = dict( zip(predictions["row_id"], predictions["confidence"], strict=True), ) - has_scores = any(v is not None for v in conf_map.values()) - all_true: list[str] = [] - all_pred: list[str] = [] + # Determine available repeats and iterate all of them + all_repeats = sorted({int(e["repeat"]) for e in splits}) + + all_true: list[str | int | float] = [] + all_pred: list[str | int | float] = [] all_score: list[float] = [] - for train_ids, test_ids in fold_index.values(): # noqa: B007 - all_true.extend(_load_ground_truth(dataset_url, target_attr, test_ids)) - all_pred.extend(pred_map.get(rid, "") for rid in test_ids) - if has_scores: - for rid in test_ids: - raw = conf_map.get(rid) - all_score.append(float(raw) if raw is not None else 0.0) + # has_scores starts True; disabled if any fold has a missing score + has_scores = any(v is not None for v in conf_map.values()) + + for repeat in all_repeats: + fold_index = build_fold_index(splits, repeat=repeat) + for _train_ids, test_ids in fold_index.values(): + # Validate ground truth: error out if any row ID is missing + missing = [rid for rid in test_ids if rid not in label_map] + if missing: + database.processing.mark_error( + run_id, + f"ground-truth missing for row_ids {missing[:5]}", + expdb, + ) + return + + y_true_fold = [label_map[rid] for rid in test_ids] + y_pred_fold = [pred_map.get(rid, "") for rid in test_ids] + all_true.extend(y_true_fold) + all_pred.extend(y_pred_fold) + + if has_scores: + fold_scores: list[float] = [] + for rid in test_ids: + raw = conf_map.get(rid) + if raw is None: + # Score missing for this fold — disable AUC for whole run + has_scores = False + fold_scores = [] + break + fold_scores.append(float(raw)) + all_score.extend(fold_scores) metrics = compute_metrics( task_type_id=task_row.ttid, diff --git a/tests/routers/openml/runs_test.py b/tests/routers/openml/runs_test.py index f5651ba0..be47fd2a 100644 --- a/tests/routers/openml/runs_test.py +++ b/tests/routers/openml/runs_test.py @@ -1,17 +1,22 @@ - from __future__ import annotations import io +import pathlib +from datetime import datetime from http import HTTPStatus +from typing import TYPE_CHECKING import pytest -from pytest_mock import MockerFixture -from starlette.testclient import TestClient + +if TYPE_CHECKING: + from pytest_mock import MockerFixture + from starlette.testclient import TestClient -# --------------------------------------------------------------------------- -# Minimal ARFF predictions content used in tests -# --------------------------------------------------------------------------- +RUN_ID = 42 +FLOW_ID = 2 +EXPECTED_RUN_ID = 99 +EXPECTED_RUN_ID_2 = 71 MINIMAL_PREDICTIONS_ARFF = b"""@relation predictions @@ -37,11 +42,6 @@ INVALID_XML = b"not xml at all <<<" -# --------------------------------------------------------------------------- -# GET /runs/{run_id} -# --------------------------------------------------------------------------- - - def test_get_run_not_found(py_api: TestClient) -> None: """GET an unknown run_id should return 404.""" response = py_api.get("/runs/999999999") @@ -52,12 +52,10 @@ def test_get_run_not_found(py_api: TestClient) -> None: def test_get_run_returns_structure(mocker: MockerFixture, py_api: TestClient) -> None: """GET /runs/{id} returns RunDetail structure when run exists.""" - from datetime import datetime - mock_run = mocker.MagicMock() - mock_run.rid = 42 + mock_run.rid = RUN_ID mock_run.task_id = 1 - mock_run.flow_id = 2 + mock_run.flow_id = FLOW_ID mock_run.uploader = 16 mock_run.upload_time = datetime(2024, 1, 15, 10, 30, 0) mock_run.setup_string = "weka.J48 -C 0.25" @@ -71,24 +69,19 @@ def test_get_run_returns_structure(mocker: MockerFixture, py_api: TestClient) -> ], ) - response = py_api.get("/runs/42") + response = py_api.get(f"/runs/{RUN_ID}") assert response.status_code == HTTPStatus.OK body = response.json() - assert body["run_id"] == 42 + assert body["run_id"] == RUN_ID assert body["task_id"] == 1 - assert body["flow_id"] == 2 + assert body["flow_id"] == FLOW_ID assert body["tags"] == ["study_1"] assert len(body["evaluations"]) == 1 assert body["evaluations"][0]["function"] == "predictive_accuracy" assert body["evaluations"][0]["value"] == pytest.approx(0.93) -# --------------------------------------------------------------------------- -# POST /runs -# --------------------------------------------------------------------------- - - def test_upload_run_requires_auth(py_api: TestClient) -> None: """Unauthenticated POST /runs should return 412 with code 103.""" response = py_api.post( @@ -104,7 +97,6 @@ def test_upload_run_requires_auth(py_api: TestClient) -> None: def test_upload_run_invalid_xml(mocker: MockerFixture, py_api: TestClient) -> None: """Malformed description XML should return 422.""" - # Simulate an authenticated user mocker.patch("routers.dependencies.fetch_user", return_value=mocker.MagicMock(user_id=1)) response = py_api.post( @@ -152,13 +144,15 @@ def test_upload_run_unknown_flow(mocker: MockerFixture, py_api: TestClient) -> N assert response.json()["detail"]["code"] == "180" -def test_upload_run_success(mocker: MockerFixture, tmp_path, py_api: TestClient) -> None: +def test_upload_run_success( + mocker: MockerFixture, tmp_path: pathlib.Path, py_api: TestClient, +) -> None: """A fully valid POST /runs should return 201 with a run_id.""" fake_user = mocker.MagicMock(user_id=16) mocker.patch("routers.dependencies.fetch_user", return_value=fake_user) mocker.patch("database.tasks.get", return_value=mocker.MagicMock()) mocker.patch("database.flows.get", return_value=mocker.MagicMock()) - mocker.patch("database.runs.create", return_value=99) + mocker.patch("database.runs.create", return_value=EXPECTED_RUN_ID) mocker.patch("database.processing.enqueue") mocker.patch( "routers.openml.runs.load_configuration", @@ -174,21 +168,25 @@ def test_upload_run_success(mocker: MockerFixture, tmp_path, py_api: TestClient) ) assert response.status_code == HTTPStatus.CREATED body = response.json() - assert body["run_id"] == 99 + assert body["run_id"] == EXPECTED_RUN_ID # Verify predictions file was persisted - predictions_path = tmp_path / "99" / "predictions.arff" + predictions_path = tmp_path / str(EXPECTED_RUN_ID) / "predictions.arff" assert predictions_path.exists() assert predictions_path.read_bytes() == MINIMAL_PREDICTIONS_ARFF -def test_upload_run_enqueues_processing(mocker: MockerFixture, tmp_path, py_api: TestClient) -> None: +def test_upload_run_enqueues_processing( + mocker: MockerFixture, + tmp_path: pathlib.Path, + py_api: TestClient, +) -> None: """Successful upload must enqueue a processing_run entry.""" fake_user = mocker.MagicMock(user_id=16) mocker.patch("routers.dependencies.fetch_user", return_value=fake_user) mocker.patch("database.tasks.get", return_value=mocker.MagicMock()) mocker.patch("database.flows.get", return_value=mocker.MagicMock()) - mocker.patch("database.runs.create", return_value=7) + mocker.patch("database.runs.create", return_value=EXPECTED_RUN_ID_2) enqueue_mock = mocker.patch("database.processing.enqueue") mocker.patch( "routers.openml.runs.load_configuration", @@ -204,4 +202,4 @@ def test_upload_run_enqueues_processing(mocker: MockerFixture, tmp_path, py_api: ) enqueue_mock.assert_called_once() call_kwargs = enqueue_mock.call_args - assert call_kwargs[0][0] == 7 # run_id positional arg + assert call_kwargs[0][0] == EXPECTED_RUN_ID_2 # run_id positional arg From 9dd9847ae5692e1f0a246dda169affdee802faec Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 27 Feb 2026 18:51:26 +0000 Subject: [PATCH 4/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/database/runs.py | 1 - src/routers/openml/runs.py | 1 - src/worker/evaluator.py | 10 ++-------- tests/routers/openml/runs_test.py | 4 +++- 4 files changed, 5 insertions(+), 11 deletions(-) diff --git a/src/database/runs.py b/src/database/runs.py index 147b0a73..a420472e 100644 --- a/src/database/runs.py +++ b/src/database/runs.py @@ -117,4 +117,3 @@ def delete(run_id: int, expdb: Connection) -> None: text("DELETE FROM run WHERE `rid` = :run_id"), parameters={"run_id": run_id}, ) - diff --git a/src/routers/openml/runs.py b/src/routers/openml/runs.py index a087c2d0..36e6361d 100644 --- a/src/routers/openml/runs.py +++ b/src/routers/openml/runs.py @@ -30,7 +30,6 @@ _OML_NAMESPACE = "http://openml.org/openml" - def _parse_run_xml(xml_bytes: bytes) -> dict[str, Any]: """Parse the run description XML uploaded by the client. diff --git a/src/worker/evaluator.py b/src/worker/evaluator.py index 763623bf..79c66880 100644 --- a/src/worker/evaluator.py +++ b/src/worker/evaluator.py @@ -106,11 +106,7 @@ def _parse_dataset_labels( return {} target_idx = attr_names.index(target_attribute) - return { - i: row[target_idx] - for i, row in enumerate(data_rows) - if target_idx < len(row) - } + return {i: row[target_idx] for i, row in enumerate(data_rows) if target_idx < len(row)} def _evaluate_run(run_id: int, expdb: Connection) -> None: # noqa: C901, PLR0911, PLR0912, PLR0915 @@ -162,9 +158,7 @@ def _evaluate_run(run_id: int, expdb: Connection) -> None: # noqa: C901, PLR091 cfg = load_routing_configuration() task_id = run.task_id - splits_url = ( - f"{cfg.get('server_url', '')}api_splits/get/{task_id}/Task_{task_id}_splits.arff" - ) + splits_url = f"{cfg.get('server_url', '')}api_splits/get/{task_id}/Task_{task_id}_splits.arff" try: splits_content = _fetch_arff(splits_url) except Exception: diff --git a/tests/routers/openml/runs_test.py b/tests/routers/openml/runs_test.py index be47fd2a..092c58fd 100644 --- a/tests/routers/openml/runs_test.py +++ b/tests/routers/openml/runs_test.py @@ -145,7 +145,9 @@ def test_upload_run_unknown_flow(mocker: MockerFixture, py_api: TestClient) -> N def test_upload_run_success( - mocker: MockerFixture, tmp_path: pathlib.Path, py_api: TestClient, + mocker: MockerFixture, + tmp_path: pathlib.Path, + py_api: TestClient, ) -> None: """A fully valid POST /runs should return 201 with a run_id.""" fake_user = mocker.MagicMock(user_id=16) From e63080443eeafc9dc7b0b7e36fdffd968a28efa0 Mon Sep 17 00:00:00 2001 From: vinayak sharma Date: Sat, 28 Feb 2026 00:30:50 +0530 Subject: [PATCH 5/5] Fix mypy Returning Any errors via explicit casts --- src/routers/openml/runs.py | 4 ++-- src/worker/evaluator.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/routers/openml/runs.py b/src/routers/openml/runs.py index 36e6361d..456d1eb6 100644 --- a/src/routers/openml/runs.py +++ b/src/routers/openml/runs.py @@ -5,7 +5,7 @@ import shutil from http import HTTPStatus from pathlib import Path -from typing import TYPE_CHECKING, Annotated, Any +from typing import TYPE_CHECKING, Annotated, Any, cast import xmltodict from fastapi import APIRouter, Depends, File, HTTPException, UploadFile @@ -53,7 +53,7 @@ def _parse_run_xml(xml_bytes: bytes) -> dict[str, Any]: detail={"code": "530", "message": f"Invalid run description XML: {exc}"}, ) from exc - return raw.get("run", {}) + return cast("dict[str, Any]", raw.get("run", {})) def _require_auth(user: User | None) -> User: diff --git a/src/worker/evaluator.py b/src/worker/evaluator.py index 79c66880..6659d3d6 100644 --- a/src/worker/evaluator.py +++ b/src/worker/evaluator.py @@ -5,7 +5,7 @@ import logging import urllib.request from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast import database.processing import database.runs @@ -68,7 +68,7 @@ def _parse_predictions_arff(content: str) -> dict[str, list[Any]]: def _fetch_arff(url: str) -> str: """Download an ARFF from a URL, returning the decoded text content.""" with urllib.request.urlopen(url, timeout=30) as resp: # noqa: S310 - return resp.read().decode("utf-8", errors="replace") + return cast("str", resp.read().decode("utf-8", errors="replace")) def _parse_dataset_labels(