Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 62 additions & 1 deletion app/routers/executions.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,59 @@
async def list_executions(
cue_id: Optional[str] = Query(None),
status: Optional[str] = Query(None),
status__in: Optional[str] = Query(
None,
description="Comma-separated execution statuses. Mutex with `status`.",
),
worker_id: Optional[str] = Query(
None,
description="Filter to executions claimed by this worker (Execution.claimed_by_worker).",
),
outcome_state: Optional[str] = Query(None),
limit: int = Query(20, ge=1, le=100),
offset: int = Query(0, ge=0),
user: AuthenticatedUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""List executions with optional filters."""
"""List executions with optional filters.

Adds `worker_id=` filter and `status__in=` comma-separated multi-status
filter to unblock menubar pending-exec counters that need to fetch
multiple statuses (e.g. `pending,delivering,retry_ready`) in one round
trip. Response gains `oldest_claimed_at` — the earliest `claimed_at`
over the filtered set (null when count=0) — so callers can render
"oldest pending: 5m" without a follow-up query.
"""
from sqlalchemy import func as sa_func
from app.services.cue_service import _execution_to_response

if status and status__in:
raise HTTPException(
status_code=400,
detail={
"error": {
"code": "conflicting_filters",
"message": "Pass either `status` or `status__in`, not both.",
"status": 400,
}
},
)

status_list: Optional[list[str]] = None
if status__in:
status_list = [s.strip() for s in status__in.split(",") if s.strip()]
if not status_list:
raise HTTPException(
status_code=400,
detail={
"error": {
"code": "invalid_status_in",
"message": "`status__in` must contain at least one status.",
"status": 400,
}
},
)

base = (
select(Execution)
.join(Cue, Execution.cue_id == Cue.id)
Expand All @@ -52,18 +95,35 @@ async def list_executions(
.join(Cue, Execution.cue_id == Cue.id)
.where(Cue.user_id == user.id)
)
oldest_base = (
select(sa_func.min(Execution.claimed_at))
.join(Cue, Execution.cue_id == Cue.id)
.where(Cue.user_id == user.id)
)

if cue_id:
base = base.where(Execution.cue_id == cue_id)
count_base = count_base.where(Execution.cue_id == cue_id)
oldest_base = oldest_base.where(Execution.cue_id == cue_id)
if status:
base = base.where(Execution.status == status)
count_base = count_base.where(Execution.status == status)
oldest_base = oldest_base.where(Execution.status == status)
if status_list:
base = base.where(Execution.status.in_(status_list))
count_base = count_base.where(Execution.status.in_(status_list))
oldest_base = oldest_base.where(Execution.status.in_(status_list))
if worker_id:
base = base.where(Execution.claimed_by_worker == worker_id)
count_base = count_base.where(Execution.claimed_by_worker == worker_id)
oldest_base = oldest_base.where(Execution.claimed_by_worker == worker_id)
if outcome_state:
base = base.where(Execution.outcome_state == outcome_state)
count_base = count_base.where(Execution.outcome_state == outcome_state)
oldest_base = oldest_base.where(Execution.outcome_state == outcome_state)

total = await db.scalar(count_base) or 0
oldest_claimed_at = await db.scalar(oldest_base) if total else None
result = await db.execute(
base.order_by(Execution.created_at.desc()).limit(limit).offset(offset)
)
Expand All @@ -75,6 +135,7 @@ async def list_executions(
"limit": limit,
"offset": offset,
"has_more": (offset + limit) < total,
"oldest_claimed_at": oldest_claimed_at.isoformat() if oldest_claimed_at else None,
}


Expand Down
83 changes: 83 additions & 0 deletions tests/test_execution_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,89 @@ async def test_list_requires_auth(self, client):
resp = await client.get("/v1/executions")
assert resp.status_code == 401

@pytest.mark.asyncio
async def test_filter_by_worker_id(self, client, auth_headers, db_session, registered_user):
"""`worker_id=` scopes to executions claimed by that worker, and
`oldest_claimed_at` reflects the earliest claimed_at over the
filtered set."""
user_id = await _get_user_id(db_session, registered_user)
cue = await _create_worker_cue(db_session, user_id)
early = datetime.now(timezone.utc) - timedelta(minutes=10)
late = datetime.now(timezone.utc) - timedelta(minutes=2)
await _create_execution(
db_session, cue.id, status="delivering",
claimed_by_worker="worker-A", claimed_at=early,
)
await _create_execution(
db_session, cue.id, status="delivering",
claimed_by_worker="worker-A", claimed_at=late,
)
await _create_execution(
db_session, cue.id, status="delivering",
claimed_by_worker="worker-B", claimed_at=datetime.now(timezone.utc),
)

resp = await client.get("/v1/executions?worker_id=worker-A", headers=auth_headers)
assert resp.status_code == 200
data = resp.json()
# Two of the three executions are claimed by worker-A; the third
# (worker-B) must be excluded. ExecutionResponse doesn't surface
# claimed_by_worker, so we verify via the count + the
# oldest_claimed_at value.
assert data["total"] == 2
# Earliest claimed_at over the filtered set, ISO 8601 string. The
# worker-B execution claimed "now" must NOT win — confirms scoping.
assert data["oldest_claimed_at"] is not None
assert data["oldest_claimed_at"].startswith(early.isoformat()[:19])

@pytest.mark.asyncio
async def test_status_in_returns_union(self, client, auth_headers, db_session, registered_user):
"""`status__in=pending,delivering` returns the union; mutex with
`status=` returns 400."""
user_id = await _get_user_id(db_session, registered_user)
cue = await _create_worker_cue(db_session, user_id)
await _create_execution(db_session, cue.id, status="pending")
await _create_execution(
db_session, cue.id, status="delivering",
claimed_at=datetime.now(timezone.utc) - timedelta(minutes=1),
)
await _create_execution(db_session, cue.id, status="success")

resp = await client.get(
"/v1/executions?status__in=pending,delivering", headers=auth_headers
)
assert resp.status_code == 200
data = resp.json()
assert data["total"] == 2
statuses = {ex["status"] for ex in data["executions"]}
assert statuses == {"pending", "delivering"}

# Mutex check.
resp = await client.get(
"/v1/executions?status=pending&status__in=pending,delivering",
headers=auth_headers,
)
assert resp.status_code == 400
# Global exception handler unwraps HTTPException.detail.
assert resp.json()["error"]["code"] == "conflicting_filters"

@pytest.mark.asyncio
async def test_oldest_claimed_at_null_when_empty(
self, client, auth_headers, db_session, registered_user
):
"""`oldest_claimed_at` is null when the filtered set is empty."""
user_id = await _get_user_id(db_session, registered_user)
cue = await _create_worker_cue(db_session, user_id)
await _create_execution(db_session, cue.id, status="pending")

resp = await client.get(
"/v1/executions?worker_id=does-not-exist", headers=auth_headers
)
assert resp.status_code == 200
data = resp.json()
assert data["total"] == 0
assert data["oldest_claimed_at"] is None


# ── 2. GET /v1/executions/{id} ──

Expand Down