diff --git a/app/routers/executions.py b/app/routers/executions.py index 9f22d37..b5ca1d9 100644 --- a/app/routers/executions.py +++ b/app/routers/executions.py @@ -32,16 +32,59 @@ async def list_executions( cue_id: Optional[str] = Query(None), status: Optional[str] = Query(None), + status__in: Optional[str] = Query( + None, + description="Comma-separated execution statuses. Mutex with `status`.", + ), + worker_id: Optional[str] = Query( + None, + description="Filter to executions claimed by this worker (Execution.claimed_by_worker).", + ), outcome_state: Optional[str] = Query(None), limit: int = Query(20, ge=1, le=100), offset: int = Query(0, ge=0), user: AuthenticatedUser = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): - """List executions with optional filters.""" + """List executions with optional filters. + + Adds `worker_id=` filter and `status__in=` comma-separated multi-status + filter to unblock menubar pending-exec counters that need to fetch + multiple statuses (e.g. `pending,delivering,retry_ready`) in one round + trip. Response gains `oldest_claimed_at` — the earliest `claimed_at` + over the filtered set (null when count=0) — so callers can render + "oldest pending: 5m" without a follow-up query. + """ from sqlalchemy import func as sa_func from app.services.cue_service import _execution_to_response + if status and status__in: + raise HTTPException( + status_code=400, + detail={ + "error": { + "code": "conflicting_filters", + "message": "Pass either `status` or `status__in`, not both.", + "status": 400, + } + }, + ) + + status_list: Optional[list[str]] = None + if status__in: + status_list = [s.strip() for s in status__in.split(",") if s.strip()] + if not status_list: + raise HTTPException( + status_code=400, + detail={ + "error": { + "code": "invalid_status_in", + "message": "`status__in` must contain at least one status.", + "status": 400, + } + }, + ) + base = ( select(Execution) .join(Cue, Execution.cue_id == Cue.id) @@ -52,18 +95,35 @@ async def list_executions( .join(Cue, Execution.cue_id == Cue.id) .where(Cue.user_id == user.id) ) + oldest_base = ( + select(sa_func.min(Execution.claimed_at)) + .join(Cue, Execution.cue_id == Cue.id) + .where(Cue.user_id == user.id) + ) if cue_id: base = base.where(Execution.cue_id == cue_id) count_base = count_base.where(Execution.cue_id == cue_id) + oldest_base = oldest_base.where(Execution.cue_id == cue_id) if status: base = base.where(Execution.status == status) count_base = count_base.where(Execution.status == status) + oldest_base = oldest_base.where(Execution.status == status) + if status_list: + base = base.where(Execution.status.in_(status_list)) + count_base = count_base.where(Execution.status.in_(status_list)) + oldest_base = oldest_base.where(Execution.status.in_(status_list)) + if worker_id: + base = base.where(Execution.claimed_by_worker == worker_id) + count_base = count_base.where(Execution.claimed_by_worker == worker_id) + oldest_base = oldest_base.where(Execution.claimed_by_worker == worker_id) if outcome_state: base = base.where(Execution.outcome_state == outcome_state) count_base = count_base.where(Execution.outcome_state == outcome_state) + oldest_base = oldest_base.where(Execution.outcome_state == outcome_state) total = await db.scalar(count_base) or 0 + oldest_claimed_at = await db.scalar(oldest_base) if total else None result = await db.execute( base.order_by(Execution.created_at.desc()).limit(limit).offset(offset) ) @@ -75,6 +135,7 @@ async def list_executions( "limit": limit, "offset": offset, "has_more": (offset + limit) < total, + "oldest_claimed_at": oldest_claimed_at.isoformat() if oldest_claimed_at else None, } diff --git a/tests/test_execution_parity.py b/tests/test_execution_parity.py index 817b4a7..de525db 100644 --- a/tests/test_execution_parity.py +++ b/tests/test_execution_parity.py @@ -104,6 +104,89 @@ async def test_list_requires_auth(self, client): resp = await client.get("/v1/executions") assert resp.status_code == 401 + @pytest.mark.asyncio + async def test_filter_by_worker_id(self, client, auth_headers, db_session, registered_user): + """`worker_id=` scopes to executions claimed by that worker, and + `oldest_claimed_at` reflects the earliest claimed_at over the + filtered set.""" + user_id = await _get_user_id(db_session, registered_user) + cue = await _create_worker_cue(db_session, user_id) + early = datetime.now(timezone.utc) - timedelta(minutes=10) + late = datetime.now(timezone.utc) - timedelta(minutes=2) + await _create_execution( + db_session, cue.id, status="delivering", + claimed_by_worker="worker-A", claimed_at=early, + ) + await _create_execution( + db_session, cue.id, status="delivering", + claimed_by_worker="worker-A", claimed_at=late, + ) + await _create_execution( + db_session, cue.id, status="delivering", + claimed_by_worker="worker-B", claimed_at=datetime.now(timezone.utc), + ) + + resp = await client.get("/v1/executions?worker_id=worker-A", headers=auth_headers) + assert resp.status_code == 200 + data = resp.json() + # Two of the three executions are claimed by worker-A; the third + # (worker-B) must be excluded. ExecutionResponse doesn't surface + # claimed_by_worker, so we verify via the count + the + # oldest_claimed_at value. + assert data["total"] == 2 + # Earliest claimed_at over the filtered set, ISO 8601 string. The + # worker-B execution claimed "now" must NOT win — confirms scoping. + assert data["oldest_claimed_at"] is not None + assert data["oldest_claimed_at"].startswith(early.isoformat()[:19]) + + @pytest.mark.asyncio + async def test_status_in_returns_union(self, client, auth_headers, db_session, registered_user): + """`status__in=pending,delivering` returns the union; mutex with + `status=` returns 400.""" + user_id = await _get_user_id(db_session, registered_user) + cue = await _create_worker_cue(db_session, user_id) + await _create_execution(db_session, cue.id, status="pending") + await _create_execution( + db_session, cue.id, status="delivering", + claimed_at=datetime.now(timezone.utc) - timedelta(minutes=1), + ) + await _create_execution(db_session, cue.id, status="success") + + resp = await client.get( + "/v1/executions?status__in=pending,delivering", headers=auth_headers + ) + assert resp.status_code == 200 + data = resp.json() + assert data["total"] == 2 + statuses = {ex["status"] for ex in data["executions"]} + assert statuses == {"pending", "delivering"} + + # Mutex check. + resp = await client.get( + "/v1/executions?status=pending&status__in=pending,delivering", + headers=auth_headers, + ) + assert resp.status_code == 400 + # Global exception handler unwraps HTTPException.detail. + assert resp.json()["error"]["code"] == "conflicting_filters" + + @pytest.mark.asyncio + async def test_oldest_claimed_at_null_when_empty( + self, client, auth_headers, db_session, registered_user + ): + """`oldest_claimed_at` is null when the filtered set is empty.""" + user_id = await _get_user_id(db_session, registered_user) + cue = await _create_worker_cue(db_session, user_id) + await _create_execution(db_session, cue.id, status="pending") + + resp = await client.get( + "/v1/executions?worker_id=does-not-exist", headers=auth_headers + ) + assert resp.status_code == 200 + data = resp.json() + assert data["total"] == 0 + assert data["oldest_claimed_at"] is None + # ── 2. GET /v1/executions/{id} ──