Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 125 additions & 0 deletions src/core/evaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
from __future__ import annotations

import math
from collections.abc import Sequence


def accuracy(y_true: Sequence[str | int], y_pred: Sequence[str | int]) -> float:
"""Fraction of predictions that exactly match the ground truth."""
if len(y_true) != len(y_pred):
msg = f"Length mismatch: {len(y_true)} vs {len(y_pred)}"
raise ValueError(msg)
if not y_true:
return 0.0
correct = sum(t == p for t, p in zip(y_true, y_pred, strict=True))
return correct / len(y_true)


def rmse(y_true: Sequence[float], y_pred: Sequence[float]) -> float:
"""Root Mean Squared Error."""
if len(y_true) != len(y_pred):
msg = f"Length mismatch: {len(y_true)} vs {len(y_pred)}"
raise ValueError(msg)
if not y_true:
return 0.0
mse = sum((t - p) ** 2 for t, p in zip(y_true, y_pred, strict=True)) / len(y_true)
return math.sqrt(mse)


def mean_absolute_error(y_true: Sequence[float], y_pred: Sequence[float]) -> float:
"""Mean Absolute Error."""
if len(y_true) != len(y_pred):
msg = f"Length mismatch: {len(y_true)} vs {len(y_pred)}"
raise ValueError(msg)
if not y_true:
return 0.0
return sum(abs(t - p) for t, p in zip(y_true, y_pred, strict=True)) / len(y_true)


def auc(y_true: Sequence[int], y_score: Sequence[float]) -> float:
"""Binary ROC AUC via an O(n log n) rank-based Mann-Whitney U statistic.

Mathematically equivalent to the area under the ROC curve.

y_true: sequence of 0/1 ground-truth labels.
y_score: sequence of predicted probabilities for the positive class (label=1).

Raises:
ValueError: if y_true contains values outside {0, 1} or lengths differ.
"""
if len(y_true) != len(y_score):
msg = f"Length mismatch: {len(y_true)} vs {len(y_score)}"
raise ValueError(msg)
if not y_true:
return 0.0

unique = set(y_true)
invalid = unique - {0, 1}
if invalid:
msg = f"y_true must contain only 0/1 labels; found {invalid}"
raise ValueError(msg)

n_pos = sum(y_true)
n_neg = len(y_true) - n_pos
if n_pos == 0 or n_neg == 0:
return 0.0

# O(n log n): rank all scores, then use the rank-sum formula
pairs = sorted(zip(y_score, y_true, strict=False), key=lambda x: x[0])

n = len(pairs)
ranks: list[float] = [0.0] * n
i = 0
while i < n:
j = i
while j < n - 1 and pairs[j][0] == pairs[j + 1][0]:
j += 1
mid_rank = (i + j) / 2 + 1 # 1-indexed
for k in range(i, j + 1):
ranks[k] = mid_rank
i = j + 1

rank_sum_pos = sum(ranks[k] for k in range(n) if pairs[k][1] == 1)
u_pos = rank_sum_pos - n_pos * (n_pos + 1) / 2
return u_pos / (n_pos * n_neg)


#: Task type IDs from the OpenML schema
TASK_TYPE_SUPERVISED_CLASSIFICATION = 1
TASK_TYPE_SUPERVISED_REGRESSION = 2


def compute_metrics(
task_type_id: int,
y_true: Sequence[str | int | float],
y_pred: Sequence[str | int | float],
y_score: Sequence[float] | None = None,
) -> dict[str, float]:
"""Compute all applicable metrics for the given task type.

Returns a dict of {measure_name: value} using the same names found in
the OpenML `math_function` table (e.g. 'predictive_accuracy',
'area_under_roc_curve').
"""
results: dict[str, float] = {}

if task_type_id == TASK_TYPE_SUPERVISED_CLASSIFICATION:
str_true = [str(v) for v in y_true]
str_pred = [str(v) for v in y_pred]
results["predictive_accuracy"] = accuracy(str_true, str_pred)

# AUC only when binary and scores are provided
unique_labels = set(str_true)
if y_score is not None and len(unique_labels) == 2: # noqa: PLR2004
# Map the positive class (lexicographically larger, matching OpenML)
pos_label = max(unique_labels)
int_true = [1 if str(v) == pos_label else 0 for v in y_true]
results["area_under_roc_curve"] = auc(int_true, list(y_score))

elif task_type_id == TASK_TYPE_SUPERVISED_REGRESSION:
float_true = [float(v) for v in y_true]
float_pred = [float(v) for v in y_pred]
results["root_mean_squared_error"] = rmse(float_true, float_pred)
results["mean_absolute_error"] = mean_absolute_error(float_true, float_pred)

return results
113 changes: 113 additions & 0 deletions src/core/splits.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from __future__ import annotations

import random
import re

SplitEntry = dict[str, int | str]


def generate_splits(
n_samples: int,
n_folds: int,
n_repeats: int,
*,
seed: int = 0,
) -> list[SplitEntry]:
"""Generate cross-validation splits deterministically.

Returns a flat list of dicts with keys:
repeat, fold, rowid, type ('TRAIN' or 'TEST')
"""
if n_folds <= 0:
msg = f"n_folds must be a positive integer, got {n_folds}"
raise ValueError(msg)
if n_repeats <= 0:
msg = f"n_repeats must be a positive integer, got {n_repeats}"
raise ValueError(msg)
if n_samples <= 0:
return []

entries: list[SplitEntry] = []
rng = random.Random(seed) # noqa: S311

for repeat in range(n_repeats):
indices = list(range(n_samples))
rng.shuffle(indices)

for fold in range(n_folds):
for row_pos, rowid in enumerate(indices):
split_type = "TEST" if row_pos % n_folds == fold else "TRAIN"
entries.append(
{
"repeat": repeat,
"fold": fold,
"rowid": rowid,
"type": split_type,
},
)

return entries


_ARFF_DATA_SECTION = re.compile(r"@[Dd][Aa][Tt][Aa]")


def parse_arff_splits(arff_content: str) -> list[SplitEntry]:
"""Parse an OpenML splits ARFF file into the same list-of-dict format.

Expected ARFF columns (in order): type, rowid, repeat, fold
(This is the column order used by OpenML's split ARFF files.)
"""
in_data = False
entries: list[SplitEntry] = []

for line in arff_content.splitlines():
stripped = line.strip()
if not stripped or stripped.startswith("%"):
continue
if _ARFF_DATA_SECTION.match(stripped):
in_data = True
continue
if not in_data:
continue

parts = [p.strip() for p in stripped.split(",")]
if len(parts) < 4: # noqa: PLR2004
continue
split_type, rowid_s, repeat_s, fold_s = parts[:4]
try:
entries.append(
{
"repeat": int(repeat_s),
"fold": int(fold_s),
"rowid": int(rowid_s),
"type": split_type.strip("'\""),
},
)
except ValueError:
continue

return entries


def build_fold_index(
splits: list[SplitEntry],
repeat: int = 0,
) -> dict[int, tuple[list[int], list[int]]]:
"""Build a dict of fold -> (train_indices, test_indices) for a given repeat."""
folds: dict[int, tuple[list[int], list[int]]] = {}
for entry in splits:
if entry["repeat"] != repeat:
continue
fold = int(entry["fold"])
rowid = int(entry["rowid"])
split_type = str(entry["type"]).upper()
if split_type not in {"TRAIN", "TEST"}:
continue
if fold not in folds:
folds[fold] = ([], [])
if split_type == "TRAIN":
folds[fold][0].append(rowid)
else:
folds[fold][1].append(rowid)
return folds
80 changes: 80 additions & 0 deletions src/database/processing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from __future__ import annotations

import datetime
from collections.abc import Sequence
from typing import cast

from sqlalchemy import Connection, Row, text


def enqueue(run_id: int, expdb: Connection) -> None:
"""Insert a new pending processing entry for the given run."""
expdb.execute(
text(
"""
INSERT INTO processing_run(`run_id`, `status`, `date`)
VALUES (:run_id, 'pending', :date)
""",
),
parameters={"run_id": run_id, "date": datetime.datetime.now()},
)


def get_pending(expdb: Connection) -> Sequence[Row]:
"""Atomically claim all pending processing_run rows for this worker.

Uses an UPDATE ... WHERE status='pending' approach so that concurrent
workers don't double-process the same run. Claimed rows are set to
'processing' and this worker reads them back by that status.
"""
# Atomically mark pending rows as 'processing' so concurrent workers skip them
expdb.execute(
text(
"""
UPDATE processing_run
SET `status` = 'processing'
WHERE `status` = 'pending'
""",
),
)
return cast(
"Sequence[Row]",
expdb.execute(
text(
"""
SELECT `run_id`, `status`, `date`
FROM processing_run
WHERE `status` = 'processing'
ORDER BY `date` ASC
""",
),
).all(),
)


def mark_done(run_id: int, expdb: Connection) -> None:
"""Mark a processing_run entry as successfully completed."""
expdb.execute(
text(
"""
UPDATE processing_run
SET `status` = 'done'
WHERE `run_id` = :run_id
""",
),
parameters={"run_id": run_id},
)


def mark_error(run_id: int, error_message: str, expdb: Connection) -> None:
"""Mark a processing_run entry as failed and store the error message."""
expdb.execute(
text(
"""
UPDATE processing_run
SET `status` = 'error', `error` = :error_message
WHERE `run_id` = :run_id
""",
),
parameters={"run_id": run_id, "error_message": error_message},
)
Comment on lines +23 to +80
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Pending work items are not atomically claimed before execution.

get_pending on Line 23 reads all pending rows, while mark_done/mark_error on Lines 40-64 update afterward. With concurrent workers, the same run can be processed multiple times and final status can be overwritten nondeterministically.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/database/processing.py` around lines 23 - 65, get_pending currently just
reads all pending rows and mark_done/mark_error update later, allowing races;
change the flow to atomically claim a run before processing by adding a claim
step that runs inside a transaction: implement a new function (or modify
get_pending) that uses a transactional SELECT ... FOR UPDATE SKIP LOCKED (or an
UPDATE ... WHERE status='pending' RETURNING run_id) via expdb.begin() to set
status='processing' and return the claimed run(s) so no two workers get the same
run; keep mark_done and mark_error as finalizers that update the claimed run_id,
and ensure all claim/processing happens within a DB transaction using the
Connection.begin()/commit/rollback APIs.

Loading