-
-
Notifications
You must be signed in to change notification settings - Fork 44
[FEAT] Implement POST /run upload and server-side evaluation pipeline in Python #260
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
5f02cf3
363b9f2
0bd41fd
91e727b
9dd9847
e630804
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,125 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import math | ||
| from collections.abc import Sequence | ||
|
|
||
|
|
||
| def accuracy(y_true: Sequence[str | int], y_pred: Sequence[str | int]) -> float: | ||
| """Fraction of predictions that exactly match the ground truth.""" | ||
| if len(y_true) != len(y_pred): | ||
| msg = f"Length mismatch: {len(y_true)} vs {len(y_pred)}" | ||
| raise ValueError(msg) | ||
| if not y_true: | ||
| return 0.0 | ||
| correct = sum(t == p for t, p in zip(y_true, y_pred, strict=True)) | ||
| return correct / len(y_true) | ||
|
|
||
|
|
||
| def rmse(y_true: Sequence[float], y_pred: Sequence[float]) -> float: | ||
| """Root Mean Squared Error.""" | ||
| if len(y_true) != len(y_pred): | ||
| msg = f"Length mismatch: {len(y_true)} vs {len(y_pred)}" | ||
| raise ValueError(msg) | ||
| if not y_true: | ||
| return 0.0 | ||
| mse = sum((t - p) ** 2 for t, p in zip(y_true, y_pred, strict=True)) / len(y_true) | ||
| return math.sqrt(mse) | ||
|
|
||
|
|
||
| def mean_absolute_error(y_true: Sequence[float], y_pred: Sequence[float]) -> float: | ||
| """Mean Absolute Error.""" | ||
| if len(y_true) != len(y_pred): | ||
| msg = f"Length mismatch: {len(y_true)} vs {len(y_pred)}" | ||
| raise ValueError(msg) | ||
| if not y_true: | ||
| return 0.0 | ||
| return sum(abs(t - p) for t, p in zip(y_true, y_pred, strict=True)) / len(y_true) | ||
|
|
||
|
|
||
| def auc(y_true: Sequence[int], y_score: Sequence[float]) -> float: | ||
| """Binary ROC AUC via an O(n log n) rank-based Mann-Whitney U statistic. | ||
|
|
||
| Mathematically equivalent to the area under the ROC curve. | ||
|
|
||
| y_true: sequence of 0/1 ground-truth labels. | ||
| y_score: sequence of predicted probabilities for the positive class (label=1). | ||
|
|
||
| Raises: | ||
| ValueError: if y_true contains values outside {0, 1} or lengths differ. | ||
| """ | ||
| if len(y_true) != len(y_score): | ||
| msg = f"Length mismatch: {len(y_true)} vs {len(y_score)}" | ||
| raise ValueError(msg) | ||
| if not y_true: | ||
| return 0.0 | ||
|
|
||
| unique = set(y_true) | ||
| invalid = unique - {0, 1} | ||
| if invalid: | ||
| msg = f"y_true must contain only 0/1 labels; found {invalid}" | ||
| raise ValueError(msg) | ||
|
|
||
| n_pos = sum(y_true) | ||
| n_neg = len(y_true) - n_pos | ||
| if n_pos == 0 or n_neg == 0: | ||
| return 0.0 | ||
|
|
||
| # O(n log n): rank all scores, then use the rank-sum formula | ||
| pairs = sorted(zip(y_score, y_true, strict=False), key=lambda x: x[0]) | ||
|
|
||
| n = len(pairs) | ||
| ranks: list[float] = [0.0] * n | ||
| i = 0 | ||
| while i < n: | ||
| j = i | ||
| while j < n - 1 and pairs[j][0] == pairs[j + 1][0]: | ||
| j += 1 | ||
| mid_rank = (i + j) / 2 + 1 # 1-indexed | ||
| for k in range(i, j + 1): | ||
| ranks[k] = mid_rank | ||
| i = j + 1 | ||
|
|
||
| rank_sum_pos = sum(ranks[k] for k in range(n) if pairs[k][1] == 1) | ||
| u_pos = rank_sum_pos - n_pos * (n_pos + 1) / 2 | ||
| return u_pos / (n_pos * n_neg) | ||
|
|
||
|
|
||
| #: Task type IDs from the OpenML schema | ||
| TASK_TYPE_SUPERVISED_CLASSIFICATION = 1 | ||
| TASK_TYPE_SUPERVISED_REGRESSION = 2 | ||
|
|
||
|
|
||
| def compute_metrics( | ||
| task_type_id: int, | ||
| y_true: Sequence[str | int | float], | ||
| y_pred: Sequence[str | int | float], | ||
| y_score: Sequence[float] | None = None, | ||
| ) -> dict[str, float]: | ||
| """Compute all applicable metrics for the given task type. | ||
|
|
||
| Returns a dict of {measure_name: value} using the same names found in | ||
| the OpenML `math_function` table (e.g. 'predictive_accuracy', | ||
| 'area_under_roc_curve'). | ||
| """ | ||
| results: dict[str, float] = {} | ||
|
|
||
| if task_type_id == TASK_TYPE_SUPERVISED_CLASSIFICATION: | ||
| str_true = [str(v) for v in y_true] | ||
| str_pred = [str(v) for v in y_pred] | ||
| results["predictive_accuracy"] = accuracy(str_true, str_pred) | ||
|
|
||
| # AUC only when binary and scores are provided | ||
| unique_labels = set(str_true) | ||
| if y_score is not None and len(unique_labels) == 2: # noqa: PLR2004 | ||
| # Map the positive class (lexicographically larger, matching OpenML) | ||
| pos_label = max(unique_labels) | ||
| int_true = [1 if str(v) == pos_label else 0 for v in y_true] | ||
| results["area_under_roc_curve"] = auc(int_true, list(y_score)) | ||
|
|
||
| elif task_type_id == TASK_TYPE_SUPERVISED_REGRESSION: | ||
| float_true = [float(v) for v in y_true] | ||
| float_pred = [float(v) for v in y_pred] | ||
| results["root_mean_squared_error"] = rmse(float_true, float_pred) | ||
| results["mean_absolute_error"] = mean_absolute_error(float_true, float_pred) | ||
|
|
||
| return results |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,113 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import random | ||
| import re | ||
|
|
||
| SplitEntry = dict[str, int | str] | ||
|
|
||
|
|
||
| def generate_splits( | ||
| n_samples: int, | ||
| n_folds: int, | ||
| n_repeats: int, | ||
| *, | ||
| seed: int = 0, | ||
| ) -> list[SplitEntry]: | ||
| """Generate cross-validation splits deterministically. | ||
|
|
||
| Returns a flat list of dicts with keys: | ||
| repeat, fold, rowid, type ('TRAIN' or 'TEST') | ||
| """ | ||
| if n_folds <= 0: | ||
| msg = f"n_folds must be a positive integer, got {n_folds}" | ||
| raise ValueError(msg) | ||
| if n_repeats <= 0: | ||
| msg = f"n_repeats must be a positive integer, got {n_repeats}" | ||
| raise ValueError(msg) | ||
| if n_samples <= 0: | ||
| return [] | ||
|
|
||
| entries: list[SplitEntry] = [] | ||
| rng = random.Random(seed) # noqa: S311 | ||
|
|
||
| for repeat in range(n_repeats): | ||
| indices = list(range(n_samples)) | ||
| rng.shuffle(indices) | ||
|
|
||
| for fold in range(n_folds): | ||
| for row_pos, rowid in enumerate(indices): | ||
| split_type = "TEST" if row_pos % n_folds == fold else "TRAIN" | ||
| entries.append( | ||
| { | ||
| "repeat": repeat, | ||
| "fold": fold, | ||
| "rowid": rowid, | ||
| "type": split_type, | ||
| }, | ||
| ) | ||
|
|
||
| return entries | ||
|
|
||
|
|
||
| _ARFF_DATA_SECTION = re.compile(r"@[Dd][Aa][Tt][Aa]") | ||
|
|
||
|
|
||
| def parse_arff_splits(arff_content: str) -> list[SplitEntry]: | ||
| """Parse an OpenML splits ARFF file into the same list-of-dict format. | ||
|
|
||
| Expected ARFF columns (in order): type, rowid, repeat, fold | ||
| (This is the column order used by OpenML's split ARFF files.) | ||
| """ | ||
| in_data = False | ||
| entries: list[SplitEntry] = [] | ||
|
|
||
| for line in arff_content.splitlines(): | ||
| stripped = line.strip() | ||
| if not stripped or stripped.startswith("%"): | ||
| continue | ||
| if _ARFF_DATA_SECTION.match(stripped): | ||
| in_data = True | ||
| continue | ||
| if not in_data: | ||
| continue | ||
|
|
||
| parts = [p.strip() for p in stripped.split(",")] | ||
| if len(parts) < 4: # noqa: PLR2004 | ||
| continue | ||
| split_type, rowid_s, repeat_s, fold_s = parts[:4] | ||
| try: | ||
| entries.append( | ||
| { | ||
| "repeat": int(repeat_s), | ||
| "fold": int(fold_s), | ||
| "rowid": int(rowid_s), | ||
| "type": split_type.strip("'\""), | ||
| }, | ||
| ) | ||
| except ValueError: | ||
| continue | ||
|
|
||
| return entries | ||
|
|
||
|
|
||
| def build_fold_index( | ||
| splits: list[SplitEntry], | ||
| repeat: int = 0, | ||
| ) -> dict[int, tuple[list[int], list[int]]]: | ||
| """Build a dict of fold -> (train_indices, test_indices) for a given repeat.""" | ||
| folds: dict[int, tuple[list[int], list[int]]] = {} | ||
| for entry in splits: | ||
| if entry["repeat"] != repeat: | ||
| continue | ||
| fold = int(entry["fold"]) | ||
| rowid = int(entry["rowid"]) | ||
| split_type = str(entry["type"]).upper() | ||
| if split_type not in {"TRAIN", "TEST"}: | ||
| continue | ||
| if fold not in folds: | ||
| folds[fold] = ([], []) | ||
| if split_type == "TRAIN": | ||
| folds[fold][0].append(rowid) | ||
| else: | ||
| folds[fold][1].append(rowid) | ||
coderabbitai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return folds | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,80 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import datetime | ||
| from collections.abc import Sequence | ||
| from typing import cast | ||
|
|
||
| from sqlalchemy import Connection, Row, text | ||
|
|
||
|
|
||
| def enqueue(run_id: int, expdb: Connection) -> None: | ||
| """Insert a new pending processing entry for the given run.""" | ||
| expdb.execute( | ||
| text( | ||
| """ | ||
| INSERT INTO processing_run(`run_id`, `status`, `date`) | ||
| VALUES (:run_id, 'pending', :date) | ||
| """, | ||
| ), | ||
| parameters={"run_id": run_id, "date": datetime.datetime.now()}, | ||
| ) | ||
|
|
||
|
|
||
| def get_pending(expdb: Connection) -> Sequence[Row]: | ||
| """Atomically claim all pending processing_run rows for this worker. | ||
|
|
||
| Uses an UPDATE ... WHERE status='pending' approach so that concurrent | ||
| workers don't double-process the same run. Claimed rows are set to | ||
| 'processing' and this worker reads them back by that status. | ||
| """ | ||
| # Atomically mark pending rows as 'processing' so concurrent workers skip them | ||
| expdb.execute( | ||
| text( | ||
| """ | ||
| UPDATE processing_run | ||
| SET `status` = 'processing' | ||
| WHERE `status` = 'pending' | ||
| """, | ||
| ), | ||
| ) | ||
| return cast( | ||
| "Sequence[Row]", | ||
| expdb.execute( | ||
| text( | ||
| """ | ||
| SELECT `run_id`, `status`, `date` | ||
| FROM processing_run | ||
| WHERE `status` = 'processing' | ||
| ORDER BY `date` ASC | ||
| """, | ||
| ), | ||
| ).all(), | ||
| ) | ||
|
|
||
|
|
||
| def mark_done(run_id: int, expdb: Connection) -> None: | ||
| """Mark a processing_run entry as successfully completed.""" | ||
| expdb.execute( | ||
| text( | ||
| """ | ||
| UPDATE processing_run | ||
| SET `status` = 'done' | ||
| WHERE `run_id` = :run_id | ||
| """, | ||
| ), | ||
| parameters={"run_id": run_id}, | ||
| ) | ||
|
|
||
|
|
||
| def mark_error(run_id: int, error_message: str, expdb: Connection) -> None: | ||
| """Mark a processing_run entry as failed and store the error message.""" | ||
| expdb.execute( | ||
| text( | ||
| """ | ||
| UPDATE processing_run | ||
| SET `status` = 'error', `error` = :error_message | ||
| WHERE `run_id` = :run_id | ||
| """, | ||
| ), | ||
| parameters={"run_id": run_id, "error_message": error_message}, | ||
| ) | ||
|
Comment on lines
+23
to
+80
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Pending work items are not atomically claimed before execution.
🤖 Prompt for AI Agents |
||
Uh oh!
There was an error while loading. Please reload this page.